Compare commits
104 Commits
rust
..
37949e560b
| Author | SHA1 | Date | |
|---|---|---|---|
| 37949e560b | |||
| f04cb6b8f2 | |||
| 10262c0e5a | |||
| d345f8c41d | |||
| d1f7a57f58 | |||
| dc9af4d79c | |||
| c009d401fb | |||
| e5ef39f327 | |||
| eb67287b56 | |||
| 4aa17b4fd2 | |||
| 79571c6bdc | |||
| d46a333249 | |||
| 7446f3463d | |||
| b1a72f5a10 | |||
| 5ee539d95c | |||
| 14e26a4323 | |||
| 1c3b1c6fe9 | |||
| 5e0c10db01 | |||
| e598150d90 | |||
| 2fa6f0df62 | |||
| db76858072 | |||
| af2c5b95f7 | |||
| 1f574d8134 | |||
| 8a8d8d1477 | |||
| da074f52b4 | |||
| 9b0aa4dbe8 | |||
| 212ac14a1b | |||
| 2929f51556 | |||
| e12418cc4c | |||
| be4ec3482a | |||
| e67aafdac1 | |||
| 21e5204abd | |||
| 4095c68822 | |||
| ef37dc5af0 | |||
| fdbb068a6c | |||
| dbbf48cb14 | |||
| 1e13b0376b | |||
| 1b5cd2815e | |||
| ba4c4af2f8 | |||
| e56a284415 | |||
| cbc9eeb453 | |||
| 2f6b7deb2c | |||
| 9375448087 | |||
| 5be2f6f7aa | |||
| eebcadcba1 | |||
| 6b2bd13903 | |||
| 5dfda0a10c | |||
| a8a02d9e1c | |||
| bd1d17cc4d | |||
| 9207a7231c | |||
| c6efff9034 | |||
| 27fbd8ed15 | |||
| 348341f304 | |||
| 9380580504 | |||
| 08cf5cc1d9 | |||
| 0f0486d8d4 | |||
| 0ea2a3a985 | |||
| 21e5908c35 | |||
| 6f0a159245 | |||
| 4120a83b67 | |||
| 742cd9e921 | |||
| 593971ecb5 | |||
| 03dca998df | |||
| 0ce5f4f490 | |||
| dec4b927dc | |||
| 3f1e6d3407 | |||
| f02fd6c249 | |||
| f23796f0cc | |||
| 3f76a544e0 | |||
| e474549940 | |||
| b7e37b0399 | |||
| 263c0f0dc9 | |||
| 26d8431998 | |||
| 1f3adceda4 | |||
| 9c64a8fe42 | |||
| b04b794705 | |||
| 0f3c5b6eb4 | |||
| 66a1643bca | |||
| edc6445d70 | |||
| 2d8f1a1fd0 | |||
| cd1a1b45aa | |||
| 246a6d88f0 | |||
| 7d43b2c31b | |||
| 45c2d5e643 | |||
| 1d032c6732 | |||
| 2245cca67a | |||
| c7c244992a | |||
| 4f5b55d40f | |||
| 90874a6721 | |||
| 6b10d4249c | |||
| 57aa0aa70e | |||
| 4de457cc5e | |||
| 66e8b114b9 | |||
| 1cac45502a | |||
| 79dc8fe409 | |||
| 24a898c9a7 | |||
| 7c2a317c01 | |||
| cb619f9286 | |||
| 441270317c | |||
| 2e4318d84b | |||
| d0be16d8e3 | |||
| 83e0ad0240 | |||
| 275ce34d05 | |||
| cb5b921550 |
@@ -1,28 +0,0 @@
|
||||
# LLM Proxy Gateway Environment Variables
|
||||
|
||||
# OpenAI
|
||||
OPENAI_API_KEY=sk-demo-openai-key
|
||||
|
||||
# Google Gemini
|
||||
GEMINI_API_KEY=AIza-demo-gemini-key
|
||||
|
||||
# DeepSeek
|
||||
DEEPSEEK_API_KEY=sk-demo-deepseek-key
|
||||
|
||||
# xAI Grok (not yet available)
|
||||
GROK_API_KEY=gk-demo-grok-key
|
||||
|
||||
# Authentication tokens (comma-separated list)
|
||||
LLM_PROXY__SERVER__AUTH_TOKENS=demo-token-123456,another-token
|
||||
|
||||
# Database path (optional)
|
||||
LLM_PROXY__DATABASE__PATH=./data/llm_proxy.db
|
||||
|
||||
# Session Secret (for signed tokens)
|
||||
SESSION_SECRET=ki9khXAk9usDkasMrD2UbK4LOgrDRJz0
|
||||
|
||||
# Encryption key (required)
|
||||
LLM_PROXY__ENCRYPTION_KEY=69879f5b7913ba169982190526ae213e830b3f1f33e785ef2b68cf48c7853fcd
|
||||
|
||||
# Server port (optional)
|
||||
LLM_PROXY__SERVER__PORT=8080
|
||||
-22
@@ -1,22 +0,0 @@
|
||||
# LLM Proxy Gateway Environment Variables
|
||||
|
||||
# OpenAI
|
||||
OPENAI_API_KEY=sk-demo-openai-key
|
||||
|
||||
# Google Gemini
|
||||
GEMINI_API_KEY=AIza-demo-gemini-key
|
||||
|
||||
# DeepSeek
|
||||
DEEPSEEK_API_KEY=sk-demo-deepseek-key
|
||||
|
||||
# xAI Grok (not yet available)
|
||||
GROK_API_KEY=gk-demo-grok-key
|
||||
|
||||
# Authentication tokens (comma-separated list)
|
||||
LLM_PROXY__SERVER__AUTH_TOKENS=demo-token-123456,another-token
|
||||
|
||||
# Server port (optional)
|
||||
LLM_PROXY__SERVER__PORT=8080
|
||||
|
||||
# Database path (optional)
|
||||
LLM_PROXY__DATABASE__PATH=./data/llm_proxy.db
|
||||
+40
-24
@@ -1,31 +1,47 @@
|
||||
# LLM Proxy Gateway Environment Variables
|
||||
# Copy to .env and fill in your API keys
|
||||
# GopherGate Configuration Example
|
||||
# Copy this file to .env and fill in your values
|
||||
|
||||
# OpenAI
|
||||
OPENAI_API_KEY=your_openai_api_key_here
|
||||
# ==============================================================================
|
||||
# MANDATORY: Encryption & Security
|
||||
# ==============================================================================
|
||||
# A 32-byte hex or base64 encoded string used for session signing and
|
||||
# database encryption.
|
||||
# Generate one with: openssl rand -hex 32
|
||||
LLM_PROXY__ENCRYPTION_KEY=your_secure_32_byte_key_here
|
||||
|
||||
# Google Gemini
|
||||
GEMINI_API_KEY=your_gemini_api_key_here
|
||||
# ==============================================================================
|
||||
# LLM Provider API Keys
|
||||
# ==============================================================================
|
||||
OPENAI_API_KEY=sk-...
|
||||
GEMINI_API_KEY=AIza...
|
||||
DEEPSEEK_API_KEY=sk-...
|
||||
MOONSHOT_API_KEY=sk-...
|
||||
GROK_API_KEY=xai-...
|
||||
|
||||
# DeepSeek
|
||||
DEEPSEEK_API_KEY=your_deepseek_api_key_here
|
||||
# ==============================================================================
|
||||
# Server Configuration
|
||||
# ==============================================================================
|
||||
LLM_PROXY__SERVER__PORT=8080
|
||||
LLM_PROXY__SERVER__HOST=0.0.0.0
|
||||
|
||||
# xAI Grok (not yet available)
|
||||
GROK_API_KEY=your_grok_api_key_here
|
||||
# Optional: Bearer tokens for client authentication (comma-separated)
|
||||
# If not set, the proxy will look up tokens in the database.
|
||||
# LLM_PROXY__SERVER__AUTH_TOKENS=token1,token2
|
||||
|
||||
# Ollama (local server)
|
||||
# LLM_PROXY__PROVIDERS__OLLAMA__BASE_URL=http://your-ollama-host:11434/v1
|
||||
# ==============================================================================
|
||||
# Database Configuration
|
||||
# ==============================================================================
|
||||
LLM_PROXY__DATABASE__PATH=./data/llm_proxy.db
|
||||
LLM_PROXY__DATABASE__MAX_CONNECTIONS=10
|
||||
|
||||
# ==============================================================================
|
||||
# Provider Overrides (Optional)
|
||||
# ==============================================================================
|
||||
# LLM_PROXY__PROVIDERS__OPENAI__BASE_URL=https://api.openai.com/v1
|
||||
# LLM_PROXY__PROVIDERS__GEMINI__ENABLED=true
|
||||
# LLM_PROXY__PROVIDERS__MOONSHOT__BASE_URL=https://api.moonshot.ai/v1
|
||||
# LLM_PROXY__PROVIDERS__MOONSHOT__ENABLED=true
|
||||
# LLM_PROXY__PROVIDERS__MOONSHOT__DEFAULT_MODEL=kimi-k2.5
|
||||
# LLM_PROXY__PROVIDERS__OLLAMA__BASE_URL=http://localhost:11434/v1
|
||||
# LLM_PROXY__PROVIDERS__OLLAMA__ENABLED=true
|
||||
# LLM_PROXY__PROVIDERS__OLLAMA__MODELS=llama3,mistral,llava
|
||||
|
||||
# Authentication tokens (comma-separated list)
|
||||
LLM_PROXY__SERVER__AUTH_TOKENS=your_bearer_token_here,another_token
|
||||
|
||||
# Server port (optional)
|
||||
LLM_PROXY__SERVER__PORT=8080
|
||||
|
||||
# Database path (optional)
|
||||
LLM_PROXY__DATABASE__PATH=./data/llm_proxy.db
|
||||
|
||||
# Session secret for HMAC-signed tokens (hex or base64 encoded, 32 bytes)
|
||||
SESSION_SECRET=your_session_secret_here_32_bytes
|
||||
+25
-37
@@ -6,56 +6,44 @@ on:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
RUST_BACKTRACE: 1
|
||||
|
||||
jobs:
|
||||
check:
|
||||
name: Check
|
||||
lint:
|
||||
name: Lint
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- run: cargo check --all-targets
|
||||
|
||||
clippy:
|
||||
name: Clippy
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
components: clippy
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- run: cargo clippy --all-targets -- -D warnings
|
||||
|
||||
fmt:
|
||||
name: Formatting
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
go-version: '1.22'
|
||||
cache: true
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v4
|
||||
with:
|
||||
components: rustfmt
|
||||
- run: cargo fmt --all -- --check
|
||||
version: latest
|
||||
|
||||
test:
|
||||
name: Test
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- run: cargo test --all-targets
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
cache: true
|
||||
- name: Run Tests
|
||||
run: go test -v ./...
|
||||
|
||||
build-release:
|
||||
name: Release Build
|
||||
build:
|
||||
name: Build
|
||||
runs-on: ubuntu-latest
|
||||
needs: [check, clippy, test]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- run: cargo build --release
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
cache: true
|
||||
- name: Build
|
||||
run: go build -v -o gophergate ./cmd/gophergate
|
||||
|
||||
+15
-4
@@ -1,5 +1,16 @@
|
||||
/data/
|
||||
*.db
|
||||
*.db-shm
|
||||
*.db-wal
|
||||
.env
|
||||
.env.*
|
||||
!.env.example
|
||||
/gophergate
|
||||
/llm-proxy
|
||||
/llm-proxy-go
|
||||
*.log
|
||||
.opencode/
|
||||
.pi-lens/
|
||||
.pi-lens/cache/
|
||||
server.pid
|
||||
/target
|
||||
/.env
|
||||
/*.db
|
||||
/*.db-shm
|
||||
/*.db-wal
|
||||
|
||||
@@ -1,566 +0,0 @@
|
||||
# LLM Proxy - Comprehensive Fix Plan
|
||||
|
||||
## Project Overview
|
||||
Rust-based unified LLM proxy gateway (Axum + SQLite + Tokio) exposing an OpenAI-compatible API that routes to OpenAI, Gemini, DeepSeek, Grok, and Ollama. Includes dashboard with WebSocket monitoring. ~4,354 lines of Rust across 25 source files.
|
||||
|
||||
## Design Decisions
|
||||
- **Session management**: In-memory HashMap with expiry (no new dependencies)
|
||||
- **Provider deduplication**: Shared helper functions approach
|
||||
- **Dashboard refactor**: Full split into sub-modules (auth, usage, clients, providers, system, websocket)
|
||||
|
||||
---
|
||||
|
||||
## Phase 1: Fix Compilation & Test Issues
|
||||
|
||||
### 1.1 Fix config_path type mismatch
|
||||
**Files**: `src/config/mod.rs:98`, `src/lib.rs:99`
|
||||
|
||||
The `AppConfig.config_path` field is `PathBuf` but `test_utils::create_test_state` sets it to `None`.
|
||||
|
||||
**Fix**: Change `src/config/mod.rs:98` from `pub config_path: PathBuf` to `pub config_path: Option<PathBuf>`. Update `src/config/mod.rs:177` to wrap in `Some()`:
|
||||
```rust
|
||||
config_path: Some(config_path),
|
||||
```
|
||||
|
||||
### 1.2 Fix streaming test compilation errors
|
||||
**File**: `src/utils/streaming.rs:195-201`
|
||||
|
||||
Three issues in the test:
|
||||
1. Line 195-196: `ProviderStreamChunk` missing `reasoning_content` field
|
||||
2. Line 201: `RequestLogger::new()` called with 1 arg but needs 2 (pool + dashboard_tx)
|
||||
|
||||
**Fix**:
|
||||
```rust
|
||||
// Line 195-196: Add reasoning_content field
|
||||
Ok(ProviderStreamChunk { content: "Hello".to_string(), reasoning_content: None, finish_reason: None, model: "test".to_string() }),
|
||||
Ok(ProviderStreamChunk { content: " World".to_string(), reasoning_content: None, finish_reason: Some("stop".to_string()), model: "test".to_string() }),
|
||||
|
||||
// Line 200-201: Add dashboard_tx argument
|
||||
let (dashboard_tx, _) = tokio::sync::broadcast::channel(16);
|
||||
let logger = Arc::new(RequestLogger::new(pool.clone(), dashboard_tx));
|
||||
```
|
||||
|
||||
### 1.3 Fix multimodal test assertion
|
||||
**File**: `src/multimodal/mod.rs:283`
|
||||
|
||||
Line 283 asserts `!model_supports_multimodal("gemini-pro")` but the function at line 187-189 returns `true` for ALL models starting with "gemini".
|
||||
|
||||
**Fix**: Either:
|
||||
- (a) Update the function to exclude non-vision Gemini models (more correct):
|
||||
```rust
|
||||
if model.starts_with("gemini") {
|
||||
// gemini-pro (text-only) doesn't support multimodal, but gemini-pro-vision and gemini-1.5+ do
|
||||
return model.contains("vision") || model.contains("1.5") || model.contains("2.0") || model.contains("flash") || model.contains("ultra");
|
||||
}
|
||||
```
|
||||
- (b) Or remove the failing assertion if all Gemini models actually support vision now.
|
||||
|
||||
**Recommendation**: Option (b) - remove line 283, since modern Gemini models all support multimodal. Replace with a non-multimodal model test like `assert!(!ImageConverter::model_supports_multimodal("claude-3-opus"))`.
|
||||
|
||||
### 1.4 Clean up empty/stale test files
|
||||
**Files**: `tests/streaming_test.rs`, `tests/integration_tests.rs.bak`
|
||||
|
||||
**Fix**:
|
||||
- Delete `tests/streaming_test.rs` (empty file)
|
||||
- Delete `tests/integration_tests.rs.bak` (stale backup referencing old APIs)
|
||||
|
||||
---
|
||||
|
||||
## Phase 2: Fix Critical Bugs
|
||||
|
||||
### 2.1 Replace `futures::executor::block_on` with async
|
||||
**Files**:
|
||||
- `src/providers/openai.rs:63,151`
|
||||
- `src/providers/deepseek.rs:65`
|
||||
- `src/providers/grok.rs:63,151`
|
||||
- `src/providers/ollama.rs:58`
|
||||
|
||||
`block_on()` inside a Tokio async context will deadlock. The issue is that `image_input.to_base64()` is async but it's called inside a sync `.map()` closure within `serde_json::json!{}`.
|
||||
|
||||
**Fix**: Pre-process messages before building the JSON body. Create a helper function in a new file `src/providers/helpers.rs`:
|
||||
|
||||
```rust
|
||||
use crate::models::{ChatMessage, ContentPart};
|
||||
use crate::errors::AppError;
|
||||
|
||||
/// Convert messages to OpenAI-compatible JSON, resolving images asynchronously
|
||||
pub async fn messages_to_openai_json(messages: &[ChatMessage]) -> Result<Vec<serde_json::Value>, AppError> {
|
||||
let mut result = Vec::new();
|
||||
for m in messages {
|
||||
let mut parts = Vec::new();
|
||||
for p in &m.content {
|
||||
match p {
|
||||
ContentPart::Text { text } => {
|
||||
parts.push(serde_json::json!({ "type": "text", "text": text }));
|
||||
}
|
||||
ContentPart::Image(image_input) => {
|
||||
let (base64_data, mime_type) = image_input.to_base64().await
|
||||
.map_err(|e| AppError::MultimodalError(e.to_string()))?;
|
||||
parts.push(serde_json::json!({
|
||||
"type": "image_url",
|
||||
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
result.push(serde_json::json!({
|
||||
"role": m.role,
|
||||
"content": parts
|
||||
}));
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
```
|
||||
|
||||
Then update each provider's `chat_completion` and `chat_completion_stream` to call:
|
||||
```rust
|
||||
let messages_json = crate::providers::helpers::messages_to_openai_json(&request.messages).await?;
|
||||
let mut body = serde_json::json!({
|
||||
"model": request.model,
|
||||
"messages": messages_json,
|
||||
"stream": false,
|
||||
});
|
||||
```
|
||||
|
||||
Remove all `futures::executor::block_on` calls.
|
||||
|
||||
### 2.2 Fix broken update_client query builder
|
||||
**File**: `src/client/mod.rs:129-163`
|
||||
|
||||
The `updates` vec collects column name strings like `"name = "` but they are **never used** in the actual query. The `query_builder` receives `.push_bind()` values without corresponding column names, producing malformed SQL.
|
||||
|
||||
**Fix**: Replace the broken pattern with proper QueryBuilder usage:
|
||||
```rust
|
||||
let mut query_builder = sqlx::QueryBuilder::new("UPDATE clients SET ");
|
||||
let mut has_updates = false;
|
||||
|
||||
if let Some(name) = &request.name {
|
||||
if has_updates { query_builder.push(", "); }
|
||||
query_builder.push("name = ");
|
||||
query_builder.push_bind(name);
|
||||
has_updates = true;
|
||||
}
|
||||
|
||||
if let Some(description) = &request.description {
|
||||
if has_updates { query_builder.push(", "); }
|
||||
query_builder.push("description = ");
|
||||
query_builder.push_bind(description);
|
||||
has_updates = true;
|
||||
}
|
||||
|
||||
if let Some(is_active) = request.is_active {
|
||||
if has_updates { query_builder.push(", "); }
|
||||
query_builder.push("is_active = ");
|
||||
query_builder.push_bind(is_active);
|
||||
has_updates = true;
|
||||
}
|
||||
|
||||
if let Some(rate_limit) = request.rate_limit_per_minute {
|
||||
if has_updates { query_builder.push(", "); }
|
||||
query_builder.push("rate_limit_per_minute = ");
|
||||
query_builder.push_bind(rate_limit);
|
||||
has_updates = true;
|
||||
}
|
||||
```
|
||||
|
||||
Remove the `updates` vec entirely - it serves no purpose.
|
||||
|
||||
---
|
||||
|
||||
## Phase 3: Security Hardening
|
||||
|
||||
### 3.1 Implement in-memory session management
|
||||
**New file**: `src/dashboard/sessions.rs`
|
||||
|
||||
```rust
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use chrono::{DateTime, Utc, Duration};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Session {
|
||||
pub username: String,
|
||||
pub role: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub expires_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SessionManager {
|
||||
sessions: Arc<RwLock<HashMap<String, Session>>>,
|
||||
ttl_hours: i64,
|
||||
}
|
||||
|
||||
impl SessionManager {
|
||||
pub fn new(ttl_hours: i64) -> Self {
|
||||
Self {
|
||||
sessions: Arc::new(RwLock::new(HashMap::new())),
|
||||
ttl_hours,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn create_session(&self, username: String, role: String) -> String {
|
||||
let token = format!("session-{}", uuid::Uuid::new_v4());
|
||||
let now = Utc::now();
|
||||
let session = Session {
|
||||
username,
|
||||
role,
|
||||
created_at: now,
|
||||
expires_at: now + Duration::hours(self.ttl_hours),
|
||||
};
|
||||
self.sessions.write().await.insert(token.clone(), session);
|
||||
token
|
||||
}
|
||||
|
||||
pub async fn validate_session(&self, token: &str) -> Option<Session> {
|
||||
let sessions = self.sessions.read().await;
|
||||
sessions.get(token).and_then(|s| {
|
||||
if s.expires_at > Utc::now() {
|
||||
Some(s.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn revoke_session(&self, token: &str) {
|
||||
self.sessions.write().await.remove(token);
|
||||
}
|
||||
|
||||
pub async fn cleanup_expired(&self) {
|
||||
let now = Utc::now();
|
||||
self.sessions.write().await.retain(|_, s| s.expires_at > now);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Add `SessionManager` to `DashboardState`. Add it to `AppState` or initialize it in dashboard `router()`.
|
||||
|
||||
### 3.2 Fix handle_auth_status to validate sessions
|
||||
**File**: `src/dashboard/mod.rs:191-199`
|
||||
|
||||
Extract the session token from the `Authorization` header and validate it:
|
||||
|
||||
```rust
|
||||
async fn handle_auth_status(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let token = headers.get("Authorization")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|v| v.strip_prefix("Bearer "));
|
||||
|
||||
if let Some(token) = token {
|
||||
if let Some(session) = state.session_manager.validate_session(token).await {
|
||||
return Json(ApiResponse::success(serde_json::json!({
|
||||
"authenticated": true,
|
||||
"user": {
|
||||
"username": session.username,
|
||||
"name": "Administrator",
|
||||
"role": session.role
|
||||
}
|
||||
})));
|
||||
}
|
||||
}
|
||||
|
||||
Json(ApiResponse::error("Not authenticated".to_string()))
|
||||
}
|
||||
```
|
||||
|
||||
### 3.3 Add middleware to protect dashboard API routes
|
||||
Create an Axum middleware that validates session tokens on all `/api/` routes except `/api/auth/login`.
|
||||
|
||||
### 3.4 Force password change for default admin
|
||||
**File**: `src/database/mod.rs:138-148`
|
||||
|
||||
Add a `must_change_password` column to the `users` table. Set it to `true` for the default admin. Return `must_change_password: true` in the login response so the frontend can prompt.
|
||||
|
||||
### 3.5 Mask auth tokens in settings API response
|
||||
**File**: `src/dashboard/mod.rs:1048`
|
||||
|
||||
Use the existing `mask_token` function (currently `#[allow(dead_code)]` at line 1066):
|
||||
```rust
|
||||
"auth_tokens": state.app_state.auth_tokens.iter().map(|t| mask_token(t)).collect::<Vec<_>>(),
|
||||
```
|
||||
Remove the `#[allow(dead_code)]` attribute.
|
||||
|
||||
### 3.6 Move Gemini API key from URL to header
|
||||
**File**: `src/providers/gemini.rs:172-176,301-305`
|
||||
|
||||
Change from:
|
||||
```rust
|
||||
let url = format!("{}/models/{}:generateContent?key={}", self.config.base_url, request.model, self.api_key);
|
||||
```
|
||||
To:
|
||||
```rust
|
||||
let url = format!("{}/models/{}:generateContent", self.config.base_url, request.model);
|
||||
// ...
|
||||
let response = self.client.post(&url)
|
||||
.header("x-goog-api-key", &self.api_key)
|
||||
.json(&gemini_request)
|
||||
.send()
|
||||
.await
|
||||
```
|
||||
|
||||
Same for the streaming URL at line 301-305.
|
||||
|
||||
---
|
||||
|
||||
## Phase 4: Implement Stubs & Missing Features
|
||||
|
||||
### 4.1 Implement handle_test_provider
|
||||
**File**: `src/dashboard/mod.rs:840-849`
|
||||
|
||||
Actually test the provider by sending a minimal chat completion:
|
||||
```rust
|
||||
async fn handle_test_provider(
|
||||
State(state): State<DashboardState>,
|
||||
axum::extract::Path(name): axum::extract::Path<String>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
if let Some(provider) = state.app_state.provider_manager.get_provider(&name).await {
|
||||
let test_request = UnifiedRequest {
|
||||
model: "test".to_string(), // Provider will use default
|
||||
messages: vec![ChatMessage { role: "user".to_string(), content: vec![ContentPart::Text { text: "Hi".to_string() }] }],
|
||||
temperature: None,
|
||||
max_tokens: Some(5),
|
||||
stream: false,
|
||||
};
|
||||
|
||||
match provider.chat_completion(test_request).await {
|
||||
Ok(_) => {
|
||||
let latency = start.elapsed().as_millis();
|
||||
Json(ApiResponse::success(json!({ "success": true, "latency": latency, "message": "Connection test successful" })))
|
||||
}
|
||||
Err(e) => Json(ApiResponse::error(format!("Provider test failed: {}", e)))
|
||||
}
|
||||
} else {
|
||||
Json(ApiResponse::error(format!("Provider '{}' not found or not enabled", name)))
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 4.2 Implement real system health metrics
|
||||
**File**: `src/dashboard/mod.rs:969-978`
|
||||
|
||||
Read from `/proc/self/status` for memory, calculate from pool stats:
|
||||
```rust
|
||||
// Memory: read RSS from /proc/self/status
|
||||
let memory_kb = std::fs::read_to_string("/proc/self/status")
|
||||
.ok()
|
||||
.and_then(|s| s.lines().find(|l| l.starts_with("VmRSS:")).map(|l| l.to_string()))
|
||||
.and_then(|l| l.split_whitespace().nth(1).and_then(|v| v.parse::<f64>().ok()))
|
||||
.unwrap_or(0.0);
|
||||
let memory_mb = memory_kb / 1024.0;
|
||||
```
|
||||
|
||||
### 4.3 Implement handle_get_client
|
||||
**File**: `src/dashboard/mod.rs:647-651`
|
||||
|
||||
Query client by ID from the `clients` table and return full details.
|
||||
|
||||
### 4.4 Implement handle_client_usage
|
||||
**File**: `src/dashboard/mod.rs:676-680`
|
||||
|
||||
Query `llm_requests` aggregated by the given client_id.
|
||||
|
||||
### 4.5 Implement handle_get_provider
|
||||
**File**: `src/dashboard/mod.rs:776-780`
|
||||
|
||||
Return individual provider details (reuse logic from `handle_get_providers`).
|
||||
|
||||
### 4.6 Implement handle_system_backup
|
||||
**File**: `src/dashboard/mod.rs:1033-1039`
|
||||
|
||||
Use SQLite's backup API via raw SQL:
|
||||
```rust
|
||||
let backup_path = format!("data/backup-{}.db", chrono::Utc::now().timestamp());
|
||||
sqlx::query(&format!("VACUUM INTO '{}'", backup_path))
|
||||
.execute(pool)
|
||||
.await?;
|
||||
```
|
||||
|
||||
### 4.7 Address TODO items
|
||||
- `src/server/mod.rs:211` - Check if request messages contain `ContentPart::Image` to set `has_images: true`
|
||||
- `src/logging/mod.rs:80-81` - Add optional request/response body storage (can remain None for now, just note in code)
|
||||
|
||||
---
|
||||
|
||||
## Phase 5: Code Quality
|
||||
|
||||
### 5.1 Extract shared provider logic
|
||||
**New file**: `src/providers/helpers.rs`
|
||||
|
||||
Create shared helper functions:
|
||||
- `messages_to_openai_json()` (from Phase 2)
|
||||
- `build_openai_compatible_body()` - builds the full JSON body with model, messages, stream, temperature, max_tokens
|
||||
- `parse_openai_response()` - extracts content, reasoning_content, usage from response JSON
|
||||
- `create_openai_stream()` - creates SSE stream with standard parsing
|
||||
- `calculate_cost_with_registry()` - shared cost calculation logic
|
||||
|
||||
Update `openai.rs`, `deepseek.rs`, `grok.rs`, `ollama.rs` to call these helpers. Each provider file should shrink from ~210 lines to ~50-80 lines.
|
||||
|
||||
Add `pub mod helpers;` to `src/providers/mod.rs`.
|
||||
|
||||
### 5.2 Replace wildcard re-exports
|
||||
**File**: `src/lib.rs:22-30`
|
||||
|
||||
Replace:
|
||||
```rust
|
||||
pub use auth::*;
|
||||
pub use client::*;
|
||||
// etc.
|
||||
```
|
||||
With explicit re-exports:
|
||||
```rust
|
||||
pub use auth::AuthenticatedClient;
|
||||
pub use client::ClientManager;
|
||||
pub use config::AppConfig;
|
||||
// etc.
|
||||
```
|
||||
|
||||
### 5.3 Fix all Clippy warnings (19 total)
|
||||
|
||||
1. `src/auth/mod.rs:19` - `manual_async_fn`: Use `async fn` instead of returning a future manually
|
||||
2. `src/database/mod.rs:12` - `collapsible_if`: Merge nested if statements
|
||||
3. `src/dashboard/mod.rs:139` - `collapsible_if`: Merge nested if
|
||||
4. `src/dashboard/mod.rs:616` - `to_string_in_format_args`: Remove redundant `.to_string()`
|
||||
5. `src/multimodal/mod.rs:211,220` - `collapsible_if` x2
|
||||
6. `src/providers/openai.rs:123`, `gemini.rs:225`, `deepseek.rs:125`, `grok.rs:123`, `ollama.rs:117` - `collapsible_if` x5 in calculate_cost (will be fixed by deduplication)
|
||||
7. `src/providers/mod.rs:80` - `new_without_default`: Add `impl Default for ProviderManager`
|
||||
8. `src/providers/mod.rs:193,200` - `redundant_closure` x2: Use `Arc::clone` directly instead of `|p| Arc::clone(p)`
|
||||
9. `src/rate_limiting/mod.rs:180,333,334` - `collapsible_if` x3
|
||||
10. `src/rate_limiting/mod.rs:336` - `manual_strip`: Use `.strip_prefix()` pattern
|
||||
11. `src/utils/streaming.rs:33` - `too_many_arguments`: Wrap params in a config struct
|
||||
|
||||
### 5.4 Replace unwrap() in production paths
|
||||
|
||||
1. `src/database/mod.rs:140` - `bcrypt::hash("admin", 12).unwrap()` → Use `?` with proper error propagation
|
||||
2. `src/dashboard/mod.rs:116` - `serde_json::to_string(&event).unwrap()` → Use `unwrap_or_default()` or log error
|
||||
3. `src/server/mod.rs:168` - `.json_data(response).unwrap()` → Handle error with fallback
|
||||
4. `src/config/mod.rs:139` - `std::env::current_dir().unwrap()` → Use `?` or provide a sensible default
|
||||
|
||||
### 5.5 Remove unused dependencies
|
||||
**File**: `Cargo.toml`
|
||||
|
||||
Remove or comment out:
|
||||
- `governor = "0.6"` - Custom TokenBucket is used instead
|
||||
- `async-openai` - Raw reqwest is used for all providers
|
||||
- `once_cell = "1.19"` - Redundant with Rust 2024 edition's `std::sync::LazyLock`
|
||||
|
||||
Verify each is actually unused by checking imports with `rg 'use governor' src/` etc. before removing.
|
||||
|
||||
### 5.6 Split dashboard/mod.rs into sub-modules
|
||||
**Current**: 1077-line monolith at `src/dashboard/mod.rs`
|
||||
|
||||
**Target structure**:
|
||||
```
|
||||
src/dashboard/
|
||||
├── mod.rs (~80 lines) - Module declarations, router(), DashboardState, ApiResponse
|
||||
├── sessions.rs (~80 lines) - SessionManager (new from Phase 3)
|
||||
├── auth.rs (~80 lines) - handle_login, handle_auth_status, handle_change_password
|
||||
├── usage.rs (~200 lines) - handle_usage_summary, handle_time_series, handle_clients_usage, handle_providers_usage, handle_detailed_usage, handle_analytics_breakdown
|
||||
├── clients.rs (~100 lines) - handle_get_clients, handle_create_client, handle_get_client, handle_delete_client, handle_client_usage
|
||||
├── providers.rs (~150 lines) - handle_get_providers, handle_get_provider, handle_update_provider, handle_test_provider
|
||||
├── models.rs (~100 lines) - handle_get_models, handle_update_model
|
||||
├── system.rs (~120 lines) - handle_system_health, handle_system_logs, handle_system_backup, handle_get_settings, handle_update_settings
|
||||
└── websocket.rs (~60 lines) - handle_websocket, handle_websocket_connection, handle_websocket_message
|
||||
```
|
||||
|
||||
The `mod.rs` will declare sub-modules and re-export the `router()` function. All handlers use `DashboardState` which stays in `mod.rs`.
|
||||
|
||||
---
|
||||
|
||||
## Phase 6: Infrastructure
|
||||
|
||||
### 6.1 Add rustfmt.toml
|
||||
```toml
|
||||
max_width = 120
|
||||
tab_spaces = 4
|
||||
edition = "2024"
|
||||
```
|
||||
|
||||
### 6.2 Add clippy.toml
|
||||
```toml
|
||||
too-many-arguments-threshold = 10
|
||||
```
|
||||
|
||||
### 6.3 Add GitHub Actions CI workflow
|
||||
**New file**: `.github/workflows/ci.yml`
|
||||
|
||||
```yaml
|
||||
name: CI
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
check:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- run: cargo fmt --check
|
||||
- run: cargo clippy -- -D warnings
|
||||
- run: cargo test
|
||||
- run: cargo build --release
|
||||
```
|
||||
|
||||
### 6.4 Fix test_dashboard.sh
|
||||
**File**: `test_dashboard.sh:33`
|
||||
|
||||
Change `"admin123"` to `"admin"` to match the actual default password.
|
||||
|
||||
### 6.5 Add Dockerfile
|
||||
**New file**: `Dockerfile`
|
||||
|
||||
Multi-stage build for minimal image size:
|
||||
```dockerfile
|
||||
FROM rust:1.85-bookworm AS builder
|
||||
WORKDIR /app
|
||||
COPY Cargo.toml Cargo.lock ./
|
||||
RUN mkdir src && echo "fn main() {}" > src/main.rs && cargo build --release && rm -rf src
|
||||
COPY . .
|
||||
RUN cargo build --release
|
||||
|
||||
FROM debian:bookworm-slim
|
||||
RUN apt-get update && apt-get install -y ca-certificates && rm -rf /var/lib/apt/lists/*
|
||||
COPY --from=builder /app/target/release/llm-proxy /usr/local/bin/
|
||||
COPY --from=builder /app/static /app/static
|
||||
WORKDIR /app
|
||||
EXPOSE 8080
|
||||
CMD ["llm-proxy"]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Verification
|
||||
|
||||
After all phases, run:
|
||||
```bash
|
||||
cargo fmt --check
|
||||
cargo clippy -- -D warnings
|
||||
cargo test
|
||||
cargo build --release
|
||||
```
|
||||
|
||||
All must pass with zero warnings and zero errors.
|
||||
|
||||
---
|
||||
|
||||
## Issue Summary
|
||||
|
||||
| Severity | Count | Phase |
|
||||
|----------|-------|-------|
|
||||
| Critical | 7 | 1-3 |
|
||||
| High | 5 | 2-3 |
|
||||
| Medium | 14 | 4-5 |
|
||||
| Low | 4 | 6 |
|
||||
| **Total** | **30** | |
|
||||
|
||||
Estimated effort: ~4-6 hours of focused implementation.
|
||||
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"gopls": {
|
||||
"choice": "yes",
|
||||
"timestamp": 1775750416837
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
{
|
||||
"version": 1,
|
||||
"files": {
|
||||
"../../../../.npm-packages/lib/node_modules/pi-lens/clients/lsp/index.ts": {
|
||||
"latest": {
|
||||
"commit": "da074f5",
|
||||
"timestamp": "2026-04-26T03:45:14.025Z",
|
||||
"mi": 12.6,
|
||||
"cognitive": 335,
|
||||
"nesting": 6,
|
||||
"lines": 910,
|
||||
"maxCyclomatic": 36,
|
||||
"entropy": 6.97
|
||||
},
|
||||
"history": [
|
||||
{
|
||||
"commit": "da074f5",
|
||||
"timestamp": "2026-04-26T03:45:14.025Z",
|
||||
"mi": 12.6,
|
||||
"cognitive": 335,
|
||||
"nesting": 6,
|
||||
"lines": 910,
|
||||
"maxCyclomatic": 36,
|
||||
"entropy": 6.97
|
||||
}
|
||||
],
|
||||
"trend": "stable"
|
||||
},
|
||||
"../../../../.npm-packages/lib/node_modules/pi-lens/clients/lsp/config.ts": {
|
||||
"latest": {
|
||||
"commit": "da074f5",
|
||||
"timestamp": "2026-04-26T03:45:32.901Z",
|
||||
"mi": 37.7,
|
||||
"cognitive": 49,
|
||||
"nesting": 6,
|
||||
"lines": 173,
|
||||
"maxCyclomatic": 8,
|
||||
"entropy": 6.39
|
||||
},
|
||||
"history": [
|
||||
{
|
||||
"commit": "da074f5",
|
||||
"timestamp": "2026-04-26T03:45:32.901Z",
|
||||
"mi": 37.7,
|
||||
"cognitive": 49,
|
||||
"nesting": 6,
|
||||
"lines": 173,
|
||||
"maxCyclomatic": 8,
|
||||
"entropy": 6.39
|
||||
}
|
||||
],
|
||||
"trend": "stable"
|
||||
},
|
||||
"../../../../.npm-packages/lib/node_modules/pi-lens/clients/lsp/server.ts": {
|
||||
"latest": {
|
||||
"commit": "da074f5",
|
||||
"timestamp": "2026-04-26T03:45:38.756Z",
|
||||
"mi": 3.9,
|
||||
"cognitive": 322,
|
||||
"nesting": 7,
|
||||
"lines": 1506,
|
||||
"maxCyclomatic": 28,
|
||||
"entropy": 7.47
|
||||
},
|
||||
"history": [
|
||||
{
|
||||
"commit": "da074f5",
|
||||
"timestamp": "2026-04-26T03:45:38.756Z",
|
||||
"mi": 3.9,
|
||||
"cognitive": 322,
|
||||
"nesting": 7,
|
||||
"lines": 1506,
|
||||
"maxCyclomatic": 28,
|
||||
"entropy": 7.47
|
||||
}
|
||||
],
|
||||
"trend": "stable"
|
||||
}
|
||||
},
|
||||
"capturedAt": "2026-04-26T03:45:43.756Z"
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"files": {},
|
||||
"turnCycles": 0,
|
||||
"maxCycles": 3,
|
||||
"lastUpdated": "2026-04-27T14:41:46.671Z"
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
# Backend Architecture (Go)
|
||||
|
||||
The GopherGate backend is implemented in Go, focusing on high performance, clear concurrency patterns, and maintainability.
|
||||
|
||||
## Core Technologies
|
||||
|
||||
- **Runtime:** Go 1.22+
|
||||
- **Web Framework:** [Gin Gonic](https://github.com/gin-gonic/gin) - Fast and lightweight HTTP routing.
|
||||
- **Database:** [sqlx](https://github.com/jmoiron/sqlx) - Lightweight wrapper for standard `database/sql`.
|
||||
- **SQLite Driver:** [modernc.org/sqlite](https://modernc.org/sqlite) - CGO-free SQLite implementation for ease of cross-compilation.
|
||||
- **Config:** [Viper](https://github.com/spf13/viper) - Robust configuration management supporting environment variables and files.
|
||||
- **Metrics:** [gopsutil](https://github.com/shirou/gopsutil) - System-level resource monitoring.
|
||||
|
||||
## Project Structure
|
||||
|
||||
```text
|
||||
├── cmd/
|
||||
│ └── gophergate/ # Entry point (main.go)
|
||||
├── internal/
|
||||
│ ├── config/ # Configuration loading and validation
|
||||
│ ├── db/ # Database schema, migrations, and models
|
||||
│ ├── middleware/ # Auth and logging middleware
|
||||
│ ├── models/ # Unified request/response structs
|
||||
│ ├── providers/ # LLM provider implementations (OpenAI, Gemini, etc.)
|
||||
│ ├── server/ # HTTP server, dashboard handlers, and WebSocket hub
|
||||
│ └── utils/ # Common utilities (registry, pricing, etc.)
|
||||
└── static/ # Frontend assets (served by the backend)
|
||||
```
|
||||
|
||||
## Key Components
|
||||
|
||||
### 1. Provider Interface (`internal/providers/provider.go`)
|
||||
Standardized interface for all LLM backends. Implementations handle mapping between the unified format and provider-specific APIs (OpenAI, Gemini, DeepSeek, Grok, Moonshot, Ollama).
|
||||
|
||||
### 2. Model Registry & Pricing (`internal/utils/registry.go`)
|
||||
Integrates with `models.dev/api.json` to provide real-time model metadata and pricing.
|
||||
- **Fuzzy Matching:** Supports matching versioned model IDs (e.g., `gpt-4o-2024-08-06`) to base registry entries.
|
||||
- **Automatic Refreshes:** The registry is fetched at startup and refreshed every 24 hours via a background goroutine.
|
||||
|
||||
### 3. Asynchronous Logging (`internal/server/logging.go`)
|
||||
Uses a buffered channel and background worker to log every request to SQLite without blocking the client response. It also broadcasts logs to the WebSocket hub for real-time dashboard updates.
|
||||
|
||||
### 4. Session Management (`internal/server/sessions.go`)
|
||||
Implements HMAC-SHA256 signed tokens for dashboard authentication. Tokens secure the management interface while standard Bearer tokens are used for LLM API access.
|
||||
|
||||
### 5. WebSocket Hub (`internal/server/websocket.go`)
|
||||
A centralized hub for managing WebSocket connections, allowing real-time broadcast of system events, system metrics, and request logs to the dashboard.
|
||||
|
||||
## Concurrency Model
|
||||
|
||||
Go's goroutines and channels are used extensively:
|
||||
- **Streaming:** Each streaming request uses a goroutine to read and parse the provider's response, feeding chunks into a channel for SSE delivery.
|
||||
- **Logging:** A single background worker processes the `logChan` to perform serial database writes.
|
||||
- **WebSocket:** The `Hub` runs in a dedicated goroutine, handling registration and broadcasting.
|
||||
- **Maintenance:** Background tasks handle registry refreshes and status monitoring.
|
||||
|
||||
## Security
|
||||
|
||||
- **Encryption Key:** A mandatory 32-byte key is used for both session signing and encryption of sensitive data.
|
||||
- **Auth Middleware:** Scoped to `/v1` routes to verify client API keys against the database.
|
||||
- **Bcrypt:** Passwords for dashboard users are hashed using Bcrypt with a work factor of 12.
|
||||
- **Database Hardening:** Automatic migrations ensure the schema is always current with the code.
|
||||
@@ -1,65 +0,0 @@
|
||||
# LLM Proxy Code Review Plan
|
||||
|
||||
## Overview
|
||||
The **LLM Proxy** project is a Rust-based middleware designed to provide a unified interface for multiple Large Language Models (LLMs). Based on the repository structure, the project aims to implement a high-performance proxy server (`src/`) that handles request routing, usage tracking, and billing logic. A static dashboard (`static/`) provides a management interface for monitoring consumption and managing API keys. The architecture leverages Rust's async capabilities for efficient request handling and SQLite for persistent state management.
|
||||
|
||||
## Review Phases
|
||||
|
||||
### Phase 1: Backend Architecture & Rust Logic (@code-reviewer)
|
||||
- **Focus on:**
|
||||
- **Core Proxy Logic:** Efficiency of the request/response pipeline and streaming support.
|
||||
- **State Management:** Thread-safety and shared state patterns using `Arc` and `Mutex`/`RwLock`.
|
||||
- **Error Handling:** Use of idiomatic Rust error types and propagation.
|
||||
- **Async Performance:** Proper use of `tokio` or similar runtimes to avoid blocking the executor.
|
||||
- **Rust Idioms:** Adherence to Clippy suggestions and standard Rust naming conventions.
|
||||
|
||||
### Phase 2: Security & Authentication Audit (@security-auditor)
|
||||
- **Focus on:**
|
||||
- **API Key Management:** Secure storage, masking in logs, and rotation mechanisms.
|
||||
- **JWT Handling:** Validation logic, signature verification, and expiration checks.
|
||||
- **Input Validation:** Sanitization of prompts and configuration parameters to prevent injection.
|
||||
- **Dependency Audit:** Scanning for known vulnerabilities in the `Cargo.lock` using `cargo-audit`.
|
||||
|
||||
### Phase 3: Database & Data Integrity Review (@database-optimizer)
|
||||
- **Focus on:**
|
||||
- **Schema Design:** Efficiency of the SQLite schema for usage tracking and billing.
|
||||
- **Migration Strategy:** Robustness of the migration scripts to prevent data loss.
|
||||
- **Usage Tracking:** Accuracy of token counting and concurrency handling during increments.
|
||||
- **Query Optimization:** Identifying potential bottlenecks in reporting queries.
|
||||
|
||||
### Phase 4: Frontend & Dashboard Review (@frontend-developer)
|
||||
- **Focus on:**
|
||||
- **Vanilla JS Patterns:** Review of Web Components and modular JS in `static/js`.
|
||||
- **Security:** Protection against XSS in the dashboard and secure handling of local storage.
|
||||
- **UI/UX Consistency:** Ensuring the management interface is intuitive and responsive.
|
||||
- **API Integration:** Robustness of the frontend's communication with the Rust backend.
|
||||
|
||||
### Phase 5: Infrastructure & Deployment Review (@devops-engineer)
|
||||
- **Focus on:**
|
||||
- **Dockerfile Optimization:** Multi-stage builds to minimize image size and attack surface.
|
||||
- **Resource Limits:** Configuration of CPU/Memory limits for the proxy container.
|
||||
- **Deployment Docs:** Clarity of the setup process and environment variable documentation.
|
||||
|
||||
## Timeline (Gantt)
|
||||
|
||||
```mermaid
|
||||
gantt
|
||||
title LLM Proxy Code Review Timeline (March 2026)
|
||||
dateFormat YYYY-MM-DD
|
||||
section Backend & Security
|
||||
Architecture & Rust Logic (Phase 1) :active, p1, 2026-03-06, 1d
|
||||
Security & Auth Audit (Phase 2) :p2, 2026-03-07, 1d
|
||||
section Data & Frontend
|
||||
Database & Integrity (Phase 3) :p3, 2026-03-07, 1d
|
||||
Frontend & Dashboard (Phase 4) :p4, 2026-03-08, 1d
|
||||
section DevOps
|
||||
Infra & Deployment (Phase 5) :p5, 2026-03-08, 1d
|
||||
Final Review & Sign-off :2026-03-08, 4h
|
||||
```
|
||||
|
||||
## Success Criteria
|
||||
- **Security:** Zero high-priority vulnerabilities identified; all API keys masked in logs.
|
||||
- **Performance:** Proxy overhead is minimal (<10ms latency addition); queries are indexed.
|
||||
- **Maintainability:** Code passes all linting (`cargo clippy`) and formatting (`cargo fmt`) checks.
|
||||
- **Documentation:** README and deployment guides are up-to-date and accurate.
|
||||
- **Reliability:** Usage tracking matches actual API consumption with 99.9% accuracy.
|
||||
Generated
-4139
File diff suppressed because it is too large
Load Diff
-75
@@ -1,75 +0,0 @@
|
||||
[package]
|
||||
name = "llm-proxy"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
rust-version = "1.87"
|
||||
description = "Unified LLM proxy gateway supporting OpenAI, Gemini, DeepSeek, and Grok with token tracking and cost calculation"
|
||||
authors = ["newkirk"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
repository = ""
|
||||
|
||||
[dependencies]
|
||||
# ========== Web Framework & Async Runtime ==========
|
||||
axum = { version = "0.8", features = ["macros", "ws"] }
|
||||
tokio = { version = "1.0", features = ["rt-multi-thread", "macros", "net", "time", "signal", "fs"] }
|
||||
tower = "0.5"
|
||||
tower-http = { version = "0.6", features = ["trace", "cors", "compression-gzip", "fs", "set-header", "limit"] }
|
||||
governor = "0.7"
|
||||
|
||||
# ========== HTTP Clients ==========
|
||||
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] }
|
||||
tiktoken-rs = "0.9"
|
||||
|
||||
# ========== Database & ORM ==========
|
||||
sqlx = { version = "0.8", features = ["runtime-tokio", "sqlite", "macros", "migrate", "chrono"] }
|
||||
|
||||
# ========== Authentication & Middleware ==========
|
||||
axum-extra = { version = "0.12", features = ["typed-header"] }
|
||||
headers = "0.4"
|
||||
|
||||
# ========== Configuration Management ==========
|
||||
config = "0.13"
|
||||
dotenvy = "0.15"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
toml = "0.8"
|
||||
|
||||
# ========== Logging & Monitoring ==========
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
|
||||
|
||||
# ========== Multimodal & Image Processing ==========
|
||||
base64 = "0.21"
|
||||
image = { version = "0.25", default-features = false, features = ["jpeg", "png", "webp"] }
|
||||
mime = "0.3"
|
||||
|
||||
# ========== Error Handling & Utilities ==========
|
||||
anyhow = "1.0"
|
||||
thiserror = "1.0"
|
||||
bcrypt = "0.15"
|
||||
aes-gcm = "0.10"
|
||||
hmac = "0.12"
|
||||
sha2 = "0.10"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
uuid = { version = "1.0", features = ["v4", "serde"] }
|
||||
futures = "0.3"
|
||||
async-trait = "0.1"
|
||||
async-stream = "0.3"
|
||||
reqwest-eventsource = "0.6"
|
||||
rand = "0.9"
|
||||
hex = "0.4"
|
||||
|
||||
[dev-dependencies]
|
||||
tokio-test = "0.4"
|
||||
mockito = "1.0"
|
||||
tempfile = "3.10"
|
||||
assert_cmd = "2.0"
|
||||
insta = "1.39"
|
||||
anyhow = "1.0"
|
||||
|
||||
[profile.release]
|
||||
opt-level = 3
|
||||
lto = true
|
||||
codegen-units = 1
|
||||
strip = true
|
||||
panic = "abort"
|
||||
@@ -1,220 +0,0 @@
|
||||
# LLM Proxy Gateway - Admin Dashboard
|
||||
|
||||
## Overview
|
||||
|
||||
This is a comprehensive admin dashboard for the LLM Proxy Gateway, providing real-time monitoring, analytics, and management capabilities for the proxy service.
|
||||
|
||||
## Features
|
||||
|
||||
### 1. Dashboard Overview
|
||||
- Real-time request counters and statistics
|
||||
- System health indicators
|
||||
- Provider status monitoring
|
||||
- Recent requests stream
|
||||
|
||||
### 2. Usage Analytics
|
||||
- Time series charts for requests, tokens, and costs
|
||||
- Filter by date range, client, provider, and model
|
||||
- Top clients and models analysis
|
||||
- Export functionality to CSV/JSON
|
||||
|
||||
### 3. Cost Management
|
||||
- Cost breakdown by provider, client, and model
|
||||
- Budget tracking with alerts
|
||||
- Cost projections
|
||||
- Pricing configuration management
|
||||
|
||||
### 4. Client Management
|
||||
- List, create, revoke, and rotate API tokens
|
||||
- Client-specific rate limits
|
||||
- Usage statistics per client
|
||||
- Token management interface
|
||||
|
||||
### 5. Provider Configuration
|
||||
- Enable/disable LLM providers
|
||||
- Configure API keys (masked display)
|
||||
- Test provider connections
|
||||
- Model availability management
|
||||
|
||||
### 6. User Management (RBAC)
|
||||
- **Admin Role:** Full access to all dashboard features, user management, system configuration
|
||||
- **Viewer Role:** Read-only access to usage analytics, costs, and monitoring
|
||||
- Create/manage dashboard users with role assignment
|
||||
- Secure password management
|
||||
|
||||
### 7. Real-time Monitoring
|
||||
- Live request stream via WebSocket
|
||||
- System metrics dashboard
|
||||
- Response time and error rate tracking
|
||||
- Live system logs
|
||||
|
||||
### 7. **System Settings**
|
||||
- General configuration
|
||||
- Database management
|
||||
- Logging settings
|
||||
- Security settings
|
||||
|
||||
## Technology Stack
|
||||
|
||||
### Frontend
|
||||
- **HTML5/CSS3**: Modern, responsive design with CSS Grid/Flexbox
|
||||
- **JavaScript (ES6+)**: Vanilla JavaScript with modular architecture
|
||||
- **Chart.js**: Interactive data visualizations
|
||||
- **Luxon**: Date/time manipulation
|
||||
- **WebSocket API**: Real-time updates
|
||||
|
||||
### Backend (Rust/Axum)
|
||||
- **Axum**: Web framework with WebSocket support
|
||||
- **Tokio**: Async runtime
|
||||
- **Serde**: JSON serialization/deserialization
|
||||
- **Broadcast channels**: Real-time event distribution
|
||||
|
||||
## Installation & Setup
|
||||
|
||||
### 1. Build and Run the Server
|
||||
```bash
|
||||
# Build the project
|
||||
cargo build --release
|
||||
|
||||
# Run the server
|
||||
cargo run --release
|
||||
```
|
||||
|
||||
### 2. Access the Dashboard
|
||||
Once the server is running, access the dashboard at:
|
||||
```
|
||||
http://localhost:8080
|
||||
```
|
||||
|
||||
### 3. Default Login Credentials
|
||||
- **Username**: `admin`
|
||||
- **Password**: `admin123`
|
||||
|
||||
## API Endpoints
|
||||
|
||||
### Authentication
|
||||
- `POST /api/auth/login` - Dashboard login
|
||||
- `GET /api/auth/status` - Authentication status
|
||||
|
||||
### Analytics
|
||||
- `GET /api/usage/summary` - Overall usage summary
|
||||
- `GET /api/usage/time-series` - Time series data
|
||||
- `GET /api/usage/clients` - Client breakdown
|
||||
- `GET /api/usage/providers` - Provider breakdown
|
||||
|
||||
### Clients
|
||||
- `GET /api/clients` - List all clients
|
||||
- `POST /api/clients` - Create new client
|
||||
- `PUT /api/clients/{id}` - Update client
|
||||
- `DELETE /api/clients/{id}` - Revoke client
|
||||
- `GET /api/clients/{id}/usage` - Client-specific usage
|
||||
|
||||
### Users (RBAC)
|
||||
- `GET /api/users` - List all dashboard users
|
||||
- `POST /api/users` - Create new user
|
||||
- `PUT /api/users/{id}` - Update user (admin only)
|
||||
- `DELETE /api/users/{id}` - Delete user (admin only)
|
||||
|
||||
### Providers
|
||||
- `GET /api/providers` - List providers and status
|
||||
- `PUT /api/providers/{name}` - Update provider config
|
||||
- `POST /api/providers/{name}/test` - Test provider connection
|
||||
|
||||
### System
|
||||
- `GET /api/system/health` - System health
|
||||
- `GET /api/system/logs` - Recent logs
|
||||
- `POST /api/system/backup` - Trigger backup
|
||||
|
||||
### WebSocket
|
||||
- `GET /ws` - WebSocket endpoint for real-time updates
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
llm-proxy/
|
||||
├── src/
|
||||
│ ├── dashboard/ # Dashboard backend module
|
||||
│ │ └── mod.rs # Dashboard routes and handlers
|
||||
│ ├── server/ # Main proxy server
|
||||
│ ├── providers/ # LLM provider implementations
|
||||
│ └── ... # Other modules
|
||||
├── static/ # Frontend dashboard files
|
||||
│ ├── index.html # Main dashboard HTML
|
||||
│ ├── css/
|
||||
│ │ └── dashboard.css # Dashboard styles
|
||||
│ ├── js/
|
||||
│ │ ├── auth.js # Authentication module
|
||||
│ │ ├── dashboard.js # Main dashboard controller
|
||||
│ │ ├── websocket.js # WebSocket manager
|
||||
│ │ ├── charts.js # Chart.js utilities
|
||||
│ │ └── pages/ # Page-specific modules
|
||||
│ │ ├── overview.js
|
||||
│ │ ├── analytics.js
|
||||
│ │ ├── costs.js
|
||||
│ │ ├── clients.js
|
||||
│ │ ├── providers.js
|
||||
│ │ ├── monitoring.js
|
||||
│ │ ├── settings.js
|
||||
│ │ └── logs.js
|
||||
│ ├── img/ # Images and icons
|
||||
│ └── fonts/ # Font files
|
||||
└── Cargo.toml # Rust dependencies
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
### Adding New Pages
|
||||
1. Create a new JavaScript module in `static/js/pages/`
|
||||
2. Implement the page class with `init()` method
|
||||
3. Register the page in `dashboard.js`
|
||||
4. Add menu item in `index.html`
|
||||
|
||||
### Adding New API Endpoints
|
||||
1. Add route in `src/dashboard/mod.rs`
|
||||
2. Implement handler function
|
||||
3. Update frontend JavaScript to call the endpoint
|
||||
|
||||
### Styling Guidelines
|
||||
- Use CSS custom properties (variables) from `:root`
|
||||
- Follow mobile-first responsive design
|
||||
- Use BEM-like naming convention for CSS classes
|
||||
- Maintain consistent spacing with CSS variables
|
||||
|
||||
## Security Considerations
|
||||
|
||||
1. **Authentication**: Simple password-based auth for demo; replace with proper auth in production
|
||||
2. **API Keys**: Tokens are masked in the UI (only last 4 characters shown)
|
||||
3. **CORS**: Configure appropriate CORS headers for production
|
||||
4. **Rate Limiting**: Implement rate limiting for API endpoints
|
||||
5. **HTTPS**: Always use HTTPS in production
|
||||
|
||||
## Performance Optimizations
|
||||
|
||||
1. **Code Splitting**: JavaScript modules are loaded on-demand
|
||||
2. **Caching**: Static assets are served with cache headers
|
||||
3. **WebSocket**: Real-time updates reduce polling overhead
|
||||
4. **Lazy Loading**: Charts and tables load data as needed
|
||||
5. **Compression**: Enable gzip/brotli compression for static files
|
||||
|
||||
## Browser Support
|
||||
|
||||
- Chrome 60+
|
||||
- Firefox 55+
|
||||
- Safari 11+
|
||||
- Edge 79+
|
||||
|
||||
## License
|
||||
|
||||
MIT License - See LICENSE file for details.
|
||||
|
||||
## Contributing
|
||||
|
||||
1. Fork the repository
|
||||
2. Create a feature branch
|
||||
3. Make your changes
|
||||
4. Add tests if applicable
|
||||
5. Submit a pull request
|
||||
|
||||
## Support
|
||||
|
||||
For issues and feature requests, please use the GitHub issue tracker.
|
||||
@@ -1,480 +0,0 @@
|
||||
# Database Review Report for LLM-Proxy Repository
|
||||
|
||||
**Review Date:** 2025-03-06
|
||||
**Reviewer:** Database Optimization Expert
|
||||
**Repository:** llm-proxy
|
||||
**Focus Areas:** Schema Design, Query Optimization, Migration Strategy, Data Integrity, Usage Tracking Accuracy
|
||||
|
||||
## Executive Summary
|
||||
|
||||
The llm-proxy database implementation demonstrates solid foundation with appropriate table structures and clear separation of concerns. However, several areas require improvement to ensure scalability, data consistency, and performance as usage grows. Key findings include:
|
||||
|
||||
1. **Schema Design**: Generally normalized but missing foreign key enforcement and some critical indexes.
|
||||
2. **Query Optimization**: Well-optimized for most queries but missing composite indexes for common filtering patterns.
|
||||
3. **Migration Strategy**: Ad-hoc migration approach that may cause issues with schema evolution.
|
||||
4. **Data Integrity**: Potential race conditions in usage tracking and missing transaction boundaries.
|
||||
5. **Usage Tracking**: Generally accurate but risk of inconsistent state between related tables.
|
||||
|
||||
This report provides detailed analysis and actionable recommendations for each area.
|
||||
|
||||
## 1. Schema Design Review
|
||||
|
||||
### Tables Overview
|
||||
|
||||
The database consists of 6 main tables:
|
||||
|
||||
1. **clients**: Client management with usage aggregates
|
||||
2. **llm_requests**: Request logging with token counts and costs
|
||||
3. **provider_configs**: Provider configuration and credit balances
|
||||
4. **model_configs**: Model-specific configuration and cost overrides
|
||||
5. **users**: Dashboard user authentication
|
||||
6. **client_tokens**: API token storage for client authentication
|
||||
|
||||
### Normalization Assessment
|
||||
|
||||
**Strengths:**
|
||||
- Tables follow 3rd Normal Form (3NF) with appropriate separation
|
||||
- Foreign key relationships properly defined
|
||||
- No obvious data duplication across tables
|
||||
|
||||
**Areas for Improvement:**
|
||||
- **Denormalized aggregates**: `clients.total_requests`, `total_tokens`, `total_cost` are derived from `llm_requests`. This introduces risk of inconsistency.
|
||||
- **Provider credit balance**: Stored in `provider_configs` but also updated based on `llm_requests`. No audit trail for balance changes.
|
||||
|
||||
### Data Type Analysis
|
||||
|
||||
**Appropriate Choices:**
|
||||
- INTEGER for token counts (cast from u32 to i64)
|
||||
- REAL for monetary values
|
||||
- DATETIME for timestamps using SQLite's CURRENT_TIMESTAMP
|
||||
- TEXT for identifiers with appropriate length
|
||||
|
||||
**Potential Issues:**
|
||||
- `llm_requests.request_body` and `response_body` defined as TEXT but always set to NULL - consider removing or making optional columns.
|
||||
- `provider_configs.billing_mode` added via migration but default value not consistently applied to existing rows.
|
||||
|
||||
### Constraints and Foreign Keys
|
||||
|
||||
**Current Constraints:**
|
||||
- Primary keys defined for all tables
|
||||
- UNIQUE constraints on `clients.client_id`, `users.username`, `client_tokens.token`
|
||||
- Foreign key definitions present but **not enforced** (SQLite default)
|
||||
|
||||
**Missing Constraints:**
|
||||
- NOT NULL constraints missing on several columns where nullability not intended
|
||||
- CHECK constraints for positive values (`credit_balance >= 0`)
|
||||
- Foreign key enforcement not enabled
|
||||
|
||||
## 2. Query Optimization Analysis
|
||||
|
||||
### Indexing Strategy
|
||||
|
||||
**Existing Indexes:**
|
||||
- `idx_clients_client_id` - Essential for client lookups
|
||||
- `idx_clients_created_at` - Useful for chronological listing
|
||||
- `idx_llm_requests_timestamp` - Critical for time-based queries
|
||||
- `idx_llm_requests_client_id` - Supports client-specific queries
|
||||
- `idx_llm_requests_provider` - Good for provider breakdowns
|
||||
- `idx_llm_requests_status` - Low cardinality but acceptable
|
||||
- `idx_client_tokens_token` UNIQUE - Essential for authentication
|
||||
- `idx_client_tokens_client_id` - Supports token management
|
||||
|
||||
**Missing Critical Indexes:**
|
||||
1. `model_configs.provider_id` - Foreign key column used in JOINs
|
||||
2. `llm_requests(client_id, timestamp)` - Composite index for client time-series queries
|
||||
3. `llm_requests(provider, timestamp)` - For provider performance analysis
|
||||
4. `llm_requests(status, timestamp)` - For error trend analysis
|
||||
|
||||
### N+1 Query Detection
|
||||
|
||||
**Well-Optimized Areas:**
|
||||
- Model configuration caching prevents repeated database hits
|
||||
- Provider configs loaded in batch for dashboard display
|
||||
- Client listing uses single efficient query
|
||||
|
||||
**Potential N+1 Patterns:**
|
||||
- In `server/mod.rs` list_models function, cache lookup per model but this is in-memory
|
||||
- No significant database N+1 issues identified
|
||||
|
||||
### Inefficient Query Patterns
|
||||
|
||||
**Query 1: Time-series aggregation with strftime()**
|
||||
```sql
|
||||
SELECT strftime('%Y-%m-%d', timestamp) as date, ...
|
||||
FROM llm_requests
|
||||
WHERE 1=1 {}
|
||||
GROUP BY date, client_id, provider, model
|
||||
ORDER BY date DESC
|
||||
LIMIT 200
|
||||
```
|
||||
**Issue:** Function on indexed column prevents index utilization for the WHERE clause when filtering by timestamp range.
|
||||
|
||||
**Recommendation:** Store computed date column or use range queries on timestamp directly.
|
||||
|
||||
**Query 2: Today's stats using strftime()**
|
||||
```sql
|
||||
WHERE strftime('%Y-%m-%d', timestamp) = ?
|
||||
```
|
||||
**Issue:** Non-sargable query prevents index usage.
|
||||
|
||||
**Recommendation:** Use range query:
|
||||
```sql
|
||||
WHERE timestamp >= date(?) AND timestamp < date(?, '+1 day')
|
||||
```
|
||||
|
||||
### Recommended Index Additions
|
||||
|
||||
```sql
|
||||
-- Composite indexes for common query patterns
|
||||
CREATE INDEX idx_llm_requests_client_timestamp ON llm_requests(client_id, timestamp);
|
||||
CREATE INDEX idx_llm_requests_provider_timestamp ON llm_requests(provider, timestamp);
|
||||
CREATE INDEX idx_llm_requests_status_timestamp ON llm_requests(status, timestamp);
|
||||
|
||||
-- Foreign key index
|
||||
CREATE INDEX idx_model_configs_provider_id ON model_configs(provider_id);
|
||||
|
||||
-- Optional: Covering index for client usage queries
|
||||
CREATE INDEX idx_clients_usage ON clients(client_id, total_requests, total_tokens, total_cost);
|
||||
```
|
||||
|
||||
## 3. Migration Strategy Assessment
|
||||
|
||||
### Current Approach
|
||||
|
||||
The migration system uses a hybrid approach:
|
||||
|
||||
1. **Schema synchronization**: `CREATE TABLE IF NOT EXISTS` on startup
|
||||
2. **Ad-hoc migrations**: `ALTER TABLE` statements with error suppression
|
||||
3. **Single migration file**: `migrations/001-add-billing-mode.sql` with transaction wrapper
|
||||
|
||||
**Pros:**
|
||||
- Simple to understand and maintain
|
||||
- Automatic schema creation for new deployments
|
||||
- Error suppression prevents crashes on column existence
|
||||
|
||||
**Cons:**
|
||||
- No version tracking of applied migrations
|
||||
- Potential for inconsistent schema across deployments
|
||||
- `ALTER TABLE` error suppression hides genuine schema issues
|
||||
- No rollback capability
|
||||
|
||||
### Risks and Limitations
|
||||
|
||||
1. **Schema Drift**: Different instances may have different schemas if migrations are applied out of order
|
||||
2. **Data Loss Risk**: No backup/verification before schema changes
|
||||
3. **Production Issues**: Error suppression could mask migration failures until runtime
|
||||
|
||||
### Recommendations
|
||||
|
||||
1. **Implement Proper Migration Tooling**: Use `sqlx migrate` or similar versioned migration system
|
||||
2. **Add Migration Version Table**: Track applied migrations and checksum verification
|
||||
3. **Separate Migration Scripts**: One file per migration with up/down directions
|
||||
4. **Pre-deployment Validation**: Schema checks in CI/CD pipeline
|
||||
5. **Backup Strategy**: Automatic backups before migration execution
|
||||
|
||||
## 4. Data Integrity Evaluation
|
||||
|
||||
### Foreign Key Enforcement
|
||||
|
||||
**Critical Issue:** Foreign key constraints are defined but **not enforced** in SQLite.
|
||||
|
||||
**Impact:** Orphaned records, inconsistent referential integrity.
|
||||
|
||||
**Solution:** Enable foreign key support in connection string:
|
||||
```rust
|
||||
let options = SqliteConnectOptions::from_str(&format!("sqlite:{}", database_path))?
|
||||
.create_if_missing(true)
|
||||
.pragma("foreign_keys", "ON");
|
||||
```
|
||||
|
||||
### Transaction Usage
|
||||
|
||||
**Good Patterns:**
|
||||
- Request logging uses transactions for insert + provider balance update
|
||||
- Atomic UPDATE for client usage statistics
|
||||
|
||||
**Problematic Areas:**
|
||||
|
||||
1. **Split Transactions**: Client usage update and request logging are in separate transactions
|
||||
- In `logging/mod.rs`: `insert_log` transaction includes provider balance update
|
||||
- In `utils/streaming.rs`: Client usage updated separately after logging
|
||||
- **Risk**: Partial updates if one transaction fails
|
||||
|
||||
2. **No Transaction for Client Creation**: Client and token creation not atomic
|
||||
|
||||
**Recommendations:**
|
||||
- Wrap client usage update within the same transaction as request logging
|
||||
- Use transaction for client + token creation
|
||||
- Consider using savepoints for complex operations
|
||||
|
||||
### Race Conditions and Consistency
|
||||
|
||||
**Potential Race Conditions:**
|
||||
1. **Provider credit balance**: Concurrent requests may cause lost updates
|
||||
- Current: `UPDATE provider_configs SET credit_balance = credit_balance - ?`
|
||||
- SQLite provides serializable isolation, but negative balances not prevented
|
||||
|
||||
2. **Client usage aggregates**: Concurrent updates to `total_requests`, `total_tokens`, `total_cost`
|
||||
- Similar UPDATE pattern, generally safe but consider idempotency
|
||||
|
||||
**Recommendations:**
|
||||
- Add check constraint: `CHECK (credit_balance >= 0)`
|
||||
- Implement idempotent request logging with unique request IDs
|
||||
- Consider optimistic concurrency control for critical balances
|
||||
|
||||
## 5. Usage Tracking Accuracy
|
||||
|
||||
### Token Counting Methodology
|
||||
|
||||
**Current Approach:**
|
||||
- Prompt tokens: Estimated using provider-specific estimators
|
||||
- Completion tokens: Estimated or from provider real usage data
|
||||
- Cache tokens: Separately tracked for cache-aware pricing
|
||||
|
||||
**Strengths:**
|
||||
- Fallback to estimation when provider doesn't report usage
|
||||
- Cache token differentiation for accurate pricing
|
||||
|
||||
**Weaknesses:**
|
||||
- Estimation may differ from actual provider counts
|
||||
- No validation of provider-reported token counts
|
||||
|
||||
### Cost Calculation
|
||||
|
||||
**Well Implemented:**
|
||||
- Model-specific cost overrides via `model_configs`
|
||||
- Cache-aware pricing when supported by registry
|
||||
- Provider fallback calculations
|
||||
|
||||
**Potential Issues:**
|
||||
- Floating-point precision for monetary calculations
|
||||
- No rounding strategy for fractional cents
|
||||
|
||||
### Update Consistency
|
||||
|
||||
**Inconsistency Risk:** Client aggregates updated separately from request logging.
|
||||
|
||||
**Example Flow:**
|
||||
1. Request log inserted and provider balance updated (transaction)
|
||||
2. Client usage updated (separate operation)
|
||||
3. If step 2 fails, client stats undercount usage
|
||||
|
||||
**Solution:** Include client update in the same transaction:
|
||||
```rust
|
||||
// In insert_log function, add:
|
||||
UPDATE clients
|
||||
SET total_requests = total_requests + 1,
|
||||
total_tokens = total_tokens + ?,
|
||||
total_cost = total_cost + ?
|
||||
WHERE client_id = ?;
|
||||
```
|
||||
|
||||
### Financial Accuracy
|
||||
|
||||
**Good Practices:**
|
||||
- Token-level granularity for cost calculation
|
||||
- Separation of prompt/completion/cache pricing
|
||||
- Database persistence for audit trail
|
||||
|
||||
**Recommendations:**
|
||||
1. **Audit Trail**: Add `balance_transactions` table for provider credit changes
|
||||
2. **Rounding Policy**: Define rounding strategy (e.g., to 6 decimal places)
|
||||
3. **Validation**: Periodic reconciliation of aggregates vs. detail records
|
||||
|
||||
## 6. Performance Recommendations
|
||||
|
||||
### Schema Improvements
|
||||
|
||||
1. **Partitioning Strategy**: For high-volume `llm_requests`, consider:
|
||||
- Monthly partitioning by timestamp
|
||||
- Archive old data to separate tables
|
||||
|
||||
2. **Data Retention Policy**: Implement automatic cleanup of old request logs
|
||||
```sql
|
||||
DELETE FROM llm_requests WHERE timestamp < date('now', '-90 days');
|
||||
```
|
||||
|
||||
3. **Column Optimization**: Remove unused `request_body`, `response_body` columns or implement compression
|
||||
|
||||
### Query Optimizations
|
||||
|
||||
1. **Avoid Functions on Indexed Columns**: Rewrite date queries as range queries
|
||||
2. **Batch Updates**: Consider batch updates for client usage instead of per-request
|
||||
3. **Read Replicas**: For dashboard queries, consider separate read connection
|
||||
|
||||
### Connection Pooling
|
||||
|
||||
**Current:** SQLx connection pool with default settings
|
||||
|
||||
**Recommendations:**
|
||||
- Configure pool size based on expected concurrency
|
||||
- Implement connection health checks
|
||||
- Monitor pool utilization metrics
|
||||
|
||||
### Monitoring Setup
|
||||
|
||||
**Essential Metrics:**
|
||||
- Query execution times (slow query logging)
|
||||
- Index usage statistics
|
||||
- Table growth trends
|
||||
- Connection pool utilization
|
||||
|
||||
**Implementation:**
|
||||
- Add `sqlx::metrics` integration
|
||||
- Regular `ANALYZE` execution for query planner
|
||||
- Dashboard for database health monitoring
|
||||
|
||||
## 7. Security Considerations
|
||||
|
||||
### Data Protection
|
||||
|
||||
**Sensitive Data:**
|
||||
- `provider_configs.api_key` - Should be encrypted at rest
|
||||
- `users.password_hash` - Already hashed with bcrypt
|
||||
- `client_tokens.token` - Plain text storage
|
||||
|
||||
**Recommendations:**
|
||||
- Encrypt API keys using libsodium or similar
|
||||
- Implement token hashing (similar to password hashing)
|
||||
- Regular security audits of authentication flows
|
||||
|
||||
### SQL Injection Prevention
|
||||
|
||||
**Good Practices:**
|
||||
- Use sqlx query builder with parameter binding
|
||||
- No raw SQL concatenation observed in code review
|
||||
|
||||
**Verification Needed:** Ensure all dynamic SQL uses parameterized queries
|
||||
|
||||
### Access Controls
|
||||
|
||||
**Database Level:**
|
||||
- SQLite lacks built-in user management
|
||||
- Consider file system permissions for database file
|
||||
- Application-level authentication is primary control
|
||||
|
||||
## 8. Summary of Critical Issues
|
||||
|
||||
**Priority 1 (Critical):**
|
||||
1. Foreign key constraints not enabled
|
||||
2. Split transactions risking data inconsistency
|
||||
3. Missing composite indexes for common queries
|
||||
|
||||
**Priority 2 (High):**
|
||||
1. No proper migration versioning system
|
||||
2. Potential race conditions in balance updates
|
||||
3. Non-sargable date queries impacting performance
|
||||
|
||||
**Priority 3 (Medium):**
|
||||
1. Denormalized aggregates without consistency guarantees
|
||||
2. No data retention policy for request logs
|
||||
3. Missing check constraints for data validation
|
||||
|
||||
## 9. Recommended Action Plan
|
||||
|
||||
### Phase 1: Immediate Fixes (1-2 weeks)
|
||||
1. Enable foreign key constraints in database connection
|
||||
2. Add composite indexes for common query patterns
|
||||
3. Fix transaction boundaries for client usage updates
|
||||
4. Rewrite non-sargable date queries
|
||||
|
||||
### Phase 2: Short-term Improvements (3-4 weeks)
|
||||
1. Implement proper migration system with version tracking
|
||||
2. Add check constraints for data validation
|
||||
3. Implement connection pooling configuration
|
||||
4. Create database monitoring dashboard
|
||||
|
||||
### Phase 3: Long-term Enhancements (2-3 months)
|
||||
1. Implement data retention and archiving strategy
|
||||
2. Add audit trail for provider balance changes
|
||||
3. Consider partitioning for high-volume tables
|
||||
4. Implement encryption for sensitive data
|
||||
|
||||
### Phase 4: Ongoing Maintenance
|
||||
1. Regular index maintenance and query plan analysis
|
||||
2. Periodic reconciliation of aggregate vs. detail data
|
||||
3. Security audits and dependency updates
|
||||
4. Performance benchmarking and optimization
|
||||
|
||||
---
|
||||
|
||||
## Appendices
|
||||
|
||||
### A. Sample Migration Implementation
|
||||
|
||||
```sql
|
||||
-- migrations/002-enable-foreign-keys.sql
|
||||
PRAGMA foreign_keys = ON;
|
||||
|
||||
-- migrations/003-add-composite-indexes.sql
|
||||
CREATE INDEX idx_llm_requests_client_timestamp ON llm_requests(client_id, timestamp);
|
||||
CREATE INDEX idx_llm_requests_provider_timestamp ON llm_requests(provider, timestamp);
|
||||
CREATE INDEX idx_model_configs_provider_id ON model_configs(provider_id);
|
||||
```
|
||||
|
||||
### B. Transaction Fix Example
|
||||
|
||||
```rust
|
||||
async fn insert_log(pool: &SqlitePool, log: RequestLog) -> Result<(), sqlx::Error> {
|
||||
let mut tx = pool.begin().await?;
|
||||
|
||||
// Insert or ignore client
|
||||
sqlx::query("INSERT OR IGNORE INTO clients (client_id, name, description) VALUES (?, ?, 'Auto-created from request')")
|
||||
.bind(&log.client_id)
|
||||
.bind(&log.client_id)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
// Insert request log
|
||||
sqlx::query("INSERT INTO llm_requests ...")
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
// Update provider balance
|
||||
if log.cost > 0.0 {
|
||||
sqlx::query("UPDATE provider_configs SET credit_balance = credit_balance - ? WHERE id = ? AND (billing_mode IS NULL OR billing_mode != 'postpaid')")
|
||||
.bind(log.cost)
|
||||
.bind(&log.provider)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Update client aggregates within same transaction
|
||||
sqlx::query("UPDATE clients SET total_requests = total_requests + 1, total_tokens = total_tokens + ?, total_cost = total_cost + ? WHERE client_id = ?")
|
||||
.bind(log.total_tokens as i64)
|
||||
.bind(log.cost)
|
||||
.bind(&log.client_id)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
tx.commit().await?;
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
### C. Monitoring Query Examples
|
||||
|
||||
```sql
|
||||
-- Identify unused indexes
|
||||
SELECT * FROM sqlite_master
|
||||
WHERE type = 'index'
|
||||
AND name NOT IN (
|
||||
SELECT DISTINCT name
|
||||
FROM sqlite_stat1
|
||||
WHERE tbl = 'llm_requests'
|
||||
);
|
||||
|
||||
-- Table size analysis
|
||||
SELECT name, (pgsize * page_count) / 1024 / 1024 as size_mb
|
||||
FROM dbstat
|
||||
WHERE name = 'llm_requests';
|
||||
|
||||
-- Query performance analysis (requires EXPLAIN QUERY PLAN)
|
||||
EXPLAIN QUERY PLAN
|
||||
SELECT * FROM llm_requests
|
||||
WHERE client_id = ? AND timestamp >= ?;
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
*This report provides a comprehensive analysis of the current database implementation and actionable recommendations for improvement. Regular review and iteration will ensure the database continues to meet performance, consistency, and scalability requirements as the application grows.*
|
||||
+21
-22
@@ -1,35 +1,34 @@
|
||||
# ── Build stage ──────────────────────────────────────────────
|
||||
FROM rust:1-bookworm AS builder
|
||||
# Build stage
|
||||
FROM golang:1.22-alpine AS builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Cache dependency build
|
||||
COPY Cargo.toml Cargo.lock ./
|
||||
RUN mkdir src && echo 'fn main() {}' > src/main.rs && \
|
||||
cargo build --release && \
|
||||
rm -rf src
|
||||
# Copy go mod and sum files
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
|
||||
# Build the actual binary
|
||||
COPY src/ src/
|
||||
RUN touch src/main.rs && cargo build --release
|
||||
# Copy the source code
|
||||
COPY . .
|
||||
|
||||
# ── Runtime stage ────────────────────────────────────────────
|
||||
FROM debian:bookworm-slim
|
||||
# Build the application
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -o gophergate ./cmd/gophergate
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends ca-certificates && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
# Final stage
|
||||
FROM alpine:latest
|
||||
|
||||
RUN apk --no-cache add ca-certificates tzdata
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY --from=builder /app/target/release/llm-proxy /app/llm-proxy
|
||||
COPY static/ /app/static/
|
||||
# Copy the binary from the builder stage
|
||||
COPY --from=builder /app/gophergate .
|
||||
COPY --from=builder /app/static ./static
|
||||
|
||||
# Default config location
|
||||
VOLUME ["/app/config", "/app/data"]
|
||||
# Create data directory
|
||||
RUN mkdir -p /app/data
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8080
|
||||
|
||||
ENV RUST_LOG=info
|
||||
|
||||
ENTRYPOINT ["/app/llm-proxy"]
|
||||
# Run the application
|
||||
CMD ["./gophergate"]
|
||||
|
||||
-232
@@ -1,232 +0,0 @@
|
||||
# Optimization for 512MB RAM Environment
|
||||
|
||||
This document provides guidance for optimizing the LLM Proxy Gateway for deployment in resource-constrained environments (512MB RAM).
|
||||
|
||||
## Memory Optimization Strategies
|
||||
|
||||
### 1. Build Optimization
|
||||
|
||||
The project is already configured with optimized build settings in `Cargo.toml`:
|
||||
|
||||
```toml
|
||||
[profile.release]
|
||||
opt-level = 3 # Maximum optimization
|
||||
lto = true # Link-time optimization
|
||||
codegen-units = 1 # Single codegen unit for better optimization
|
||||
strip = true # Strip debug symbols
|
||||
```
|
||||
|
||||
**Additional optimizations you can apply:**
|
||||
|
||||
```bash
|
||||
# Build with specific target for better optimization
|
||||
cargo build --release --target x86_64-unknown-linux-musl
|
||||
|
||||
# Or for ARM (Raspberry Pi, etc.)
|
||||
cargo build --release --target aarch64-unknown-linux-musl
|
||||
```
|
||||
|
||||
### 2. Runtime Memory Management
|
||||
|
||||
#### Database Connection Pool
|
||||
- Default: 10 connections
|
||||
- Recommended for 512MB: 5 connections
|
||||
|
||||
Update `config.toml`:
|
||||
```toml
|
||||
[database]
|
||||
max_connections = 5
|
||||
```
|
||||
|
||||
#### Rate Limiting Memory Usage
|
||||
- Client rate limit buckets: Store in memory
|
||||
- Circuit breakers: Minimal memory usage
|
||||
- Consider reducing burst capacity if memory is critical
|
||||
|
||||
#### Provider Management
|
||||
- Only enable providers you actually use
|
||||
- Disable unused providers in configuration
|
||||
|
||||
### 3. Configuration for Low Memory
|
||||
|
||||
Create a `config-low-memory.toml`:
|
||||
|
||||
```toml
|
||||
[server]
|
||||
port = 8080
|
||||
host = "0.0.0.0"
|
||||
|
||||
[database]
|
||||
path = "./data/llm_proxy.db"
|
||||
max_connections = 3 # Reduced from default 10
|
||||
|
||||
[providers]
|
||||
# Only enable providers you need
|
||||
openai.enabled = true
|
||||
gemini.enabled = false # Disable if not used
|
||||
deepseek.enabled = false # Disable if not used
|
||||
grok.enabled = false # Disable if not used
|
||||
|
||||
[rate_limiting]
|
||||
# Reduce memory usage for rate limiting
|
||||
client_requests_per_minute = 30 # Reduced from 60
|
||||
client_burst_size = 5 # Reduced from 10
|
||||
global_requests_per_minute = 300 # Reduced from 600
|
||||
```
|
||||
|
||||
### 4. System-Level Optimizations
|
||||
|
||||
#### Linux Kernel Parameters
|
||||
Add to `/etc/sysctl.conf`:
|
||||
```bash
|
||||
# Reduce TCP buffer sizes
|
||||
net.ipv4.tcp_rmem = 4096 87380 174760
|
||||
net.ipv4.tcp_wmem = 4096 65536 131072
|
||||
|
||||
# Reduce connection tracking
|
||||
net.netfilter.nf_conntrack_max = 65536
|
||||
net.netfilter.nf_conntrack_tcp_timeout_established = 1200
|
||||
|
||||
# Reduce socket buffer sizes
|
||||
net.core.rmem_max = 131072
|
||||
net.core.wmem_max = 131072
|
||||
net.core.rmem_default = 65536
|
||||
net.core.wmem_default = 65536
|
||||
```
|
||||
|
||||
#### Systemd Service Configuration
|
||||
Create `/etc/systemd/system/llm-proxy.service`:
|
||||
```ini
|
||||
[Unit]
|
||||
Description=LLM Proxy Gateway
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
User=llmproxy
|
||||
Group=llmproxy
|
||||
WorkingDirectory=/opt/llm-proxy
|
||||
ExecStart=/opt/llm-proxy/llm-proxy
|
||||
Restart=on-failure
|
||||
RestartSec=5
|
||||
|
||||
# Memory limits
|
||||
MemoryMax=400M
|
||||
MemorySwapMax=100M
|
||||
|
||||
# CPU limits
|
||||
CPUQuota=50%
|
||||
|
||||
# Process limits
|
||||
LimitNOFILE=65536
|
||||
LimitNPROC=512
|
||||
|
||||
Environment="RUST_LOG=info"
|
||||
Environment="LLM_PROXY__DATABASE__MAX_CONNECTIONS=3"
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
```
|
||||
|
||||
### 5. Application-Specific Optimizations
|
||||
|
||||
#### Disable Unused Features
|
||||
- **Multimodal support**: If not using images, disable image processing dependencies
|
||||
- **Dashboard**: The dashboard uses WebSockets and additional memory. Consider disabling if not needed.
|
||||
- **Detailed logging**: Reduce log verbosity in production
|
||||
|
||||
#### Memory Pool Sizes
|
||||
The application uses several memory pools:
|
||||
1. **Database connection pool**: Configured via `max_connections`
|
||||
2. **HTTP client pool**: Reqwest client pool (defaults to reasonable values)
|
||||
3. **Async runtime**: Tokio worker threads
|
||||
|
||||
Reduce Tokio worker threads for low-core systems:
|
||||
```rust
|
||||
// In main.rs, modify tokio runtime creation
|
||||
#[tokio::main(flavor = "current_thread")] // Single-threaded runtime
|
||||
async fn main() -> Result<()> {
|
||||
// Or for multi-threaded with limited threads:
|
||||
// #[tokio::main(worker_threads = 2)]
|
||||
```
|
||||
|
||||
### 6. Monitoring and Profiling
|
||||
|
||||
#### Memory Usage Monitoring
|
||||
```bash
|
||||
# Install heaptrack for memory profiling
|
||||
cargo install heaptrack
|
||||
|
||||
# Profile memory usage
|
||||
heaptrack ./target/release/llm-proxy
|
||||
|
||||
# Monitor with ps
|
||||
ps aux --sort=-%mem | head -10
|
||||
|
||||
# Monitor with top
|
||||
top -p $(pgrep llm-proxy)
|
||||
```
|
||||
|
||||
#### Performance Benchmarks
|
||||
Test with different configurations:
|
||||
```bash
|
||||
# Test with 100 concurrent connections
|
||||
wrk -t4 -c100 -d30s http://localhost:8080/health
|
||||
|
||||
# Test chat completion endpoint
|
||||
ab -n 1000 -c 10 -p test_request.json -T application/json http://localhost:8080/v1/chat/completions
|
||||
```
|
||||
|
||||
### 7. Deployment Checklist for 512MB RAM
|
||||
|
||||
- [ ] Build with release profile: `cargo build --release`
|
||||
- [ ] Configure database with `max_connections = 3`
|
||||
- [ ] Disable unused providers in configuration
|
||||
- [ ] Set appropriate rate limiting limits
|
||||
- [ ] Configure systemd with memory limits
|
||||
- [ ] Set up log rotation to prevent disk space issues
|
||||
- [ ] Monitor memory usage during initial deployment
|
||||
- [ ] Consider using swap space (512MB-1GB) for safety
|
||||
|
||||
### 8. Troubleshooting High Memory Usage
|
||||
|
||||
#### Common Issues and Solutions:
|
||||
|
||||
1. **Database connection leaks**: Ensure connections are properly closed
|
||||
2. **Memory fragmentation**: Use jemalloc or mimalloc as allocator
|
||||
3. **Unbounded queues**: Check WebSocket message queues
|
||||
4. **Cache growth**: Implement cache limits or TTL
|
||||
|
||||
#### Add to Cargo.toml for alternative allocator:
|
||||
```toml
|
||||
[dependencies]
|
||||
mimalloc = { version = "0.1", default-features = false }
|
||||
|
||||
[features]
|
||||
default = ["mimalloc"]
|
||||
```
|
||||
|
||||
#### In main.rs:
|
||||
```rust
|
||||
#[global_allocator]
|
||||
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
|
||||
```
|
||||
|
||||
### 9. Expected Memory Usage
|
||||
|
||||
| Component | Baseline | With 10 clients | With 100 clients |
|
||||
|-----------|----------|-----------------|------------------|
|
||||
| Base executable | 15MB | 15MB | 15MB |
|
||||
| Database connections | 5MB | 8MB | 15MB |
|
||||
| Rate limiting | 2MB | 5MB | 20MB |
|
||||
| HTTP clients | 3MB | 5MB | 10MB |
|
||||
| **Total** | **25MB** | **33MB** | **60MB** |
|
||||
|
||||
**Note**: These are estimates. Actual usage depends on request volume, payload sizes, and configuration.
|
||||
|
||||
### 10. Further Reading
|
||||
|
||||
- [Tokio performance guide](https://tokio.rs/tokio/topics/performance)
|
||||
- [Rust performance book](https://nnethercote.github.io/perf-book/)
|
||||
- [Linux memory management](https://www.kernel.org/doc/html/latest/admin-guide/mm/)
|
||||
- [SQLite performance tips](https://www.sqlite.org/faq.html#q19)
|
||||
@@ -1,99 +1,202 @@
|
||||
# Project Plan: LLM Proxy Enhancements & Security Upgrade
|
||||
# GopherGate — Remediation Plan
|
||||
|
||||
This document outlines the roadmap for standardizing frontend security, cleaning up the codebase, upgrading session management to HMAC-signed tokens, and extending integration testing.
|
||||
|
||||
## Phase 1: Frontend Security Standardization
|
||||
**Primary Agent:** `frontend-developer`
|
||||
|
||||
- [x] Audit `static/js/pages/users.js` for manual HTML string concatenation.
|
||||
- [x] Replace custom escaping or unescaped injections with `window.api.escapeHtml`.
|
||||
- [x] Verify user list and user detail rendering for XSS vulnerabilities.
|
||||
|
||||
## Phase 2: Codebase Cleanup
|
||||
**Primary Agent:** `backend-developer`
|
||||
|
||||
- [x] Identify and remove unused imports in `src/config/mod.rs`.
|
||||
- [x] Identify and remove unused imports in `src/providers/mod.rs`.
|
||||
- [x] Run `cargo clippy` and `cargo fmt` to ensure adherence to standards.
|
||||
|
||||
## Phase 3: HMAC Architectural Upgrade
|
||||
**Primary Agents:** `fullstack-developer`, `security-auditor`, `backend-developer`
|
||||
|
||||
### 3.1 Design (Security Auditor)
|
||||
- [x] Define Token Structure: `base64(payload).signature`.
|
||||
- Payload: `{ "session_id": "...", "username": "...", "role": "...", "exp": ... }`
|
||||
- [x] Select HMAC algorithm (HMAC-SHA256).
|
||||
- [x] Define environment variable for secret key: `SESSION_SECRET`.
|
||||
|
||||
### 3.2 Implementation (Backend Developer)
|
||||
- [x] Refactor `src/dashboard/sessions.rs`:
|
||||
- Integrate `hmac` and `sha2` crates (or similar).
|
||||
- Update `create_session` to return signed tokens.
|
||||
- Update `validate_session` to verify signature before checking store.
|
||||
- [x] Implement activity-based session refresh:
|
||||
- If session is valid and >50% through its TTL, extend `expires_at` and issue new signed token.
|
||||
|
||||
### 3.3 Integration (Fullstack Developer)
|
||||
- [x] Update dashboard API handlers to handle new token format.
|
||||
- [x] Update frontend session storage/retrieval if necessary.
|
||||
|
||||
## Phase 4: Extended Integration Testing
|
||||
**Primary Agent:** `qa-automation`
|
||||
|
||||
- [ ] Setup test environment with encrypted key storage enabled.
|
||||
- [ ] Implement end-to-end flow:
|
||||
1. Store encrypted provider key via API.
|
||||
2. Authenticate through Proxy.
|
||||
3. Make proxied LLM request (verifying decryption and usage).
|
||||
- [ ] Validate HMAC token expiration and refresh logic in automated tests.
|
||||
|
||||
## Phase 5: Code Quality & Refactoring
|
||||
**Primary Agent:** `fullstack-developer`
|
||||
|
||||
- [x] Refactor dashboard monolith into modular sub-modules (`auth.rs`, `usage.rs`, etc.).
|
||||
- [x] Standardize error handling and remove `unwrap()` in production paths.
|
||||
- [x] Implement system health metrics and backup functionality.
|
||||
> 3 phases, 6 weeks total. Each phase independently shippable.
|
||||
|
||||
---
|
||||
|
||||
# Phase 6: Cache Cost & Provider Audit (ACTIVE)
|
||||
**Primary Agents:** `frontend-developer`, `backend-developer`, `database-optimizer`, `lab-assistant`
|
||||
## Phase 1 — Security & Stability (Weeks 1-2)
|
||||
|
||||
## 6.1 Dashboard UI Updates (@frontend-developer)
|
||||
- [ ] **Update Models Page Modal:** Add input fields for `Cache Read Cost` and `Cache Write Cost` in `static/js/pages/models.js`.
|
||||
- [ ] **API Integration:** Ensure `window.api.put` includes these new cost fields in the request body.
|
||||
- [ ] **Verify Costs Page:** Confirm `static/js/pages/costs.js` displays these rates correctly in the pricing table.
|
||||
**Goal:** Patch auth bypass, data races, debug leaks. No new features.
|
||||
|
||||
## 6.2 Provider Audit & Stream Fixes (@backend-developer)
|
||||
- [ ] **Standard DeepSeek Fix:** Modify `src/providers/deepseek.rs` to stop stripping `stream_options` for `deepseek-chat`.
|
||||
- [ ] **Grok Audit:** Verify if Grok correctly returns usage in streaming; it uses `build_openai_body` and doesn't seem to strip it.
|
||||
- [ ] **Gemini Audit:** Confirm Gemini returns `usage_metadata` reliably in the final chunk.
|
||||
- [ ] **Anthropic Audit:** Check if Anthropic streaming requires `include_usage` or similar flags.
|
||||
### 1.1 Fix auth bypass
|
||||
|
||||
## 6.3 Database & Migration Validation (@database-optimizer)
|
||||
- [ ] **Test Migrations:** Run the server to ensure `ALTER TABLE` logic in `src/database/mod.rs` applies the new columns correctly.
|
||||
- [ ] **Schema Verification:** Verify `model_configs` has `cache_read_cost_per_m` and `cache_write_cost_per_m` columns.
|
||||
- [ ] `middleware/auth.go`: Return 401 instead of `c.Next()` when no auth header on `/v1/*`
|
||||
- [ ] Add `requireAuth` param to `AuthMiddleware` constructor: `AuthMiddleware(db, requireAuth bool)`
|
||||
- [ ] `/v1/*` routes → `requireAuth=true`, leave `/health` unauthed
|
||||
- [ ] Add tests: curl request without token → 401
|
||||
|
||||
## 6.4 Token Estimation Refinement (@lab-assistant)
|
||||
- [ ] **Analyze Heuristic:** Review `chars / 4` in `src/utils/tokens.rs`.
|
||||
- [ ] **Background Precise Recount:** Propose a mechanism for a precise token count (using Tiktoken) after the response is finalized.
|
||||
### 1.2 Fix WebSocket origin
|
||||
|
||||
## Critical Path
|
||||
Migration Validation → UI Fields → Provider Stream Usage Reporting.
|
||||
- [ ] `websocket.go`: Replace `return true` with origin check against configured `Server.Host`
|
||||
- [ ] Config option `websocket.allowed_origins []string` (default: same origin)
|
||||
- [ ] Add `xsrf` check on WS upgrade endpoint if behind proxy
|
||||
|
||||
### 1.3 Strip debug prints
|
||||
|
||||
- [ ] `config.go`: Remove `fmt.Printf("Debug Config:...")` and `fmt.Printf("Debug Env:...")`
|
||||
- [ ] `server.go` `logRequest()`: Remove `fmt.Printf("[DEBUG] Request logged:...")`
|
||||
- [ ] `config.go`: Remove `fmt.Printf("[DEBUG] Final Ollama Config:...")`
|
||||
- [ ] `providers/ollama.go`: Remove `fmt.Printf("[Ollama]...")` debug logs or gate behind `LLM_PROXY_DEBUG=1`
|
||||
- [ ] Replace all `fmt.Printf` with structured logger (slog from stdlib)
|
||||
|
||||
### 1.4 Fix registry data race
|
||||
|
||||
- [ ] `server.go`: Add `sync.RWMutex` around `s.registry`
|
||||
- [ ] `handleListModels()`: Lock read
|
||||
- [ ] `logRequest()`: Lock read
|
||||
- [ ] Background refresh goroutines: Lock write
|
||||
- [ ] Verify with `go run -race`
|
||||
|
||||
### 1.5 Session cleanup
|
||||
|
||||
- [ ] `sessions.go`: Add periodic cleanup goroutine for expired sessions
|
||||
- [ ] Cleanup interval: every 15 minutes
|
||||
- [ ] `RevokeSession`: Return error instead of silent no-op
|
||||
|
||||
---
|
||||
|
||||
## Phase 2 — Reliability & Observability (Weeks 3-4)
|
||||
|
||||
**Goal:** Error handling, timeouts, logging maturity, concurrency hardening.
|
||||
|
||||
### 2.1 Provider HTTP timeouts
|
||||
|
||||
- [ ] Each provider `New*Provider()`: Set `client.SetTimeout(30 * time.Second)` for non-stream
|
||||
- [ ] Streaming: No timeout, but add `context.Context` cancellation from request
|
||||
- [ ] `circuit_breaker.go`: Configure real thresholds
|
||||
- `MaxRequests: 5`
|
||||
- `Interval: 60 * time.Second`
|
||||
- `Timeout: 30 * time.Second`
|
||||
- `ReadyToTrip: func(counts) bool { return counts.ConsecutiveFailures > 3 }`
|
||||
- [ ] Test: Stop Ollama, hit endpoint → circuit opens after 3 failures → auto-recovers after 30s
|
||||
|
||||
### 2.2 Structured logging (slog)
|
||||
|
||||
- [ ] Create `internal/logger/logger.go` — `slog.NewJSONHandler`
|
||||
- [ ] Log levels: error/warn/info/debug
|
||||
- [ ] Replace all `fmt.Printf` in: server, providers, config, logging
|
||||
- [ ] `RequestLogger`: Use slog structured fields, remove manual JSON building
|
||||
- [ ] Log channel: increase buffer from 100 to 10000 or use batch insert every 5s
|
||||
|
||||
### 2.3 Stream error propagation
|
||||
|
||||
- [ ] `ChatCompletionStream`: Send error chunks as SSE events, not just `fmt.Printf`
|
||||
- [ ] Format: `data: {"error":"..."}\n\n`
|
||||
- [ ] Client sees full error in stream instead of silent truncation
|
||||
|
||||
### 2.4 Registry fetch retry
|
||||
|
||||
- [ ] `FetchRegistry()`: Add retry with backoff (3 tries, 1s/2s/4s)
|
||||
- [ ] Cache last-known-good registry so startup works offline
|
||||
|
||||
### 2.5 Token truncation safety
|
||||
|
||||
- [ ] `helpers.go`: Deep-copy ToolCall before truncation, don't mutate original
|
||||
- [ ] Same pattern across all providers that sanitize IDs
|
||||
|
||||
### 2.6 RevokeSession error handling
|
||||
|
||||
- [ ] `RevokeSession(token)` → `RevokeSession(token) error`
|
||||
- [ ] Update all callers to handle error
|
||||
|
||||
---
|
||||
|
||||
## Phase 3 — Architecture & Maintainability (Weeks 5-6)
|
||||
|
||||
**Goal:** Code splitting, test coverage, billing integrity.
|
||||
|
||||
### 3.1 Split dashboard.go
|
||||
|
||||
- [ ] Create `internal/server/clients.go` — client CRUD handlers
|
||||
- [ ] Create `internal/server/providers.go` — provider handlers
|
||||
- [ ] Create `internal/server/users.go` — user handlers
|
||||
- [ ] Create `internal/server/analytics.go` — usage/analytics handlers
|
||||
- [ ] Create `internal/server/system.go` — health, metrics, logs, backup
|
||||
- [ ] `dashboard.go` shrinks to imports + route wiring only
|
||||
|
||||
### 3.2 Provider routing via config
|
||||
|
||||
- [ ] Replace `strings.Contains` routing table with config-driven model→provider map
|
||||
- [ ] `config.go`: Add `server.model_routing` map (e.g. `"llama-*": "ollama"`)
|
||||
- [ ] Fallback chain: explicit match → prefix match → glob match → default
|
||||
- [ ] Backward-compat: keep old prefix logic as fallback
|
||||
|
||||
### 3.3 Billing integrity
|
||||
|
||||
- [ ] `logging.go`: Add idempotency key to log entries (unique request ID)
|
||||
- [ ] Before deducting balance, check if `request_id` already processed
|
||||
- [ ] `processLog`: Wrap in retry on serialization failure (SQLite busy)
|
||||
- [ ] Credit deduction: move to separate async worker with replay protection
|
||||
|
||||
### 3.4 Add tests
|
||||
|
||||
- [ ] `internal/models/`: Unit tests for `FindModel()`, message conversion
|
||||
- [ ] `internal/providers/helpers_test.go`: Unit tests for `MessagesToOpenAIJSON`, `ParseOpenAIResponse`
|
||||
- [ ] `internal/utils/`: Tests for `Encrypt`/`Decrypt`, `CalculateCost`
|
||||
- [ ] `internal/server/`: Integration test for auth flow (token → chat completion)
|
||||
- [ ] `internal/middleware/`: Test auth bypass fix
|
||||
- [ ] Goal: ≥40% coverage on non-UI packages
|
||||
|
||||
### 3.5 go.mod hygiene
|
||||
|
||||
- [ ] `go mod tidy` (done)
|
||||
- [ ] Add `go vet ./...` to CI/pre-commit hook
|
||||
- [ ] Pin dependencies with `go mod verify`
|
||||
|
||||
---
|
||||
|
||||
## Dependency Map
|
||||
|
||||
```
|
||||
Phase 1 ──────────────────────────▶ Phase 2 ──────────────────────────▶ Phase 3
|
||||
│ │ │
|
||||
├─ 1.1 Auth bypass ──────────▶ 2.3 Stream errors (depends on auth) │
|
||||
├─ 1.2 WS origin │ │
|
||||
├─ 1.3 Debug prints │ │
|
||||
├─ 1.4 Registry race │ │
|
||||
├─ 1.5 Session cleanup │ │
|
||||
│ ├─ 2.1 HTTP timeouts │
|
||||
│ ├─ 2.2 Structured logging ───────────▶ 3.3 Billing (depends on good logs)
|
||||
│ ├─ 2.4 Registry retry │
|
||||
│ ├─ 2.5 Token truncation │
|
||||
│ ├─ 2.6 RevokeSession errors │
|
||||
│ │
|
||||
│ ├─ 3.1 Split dashboard.go
|
||||
│ ├─ 3.2 Config routing
|
||||
│ ├─ 3.4 Tests
|
||||
│ ├─ 3.5 go.mod hygiene
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Mermaid Gantt
|
||||
|
||||
```mermaid
|
||||
gantt
|
||||
title Phase 6 Timeline
|
||||
title GopherGate Remediation
|
||||
dateFormat YYYY-MM-DD
|
||||
section Frontend
|
||||
Models Page UI :2026-03-06, 1d
|
||||
Costs Table Update:after Models Page UI, 1d
|
||||
section Backend
|
||||
DeepSeek Fix :2026-03-06, 1d
|
||||
Provider Audit (Grok/Gemini):after DeepSeek Fix, 2d
|
||||
section Database
|
||||
Migration Test :2026-03-06, 1d
|
||||
section Optimization
|
||||
Token Heuristic Review :2026-03-06, 1d
|
||||
axisFormat %b %d
|
||||
|
||||
section Phase 1 — Security
|
||||
Auth bypass fix :p1a, 2026-05-04, 2d
|
||||
WS origin lock :p1b, after p1a, 1d
|
||||
Strip debug prints :p1c, 2026-05-04, 2d
|
||||
Registry race fix :p1d, after p1c, 1d
|
||||
Session cleanup :p1e, after p1d, 2d
|
||||
|
||||
section Phase 2 — Reliability
|
||||
HTTP timeouts + CB :p2a, 2026-05-11, 3d
|
||||
Structured logging :p2b, 2026-05-11, 3d
|
||||
Stream error propagation :p2c, after p2a, 1d
|
||||
Registry retry :p2d, after p2b, 1d
|
||||
Token truncation fix :p2e, after p2a, 1d
|
||||
RevokeSession errors :p2f, after p2b, 1d
|
||||
|
||||
section Phase 3 — Architecture
|
||||
Split dashboard.go :p3a, 2026-05-25, 4d
|
||||
Config-driven routing :p3b, 2026-05-25, 3d
|
||||
Billing integrity :p3c, after p3a, 3d
|
||||
Add tests :p3d, 2026-06-01, 5d
|
||||
go.mod hygiene :p3e, after p3d, 1d
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Immediate Next Action
|
||||
|
||||
**Start 1.1 — Fix auth bypass:**
|
||||
|
||||
- Edit `middleware/auth.go` → change `c.Next()` to `c.AbortWithStatusJSON(401, ...)` when no header
|
||||
- Add `RequireAuth` bool param
|
||||
- Update `server.go` `setupRoutes()` to pass `requireAuth=true` for `/v1/*`
|
||||
- `curl localhost:8080/v1/chat/completions -d '{}'` → 401
|
||||
|
||||
@@ -1,120 +1,140 @@
|
||||
# LLM Proxy Gateway
|
||||
# GopherGate
|
||||
|
||||
A unified, high-performance LLM proxy gateway built in Rust. It provides a single OpenAI-compatible API to access multiple providers (OpenAI, Gemini, DeepSeek, Grok, Ollama) with built-in token tracking, real-time cost calculation, multi-user authentication, and a management dashboard.
|
||||
A unified, high-performance LLM proxy gateway built in Go. It provides OpenAI-compatible `/v1/chat/completions`, `/v1/images/generations`, `/v1/responses`, and `/v1/models` endpoints to access multiple providers (OpenAI, Gemini, DeepSeek, Moonshot, Grok, Ollama) with built-in token tracking, real-time cost calculation, multi-user authentication, and a management dashboard.
|
||||
|
||||
## Features
|
||||
|
||||
- **Unified API:** OpenAI-compatible `/v1/chat/completions` and `/v1/models` endpoints.
|
||||
- **Unified API:** OpenAI-compatible `/v1/chat/completions`, `/v1/images/generations`, `/v1/responses`, and `/v1/models` endpoints.
|
||||
- The `/v1/responses` endpoint (OpenAI Responses API) is currently supported for OpenAI models only. Non-OpenAI providers (Gemini, DeepSeek, Moonshot, Grok, Ollama) return a "not supported" response.
|
||||
- **Multi-Provider Support:**
|
||||
- **OpenAI:** GPT-4o, GPT-4o Mini, o1, o3 reasoning models.
|
||||
- **Google Gemini:** Gemini 2.0 Flash, Pro, and vision models.
|
||||
- **DeepSeek:** DeepSeek Chat and Reasoner models.
|
||||
- **xAI Grok:** Grok-beta models.
|
||||
- **OpenAI:** GPT-4o, GPT-4o Mini, o1, o3 reasoning models, DALL-E 2/3 image generation.
|
||||
- **Google Gemini:** Gemini 2.0 Flash, Pro, and vision models (with native CoT support), Imagen 3 image generation.
|
||||
- **DeepSeek:** DeepSeek Chat and Reasoner (R1) models.
|
||||
- **Moonshot:** Kimi K2.5 and other Kimi models.
|
||||
- **xAI Grok:** Grok-4 models.
|
||||
- **Ollama:** Local LLMs running on your network.
|
||||
- **Observability & Tracking:**
|
||||
- **Real-time Costing:** Fetches live pricing and context specs from `models.dev` on startup.
|
||||
- **Token Counting:** Precise estimation using `tiktoken-rs`.
|
||||
- **Database Logging:** Every request logged to SQLite for historical analysis.
|
||||
- **Streaming Support:** Full SSE (Server-Sent Events) with `[DONE]` termination for client compatibility.
|
||||
- **Asynchronous Logging:** Non-blocking request logging to SQLite using background workers.
|
||||
- **Token Counting:** Precise estimation and tracking of prompt, completion, and reasoning tokens.
|
||||
- **Database Persistence:** Every request logged to SQLite for historical analysis and dashboard analytics.
|
||||
- **Streaming Support:** Full SSE (Server-Sent Events) support for all providers.
|
||||
- **Multimodal (Vision):** Image processing (Base64 and remote URLs) across compatible providers.
|
||||
- **Image Generation:** DALL-E 2/3 (OpenAI) and Imagen 3 (Gemini) via OpenAI-compatible `/v1/images/generations` endpoint.
|
||||
- **Multi-User Access Control:**
|
||||
- **Admin Role:** Full access to all dashboard features, user management, and system configuration.
|
||||
- **Viewer Role:** Read-only access to usage analytics, costs, and monitoring.
|
||||
- **Client API Keys:** Create and manage multiple client tokens for external integrations.
|
||||
- **Reliability:**
|
||||
- **Circuit Breaking:** Automatically protects when providers are down.
|
||||
- **Rate Limiting:** Per-client and global rate limits.
|
||||
- **Cache-Aware Costing:** Tracks cache hit/miss tokens for accurate billing.
|
||||
- **Circuit Breaking:** Automatically protects when providers are down (coming soon).
|
||||
- **Rate Limiting:** Per-client and global rate limits (coming soon).
|
||||
|
||||
## Security
|
||||
|
||||
LLM Proxy is designed with security in mind:
|
||||
GopherGate is designed with security in mind:
|
||||
|
||||
- **HMAC Session Tokens:** Management dashboard sessions are secured using HMAC-SHA256 signed tokens.
|
||||
- **Encrypted Provider Keys:** Sensitive LLM provider API keys are stored encrypted (AES-256-GCM) in the database.
|
||||
- **Session Refresh:** Activity-based session extension prevents session hijacking while maintaining user convenience.
|
||||
- **XSS Prevention:** Standardized frontend escaping using `window.api.escapeHtml`.
|
||||
- **Signed Session Tokens:** Management dashboard sessions are secured using HMAC-SHA256 signed tokens.
|
||||
- **Encrypted Storage:** Support for encrypted provider API keys in the database.
|
||||
- **Auth Middleware:** Secure client authentication via database-backed API keys.
|
||||
|
||||
**Note:** You must define a `SESSION_SECRET` in your `.env` file for secure session signing.
|
||||
**Note:** You must define an `LLM_PROXY__ENCRYPTION_KEY` in your `.env` file for secure session signing and encryption.
|
||||
|
||||
## Tech Stack
|
||||
|
||||
- **Runtime:** Rust with Tokio.
|
||||
- **Web Framework:** Axum.
|
||||
- **Database:** SQLx with SQLite.
|
||||
- **Frontend:** Vanilla JS/CSS with Chart.js for visualizations.
|
||||
- **Runtime:** Go 1.22+
|
||||
- **Web Framework:** Gin Gonic
|
||||
- **Database:** sqlx with SQLite (CGO-free via `modernc.org/sqlite`)
|
||||
- **Frontend:** Vanilla JS/CSS with Chart.js for visualizations
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Rust (1.80+)
|
||||
- SQLite3
|
||||
- Go (1.22+)
|
||||
- SQLite3 (optional, driver is built-in)
|
||||
- Docker (optional, for containerized deployment)
|
||||
|
||||
### Quick Start
|
||||
|
||||
1. Clone and build:
|
||||
|
||||
```bash
|
||||
git clone ssh://git.dustin.coffee:2222/hobokenchicken/llm-proxy.git
|
||||
cd llm-proxy
|
||||
cargo build --release
|
||||
git clone <repository-url>
|
||||
cd gophergate
|
||||
go build -o gophergate ./cmd/gophergate
|
||||
```
|
||||
|
||||
2. Configure environment:
|
||||
|
||||
```bash
|
||||
cp .env.example .env
|
||||
# Edit .env and add your API keys:
|
||||
# SESSION_SECRET=... (Generate a strong random secret)
|
||||
# Edit .env and add your configuration:
|
||||
# LLM_PROXY__ENCRYPTION_KEY=... (32-byte hex or base64 string)
|
||||
# OPENAI_API_KEY=sk-...
|
||||
# GEMINI_API_KEY=AIza...
|
||||
# MOONSHOT_API_KEY=...
|
||||
# For Ollama (optional): Set base URL and enable
|
||||
# LLM_PROXY__PROVIDERS__OLLAMA__BASE_URL=http://localhost:11434/v1
|
||||
# LLM_PROXY__PROVIDERS__OLLAMA__ENABLED=true
|
||||
# LLM_PROXY__PROVIDERS__OLLAMA__MODELS=llama3,gemma2,mistral
|
||||
```
|
||||
|
||||
3. Run the proxy:
|
||||
```bash
|
||||
cargo run --release
|
||||
./gophergate
|
||||
```
|
||||
|
||||
The server starts on `http://localhost:8080` by default.
|
||||
The server starts on `http://0.0.0.0:8080` by default.
|
||||
|
||||
### Deployment (Docker)
|
||||
|
||||
A multi-stage `Dockerfile` is provided for efficient deployment:
|
||||
|
||||
```bash
|
||||
# Build the container
|
||||
docker build -t llm-proxy .
|
||||
docker build -t gophergate .
|
||||
|
||||
# Run the container
|
||||
docker run -p 8080:8080 \
|
||||
-e SESSION_SECRET=your-secure-secret \
|
||||
-e LLM_PROXY__ENCRYPTION_KEY=your-secure-key \
|
||||
-v ./data:/app/data \
|
||||
llm-proxy
|
||||
gophergate
|
||||
```
|
||||
|
||||
## Management Dashboard
|
||||
|
||||
Access the dashboard at `http://localhost:8080`. The dashboard architecture has been refactored into modular sub-components for better maintainability:
|
||||
Access the dashboard at `http://localhost:8080`.
|
||||
|
||||
- **Auth (`/api/auth`):** Login, session management, and password changes.
|
||||
- **Usage (`/api/usage`):** Summary stats, time-series analytics, and provider breakdown.
|
||||
- **Clients (`/api/clients`):** API key management and per-client usage tracking.
|
||||
- **Providers (`/api/providers`):** Provider configuration, status monitoring, and connection testing.
|
||||
- **System (`/api/system`):** Health metrics, live logs, database backups, and global settings.
|
||||
- **Auth:** Login, session management, and status tracking.
|
||||
- **Usage:** Summary stats, time-series analytics, and provider breakdown.
|
||||
- **Clients:** API key management and per-client usage tracking.
|
||||
- **Providers:** Provider configuration and status monitoring.
|
||||
- **Users:** Admin-only user management for dashboard access.
|
||||
- **Monitoring:** Live request stream via WebSocket.
|
||||
|
||||
### Default Credentials
|
||||
|
||||
- **Username:** `admin`
|
||||
- **Password:** `admin123`
|
||||
- **Password:** `admin123` (You will be prompted to change this on first login)
|
||||
|
||||
Change the admin password in the dashboard after first login!
|
||||
**Forgot Password?**
|
||||
You can reset the admin password to default by running:
|
||||
|
||||
```bash
|
||||
./gophergate -reset-admin
|
||||
```
|
||||
|
||||
## API Usage
|
||||
|
||||
The proxy is a drop-in replacement for OpenAI. Configure your client:
|
||||
|
||||
Moonshot models are available through the same OpenAI-compatible endpoint. For
|
||||
example, use `kimi-k2.5` as the model name after setting `MOONSHOT_API_KEY` in
|
||||
your environment.
|
||||
|
||||
Ollama models (like `llama3`, `gemma2`, `mistral`) are also available through the same
|
||||
endpoint after enabling Ollama in configuration and setting the base URL to your
|
||||
Ollama server (default: `http://localhost:11434/v1`).
|
||||
|
||||
### Python
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
@@ -129,6 +149,58 @@ response = client.chat.completions.create(
|
||||
)
|
||||
```
|
||||
|
||||
### Responses API
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:8080/v1",
|
||||
api_key="YOUR_CLIENT_API_KEY"
|
||||
)
|
||||
|
||||
# OpenAI Responses API (supported for OpenAI models only)
|
||||
response = client.responses.create(
|
||||
model="gpt-4o",
|
||||
input="Explain quantum computing in one paragraph.",
|
||||
instructions="You are a helpful assistant.",
|
||||
temperature=0.7,
|
||||
max_output_tokens=500
|
||||
)
|
||||
print(response.output_text)
|
||||
```
|
||||
|
||||
**Note:** The `/v1/responses` endpoint is currently supported for OpenAI models only. Requests routed to Gemini, DeepSeek, Moonshot, Grok, or Ollama models return a "not supported" error.
|
||||
|
||||
### Image Generation (DALL-E / Imagen)
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:8080/v1",
|
||||
api_key="YOUR_CLIENT_API_KEY"
|
||||
)
|
||||
|
||||
# DALL-E 3 (OpenAI)
|
||||
resp = client.images.generate(
|
||||
model="dall-e-3",
|
||||
prompt="A cute gopher wearing a top hat",
|
||||
n=1,
|
||||
size="1024x1024"
|
||||
)
|
||||
print(resp.data[0].url)
|
||||
|
||||
# Imagen 3 (Gemini) — uses same endpoint
|
||||
resp = client.images.generate(
|
||||
model="imagen-3.0-generate-001",
|
||||
prompt="A gopher coding in Go",
|
||||
n=1,
|
||||
size="1024x1024"
|
||||
)
|
||||
print(resp.data[0].url) # Returns data URI (Gemini returns base64)
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
MIT OR Apache-2.0
|
||||
MIT
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
# LLM Proxy Rust Backend Code Review Report
|
||||
|
||||
## Executive Summary
|
||||
This code review examines the `llm-proxy` Rust backend, focusing on architectural soundness, performance characteristics, and adherence to Rust idioms. The codebase demonstrates solid engineering with well-structured modular design, comprehensive error handling, and thoughtful async patterns. However, several areas require attention for production readiness, particularly around thread safety, memory efficiency, and error recovery.
|
||||
|
||||
## 1. Core Proxy Logic Review
|
||||
### Strengths
|
||||
- Clean provider abstraction (`Provider` trait).
|
||||
- Streaming support with `AggregatingStream` for token counting.
|
||||
- Model mapping and caching system.
|
||||
|
||||
### Issues Found
|
||||
- **Provider Manager Thread Safety Risk:** O(n) lookups using `Vec` with `RwLock`. Use `DashMap` instead.
|
||||
- **Streaming Memory Inefficiency:** Accumulates complete response content in memory.
|
||||
- **Model Registry Cache Invalidation:** No strategy when config changes via dashboard.
|
||||
|
||||
## 2. State Management Review
|
||||
- **Token Bucket Algorithm Flaw:** Custom implementation lacks thread-safe refill sync. Use `governor` crate.
|
||||
- **Broadcast Channel Unbounded Growth Risk:** Fixed-size (100) may drop messages.
|
||||
- **Database Connection Pool Contention:** SQLite connections shared without differentiation.
|
||||
|
||||
## 3. Error Handling Review
|
||||
- **Error Recovery Missing:** No circuit breaker or retry logic for provider calls.
|
||||
- **Stream Error Logging Gap:** Stream errors swallowed without logging partial usage.
|
||||
|
||||
## 4. Rust Idioms and Performance
|
||||
- **Unnecessary String Cloning:** Frequent cloning in authentication hot paths.
|
||||
- **JSON Parsing Inefficiency:** Multiple passes with `serde_json::Value`. Use typed structs.
|
||||
- **Missing `#[derive(Copy)]`:** For small enums like `CircuitState`.
|
||||
|
||||
## 5. Async Performance
|
||||
- **Blocking Calls:** Token estimation may block async runtime.
|
||||
- **Missing Connection Timeouts:** Only overall timeout, no separate read/write timeouts.
|
||||
- **Unbounded Task Spawn:** For client usage updates under load.
|
||||
|
||||
## 6. Security Considerations
|
||||
- **Token Leakage:** Redact tokens in `Debug` and `Display` impls.
|
||||
- **No Request Size Limits:** Vulnerable to memory exhaustion.
|
||||
|
||||
## 7. Testability
|
||||
- **Mocking Difficulty:** Tight coupling to concrete provider implementations.
|
||||
- **Missing Integration Tests:** No E2E tests for streaming.
|
||||
|
||||
## 8. Summary of Critical Actions
|
||||
### High Priority
|
||||
1. Replace custom token bucket with `governor` crate.
|
||||
2. Fix provider lookup O(n) scaling issue.
|
||||
3. Implement proper error recovery with retries.
|
||||
4. Add request size limits and timeout configurations.
|
||||
|
||||
### Medium Priority
|
||||
1. Reduce string cloning in hot paths.
|
||||
2. Implement cache invalidation for model configs.
|
||||
3. Add connection pooling separation.
|
||||
4. Improve streaming memory efficiency.
|
||||
|
||||
---
|
||||
*Review conducted by: Senior Principal Engineer (Code Reviewer)*
|
||||
@@ -1,58 +0,0 @@
|
||||
# LLM Proxy Security Audit Report
|
||||
|
||||
## Executive Summary
|
||||
A comprehensive security audit of the `llm-proxy` repository was conducted. The audit identified **1 critical vulnerability**, **3 high-risk issues**, **4 medium-risk issues**, and **3 low-risk issues**. The most severe findings include Cross-Site Scripting (XSS) in the dashboard interface and insecure storage of provider API keys in the database.
|
||||
|
||||
## Detailed Findings
|
||||
|
||||
### Critical Risk Vulnerabilities
|
||||
#### **CRITICAL-01: Cross-Site Scripting (XSS) in Dashboard Interface**
|
||||
- **Location**: `static/js/pages/clients.js` (multiple locations).
|
||||
- **Description**: User-controlled data (e.g., `client.id`) inserted directly into HTML or `onclick` handlers without escaping.
|
||||
- **Impact**: Arbitrary JavaScript execution in admin context, potentially stealing session tokens.
|
||||
|
||||
#### **CRITICAL-02: Insecure API Key Storage in Database**
|
||||
- **Location**: `src/database/mod.rs`, `src/providers/mod.rs`, `src/dashboard/providers.rs`.
|
||||
- **Description**: Provider API keys are stored in **plaintext** in the SQLite database.
|
||||
- **Impact**: Compromised database file exposes all provider API keys.
|
||||
|
||||
### High Risk Vulnerabilities
|
||||
#### **HIGH-01: Missing Input Validation and Size Limits**
|
||||
- **Location**: `src/server/mod.rs`, `src/models/mod.rs`.
|
||||
- **Impact**: Denial of Service via large payloads.
|
||||
|
||||
#### **HIGH-02: Sensitive Data Logging Without Encryption**
|
||||
- **Location**: `src/database/mod.rs`, `src/logging/mod.rs`.
|
||||
- **Description**: Full request and response bodies stored in `llm_requests` table without encryption or redaction.
|
||||
|
||||
#### **HIGH-03: Weak Default Credentials and Password Policy**
|
||||
- **Description**: Default admin password is 'admin' with only 4-character minimum password length.
|
||||
|
||||
### Medium Risk Vulnerabilities
|
||||
#### **MEDIUM-01: Missing CSRF Protection**
|
||||
- No CSRF tokens or SameSite cookie attributes for state-changing dashboard endpoints.
|
||||
|
||||
#### **MEDIUM-02: Insecure Session Management**
|
||||
- Session tokens stored in localStorage without HttpOnly flag.
|
||||
- Tokens use simple `session-{uuid}` format.
|
||||
|
||||
#### **MEDIUM-03: Error Information Leakage**
|
||||
- Internal error details exposed to clients in some cases.
|
||||
|
||||
#### **MEDIUM-04: Outdated Dependencies**
|
||||
- Outdated versions of `chrono`, `tokio`, and `reqwest`.
|
||||
|
||||
### Low Risk Vulnerabilities
|
||||
- Missing security headers (CSP, HSTS, X-Frame-Options).
|
||||
- Insufficient rate limiting on dashboard authentication.
|
||||
- No database encryption at rest.
|
||||
|
||||
## Recommendations
|
||||
### Immediate Actions
|
||||
1. **Fix XSS Vulnerabilities:** Implement proper HTML escaping for all user-controlled data.
|
||||
2. **Secure API Key Storage:** Encrypt API keys in database using a library like `ring`.
|
||||
3. **Implement Input Validation:** Add maximum payload size limits (e.g., 10MB).
|
||||
4. **Improve Data Protection:** Add option to disable request/response body logging.
|
||||
|
||||
---
|
||||
*Report generated by Security Auditor Agent on March 6, 2026*
|
||||
@@ -0,0 +1,70 @@
|
||||
# Migration TODO List
|
||||
|
||||
## Completed Tasks
|
||||
- [x] Initial Go project setup
|
||||
- [x] Database schema & migrations (hardcoded in `db.go`)
|
||||
- [x] Configuration loader (Viper)
|
||||
- [x] Auth Middleware (scoped to `/v1`)
|
||||
- [x] Basic Provider implementations (OpenAI, Gemini, DeepSeek, Grok, Ollama)
|
||||
- [x] Streaming Support (SSE & Gemini custom streaming)
|
||||
- [x] Archive Rust files to `rust` branch
|
||||
- [x] Clean root and set Go version as `main`
|
||||
- [x] Enhanced `helpers.go` for Multimodal & Tool Calling (OpenAI compatible)
|
||||
- [x] Enhanced `server.go` for robust request conversion
|
||||
- [x] Dashboard Management APIs (Clients, Tokens, Users, Providers)
|
||||
- [x] Dashboard Analytics & Usage Summary (Fixed SQL robustness)
|
||||
- [x] WebSocket for real-time dashboard updates (Hub with client counting)
|
||||
- [x] Asynchronous Request Logging to SQLite
|
||||
- [x] Update documentation (README, deployment, architecture)
|
||||
- [x] Cost Tracking accuracy (Registry integration with `models.dev`)
|
||||
- [x] Model Listing endpoint (`/v1/models`) with provider filtering
|
||||
- [x] System Metrics endpoint (`/api/system/metrics` using `gopsutil`)
|
||||
- [x] Fixed dashboard 404s and 500s
|
||||
|
||||
## Planned Resolutions (High Priority)
|
||||
|
||||
### Security Fixes
|
||||
- [x] **Critical:** Fix `AuthMiddleware` to reject invalid tokens instead of falling back to insecure prefix derivation.
|
||||
|
||||
### Feature Parity Checklist (High Priority)
|
||||
|
||||
### OpenAI Provider
|
||||
- [x] Tool Calling
|
||||
- [x] Multimodal (Images) support
|
||||
- [x] Accurate usage parsing (cached & reasoning tokens)
|
||||
### Feature Parity: OpenAI Provider Enhancements
|
||||
- [x] **Reasoning Content (CoT) Support (`o1`/`o3`):**
|
||||
- [x] Infrastructure verified. `reasoning_content` is mapped in request/response structures.
|
||||
- [x] **Support for `/v1/responses` API:**
|
||||
- [x] Implemented new route in `internal/server/server.go`.
|
||||
|
||||
### Gemini Provider
|
||||
- [x] Tool Calling (mapping to Gemini format)
|
||||
- [x] Multimodal (Images) support
|
||||
- [x] Reasoning/Thought support
|
||||
- [x] Handle Tool Response role in unified format
|
||||
|
||||
### DeepSeek Provider
|
||||
- [x] Reasoning Content (CoT) support
|
||||
- [x] Parameter sanitization for `deepseek-reasoner`
|
||||
- [x] Tool Calling support
|
||||
- [x] Accurate usage parsing (cache hits & reasoning)
|
||||
|
||||
### Grok Provider
|
||||
- [x] Tool Calling support
|
||||
- [x] Multimodal support
|
||||
- [x] Accurate usage parsing (via OpenAI helper)
|
||||
|
||||
### Ollama Provider
|
||||
- [x] OpenAI-compatible API integration
|
||||
- [x] Streaming support
|
||||
- [x] Model pattern detection for routing
|
||||
- [x] Zero cost calculation (local/free models)
|
||||
|
||||
## Infrastructure & Middleware
|
||||
- [ ] Implement Rate Limiting (`golang.org/x/time/rate`)
|
||||
- [x] Implement Circuit Breaker (`github.com/sony/gobreaker`)
|
||||
|
||||
## Verification
|
||||
- [ ] Unit tests for feature-specific mapping (CoT, Tools, Images)
|
||||
- [ ] Integration tests with live LLM APIs
|
||||
@@ -1 +0,0 @@
|
||||
too-many-arguments-threshold = 8
|
||||
@@ -0,0 +1,55 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"gophergate/internal/config"
|
||||
"gophergate/internal/db"
|
||||
"gophergate/internal/server"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
func main() {
|
||||
resetAdmin := flag.Bool("reset-admin", false, "Reset admin password to admin123")
|
||||
flag.Parse()
|
||||
|
||||
// Load environment variables
|
||||
if err := godotenv.Load(); err != nil {
|
||||
log.Println("No .env file found")
|
||||
}
|
||||
|
||||
// Load configuration
|
||||
cfg, err := config.Load()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to load configuration: %v", err)
|
||||
}
|
||||
|
||||
// Initialize database
|
||||
database, err := db.Init(cfg.Database.Path)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to initialize database: %v", err)
|
||||
}
|
||||
|
||||
if *resetAdmin {
|
||||
hash, _ := bcrypt.GenerateFromPassword([]byte("admin123"), 12)
|
||||
_, err = database.Exec("UPDATE users SET password_hash = ?, must_change_password = 1 WHERE username = 'admin'", string(hash))
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to reset admin password: %v", err)
|
||||
}
|
||||
log.Println("Admin password has been reset to 'admin123'")
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
// Initialize server
|
||||
s := server.NewServer(cfg, database)
|
||||
|
||||
// Run server
|
||||
log.Printf("Starting GopherGate on %s:%d", cfg.Server.Host, cfg.Server.Port)
|
||||
if err := s.Run(); err != nil {
|
||||
log.Fatalf("Server failed: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/jmoiron/sqlx"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
type MyNullTime struct {
|
||||
Time interface{}
|
||||
Type string
|
||||
}
|
||||
|
||||
func (n *MyNullTime) Scan(value interface{}) error {
|
||||
n.Time = value
|
||||
n.Type = fmt.Sprintf("%T", value)
|
||||
return nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
db, err := sqlx.Connect("sqlite", "/home/newkirk/Documents/projects/web_projects/gophergate/data/backups/llm_proxy.db.20260303T205057Z")
|
||||
if err != nil {
|
||||
fmt.Println("connect err:", err)
|
||||
return
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Test 1: Direct column scan type
|
||||
var d MyNullTime
|
||||
db.Get(&d, "SELECT last_used_at FROM client_tokens WHERE client_id = ? LIMIT 1", "sk-opencode")
|
||||
fmt.Printf("direct SELECT: GoType=%s value=%v\n", d.Type, d.Time)
|
||||
|
||||
// Test 2: MAX aggregate scan type
|
||||
var m MyNullTime
|
||||
db.Get(&m, "SELECT MAX(last_used_at) FROM client_tokens WHERE client_id = ?", "sk-opencode")
|
||||
fmt.Printf("MAX SELECT: GoType=%s value=%v\n", m.Type, m.Time)
|
||||
|
||||
// Test 3: peek at the raw driver types
|
||||
row := db.QueryRow("SELECT last_used_at, MAX(last_used_at) FROM client_tokens WHERE client_id = ? LIMIT 1", "sk-opencode")
|
||||
var a, b interface{}
|
||||
row.Scan(&a, &b)
|
||||
fmt.Printf("\nRaw Scan:\n")
|
||||
fmt.Printf(" last_used_at: type=%T val=%v\n", a, a)
|
||||
fmt.Printf(" MAX(last_used_at): type=%T val=%v\n", b, b)
|
||||
}
|
||||
Binary file not shown.
BIN
Binary file not shown.
@@ -1,667 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# LLM Proxy Gateway Deployment Script
|
||||
# This script automates the deployment of the LLM Proxy Gateway on a Linux server
|
||||
|
||||
set -e # Exit on error
|
||||
set -u # Exit on undefined variable
|
||||
|
||||
# Configuration
|
||||
APP_NAME="llm-proxy"
|
||||
APP_USER="llmproxy"
|
||||
APP_GROUP="llmproxy"
|
||||
GIT_REPO="ssh://git.dustin.coffee:2222/hobokenchicken/llm-proxy.git"
|
||||
INSTALL_DIR="/opt/$APP_NAME"
|
||||
CONFIG_DIR="/etc/$APP_NAME"
|
||||
DATA_DIR="/var/lib/$APP_NAME"
|
||||
LOG_DIR="/var/log/$APP_NAME"
|
||||
SERVICE_FILE="/etc/systemd/system/$APP_NAME.service"
|
||||
ENV_FILE="$CONFIG_DIR/.env"
|
||||
|
||||
# Colors for output
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# Logging functions
|
||||
log_info() {
|
||||
echo -e "${GREEN}[INFO]${NC} $1"
|
||||
}
|
||||
|
||||
log_warn() {
|
||||
echo -e "${YELLOW}[WARN]${NC} $1"
|
||||
}
|
||||
|
||||
log_error() {
|
||||
echo -e "${RED}[ERROR]${NC} $1"
|
||||
}
|
||||
|
||||
# Check if running as root
|
||||
check_root() {
|
||||
if [[ $EUID -ne 0 ]]; then
|
||||
log_error "This script must be run as root"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Install system dependencies
|
||||
install_dependencies() {
|
||||
log_info "Installing system dependencies..."
|
||||
|
||||
# Detect package manager
|
||||
if command -v apt-get &> /dev/null; then
|
||||
# Debian/Ubuntu
|
||||
apt-get update
|
||||
apt-get install -y \
|
||||
build-essential \
|
||||
pkg-config \
|
||||
libssl-dev \
|
||||
sqlite3 \
|
||||
curl \
|
||||
git
|
||||
elif command -v yum &> /dev/null; then
|
||||
# RHEL/CentOS
|
||||
yum groupinstall -y "Development Tools"
|
||||
yum install -y \
|
||||
openssl-devel \
|
||||
sqlite \
|
||||
curl \
|
||||
git
|
||||
elif command -v dnf &> /dev/null; then
|
||||
# Fedora
|
||||
dnf groupinstall -y "Development Tools"
|
||||
dnf install -y \
|
||||
openssl-devel \
|
||||
sqlite \
|
||||
curl \
|
||||
git
|
||||
elif command -v pacman &> /dev/null; then
|
||||
# Arch Linux
|
||||
pacman -Syu --noconfirm \
|
||||
base-devel \
|
||||
openssl \
|
||||
sqlite \
|
||||
curl \
|
||||
git
|
||||
else
|
||||
log_warn "Could not detect package manager. Please install dependencies manually."
|
||||
fi
|
||||
}
|
||||
|
||||
# Install Rust if not present
|
||||
install_rust() {
|
||||
log_info "Checking for Rust installation..."
|
||||
|
||||
if ! command -v rustc &> /dev/null; then
|
||||
log_info "Installing Rust..."
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
source "$HOME/.cargo/env"
|
||||
else
|
||||
log_info "Rust is already installed"
|
||||
fi
|
||||
|
||||
# Verify installation
|
||||
rustc --version
|
||||
cargo --version
|
||||
}
|
||||
|
||||
# Create system user and directories
|
||||
setup_directories() {
|
||||
log_info "Creating system user and directories..."
|
||||
|
||||
# Create user and group if they don't exist
|
||||
if ! id "$APP_USER" &>/dev/null; then
|
||||
# Arch uses /usr/bin/nologin, Debian/Ubuntu use /usr/sbin/nologin
|
||||
NOLOGIN=$(command -v nologin 2>/dev/null || echo "/usr/bin/nologin")
|
||||
useradd -r -s "$NOLOGIN" -M "$APP_USER"
|
||||
fi
|
||||
|
||||
# Create directories
|
||||
mkdir -p "$INSTALL_DIR"
|
||||
mkdir -p "$CONFIG_DIR"
|
||||
mkdir -p "$DATA_DIR"
|
||||
mkdir -p "$LOG_DIR"
|
||||
|
||||
# Set permissions
|
||||
chown -R "$APP_USER:$APP_GROUP" "$INSTALL_DIR"
|
||||
chown -R "$APP_USER:$APP_GROUP" "$CONFIG_DIR"
|
||||
chown -R "$APP_USER:$APP_GROUP" "$DATA_DIR"
|
||||
chown -R "$APP_USER:$APP_GROUP" "$LOG_DIR"
|
||||
|
||||
chmod 750 "$INSTALL_DIR"
|
||||
chmod 750 "$CONFIG_DIR"
|
||||
chmod 750 "$DATA_DIR"
|
||||
chmod 750 "$LOG_DIR"
|
||||
}
|
||||
|
||||
# Build the application
|
||||
build_application() {
|
||||
log_info "Building the application..."
|
||||
|
||||
# Clone or update repository
|
||||
if [[ ! -d "$INSTALL_DIR/.git" ]]; then
|
||||
log_info "Cloning repository..."
|
||||
git clone "$GIT_REPO" "$INSTALL_DIR"
|
||||
else
|
||||
log_info "Updating repository..."
|
||||
cd "$INSTALL_DIR"
|
||||
git pull
|
||||
fi
|
||||
|
||||
# Build in release mode
|
||||
cd "$INSTALL_DIR"
|
||||
log_info "Building release binary..."
|
||||
cargo build --release
|
||||
|
||||
# Verify build
|
||||
if [[ -f "target/release/$APP_NAME" ]]; then
|
||||
log_info "Build successful"
|
||||
else
|
||||
log_error "Build failed"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Create configuration files
|
||||
create_configuration() {
|
||||
log_info "Creating configuration files..."
|
||||
|
||||
# Create .env file with API keys
|
||||
cat > "$ENV_FILE" << EOF
|
||||
# LLM Proxy Gateway Environment Variables
|
||||
# Add your API keys here
|
||||
|
||||
# OpenAI API Key
|
||||
# OPENAI_API_KEY=sk-your-key-here
|
||||
|
||||
# Google Gemini API Key
|
||||
# GEMINI_API_KEY=AIza-your-key-here
|
||||
|
||||
# DeepSeek API Key
|
||||
# DEEPSEEK_API_KEY=sk-your-key-here
|
||||
|
||||
# xAI Grok API Key
|
||||
# GROK_API_KEY=gk-your-key-here
|
||||
|
||||
# Authentication tokens (comma-separated)
|
||||
# LLM_PROXY__SERVER__AUTH_TOKENS=token1,token2,token3
|
||||
EOF
|
||||
|
||||
# Create config.toml
|
||||
cat > "$CONFIG_DIR/config.toml" << EOF
|
||||
# LLM Proxy Gateway Configuration
|
||||
|
||||
[server]
|
||||
port = 8080
|
||||
host = "0.0.0.0"
|
||||
# auth_tokens = ["token1", "token2", "token3"] # Uncomment to enable authentication
|
||||
|
||||
[database]
|
||||
path = "$DATA_DIR/llm_proxy.db"
|
||||
max_connections = 5
|
||||
|
||||
[providers.openai]
|
||||
enabled = true
|
||||
api_key_env = "OPENAI_API_KEY"
|
||||
base_url = "https://api.openai.com/v1"
|
||||
default_model = "gpt-4o"
|
||||
|
||||
[providers.gemini]
|
||||
enabled = true
|
||||
api_key_env = "GEMINI_API_KEY"
|
||||
base_url = "https://generativelanguage.googleapis.com/v1"
|
||||
default_model = "gemini-2.0-flash"
|
||||
|
||||
[providers.deepseek]
|
||||
enabled = true
|
||||
api_key_env = "DEEPSEEK_API_KEY"
|
||||
base_url = "https://api.deepseek.com"
|
||||
default_model = "deepseek-reasoner"
|
||||
|
||||
[providers.grok]
|
||||
enabled = false # Disabled by default until API is researched
|
||||
api_key_env = "GROK_API_KEY"
|
||||
base_url = "https://api.x.ai/v1"
|
||||
default_model = "grok-beta"
|
||||
|
||||
[model_mapping]
|
||||
"gpt-*" = "openai"
|
||||
"gemini-*" = "gemini"
|
||||
"deepseek-*" = "deepseek"
|
||||
"grok-*" = "grok"
|
||||
|
||||
[pricing]
|
||||
openai = { input = 0.01, output = 0.03 }
|
||||
gemini = { input = 0.0005, output = 0.0015 }
|
||||
deepseek = { input = 0.00014, output = 0.00028 }
|
||||
grok = { input = 0.001, output = 0.003 }
|
||||
EOF
|
||||
|
||||
# Set permissions
|
||||
chown "$APP_USER:$APP_GROUP" "$ENV_FILE"
|
||||
chown "$APP_USER:$APP_GROUP" "$CONFIG_DIR/config.toml"
|
||||
chmod 640 "$ENV_FILE"
|
||||
chmod 640 "$CONFIG_DIR/config.toml"
|
||||
}
|
||||
|
||||
# Create systemd service
|
||||
create_systemd_service() {
|
||||
log_info "Creating systemd service..."
|
||||
|
||||
cat > "$SERVICE_FILE" << EOF
|
||||
[Unit]
|
||||
Description=LLM Proxy Gateway
|
||||
Documentation=https://git.dustin.coffee/hobokenchicken/llm-proxy
|
||||
After=network.target
|
||||
Wants=network.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
User=$APP_USER
|
||||
Group=$APP_GROUP
|
||||
WorkingDirectory=$INSTALL_DIR
|
||||
EnvironmentFile=$ENV_FILE
|
||||
Environment="RUST_LOG=info"
|
||||
Environment="LLM_PROXY__CONFIG_PATH=$CONFIG_DIR/config.toml"
|
||||
Environment="LLM_PROXY__DATABASE__PATH=$DATA_DIR/llm_proxy.db"
|
||||
ExecStart=$INSTALL_DIR/target/release/$APP_NAME
|
||||
Restart=on-failure
|
||||
RestartSec=5
|
||||
|
||||
# Security hardening
|
||||
NoNewPrivileges=true
|
||||
PrivateTmp=true
|
||||
ProtectSystem=strict
|
||||
ProtectHome=true
|
||||
ReadWritePaths=$DATA_DIR $LOG_DIR
|
||||
|
||||
# Resource limits (adjust based on your server)
|
||||
MemoryMax=400M
|
||||
MemorySwapMax=100M
|
||||
CPUQuota=50%
|
||||
LimitNOFILE=65536
|
||||
|
||||
# Logging
|
||||
StandardOutput=journal
|
||||
StandardError=journal
|
||||
SyslogIdentifier=$APP_NAME
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
EOF
|
||||
|
||||
# Reload systemd
|
||||
systemctl daemon-reload
|
||||
}
|
||||
|
||||
# Setup nginx reverse proxy (optional)
|
||||
setup_nginx_proxy() {
|
||||
if ! command -v nginx &> /dev/null; then
|
||||
log_warn "nginx not installed. Skipping reverse proxy setup."
|
||||
return
|
||||
fi
|
||||
|
||||
log_info "Setting up nginx reverse proxy..."
|
||||
|
||||
cat > "/etc/nginx/sites-available/$APP_NAME" << EOF
|
||||
server {
|
||||
listen 80;
|
||||
server_name your-domain.com; # Change to your domain
|
||||
|
||||
# Redirect to HTTPS (recommended)
|
||||
return 301 https://\$server_name\$request_uri;
|
||||
}
|
||||
|
||||
server {
|
||||
listen 443 ssl http2;
|
||||
server_name your-domain.com; # Change to your domain
|
||||
|
||||
# SSL certificates (adjust paths)
|
||||
ssl_certificate /etc/letsencrypt/live/your-domain.com/fullchain.pem;
|
||||
ssl_certificate_key /etc/letsencrypt/live/your-domain.com/privkey.pem;
|
||||
|
||||
# SSL configuration
|
||||
ssl_protocols TLSv1.2 TLSv1.3;
|
||||
ssl_ciphers ECDHE-RSA-AES256-GCM-SHA512:DHE-RSA-AES256-GCM-SHA512:ECDHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES256-GCM-SHA384;
|
||||
ssl_prefer_server_ciphers off;
|
||||
|
||||
# Proxy to LLM Proxy Gateway
|
||||
location / {
|
||||
proxy_pass http://127.0.0.1:8080;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Upgrade \$http_upgrade;
|
||||
proxy_set_header Connection "upgrade";
|
||||
proxy_set_header Host \$host;
|
||||
proxy_set_header X-Real-IP \$remote_addr;
|
||||
proxy_set_header X-Forwarded-For \$proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto \$scheme;
|
||||
|
||||
# Timeouts
|
||||
proxy_connect_timeout 60s;
|
||||
proxy_send_timeout 60s;
|
||||
proxy_read_timeout 60s;
|
||||
}
|
||||
|
||||
# Health check endpoint
|
||||
location /health {
|
||||
proxy_pass http://127.0.0.1:8080/health;
|
||||
access_log off;
|
||||
}
|
||||
|
||||
# Dashboard
|
||||
location /dashboard {
|
||||
proxy_pass http://127.0.0.1:8080/dashboard;
|
||||
}
|
||||
}
|
||||
EOF
|
||||
|
||||
# Enable site
|
||||
ln -sf "/etc/nginx/sites-available/$APP_NAME" "/etc/nginx/sites-enabled/"
|
||||
|
||||
# Test nginx configuration
|
||||
nginx -t
|
||||
|
||||
log_info "nginx configuration created. Please update the domain and SSL certificate paths."
|
||||
}
|
||||
|
||||
# Setup firewall
|
||||
setup_firewall() {
|
||||
log_info "Configuring firewall..."
|
||||
|
||||
# Check for ufw (Ubuntu)
|
||||
if command -v ufw &> /dev/null; then
|
||||
ufw allow 22/tcp # SSH
|
||||
ufw allow 80/tcp # HTTP
|
||||
ufw allow 443/tcp # HTTPS
|
||||
ufw --force enable
|
||||
log_info "UFW firewall configured"
|
||||
fi
|
||||
|
||||
# Check for firewalld (RHEL/CentOS)
|
||||
if command -v firewall-cmd &> /dev/null; then
|
||||
firewall-cmd --permanent --add-service=ssh
|
||||
firewall-cmd --permanent --add-service=http
|
||||
firewall-cmd --permanent --add-service=https
|
||||
firewall-cmd --reload
|
||||
log_info "Firewalld configured"
|
||||
fi
|
||||
}
|
||||
|
||||
# Initialize database
|
||||
initialize_database() {
|
||||
log_info "Initializing database..."
|
||||
|
||||
# Run the application once to create database
|
||||
sudo -u "$APP_USER" "$INSTALL_DIR/target/release/$APP_NAME" --help &> /dev/null || true
|
||||
|
||||
log_info "Database initialized at $DATA_DIR/llm_proxy.db"
|
||||
}
|
||||
|
||||
# Start and enable service
|
||||
start_service() {
|
||||
log_info "Starting $APP_NAME service..."
|
||||
|
||||
systemctl enable "$APP_NAME"
|
||||
systemctl start "$APP_NAME"
|
||||
|
||||
# Check status
|
||||
sleep 2
|
||||
systemctl status "$APP_NAME" --no-pager
|
||||
}
|
||||
|
||||
# Verify installation
|
||||
verify_installation() {
|
||||
log_info "Verifying installation..."
|
||||
|
||||
# Check if service is running
|
||||
if systemctl is-active --quiet "$APP_NAME"; then
|
||||
log_info "Service is running"
|
||||
else
|
||||
log_error "Service is not running"
|
||||
journalctl -u "$APP_NAME" -n 20 --no-pager
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Test health endpoint
|
||||
if curl -s http://localhost:8080/health | grep -q "OK"; then
|
||||
log_info "Health check passed"
|
||||
else
|
||||
log_error "Health check failed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Test dashboard
|
||||
if curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/dashboard | grep -q "200"; then
|
||||
log_info "Dashboard is accessible"
|
||||
else
|
||||
log_warn "Dashboard may not be accessible (this is normal if not configured)"
|
||||
fi
|
||||
|
||||
log_info "Installation verified successfully!"
|
||||
}
|
||||
|
||||
# Print next steps
|
||||
print_next_steps() {
|
||||
cat << EOF
|
||||
|
||||
${GREEN}=== LLM Proxy Gateway Installation Complete ===${NC}
|
||||
|
||||
${YELLOW}Next steps:${NC}
|
||||
|
||||
1. ${GREEN}Configure API keys${NC}
|
||||
Edit: $ENV_FILE
|
||||
Add your API keys for the providers you want to use
|
||||
|
||||
2. ${GREEN}Configure authentication${NC}
|
||||
Edit: $CONFIG_DIR/config.toml
|
||||
Uncomment and set auth_tokens for client authentication
|
||||
|
||||
3. ${GREEN}Configure nginx${NC}
|
||||
Edit: /etc/nginx/sites-available/$APP_NAME
|
||||
Update domain name and SSL certificate paths
|
||||
|
||||
4. ${GREEN}Test the API${NC}
|
||||
curl -X POST http://localhost:8080/v1/chat/completions \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-H "Authorization: Bearer your-token" \\
|
||||
-d '{
|
||||
"model": "gpt-4o",
|
||||
"messages": [{"role": "user", "content": "Hello!"}]
|
||||
}'
|
||||
|
||||
5. ${GREEN}Access the dashboard${NC}
|
||||
Open: http://your-server-ip:8080/dashboard
|
||||
Or: https://your-domain.com/dashboard (if nginx configured)
|
||||
|
||||
${YELLOW}Useful commands:${NC}
|
||||
systemctl status $APP_NAME # Check service status
|
||||
journalctl -u $APP_NAME -f # View logs
|
||||
systemctl restart $APP_NAME # Restart service
|
||||
|
||||
${YELLOW}Configuration files:${NC}
|
||||
Service: $SERVICE_FILE
|
||||
Config: $CONFIG_DIR/config.toml
|
||||
Environment: $ENV_FILE
|
||||
Database: $DATA_DIR/llm_proxy.db
|
||||
Logs: $LOG_DIR/
|
||||
|
||||
${GREEN}For more information, see:${NC}
|
||||
https://git.dustin.coffee/hobokenchicken/llm-proxy
|
||||
$INSTALL_DIR/README.md
|
||||
$INSTALL_DIR/deployment.md
|
||||
|
||||
EOF
|
||||
}
|
||||
|
||||
# Main deployment function
|
||||
deploy() {
|
||||
log_info "Starting LLM Proxy Gateway deployment..."
|
||||
|
||||
check_root
|
||||
install_dependencies
|
||||
install_rust
|
||||
setup_directories
|
||||
build_application
|
||||
create_configuration
|
||||
create_systemd_service
|
||||
initialize_database
|
||||
start_service
|
||||
verify_installation
|
||||
print_next_steps
|
||||
|
||||
# Optional steps (uncomment if needed)
|
||||
# setup_nginx_proxy
|
||||
# setup_firewall
|
||||
|
||||
log_info "Deployment completed successfully!"
|
||||
}
|
||||
|
||||
# Update function
|
||||
update() {
|
||||
log_info "Updating LLM Proxy Gateway..."
|
||||
|
||||
check_root
|
||||
|
||||
# Pull latest changes (while service keeps running)
|
||||
cd "$INSTALL_DIR"
|
||||
log_info "Pulling latest changes..."
|
||||
git pull
|
||||
|
||||
# Build new binary (service stays up on the old binary)
|
||||
log_info "Building release binary (service still running)..."
|
||||
if ! cargo build --release; then
|
||||
log_error "Build failed — service was NOT interrupted. Fix the error and try again."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Verify binary exists
|
||||
if [[ ! -f "target/release/$APP_NAME" ]]; then
|
||||
log_error "Binary not found after build — aborting."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Restart service to pick up new binary
|
||||
log_info "Build succeeded. Restarting service..."
|
||||
systemctl restart "$APP_NAME"
|
||||
|
||||
sleep 2
|
||||
if systemctl is-active --quiet "$APP_NAME"; then
|
||||
log_info "Update completed successfully!"
|
||||
systemctl status "$APP_NAME" --no-pager
|
||||
else
|
||||
log_error "Service failed to start after update. Check logs:"
|
||||
journalctl -u "$APP_NAME" -n 20 --no-pager
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Uninstall function
|
||||
uninstall() {
|
||||
log_info "Uninstalling LLM Proxy Gateway..."
|
||||
|
||||
check_root
|
||||
|
||||
# Stop and disable service
|
||||
systemctl stop "$APP_NAME" 2>/dev/null || true
|
||||
systemctl disable "$APP_NAME" 2>/dev/null || true
|
||||
rm -f "$SERVICE_FILE"
|
||||
systemctl daemon-reload
|
||||
|
||||
# Remove application files
|
||||
rm -rf "$INSTALL_DIR"
|
||||
rm -rf "$CONFIG_DIR"
|
||||
|
||||
# Keep data and logs (comment out to remove)
|
||||
log_warn "Data directory $DATA_DIR and logs $LOG_DIR have been preserved"
|
||||
log_warn "Remove manually if desired:"
|
||||
log_warn " rm -rf $DATA_DIR $LOG_DIR"
|
||||
|
||||
# Remove user (optional)
|
||||
read -p "Remove user $APP_USER? [y/N]: " -n 1 -r
|
||||
echo
|
||||
if [[ $REPLY =~ ^[Yy]$ ]]; then
|
||||
userdel "$APP_USER" 2>/dev/null || true
|
||||
groupdel "$APP_GROUP" 2>/dev/null || true
|
||||
fi
|
||||
|
||||
log_info "Uninstallation completed!"
|
||||
}
|
||||
|
||||
# Show usage
|
||||
usage() {
|
||||
cat << EOF
|
||||
LLM Proxy Gateway Deployment Script
|
||||
|
||||
Usage: $0 [command]
|
||||
|
||||
Commands:
|
||||
deploy - Install and configure LLM Proxy Gateway
|
||||
update - Pull latest changes, rebuild, and restart
|
||||
status - Show service status and health check
|
||||
logs - Tail the service logs (Ctrl+C to stop)
|
||||
uninstall - Remove LLM Proxy Gateway
|
||||
help - Show this help message
|
||||
|
||||
Examples:
|
||||
$0 deploy # Full installation
|
||||
$0 update # Update to latest version
|
||||
$0 status # Check if service is healthy
|
||||
$0 logs # Follow live logs
|
||||
|
||||
EOF
|
||||
}
|
||||
|
||||
# Status function
|
||||
status() {
|
||||
echo ""
|
||||
log_info "Service status:"
|
||||
systemctl status "$APP_NAME" --no-pager 2>/dev/null || log_warn "Service not found"
|
||||
echo ""
|
||||
|
||||
# Health check
|
||||
if curl -sf http://localhost:8080/health &>/dev/null; then
|
||||
log_info "Health check: OK"
|
||||
else
|
||||
log_warn "Health check: FAILED (service may not be running or port 8080 not responding)"
|
||||
fi
|
||||
|
||||
# Show current git commit
|
||||
if [[ -d "$INSTALL_DIR/.git" ]]; then
|
||||
echo ""
|
||||
log_info "Installed version:"
|
||||
git -C "$INSTALL_DIR" log -1 --format=" %h %s (%cr)" 2>/dev/null
|
||||
fi
|
||||
}
|
||||
|
||||
# Logs function
|
||||
logs() {
|
||||
log_info "Tailing $APP_NAME logs (Ctrl+C to stop)..."
|
||||
journalctl -u "$APP_NAME" -f
|
||||
}
|
||||
|
||||
# Parse command line arguments
|
||||
case "${1:-}" in
|
||||
deploy)
|
||||
deploy
|
||||
;;
|
||||
update)
|
||||
update
|
||||
;;
|
||||
status)
|
||||
status
|
||||
;;
|
||||
logs)
|
||||
logs
|
||||
;;
|
||||
uninstall)
|
||||
uninstall
|
||||
;;
|
||||
help|--help|-h)
|
||||
usage
|
||||
;;
|
||||
*)
|
||||
usage
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
+33
-303
@@ -1,322 +1,52 @@
|
||||
# LLM Proxy Gateway - Deployment Guide
|
||||
# Deployment Guide (Go)
|
||||
|
||||
## Overview
|
||||
A unified LLM proxy gateway supporting OpenAI, Google Gemini, DeepSeek, and xAI Grok with token tracking, cost calculation, and admin dashboard.
|
||||
This guide covers deploying the Go-based GopherGate.
|
||||
|
||||
## System Requirements
|
||||
- **CPU**: 2 cores minimum
|
||||
- **RAM**: 512MB minimum (1GB recommended)
|
||||
- **Storage**: 10GB minimum
|
||||
- **OS**: Linux (tested on Arch Linux, Ubuntu, Debian)
|
||||
- **Runtime**: Rust 1.70+ with Cargo
|
||||
## Environment Setup
|
||||
|
||||
## Deployment Options
|
||||
|
||||
### Option 1: Docker (Recommended)
|
||||
```dockerfile
|
||||
FROM rust:1.70-alpine as builder
|
||||
WORKDIR /app
|
||||
COPY . .
|
||||
RUN cargo build --release
|
||||
|
||||
FROM alpine:latest
|
||||
RUN apk add --no-cache libgcc
|
||||
COPY --from=builder /app/target/release/llm-proxy /usr/local/bin/
|
||||
COPY --from=builder /app/static /app/static
|
||||
WORKDIR /app
|
||||
EXPOSE 8080
|
||||
CMD ["llm-proxy"]
|
||||
```
|
||||
|
||||
### Option 2: Systemd Service (Bare Metal/LXC)
|
||||
```ini
|
||||
# /etc/systemd/system/llm-proxy.service
|
||||
[Unit]
|
||||
Description=LLM Proxy Gateway
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
User=llmproxy
|
||||
Group=llmproxy
|
||||
WorkingDirectory=/opt/llm-proxy
|
||||
ExecStart=/opt/llm-proxy/llm-proxy
|
||||
Restart=always
|
||||
RestartSec=10
|
||||
Environment="RUST_LOG=info"
|
||||
Environment="LLM_PROXY__SERVER__PORT=8080"
|
||||
Environment="LLM_PROXY__SERVER__AUTH_TOKENS=sk-test-123,sk-test-456"
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
```
|
||||
|
||||
### Option 3: LXC Container (Proxmox)
|
||||
1. Create Alpine Linux LXC container
|
||||
2. Install Rust: `apk add rust cargo`
|
||||
3. Copy application files
|
||||
4. Build: `cargo build --release`
|
||||
5. Run: `./target/release/llm-proxy`
|
||||
|
||||
## Configuration
|
||||
|
||||
### Environment Variables
|
||||
```bash
|
||||
# Required API Keys
|
||||
OPENAI_API_KEY=sk-...
|
||||
GEMINI_API_KEY=AIza...
|
||||
DEEPSEEK_API_KEY=sk-...
|
||||
GROK_API_KEY=gk-... # Optional
|
||||
|
||||
# Server Configuration (with LLM_PROXY__ prefix)
|
||||
LLM_PROXY__SERVER__PORT=8080
|
||||
LLM_PROXY__SERVER__HOST=0.0.0.0
|
||||
LLM_PROXY__SERVER__AUTH_TOKENS=sk-test-123,sk-test-456
|
||||
|
||||
# Database Configuration
|
||||
LLM_PROXY__DATABASE__PATH=./data/llm_proxy.db
|
||||
LLM_PROXY__DATABASE__MAX_CONNECTIONS=10
|
||||
|
||||
# Provider Configuration
|
||||
LLM_PROXY__PROVIDERS__OPENAI__ENABLED=true
|
||||
LLM_PROXY__PROVIDERS__GEMINI__ENABLED=true
|
||||
LLM_PROXY__PROVIDERS__DEEPSEEK__ENABLED=true
|
||||
LLM_PROXY__PROVIDERS__GROK__ENABLED=false
|
||||
```
|
||||
|
||||
### Configuration File (config.toml)
|
||||
Create `config.toml` in the application directory:
|
||||
```toml
|
||||
[server]
|
||||
port = 8080
|
||||
host = "0.0.0.0"
|
||||
auth_tokens = ["sk-test-123", "sk-test-456"]
|
||||
|
||||
[database]
|
||||
path = "./data/llm_proxy.db"
|
||||
max_connections = 10
|
||||
|
||||
[providers.openai]
|
||||
enabled = true
|
||||
base_url = "https://api.openai.com/v1"
|
||||
default_model = "gpt-4o"
|
||||
|
||||
[providers.gemini]
|
||||
enabled = true
|
||||
base_url = "https://generativelanguage.googleapis.com/v1"
|
||||
default_model = "gemini-2.0-flash"
|
||||
|
||||
[providers.deepseek]
|
||||
enabled = true
|
||||
base_url = "https://api.deepseek.com"
|
||||
default_model = "deepseek-reasoner"
|
||||
|
||||
[providers.grok]
|
||||
enabled = false
|
||||
base_url = "https://api.x.ai/v1"
|
||||
default_model = "grok-beta"
|
||||
```
|
||||
|
||||
## Nginx Reverse Proxy Configuration
|
||||
|
||||
**Important for SSE/Streaming:** Disable buffering and configure timeouts for proper SSE support.
|
||||
|
||||
```nginx
|
||||
server {
|
||||
listen 80;
|
||||
server_name llm-proxy.yourdomain.com;
|
||||
|
||||
location / {
|
||||
proxy_pass http://localhost:8080;
|
||||
proxy_http_version 1.1;
|
||||
|
||||
# SSE/Streaming support
|
||||
proxy_buffering off;
|
||||
chunked_transfer_encoding on;
|
||||
proxy_set_header Connection '';
|
||||
|
||||
# Timeouts for long-running streams
|
||||
proxy_connect_timeout 7200s;
|
||||
proxy_read_timeout 7200s;
|
||||
proxy_send_timeout 7200s;
|
||||
|
||||
# Disable gzip for streaming
|
||||
gzip off;
|
||||
|
||||
# Headers
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
}
|
||||
|
||||
# SSL configuration (recommended)
|
||||
listen 443 ssl http2;
|
||||
ssl_certificate /etc/letsencrypt/live/llm-proxy.yourdomain.com/fullchain.pem;
|
||||
ssl_certificate_key /etc/letsencrypt/live/llm-proxy.yourdomain.com/privkey.pem;
|
||||
}
|
||||
```
|
||||
|
||||
### NGINX Proxy Manager
|
||||
|
||||
If using NGINX Proxy Manager, add this to **Advanced Settings**:
|
||||
|
||||
```nginx
|
||||
proxy_buffering off;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Connection '';
|
||||
chunked_transfer_encoding on;
|
||||
proxy_connect_timeout 7200s;
|
||||
proxy_read_timeout 7200s;
|
||||
proxy_send_timeout 7200s;
|
||||
gzip off;
|
||||
```
|
||||
|
||||
## Security Considerations
|
||||
|
||||
### 1. Authentication
|
||||
- Use strong Bearer tokens
|
||||
- Rotate tokens regularly
|
||||
- Consider implementing JWT for production
|
||||
|
||||
### 2. Rate Limiting
|
||||
- Implement per-client rate limiting
|
||||
- Consider using `governor` crate for advanced rate limiting
|
||||
|
||||
### 3. Network Security
|
||||
- Run behind reverse proxy (nginx)
|
||||
- Enable HTTPS
|
||||
- Restrict access by IP if needed
|
||||
- Use firewall rules
|
||||
|
||||
### 4. Data Security
|
||||
- Database encryption (SQLCipher for SQLite)
|
||||
- Secure API key storage
|
||||
- Regular backups
|
||||
|
||||
## Monitoring & Maintenance
|
||||
|
||||
### Logging
|
||||
- Application logs: `RUST_LOG=info` (or `debug` for troubleshooting)
|
||||
- Access logs via nginx
|
||||
- Database logs for audit trail
|
||||
|
||||
### Health Checks
|
||||
```bash
|
||||
# Health endpoint
|
||||
curl http://localhost:8080/health
|
||||
|
||||
# Database check
|
||||
sqlite3 ./data/llm_proxy.db "SELECT COUNT(*) FROM llm_requests;"
|
||||
```
|
||||
|
||||
### Backup Strategy
|
||||
```bash
|
||||
#!/bin/bash
|
||||
# backup.sh
|
||||
BACKUP_DIR="/backups/llm-proxy"
|
||||
DATE=$(date +%Y%m%d_%H%M%S)
|
||||
|
||||
# Backup database
|
||||
sqlite3 ./data/llm_proxy.db ".backup $BACKUP_DIR/llm_proxy_$DATE.db"
|
||||
|
||||
# Backup configuration
|
||||
cp config.toml $BACKUP_DIR/config_$DATE.toml
|
||||
|
||||
# Rotate old backups (keep 30 days)
|
||||
find $BACKUP_DIR -name "*.db" -mtime +30 -delete
|
||||
find $BACKUP_DIR -name "*.toml" -mtime +30 -delete
|
||||
```
|
||||
|
||||
## Performance Tuning
|
||||
|
||||
### Database Optimization
|
||||
```sql
|
||||
-- Run these SQL commands periodically
|
||||
VACUUM;
|
||||
ANALYZE;
|
||||
```
|
||||
|
||||
### Memory Management
|
||||
- Monitor memory usage with `htop` or `ps aux`
|
||||
- Adjust `max_connections` based on load
|
||||
- Consider connection pooling for high traffic
|
||||
|
||||
### Scaling
|
||||
1. **Vertical Scaling**: Increase container resources
|
||||
2. **Horizontal Scaling**: Deploy multiple instances behind load balancer
|
||||
3. **Database**: Migrate to PostgreSQL for high-volume usage
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Port already in use**
|
||||
1. **Mandatory Configuration:**
|
||||
Create a `.env` file from the example:
|
||||
```bash
|
||||
netstat -tulpn | grep :8080
|
||||
kill <PID> # or change port in config
|
||||
cp .env.example .env
|
||||
```
|
||||
Ensure `LLM_PROXY__ENCRYPTION_KEY` is set to a secure 32-byte string.
|
||||
|
||||
2. **Database permissions**
|
||||
```bash
|
||||
chown -R llmproxy:llmproxy /opt/llm-proxy/data
|
||||
chmod 600 /opt/llm-proxy/data/llm_proxy.db
|
||||
```
|
||||
2. **Data Directory:**
|
||||
The proxy stores its database in `./data/llm_proxy.db` by default. Ensure this directory exists and is writable.
|
||||
|
||||
3. **API key errors**
|
||||
- Verify environment variables are set
|
||||
- Check provider status (dashboard)
|
||||
- Test connectivity: `curl https://api.openai.com/v1/models`
|
||||
## Binary Deployment
|
||||
|
||||
4. **High memory usage**
|
||||
- Check for memory leaks
|
||||
- Reduce `max_connections`
|
||||
- Implement connection timeouts
|
||||
|
||||
### Debug Mode
|
||||
### 1. Build
|
||||
```bash
|
||||
# Run with debug logging
|
||||
RUST_LOG=debug ./llm-proxy
|
||||
|
||||
# Check system logs
|
||||
journalctl -u llm-proxy -f
|
||||
go build -o gophergate ./cmd/gophergate
|
||||
```
|
||||
|
||||
## Integration
|
||||
|
||||
### Open-WebUI Compatibility
|
||||
The proxy provides OpenAI-compatible API, so configure Open-WebUI:
|
||||
```
|
||||
API Base URL: http://your-proxy-address:8080
|
||||
API Key: sk-test-123 (or your configured token)
|
||||
### 2. Run
|
||||
```bash
|
||||
./gophergate
|
||||
```
|
||||
|
||||
### Custom Clients
|
||||
```python
|
||||
import openai
|
||||
## Docker Deployment
|
||||
|
||||
client = openai.OpenAI(
|
||||
base_url="http://localhost:8080/v1",
|
||||
api_key="sk-test-123"
|
||||
)
|
||||
The project includes a multi-stage `Dockerfile` for minimal image size.
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-4",
|
||||
messages=[{"role": "user", "content": "Hello"}]
|
||||
)
|
||||
### 1. Build Image
|
||||
```bash
|
||||
docker build -t gophergate .
|
||||
```
|
||||
|
||||
## Updates & Upgrades
|
||||
### 2. Run Container
|
||||
```bash
|
||||
docker run -d \
|
||||
--name gophergate \
|
||||
-p 8080:8080 \
|
||||
-v $(pwd)/data:/app/data \
|
||||
--env-file .env \
|
||||
gophergate
|
||||
```
|
||||
|
||||
1. **Backup** current configuration and database
|
||||
2. **Stop** the service: `systemctl stop llm-proxy`
|
||||
3. **Update** code: `git pull` or copy new binaries
|
||||
4. **Migrate** database if needed (check migrations/)
|
||||
5. **Restart**: `systemctl start llm-proxy`
|
||||
6. **Verify**: Check logs and test endpoints
|
||||
## Production Considerations
|
||||
|
||||
## Support
|
||||
- Check logs in `/var/log/llm-proxy/`
|
||||
- Monitor dashboard at `http://your-server:8080`
|
||||
- Review database metrics in dashboard
|
||||
- Enable debug logging for troubleshooting
|
||||
- **SSL/TLS:** It is recommended to run the proxy behind a reverse proxy like Nginx or Caddy for SSL termination.
|
||||
- **Backups:** Regularly backup the `data/llm_proxy.db` file.
|
||||
- **Monitoring:** Monitor the `/health` endpoint for system status.
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
module gophergate
|
||||
|
||||
go 1.26.1
|
||||
|
||||
require (
|
||||
github.com/gin-gonic/gin v1.12.0
|
||||
github.com/go-resty/resty/v2 v2.17.2
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/jmoiron/sqlx v1.4.0
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/shirou/gopsutil/v3 v3.24.5
|
||||
github.com/sony/gobreaker v1.0.0
|
||||
github.com/spf13/viper v1.21.0
|
||||
golang.org/x/crypto v0.48.0
|
||||
modernc.org/sqlite v1.47.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/bytedance/gopkg v0.1.3 // indirect
|
||||
github.com/bytedance/sonic v1.15.0 // indirect
|
||||
github.com/bytedance/sonic/loader v0.5.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.6 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.12 // indirect
|
||||
github.com/gin-contrib/sse v1.1.0 // indirect
|
||||
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.30.1 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
|
||||
github.com/goccy/go-json v0.10.5 // indirect
|
||||
github.com/goccy/go-yaml v1.19.2 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
||||
github.com/quic-go/qpack v0.6.0 // indirect
|
||||
github.com/quic-go/quic-go v0.59.0 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/sagikazarmark/locafero v0.11.0 // indirect
|
||||
github.com/shoenig/go-m1cpu v0.1.6 // indirect
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
|
||||
github.com/spf13/afero v1.15.0 // indirect
|
||||
github.com/spf13/cast v1.10.0 // indirect
|
||||
github.com/spf13/pflag v1.0.10 // indirect
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.12 // indirect
|
||||
github.com/tklauser/numcpus v0.6.1 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.3.1 // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||
go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect
|
||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||
golang.org/x/arch v0.22.0 // indirect
|
||||
golang.org/x/net v0.51.0 // indirect
|
||||
golang.org/x/sys v0.42.0 // indirect
|
||||
golang.org/x/text v0.34.0 // indirect
|
||||
golang.org/x/time v0.15.0 // indirect
|
||||
google.golang.org/protobuf v1.36.10 // indirect
|
||||
modernc.org/libc v1.70.0 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.11.0 // indirect
|
||||
)
|
||||
@@ -0,0 +1,209 @@
|
||||
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||
github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M=
|
||||
github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM=
|
||||
github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE=
|
||||
github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k=
|
||||
github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE=
|
||||
github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo=
|
||||
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
|
||||
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/gabriel-vasile/mimetype v1.4.12 h1:e9hWvmLYvtp846tLHam2o++qitpguFiYCKbn0w9jyqw=
|
||||
github.com/gabriel-vasile/mimetype v1.4.12/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
|
||||
github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w=
|
||||
github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM=
|
||||
github.com/gin-gonic/gin v1.12.0 h1:b3YAbrZtnf8N//yjKeU2+MQsh2mY5htkZidOM7O0wG8=
|
||||
github.com/gin-gonic/gin v1.12.0/go.mod h1:VxccKfsSllpKshkBWgVgRniFFAzFb9csfngsqANjnLc=
|
||||
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||
github.com/go-playground/validator/v10 v10.30.1 h1:f3zDSN/zOma+w6+1Wswgd9fLkdwy06ntQJp0BBvFG0w=
|
||||
github.com/go-playground/validator/v10 v10.30.1/go.mod h1:oSuBIQzuJxL//3MelwSLD5hc2Tu889bF0Idm9Dg26cM=
|
||||
github.com/go-resty/resty/v2 v2.17.2 h1:FQW5oHYcIlkCNrMD2lloGScxcHJ0gkjshV3qcQAyHQk=
|
||||
github.com/go-resty/resty/v2 v2.17.2/go.mod h1:kCKZ3wWmwJaNc7S29BRtUhJwy7iqmn+2mLtQrOyQlVA=
|
||||
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
|
||||
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs=
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
||||
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
||||
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM=
|
||||
github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
||||
github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o=
|
||||
github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY=
|
||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
|
||||
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
||||
github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8=
|
||||
github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII=
|
||||
github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SAw=
|
||||
github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
|
||||
github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc=
|
||||
github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik=
|
||||
github.com/shirou/gopsutil/v3 v3.24.5 h1:i0t8kL+kQTvpAYToeuiVk3TgDeKOFioZO3Ztz/iZ9pI=
|
||||
github.com/shirou/gopsutil/v3 v3.24.5/go.mod h1:bsoOS1aStSs9ErQ1WWfxllSeS1K5D+U30r2NfcubMVk=
|
||||
github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM=
|
||||
github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ=
|
||||
github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU=
|
||||
github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k=
|
||||
github.com/sony/gobreaker v1.0.0 h1:feX5fGGXSl3dYd4aHZItw+FpHLvvoaqkawKjVNiFMNQ=
|
||||
github.com/sony/gobreaker v1.0.0/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY=
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw=
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U=
|
||||
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
|
||||
github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg=
|
||||
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
|
||||
github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo=
|
||||
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
|
||||
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU=
|
||||
github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
||||
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
|
||||
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
|
||||
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.3.1 h1:waO7eEiFDwidsBN6agj1vJQ4AG7lh2yqXyOXqhgQuyY=
|
||||
github.com/ugorji/go/codec v1.3.1/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4=
|
||||
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||
go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE=
|
||||
go.mongodb.org/mongo-driver/v2 v2.5.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0=
|
||||
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
|
||||
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
|
||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI=
|
||||
golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
|
||||
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
|
||||
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
|
||||
golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=
|
||||
golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w=
|
||||
golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
|
||||
golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
|
||||
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
|
||||
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
|
||||
golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
|
||||
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
|
||||
golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k=
|
||||
golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
|
||||
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
|
||||
modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
|
||||
modernc.org/ccgo/v4 v4.32.0 h1:hjG66bI/kqIPX1b2yT6fr/jt+QedtP2fqojG2VrFuVw=
|
||||
modernc.org/ccgo/v4 v4.32.0/go.mod h1:6F08EBCx5uQc38kMGl+0Nm0oWczoo1c7cgpzEry7Uc0=
|
||||
modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM=
|
||||
modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU=
|
||||
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
|
||||
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
|
||||
modernc.org/gc/v3 v3.1.2 h1:ZtDCnhonXSZexk/AYsegNRV1lJGgaNZJuKjJSWKyEqo=
|
||||
modernc.org/gc/v3 v3.1.2/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
|
||||
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
|
||||
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
|
||||
modernc.org/libc v1.70.0 h1:U58NawXqXbgpZ/dcdS9kMshu08aiA6b7gusEusqzNkw=
|
||||
modernc.org/libc v1.70.0/go.mod h1:OVmxFGP1CI/Z4L3E0Q3Mf1PDE0BucwMkcXjjLntvHJo=
|
||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
|
||||
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
|
||||
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
||||
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
||||
modernc.org/sqlite v1.47.0 h1:R1XyaNpoW4Et9yly+I2EeX7pBza/w+pmYee/0HJDyKk=
|
||||
modernc.org/sqlite v1.47.0/go.mod h1:hWjRO6Tj/5Ik8ieqxQybiEOUXy0NJFNp2tpvVpKlvig=
|
||||
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
|
||||
@@ -0,0 +1,222 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
Database DatabaseConfig `mapstructure:"database"`
|
||||
Providers ProviderConfig `mapstructure:"providers"`
|
||||
EncryptionKey string `mapstructure:"encryption_key"`
|
||||
KeyBytes []byte
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
Port int `mapstructure:"port"`
|
||||
Host string `mapstructure:"host"`
|
||||
AuthTokens []string `mapstructure:"auth_tokens"`
|
||||
WSAllowedOrigin string `mapstructure:"ws_allowed_origin"`
|
||||
}
|
||||
|
||||
type DatabaseConfig struct {
|
||||
Path string `mapstructure:"path"`
|
||||
MaxConnections int `mapstructure:"max_connections"`
|
||||
}
|
||||
|
||||
type ProviderConfig struct {
|
||||
OpenAI OpenAIConfig `mapstructure:"openai"`
|
||||
Gemini GeminiConfig `mapstructure:"gemini"`
|
||||
DeepSeek DeepSeekConfig `mapstructure:"deepseek"`
|
||||
Moonshot MoonshotConfig `mapstructure:"moonshot"`
|
||||
Grok GrokConfig `mapstructure:"grok"`
|
||||
Ollama OllamaConfig `mapstructure:"ollama"`
|
||||
}
|
||||
|
||||
type OpenAIConfig struct {
|
||||
APIKeyEnv string `mapstructure:"api_key_env"`
|
||||
BaseURL string `mapstructure:"base_url"`
|
||||
DefaultModel string `mapstructure:"default_model"`
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
}
|
||||
|
||||
type GeminiConfig struct {
|
||||
APIKeyEnv string `mapstructure:"api_key_env"`
|
||||
BaseURL string `mapstructure:"base_url"`
|
||||
DefaultModel string `mapstructure:"default_model"`
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
}
|
||||
|
||||
type DeepSeekConfig struct {
|
||||
APIKeyEnv string `mapstructure:"api_key_env"`
|
||||
BaseURL string `mapstructure:"base_url"`
|
||||
DefaultModel string `mapstructure:"default_model"`
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
}
|
||||
|
||||
type MoonshotConfig struct {
|
||||
APIKeyEnv string `mapstructure:"api_key_env"`
|
||||
BaseURL string `mapstructure:"base_url"`
|
||||
DefaultModel string `mapstructure:"default_model"`
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
}
|
||||
|
||||
type GrokConfig struct {
|
||||
APIKeyEnv string `mapstructure:"api_key_env"`
|
||||
BaseURL string `mapstructure:"base_url"`
|
||||
DefaultModel string `mapstructure:"default_model"`
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
}
|
||||
|
||||
type OllamaConfig struct {
|
||||
BaseURL string `mapstructure:"base_url"`
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
DefaultModel string `mapstructure:"default_model"`
|
||||
Models []string `mapstructure:"models"`
|
||||
}
|
||||
|
||||
func Load() (*Config, error) {
|
||||
v := viper.New()
|
||||
|
||||
// Defaults
|
||||
v.SetDefault("server.port", 8080)
|
||||
v.SetDefault("server.host", "0.0.0.0")
|
||||
v.SetDefault("server.auth_tokens", []string{})
|
||||
v.SetDefault("database.path", "./data/llm_proxy.db")
|
||||
v.SetDefault("database.max_connections", 10)
|
||||
|
||||
v.SetDefault("providers.openai.api_key_env", "OPENAI_API_KEY")
|
||||
v.SetDefault("providers.openai.base_url", "https://api.openai.com/v1")
|
||||
v.SetDefault("providers.openai.default_model", "gpt-4o")
|
||||
v.SetDefault("providers.openai.enabled", true)
|
||||
|
||||
v.SetDefault("providers.gemini.api_key_env", "GEMINI_API_KEY")
|
||||
v.SetDefault("providers.gemini.base_url", "https://generativelanguage.googleapis.com/v1")
|
||||
v.SetDefault("providers.gemini.default_model", "gemini-2.0-flash")
|
||||
v.SetDefault("providers.gemini.enabled", true)
|
||||
|
||||
v.SetDefault("providers.deepseek.api_key_env", "DEEPSEEK_API_KEY")
|
||||
v.SetDefault("providers.deepseek.base_url", "https://api.deepseek.com")
|
||||
v.SetDefault("providers.deepseek.default_model", "deepseek-reasoner")
|
||||
v.SetDefault("providers.deepseek.enabled", true)
|
||||
|
||||
v.SetDefault("providers.moonshot.api_key_env", "MOONSHOT_API_KEY")
|
||||
v.SetDefault("providers.moonshot.base_url", "https://api.moonshot.ai/v1")
|
||||
v.SetDefault("providers.moonshot.default_model", "kimi-k2.5")
|
||||
v.SetDefault("providers.moonshot.enabled", true)
|
||||
|
||||
v.SetDefault("providers.grok.api_key_env", "GROK_API_KEY")
|
||||
v.SetDefault("providers.grok.base_url", "https://api.x.ai/v1")
|
||||
v.SetDefault("providers.grok.default_model", "grok-4-1-fast-non-reasoning")
|
||||
v.SetDefault("providers.grok.enabled", true)
|
||||
|
||||
v.SetDefault("providers.ollama.base_url", "http://localhost:11434/v1")
|
||||
v.SetDefault("providers.ollama.enabled", false)
|
||||
v.SetDefault("providers.ollama.models", []string{})
|
||||
|
||||
// Environment variables
|
||||
v.SetEnvPrefix("LLM_PROXY")
|
||||
v.SetEnvKeyReplacer(strings.NewReplacer(".", "__"))
|
||||
v.AutomaticEnv()
|
||||
|
||||
// Explicitly bind keys that might use double underscores in .env
|
||||
v.BindEnv("encryption_key", "LLM_PROXY__ENCRYPTION_KEY")
|
||||
v.BindEnv("server.port", "LLM_PROXY__SERVER__PORT")
|
||||
v.BindEnv("server.host", "LLM_PROXY__SERVER__HOST")
|
||||
v.BindEnv("providers.ollama.enabled", "LLM_PROXY__PROVIDERS__OLLAMA__ENABLED")
|
||||
v.BindEnv("providers.ollama.base_url", "LLM_PROXY__PROVIDERS__OLLAMA__BASE_URL")
|
||||
v.BindEnv("providers.ollama.models", "LLM_PROXY__PROVIDERS__OLLAMA__MODELS")
|
||||
|
||||
// Config file
|
||||
v.SetConfigName("config")
|
||||
v.SetConfigType("toml")
|
||||
v.AddConfigPath(".")
|
||||
if envPath := os.Getenv("LLM_PROXY__CONFIG_PATH"); envPath != "" {
|
||||
v.SetConfigFile(envPath)
|
||||
}
|
||||
|
||||
if err := v.ReadInConfig(); err != nil {
|
||||
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
|
||||
return nil, fmt.Errorf("failed to read config file: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
|
||||
}
|
||||
|
||||
// Manual overrides for nested keys which Viper doesn't always bind correctly with AutomaticEnv + SetEnvPrefix
|
||||
if port := os.Getenv("LLM_PROXY__SERVER__PORT"); port != "" {
|
||||
fmt.Sscanf(port, "%d", &cfg.Server.Port)
|
||||
|
||||
}
|
||||
if host := os.Getenv("LLM_PROXY__SERVER__HOST"); host != "" {
|
||||
cfg.Server.Host = host
|
||||
|
||||
}
|
||||
|
||||
// Ollama overrides
|
||||
if enabled := os.Getenv("LLM_PROXY__PROVIDERS__OLLAMA__ENABLED"); enabled != "" {
|
||||
cfg.Providers.Ollama.Enabled = enabled == "true"
|
||||
}
|
||||
if baseURL := os.Getenv("LLM_PROXY__PROVIDERS__OLLAMA__BASE_URL"); baseURL != "" {
|
||||
cfg.Providers.Ollama.BaseURL = baseURL
|
||||
}
|
||||
if models := os.Getenv("LLM_PROXY__PROVIDERS__OLLAMA__MODELS"); models != "" {
|
||||
cfg.Providers.Ollama.Models = strings.Split(models, ",")
|
||||
}
|
||||
|
||||
// Validate encryption key
|
||||
if cfg.EncryptionKey == "" {
|
||||
return nil, fmt.Errorf("encryption key is required (LLM_PROXY__ENCRYPTION_KEY)")
|
||||
}
|
||||
|
||||
keyBytes, err := hex.DecodeString(cfg.EncryptionKey)
|
||||
if err != nil {
|
||||
keyBytes, err = base64.StdEncoding.DecodeString(cfg.EncryptionKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encryption key must be hex or base64 encoded")
|
||||
}
|
||||
}
|
||||
|
||||
if len(keyBytes) != 32 {
|
||||
return nil, fmt.Errorf("encryption key must be 32 bytes, got %d", len(keyBytes))
|
||||
}
|
||||
cfg.KeyBytes = keyBytes
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func (c *Config) GetAPIKey(provider string) (string, error) {
|
||||
var envVar string
|
||||
switch provider {
|
||||
case "openai":
|
||||
envVar = c.Providers.OpenAI.APIKeyEnv
|
||||
case "gemini":
|
||||
envVar = c.Providers.Gemini.APIKeyEnv
|
||||
case "deepseek":
|
||||
envVar = c.Providers.DeepSeek.APIKeyEnv
|
||||
case "moonshot":
|
||||
envVar = c.Providers.Moonshot.APIKeyEnv
|
||||
case "grok":
|
||||
envVar = c.Providers.Grok.APIKeyEnv
|
||||
case "ollama":
|
||||
// Ollama doesn't require an API key
|
||||
return "", nil
|
||||
default:
|
||||
return "", fmt.Errorf("unknown provider: %s", provider)
|
||||
}
|
||||
|
||||
val := os.Getenv(envVar)
|
||||
if val == "" {
|
||||
return "", fmt.Errorf("environment variable %s not set for %s", envVar, provider)
|
||||
}
|
||||
return strings.TrimSpace(val), nil
|
||||
}
|
||||
@@ -0,0 +1,298 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
_ "modernc.org/sqlite"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
type DB struct {
|
||||
*sqlx.DB
|
||||
}
|
||||
|
||||
func Init(path string) (*DB, error) {
|
||||
// Ensure directory exists
|
||||
dir := filepath.Dir(path)
|
||||
if dir != "." && dir != "/" {
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create database directory: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Connect to SQLite
|
||||
dsn := fmt.Sprintf("file:%s?_pragma=foreign_keys(1)", path)
|
||||
db, err := sqlx.Connect("sqlite", dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
|
||||
instance := &DB{db}
|
||||
|
||||
// Run migrations
|
||||
if err := instance.RunMigrations(); err != nil {
|
||||
return nil, fmt.Errorf("failed to run migrations: %w", err)
|
||||
}
|
||||
|
||||
return instance, nil
|
||||
}
|
||||
|
||||
func (db *DB) RunMigrations() error {
|
||||
// Tables creation
|
||||
queries := []string{
|
||||
`CREATE TABLE IF NOT EXISTS clients (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
client_id TEXT UNIQUE NOT NULL,
|
||||
name TEXT,
|
||||
description TEXT,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
is_active BOOLEAN DEFAULT TRUE,
|
||||
rate_limit_per_minute INTEGER DEFAULT 60,
|
||||
total_requests INTEGER DEFAULT 0,
|
||||
total_tokens INTEGER DEFAULT 0,
|
||||
total_cost REAL DEFAULT 0.0
|
||||
)`,
|
||||
`CREATE TABLE IF NOT EXISTS llm_requests (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
client_id TEXT,
|
||||
provider TEXT,
|
||||
model TEXT,
|
||||
prompt_tokens INTEGER,
|
||||
completion_tokens INTEGER,
|
||||
reasoning_tokens INTEGER DEFAULT 0,
|
||||
total_tokens INTEGER,
|
||||
cost REAL,
|
||||
has_images BOOLEAN DEFAULT FALSE,
|
||||
status TEXT DEFAULT 'success',
|
||||
error_message TEXT,
|
||||
duration_ms INTEGER,
|
||||
request_body TEXT,
|
||||
response_body TEXT,
|
||||
cache_read_tokens INTEGER DEFAULT 0,
|
||||
cache_write_tokens INTEGER DEFAULT 0,
|
||||
FOREIGN KEY (client_id) REFERENCES clients(client_id) ON DELETE SET NULL
|
||||
)`,
|
||||
`CREATE TABLE IF NOT EXISTS provider_configs (
|
||||
id TEXT PRIMARY KEY,
|
||||
display_name TEXT NOT NULL,
|
||||
enabled BOOLEAN DEFAULT TRUE,
|
||||
base_url TEXT,
|
||||
api_key TEXT,
|
||||
credit_balance REAL DEFAULT 0.0,
|
||||
low_credit_threshold REAL DEFAULT 5.0,
|
||||
billing_mode TEXT,
|
||||
api_key_encrypted BOOLEAN DEFAULT FALSE,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)`,
|
||||
`CREATE TABLE IF NOT EXISTS model_configs (
|
||||
id TEXT PRIMARY KEY,
|
||||
provider_id TEXT NOT NULL,
|
||||
display_name TEXT,
|
||||
enabled BOOLEAN DEFAULT TRUE,
|
||||
prompt_cost_per_m REAL,
|
||||
completion_cost_per_m REAL,
|
||||
cache_read_cost_per_m REAL,
|
||||
cache_write_cost_per_m REAL,
|
||||
mapping TEXT,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (provider_id) REFERENCES provider_configs(id) ON DELETE CASCADE
|
||||
)`,
|
||||
`CREATE TABLE IF NOT EXISTS users (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username TEXT UNIQUE NOT NULL,
|
||||
password_hash TEXT NOT NULL,
|
||||
display_name TEXT,
|
||||
role TEXT DEFAULT 'admin',
|
||||
must_change_password BOOLEAN DEFAULT FALSE,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)`,
|
||||
`CREATE TABLE IF NOT EXISTS client_tokens (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
client_id TEXT NOT NULL,
|
||||
token TEXT NOT NULL UNIQUE,
|
||||
name TEXT DEFAULT 'default',
|
||||
is_active BOOLEAN DEFAULT TRUE,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
last_used_at DATETIME,
|
||||
FOREIGN KEY (client_id) REFERENCES clients(client_id) ON DELETE CASCADE
|
||||
)`,
|
||||
`CREATE TABLE IF NOT EXISTS model_groups (
|
||||
id TEXT PRIMARY KEY,
|
||||
strategy TEXT NOT NULL DEFAULT 'heuristic',
|
||||
selector_model TEXT,
|
||||
targets TEXT NOT NULL DEFAULT '[]',
|
||||
complexity_threshold INTEGER,
|
||||
heuristic_rules TEXT,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)`,
|
||||
}
|
||||
|
||||
for _, q := range queries {
|
||||
if _, err := db.Exec(q); err != nil {
|
||||
return fmt.Errorf("migration failed for query [%s]: %w", q, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Add indices
|
||||
indices := []string{
|
||||
"CREATE INDEX IF NOT EXISTS idx_clients_client_id ON clients(client_id)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_clients_created_at ON clients(created_at)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_llm_requests_timestamp ON llm_requests(timestamp)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_llm_requests_client_id ON llm_requests(client_id)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_llm_requests_provider ON llm_requests(provider)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_llm_requests_status ON llm_requests(status)",
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS idx_client_tokens_token ON client_tokens(token)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_client_tokens_client_id ON client_tokens(client_id)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_llm_requests_client_timestamp ON llm_requests(client_id, timestamp)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_llm_requests_provider_timestamp ON llm_requests(provider, timestamp)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_model_configs_provider_id ON model_configs(provider_id)",
|
||||
}
|
||||
|
||||
for _, idx := range indices {
|
||||
if _, err := db.Exec(idx); err != nil {
|
||||
return fmt.Errorf("failed to create index [%s]: %w", idx, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Default admin user
|
||||
var count int
|
||||
if err := db.Get(&count, "SELECT COUNT(*) FROM users"); err != nil {
|
||||
return fmt.Errorf("failed to count users: %w", err)
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte("admin123"), 12)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to hash default password: %w", err)
|
||||
}
|
||||
_, err = db.Exec("INSERT INTO users (username, password_hash, role, must_change_password) VALUES ('admin', ?, 'admin', 1)", string(hash))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to insert default admin: %w", err)
|
||||
}
|
||||
log.Println("Created default admin user with password 'admin123' (must change on first login)")
|
||||
}
|
||||
|
||||
// Default client
|
||||
_, err := db.Exec(`INSERT OR IGNORE INTO clients (client_id, name, description)
|
||||
VALUES ('default', 'Default Client', 'Default client for anonymous requests')`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to insert default client: %w", err)
|
||||
}
|
||||
|
||||
// Seed default model groups
|
||||
defaultGroups := []struct {
|
||||
id, strategy, targets string
|
||||
}{
|
||||
{"deepseek-auto", "heuristic", `["deepseek-chat","deepseek-reasoner"]`},
|
||||
{"openai-auto", "heuristic", `["gpt-4o-mini","gpt-4o"]`},
|
||||
{"gemini-auto", "heuristic", `["gemini-2.0-flash","gemini-2.5-pro"]`},
|
||||
}
|
||||
for _, g := range defaultGroups {
|
||||
db.Exec(`INSERT OR IGNORE INTO model_groups (id, strategy, targets) VALUES (?, ?, ?)`,
|
||||
g.id, g.strategy, g.targets)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Data models for DB tables
|
||||
|
||||
type Client struct {
|
||||
ID int `db:"id"`
|
||||
ClientID string `db:"client_id"`
|
||||
Name *string `db:"name"`
|
||||
Description *string `db:"description"`
|
||||
CreatedAt time.Time `db:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at"`
|
||||
IsActive bool `db:"is_active"`
|
||||
RateLimitPerMinute int `db:"rate_limit_per_minute"`
|
||||
TotalRequests int `db:"total_requests"`
|
||||
TotalTokens int `db:"total_tokens"`
|
||||
TotalCost float64 `db:"total_cost"`
|
||||
}
|
||||
|
||||
type LLMRequest struct {
|
||||
ID int `db:"id"`
|
||||
Timestamp time.Time `db:"timestamp"`
|
||||
ClientID *string `db:"client_id"`
|
||||
Provider *string `db:"provider"`
|
||||
Model *string `db:"model"`
|
||||
PromptTokens *int `db:"prompt_tokens"`
|
||||
CompletionTokens *int `db:"completion_tokens"`
|
||||
ReasoningTokens int `db:"reasoning_tokens"`
|
||||
TotalTokens *int `db:"total_tokens"`
|
||||
Cost *float64 `db:"cost"`
|
||||
HasImages bool `db:"has_images"`
|
||||
Status string `db:"status"`
|
||||
ErrorMessage *string `db:"error_message"`
|
||||
DurationMS *int `db:"duration_ms"`
|
||||
RequestBody *string `db:"request_body"`
|
||||
ResponseBody *string `db:"response_body"`
|
||||
CacheReadTokens int `db:"cache_read_tokens"`
|
||||
CacheWriteTokens int `db:"cache_write_tokens"`
|
||||
}
|
||||
|
||||
type ProviderConfig struct {
|
||||
ID string `db:"id"`
|
||||
DisplayName string `db:"display_name"`
|
||||
Enabled bool `db:"enabled"`
|
||||
BaseURL *string `db:"base_url"`
|
||||
APIKey *string `db:"api_key"`
|
||||
CreditBalance float64 `db:"credit_balance"`
|
||||
LowCreditThreshold float64 `db:"low_credit_threshold"`
|
||||
BillingMode *string `db:"billing_mode"`
|
||||
APIKeyEncrypted bool `db:"api_key_encrypted"`
|
||||
UpdatedAt time.Time `db:"updated_at"`
|
||||
}
|
||||
|
||||
type ModelConfig struct {
|
||||
ID string `db:"id"`
|
||||
ProviderID string `db:"provider_id"`
|
||||
DisplayName *string `db:"display_name"`
|
||||
Enabled bool `db:"enabled"`
|
||||
PromptCostPerM *float64 `db:"prompt_cost_per_m"`
|
||||
CompletionCostPerM *float64 `db:"completion_cost_per_m"`
|
||||
CacheReadCostPerM *float64 `db:"cache_read_cost_per_m"`
|
||||
CacheWriteCostPerM *float64 `db:"cache_write_cost_per_m"`
|
||||
Mapping *string `db:"mapping"`
|
||||
UpdatedAt time.Time `db:"updated_at"`
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID int `db:"id" json:"id"`
|
||||
Username string `db:"username" json:"username"`
|
||||
PasswordHash string `db:"password_hash" json:"-"`
|
||||
DisplayName *string `db:"display_name" json:"display_name"`
|
||||
Role string `db:"role" json:"role"`
|
||||
MustChangePassword bool `db:"must_change_password" json:"must_change_password"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
}
|
||||
|
||||
type ClientToken struct {
|
||||
ID int `db:"id"`
|
||||
ClientID string `db:"client_id"`
|
||||
Token string `db:"token"`
|
||||
Name string `db:"name"`
|
||||
IsActive bool `db:"is_active"`
|
||||
CreatedAt time.Time `db:"created_at"`
|
||||
LastUsedAt *time.Time `db:"last_used_at"`
|
||||
}
|
||||
|
||||
type ModelGroup struct {
|
||||
ID string `db:"id" json:"id"`
|
||||
Strategy string `db:"strategy" json:"strategy"`
|
||||
SelectorModel *string `db:"selector_model" json:"selector_model"`
|
||||
Targets string `db:"targets" json:"targets"` // JSON array
|
||||
ComplexityThreshold *int `db:"complexity_threshold" json:"complexity_threshold"`
|
||||
HeuristicRules *string `db:"heuristic_rules" json:"heuristic_rules"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var level = slog.LevelInfo
|
||||
|
||||
func init() {
|
||||
env := os.Getenv("LLM_PROXY_LOG_LEVEL")
|
||||
switch strings.ToLower(env) {
|
||||
case "debug":
|
||||
level = slog.LevelDebug
|
||||
case "warn":
|
||||
level = slog.LevelWarn
|
||||
case "error":
|
||||
level = slog.LevelError
|
||||
}
|
||||
|
||||
h := slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
|
||||
Level: level,
|
||||
})
|
||||
slog.SetDefault(slog.New(h))
|
||||
}
|
||||
|
||||
// Warn is a helper to emit structured warnings.
|
||||
func Warn(msg string, args ...any) {
|
||||
slog.Warn(msg, args...)
|
||||
}
|
||||
|
||||
// Error is a helper to emit structured errors.
|
||||
func Error(msg string, args ...any) {
|
||||
slog.Error(msg, args...)
|
||||
}
|
||||
|
||||
// Debug is a helper to emit structured debug messages.
|
||||
func Debug(msg string, args ...any) {
|
||||
slog.Debug(msg, args...)
|
||||
}
|
||||
|
||||
// Ctx wraps slog with context.
|
||||
func Ctx(ctx context.Context) *slog.Logger {
|
||||
return slog.Default()
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"gophergate/internal/db"
|
||||
"gophergate/internal/models"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func AuthMiddleware(database *db.DB, requireAuth bool) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
if requireAuth {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing authorization header"})
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
token := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if token == authHeader { // No "Bearer " prefix
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Try to resolve client from database
|
||||
var clientID string
|
||||
err := database.Get(&clientID, "UPDATE client_tokens SET last_used_at = CURRENT_TIMESTAMP WHERE token = ? AND is_active = 1 RETURNING client_id", token)
|
||||
|
||||
if err == nil {
|
||||
c.Set("auth", models.AuthInfo{
|
||||
Token: token,
|
||||
ClientID: clientID,
|
||||
})
|
||||
c.Next()
|
||||
} else {
|
||||
log.Printf("Token not found or inactive in DB: %s", token)
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid or inactive token"})
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,241 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
// OpenAI-compatible Request/Response Structs
|
||||
|
||||
type ChatCompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []ChatMessage `json:"messages"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK *uint32 `json:"top_k,omitempty"`
|
||||
N *uint32 `json:"n,omitempty"`
|
||||
Stop json.RawMessage `json:"stop,omitempty"` // Can be string or array of strings
|
||||
MaxTokens *uint32 `json:"max_tokens,omitempty"`
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
}
|
||||
|
||||
type ChatMessage struct {
|
||||
Role string `json:"role"` // "system", "user", "assistant", "tool"
|
||||
Content interface{} `json:"content"`
|
||||
ReasoningContent *string `json:"reasoning_content,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
ToolCallID *string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
type ContentPart struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ImageUrl *ImageUrl `json:"image_url,omitempty"`
|
||||
}
|
||||
|
||||
type ImageUrl struct {
|
||||
URL string `json:"url"`
|
||||
Detail *string `json:"detail,omitempty"`
|
||||
}
|
||||
|
||||
// Tool-Calling Types
|
||||
|
||||
type Tool struct {
|
||||
Type string `json:"type"`
|
||||
Function FunctionDef `json:"function"`
|
||||
}
|
||||
|
||||
type FunctionDef struct {
|
||||
Name string `json:"name"`
|
||||
Description *string `json:"description,omitempty"`
|
||||
Parameters json.RawMessage `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function FunctionCall `json:"function"`
|
||||
}
|
||||
|
||||
type FunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}
|
||||
|
||||
type ToolCallDelta struct {
|
||||
Index uint32 `json:"index"`
|
||||
ID *string `json:"id,omitempty"`
|
||||
Type *string `json:"type,omitempty"`
|
||||
Function *FunctionCallDelta `json:"function,omitempty"`
|
||||
}
|
||||
|
||||
type FunctionCallDelta struct {
|
||||
Name *string `json:"name,omitempty"`
|
||||
Arguments *string `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
// OpenAI-compatible Response Structs
|
||||
|
||||
type ChatCompletionResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []ChatChoice `json:"choices"`
|
||||
Usage *Usage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type ChatChoice struct {
|
||||
Index uint32 `json:"index"`
|
||||
Message ChatMessage `json:"message"`
|
||||
FinishReason *string `json:"finish_reason,omitempty"`
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
PromptTokens uint32 `json:"prompt_tokens"`
|
||||
CompletionTokens uint32 `json:"completion_tokens"`
|
||||
TotalTokens uint32 `json:"total_tokens"`
|
||||
ReasoningTokens *uint32 `json:"reasoning_tokens,omitempty"`
|
||||
CacheReadTokens *uint32 `json:"cache_read_tokens,omitempty"`
|
||||
CacheWriteTokens *uint32 `json:"cache_write_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// Streaming Response Structs
|
||||
|
||||
type ChatCompletionStreamResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []ChatStreamChoice `json:"choices"`
|
||||
Usage *Usage `json:"usage,omitempty"`
|
||||
Error *string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type ChatStreamChoice struct {
|
||||
Index uint32 `json:"index"`
|
||||
Delta ChatStreamDelta `json:"delta"`
|
||||
FinishReason *string `json:"finish_reason,omitempty"`
|
||||
}
|
||||
|
||||
type ChatStreamDelta struct {
|
||||
Role *string `json:"role,omitempty"`
|
||||
Content *string `json:"content,omitempty"`
|
||||
ReasoningContent *string `json:"reasoning_content,omitempty"`
|
||||
ToolCalls []ToolCallDelta `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
type StreamUsage struct {
|
||||
PromptTokens uint32 `json:"prompt_tokens"`
|
||||
CompletionTokens uint32 `json:"completion_tokens"`
|
||||
TotalTokens uint32 `json:"total_tokens"`
|
||||
ReasoningTokens uint32 `json:"reasoning_tokens"`
|
||||
CacheReadTokens uint32 `json:"cache_read_tokens"`
|
||||
CacheWriteTokens uint32 `json:"cache_write_tokens"`
|
||||
}
|
||||
|
||||
// Unified Request Format (for internal use)
|
||||
|
||||
type UnifiedRequest struct {
|
||||
ClientID string
|
||||
Model string
|
||||
Messages []UnifiedMessage
|
||||
Temperature *float64
|
||||
TopP *float64
|
||||
TopK *uint32
|
||||
N *uint32
|
||||
Stop []string
|
||||
MaxTokens *uint32
|
||||
PresencePenalty *float64
|
||||
FrequencyPenalty *float64
|
||||
Stream bool
|
||||
HasImages bool
|
||||
Tools []Tool
|
||||
ToolChoice json.RawMessage
|
||||
}
|
||||
|
||||
type UnifiedMessage struct {
|
||||
Role string
|
||||
Content []UnifiedContentPart
|
||||
ReasoningContent *string
|
||||
ToolCalls []ToolCall
|
||||
Name *string
|
||||
ToolCallID *string
|
||||
}
|
||||
|
||||
type UnifiedContentPart struct {
|
||||
Type string
|
||||
Text string
|
||||
Image *ImageInput
|
||||
}
|
||||
|
||||
type ImageInput struct {
|
||||
Base64 string `json:"base64,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
MimeType string `json:"mime_type,omitempty"`
|
||||
}
|
||||
|
||||
func (i *ImageInput) ToBase64() (string, string, error) {
|
||||
if i.Base64 != "" {
|
||||
return i.Base64, i.MimeType, nil
|
||||
}
|
||||
|
||||
if i.URL != "" {
|
||||
client := resty.New()
|
||||
resp, err := client.R().Get(i.URL)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to fetch image: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return "", "", fmt.Errorf("failed to fetch image: HTTP %d", resp.StatusCode())
|
||||
}
|
||||
|
||||
mimeType := resp.Header().Get("Content-Type")
|
||||
if mimeType == "" {
|
||||
mimeType = "image/jpeg"
|
||||
}
|
||||
|
||||
encoded := base64.StdEncoding.EncodeToString(resp.Body())
|
||||
return encoded, mimeType, nil
|
||||
}
|
||||
|
||||
return "", "", fmt.Errorf("empty image input")
|
||||
}
|
||||
|
||||
// Image Generation (DALL-E, Imagen)
|
||||
|
||||
type ImageGenerationRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
N *uint32 `json:"n,omitempty"`
|
||||
Quality *string `json:"quality,omitempty"`
|
||||
ResponseFormat *string `json:"response_format,omitempty"`
|
||||
Size *string `json:"size,omitempty"`
|
||||
Style *string `json:"style,omitempty"`
|
||||
User *string `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
type ImageGenerationResponse struct {
|
||||
Created int64 `json:"created"`
|
||||
Data []ImageData `json:"data"`
|
||||
}
|
||||
|
||||
type ImageData struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
B64JSON string `json:"b64_json,omitempty"`
|
||||
RevisedPrompt string `json:"revised_prompt,omitempty"`
|
||||
}
|
||||
|
||||
// AuthInfo for context
|
||||
type AuthInfo struct {
|
||||
Token string
|
||||
ClientID string
|
||||
}
|
||||
@@ -0,0 +1,78 @@
|
||||
package models
|
||||
|
||||
import "strings"
|
||||
|
||||
type ModelRegistry struct {
|
||||
Providers map[string]ProviderInfo `json:"-"`
|
||||
}
|
||||
|
||||
type ProviderInfo struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Models map[string]ModelMetadata `json:"models"`
|
||||
}
|
||||
|
||||
type ModelMetadata struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Cost *ModelCost `json:"cost,omitempty"`
|
||||
Limit *ModelLimit `json:"limit,omitempty"`
|
||||
Modalities *ModelModalities `json:"modalities,omitempty"`
|
||||
ToolCall *bool `json:"tool_call,omitempty"`
|
||||
Reasoning *bool `json:"reasoning,omitempty"`
|
||||
}
|
||||
|
||||
type ModelCost struct {
|
||||
Input float64 `json:"input"`
|
||||
Output float64 `json:"output"`
|
||||
CacheRead *float64 `json:"cache_read,omitempty"`
|
||||
CacheWrite *float64 `json:"cache_write,omitempty"`
|
||||
}
|
||||
|
||||
type ModelLimit struct {
|
||||
Context uint32 `json:"context"`
|
||||
Output uint32 `json:"output"`
|
||||
}
|
||||
|
||||
type ModelModalities struct {
|
||||
Input []string `json:"input"`
|
||||
Output []string `json:"output"`
|
||||
}
|
||||
|
||||
func (r *ModelRegistry) FindModel(modelID string) *ModelMetadata {
|
||||
// First try exact match in models map
|
||||
for _, provider := range r.Providers {
|
||||
if model, ok := provider.Models[modelID]; ok {
|
||||
return &model
|
||||
}
|
||||
}
|
||||
|
||||
// Try searching by ID in metadata
|
||||
for _, provider := range r.Providers {
|
||||
for _, model := range provider.Models {
|
||||
if model.ID == modelID {
|
||||
return &model
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try reverse fuzzy matching (e.g. 'gpt-5.4-mini' matching 'gpt-5.4-mini-2026-04-01')
|
||||
for _, provider := range r.Providers {
|
||||
for id, model := range provider.Models {
|
||||
if strings.HasPrefix(id, modelID) {
|
||||
return &model
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try fuzzy matching (e.g. 'gpt-4o-2024-05-13' matching 'gpt-4o')
|
||||
for _, provider := range r.Providers {
|
||||
for id, model := range provider.Models {
|
||||
if strings.HasPrefix(modelID, id) {
|
||||
return &model
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestModelRegistry_FindModel_Exact(t *testing.T) {
|
||||
r := &ModelRegistry{
|
||||
Providers: map[string]ProviderInfo{
|
||||
"openai": {
|
||||
Models: map[string]ModelMetadata{
|
||||
"gpt-4o": {ID: "gpt-4o", Name: "GPT-4o"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
m := r.FindModel("gpt-4o")
|
||||
if m == nil {
|
||||
t.Fatal("expected to find gpt-4o")
|
||||
}
|
||||
if m.Name != "GPT-4o" {
|
||||
t.Fatalf("expected GPT-4o, got %s", m.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelRegistry_FindModel_Fuzzy(t *testing.T) {
|
||||
r := &ModelRegistry{
|
||||
Providers: map[string]ProviderInfo{
|
||||
"openai": {
|
||||
Models: map[string]ModelMetadata{
|
||||
"gpt-4o": {ID: "gpt-4o", Name: "GPT-4o"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
// Fuzzy: "gpt-4o-2024-05-13" should match "gpt-4o"
|
||||
m := r.FindModel("gpt-4o-2024-05-13")
|
||||
if m == nil {
|
||||
t.Fatal("expected fuzzy match")
|
||||
}
|
||||
if m.Name != "GPT-4o" {
|
||||
t.Fatalf("expected GPT-4o, got %s", m.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelRegistry_FindModel_NotFound(t *testing.T) {
|
||||
r := &ModelRegistry{
|
||||
Providers: map[string]ProviderInfo{
|
||||
"openai": {
|
||||
Models: map[string]ModelMetadata{
|
||||
"gpt-4o": {ID: "gpt-4o", Name: "GPT-4o"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
m := r.FindModel("nonexistent-model")
|
||||
if m != nil {
|
||||
t.Fatal("expected nil for nonexistent model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelRegistry_FindModel_ReverseFuzzy(t *testing.T) {
|
||||
r := &ModelRegistry{
|
||||
Providers: map[string]ProviderInfo{
|
||||
"openai": {
|
||||
Models: map[string]ModelMetadata{
|
||||
"gpt-5.4-mini-2026-04-01": {ID: "gpt-5.4-mini-2026-04-01", Name: "GPT-5.4 Mini"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
// Reverse fuzzy: "gpt-5.4-mini" should match "gpt-5.4-mini-2026-04-01"
|
||||
m := r.FindModel("gpt-5.4-mini")
|
||||
if m == nil {
|
||||
t.Fatal("expected reverse fuzzy match")
|
||||
}
|
||||
if m.Name != "GPT-5.4 Mini" {
|
||||
t.Fatalf("expected GPT-5.4 Mini, got %s", m.Name)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,141 @@
|
||||
package models
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// Responses API request types
|
||||
|
||||
// ResponsesRequest maps to POST /v1/responses body (OpenAI Responses API format).
|
||||
// The `input` field can be a string or an array of message objects.
|
||||
type ResponsesRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input json.RawMessage `json:"input"` // string or []ResponseInputMessage
|
||||
Instructions string `json:"instructions,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
MaxOutputTokens *uint32 `json:"max_output_tokens,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Tools json.RawMessage `json:"tools,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
Store *bool `json:"store,omitempty"`
|
||||
}
|
||||
|
||||
// ResponseInputMessage represents a single message in the input array.
|
||||
type ResponseInputMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content json.RawMessage `json:"content"` // string or []ContentPart
|
||||
}
|
||||
|
||||
// Responses API response types
|
||||
|
||||
// ResponsesResponse maps to OpenAI /v1/responses response.
|
||||
type ResponsesResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Model string `json:"model"`
|
||||
Output []ResponsesOutputItem `json:"output"`
|
||||
Usage *ResponsesUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesOutputItem represents an item in the output array.
|
||||
// For messages: type="message", role, content[].
|
||||
// For function calls: type="function_call", id, name, arguments, status.
|
||||
type ResponsesOutputItem struct {
|
||||
Type string `json:"type"`
|
||||
Role string `json:"role,omitempty"`
|
||||
Content []ResponsesOutputContent `json:"content,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
Status string `json:"status,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesOutputContent represents content parts within an output message.
|
||||
type ResponsesOutputContent struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Annotations []json.RawMessage `json:"annotations,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesUsage maps to the usage block in Responses API.
|
||||
type ResponsesUsage struct {
|
||||
InputTokens uint32 `json:"input_tokens"`
|
||||
OutputTokens uint32 `json:"output_tokens"`
|
||||
TotalTokens uint32 `json:"total_tokens"`
|
||||
InputTokensDetails *ResponsesInputTokensDetails `json:"input_tokens_details,omitempty"`
|
||||
OutputTokensDetails *ResponsesOutputTokensDetails `json:"output_tokens_details,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesInputTokensDetails maps input token details.
|
||||
type ResponsesInputTokensDetails struct {
|
||||
CachedTokens uint32 `json:"cached_tokens"`
|
||||
}
|
||||
|
||||
// ResponsesOutputTokensDetails maps output token details.
|
||||
type ResponsesOutputTokensDetails struct {
|
||||
ReasoningTokens uint32 `json:"reasoning_tokens"`
|
||||
}
|
||||
|
||||
// ToUsage converts ResponsesUsage to the unified Usage model.
|
||||
func (u *ResponsesUsage) ToUsage() *Usage {
|
||||
usage := &Usage{
|
||||
PromptTokens: u.InputTokens,
|
||||
CompletionTokens: u.OutputTokens,
|
||||
TotalTokens: u.TotalTokens,
|
||||
}
|
||||
if u.InputTokensDetails != nil && u.InputTokensDetails.CachedTokens > 0 {
|
||||
usage.CacheReadTokens = &u.InputTokensDetails.CachedTokens
|
||||
}
|
||||
if u.OutputTokensDetails != nil && u.OutputTokensDetails.ReasoningTokens > 0 {
|
||||
usage.ReasoningTokens = &u.OutputTokensDetails.ReasoningTokens
|
||||
}
|
||||
return usage
|
||||
}
|
||||
|
||||
// ResponsesStreamChunk represents an SSE chunk from the Responses streaming endpoint.
|
||||
type ResponsesStreamChunk struct {
|
||||
Type string `json:"type"`
|
||||
Response *ResponsesStreamPayload `json:"response,omitempty"`
|
||||
Item *ResponsesStreamPayloadItem `json:"item,omitempty"`
|
||||
Delta *ResponsesStreamDelta `json:"delta,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesStreamPayload represents the "response" field in some SSE chunks.
|
||||
type ResponsesStreamPayload struct {
|
||||
Object string `json:"object"`
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Usage *ResponsesUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesStreamPayloadItem represents the "item" field in SSE chunks.
|
||||
type ResponsesStreamPayloadItem struct {
|
||||
Type string `json:"type"`
|
||||
Role string `json:"role,omitempty"`
|
||||
Content []ResponsesOutputContent `json:"content,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Status string `json:"status,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesStreamDelta represents a content delta in streaming.
|
||||
type ResponsesStreamDelta struct {
|
||||
ContentIndex int `json:"content_index"`
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
// UnifiedResponsesRequest is the internal unified format for Responses API.
|
||||
type UnifiedResponsesRequest struct {
|
||||
ClientID string
|
||||
Model string
|
||||
Input string // normalized input text
|
||||
InputMessages []ResponseInputMessage // structured input messages (if provided as array)
|
||||
Instructions string
|
||||
Temperature *float64
|
||||
MaxOutputTokens *uint32
|
||||
TopP *float64
|
||||
Stream bool
|
||||
Tools json.RawMessage
|
||||
ToolChoice json.RawMessage
|
||||
Store bool
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/sony/gobreaker"
|
||||
"gophergate/internal/models"
|
||||
)
|
||||
|
||||
type CircuitBreakerProvider struct {
|
||||
provider Provider
|
||||
cb *gobreaker.CircuitBreaker
|
||||
}
|
||||
|
||||
func NewCircuitBreakerProvider(p Provider) Provider {
|
||||
name := p.Name()
|
||||
var maxRequests uint32 = 5
|
||||
var interval = 60 * time.Second
|
||||
var timeout = 5 * time.Minute
|
||||
|
||||
settings := gobreaker.Settings{
|
||||
Name: name,
|
||||
MaxRequests: maxRequests,
|
||||
Interval: interval,
|
||||
Timeout: timeout,
|
||||
ReadyToTrip: func(counts gobreaker.Counts) bool {
|
||||
// Trip after 3 consecutive failures
|
||||
return counts.ConsecutiveFailures > 3
|
||||
},
|
||||
}
|
||||
return &CircuitBreakerProvider{
|
||||
provider: p,
|
||||
cb: gobreaker.NewCircuitBreaker(settings),
|
||||
}
|
||||
}
|
||||
|
||||
func (cbp *CircuitBreakerProvider) Name() string {
|
||||
return cbp.provider.Name()
|
||||
}
|
||||
|
||||
func (cbp *CircuitBreakerProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
|
||||
result, err := cbp.cb.Execute(func() (interface{}, error) {
|
||||
return cbp.provider.ChatCompletion(ctx, req)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result.(*models.ChatCompletionResponse), nil
|
||||
}
|
||||
|
||||
func (cbp *CircuitBreakerProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
|
||||
// Circuit breaker for streaming is tricky. We'll just call the provider directly.
|
||||
// Future: Implement a way to track stream failures in the circuit breaker.
|
||||
return cbp.provider.ChatCompletionStream(ctx, req)
|
||||
}
|
||||
|
||||
func (cbp *CircuitBreakerProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) {
|
||||
result, err := cbp.cb.Execute(func() (interface{}, error) {
|
||||
return cbp.provider.ImageGeneration(ctx, req)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result.(*models.ImageGenerationResponse), nil
|
||||
}
|
||||
|
||||
func (cbp *CircuitBreakerProvider) Responses(ctx context.Context, req *models.ResponsesRequest) (*models.ResponsesResponse, error) {
|
||||
result, err := cbp.cb.Execute(func() (interface{}, error) {
|
||||
return cbp.provider.Responses(ctx, req)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result.(*models.ResponsesResponse), nil
|
||||
}
|
||||
|
||||
func (cbp *CircuitBreakerProvider) ResponsesStream(ctx context.Context, req *models.ResponsesRequest) (<-chan *models.ResponsesStreamChunk, error) {
|
||||
// Circuit breaker passthrough for streaming (same pattern as ChatCompletionStream)
|
||||
return cbp.provider.ResponsesStream(ctx, req)
|
||||
}
|
||||
@@ -0,0 +1,233 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
"gophergate/internal/config"
|
||||
"gophergate/internal/models"
|
||||
)
|
||||
|
||||
type DeepSeekProvider struct {
|
||||
client *resty.Client
|
||||
config config.DeepSeekConfig
|
||||
apiKey string
|
||||
}
|
||||
|
||||
func NewDeepSeekProvider(cfg config.DeepSeekConfig, apiKey string) *DeepSeekProvider {
|
||||
return &DeepSeekProvider{
|
||||
client: resty.New().SetTimeout(10 * time.Minute),
|
||||
config: cfg,
|
||||
apiKey: apiKey,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *DeepSeekProvider) Name() string {
|
||||
return "deepseek"
|
||||
}
|
||||
|
||||
type deepSeekUsage struct {
|
||||
PromptTokens uint32 `json:"prompt_tokens"`
|
||||
CompletionTokens uint32 `json:"completion_tokens"`
|
||||
TotalTokens uint32 `json:"total_tokens"`
|
||||
PromptCacheHitTokens uint32 `json:"prompt_cache_hit_tokens"`
|
||||
PromptCacheMissTokens uint32 `json:"prompt_cache_miss_tokens"`
|
||||
CompletionTokensDetails *struct {
|
||||
ReasoningTokens uint32 `json:"reasoning_tokens"`
|
||||
} `json:"completion_tokens_details"`
|
||||
}
|
||||
|
||||
func (u *deepSeekUsage) ToUnified() *models.Usage {
|
||||
usage := &models.Usage{
|
||||
PromptTokens: u.PromptTokens,
|
||||
CompletionTokens: u.CompletionTokens,
|
||||
TotalTokens: u.TotalTokens,
|
||||
}
|
||||
if u.PromptCacheHitTokens > 0 {
|
||||
usage.CacheReadTokens = &u.PromptCacheHitTokens
|
||||
}
|
||||
if u.PromptCacheMissTokens > 0 {
|
||||
usage.CacheWriteTokens = &u.PromptCacheMissTokens
|
||||
}
|
||||
if u.CompletionTokensDetails != nil && u.CompletionTokensDetails.ReasoningTokens > 0 {
|
||||
usage.ReasoningTokens = &u.CompletionTokensDetails.ReasoningTokens
|
||||
}
|
||||
return usage
|
||||
}
|
||||
|
||||
func (p *DeepSeekProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
|
||||
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||
}
|
||||
|
||||
body := BuildOpenAIBody(req, messagesJSON, false)
|
||||
|
||||
// Sanitize for deepseek-reasoner
|
||||
if req.Model == "deepseek-reasoner" {
|
||||
delete(body, "temperature")
|
||||
delete(body, "top_p")
|
||||
delete(body, "presence_penalty")
|
||||
delete(body, "frequency_penalty")
|
||||
|
||||
if msgs, ok := body["messages"].([]interface{}); ok {
|
||||
for _, m := range msgs {
|
||||
if msg, ok := m.(map[string]interface{}); ok {
|
||||
if msg["role"] == "assistant" {
|
||||
if msg["reasoning_content"] == nil {
|
||||
msg["reasoning_content"] = " "
|
||||
}
|
||||
if msg["content"] == nil || msg["content"] == "" {
|
||||
msg["content"] = ""
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||
SetBody(body).
|
||||
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("DeepSeek API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
var respJSON map[string]interface{}
|
||||
if err := json.Unmarshal(resp.Body(), &respJSON); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
result, err := ParseOpenAIResponse(respJSON, req.Model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Fix usage for DeepSeek specifically if details were missing in ParseOpenAIResponse
|
||||
if usageData, ok := respJSON["usage"]; ok {
|
||||
var dUsage deepSeekUsage
|
||||
usageBytes, _ := json.Marshal(usageData)
|
||||
if err := json.Unmarshal(usageBytes, &dUsage); err == nil {
|
||||
result.Usage = dUsage.ToUnified()
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (p *DeepSeekProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
|
||||
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||
}
|
||||
|
||||
body := BuildOpenAIBody(req, messagesJSON, true)
|
||||
|
||||
// Sanitize for deepseek-reasoner
|
||||
if req.Model == "deepseek-reasoner" {
|
||||
delete(body, "temperature")
|
||||
delete(body, "top_p")
|
||||
delete(body, "presence_penalty")
|
||||
delete(body, "frequency_penalty")
|
||||
|
||||
if msgs, ok := body["messages"].([]interface{}); ok {
|
||||
for _, m := range msgs {
|
||||
if msg, ok := m.(map[string]interface{}); ok {
|
||||
if msg["role"] == "assistant" {
|
||||
if msg["reasoning_content"] == nil {
|
||||
msg["reasoning_content"] = " "
|
||||
}
|
||||
if msg["content"] == nil || msg["content"] == "" {
|
||||
msg["content"] = ""
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||
SetBody(body).
|
||||
SetDoNotParseResponse(true).
|
||||
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("DeepSeek API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
ch := make(chan *models.ChatCompletionStreamResponse)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
// Custom scanner loop to handle DeepSeek specific usage in chunks
|
||||
err := StreamDeepSeek(resp.RawBody(), ch)
|
||||
if err != nil {
|
||||
fmt.Printf("DeepSeek Stream error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func StreamDeepSeek(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamResponse) error {
|
||||
defer ctx.Close()
|
||||
scanner := bufio.NewScanner(ctx)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if line == "" || !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
if data == "[DONE]" {
|
||||
break
|
||||
}
|
||||
|
||||
var chunk models.ChatCompletionStreamResponse
|
||||
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Fix DeepSeek specific usage in stream
|
||||
var rawChunk struct {
|
||||
Usage *deepSeekUsage `json:"usage"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(data), &rawChunk); err == nil && rawChunk.Usage != nil {
|
||||
chunk.Usage = rawChunk.Usage.ToUnified()
|
||||
}
|
||||
|
||||
ch <- &chunk
|
||||
}
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
func (p *DeepSeekProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) {
|
||||
return nil, fmt.Errorf("deepseek does not support image generation")
|
||||
}
|
||||
|
||||
func (p *DeepSeekProvider) Responses(ctx context.Context, req *models.ResponsesRequest) (*models.ResponsesResponse, error) {
|
||||
return nil, fmt.Errorf("responses API not supported by deepseek")
|
||||
}
|
||||
|
||||
func (p *DeepSeekProvider) ResponsesStream(ctx context.Context, req *models.ResponsesRequest) (<-chan *models.ResponsesStreamChunk, error) {
|
||||
return nil, fmt.Errorf("responses API not supported by deepseek")
|
||||
}
|
||||
@@ -0,0 +1,623 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
"gophergate/internal/config"
|
||||
"gophergate/internal/models"
|
||||
)
|
||||
|
||||
type GeminiProvider struct {
|
||||
client *resty.Client
|
||||
config config.GeminiConfig
|
||||
apiKey string
|
||||
}
|
||||
|
||||
func NewGeminiProvider(cfg config.GeminiConfig, apiKey string) *GeminiProvider {
|
||||
return &GeminiProvider{
|
||||
client: resty.New().SetTimeout(10 * time.Minute),
|
||||
config: cfg,
|
||||
apiKey: apiKey,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) Name() string {
|
||||
return "gemini"
|
||||
}
|
||||
|
||||
type GeminiRequest struct {
|
||||
Contents []GeminiContent `json:"contents"`
|
||||
Tools []GeminiTool `json:"tools,omitempty"`
|
||||
GenerationConfig *GeminiGenerationConfig `json:"generationConfig,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiTool struct {
|
||||
FunctionDeclarations []models.FunctionDef `json:"functionDeclarations"`
|
||||
}
|
||||
|
||||
type GeminiGenerationConfig struct {
|
||||
Temperature *float32 `json:"temperature,omitempty"`
|
||||
TopP *float32 `json:"topP,omitempty"`
|
||||
TopK *int `json:"topK,omitempty"`
|
||||
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiContent struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Parts []GeminiPart `json:"parts"`
|
||||
}
|
||||
|
||||
type GeminiPart struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||
FunctionCall *GeminiFunctionCall `json:"functionCall,omitempty"`
|
||||
FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiInlineData struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
type GeminiFunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Args json.RawMessage `json:"args"`
|
||||
}
|
||||
|
||||
type GeminiFunctionResponse struct {
|
||||
Name string `json:"name"`
|
||||
Response json.RawMessage `json:"response"`
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) {
|
||||
// Gemini Imagen API: POST https://generativelanguage.googleapis.com/v1beta/models/{model}:predict
|
||||
// Map OpenAI-style params to Gemini Imagen params
|
||||
|
||||
n := uint32(1)
|
||||
if req.N != nil && *req.N > 0 {
|
||||
n = *req.N
|
||||
}
|
||||
|
||||
aspectRatio := "1:1"
|
||||
if req.Size != nil {
|
||||
aspectRatio = sizeToGeminiAspectRatio(*req.Size)
|
||||
}
|
||||
|
||||
// Build Imagen request
|
||||
imagenReq := map[string]interface{}{
|
||||
"instances": []map[string]interface{}{
|
||||
{"prompt": req.Prompt},
|
||||
},
|
||||
"parameters": map[string]interface{}{
|
||||
"sampleCount": n,
|
||||
"aspectRatio": aspectRatio,
|
||||
},
|
||||
}
|
||||
|
||||
// Model defaults to imagen-3.0-generate-001 if empty
|
||||
model := req.Model
|
||||
if model == "" {
|
||||
model = "imagen-3.0-generate-001"
|
||||
}
|
||||
|
||||
// Use v1beta for Imagen
|
||||
baseURL := p.config.BaseURL
|
||||
if !strings.Contains(baseURL, "v1beta") {
|
||||
baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/models/%s:predict?key=%s", baseURL, model, p.apiKey)
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetBody(imagenReq).
|
||||
Post(url)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gemini imagen request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("Gemini Imagen API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
// Parse Imagen response
|
||||
var imagenResp struct {
|
||||
Predictions []struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
BytesBase64Encoded string `json:"bytesBase64Encoded"`
|
||||
} `json:"predictions"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(resp.Body(), &imagenResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse Imagen response: %w", err)
|
||||
}
|
||||
|
||||
respFormat := "url"
|
||||
if req.ResponseFormat != nil && *req.ResponseFormat == "b64_json" {
|
||||
respFormat = "b64_json"
|
||||
}
|
||||
|
||||
var data []models.ImageData
|
||||
for _, pred := range imagenResp.Predictions {
|
||||
imgData := models.ImageData{}
|
||||
if respFormat == "b64_json" {
|
||||
imgData.B64JSON = pred.BytesBase64Encoded
|
||||
} else {
|
||||
// Build a data URI since Gemini returns base64, not a URL
|
||||
mime := pred.MimeType
|
||||
if mime == "" {
|
||||
mime = "image/png"
|
||||
}
|
||||
imgData.URL = fmt.Sprintf("data:%s;base64,%s", mime, pred.BytesBase64Encoded)
|
||||
}
|
||||
data = append(data, imgData)
|
||||
}
|
||||
|
||||
result := &models.ImageGenerationResponse{
|
||||
Created: time.Now().Unix(),
|
||||
Data: data,
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// sizeToGeminiAspectRatio converts OpenAI size format (e.g. "1024x1024") to Gemini aspect ratio (e.g. "1:1")
|
||||
func sizeToGeminiAspectRatio(size string) string {
|
||||
switch size {
|
||||
case "1024x1024":
|
||||
return "1:1"
|
||||
case "1024x1792":
|
||||
return "9:16"
|
||||
case "1792x1024":
|
||||
return "16:9"
|
||||
case "256x256", "512x512":
|
||||
return "1:1"
|
||||
default:
|
||||
return "1:1"
|
||||
}
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) Responses(ctx context.Context, req *models.ResponsesRequest) (*models.ResponsesResponse, error) {
|
||||
return nil, fmt.Errorf("responses API not supported by gemini")
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) ResponsesStream(ctx context.Context, req *models.ResponsesRequest) (<-chan *models.ResponsesStreamChunk, error) {
|
||||
return nil, fmt.Errorf("responses API not supported by gemini")
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
|
||||
// Gemini mapping
|
||||
var contents []GeminiContent
|
||||
|
||||
for i := 0; i < len(req.Messages); i++ {
|
||||
msg := req.Messages[i]
|
||||
|
||||
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
|
||||
// 1. Add the assistant (model) message with tool calls
|
||||
parts := []GeminiPart{}
|
||||
for _, cp := range msg.Content {
|
||||
if cp.Type == "text" && cp.Text != "" {
|
||||
parts = append(parts, GeminiPart{Text: cp.Text})
|
||||
}
|
||||
}
|
||||
for _, tc := range msg.ToolCalls {
|
||||
parts = append(parts, GeminiPart{
|
||||
FunctionCall: &GeminiFunctionCall{
|
||||
Name: tc.Function.Name,
|
||||
Args: json.RawMessage(tc.Function.Arguments),
|
||||
},
|
||||
})
|
||||
}
|
||||
contents = append(contents, GeminiContent{Role: "model", Parts: parts})
|
||||
|
||||
// 2. The VERY NEXT message MUST be the "function" results for THESE EXACT calls.
|
||||
// Look ahead for tool messages.
|
||||
var functionParts []GeminiPart
|
||||
toolCallIDs := make(map[string]bool)
|
||||
for _, tc := range msg.ToolCalls {
|
||||
toolCallIDs[tc.ID] = true
|
||||
}
|
||||
|
||||
// We need to find tool messages that correspond to these calls.
|
||||
// In many patterns, they follow immediately.
|
||||
j := i + 1
|
||||
foundAny := false
|
||||
for j < len(req.Messages) && req.Messages[j].Role == "tool" {
|
||||
m := req.Messages[j]
|
||||
|
||||
// Try to match by ID or just take them in order if IDs are missing/mismatched
|
||||
// Gemini is strict: you must respond to EVERY call in the previous message.
|
||||
text := ""
|
||||
if len(m.Content) > 0 {
|
||||
text = m.Content[0].Text
|
||||
}
|
||||
name := "unknown_function"
|
||||
if m.Name != nil {
|
||||
name = *m.Name
|
||||
}
|
||||
|
||||
var responseObj interface{}
|
||||
if err := json.Unmarshal([]byte(text), &responseObj); err != nil {
|
||||
responseObj = map[string]interface{}{"result": text}
|
||||
}
|
||||
respBytes, _ := json.Marshal(responseObj)
|
||||
|
||||
functionParts = append(functionParts, GeminiPart{
|
||||
FunctionResponse: &GeminiFunctionResponse{
|
||||
Name: name,
|
||||
Response: json.RawMessage(respBytes),
|
||||
},
|
||||
})
|
||||
foundAny = true
|
||||
j++
|
||||
}
|
||||
|
||||
if foundAny {
|
||||
contents = append(contents, GeminiContent{Role: "function", Parts: functionParts})
|
||||
i = j - 1 // Advance outer loop past the tool messages we consumed
|
||||
} else {
|
||||
// If no tool results found but assistant made calls, Gemini WILL error.
|
||||
// We should probably skip the calls or provide dummy results,
|
||||
// but usually this means the conversation is incomplete.
|
||||
// For now, don't add a "function" message if none found.
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Standard message handling (System/User/Assistant without tools)
|
||||
role := "user"
|
||||
if msg.Role == "assistant" {
|
||||
role = "model"
|
||||
} else if msg.Role == "system" {
|
||||
role = "user" // Gemini uses 'user' for system prompts in some versions, or handles it via systemInstruction
|
||||
} else if msg.Role == "tool" {
|
||||
// Orphaned tool message (not following an assistant call) - Gemini doesn't like this.
|
||||
// Skip or map to user? Skipping is safer for API stability.
|
||||
continue
|
||||
}
|
||||
|
||||
var parts []GeminiPart
|
||||
for _, cp := range msg.Content {
|
||||
if cp.Type == "text" && cp.Text != "" {
|
||||
parts = append(parts, GeminiPart{Text: cp.Text})
|
||||
} else if cp.Image != nil {
|
||||
base64Data, mimeType, _ := cp.Image.ToBase64()
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
MimeType: mimeType,
|
||||
Data: base64Data,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(parts) > 0 {
|
||||
contents = append(contents, GeminiContent{Role: role, Parts: parts})
|
||||
}
|
||||
}
|
||||
|
||||
genConfig := &GeminiGenerationConfig{}
|
||||
if req.Temperature != nil {
|
||||
t := float32(*req.Temperature)
|
||||
genConfig.Temperature = &t
|
||||
}
|
||||
if req.TopP != nil {
|
||||
tp := float32(*req.TopP)
|
||||
genConfig.TopP = &tp
|
||||
}
|
||||
if req.TopK != nil {
|
||||
tk := int(*req.TopK)
|
||||
genConfig.TopK = &tk
|
||||
}
|
||||
if req.MaxTokens != nil {
|
||||
mt := int(*req.MaxTokens)
|
||||
genConfig.MaxOutputTokens = &mt
|
||||
}
|
||||
if len(req.Stop) > 0 {
|
||||
genConfig.StopSequences = req.Stop
|
||||
}
|
||||
|
||||
body := GeminiRequest{
|
||||
Contents: contents,
|
||||
GenerationConfig: genConfig,
|
||||
}
|
||||
|
||||
// Map Tools
|
||||
if len(req.Tools) > 0 {
|
||||
geminiTool := GeminiTool{FunctionDeclarations: []models.FunctionDef{}}
|
||||
for _, t := range req.Tools {
|
||||
if t.Type == "function" {
|
||||
geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, t.Function)
|
||||
}
|
||||
}
|
||||
body.Tools = []GeminiTool{geminiTool}
|
||||
}
|
||||
|
||||
baseURL := p.config.BaseURL
|
||||
lowerModel := strings.ToLower(req.Model)
|
||||
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") {
|
||||
// Use v1beta for preview and newer models
|
||||
if !strings.Contains(baseURL, "v1beta") {
|
||||
baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1)
|
||||
}
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/models/%s:generateContent?key=%s", baseURL, req.Model, p.apiKey)
|
||||
fmt.Printf("[Gemini] POST %s\n", url)
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetBody(body).
|
||||
Post(url)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
fmt.Printf("[Gemini] API Error %d: %s\n", resp.StatusCode(), resp.String())
|
||||
// Also log the request body for debugging (careful with API keys if logged elsewhere)
|
||||
reqJSON, _ := json.Marshal(body)
|
||||
fmt.Printf("[Gemini] Request Body: %s\n", string(reqJSON))
|
||||
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
// Parse Gemini response and convert to OpenAI format
|
||||
var geminiResp struct {
|
||||
Candidates []struct {
|
||||
Content struct {
|
||||
Role string `json:"role"`
|
||||
Parts []struct {
|
||||
Text string `json:"text"`
|
||||
FunctionCall *GeminiFunctionCall `json:"functionCall"`
|
||||
} `json:"parts"`
|
||||
} `json:"content"`
|
||||
FinishReason string `json:"finishReason"`
|
||||
} `json:"candidates"`
|
||||
UsageMetadata struct {
|
||||
PromptTokenCount uint32 `json:"promptTokenCount"`
|
||||
CandidatesTokenCount uint32 `json:"candidatesTokenCount"`
|
||||
TotalTokenCount uint32 `json:"totalTokenCount"`
|
||||
CachedContentTokenCount uint32 `json:"cachedContentTokenCount"`
|
||||
} `json:"usageMetadata"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(resp.Body(), &geminiResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
if len(geminiResp.Candidates) == 0 {
|
||||
return nil, fmt.Errorf("no candidates in Gemini response")
|
||||
}
|
||||
|
||||
content := ""
|
||||
var toolCalls []models.ToolCall
|
||||
for _, part := range geminiResp.Candidates[0].Content.Parts {
|
||||
if part.Text != "" {
|
||||
content += part.Text
|
||||
}
|
||||
if part.FunctionCall != nil {
|
||||
toolCalls = append(toolCalls, models.ToolCall{
|
||||
ID: fmt.Sprintf("call_%s", part.FunctionCall.Name), // Gemini doesn't have call IDs
|
||||
Type: "function",
|
||||
Function: models.FunctionCall{
|
||||
Name: part.FunctionCall.Name,
|
||||
Arguments: string(part.FunctionCall.Args),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
finishReason := strings.ToLower(geminiResp.Candidates[0].FinishReason)
|
||||
if finishReason == "stop" {
|
||||
finishReason = "stop"
|
||||
} else if len(toolCalls) > 0 {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
|
||||
openAIResp := &models.ChatCompletionResponse{
|
||||
ID: "gemini-" + req.Model,
|
||||
Object: "chat.completion",
|
||||
Created: 0,
|
||||
Model: req.Model,
|
||||
Choices: []models.ChatChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Message: models.ChatMessage{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
ToolCalls: toolCalls,
|
||||
},
|
||||
FinishReason: &finishReason,
|
||||
},
|
||||
},
|
||||
Usage: &models.Usage{
|
||||
PromptTokens: geminiResp.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: geminiResp.UsageMetadata.CandidatesTokenCount,
|
||||
TotalTokens: geminiResp.UsageMetadata.TotalTokenCount,
|
||||
CacheReadTokens: uint32Ptr(geminiResp.UsageMetadata.CachedContentTokenCount),
|
||||
},
|
||||
}
|
||||
|
||||
return openAIResp, nil
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
|
||||
// Simplified Gemini mapping
|
||||
var contents []GeminiContent
|
||||
for i := 0; i < len(req.Messages); i++ {
|
||||
msg := req.Messages[i]
|
||||
|
||||
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
|
||||
parts := []GeminiPart{}
|
||||
for _, cp := range msg.Content {
|
||||
if cp.Type == "text" && cp.Text != "" {
|
||||
parts = append(parts, GeminiPart{Text: cp.Text})
|
||||
}
|
||||
}
|
||||
for _, tc := range msg.ToolCalls {
|
||||
parts = append(parts, GeminiPart{
|
||||
FunctionCall: &GeminiFunctionCall{
|
||||
Name: tc.Function.Name,
|
||||
Args: json.RawMessage(tc.Function.Arguments),
|
||||
},
|
||||
})
|
||||
}
|
||||
contents = append(contents, GeminiContent{Role: "model", Parts: parts})
|
||||
|
||||
var functionParts []GeminiPart
|
||||
j := i + 1
|
||||
foundAny := false
|
||||
for j < len(req.Messages) && req.Messages[j].Role == "tool" {
|
||||
m := req.Messages[j]
|
||||
text := ""
|
||||
if len(m.Content) > 0 {
|
||||
text = m.Content[0].Text
|
||||
}
|
||||
name := "unknown_function"
|
||||
if m.Name != nil {
|
||||
name = *m.Name
|
||||
}
|
||||
|
||||
var responseObj interface{}
|
||||
if err := json.Unmarshal([]byte(text), &responseObj); err != nil {
|
||||
responseObj = map[string]interface{}{"result": text}
|
||||
}
|
||||
respBytes, _ := json.Marshal(responseObj)
|
||||
|
||||
functionParts = append(functionParts, GeminiPart{
|
||||
FunctionResponse: &GeminiFunctionResponse{
|
||||
Name: name,
|
||||
Response: json.RawMessage(respBytes),
|
||||
},
|
||||
})
|
||||
foundAny = true
|
||||
j++
|
||||
}
|
||||
|
||||
if foundAny {
|
||||
contents = append(contents, GeminiContent{Role: "function", Parts: functionParts})
|
||||
i = j - 1
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
role := "user"
|
||||
if msg.Role == "assistant" {
|
||||
role = "model"
|
||||
} else if msg.Role == "system" {
|
||||
role = "user"
|
||||
} else if msg.Role == "tool" {
|
||||
continue
|
||||
}
|
||||
|
||||
var parts []GeminiPart
|
||||
for _, cp := range msg.Content {
|
||||
if cp.Type == "text" && cp.Text != "" {
|
||||
parts = append(parts, GeminiPart{Text: cp.Text})
|
||||
} else if cp.Image != nil {
|
||||
base64Data, mimeType, _ := cp.Image.ToBase64()
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
MimeType: mimeType,
|
||||
Data: base64Data,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(parts) > 0 {
|
||||
contents = append(contents, GeminiContent{Role: role, Parts: parts})
|
||||
}
|
||||
}
|
||||
|
||||
genConfig := &GeminiGenerationConfig{}
|
||||
if req.Temperature != nil {
|
||||
t := float32(*req.Temperature)
|
||||
genConfig.Temperature = &t
|
||||
}
|
||||
if req.TopP != nil {
|
||||
tp := float32(*req.TopP)
|
||||
genConfig.TopP = &tp
|
||||
}
|
||||
if req.TopK != nil {
|
||||
tk := int(*req.TopK)
|
||||
genConfig.TopK = &tk
|
||||
}
|
||||
if req.MaxTokens != nil {
|
||||
mt := int(*req.MaxTokens)
|
||||
genConfig.MaxOutputTokens = &mt
|
||||
}
|
||||
if len(req.Stop) > 0 {
|
||||
genConfig.StopSequences = req.Stop
|
||||
}
|
||||
|
||||
body := GeminiRequest{
|
||||
Contents: contents,
|
||||
GenerationConfig: genConfig,
|
||||
}
|
||||
|
||||
if len(req.Tools) > 0 {
|
||||
geminiTool := GeminiTool{FunctionDeclarations: []models.FunctionDef{}}
|
||||
for _, t := range req.Tools {
|
||||
if t.Type == "function" {
|
||||
geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, t.Function)
|
||||
}
|
||||
}
|
||||
body.Tools = []GeminiTool{geminiTool}
|
||||
}
|
||||
|
||||
baseURL := p.config.BaseURL
|
||||
lowerModel := strings.ToLower(req.Model)
|
||||
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") {
|
||||
// Use v1beta for preview and newer models
|
||||
if !strings.Contains(baseURL, "v1beta") {
|
||||
baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1)
|
||||
}
|
||||
}
|
||||
|
||||
// Use streamGenerateContent for streaming
|
||||
url := fmt.Sprintf("%s/models/%s:streamGenerateContent?key=%s", baseURL, req.Model, p.apiKey)
|
||||
fmt.Printf("[Gemini-Stream] POST %s\n", url)
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetBody(body).
|
||||
SetDoNotParseResponse(true).
|
||||
Post(url)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
ch := make(chan *models.ChatCompletionStreamResponse)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
err := StreamGemini(resp.RawBody(), ch, req.Model)
|
||||
if err != nil {
|
||||
fmt.Printf("Gemini Stream error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func uint32Ptr(v uint32) *uint32 {
|
||||
if v > 0 {
|
||||
return &v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,108 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
"gophergate/internal/config"
|
||||
"gophergate/internal/models"
|
||||
)
|
||||
|
||||
type GrokProvider struct {
|
||||
client *resty.Client
|
||||
config config.GrokConfig
|
||||
apiKey string
|
||||
}
|
||||
|
||||
func NewGrokProvider(cfg config.GrokConfig, apiKey string) *GrokProvider {
|
||||
return &GrokProvider{
|
||||
client: resty.New().SetTimeout(10 * time.Minute),
|
||||
config: cfg,
|
||||
apiKey: apiKey,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *GrokProvider) Name() string {
|
||||
return "grok"
|
||||
}
|
||||
|
||||
func (p *GrokProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
|
||||
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||
}
|
||||
|
||||
body := BuildOpenAIBody(req, messagesJSON, false)
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||
SetBody(body).
|
||||
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("Grok API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
var respJSON map[string]interface{}
|
||||
if err := json.Unmarshal(resp.Body(), &respJSON); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
return ParseOpenAIResponse(respJSON, req.Model)
|
||||
}
|
||||
|
||||
func (p *GrokProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
|
||||
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||
}
|
||||
|
||||
body := BuildOpenAIBody(req, messagesJSON, true)
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||
SetBody(body).
|
||||
SetDoNotParseResponse(true).
|
||||
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("Grok API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
ch := make(chan *models.ChatCompletionStreamResponse)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
err := StreamOpenAI(resp.RawBody(), ch)
|
||||
if err != nil {
|
||||
fmt.Printf("Grok Stream error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (p *GrokProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) {
|
||||
return nil, fmt.Errorf("grok does not support image generation")
|
||||
}
|
||||
|
||||
func (p *GrokProvider) Responses(ctx context.Context, req *models.ResponsesRequest) (*models.ResponsesResponse, error) {
|
||||
return nil, fmt.Errorf("responses API not supported by grok")
|
||||
}
|
||||
|
||||
func (p *GrokProvider) ResponsesStream(ctx context.Context, req *models.ResponsesRequest) (<-chan *models.ResponsesStreamChunk, error) {
|
||||
return nil, fmt.Errorf("responses API not supported by grok")
|
||||
}
|
||||
@@ -0,0 +1,447 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"gophergate/internal/models"
|
||||
)
|
||||
|
||||
// MessagesToOpenAIJSON converts unified messages to OpenAI-compatible JSON, including tools and images.
|
||||
func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, error) {
|
||||
var result []interface{}
|
||||
for _, m := range messages {
|
||||
if m.Role == "tool" {
|
||||
text := ""
|
||||
if len(m.Content) > 0 {
|
||||
text = m.Content[0].Text
|
||||
}
|
||||
msg := map[string]interface{}{
|
||||
"role": "tool",
|
||||
"content": text,
|
||||
}
|
||||
if m.ToolCallID != nil {
|
||||
id := *m.ToolCallID
|
||||
if len(id) > 40 {
|
||||
id = id[:40]
|
||||
}
|
||||
msg["tool_call_id"] = id
|
||||
}
|
||||
if m.Name != nil {
|
||||
msg["name"] = *m.Name
|
||||
}
|
||||
result = append(result, msg)
|
||||
continue
|
||||
}
|
||||
|
||||
var parts []interface{}
|
||||
for _, p := range m.Content {
|
||||
if p.Type == "text" {
|
||||
parts = append(parts, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": p.Text,
|
||||
})
|
||||
} else if p.Image != nil {
|
||||
base64Data, mimeType, err := p.Image.ToBase64()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert image to base64: %w", err)
|
||||
}
|
||||
parts = append(parts, map[string]interface{}{
|
||||
"type": "image_url",
|
||||
"image_url": map[string]interface{}{
|
||||
"url": fmt.Sprintf("data:%s;base64,%s", mimeType, base64Data),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var finalContent interface{}
|
||||
if len(parts) == 1 {
|
||||
if p, ok := parts[0].(map[string]interface{}); ok && p["type"] == "text" {
|
||||
finalContent = p["text"]
|
||||
} else {
|
||||
finalContent = parts
|
||||
}
|
||||
} else {
|
||||
finalContent = parts
|
||||
}
|
||||
|
||||
msg := map[string]interface{}{
|
||||
"role": m.Role,
|
||||
"content": finalContent,
|
||||
}
|
||||
|
||||
if m.ReasoningContent != nil {
|
||||
msg["reasoning_content"] = *m.ReasoningContent
|
||||
}
|
||||
|
||||
if len(m.ToolCalls) > 0 {
|
||||
sanitizedCalls := make([]models.ToolCall, len(m.ToolCalls))
|
||||
copy(sanitizedCalls, m.ToolCalls)
|
||||
for i := range sanitizedCalls {
|
||||
if len(sanitizedCalls[i].ID) > 40 {
|
||||
sanitizedCalls[i].ID = sanitizedCalls[i].ID[:40]
|
||||
}
|
||||
}
|
||||
msg["tool_calls"] = sanitizedCalls
|
||||
if len(parts) == 0 {
|
||||
msg["content"] = ""
|
||||
}
|
||||
}
|
||||
|
||||
if m.Name != nil {
|
||||
msg["name"] = *m.Name
|
||||
}
|
||||
|
||||
result = append(result, msg)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func BuildOpenAIBody(request *models.UnifiedRequest, messagesJSON []interface{}, stream bool) map[string]interface{} {
|
||||
body := map[string]interface{}{
|
||||
"model": request.Model,
|
||||
"messages": messagesJSON,
|
||||
"stream": stream,
|
||||
}
|
||||
|
||||
if stream {
|
||||
body["stream_options"] = map[string]interface{}{
|
||||
"include_usage": true,
|
||||
}
|
||||
}
|
||||
|
||||
if request.Temperature != nil {
|
||||
body["temperature"] = *request.Temperature
|
||||
}
|
||||
if request.MaxTokens != nil {
|
||||
body["max_tokens"] = *request.MaxTokens
|
||||
}
|
||||
if len(request.Tools) > 0 {
|
||||
body["tools"] = request.Tools
|
||||
}
|
||||
if request.ToolChoice != nil {
|
||||
var toolChoice interface{}
|
||||
if err := json.Unmarshal(request.ToolChoice, &toolChoice); err == nil {
|
||||
body["tool_choice"] = toolChoice
|
||||
}
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
// BuildOpenAIResponsesBody builds the request body for the Responses API endpoint.
|
||||
func BuildOpenAIResponsesBody(req *models.ResponsesRequest, stream bool) map[string]interface{} {
|
||||
body := map[string]interface{}{
|
||||
"model": req.Model,
|
||||
"stream": stream,
|
||||
}
|
||||
|
||||
// The input field can be a string or a structured array.
|
||||
// Try to preserve the original format.
|
||||
if req.Input != nil {
|
||||
// Try as string first
|
||||
var inputStr string
|
||||
if err := json.Unmarshal(req.Input, &inputStr); err == nil {
|
||||
body["input"] = inputStr
|
||||
} else {
|
||||
// Try as array of messages
|
||||
var inputArr []interface{}
|
||||
if err := json.Unmarshal(req.Input, &inputArr); err == nil {
|
||||
body["input"] = inputArr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if req.Instructions != "" {
|
||||
body["instructions"] = req.Instructions
|
||||
}
|
||||
if req.Temperature != nil {
|
||||
body["temperature"] = *req.Temperature
|
||||
}
|
||||
if req.MaxOutputTokens != nil {
|
||||
body["max_output_tokens"] = *req.MaxOutputTokens
|
||||
}
|
||||
if req.TopP != nil {
|
||||
body["top_p"] = *req.TopP
|
||||
}
|
||||
if req.Tools != nil {
|
||||
var tools interface{}
|
||||
if err := json.Unmarshal(req.Tools, &tools); err == nil {
|
||||
body["tools"] = tools
|
||||
}
|
||||
}
|
||||
if req.ToolChoice != nil {
|
||||
var toolChoice interface{}
|
||||
if err := json.Unmarshal(req.ToolChoice, &toolChoice); err == nil {
|
||||
body["tool_choice"] = toolChoice
|
||||
}
|
||||
}
|
||||
if req.Store != nil {
|
||||
body["store"] = *req.Store
|
||||
}
|
||||
|
||||
if stream {
|
||||
body["stream_options"] = map[string]interface{}{
|
||||
"include_usage": true,
|
||||
}
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
// ParseOpenAIResponsesResponse parses a raw JSON map into a ResponsesResponse.
|
||||
func ParseOpenAIResponsesResponse(respJSON map[string]interface{}, model string) (*models.ResponsesResponse, error) {
|
||||
data, err := json.Marshal(respJSON)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var resp models.ResponsesResponse
|
||||
if err := json.Unmarshal(data, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Re-parse usage with the detailed tokens
|
||||
if usageData, ok := respJSON["usage"]; ok {
|
||||
var responsesUsage models.ResponsesUsage
|
||||
usageBytes, _ := json.Marshal(usageData)
|
||||
if err := json.Unmarshal(usageBytes, &responsesUsage); err == nil {
|
||||
resp.Usage = &responsesUsage
|
||||
}
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ParseOpenAIResponsesStreamChunk parses a single SSE line into a ResponsesStreamChunk.
|
||||
// Returns the chunk, whether this is the [DONE] signal, and any error.
|
||||
func ParseOpenAIResponsesStreamChunk(line string) (*models.ResponsesStreamChunk, bool, error) {
|
||||
if line == "" {
|
||||
return nil, false, nil
|
||||
}
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
if data == "[DONE]" {
|
||||
return nil, true, nil
|
||||
}
|
||||
|
||||
var chunk models.ResponsesStreamChunk
|
||||
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
||||
return nil, false, fmt.Errorf("failed to unmarshal responses stream chunk: %w", err)
|
||||
}
|
||||
|
||||
return &chunk, false, nil
|
||||
}
|
||||
|
||||
// StreamOpenAIResponses reads SSE chunks from the body and sends them to the channel.
|
||||
func StreamOpenAIResponses(ctx io.ReadCloser, ch chan<- *models.ResponsesStreamChunk) error {
|
||||
defer ctx.Close()
|
||||
scanner := bufio.NewScanner(ctx)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
chunk, done, err := ParseOpenAIResponsesStreamChunk(line)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if done {
|
||||
break
|
||||
}
|
||||
if chunk != nil {
|
||||
ch <- chunk
|
||||
}
|
||||
}
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
type openAIUsage struct {
|
||||
PromptTokens uint32 `json:"prompt_tokens"`
|
||||
CompletionTokens uint32 `json:"completion_tokens"`
|
||||
TotalTokens uint32 `json:"total_tokens"`
|
||||
PromptTokensDetails *struct {
|
||||
CachedTokens uint32 `json:"cached_tokens"`
|
||||
} `json:"prompt_tokens_details"`
|
||||
CompletionTokensDetails *struct {
|
||||
ReasoningTokens uint32 `json:"reasoning_tokens"`
|
||||
} `json:"completion_tokens_details"`
|
||||
}
|
||||
|
||||
func (u *openAIUsage) ToUnified() *models.Usage {
|
||||
usage := &models.Usage{
|
||||
PromptTokens: u.PromptTokens,
|
||||
CompletionTokens: u.CompletionTokens,
|
||||
TotalTokens: u.TotalTokens,
|
||||
}
|
||||
if u.PromptTokensDetails != nil && u.PromptTokensDetails.CachedTokens > 0 {
|
||||
usage.CacheReadTokens = &u.PromptTokensDetails.CachedTokens
|
||||
}
|
||||
if u.CompletionTokensDetails != nil && u.CompletionTokensDetails.ReasoningTokens > 0 {
|
||||
usage.ReasoningTokens = &u.CompletionTokensDetails.ReasoningTokens
|
||||
}
|
||||
return usage
|
||||
}
|
||||
|
||||
func ParseOpenAIResponse(respJSON map[string]interface{}, model string) (*models.ChatCompletionResponse, error) {
|
||||
data, err := json.Marshal(respJSON)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var resp models.ChatCompletionResponse
|
||||
if err := json.Unmarshal(data, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Manually fix usage because ChatCompletionResponse uses the unified Usage struct
|
||||
// but the provider might have returned more details.
|
||||
if usageData, ok := respJSON["usage"]; ok {
|
||||
var oUsage openAIUsage
|
||||
usageBytes, _ := json.Marshal(usageData)
|
||||
if err := json.Unmarshal(usageBytes, &oUsage); err == nil {
|
||||
resp.Usage = oUsage.ToUnified()
|
||||
}
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// Streaming support
|
||||
|
||||
func ParseOpenAIStreamChunk(line string) (*models.ChatCompletionStreamResponse, bool, error) {
|
||||
if line == "" {
|
||||
return nil, false, nil
|
||||
}
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
if data == "[DONE]" {
|
||||
return nil, true, nil
|
||||
}
|
||||
|
||||
var chunk models.ChatCompletionStreamResponse
|
||||
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
||||
return nil, false, fmt.Errorf("failed to unmarshal stream chunk: %w", err)
|
||||
}
|
||||
|
||||
// Handle specialized usage in stream chunks
|
||||
var rawChunk struct {
|
||||
Usage *openAIUsage `json:"usage"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(data), &rawChunk); err == nil && rawChunk.Usage != nil {
|
||||
chunk.Usage = rawChunk.Usage.ToUnified()
|
||||
}
|
||||
|
||||
return &chunk, false, nil
|
||||
}
|
||||
|
||||
func StreamOpenAI(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamResponse) error {
|
||||
defer ctx.Close()
|
||||
scanner := bufio.NewScanner(ctx)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
chunk, done, err := ParseOpenAIStreamChunk(line)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if done {
|
||||
break
|
||||
}
|
||||
if chunk != nil {
|
||||
ch <- chunk
|
||||
}
|
||||
}
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
func StreamGemini(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamResponse, model string) error {
|
||||
defer ctx.Close()
|
||||
|
||||
dec := json.NewDecoder(ctx)
|
||||
|
||||
t, err := dec.Token()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if delim, ok := t.(json.Delim); ok && delim == '[' {
|
||||
for dec.More() {
|
||||
var geminiChunk struct {
|
||||
Candidates []struct {
|
||||
Content struct {
|
||||
Parts []struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
Thought string `json:"thought,omitempty"`
|
||||
} `json:"parts"`
|
||||
} `json:"content"`
|
||||
FinishReason string `json:"finishReason"`
|
||||
} `json:"candidates"`
|
||||
UsageMetadata struct {
|
||||
PromptTokenCount uint32 `json:"promptTokenCount"`
|
||||
CandidatesTokenCount uint32 `json:"candidatesTokenCount"`
|
||||
TotalTokenCount uint32 `json:"totalTokenCount"`
|
||||
CachedContentTokenCount uint32 `json:"cachedContentTokenCount"`
|
||||
} `json:"usageMetadata"`
|
||||
}
|
||||
|
||||
if err := dec.Decode(&geminiChunk); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(geminiChunk.Candidates) > 0 || geminiChunk.UsageMetadata.TotalTokenCount > 0 {
|
||||
content := ""
|
||||
var reasoning *string
|
||||
if len(geminiChunk.Candidates) > 0 {
|
||||
for _, p := range geminiChunk.Candidates[0].Content.Parts {
|
||||
if p.Text != "" {
|
||||
content += p.Text
|
||||
}
|
||||
if p.Thought != "" {
|
||||
if reasoning == nil {
|
||||
reasoning = new(string)
|
||||
}
|
||||
*reasoning += p.Thought
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var finishReason *string
|
||||
if len(geminiChunk.Candidates) > 0 {
|
||||
fr := strings.ToLower(geminiChunk.Candidates[0].FinishReason)
|
||||
finishReason = &fr
|
||||
}
|
||||
|
||||
ch <- &models.ChatCompletionStreamResponse{
|
||||
ID: "gemini-stream",
|
||||
Object: "chat.completion.chunk",
|
||||
Created: 0,
|
||||
Model: model,
|
||||
Choices: []models.ChatStreamChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Delta: models.ChatStreamDelta{
|
||||
Content: &content,
|
||||
ReasoningContent: reasoning,
|
||||
},
|
||||
FinishReason: finishReason,
|
||||
},
|
||||
},
|
||||
Usage: &models.Usage{
|
||||
PromptTokens: geminiChunk.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: geminiChunk.UsageMetadata.CandidatesTokenCount,
|
||||
TotalTokens: geminiChunk.UsageMetadata.TotalTokenCount,
|
||||
CacheReadTokens: uint32Ptr(geminiChunk.UsageMetadata.CachedContentTokenCount),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,127 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
"gophergate/internal/config"
|
||||
"gophergate/internal/models"
|
||||
)
|
||||
|
||||
type MoonshotProvider struct {
|
||||
client *resty.Client
|
||||
config config.MoonshotConfig
|
||||
apiKey string
|
||||
}
|
||||
|
||||
func NewMoonshotProvider(cfg config.MoonshotConfig, apiKey string) *MoonshotProvider {
|
||||
return &MoonshotProvider{
|
||||
client: resty.New().SetTimeout(10 * time.Minute),
|
||||
config: cfg,
|
||||
apiKey: strings.TrimSpace(apiKey),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *MoonshotProvider) Name() string {
|
||||
return "moonshot"
|
||||
}
|
||||
|
||||
func (p *MoonshotProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
|
||||
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||
}
|
||||
|
||||
body := BuildOpenAIBody(req, messagesJSON, false)
|
||||
if strings.Contains(strings.ToLower(req.Model), "kimi-k2.5") {
|
||||
if maxTokens, ok := body["max_tokens"]; ok {
|
||||
delete(body, "max_tokens")
|
||||
body["max_completion_tokens"] = maxTokens
|
||||
}
|
||||
}
|
||||
|
||||
baseURL := strings.TrimRight(p.config.BaseURL, "/")
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetHeader("Accept", "application/json").
|
||||
SetBody(body).
|
||||
Post(fmt.Sprintf("%s/chat/completions", baseURL))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("Moonshot API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
var respJSON map[string]interface{}
|
||||
if err := json.Unmarshal(resp.Body(), &respJSON); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
return ParseOpenAIResponse(respJSON, req.Model)
|
||||
}
|
||||
|
||||
func (p *MoonshotProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
|
||||
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||
}
|
||||
|
||||
body := BuildOpenAIBody(req, messagesJSON, true)
|
||||
if strings.Contains(strings.ToLower(req.Model), "kimi-k2.5") {
|
||||
if maxTokens, ok := body["max_tokens"]; ok {
|
||||
delete(body, "max_tokens")
|
||||
body["max_completion_tokens"] = maxTokens
|
||||
}
|
||||
}
|
||||
|
||||
baseURL := strings.TrimRight(p.config.BaseURL, "/")
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetHeader("Accept", "text/event-stream").
|
||||
SetBody(body).
|
||||
SetDoNotParseResponse(true).
|
||||
Post(fmt.Sprintf("%s/chat/completions", baseURL))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("Moonshot API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
ch := make(chan *models.ChatCompletionStreamResponse)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
if err := StreamOpenAI(resp.RawBody(), ch); err != nil {
|
||||
fmt.Printf("Moonshot Stream error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (p *MoonshotProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) {
|
||||
return nil, fmt.Errorf("moonshot does not support image generation")
|
||||
}
|
||||
|
||||
func (p *MoonshotProvider) Responses(ctx context.Context, req *models.ResponsesRequest) (*models.ResponsesResponse, error) {
|
||||
return nil, fmt.Errorf("responses API not supported by moonshot")
|
||||
}
|
||||
|
||||
func (p *MoonshotProvider) ResponsesStream(ctx context.Context, req *models.ResponsesRequest) (<-chan *models.ResponsesStreamChunk, error) {
|
||||
return nil, fmt.Errorf("responses API not supported by moonshot")
|
||||
}
|
||||
@@ -0,0 +1,263 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
"gophergate/internal/config"
|
||||
"gophergate/internal/models"
|
||||
)
|
||||
|
||||
type OllamaProvider struct {
|
||||
client *resty.Client
|
||||
config config.OllamaConfig
|
||||
}
|
||||
|
||||
func NewOllamaProvider(cfg config.OllamaConfig) *OllamaProvider {
|
||||
client := resty.New()
|
||||
// Set reasonable timeouts for local Ollama server (longer for larger models)
|
||||
// For streaming, we want a very long timeout or none at all to handle generation time
|
||||
client.SetTimeout(15 * time.Minute)
|
||||
client.SetRetryCount(2)
|
||||
client.SetRetryWaitTime(1 * time.Second)
|
||||
|
||||
return &OllamaProvider{
|
||||
client: client,
|
||||
config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *OllamaProvider) Name() string {
|
||||
return "ollama"
|
||||
}
|
||||
|
||||
func (p *OllamaProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
|
||||
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||
}
|
||||
|
||||
body := BuildOllamaBody(req, messagesJSON, false)
|
||||
url := fmt.Sprintf("%s/chat/completions", p.config.BaseURL)
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetBody(body).
|
||||
Post(url)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("Ollama API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
var respJSON map[string]interface{}
|
||||
if err := json.Unmarshal(resp.Body(), &respJSON); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
return ParseOllamaResponse(respJSON, req.Model)
|
||||
}
|
||||
|
||||
func (p *OllamaProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
|
||||
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||
}
|
||||
|
||||
body := BuildOllamaBody(req, messagesJSON, true)
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetBody(body).
|
||||
SetDoNotParseResponse(true).
|
||||
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("Ollama API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
ch := make(chan *models.ChatCompletionStreamResponse)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
err := StreamOllama(resp.RawBody(), ch, req.Model)
|
||||
if err != nil {
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func BuildOllamaBody(request *models.UnifiedRequest, messagesJSON []interface{}, stream bool) map[string]interface{} {
|
||||
body := map[string]interface{}{
|
||||
"model": request.Model,
|
||||
"messages": messagesJSON,
|
||||
"stream": stream,
|
||||
}
|
||||
|
||||
options := make(map[string]interface{})
|
||||
modelLower := strings.ToLower(request.Model)
|
||||
|
||||
// Context window size (default 8k for all, 32k+ for modern large-context models)
|
||||
ctxSize := 8192
|
||||
if strings.Contains(modelLower, "llama") ||
|
||||
strings.Contains(modelLower, "gemma") ||
|
||||
strings.Contains(modelLower, "mistral") ||
|
||||
strings.Contains(modelLower, "mixtral") ||
|
||||
strings.Contains(modelLower, "qwen") ||
|
||||
strings.Contains(modelLower, "deepseek") ||
|
||||
strings.Contains(modelLower, "command-r") ||
|
||||
strings.Contains(modelLower, "phi") {
|
||||
ctxSize = 32768
|
||||
}
|
||||
options["num_ctx"] = ctxSize
|
||||
|
||||
if request.Temperature != nil {
|
||||
body["temperature"] = *request.Temperature
|
||||
options["temperature"] = *request.Temperature
|
||||
}
|
||||
|
||||
if request.MaxTokens != nil {
|
||||
body["max_tokens"] = *request.MaxTokens
|
||||
options["num_predict"] = *request.MaxTokens
|
||||
} else {
|
||||
// Default to 8192 for all Ollama models if not specified,
|
||||
// as Ollama's compatibility layer defaults to 128 if neither
|
||||
// max_tokens nor num_predict are provided.
|
||||
body["max_tokens"] = 8192
|
||||
options["num_predict"] = 8192
|
||||
}
|
||||
|
||||
if request.TopP != nil {
|
||||
body["top_p"] = *request.TopP
|
||||
options["top_p"] = *request.TopP
|
||||
}
|
||||
if request.TopK != nil {
|
||||
body["top_k"] = *request.TopK
|
||||
options["top_k"] = *request.TopK
|
||||
}
|
||||
|
||||
if len(request.Stop) > 0 {
|
||||
body["stop"] = request.Stop
|
||||
options["stop"] = request.Stop
|
||||
}
|
||||
|
||||
if len(options) > 0 {
|
||||
body["options"] = options
|
||||
}
|
||||
|
||||
if len(request.Tools) > 0 {
|
||||
body["tools"] = request.Tools
|
||||
// Explicitly set tool_choice to auto if tools are present but choice is not specified
|
||||
if request.ToolChoice == nil {
|
||||
body["tool_choice"] = "auto"
|
||||
}
|
||||
}
|
||||
if request.ToolChoice != nil {
|
||||
var toolChoice interface{}
|
||||
if err := json.Unmarshal(request.ToolChoice, &toolChoice); err == nil {
|
||||
body["tool_choice"] = toolChoice
|
||||
}
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
func ParseOllamaResponse(respJSON map[string]interface{}, model string) (*models.ChatCompletionResponse, error) {
|
||||
data, err := json.Marshal(respJSON)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var resp models.ChatCompletionResponse
|
||||
if err := json.Unmarshal(data, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if usageData, ok := respJSON["usage"]; ok {
|
||||
var usage models.Usage
|
||||
usageBytes, _ := json.Marshal(usageData)
|
||||
if err := json.Unmarshal(usageBytes, &usage); err == nil {
|
||||
resp.Usage = &usage
|
||||
}
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func ParseOllamaStreamChunk(line string) (*models.ChatCompletionStreamResponse, bool, error) {
|
||||
if line == "" {
|
||||
return nil, false, nil
|
||||
}
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
if data == "[DONE]" {
|
||||
return nil, true, nil
|
||||
}
|
||||
|
||||
var chunk models.ChatCompletionStreamResponse
|
||||
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
||||
return nil, false, fmt.Errorf("failed to unmarshal stream chunk: %w", err)
|
||||
}
|
||||
|
||||
var rawChunk struct {
|
||||
Usage *models.Usage `json:"usage"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(data), &rawChunk); err == nil && rawChunk.Usage != nil {
|
||||
chunk.Usage = rawChunk.Usage
|
||||
}
|
||||
|
||||
return &chunk, false, nil
|
||||
}
|
||||
|
||||
func StreamOllama(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamResponse, model string) error {
|
||||
defer ctx.Close()
|
||||
scanner := bufio.NewScanner(ctx)
|
||||
// Set a larger buffer for scanning to handle large chunks if they occur
|
||||
const maxCapacity = 10 * 1024 * 1024 // 10MB
|
||||
buf := make([]byte, 64*1024)
|
||||
scanner.Buffer(buf, maxCapacity)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
chunk, done, err := ParseOllamaStreamChunk(line)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if done {
|
||||
break
|
||||
}
|
||||
if chunk != nil {
|
||||
ch <- chunk
|
||||
}
|
||||
}
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
func (p *OllamaProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) {
|
||||
return nil, fmt.Errorf("ollama does not support image generation")
|
||||
}
|
||||
|
||||
func (p *OllamaProvider) Responses(ctx context.Context, req *models.ResponsesRequest) (*models.ResponsesResponse, error) {
|
||||
return nil, fmt.Errorf("responses API not supported by ollama")
|
||||
}
|
||||
|
||||
func (p *OllamaProvider) ResponsesStream(ctx context.Context, req *models.ResponsesRequest) (<-chan *models.ResponsesStreamChunk, error) {
|
||||
return nil, fmt.Errorf("responses API not supported by ollama")
|
||||
}
|
||||
@@ -0,0 +1,161 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
"gophergate/internal/config"
|
||||
"gophergate/internal/models"
|
||||
)
|
||||
|
||||
type OpenAIProvider struct {
|
||||
client *resty.Client
|
||||
config config.OpenAIConfig
|
||||
apiKey string
|
||||
}
|
||||
|
||||
func NewOpenAIProvider(cfg config.OpenAIConfig, apiKey string) *OpenAIProvider {
|
||||
return &OpenAIProvider{
|
||||
client: resty.New().SetTimeout(10 * time.Minute),
|
||||
config: cfg,
|
||||
apiKey: apiKey,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) Name() string {
|
||||
return "openai"
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
|
||||
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||
}
|
||||
|
||||
body := BuildOpenAIBody(req, messagesJSON, false)
|
||||
|
||||
// Transition: Newer models require max_completion_tokens
|
||||
if strings.HasPrefix(req.Model, "o1-") || strings.HasPrefix(req.Model, "o3-") || strings.Contains(req.Model, "gpt-5") {
|
||||
if maxTokens, ok := body["max_tokens"]; ok {
|
||||
delete(body, "max_tokens")
|
||||
body["max_completion_tokens"] = maxTokens
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||
SetBody(body).
|
||||
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("OpenAI API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
var respJSON map[string]interface{}
|
||||
if err := json.Unmarshal(resp.Body(), &respJSON); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
return ParseOpenAIResponse(respJSON, req.Model)
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) {
|
||||
body := map[string]interface{}{
|
||||
"prompt": req.Prompt,
|
||||
"model": req.Model,
|
||||
}
|
||||
|
||||
if req.N != nil {
|
||||
body["n"] = *req.N
|
||||
}
|
||||
if req.Quality != nil {
|
||||
body["quality"] = *req.Quality
|
||||
}
|
||||
if req.ResponseFormat != nil {
|
||||
body["response_format"] = *req.ResponseFormat
|
||||
}
|
||||
if req.Size != nil {
|
||||
body["size"] = *req.Size
|
||||
}
|
||||
if req.Style != nil {
|
||||
body["style"] = *req.Style
|
||||
}
|
||||
if req.User != nil {
|
||||
body["user"] = *req.User
|
||||
}
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||
SetBody(body).
|
||||
Post(fmt.Sprintf("%s/images/generations", p.config.BaseURL))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("OpenAI image API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
var result models.ImageGenerationResponse
|
||||
if err := json.Unmarshal(resp.Body(), &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
|
||||
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||
}
|
||||
|
||||
body := BuildOpenAIBody(req, messagesJSON, true)
|
||||
|
||||
// Transition: Newer models require max_completion_tokens
|
||||
if strings.HasPrefix(req.Model, "o1-") || strings.HasPrefix(req.Model, "o3-") || strings.Contains(req.Model, "gpt-5") {
|
||||
if maxTokens, ok := body["max_tokens"]; ok {
|
||||
delete(body, "max_tokens")
|
||||
body["max_completion_tokens"] = maxTokens
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||
SetBody(body).
|
||||
SetDoNotParseResponse(true).
|
||||
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("OpenAI API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
ch := make(chan *models.ChatCompletionStreamResponse)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
err := StreamOpenAI(resp.RawBody(), ch)
|
||||
if err != nil {
|
||||
// In a real app, you might want to send an error chunk or log it
|
||||
fmt.Printf("Stream error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"gophergate/internal/models"
|
||||
)
|
||||
|
||||
// Responses sends a non-streaming request to OpenAI's /v1/responses endpoint.
|
||||
func (p *OpenAIProvider) Responses(ctx context.Context, req *models.ResponsesRequest) (*models.ResponsesResponse, error) {
|
||||
// Determine if streaming was requested
|
||||
stream := req.Stream != nil && *req.Stream
|
||||
|
||||
body := BuildOpenAIResponsesBody(req, stream)
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||
SetBody(body).
|
||||
Post(fmt.Sprintf("%s/responses", p.config.BaseURL))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("responses request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("OpenAI Responses API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
var respJSON map[string]interface{}
|
||||
if err := json.Unmarshal(resp.Body(), &respJSON); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse responses response: %w", err)
|
||||
}
|
||||
|
||||
return ParseOpenAIResponsesResponse(respJSON, req.Model)
|
||||
}
|
||||
|
||||
// ResponsesStream sends a streaming request to OpenAI's /v1/responses endpoint.
|
||||
func (p *OpenAIProvider) ResponsesStream(ctx context.Context, req *models.ResponsesRequest) (<-chan *models.ResponsesStreamChunk, error) {
|
||||
body := BuildOpenAIResponsesBody(req, true)
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||
SetBody(body).
|
||||
SetDoNotParseResponse(true).
|
||||
Post(fmt.Sprintf("%s/responses", p.config.BaseURL))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("responses stream request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("OpenAI Responses API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
ch := make(chan *models.ResponsesStreamChunk)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
err := StreamOpenAIResponses(resp.RawBody(), ch)
|
||||
if err != nil {
|
||||
fmt.Printf("Responses stream error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gophergate/internal/models"
|
||||
)
|
||||
|
||||
type Provider interface {
|
||||
Name() string
|
||||
ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error)
|
||||
ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error)
|
||||
ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error)
|
||||
Responses(ctx context.Context, req *models.ResponsesRequest) (*models.ResponsesResponse, error)
|
||||
ResponsesStream(ctx context.Context, req *models.ResponsesRequest) (<-chan *models.ResponsesStreamChunk, error)
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gophergate/internal/db"
|
||||
)
|
||||
|
||||
const classifierSystemPrompt = `You are a task complexity classifier. Rate the following user message on a scale of 1 to %d, where:
|
||||
1 = trivial/simple (basic facts, greetings, simple math)
|
||||
%d = highly complex (multi-step reasoning, code generation, architecture design)
|
||||
|
||||
Reply with ONLY the number. No explanation.`
|
||||
|
||||
func routeClassifier(ctx context.Context, classify ClassifierFunc, group db.ModelGroup, targets []string, userMessage string) (*Decision, error) {
|
||||
maxRating := len(targets)
|
||||
if maxRating < 2 {
|
||||
maxRating = 2
|
||||
}
|
||||
|
||||
prompt := fmt.Sprintf(classifierSystemPrompt, maxRating, maxRating)
|
||||
ratingStr, err := classify(ctx, getSelectorModel(group, targets), prompt, userMessage)
|
||||
if err != nil {
|
||||
// Classifier failed — fall back to heuristic
|
||||
return routeHeuristic(group, targets, userMessage)
|
||||
}
|
||||
|
||||
rating, err := strconv.Atoi(strings.TrimSpace(ratingStr))
|
||||
if err != nil || rating < 1 {
|
||||
rating = 1
|
||||
}
|
||||
if rating > maxRating {
|
||||
rating = maxRating
|
||||
}
|
||||
|
||||
idx := rating - 1 // 0-based index into targets
|
||||
return &Decision{
|
||||
SelectedModel: targets[idx],
|
||||
Strategy: "classifier",
|
||||
Reason: fmt.Sprintf("complexity rating: %d/%d", rating, maxRating),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func getSelectorModel(group db.ModelGroup, targets []string) string {
|
||||
if group.SelectorModel != nil && *group.SelectorModel != "" {
|
||||
return *group.SelectorModel
|
||||
}
|
||||
// Default: use the first (cheapest) target model as the selector
|
||||
return targets[0]
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"gophergate/internal/db"
|
||||
)
|
||||
|
||||
// HeuristicRule defines a pattern-based routing rule.
|
||||
type HeuristicRule struct {
|
||||
Pattern string `json:"pattern"`
|
||||
TargetIdx int `json:"target"`
|
||||
CaseSensitive bool `json:"case_sensitive,omitempty"`
|
||||
}
|
||||
|
||||
func routeHeuristic(group db.ModelGroup, targets []string, userMessage string) (*Decision, error) {
|
||||
selected := targets[0]
|
||||
reason := "default (first target)"
|
||||
|
||||
// If heuristic_rules is set, use them
|
||||
if group.HeuristicRules != nil && *group.HeuristicRules != "" {
|
||||
var rules []HeuristicRule
|
||||
if err := json.Unmarshal([]byte(*group.HeuristicRules), &rules); err == nil {
|
||||
searchMsg := userMessage
|
||||
for _, rule := range rules {
|
||||
pattern := rule.Pattern
|
||||
msg := searchMsg
|
||||
if !rule.CaseSensitive {
|
||||
pattern = strings.ToLower(pattern)
|
||||
msg = strings.ToLower(msg)
|
||||
}
|
||||
if strings.Contains(msg, pattern) {
|
||||
if rule.TargetIdx >= 0 && rule.TargetIdx < len(targets) {
|
||||
selected = targets[rule.TargetIdx]
|
||||
reason = "matched heuristic rule: " + rule.Pattern
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Built-in fallback heuristics
|
||||
if reason == "default (first target)" && len(targets) > 1 {
|
||||
msgLower := strings.ToLower(userMessage)
|
||||
complexIndicators := []string{
|
||||
"step by step", "explain in detail", "reason through",
|
||||
"think carefully", "analyze", "debug", "write code",
|
||||
"implement", "refactor", "architecture",
|
||||
}
|
||||
for _, indicator := range complexIndicators {
|
||||
if strings.Contains(msgLower, indicator) {
|
||||
selected = targets[len(targets)-1]
|
||||
reason = "complex task indicator: " + indicator
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &Decision{
|
||||
SelectedModel: selected,
|
||||
Strategy: "heuristic",
|
||||
Reason: reason,
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"gophergate/internal/db"
|
||||
)
|
||||
|
||||
// Decision holds the result of a routing decision.
|
||||
type Decision struct {
|
||||
SelectedModel string `json:"selected_model"`
|
||||
Strategy string `json:"strategy"` // "heuristic" or "classifier"
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
// ClassifierFunc is the callback for classifier-based routing.
|
||||
type ClassifierFunc func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error)
|
||||
|
||||
// Router resolves model groups to concrete models.
|
||||
type Router struct {
|
||||
groups map[string]db.ModelGroup
|
||||
classify ClassifierFunc
|
||||
}
|
||||
|
||||
// New creates a Router. classify may be nil if no classifier groups exist.
|
||||
func New(groups []db.ModelGroup, classify ClassifierFunc) *Router {
|
||||
r := &Router{
|
||||
groups: make(map[string]db.ModelGroup),
|
||||
classify: classify,
|
||||
}
|
||||
for _, g := range groups {
|
||||
r.groups[g.ID] = g
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// IsGroup returns true if the model name is a group ID.
|
||||
func (r *Router) IsGroup(modelID string) bool {
|
||||
_, ok := r.groups[modelID]
|
||||
return ok
|
||||
}
|
||||
|
||||
// Route resolves a group to a concrete model.
|
||||
func (r *Router) Route(ctx context.Context, groupID string, userMessage string) (*Decision, error) {
|
||||
group, ok := r.groups[groupID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown model group: %s", groupID)
|
||||
}
|
||||
|
||||
var targets []string
|
||||
if err := json.Unmarshal([]byte(group.Targets), &targets); err != nil || len(targets) == 0 {
|
||||
return nil, fmt.Errorf("invalid or empty targets for group %s", groupID)
|
||||
}
|
||||
|
||||
switch group.Strategy {
|
||||
case "heuristic":
|
||||
return routeHeuristic(group, targets, userMessage)
|
||||
case "classifier":
|
||||
if r.classify == nil {
|
||||
return routeHeuristic(group, targets, userMessage)
|
||||
}
|
||||
return routeClassifier(ctx, r.classify, group, targets, userMessage)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown strategy: %s", group.Strategy)
|
||||
}
|
||||
}
|
||||
|
||||
// Reload replaces the group definitions without recreating the router.
|
||||
func (r *Router) Reload(groups []db.ModelGroup) {
|
||||
r.groups = make(map[string]db.ModelGroup)
|
||||
for _, g := range groups {
|
||||
r.groups[g.ID] = g
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,372 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type UsagePeriodFilter struct {
|
||||
Period string `form:"period"`
|
||||
From string `form:"from"`
|
||||
To string `form:"to"`
|
||||
}
|
||||
|
||||
func (f *UsagePeriodFilter) ToSQL() (string, []interface{}) {
|
||||
period := f.Period
|
||||
if period == "" {
|
||||
period = "all"
|
||||
}
|
||||
|
||||
if period == "custom" {
|
||||
var clauses []string
|
||||
var binds []interface{}
|
||||
if f.From != "" {
|
||||
clauses = append(clauses, "timestamp >= ?")
|
||||
binds = append(binds, f.From)
|
||||
}
|
||||
if f.To != "" {
|
||||
clauses = append(clauses, "timestamp <= ?")
|
||||
binds = append(binds, f.To)
|
||||
}
|
||||
if len(clauses) > 0 {
|
||||
return " AND " + strings.Join(clauses, " AND "), binds
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
var cutoff time.Time
|
||||
switch period {
|
||||
case "today":
|
||||
cutoff = time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC)
|
||||
case "24h":
|
||||
cutoff = now.Add(-24 * time.Hour)
|
||||
case "7d":
|
||||
cutoff = now.Add(-7 * 24 * time.Hour)
|
||||
case "30d":
|
||||
cutoff = now.Add(-30 * 24 * time.Hour)
|
||||
default:
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return " AND timestamp >= ?", []interface{}{cutoff.Format(time.RFC3339)}
|
||||
}
|
||||
|
||||
func (s *Server) handleUsageSummary(c *gin.Context) {
|
||||
var filter UsagePeriodFilter
|
||||
if err := c.ShouldBindQuery(&filter); err != nil {
|
||||
// ignore
|
||||
}
|
||||
|
||||
clause, binds := filter.ToSQL()
|
||||
|
||||
// Total stats
|
||||
var totalStats struct {
|
||||
TotalRequests int `db:"total_requests"`
|
||||
TotalTokens int `db:"total_tokens"`
|
||||
CacheReadTokens int `db:"total_cache_read_tokens"`
|
||||
CacheWriteTokens int `db:"total_cache_write_tokens"`
|
||||
TotalCost float64 `db:"total_cost"`
|
||||
ActiveClients int `db:"active_clients"`
|
||||
}
|
||||
err := s.database.Get(&totalStats, fmt.Sprintf(`
|
||||
SELECT
|
||||
COUNT(*) as total_requests,
|
||||
COALESCE(SUM(total_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
|
||||
COALESCE(SUM(cache_write_tokens), 0) as total_cache_write_tokens,
|
||||
COALESCE(SUM(cost), 0.0) as total_cost,
|
||||
COUNT(DISTINCT client_id) as active_clients
|
||||
FROM llm_requests
|
||||
WHERE 1=1 %s
|
||||
`, clause), binds...)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// Today stats
|
||||
var todayStats struct {
|
||||
TodayRequests int `db:"today_requests"`
|
||||
TodayCost float64 `db:"today_cost"`
|
||||
}
|
||||
today := time.Now().UTC().Format("2006-01-02")
|
||||
err = s.database.Get(&todayStats, `
|
||||
SELECT
|
||||
COUNT(*) as today_requests,
|
||||
COALESCE(SUM(cost), 0.0) as today_cost
|
||||
FROM llm_requests
|
||||
WHERE timestamp LIKE ?
|
||||
`, today+"%")
|
||||
if err != nil {
|
||||
todayStats.TodayRequests = 0
|
||||
todayStats.TodayCost = 0.0
|
||||
}
|
||||
|
||||
// Error rate & Avg response time
|
||||
var miscStats struct {
|
||||
ErrorRate float64 `db:"error_rate"`
|
||||
AvgResponseTime float64 `db:"avg_response_time"`
|
||||
}
|
||||
err = s.database.Get(&miscStats, `
|
||||
SELECT
|
||||
CASE WHEN COUNT(*) = 0 THEN 0.0 ELSE (CAST(SUM(CASE WHEN status = 'error' THEN 1 ELSE 0 END) AS FLOAT) / COUNT(*)) * 100.0 END as error_rate,
|
||||
COALESCE(AVG(duration_ms), 0.0) as avg_response_time
|
||||
FROM llm_requests
|
||||
`)
|
||||
if err != nil {
|
||||
miscStats.ErrorRate = 0.0
|
||||
miscStats.AvgResponseTime = 0.0
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{
|
||||
"total_requests": totalStats.TotalRequests,
|
||||
"total_tokens": totalStats.TotalTokens,
|
||||
"total_cache_read_tokens": totalStats.CacheReadTokens,
|
||||
"total_cache_write_tokens": totalStats.CacheWriteTokens,
|
||||
"total_cost": totalStats.TotalCost,
|
||||
"active_clients": totalStats.ActiveClients,
|
||||
"today_requests": todayStats.TodayRequests,
|
||||
"today_cost": todayStats.TodayCost,
|
||||
"error_rate": miscStats.ErrorRate,
|
||||
"avg_response_time": miscStats.AvgResponseTime,
|
||||
}))
|
||||
}
|
||||
|
||||
func (s *Server) handleTimeSeries(c *gin.Context) {
|
||||
var filter UsagePeriodFilter
|
||||
if err := c.ShouldBindQuery(&filter); err != nil {
|
||||
// ignore
|
||||
}
|
||||
|
||||
clause, binds := filter.ToSQL()
|
||||
|
||||
if clause == "" {
|
||||
cutoff := time.Now().UTC().Add(-30 * 24 * time.Hour)
|
||||
clause = " AND timestamp >= ?"
|
||||
binds = []interface{}{cutoff.Format(time.RFC3339)}
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
COALESCE(SUBSTR(timestamp, 1, 10), 'unknown') as bucket,
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(total_tokens), 0) as tokens,
|
||||
COALESCE(SUM(cost), 0.0) as cost
|
||||
FROM llm_requests
|
||||
WHERE 1=1 %s
|
||||
GROUP BY bucket
|
||||
ORDER BY bucket
|
||||
`, clause)
|
||||
|
||||
rows, err := s.database.Queryx(query, binds...)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var series []gin.H
|
||||
for rows.Next() {
|
||||
var bucket string
|
||||
var requests int
|
||||
var tokens int
|
||||
var cost float64
|
||||
if err := rows.Scan(&bucket, &requests, &tokens, &cost); err != nil {
|
||||
continue
|
||||
}
|
||||
series = append(series, gin.H{
|
||||
"time": bucket,
|
||||
"requests": requests,
|
||||
"tokens": tokens,
|
||||
"cost": cost,
|
||||
})
|
||||
}
|
||||
|
||||
granularity := "day"
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{
|
||||
"series": series,
|
||||
"granularity": granularity,
|
||||
}))
|
||||
}
|
||||
|
||||
func (s *Server) handleProvidersUsage(c *gin.Context) {
|
||||
var filter UsagePeriodFilter
|
||||
if err := c.ShouldBindQuery(&filter); err != nil {
|
||||
// ignore
|
||||
}
|
||||
|
||||
clause, binds := filter.ToSQL()
|
||||
|
||||
rows, err := s.database.Queryx(fmt.Sprintf(`
|
||||
SELECT
|
||||
COALESCE(provider, 'unknown') as provider,
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(cost), 0.0) as cost
|
||||
FROM llm_requests
|
||||
WHERE 1=1 %s
|
||||
GROUP BY provider
|
||||
`, clause), binds...)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, SuccessResponse([]interface{}{}))
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var results []gin.H
|
||||
for rows.Next() {
|
||||
var provider string
|
||||
var requests int
|
||||
var cost float64
|
||||
if err := rows.Scan(&provider, &requests, &cost); err == nil {
|
||||
results = append(results, gin.H{"provider": provider, "requests": requests, "cost": cost})
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(results))
|
||||
}
|
||||
|
||||
func (s *Server) handleClientsUsage(c *gin.Context) {
|
||||
var filter UsagePeriodFilter
|
||||
if err := c.ShouldBindQuery(&filter); err != nil {
|
||||
// ignore
|
||||
}
|
||||
|
||||
clause, binds := filter.ToSQL()
|
||||
|
||||
rows, err := s.database.Queryx(fmt.Sprintf(`
|
||||
SELECT COALESCE(client_id, 'unknown') as client_id, COUNT(*) as requests
|
||||
FROM llm_requests
|
||||
WHERE 1=1 %s
|
||||
GROUP BY client_id
|
||||
`, clause), binds...)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, SuccessResponse([]interface{}{}))
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var results []gin.H
|
||||
for rows.Next() {
|
||||
var clientID string
|
||||
var requests int
|
||||
if err := rows.Scan(&clientID, &requests); err == nil {
|
||||
results = append(results, gin.H{"client_id": clientID, "requests": requests})
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(results))
|
||||
}
|
||||
|
||||
func (s *Server) handleAnalyticsBreakdown(c *gin.Context) {
|
||||
var filter UsagePeriodFilter
|
||||
if err := c.ShouldBindQuery(&filter); err != nil {
|
||||
// ignore
|
||||
}
|
||||
|
||||
clause, binds := filter.ToSQL()
|
||||
|
||||
// Models breakdown
|
||||
var models []struct {
|
||||
Label string `json:"label"`
|
||||
Value int `json:"value"`
|
||||
}
|
||||
mRows, err := s.database.Queryx(fmt.Sprintf("SELECT COALESCE(model, 'unknown') as label, COUNT(*) as value FROM llm_requests WHERE 1=1 %s GROUP BY model ORDER BY value DESC", clause), binds...)
|
||||
if err == nil {
|
||||
for mRows.Next() {
|
||||
var label string
|
||||
var value int
|
||||
if err := mRows.Scan(&label, &value); err == nil {
|
||||
models = append(models, struct {
|
||||
Label string `json:"label"`
|
||||
Value int `json:"value"`
|
||||
}{label, value})
|
||||
}
|
||||
}
|
||||
mRows.Close()
|
||||
}
|
||||
|
||||
// Clients breakdown
|
||||
var clients []struct {
|
||||
Label string `json:"label"`
|
||||
Value int `json:"value"`
|
||||
}
|
||||
cRows, err := s.database.Queryx(fmt.Sprintf("SELECT COALESCE(client_id, 'unknown') as label, COUNT(*) as value FROM llm_requests WHERE 1=1 %s GROUP BY client_id ORDER BY value DESC", clause), binds...)
|
||||
if err == nil {
|
||||
for cRows.Next() {
|
||||
var label string
|
||||
var value int
|
||||
if err := cRows.Scan(&label, &value); err == nil {
|
||||
clients = append(clients, struct {
|
||||
Label string `json:"label"`
|
||||
Value int `json:"value"`
|
||||
}{label, value})
|
||||
}
|
||||
}
|
||||
cRows.Close()
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{
|
||||
"models": models,
|
||||
"clients": clients,
|
||||
}))
|
||||
}
|
||||
|
||||
func (s *Server) handleDetailedUsage(c *gin.Context) {
|
||||
var filter UsagePeriodFilter
|
||||
if err := c.ShouldBindQuery(&filter); err != nil {
|
||||
// ignore
|
||||
}
|
||||
|
||||
clause, binds := filter.ToSQL()
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
COALESCE(SUBSTR(timestamp, 1, 10), 'unknown') as date,
|
||||
COALESCE(client_id, 'unknown') as client,
|
||||
COALESCE(provider, 'unknown') as provider,
|
||||
COALESCE(model, 'unknown') as model,
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(total_tokens), 0) as tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
|
||||
COALESCE(SUM(cache_write_tokens), 0) as cache_write_tokens,
|
||||
COALESCE(SUM(cost), 0.0) as cost
|
||||
FROM llm_requests
|
||||
WHERE 1=1 %s
|
||||
GROUP BY date, client, provider, model
|
||||
ORDER BY date DESC, cost DESC
|
||||
`, clause)
|
||||
|
||||
rows, err := s.database.Queryx(query, binds...)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, SuccessResponse([]interface{}{}))
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var results []gin.H
|
||||
for rows.Next() {
|
||||
var date, client, provider, model string
|
||||
var requests, tokens, cacheRead, cacheWrite int
|
||||
var cost float64
|
||||
if err := rows.Scan(&date, &client, &provider, &model, &requests, &tokens, &cacheRead, &cacheWrite, &cost); err == nil {
|
||||
results = append(results, gin.H{
|
||||
"date": date,
|
||||
"client": client,
|
||||
"provider": provider,
|
||||
"model": model,
|
||||
"requests": requests,
|
||||
"tokens": tokens,
|
||||
"cache_read_tokens": cacheRead,
|
||||
"cache_write_tokens": cacheWrite,
|
||||
"cost": cost,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(results))
|
||||
}
|
||||
@@ -0,0 +1,273 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"gophergate/internal/db"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func (s *Server) handleGetClients(c *gin.Context) {
|
||||
var clients []db.Client
|
||||
err := s.database.Select(&clients, "SELECT * FROM clients ORDER BY created_at DESC")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
type UIClient struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
LastUsed *time.Time `json:"last_used"`
|
||||
RequestsCount int `json:"requests_count"`
|
||||
TokensCount int `json:"tokens_count"`
|
||||
Status string `json:"status"`
|
||||
RateLimitPerMinute int `json:"rate_limit_per_minute"`
|
||||
}
|
||||
|
||||
uiClients := make([]UIClient, len(clients))
|
||||
for i, cl := range clients {
|
||||
status := "active"
|
||||
if !cl.IsActive {
|
||||
status = "disabled"
|
||||
}
|
||||
|
||||
name := ""
|
||||
if cl.Name != nil {
|
||||
name = *cl.Name
|
||||
}
|
||||
desc := ""
|
||||
if cl.Description != nil {
|
||||
desc = *cl.Description
|
||||
}
|
||||
|
||||
var lastUsedStr string
|
||||
_ = s.database.Get(&lastUsedStr, "SELECT MAX(last_used_at) FROM client_tokens WHERE client_id = ?", cl.ClientID)
|
||||
|
||||
var lastUsed *time.Time
|
||||
if lastUsedStr != "" {
|
||||
if t, err := time.Parse("2006-01-02 15:04:05", lastUsedStr); err == nil {
|
||||
lastUsed = &t
|
||||
}
|
||||
}
|
||||
|
||||
uiClients[i] = UIClient{
|
||||
ID: cl.ClientID,
|
||||
Name: name,
|
||||
Description: desc,
|
||||
CreatedAt: cl.CreatedAt,
|
||||
LastUsed: lastUsed,
|
||||
RequestsCount: cl.TotalRequests,
|
||||
TokensCount: cl.TotalTokens,
|
||||
Status: status,
|
||||
RateLimitPerMinute: cl.RateLimitPerMinute,
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(uiClients))
|
||||
}
|
||||
|
||||
func (s *Server) handleGetClient(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
var cl db.Client
|
||||
err := s.database.Get(&cl, "SELECT * FROM clients WHERE client_id = ?", id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, ErrorResponse("Client not found"))
|
||||
return
|
||||
}
|
||||
|
||||
name := ""
|
||||
if cl.Name != nil {
|
||||
name = *cl.Name
|
||||
}
|
||||
desc := ""
|
||||
if cl.Description != nil {
|
||||
desc = *cl.Description
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{
|
||||
"id": cl.ClientID,
|
||||
"name": name,
|
||||
"description": desc,
|
||||
"is_active": cl.IsActive,
|
||||
"rate_limit_per_minute": cl.RateLimitPerMinute,
|
||||
"created_at": cl.CreatedAt,
|
||||
}))
|
||||
}
|
||||
|
||||
type UpdateClientRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description *string `json:"description"`
|
||||
IsActive bool `json:"is_active"`
|
||||
RateLimitPerMinute *int `json:"rate_limit_per_minute"`
|
||||
}
|
||||
|
||||
func (s *Server) handleUpdateClient(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
var req UpdateClientRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request"))
|
||||
return
|
||||
}
|
||||
|
||||
_, err := s.database.Exec(`
|
||||
UPDATE clients SET
|
||||
name = ?,
|
||||
description = ?,
|
||||
is_active = ?,
|
||||
rate_limit_per_minute = COALESCE(?, rate_limit_per_minute),
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE client_id = ?
|
||||
`, req.Name, req.Description, req.IsActive, req.RateLimitPerMinute, id)
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Client updated"}))
|
||||
}
|
||||
|
||||
type CreateClientRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
ClientID *string `json:"client_id"`
|
||||
}
|
||||
|
||||
func (s *Server) handleCreateClient(c *gin.Context) {
|
||||
var req CreateClientRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request"))
|
||||
return
|
||||
}
|
||||
|
||||
clientID := ""
|
||||
if req.ClientID != nil {
|
||||
clientID = *req.ClientID
|
||||
} else {
|
||||
clientID = "client-" + uuid.New().String()[:8]
|
||||
}
|
||||
|
||||
_, err := s.database.Exec("INSERT INTO clients (client_id, name, is_active) VALUES (?, ?, 1)", clientID, req.Name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
token := "sk-" + uuid.New().String() + uuid.New().String()
|
||||
token = token[:51]
|
||||
|
||||
_, err = s.database.Exec("INSERT INTO client_tokens (client_id, token, name) VALUES (?, ?, 'default')", clientID, token)
|
||||
if err != nil {
|
||||
// Log error
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{
|
||||
"id": clientID,
|
||||
"name": req.Name,
|
||||
"status": "active",
|
||||
"token": token,
|
||||
"created_at": time.Now(),
|
||||
}))
|
||||
}
|
||||
|
||||
func (s *Server) handleDeleteClient(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
if id == "default" {
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse("Cannot delete default client"))
|
||||
return
|
||||
}
|
||||
|
||||
_, err := s.database.Exec("DELETE FROM clients WHERE client_id = ?", id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Client deleted"}))
|
||||
}
|
||||
|
||||
func (s *Server) handleGetClientTokens(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
var tokens []db.ClientToken
|
||||
err := s.database.Select(&tokens, "SELECT * FROM client_tokens WHERE client_id = ? ORDER BY created_at DESC", id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
type MaskedToken struct {
|
||||
ID int `json:"id"`
|
||||
TokenMasked string `json:"token_masked"`
|
||||
Name string `json:"name"`
|
||||
IsActive bool `json:"is_active"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
LastUsedAt *time.Time `json:"last_used_at"`
|
||||
}
|
||||
|
||||
masked := make([]MaskedToken, len(tokens))
|
||||
for i, t := range tokens {
|
||||
maskedToken := "••••"
|
||||
if len(t.Token) > 8 {
|
||||
maskedToken = t.Token[:3] + "••••" + t.Token[len(t.Token)-8:]
|
||||
}
|
||||
masked[i] = MaskedToken{
|
||||
ID: t.ID,
|
||||
TokenMasked: maskedToken,
|
||||
Name: t.Name,
|
||||
IsActive: t.IsActive,
|
||||
CreatedAt: t.CreatedAt,
|
||||
LastUsedAt: t.LastUsedAt,
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(masked))
|
||||
}
|
||||
|
||||
type CreateTokenRequest struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func (s *Server) handleCreateClientToken(c *gin.Context) {
|
||||
clientID := c.Param("id")
|
||||
var req CreateTokenRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
// optional name
|
||||
}
|
||||
|
||||
name := "default"
|
||||
if req.Name != "" {
|
||||
name = req.Name
|
||||
}
|
||||
|
||||
token := "sk-" + uuid.New().String() + uuid.New().String()
|
||||
token = token[:51]
|
||||
|
||||
_, err := s.database.Exec("INSERT INTO client_tokens (client_id, token, name) VALUES (?, ?, ?)", clientID, token, name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{
|
||||
"token": token,
|
||||
"name": name,
|
||||
"created_at": time.Now(),
|
||||
}))
|
||||
}
|
||||
|
||||
func (s *Server) handleDeleteClientToken(c *gin.Context) {
|
||||
tokenID := c.Param("token_id")
|
||||
|
||||
_, err := s.database.Exec("DELETE FROM client_tokens WHERE id = ?", tokenID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Token revoked"}))
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"gophergate/internal/db"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
type ApiResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func SuccessResponse(data interface{}) ApiResponse {
|
||||
return ApiResponse{Success: true, Data: data}
|
||||
}
|
||||
|
||||
func ErrorResponse(err string) ApiResponse {
|
||||
return ApiResponse{Success: false, Error: err}
|
||||
}
|
||||
|
||||
func (s *Server) adminAuthMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
token := strings.TrimPrefix(c.GetHeader("Authorization"), "Bearer ")
|
||||
if token == "" {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse("Not authenticated"))
|
||||
return
|
||||
}
|
||||
|
||||
session, _, err := s.sessions.ValidateSession(token)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse("Session expired or invalid"))
|
||||
return
|
||||
}
|
||||
|
||||
if session.Role != "admin" {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, ErrorResponse("Admin access required"))
|
||||
return
|
||||
}
|
||||
|
||||
c.Set("session", session)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
type LoginRequest struct {
|
||||
Username string `json:"username" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
}
|
||||
|
||||
func (s *Server) handleLogin(c *gin.Context) {
|
||||
var req LoginRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request"))
|
||||
return
|
||||
}
|
||||
|
||||
var user db.User
|
||||
err := s.database.Get(&user, "SELECT username, password_hash, display_name, role, must_change_password FROM users WHERE username = ?", req.Username)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, ErrorResponse("Invalid username or password"))
|
||||
return
|
||||
}
|
||||
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password)); err != nil {
|
||||
c.JSON(http.StatusUnauthorized, ErrorResponse("Invalid username or password"))
|
||||
return
|
||||
}
|
||||
|
||||
displayName := user.Username
|
||||
if user.DisplayName != nil {
|
||||
displayName = *user.DisplayName
|
||||
}
|
||||
|
||||
token, err := s.sessions.CreateSession(user.Username, displayName, user.Role)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse("Failed to create session"))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{
|
||||
"token": token,
|
||||
"must_change_password": user.MustChangePassword,
|
||||
"user": user,
|
||||
}))
|
||||
}
|
||||
|
||||
func (s *Server) handleAuthStatus(c *gin.Context) {
|
||||
token := strings.TrimPrefix(c.GetHeader("Authorization"), "Bearer ")
|
||||
session, _, err := s.sessions.ValidateSession(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, ErrorResponse("Not authenticated"))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{
|
||||
"authenticated": true,
|
||||
"user": gin.H{
|
||||
"username": session.Username,
|
||||
"role": session.Role,
|
||||
"display_name": session.DisplayName,
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
type ChangePasswordRequest struct {
|
||||
CurrentPassword string `json:"current_password" binding:"required"`
|
||||
NewPassword string `json:"new_password" binding:"required"`
|
||||
}
|
||||
|
||||
func (s *Server) handleChangePassword(c *gin.Context) {
|
||||
token := strings.TrimPrefix(c.GetHeader("Authorization"), "Bearer ")
|
||||
session, _, err := s.sessions.ValidateSession(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, ErrorResponse("Not authenticated"))
|
||||
return
|
||||
}
|
||||
|
||||
var req ChangePasswordRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request"))
|
||||
return
|
||||
}
|
||||
|
||||
var user db.User
|
||||
err = s.database.Get(&user, "SELECT password_hash FROM users WHERE username = ?", session.Username)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse("User not found"))
|
||||
return
|
||||
}
|
||||
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.CurrentPassword)); err != nil {
|
||||
c.JSON(http.StatusUnauthorized, ErrorResponse("Current password incorrect"))
|
||||
return
|
||||
}
|
||||
|
||||
newHash, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), 12)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse("Failed to hash new password"))
|
||||
return
|
||||
}
|
||||
|
||||
_, err = s.database.Exec("UPDATE users SET password_hash = ?, must_change_password = 0 WHERE username = ?", string(newHash), session.Username)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse("Failed to update password"))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Password updated successfully"}))
|
||||
}
|
||||
|
||||
func (s *Server) handleLogout(c *gin.Context) {
|
||||
token := strings.TrimPrefix(c.GetHeader("Authorization"), "Bearer ")
|
||||
if err := s.sessions.RevokeSession(token); err != nil {
|
||||
fmt.Printf("Error revoking session: %v\n", err)
|
||||
}
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Logged out"}))
|
||||
}
|
||||
@@ -0,0 +1,113 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"gophergate/internal/db"
|
||||
)
|
||||
|
||||
type RequestLog struct {
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
ClientID string `json:"client_id"`
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
PromptTokens uint32 `json:"prompt_tokens"`
|
||||
CompletionTokens uint32 `json:"completion_tokens"`
|
||||
ReasoningTokens uint32 `json:"reasoning_tokens"`
|
||||
TotalTokens uint32 `json:"total_tokens"`
|
||||
CacheReadTokens uint32 `json:"cache_read_tokens"`
|
||||
CacheWriteTokens uint32 `json:"cache_write_tokens"`
|
||||
Cost float64 `json:"cost"`
|
||||
HasImages bool `json:"has_images"`
|
||||
Status string `json:"status"`
|
||||
ErrorMessage string `json:"error_message,omitempty"`
|
||||
DurationMS int64 `json:"duration_ms"`
|
||||
}
|
||||
|
||||
type RequestLogger struct {
|
||||
database *db.DB
|
||||
hub *Hub
|
||||
logChan chan RequestLog
|
||||
}
|
||||
|
||||
func NewRequestLogger(database *db.DB, hub *Hub) *RequestLogger {
|
||||
return &RequestLogger{
|
||||
database: database,
|
||||
hub: hub,
|
||||
logChan: make(chan RequestLog, 100),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *RequestLogger) Start() {
|
||||
go func() {
|
||||
for entry := range l.logChan {
|
||||
l.processLog(entry)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (l *RequestLogger) LogRequest(entry RequestLog) {
|
||||
select {
|
||||
case l.logChan <- entry:
|
||||
default:
|
||||
log.Println("Request log channel full, dropping log entry")
|
||||
}
|
||||
}
|
||||
|
||||
func (l *RequestLogger) processLog(entry RequestLog) {
|
||||
// Broadcast to dashboard
|
||||
l.hub.broadcast <- map[string]interface{}{
|
||||
"type": "request",
|
||||
"channel": "requests",
|
||||
"payload": entry,
|
||||
}
|
||||
|
||||
// Insert into DB
|
||||
tx, err := l.database.Begin()
|
||||
if err != nil {
|
||||
log.Printf("Failed to begin transaction for logging: %v", err)
|
||||
return
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// Ensure client exists
|
||||
_, _ = tx.Exec("INSERT OR IGNORE INTO clients (client_id, name, description) VALUES (?, ?, 'Auto-created from request')",
|
||||
entry.ClientID, entry.ClientID)
|
||||
|
||||
// Insert log
|
||||
_, err = tx.Exec(`
|
||||
INSERT INTO llm_requests
|
||||
(timestamp, client_id, provider, model, prompt_tokens, completion_tokens, reasoning_tokens, total_tokens, cache_read_tokens, cache_write_tokens, cost, has_images, status, error_message, duration_ms)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`, entry.Timestamp, entry.ClientID, entry.Provider, entry.Model,
|
||||
entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.TotalTokens,
|
||||
entry.CacheReadTokens, entry.CacheWriteTokens, entry.Cost, entry.HasImages,
|
||||
entry.Status, entry.ErrorMessage, entry.DurationMS)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Failed to insert request log: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Update client stats
|
||||
_, _ = tx.Exec(`
|
||||
UPDATE clients SET
|
||||
total_requests = total_requests + 1,
|
||||
total_tokens = total_tokens + ?,
|
||||
total_cost = total_cost + ?,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE client_id = ?
|
||||
`, entry.TotalTokens, entry.Cost, entry.ClientID)
|
||||
|
||||
// Update provider balance
|
||||
if entry.Cost > 0 {
|
||||
_, _ = tx.Exec("UPDATE provider_configs SET credit_balance = credit_balance - ? WHERE id = ? AND (billing_mode IS NULL OR billing_mode != 'postpaid')",
|
||||
entry.Cost, entry.Provider)
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
log.Printf("Failed to commit logging transaction: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"gophergate/internal/db"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func (s *Server) handleGetModelGroups(c *gin.Context) {
|
||||
var groups []db.ModelGroup
|
||||
if err := s.database.Select(&groups, "SELECT * FROM model_groups ORDER BY id"); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if groups == nil {
|
||||
groups = []db.ModelGroup{}
|
||||
}
|
||||
c.JSON(http.StatusOK, groups)
|
||||
}
|
||||
|
||||
func (s *Server) handleCreateModelGroup(c *gin.Context) {
|
||||
var group db.ModelGroup
|
||||
if err := c.ShouldBindJSON(&group); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
_, err := s.database.Exec(`
|
||||
INSERT INTO model_groups (id, strategy, selector_model, targets, complexity_threshold, heuristic_rules)
|
||||
VALUES (?, ?, ?, ?, ?, ?)`,
|
||||
group.ID, group.Strategy, group.SelectorModel, group.Targets,
|
||||
group.ComplexityThreshold, group.HeuristicRules)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
s.refreshRouter()
|
||||
c.JSON(http.StatusCreated, group)
|
||||
}
|
||||
|
||||
func (s *Server) handleUpdateModelGroup(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
var group db.ModelGroup
|
||||
if err := c.ShouldBindJSON(&group); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
_, err := s.database.Exec(`
|
||||
UPDATE model_groups SET strategy=?, selector_model=?, targets=?, complexity_threshold=?, heuristic_rules=?, updated_at=CURRENT_TIMESTAMP
|
||||
WHERE id=?`,
|
||||
group.Strategy, group.SelectorModel, group.Targets,
|
||||
group.ComplexityThreshold, group.HeuristicRules, id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
s.refreshRouter()
|
||||
c.JSON(http.StatusOK, group)
|
||||
}
|
||||
|
||||
func (s *Server) handleDeleteModelGroup(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
_, err := s.database.Exec("DELETE FROM model_groups WHERE id=?", id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
s.refreshRouter()
|
||||
c.JSON(http.StatusOK, gin.H{"status": "deleted"})
|
||||
}
|
||||
@@ -0,0 +1,230 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"gophergate/internal/db"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func (s *Server) handleGetModels(c *gin.Context) {
|
||||
usedOnly := c.Query("used_only") == "true"
|
||||
|
||||
// Registry provider normalized name -> Proxy-internal provider ID
|
||||
allowedRegistryProviders := map[string]string{
|
||||
"openai": "openai",
|
||||
"google": "gemini",
|
||||
"deepseek": "deepseek",
|
||||
"xai": "grok",
|
||||
"ollama": "ollama",
|
||||
}
|
||||
|
||||
// Merge registry models with DB overrides
|
||||
var dbModels []db.ModelConfig
|
||||
_ = s.database.Select(&dbModels, "SELECT * FROM model_configs")
|
||||
|
||||
dbMap := make(map[string]db.ModelConfig)
|
||||
for _, m := range dbModels {
|
||||
dbMap[m.ID] = m
|
||||
}
|
||||
|
||||
// Fetch specific (model, provider) combinations that have been used
|
||||
type modelProvider struct {
|
||||
Model string `db:"model"`
|
||||
Provider string `db:"provider"`
|
||||
}
|
||||
usedPairs := make(map[string]bool)
|
||||
if usedOnly {
|
||||
var pairs []modelProvider
|
||||
err := s.database.Select(&pairs, "SELECT DISTINCT model, provider FROM llm_requests WHERE status = 'success'")
|
||||
if err == nil {
|
||||
for _, p := range pairs {
|
||||
usedPairs[fmt.Sprintf("%s:%s", p.Model, p.Provider)] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var result []gin.H
|
||||
s.registryMu.RLock()
|
||||
if s.registry != nil {
|
||||
for pID, pInfo := range s.registry.Providers {
|
||||
proxyProvider, allowed := allowedRegistryProviders[pID]
|
||||
if !allowed {
|
||||
continue
|
||||
}
|
||||
|
||||
for mID, mMeta := range pInfo.Models {
|
||||
if usedOnly && !usedPairs[fmt.Sprintf("%s:%s", mID, proxyProvider)] {
|
||||
continue
|
||||
}
|
||||
|
||||
enabled := true
|
||||
promptCost := 0.0
|
||||
completionCost := 0.0
|
||||
var cacheReadCost *float64
|
||||
var cacheWriteCost *float64
|
||||
var mapping *string
|
||||
contextLimit := uint32(0)
|
||||
|
||||
if mMeta.Cost != nil {
|
||||
promptCost = mMeta.Cost.Input
|
||||
completionCost = mMeta.Cost.Output
|
||||
cacheReadCost = mMeta.Cost.CacheRead
|
||||
cacheWriteCost = mMeta.Cost.CacheWrite
|
||||
}
|
||||
if mMeta.Limit != nil {
|
||||
contextLimit = mMeta.Limit.Context
|
||||
}
|
||||
|
||||
// Override from DB
|
||||
if dbCfg, ok := dbMap[mID]; ok {
|
||||
enabled = dbCfg.Enabled
|
||||
if dbCfg.PromptCostPerM != nil {
|
||||
promptCost = *dbCfg.PromptCostPerM
|
||||
}
|
||||
if dbCfg.CompletionCostPerM != nil {
|
||||
completionCost = *dbCfg.CompletionCostPerM
|
||||
}
|
||||
if dbCfg.CacheReadCostPerM != nil {
|
||||
cacheReadCost = dbCfg.CacheReadCostPerM
|
||||
}
|
||||
if dbCfg.CacheWriteCostPerM != nil {
|
||||
cacheWriteCost = dbCfg.CacheWriteCostPerM
|
||||
}
|
||||
mapping = dbCfg.Mapping
|
||||
}
|
||||
|
||||
result = append(result, gin.H{
|
||||
"id": mID,
|
||||
"name": mMeta.Name,
|
||||
"provider": proxyProvider,
|
||||
"enabled": enabled,
|
||||
"prompt_cost": promptCost,
|
||||
"completion_cost": completionCost,
|
||||
"cache_read_cost": cacheReadCost,
|
||||
"cache_write_cost": cacheWriteCost,
|
||||
"context_limit": contextLimit,
|
||||
"mapping": mapping,
|
||||
"tool_call": mMeta.ToolCall != nil && *mMeta.ToolCall,
|
||||
"reasoning": mMeta.Reasoning != nil && *mMeta.Reasoning,
|
||||
"modalities": mMeta.Modalities,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add configured Ollama models if they aren't in registry
|
||||
if s.cfg.Providers.Ollama.Enabled {
|
||||
for _, mID := range s.cfg.Providers.Ollama.Models {
|
||||
// Check if already added from registry
|
||||
exists := false
|
||||
for _, r := range result {
|
||||
if r["id"] == mID {
|
||||
exists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if exists {
|
||||
continue
|
||||
}
|
||||
|
||||
if usedOnly && !usedPairs[fmt.Sprintf("%s:ollama", mID)] {
|
||||
continue
|
||||
}
|
||||
|
||||
enabled := true
|
||||
promptCost := 0.0
|
||||
completionCost := 0.0
|
||||
var cacheReadCost *float64
|
||||
var cacheWriteCost *float64
|
||||
var mapping *string
|
||||
contextLimit := uint32(0)
|
||||
|
||||
// Override from DB
|
||||
if dbCfg, ok := dbMap[mID]; ok {
|
||||
enabled = dbCfg.Enabled
|
||||
if dbCfg.PromptCostPerM != nil {
|
||||
promptCost = *dbCfg.PromptCostPerM
|
||||
}
|
||||
if dbCfg.CompletionCostPerM != nil {
|
||||
completionCost = *dbCfg.CompletionCostPerM
|
||||
}
|
||||
if dbCfg.CacheReadCostPerM != nil {
|
||||
cacheReadCost = dbCfg.CacheReadCostPerM
|
||||
}
|
||||
if dbCfg.CacheWriteCostPerM != nil {
|
||||
cacheWriteCost = dbCfg.CacheWriteCostPerM
|
||||
}
|
||||
mapping = dbCfg.Mapping
|
||||
}
|
||||
|
||||
result = append(result, gin.H{
|
||||
"id": mID,
|
||||
"name": mID,
|
||||
"provider": "ollama",
|
||||
"enabled": enabled,
|
||||
"prompt_cost": promptCost,
|
||||
"completion_cost": completionCost,
|
||||
"cache_read_cost": cacheReadCost,
|
||||
"cache_write_cost": cacheWriteCost,
|
||||
"context_limit": contextLimit,
|
||||
"modalities": gin.H{"input": []string{"text"}, "output": []string{"text"}},
|
||||
"tool_call": false,
|
||||
"reasoning": false,
|
||||
"mapping": mapping,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(result))
|
||||
}
|
||||
|
||||
func (s *Server) handleUpdateModel(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
var req struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
PromptCost float64 `json:"prompt_cost"`
|
||||
CompletionCost float64 `json:"completion_cost"`
|
||||
CacheReadCost *float64 `json:"cache_read_cost"`
|
||||
CacheWriteCost *float64 `json:"cache_write_cost"`
|
||||
Mapping *string `json:"mapping"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request"))
|
||||
return
|
||||
}
|
||||
|
||||
// Find provider for this model
|
||||
providerID := "unknown"
|
||||
s.registryMu.RLock()
|
||||
if s.registry != nil {
|
||||
for pID, pInfo := range s.registry.Providers {
|
||||
if _, ok := pInfo.Models[id]; ok {
|
||||
providerID = pID
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_, err := s.database.Exec(`
|
||||
INSERT INTO model_configs (id, provider_id, enabled, prompt_cost_per_m, completion_cost_per_m, cache_read_cost_per_m, cache_write_cost_per_m, mapping)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
enabled = excluded.enabled,
|
||||
prompt_cost_per_m = excluded.prompt_cost_per_m,
|
||||
completion_cost_per_m = excluded.completion_cost_per_m,
|
||||
cache_read_cost_per_m = excluded.cache_read_cost_per_m,
|
||||
cache_write_cost_per_m = excluded.cache_write_cost_per_m,
|
||||
mapping = excluded.mapping,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
`, id, providerID, req.Enabled, req.PromptCost, req.CompletionCost, req.CacheReadCost, req.CacheWriteCost, req.Mapping)
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Model updated"}))
|
||||
}
|
||||
@@ -0,0 +1,243 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gophergate/internal/db"
|
||||
"gophergate/internal/models"
|
||||
"gophergate/internal/utils"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func (s *Server) handleGetProviders(c *gin.Context) {
|
||||
var dbConfigs []db.ProviderConfig
|
||||
err := s.database.Select(&dbConfigs, "SELECT id, enabled, base_url, credit_balance, low_credit_threshold, billing_mode FROM provider_configs")
|
||||
if err != nil {
|
||||
// Log error
|
||||
}
|
||||
|
||||
dbMap := make(map[string]db.ProviderConfig)
|
||||
for _, cfg := range dbConfigs {
|
||||
dbMap[cfg.ID] = cfg
|
||||
}
|
||||
|
||||
providerIDs := []string{"openai", "gemini", "deepseek", "moonshot", "grok", "ollama"}
|
||||
var result []gin.H
|
||||
|
||||
for _, id := range providerIDs {
|
||||
var name string
|
||||
var enabled bool
|
||||
var baseURL string
|
||||
|
||||
switch id {
|
||||
case "openai":
|
||||
name = "OpenAI"
|
||||
enabled = s.cfg.Providers.OpenAI.Enabled
|
||||
baseURL = s.cfg.Providers.OpenAI.BaseURL
|
||||
case "gemini":
|
||||
name = "Google Gemini"
|
||||
enabled = s.cfg.Providers.Gemini.Enabled
|
||||
baseURL = s.cfg.Providers.Gemini.BaseURL
|
||||
case "deepseek":
|
||||
name = "DeepSeek"
|
||||
enabled = s.cfg.Providers.DeepSeek.Enabled
|
||||
baseURL = s.cfg.Providers.DeepSeek.BaseURL
|
||||
case "moonshot":
|
||||
name = "Moonshot"
|
||||
enabled = s.cfg.Providers.Moonshot.Enabled
|
||||
baseURL = s.cfg.Providers.Moonshot.BaseURL
|
||||
case "grok":
|
||||
name = "xAI Grok"
|
||||
enabled = s.cfg.Providers.Grok.Enabled
|
||||
baseURL = s.cfg.Providers.Grok.BaseURL
|
||||
case "ollama":
|
||||
name = "Ollama"
|
||||
enabled = s.cfg.Providers.Ollama.Enabled
|
||||
baseURL = s.cfg.Providers.Ollama.BaseURL
|
||||
}
|
||||
|
||||
var balance float64
|
||||
var threshold float64 = 5.0
|
||||
var billingMode string
|
||||
|
||||
if dbCfg, ok := dbMap[id]; ok {
|
||||
enabled = dbCfg.Enabled
|
||||
if dbCfg.BaseURL != nil {
|
||||
baseURL = *dbCfg.BaseURL
|
||||
}
|
||||
balance = dbCfg.CreditBalance
|
||||
threshold = dbCfg.LowCreditThreshold
|
||||
if dbCfg.BillingMode != nil {
|
||||
billingMode = *dbCfg.BillingMode
|
||||
}
|
||||
}
|
||||
|
||||
status := "disabled"
|
||||
if enabled {
|
||||
if _, ok := s.providers[id]; ok {
|
||||
status = "online"
|
||||
} else {
|
||||
status = "error"
|
||||
}
|
||||
}
|
||||
|
||||
// Get last used for this provider
|
||||
var lastUsedStr string
|
||||
_ = s.database.Get(&lastUsedStr, "SELECT MAX(timestamp) FROM llm_requests WHERE provider = ?", id)
|
||||
var lastUsed interface{}
|
||||
if lastUsedStr != "" {
|
||||
if t, err := time.Parse("2006-01-02 15:04:05", lastUsedStr); err == nil {
|
||||
lastUsed = t
|
||||
}
|
||||
}
|
||||
|
||||
// Get models for this provider from registry
|
||||
var models []string
|
||||
s.registryMu.RLock()
|
||||
if s.registry != nil {
|
||||
registryID := id
|
||||
if id == "gemini" {
|
||||
registryID = "google"
|
||||
}
|
||||
if id == "moonshot" {
|
||||
registryID = "moonshot"
|
||||
}
|
||||
if id == "grok" {
|
||||
registryID = "xai"
|
||||
}
|
||||
|
||||
if pInfo, ok := s.registry.Providers[registryID]; ok {
|
||||
for mID := range pInfo.Models {
|
||||
models = append(models, mID)
|
||||
}
|
||||
}
|
||||
}
|
||||
s.registryMu.RUnlock()
|
||||
|
||||
// If it's ollama, also include models from config
|
||||
if id == "ollama" {
|
||||
models = append(models, s.cfg.Providers.Ollama.Models...)
|
||||
}
|
||||
|
||||
result = append(result, gin.H{
|
||||
"id": id,
|
||||
"name": name,
|
||||
"enabled": enabled,
|
||||
"status": status,
|
||||
"base_url": baseURL,
|
||||
"credit_balance": balance,
|
||||
"low_credit_threshold": threshold,
|
||||
"billing_mode": billingMode,
|
||||
"last_used": lastUsed,
|
||||
"models": models,
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(result))
|
||||
}
|
||||
|
||||
type UpdateProviderRequest struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
BaseURL *string `json:"base_url"`
|
||||
APIKey *string `json:"api_key"`
|
||||
CreditBalance *float64 `json:"credit_balance"`
|
||||
LowCreditThreshold *float64 `json:"low_credit_threshold"`
|
||||
BillingMode *string `json:"billing_mode"`
|
||||
}
|
||||
|
||||
func (s *Server) handleUpdateProvider(c *gin.Context) {
|
||||
name := c.Param("name")
|
||||
var req UpdateProviderRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request"))
|
||||
return
|
||||
}
|
||||
|
||||
apiKeyEncrypted := false
|
||||
var apiKey *string = req.APIKey
|
||||
if req.APIKey != nil && *req.APIKey != "" {
|
||||
encrypted, err := utils.Encrypt(*req.APIKey, s.cfg.KeyBytes)
|
||||
if err == nil {
|
||||
apiKey = &encrypted
|
||||
apiKeyEncrypted = true
|
||||
}
|
||||
}
|
||||
|
||||
_, err := s.database.Exec(`
|
||||
INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, credit_balance, low_credit_threshold, billing_mode, api_key_encrypted)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
enabled = excluded.enabled,
|
||||
base_url = COALESCE(excluded.base_url, provider_configs.base_url),
|
||||
api_key = COALESCE(excluded.api_key, provider_configs.api_key),
|
||||
api_key_encrypted = excluded.api_key_encrypted,
|
||||
credit_balance = COALESCE(excluded.credit_balance, provider_configs.credit_balance),
|
||||
low_credit_threshold = COALESCE(excluded.low_credit_threshold, provider_configs.low_credit_threshold),
|
||||
billing_mode = COALESCE(excluded.billing_mode, provider_configs.billing_mode),
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
`, name, strings.ToUpper(name), req.Enabled, req.BaseURL, apiKey, req.CreditBalance, req.LowCreditThreshold, req.BillingMode, apiKeyEncrypted)
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// Refresh in-memory providers
|
||||
if err := s.RefreshProviders(); err != nil {
|
||||
fmt.Printf("Error refreshing providers: %v\n", err)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Provider updated"}))
|
||||
}
|
||||
|
||||
func (s *Server) handleTestProvider(c *gin.Context) {
|
||||
name := c.Param("name")
|
||||
provider, ok := s.providers[name]
|
||||
if !ok {
|
||||
c.JSON(http.StatusNotFound, ErrorResponse(fmt.Sprintf("Provider %s not found or not enabled", name)))
|
||||
return
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
// Prepare a simple test request
|
||||
testReq := &models.UnifiedRequest{
|
||||
Model: "gpt-4o-mini", // Default cheap test model
|
||||
Messages: []models.UnifiedMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []models.UnifiedContentPart{{Type: "text", Text: "Hi"}},
|
||||
},
|
||||
},
|
||||
MaxTokens: new(uint32),
|
||||
}
|
||||
*testReq.MaxTokens = 5
|
||||
|
||||
// Adjust model for non-openai providers
|
||||
if name == "gemini" {
|
||||
testReq.Model = "gemini-2.0-flash"
|
||||
} else if name == "deepseek" {
|
||||
testReq.Model = "deepseek-chat"
|
||||
} else if name == "moonshot" {
|
||||
testReq.Model = "kimi-k2.5"
|
||||
} else if name == "grok" {
|
||||
testReq.Model = "grok-4-1-fast-non-reasoning"
|
||||
}
|
||||
|
||||
_, err := provider.ChatCompletion(c.Request.Context(), testReq)
|
||||
latency := time.Since(startTime).Milliseconds()
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, ErrorResponse(fmt.Sprintf("Provider test failed: %v", err)))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{
|
||||
"message": "Connection test successful",
|
||||
"latency": latency,
|
||||
}))
|
||||
}
|
||||
@@ -0,0 +1,882 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gophergate/internal/config"
|
||||
"gophergate/internal/db"
|
||||
"gophergate/internal/middleware"
|
||||
"gophergate/internal/models"
|
||||
"gophergate/internal/providers"
|
||||
"gophergate/internal/router"
|
||||
"gophergate/internal/utils"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
router *gin.Engine
|
||||
cfg *config.Config
|
||||
database *db.DB
|
||||
providers map[string]providers.Provider
|
||||
sessions *SessionManager
|
||||
hub *Hub
|
||||
logger *RequestLogger
|
||||
registry *models.ModelRegistry
|
||||
registryMu sync.RWMutex
|
||||
modelRouter *router.Router
|
||||
}
|
||||
|
||||
func NewServer(cfg *config.Config, database *db.DB) *Server {
|
||||
router := gin.Default()
|
||||
hub := NewHub()
|
||||
|
||||
s := &Server{
|
||||
router: router,
|
||||
cfg: cfg,
|
||||
database: database,
|
||||
providers: make(map[string]providers.Provider),
|
||||
sessions: NewSessionManager(cfg.KeyBytes, 24*time.Hour),
|
||||
hub: hub,
|
||||
logger: NewRequestLogger(database, hub),
|
||||
registry: &models.ModelRegistry{Providers: make(map[string]models.ProviderInfo)},
|
||||
}
|
||||
|
||||
s.sessions.StartCleanup()
|
||||
// Fetch registry in background
|
||||
go func() {
|
||||
registry, err := utils.FetchRegistry()
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: Failed to fetch initial model registry: %v\n", err)
|
||||
} else {
|
||||
s.registry = registry
|
||||
}
|
||||
}()
|
||||
|
||||
// Initialize providers from DB and Config
|
||||
if err := s.RefreshProviders(); err != nil {
|
||||
fmt.Printf("Warning: Failed to initial refresh providers: %v\n", err)
|
||||
}
|
||||
|
||||
s.setupRoutes()
|
||||
|
||||
// Initialize model group router
|
||||
s.refreshRouter()
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Server) RefreshProviders() error {
|
||||
var dbConfigs []db.ProviderConfig
|
||||
err := s.database.Select(&dbConfigs, "SELECT * FROM provider_configs")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch provider configs from db: %w", err)
|
||||
}
|
||||
|
||||
dbMap := make(map[string]db.ProviderConfig)
|
||||
for _, cfg := range dbConfigs {
|
||||
dbMap[cfg.ID] = cfg
|
||||
}
|
||||
|
||||
providerIDs := []string{"openai", "gemini", "deepseek", "moonshot", "grok", "ollama"}
|
||||
for _, id := range providerIDs {
|
||||
// Default values from config
|
||||
enabled := false
|
||||
baseURL := ""
|
||||
apiKey := ""
|
||||
|
||||
switch id {
|
||||
case "openai":
|
||||
enabled = s.cfg.Providers.OpenAI.Enabled
|
||||
baseURL = s.cfg.Providers.OpenAI.BaseURL
|
||||
apiKey, _ = s.cfg.GetAPIKey("openai")
|
||||
case "gemini":
|
||||
enabled = s.cfg.Providers.Gemini.Enabled
|
||||
baseURL = s.cfg.Providers.Gemini.BaseURL
|
||||
apiKey, _ = s.cfg.GetAPIKey("gemini")
|
||||
case "deepseek":
|
||||
enabled = s.cfg.Providers.DeepSeek.Enabled
|
||||
baseURL = s.cfg.Providers.DeepSeek.BaseURL
|
||||
apiKey, _ = s.cfg.GetAPIKey("deepseek")
|
||||
case "moonshot":
|
||||
enabled = s.cfg.Providers.Moonshot.Enabled
|
||||
baseURL = s.cfg.Providers.Moonshot.BaseURL
|
||||
apiKey, _ = s.cfg.GetAPIKey("moonshot")
|
||||
case "grok":
|
||||
enabled = s.cfg.Providers.Grok.Enabled
|
||||
baseURL = s.cfg.Providers.Grok.BaseURL
|
||||
apiKey, _ = s.cfg.GetAPIKey("grok")
|
||||
}
|
||||
|
||||
// Overrides from DB
|
||||
if dbCfg, ok := dbMap[id]; ok {
|
||||
enabled = dbCfg.Enabled
|
||||
if dbCfg.BaseURL != nil && *dbCfg.BaseURL != "" {
|
||||
baseURL = *dbCfg.BaseURL
|
||||
}
|
||||
if dbCfg.APIKey != nil && *dbCfg.APIKey != "" {
|
||||
key := *dbCfg.APIKey
|
||||
if dbCfg.APIKeyEncrypted {
|
||||
decrypted, err := utils.Decrypt(key, s.cfg.KeyBytes)
|
||||
if err == nil {
|
||||
key = decrypted
|
||||
} else {
|
||||
fmt.Printf("Warning: Failed to decrypt API key for %s: %v\n", id, err)
|
||||
}
|
||||
}
|
||||
apiKey = key
|
||||
}
|
||||
}
|
||||
|
||||
if !enabled {
|
||||
delete(s.providers, id)
|
||||
continue
|
||||
}
|
||||
|
||||
// Initialize provider
|
||||
var p providers.Provider
|
||||
switch id {
|
||||
case "openai":
|
||||
cfg := s.cfg.Providers.OpenAI
|
||||
cfg.BaseURL = baseURL
|
||||
p = providers.NewOpenAIProvider(cfg, apiKey)
|
||||
case "gemini":
|
||||
cfg := s.cfg.Providers.Gemini
|
||||
cfg.BaseURL = baseURL
|
||||
p = providers.NewGeminiProvider(cfg, apiKey)
|
||||
case "deepseek":
|
||||
cfg := s.cfg.Providers.DeepSeek
|
||||
cfg.BaseURL = baseURL
|
||||
p = providers.NewDeepSeekProvider(cfg, apiKey)
|
||||
case "moonshot":
|
||||
cfg := s.cfg.Providers.Moonshot
|
||||
cfg.BaseURL = baseURL
|
||||
p = providers.NewMoonshotProvider(cfg, apiKey)
|
||||
case "grok":
|
||||
cfg := s.cfg.Providers.Grok
|
||||
cfg.BaseURL = baseURL
|
||||
p = providers.NewGrokProvider(cfg, apiKey)
|
||||
case "ollama":
|
||||
cfg := s.cfg.Providers.Ollama
|
||||
cfg.BaseURL = baseURL
|
||||
p = providers.NewOllamaProvider(cfg)
|
||||
}
|
||||
|
||||
if p != nil {
|
||||
s.providers[id] = providers.NewCircuitBreakerProvider(p)
|
||||
}
|
||||
}
|
||||
|
||||
s.refreshRouter()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) refreshRouter() {
|
||||
var groups []db.ModelGroup
|
||||
if err := s.database.Select(&groups, "SELECT * FROM model_groups"); err != nil {
|
||||
fmt.Printf("Warning: Failed to load model groups: %v\n", err)
|
||||
groups = nil
|
||||
}
|
||||
|
||||
var classifyFn router.ClassifierFunc
|
||||
if openaiProvider, ok := s.providers["openai"]; ok {
|
||||
classifyFn = func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error) {
|
||||
req := &models.UnifiedRequest{
|
||||
Model: selectorModel,
|
||||
Messages: []models.UnifiedMessage{
|
||||
{Role: "system", Content: []models.UnifiedContentPart{{Type: "text", Text: systemPrompt}}},
|
||||
{Role: "user", Content: []models.UnifiedContentPart{{Type: "text", Text: userMessage}}},
|
||||
},
|
||||
MaxTokens: uint32Ptr(5),
|
||||
Stream: false,
|
||||
}
|
||||
resp, err := openaiProvider.ChatCompletion(ctx, req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(resp.Choices) == 0 {
|
||||
return "", fmt.Errorf("no choices in classifier response")
|
||||
}
|
||||
content, ok := resp.Choices[0].Message.Content.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("classifier response content is not a string")
|
||||
}
|
||||
return content, nil
|
||||
}
|
||||
}
|
||||
|
||||
if s.modelRouter == nil {
|
||||
s.modelRouter = router.New(groups, classifyFn)
|
||||
} else {
|
||||
s.modelRouter.Reload(groups)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) setupRoutes() {
|
||||
// Static files
|
||||
s.router.StaticFile("/", "./static/index.html")
|
||||
s.router.StaticFile("/favicon.ico", "./static/favicon.ico")
|
||||
s.router.Static("/css", "./static/css")
|
||||
s.router.Static("/js", "./static/js")
|
||||
s.router.Static("/img", "./static/img")
|
||||
|
||||
// WebSocket
|
||||
s.router.GET("/ws", s.handleWebSocket)
|
||||
|
||||
// API V1 (External LLM Access) - Secured with AuthMiddleware
|
||||
v1 := s.router.Group("/v1")
|
||||
v1.Use(middleware.AuthMiddleware(s.database, true))
|
||||
{
|
||||
v1.POST("/chat/completions", s.handleChatCompletions)
|
||||
v1.POST("/images/generations", s.handleImageGenerations)
|
||||
v1.GET("/models", s.handleListModels)
|
||||
v1.POST("/responses", s.handleResponses)
|
||||
}
|
||||
|
||||
// Dashboard API Group
|
||||
api := s.router.Group("/api")
|
||||
{
|
||||
api.POST("/auth/login", s.handleLogin)
|
||||
api.GET("/auth/status", s.handleAuthStatus)
|
||||
api.POST("/auth/logout", s.handleLogout)
|
||||
api.POST("/auth/change-password", s.handleChangePassword)
|
||||
|
||||
// Protected dashboard routes (need admin session)
|
||||
admin := api.Group("/")
|
||||
admin.Use(s.adminAuthMiddleware())
|
||||
{
|
||||
admin.GET("/usage/summary", s.handleUsageSummary)
|
||||
admin.GET("/usage/time-series", s.handleTimeSeries)
|
||||
admin.GET("/usage/providers", s.handleProvidersUsage)
|
||||
admin.GET("/usage/clients", s.handleClientsUsage)
|
||||
admin.GET("/usage/detailed", s.handleDetailedUsage)
|
||||
admin.GET("/analytics/breakdown", s.handleAnalyticsBreakdown)
|
||||
|
||||
admin.GET("/clients", s.handleGetClients)
|
||||
admin.POST("/clients", s.handleCreateClient)
|
||||
admin.GET("/clients/:id", s.handleGetClient)
|
||||
admin.PUT("/clients/:id", s.handleUpdateClient)
|
||||
admin.DELETE("/clients/:id", s.handleDeleteClient)
|
||||
|
||||
admin.GET("/clients/:id/tokens", s.handleGetClientTokens)
|
||||
admin.POST("/clients/:id/tokens", s.handleCreateClientToken)
|
||||
admin.DELETE("/clients/:id/tokens/:token_id", s.handleDeleteClientToken)
|
||||
|
||||
admin.GET("/providers", s.handleGetProviders)
|
||||
admin.PUT("/providers/:name", s.handleUpdateProvider)
|
||||
admin.POST("/providers/:name/test", s.handleTestProvider)
|
||||
|
||||
admin.GET("/models", s.handleGetModels)
|
||||
admin.PUT("/models/:id", s.handleUpdateModel)
|
||||
|
||||
admin.GET("/model-groups", s.handleGetModelGroups)
|
||||
admin.POST("/model-groups", s.handleCreateModelGroup)
|
||||
admin.PUT("/model-groups/:id", s.handleUpdateModelGroup)
|
||||
admin.DELETE("/model-groups/:id", s.handleDeleteModelGroup)
|
||||
|
||||
admin.GET("/users", s.handleGetUsers)
|
||||
admin.POST("/users", s.handleCreateUser)
|
||||
admin.PUT("/users/:id", s.handleUpdateUser)
|
||||
admin.DELETE("/users/:id", s.handleDeleteUser)
|
||||
|
||||
admin.GET("/system/health", s.handleSystemHealth)
|
||||
admin.GET("/system/metrics", s.handleSystemMetrics)
|
||||
admin.GET("/system/settings", s.handleGetSettings)
|
||||
admin.POST("/system/backup", s.handleCreateBackup)
|
||||
admin.GET("/system/logs", s.handleGetLogs)
|
||||
}
|
||||
}
|
||||
|
||||
s.router.GET("/health", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleResponses(c *gin.Context) {
|
||||
startTime := time.Now()
|
||||
var req models.ResponsesRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Select provider based on model name
|
||||
providerName := "openai" // default for Responses API
|
||||
modelLower := strings.ToLower(req.Model)
|
||||
if strings.HasPrefix(modelLower, "gemini/") || strings.Contains(modelLower, "gemini") || strings.HasPrefix(modelLower, "google/") {
|
||||
providerName = "gemini"
|
||||
} else if strings.HasPrefix(modelLower, "deepseek/") || (strings.Contains(modelLower, "deepseek") && !strings.Contains(modelLower, "ollama")) {
|
||||
providerName = "deepseek"
|
||||
} else if strings.HasPrefix(modelLower, "moonshot/") || strings.Contains(modelLower, "kimi") || strings.Contains(modelLower, "moonshot") {
|
||||
providerName = "moonshot"
|
||||
} else if strings.HasPrefix(modelLower, "grok/") || strings.Contains(modelLower, "grok") {
|
||||
providerName = "grok"
|
||||
} else if strings.HasPrefix(modelLower, "ollama/") ||
|
||||
strings.Contains(modelLower, "glm-") ||
|
||||
strings.Contains(modelLower, "qwen") ||
|
||||
strings.Contains(modelLower, "gemma") ||
|
||||
strings.Contains(modelLower, "llama") ||
|
||||
strings.Contains(modelLower, "mistral") ||
|
||||
strings.Contains(modelLower, "phi") ||
|
||||
strings.Contains(modelLower, "yi") ||
|
||||
strings.Contains(modelLower, "codellama") ||
|
||||
strings.Contains(modelLower, "command-r") {
|
||||
providerName = "ollama"
|
||||
}
|
||||
|
||||
provider, ok := s.providers[providerName]
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Provider %s not enabled or supported", providerName)})
|
||||
return
|
||||
}
|
||||
|
||||
// Strip common prefixes from model name
|
||||
modelID := req.Model
|
||||
prefixes := []string{"gemini/", "google/", "openai/", "deepseek/", "moonshot/", "grok/", "ollama/"}
|
||||
for _, p := range prefixes {
|
||||
if strings.HasPrefix(modelID, p) {
|
||||
modelID = strings.TrimPrefix(modelID, p)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Use the stripped model name for the actual API call
|
||||
req.Model = modelID
|
||||
|
||||
clientID := "default"
|
||||
if auth, ok := c.Get("auth"); ok {
|
||||
if authInfo, ok := auth.(models.AuthInfo); ok {
|
||||
clientID = authInfo.ClientID
|
||||
}
|
||||
}
|
||||
|
||||
stream := req.Stream != nil && *req.Stream
|
||||
|
||||
if stream {
|
||||
ch, err := provider.ResponsesStream(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, false)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
var lastUsage *models.ResponsesUsage
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
chunk, ok := <-ch
|
||||
if !ok {
|
||||
fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||
if lastUsage != nil {
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, lastUsage.ToUsage(), nil, false)
|
||||
} else {
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, nil, nil, false)
|
||||
}
|
||||
return false
|
||||
}
|
||||
// Capture usage from the response payload in streaming chunks
|
||||
if chunk.Response != nil && chunk.Response.Usage != nil {
|
||||
lastUsage = chunk.Response.Usage
|
||||
}
|
||||
data, err := json.Marshal(chunk)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
return true
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := provider.Responses(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, false)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if resp.Usage != nil {
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, resp.Usage.ToUsage(), nil, false)
|
||||
} else {
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, nil, nil, false)
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (s *Server) handleListModels(c *gin.Context) {
|
||||
type OpenAIModel struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
}
|
||||
|
||||
modelMap := make(map[string]OpenAIModel)
|
||||
allowedProviders := map[string]bool{
|
||||
"openai": true,
|
||||
"google": true, // Models from models.dev use 'google' ID for Gemini
|
||||
"deepseek": true,
|
||||
"moonshot": true,
|
||||
"moonshotai": true, // Official moonshotai ID in models.dev
|
||||
"moonshotai-cn": true, // Official moonshotai-cn ID in models.dev
|
||||
"xai": true, // Models from models.dev use 'xai' ID for Grok
|
||||
"llmgateway": true, // Catch-all for newer models
|
||||
"ollama": true,
|
||||
}
|
||||
|
||||
s.registryMu.RLock()
|
||||
if s.registry != nil {
|
||||
for pID, pInfo := range s.registry.Providers {
|
||||
if !allowedProviders[pID] {
|
||||
continue
|
||||
}
|
||||
for mID := range pInfo.Models {
|
||||
if _, exists := modelMap[mID]; !exists {
|
||||
modelMap[mID] = OpenAIModel{
|
||||
ID: mID,
|
||||
Object: "model",
|
||||
Created: 1700000000,
|
||||
OwnedBy: pID,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
s.registryMu.RUnlock()
|
||||
|
||||
// Add configured Ollama models
|
||||
if s.cfg.Providers.Ollama.Enabled {
|
||||
for _, mID := range s.cfg.Providers.Ollama.Models {
|
||||
if _, exists := modelMap[mID]; !exists {
|
||||
modelMap[mID] = OpenAIModel{
|
||||
ID: mID,
|
||||
Object: "model",
|
||||
Created: 1700000000,
|
||||
OwnedBy: "ollama",
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var data []OpenAIModel
|
||||
for _, m := range modelMap {
|
||||
data = append(data, m)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"object": "list",
|
||||
"data": data,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleChatCompletions(c *gin.Context) {
|
||||
startTime := time.Now()
|
||||
var req models.ChatCompletionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Select provider based on model name
|
||||
providerName := "openai" // default
|
||||
modelLower := strings.ToLower(req.Model)
|
||||
if strings.HasPrefix(modelLower, "gemini/") || strings.Contains(modelLower, "gemini") || strings.HasPrefix(modelLower, "google/") {
|
||||
providerName = "gemini"
|
||||
} else if strings.HasPrefix(modelLower, "deepseek/") || (strings.Contains(modelLower, "deepseek") && !strings.Contains(modelLower, "ollama")) {
|
||||
// Only use deepseek provider if it's not explicitly tagged for ollama
|
||||
providerName = "deepseek"
|
||||
} else if strings.HasPrefix(modelLower, "moonshot/") || strings.Contains(modelLower, "kimi") || strings.Contains(modelLower, "moonshot") {
|
||||
providerName = "moonshot"
|
||||
} else if strings.HasPrefix(modelLower, "grok/") || strings.Contains(modelLower, "grok") {
|
||||
providerName = "grok"
|
||||
} else if strings.HasPrefix(modelLower, "ollama/") ||
|
||||
strings.Contains(modelLower, "glm-") ||
|
||||
strings.Contains(modelLower, "qwen") ||
|
||||
strings.Contains(modelLower, "gemma") ||
|
||||
strings.Contains(modelLower, "llama") ||
|
||||
strings.Contains(modelLower, "mistral") ||
|
||||
strings.Contains(modelLower, "phi") ||
|
||||
strings.Contains(modelLower, "yi") ||
|
||||
strings.Contains(modelLower, "codellama") ||
|
||||
strings.Contains(modelLower, "command-r") {
|
||||
providerName = "ollama"
|
||||
}
|
||||
|
||||
provider, ok := s.providers[providerName]
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Provider %s not enabled or supported", providerName)})
|
||||
return
|
||||
}
|
||||
|
||||
// Strip common prefixes
|
||||
modelID := req.Model
|
||||
prefixes := []string{"gemini/", "google/", "openai/", "deepseek/", "moonshot/", "grok/", "ollama/"}
|
||||
for _, p := range prefixes {
|
||||
if strings.HasPrefix(modelID, p) {
|
||||
modelID = strings.TrimPrefix(modelID, p)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Check if model is a group and route to a concrete model
|
||||
if s.modelRouter != nil && s.modelRouter.IsGroup(modelID) {
|
||||
userMessage := extractUserMessage(req.Messages)
|
||||
decision, err := s.modelRouter.Route(c.Request.Context(), modelID, userMessage)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)})
|
||||
return
|
||||
}
|
||||
modelID = decision.SelectedModel
|
||||
log.Printf("[ROUTER] %s -> %s (%s: %s)", req.Model, modelID, decision.Strategy, decision.Reason)
|
||||
}
|
||||
|
||||
// Convert ChatCompletionRequest to UnifiedRequest
|
||||
unifiedReq := &models.UnifiedRequest{
|
||||
Model: modelID,
|
||||
Messages: []models.UnifiedMessage{},
|
||||
Temperature: req.Temperature,
|
||||
TopP: req.TopP,
|
||||
TopK: req.TopK,
|
||||
N: req.N,
|
||||
MaxTokens: req.MaxTokens,
|
||||
PresencePenalty: req.PresencePenalty,
|
||||
FrequencyPenalty: req.FrequencyPenalty,
|
||||
Stream: req.Stream != nil && *req.Stream,
|
||||
Tools: req.Tools,
|
||||
ToolChoice: req.ToolChoice,
|
||||
}
|
||||
|
||||
// Inject max_tokens from model registry when client doesn't specify one.
|
||||
// Prevents providers from applying a low default output cap.
|
||||
// DEBUG: Trace max_tokens through the proxy
|
||||
clientMaxTokens := "nil"
|
||||
if unifiedReq.MaxTokens != nil {
|
||||
clientMaxTokens = fmt.Sprintf("%d", *unifiedReq.MaxTokens)
|
||||
}
|
||||
log.Printf("[DEBUG] %s: client max_tokens=%s", modelID, clientMaxTokens)
|
||||
if unifiedReq.MaxTokens == nil {
|
||||
s.registryMu.RLock()
|
||||
meta := s.registry.FindModel(modelID)
|
||||
s.registryMu.RUnlock()
|
||||
if meta != nil && meta.Limit != nil && meta.Limit.Output > 0 {
|
||||
unifiedReq.MaxTokens = &meta.Limit.Output
|
||||
log.Printf("[DEBUG] %s: injected registry max_tokens=%d", modelID, meta.Limit.Output)
|
||||
} else {
|
||||
log.Printf("[DEBUG] %s: no registry limit found, leaving max_tokens nil (provider default)", modelID)
|
||||
}
|
||||
} else {
|
||||
log.Printf("[DEBUG] %s: using client's max_tokens=%d", modelID, *unifiedReq.MaxTokens)
|
||||
}
|
||||
|
||||
// Handle Stop sequences
|
||||
if req.Stop != nil {
|
||||
var stop []string
|
||||
if err := json.Unmarshal(req.Stop, &stop); err == nil {
|
||||
unifiedReq.Stop = stop
|
||||
} else {
|
||||
var singleStop string
|
||||
if err := json.Unmarshal(req.Stop, &singleStop); err == nil {
|
||||
unifiedReq.Stop = []string{singleStop}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert messages
|
||||
for _, msg := range req.Messages {
|
||||
unifiedMsg := models.UnifiedMessage{
|
||||
Role: msg.Role,
|
||||
Content: []models.UnifiedContentPart{},
|
||||
ReasoningContent: msg.ReasoningContent,
|
||||
ToolCalls: msg.ToolCalls,
|
||||
Name: msg.Name,
|
||||
ToolCallID: msg.ToolCallID,
|
||||
}
|
||||
|
||||
// Handle multimodal content
|
||||
if strContent, ok := msg.Content.(string); ok {
|
||||
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
|
||||
Type: "text",
|
||||
Text: strContent,
|
||||
})
|
||||
} else if parts, ok := msg.Content.([]interface{}); ok {
|
||||
for _, part := range parts {
|
||||
if partMap, ok := part.(map[string]interface{}); ok {
|
||||
partType, _ := partMap["type"].(string)
|
||||
if partType == "text" {
|
||||
text, _ := partMap["text"].(string)
|
||||
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
|
||||
Type: "text",
|
||||
Text: text,
|
||||
})
|
||||
} else if partType == "image_url" {
|
||||
if imgURLMap, ok := partMap["image_url"].(map[string]interface{}); ok {
|
||||
url, _ := imgURLMap["url"].(string)
|
||||
imageInput := &models.ImageInput{}
|
||||
if strings.HasPrefix(url, "data:") {
|
||||
mime, data, err := utils.ParseDataURL(url)
|
||||
if err == nil {
|
||||
imageInput.Base64 = data
|
||||
imageInput.MimeType = mime
|
||||
}
|
||||
} else {
|
||||
imageInput.URL = url
|
||||
}
|
||||
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
|
||||
Type: "image",
|
||||
Image: imageInput,
|
||||
})
|
||||
unifiedReq.HasImages = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unifiedReq.Messages = append(unifiedReq.Messages, unifiedMsg)
|
||||
}
|
||||
|
||||
clientID := "default"
|
||||
if auth, ok := c.Get("auth"); ok {
|
||||
if authInfo, ok := auth.(models.AuthInfo); ok {
|
||||
unifiedReq.ClientID = authInfo.ClientID
|
||||
clientID = authInfo.ClientID
|
||||
}
|
||||
} else {
|
||||
unifiedReq.ClientID = clientID
|
||||
}
|
||||
|
||||
if unifiedReq.Stream {
|
||||
ch, err := provider.ChatCompletionStream(c.Request.Context(), unifiedReq)
|
||||
if err != nil {
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, unifiedReq.HasImages)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
var lastUsage *models.Usage
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
chunk, ok := <-ch
|
||||
if !ok {
|
||||
fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, lastUsage, nil, unifiedReq.HasImages)
|
||||
return false
|
||||
}
|
||||
if chunk.Usage != nil {
|
||||
lastUsage = chunk.Usage
|
||||
}
|
||||
data, err := json.Marshal(chunk)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
return true
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := provider.ChatCompletion(c.Request.Context(), unifiedReq)
|
||||
if err != nil {
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, unifiedReq.HasImages)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, resp.Usage, nil, unifiedReq.HasImages)
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func extractUserMessage(messages []models.ChatMessage) string {
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if messages[i].Role == "user" {
|
||||
switch c := messages[i].Content.(type) {
|
||||
case string:
|
||||
return c
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *Server) handleImageGenerations(c *gin.Context) {
|
||||
startTime := time.Now()
|
||||
var req models.ImageGenerationRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Determine provider based on model name
|
||||
providerName := "openai"
|
||||
modelLower := strings.ToLower(req.Model)
|
||||
switch {
|
||||
case strings.Contains(modelLower, "imagen"), strings.Contains(modelLower, "gemini"):
|
||||
providerName = "gemini"
|
||||
case strings.Contains(modelLower, "dall"), strings.HasPrefix(modelLower, "openai/"):
|
||||
providerName = "openai"
|
||||
}
|
||||
|
||||
// Default model for each provider if not specified
|
||||
if req.Model == "" {
|
||||
if providerName == "openai" {
|
||||
req.Model = "dall-e-3"
|
||||
} else {
|
||||
req.Model = "imagen-3.0-generate-001"
|
||||
}
|
||||
}
|
||||
|
||||
// Strip common prefixes
|
||||
prefixes := []string{"openai/", "gemini/", "google/"}
|
||||
for _, p := range prefixes {
|
||||
if strings.HasPrefix(req.Model, p) {
|
||||
req.Model = strings.TrimPrefix(req.Model, p)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
provider, ok := s.providers[providerName]
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Provider %s not enabled or supported", providerName)})
|
||||
return
|
||||
}
|
||||
|
||||
clientID := "default"
|
||||
if auth, ok := c.Get("auth"); ok {
|
||||
if authInfo, ok := auth.(models.AuthInfo); ok {
|
||||
clientID = authInfo.ClientID
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := provider.ImageGeneration(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, false)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Estimate tokens from prompt text (~4 chars per token)
|
||||
promptTokens := uint32(len(req.Prompt) / 4)
|
||||
if promptTokens < 1 {
|
||||
promptTokens = 1
|
||||
}
|
||||
|
||||
// Calculate per-image cost (not per-token like chat)
|
||||
cost := imageGenCost(providerName, req.Model, req.Size, uint32(len(resp.Data)))
|
||||
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, &models.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: uint32(len(resp.Data)),
|
||||
TotalTokens: promptTokens + uint32(len(resp.Data)),
|
||||
}, nil, false)
|
||||
|
||||
// Update cost in DB — image gen is per-image, not per-token
|
||||
if cost > 0 {
|
||||
s.database.Exec("UPDATE llm_requests SET cost = ? WHERE id = (SELECT MAX(id) FROM llm_requests)", cost)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// imageGenCost returns per-image pricing for known image generation models.
|
||||
func imageGenCost(provider, model string, size *string, n uint32) float64 {
|
||||
if n == 0 {
|
||||
return 0
|
||||
}
|
||||
modelLower := strings.ToLower(model)
|
||||
var perImage float64
|
||||
|
||||
switch {
|
||||
case strings.Contains(modelLower, "dall-e-3"):
|
||||
perImage = 0.040 // standard 1024x1024
|
||||
if size != nil {
|
||||
s := *size
|
||||
if s == "1024x1792" || s == "1792x1024" {
|
||||
perImage = 0.080
|
||||
}
|
||||
}
|
||||
case strings.Contains(modelLower, "dall-e-2"):
|
||||
perImage = 0.020
|
||||
case strings.Contains(modelLower, "imagen"):
|
||||
perImage = 0.040 // approximate
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
|
||||
return perImage * float64(n)
|
||||
}
|
||||
|
||||
func (s *Server) logRequest(start time.Time, clientID, provider, model string, usage *models.Usage, err error, hasImages bool) {
|
||||
entry := RequestLog{
|
||||
Timestamp: start,
|
||||
ClientID: clientID,
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Status: "success",
|
||||
DurationMS: time.Since(start).Milliseconds(),
|
||||
HasImages: hasImages,
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
entry.Status = "error"
|
||||
entry.ErrorMessage = err.Error()
|
||||
}
|
||||
|
||||
if usage != nil {
|
||||
entry.PromptTokens = usage.PromptTokens
|
||||
entry.CompletionTokens = usage.CompletionTokens
|
||||
entry.TotalTokens = usage.TotalTokens
|
||||
if usage.ReasoningTokens != nil {
|
||||
entry.ReasoningTokens = *usage.ReasoningTokens
|
||||
}
|
||||
if usage.CacheReadTokens != nil {
|
||||
entry.CacheReadTokens = *usage.CacheReadTokens
|
||||
}
|
||||
if usage.CacheWriteTokens != nil {
|
||||
entry.CacheWriteTokens = *usage.CacheWriteTokens
|
||||
}
|
||||
|
||||
// Calculate cost using registry
|
||||
s.registryMu.RLock()
|
||||
entry.Cost = utils.CalculateCost(s.registry, model, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.CacheWriteTokens)
|
||||
s.registryMu.RUnlock()
|
||||
}
|
||||
|
||||
s.logger.LogRequest(entry)
|
||||
}
|
||||
|
||||
func (s *Server) Run() error {
|
||||
go s.hub.Run()
|
||||
s.logger.Start()
|
||||
|
||||
// Start registry refresher
|
||||
go func() {
|
||||
ticker := time.NewTicker(24 * time.Hour)
|
||||
for range ticker.C {
|
||||
newRegistry, err := utils.FetchRegistry()
|
||||
if err == nil {
|
||||
s.registryMu.Lock()
|
||||
s.registry = newRegistry
|
||||
s.registryMu.Unlock()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", s.cfg.Server.Host, s.cfg.Server.Port)
|
||||
return s.router.Run(addr)
|
||||
}
|
||||
|
||||
func uint32Ptr(v uint32) *uint32 { return &v }
|
||||
@@ -0,0 +1,173 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
Username string `json:"username"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Role string `json:"role"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
SessionID string `json:"session_id"`
|
||||
}
|
||||
|
||||
type SessionManager struct {
|
||||
sessions map[string]Session
|
||||
mu sync.RWMutex
|
||||
secret []byte
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
type sessionPayload struct {
|
||||
SessionID string `json:"session_id"`
|
||||
Username string `json:"username"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Role string `json:"role"`
|
||||
Exp int64 `json:"exp"`
|
||||
}
|
||||
|
||||
func NewSessionManager(secret []byte, ttl time.Duration) *SessionManager {
|
||||
return &SessionManager{
|
||||
sessions: make(map[string]Session),
|
||||
secret: secret,
|
||||
ttl: ttl,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *SessionManager) CreateSession(username, displayName, role string) (string, error) {
|
||||
sessionID := uuid.New().String()
|
||||
now := time.Now()
|
||||
expiresAt := now.Add(m.ttl)
|
||||
|
||||
m.mu.Lock()
|
||||
m.sessions[sessionID] = Session{
|
||||
Username: username,
|
||||
DisplayName: displayName,
|
||||
Role: role,
|
||||
CreatedAt: now,
|
||||
ExpiresAt: expiresAt,
|
||||
SessionID: sessionID,
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
return m.createSignedToken(sessionID, username, displayName, role, expiresAt.Unix())
|
||||
}
|
||||
|
||||
func (m *SessionManager) createSignedToken(sessionID, username, displayName, role string, exp int64) (string, error) {
|
||||
payload := sessionPayload{
|
||||
SessionID: sessionID,
|
||||
Username: username,
|
||||
DisplayName: displayName,
|
||||
Role: role,
|
||||
Exp: exp,
|
||||
}
|
||||
|
||||
payloadJSON, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
payloadB64 := base64.RawURLEncoding.EncodeToString(payloadJSON)
|
||||
|
||||
h := hmac.New(sha256.New, m.secret)
|
||||
h.Write(payloadJSON)
|
||||
signature := h.Sum(nil)
|
||||
signatureB64 := base64.RawURLEncoding.EncodeToString(signature)
|
||||
|
||||
return fmt.Sprintf("%s.%s", payloadB64, signatureB64), nil
|
||||
}
|
||||
|
||||
func (m *SessionManager) ValidateSession(token string) (*Session, string, error) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 2 {
|
||||
return nil, "", fmt.Errorf("invalid token format")
|
||||
}
|
||||
|
||||
payloadB64 := parts[0]
|
||||
signatureB64 := parts[1]
|
||||
|
||||
payloadJSON, err := base64.RawURLEncoding.DecodeString(payloadB64)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
signature, err := base64.RawURLEncoding.DecodeString(signatureB64)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
h := hmac.New(sha256.New, m.secret)
|
||||
h.Write(payloadJSON)
|
||||
if !hmac.Equal(signature, h.Sum(nil)) {
|
||||
return nil, "", fmt.Errorf("invalid signature")
|
||||
}
|
||||
|
||||
var payload sessionPayload
|
||||
if err := json.Unmarshal(payloadJSON, &payload); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
if time.Now().Unix() > payload.Exp {
|
||||
return nil, "", fmt.Errorf("token expired")
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
session, ok := m.sessions[payload.SessionID]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return nil, "", fmt.Errorf("session not found")
|
||||
}
|
||||
|
||||
return &session, "", nil
|
||||
}
|
||||
|
||||
func (m *SessionManager) RevokeSession(token string) error {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 2 {
|
||||
return fmt.Errorf("invalid token format")
|
||||
}
|
||||
|
||||
payloadJSON, err := base64.RawURLEncoding.DecodeString(parts[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode payload: %w", err)
|
||||
}
|
||||
|
||||
var payload sessionPayload
|
||||
if err := json.Unmarshal(payloadJSON, &payload); err != nil {
|
||||
return fmt.Errorf("failed to parse payload: %w", err)
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
delete(m.sessions, payload.SessionID)
|
||||
m.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartCleanup runs a background goroutine that removes expired sessions every 15 minutes.
|
||||
func (m *SessionManager) StartCleanup() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(15 * time.Minute)
|
||||
for range ticker.C {
|
||||
m.mu.Lock()
|
||||
now := time.Now()
|
||||
for id, s := range m.sessions {
|
||||
if now.After(s.ExpiresAt) {
|
||||
delete(m.sessions, id)
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -0,0 +1,155 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"gophergate/internal/db"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/shirou/gopsutil/v3/cpu"
|
||||
"github.com/shirou/gopsutil/v3/disk"
|
||||
"github.com/shirou/gopsutil/v3/load"
|
||||
"github.com/shirou/gopsutil/v3/mem"
|
||||
"github.com/shirou/gopsutil/v3/process"
|
||||
)
|
||||
|
||||
func (s *Server) handleSystemHealth(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{
|
||||
"status": "ok",
|
||||
"components": gin.H{
|
||||
"database": "online",
|
||||
"proxy": "online",
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
func (s *Server) handleSystemMetrics(c *gin.Context) {
|
||||
v, _ := mem.VirtualMemory()
|
||||
c_usage, _ := cpu.Percent(time.Second, false)
|
||||
d, _ := disk.Usage("/")
|
||||
l, _ := load.Avg()
|
||||
p, _ := process.NewProcess(int32(os.Getpid()))
|
||||
rss, _ := p.MemoryInfo()
|
||||
|
||||
cpuPercent := 0.0
|
||||
if len(c_usage) > 0 {
|
||||
cpuPercent = c_usage[0]
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{
|
||||
"cpu": gin.H{
|
||||
"usage_percent": fmt.Sprintf("%.1f", cpuPercent),
|
||||
"load_average": []float64{l.Load1, l.Load5, l.Load15},
|
||||
},
|
||||
"memory": gin.H{
|
||||
"used_mb": v.Used / 1024 / 1024,
|
||||
"total_mb": v.Total / 1024 / 1024,
|
||||
"usage_percent": fmt.Sprintf("%.1f", v.UsedPercent),
|
||||
"process_rss_mb": rss.RSS / 1024 / 1024,
|
||||
},
|
||||
"disk": gin.H{
|
||||
"used_gb": float64(d.Used) / 1024 / 1024 / 1024,
|
||||
"total_gb": float64(d.Total) / 1024 / 1024 / 1024,
|
||||
"usage_percent": fmt.Sprintf("%.1f", d.UsedPercent),
|
||||
},
|
||||
"connections": gin.H{
|
||||
"db_active": s.database.Stats().OpenConnections,
|
||||
"websocket_listeners": s.hub.GetClientCount(),
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
func (s *Server) handleGetSettings(c *gin.Context) {
|
||||
providerCount := 0
|
||||
modelCount := 0
|
||||
s.registryMu.RLock()
|
||||
if s.registry != nil {
|
||||
providerCount = len(s.registry.Providers)
|
||||
for _, p := range s.registry.Providers {
|
||||
modelCount += len(p.Models)
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{
|
||||
"server": gin.H{
|
||||
"version": "1.0.0-go",
|
||||
"auth_tokens": s.cfg.Server.AuthTokens,
|
||||
},
|
||||
"database": gin.H{
|
||||
"type": "sqlite",
|
||||
"path": s.cfg.Database.Path,
|
||||
},
|
||||
"registry": gin.H{
|
||||
"provider_count": providerCount,
|
||||
"model_count": modelCount,
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
func (s *Server) handleCreateBackup(c *gin.Context) {
|
||||
// Simplified backup response
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{
|
||||
"backup_id": fmt.Sprintf("backup-%d.db", time.Now().Unix()),
|
||||
"status": "created",
|
||||
}))
|
||||
}
|
||||
|
||||
func (s *Server) handleGetLogs(c *gin.Context) {
|
||||
var logs []db.LLMRequest
|
||||
err := s.database.Select(&logs, "SELECT * FROM llm_requests ORDER BY timestamp DESC LIMIT 100")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// Format for UI
|
||||
type UILog struct {
|
||||
Timestamp string `json:"timestamp"`
|
||||
ClientID string `json:"client_id"`
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
Tokens int `json:"tokens"`
|
||||
Status string `json:"status"`
|
||||
Duration int `json:"duration"`
|
||||
}
|
||||
|
||||
uiLogs := make([]UILog, len(logs))
|
||||
for i, l := range logs {
|
||||
clientID := "unknown"
|
||||
if l.ClientID != nil {
|
||||
clientID = *l.ClientID
|
||||
}
|
||||
provider := "unknown"
|
||||
if l.Provider != nil {
|
||||
provider = *l.Provider
|
||||
}
|
||||
model := "unknown"
|
||||
if l.Model != nil {
|
||||
model = *l.Model
|
||||
}
|
||||
tokens := 0
|
||||
if l.TotalTokens != nil {
|
||||
tokens = *l.TotalTokens
|
||||
}
|
||||
duration := 0
|
||||
if l.DurationMS != nil {
|
||||
duration = *l.DurationMS
|
||||
}
|
||||
|
||||
uiLogs[i] = UILog{
|
||||
Timestamp: l.Timestamp.Format(time.RFC3339),
|
||||
ClientID: clientID,
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Tokens: tokens,
|
||||
Status: l.Status,
|
||||
Duration: duration,
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(uiLogs))
|
||||
}
|
||||
@@ -0,0 +1,109 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"gophergate/internal/db"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
func (s *Server) handleGetUsers(c *gin.Context) {
|
||||
var users []db.User
|
||||
err := s.database.Select(&users, "SELECT id, username, display_name, role, must_change_password, created_at FROM users")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, SuccessResponse(users))
|
||||
}
|
||||
|
||||
type CreateUserRequest struct {
|
||||
Username string `json:"username" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
DisplayName *string `json:"display_name"`
|
||||
Role *string `json:"role"`
|
||||
}
|
||||
|
||||
func (s *Server) handleCreateUser(c *gin.Context) {
|
||||
var req CreateUserRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request"))
|
||||
return
|
||||
}
|
||||
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(req.Password), 12)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse("Failed to hash password"))
|
||||
return
|
||||
}
|
||||
|
||||
role := "viewer"
|
||||
if req.Role != nil {
|
||||
role = *req.Role
|
||||
}
|
||||
|
||||
_, err = s.database.Exec("INSERT INTO users (username, password_hash, display_name, role, must_change_password) VALUES (?, ?, ?, ?, 1)",
|
||||
req.Username, string(hash), req.DisplayName, role)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "User created"}))
|
||||
}
|
||||
|
||||
type UpdateUserRequest struct {
|
||||
DisplayName *string `json:"display_name"`
|
||||
Role *string `json:"role"`
|
||||
Password *string `json:"password"`
|
||||
MustChangePassword *bool `json:"must_change_password"`
|
||||
}
|
||||
|
||||
func (s *Server) handleUpdateUser(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
var req UpdateUserRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request"))
|
||||
return
|
||||
}
|
||||
|
||||
if req.DisplayName != nil {
|
||||
s.database.Exec("UPDATE users SET display_name = ? WHERE id = ?", req.DisplayName, id)
|
||||
}
|
||||
if req.Role != nil {
|
||||
s.database.Exec("UPDATE users SET role = ? WHERE id = ?", req.Role, id)
|
||||
}
|
||||
if req.MustChangePassword != nil {
|
||||
s.database.Exec("UPDATE users SET must_change_password = ? WHERE id = ?", req.MustChangePassword, id)
|
||||
}
|
||||
if req.Password != nil {
|
||||
hash, _ := bcrypt.GenerateFromPassword([]byte(*req.Password), 12)
|
||||
s.database.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hash), id)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "User updated"}))
|
||||
}
|
||||
|
||||
func (s *Server) handleDeleteUser(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
session, _ := c.Get("session")
|
||||
if sess, ok := session.(*Session); ok {
|
||||
var username string
|
||||
s.database.Get(&username, "SELECT username FROM users WHERE id = ?", id)
|
||||
if username == sess.Username {
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse("Cannot delete your own account"))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
_, err := s.database.Exec("DELETE FROM users WHERE id = ?", id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "User deleted"}))
|
||||
}
|
||||
@@ -0,0 +1,118 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func newUpgrader(allowedOrigin string) websocket.Upgrader {
|
||||
return websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
if allowedOrigin == "*" {
|
||||
return true
|
||||
}
|
||||
origin := r.Header.Get("Origin")
|
||||
return origin == "" || origin == allowedOrigin
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type Hub struct {
|
||||
clients map[*websocket.Conn]bool
|
||||
broadcast chan interface{}
|
||||
register chan *websocket.Conn
|
||||
unregister chan *websocket.Conn
|
||||
mu sync.Mutex
|
||||
clientCount int32
|
||||
}
|
||||
|
||||
func NewHub() *Hub {
|
||||
return &Hub{
|
||||
clients: make(map[*websocket.Conn]bool),
|
||||
broadcast: make(chan interface{}),
|
||||
register: make(chan *websocket.Conn),
|
||||
unregister: make(chan *websocket.Conn),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hub) Run() {
|
||||
for {
|
||||
select {
|
||||
case client := <-h.register:
|
||||
h.mu.Lock()
|
||||
h.clients[client] = true
|
||||
atomic.AddInt32(&h.clientCount, 1)
|
||||
h.mu.Unlock()
|
||||
log.Println("WebSocket client registered")
|
||||
case client := <-h.unregister:
|
||||
h.mu.Lock()
|
||||
if _, ok := h.clients[client]; ok {
|
||||
delete(h.clients, client)
|
||||
client.Close()
|
||||
atomic.AddInt32(&h.clientCount, -1)
|
||||
}
|
||||
h.mu.Unlock()
|
||||
log.Println("WebSocket client unregistered")
|
||||
case message := <-h.broadcast:
|
||||
h.mu.Lock()
|
||||
for client := range h.clients {
|
||||
err := client.WriteJSON(message)
|
||||
if err != nil {
|
||||
log.Printf("WebSocket error: %v", err)
|
||||
client.Close()
|
||||
delete(h.clients, client)
|
||||
atomic.AddInt32(&h.clientCount, -1)
|
||||
}
|
||||
}
|
||||
h.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hub) GetClientCount() int {
|
||||
return int(atomic.LoadInt32(&h.clientCount))
|
||||
}
|
||||
|
||||
func (s *Server) handleWebSocket(c *gin.Context) {
|
||||
allowedOrigin := s.cfg.Server.WSAllowedOrigin
|
||||
if allowedOrigin == "" {
|
||||
allowedOrigin = "*"
|
||||
}
|
||||
upgrader := newUpgrader(allowedOrigin)
|
||||
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
log.Printf("Failed to set websocket upgrade: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
s.hub.register <- conn
|
||||
|
||||
defer func() {
|
||||
s.hub.unregister <- conn
|
||||
}()
|
||||
|
||||
// Initial message
|
||||
conn.WriteJSON(gin.H{
|
||||
"type": "connected",
|
||||
"message": "Connected to GopherGate Dashboard",
|
||||
})
|
||||
|
||||
for {
|
||||
var msg map[string]interface{}
|
||||
err := conn.ReadJSON(&msg)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
if msg["type"] == "ping" {
|
||||
conn.WriteJSON(gin.H{"type": "pong", "payload": gin.H{}})
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// Encrypt encrypts plain text using AES-GCM with the given 32-byte key.
|
||||
func Encrypt(plainText string, key []byte) (string, error) {
|
||||
if len(key) != 32 {
|
||||
return "", fmt.Errorf("encryption key must be 32 bytes")
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// The nonce should be prepended to the ciphertext
|
||||
cipherText := gcm.Seal(nonce, nonce, []byte(plainText), nil)
|
||||
return base64.StdEncoding.EncodeToString(cipherText), nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts base64-encoded cipher text using AES-GCM with the given 32-byte key.
|
||||
func Decrypt(encodedCipherText string, key []byte) (string, error) {
|
||||
if len(key) != 32 {
|
||||
return "", fmt.Errorf("encryption key must be 32 bytes")
|
||||
}
|
||||
|
||||
cipherText, err := base64.StdEncoding.DecodeString(encodedCipherText)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
nonceSize := gcm.NonceSize()
|
||||
if len(cipherText) < nonceSize {
|
||||
return "", fmt.Errorf("cipher text too short")
|
||||
}
|
||||
|
||||
nonce, actualCipherText := cipherText[:nonceSize], cipherText[nonceSize:]
|
||||
plainText, err := gcm.Open(nil, nonce, actualCipherText, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(plainText), nil
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
"gophergate/internal/models"
|
||||
)
|
||||
|
||||
const ModelsDevURL = "https://models.dev/api.json"
|
||||
|
||||
func FetchRegistry() (*models.ModelRegistry, error) {
|
||||
client := resty.New().SetTimeout(10 * time.Second)
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < 3; attempt++ {
|
||||
if attempt > 0 {
|
||||
backoff := time.Duration(1<<attempt) * time.Second
|
||||
time.Sleep(backoff)
|
||||
}
|
||||
|
||||
resp, err := client.R().Get(ModelsDevURL)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("attempt %d: %w", attempt+1, err)
|
||||
continue
|
||||
}
|
||||
if !resp.IsSuccess() {
|
||||
lastErr = fmt.Errorf("attempt %d: HTTP %d", attempt+1, resp.StatusCode())
|
||||
continue
|
||||
}
|
||||
|
||||
var providers map[string]models.ProviderInfo
|
||||
if err := json.Unmarshal(resp.Body(), &providers); err != nil {
|
||||
lastErr = fmt.Errorf("attempt %d: unmarshal: %w", attempt+1, err)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Println("Successfully loaded model registry")
|
||||
return &models.ModelRegistry{Providers: providers}, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed to fetch registry after 3 attempts: %w", lastErr)
|
||||
}
|
||||
|
||||
func CalculateCost(registry *models.ModelRegistry, modelID string, promptTokens, completionTokens, reasoningTokens, cacheRead, cacheWrite uint32) float64 {
|
||||
meta := registry.FindModel(modelID)
|
||||
if meta == nil || meta.Cost == nil {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// promptTokens is usually the TOTAL prompt size.
|
||||
// We subtract cacheRead from it to get the uncached part.
|
||||
uncachedTokens := promptTokens
|
||||
if cacheRead > 0 {
|
||||
if cacheRead > promptTokens {
|
||||
uncachedTokens = 0
|
||||
} else {
|
||||
uncachedTokens = promptTokens - cacheRead
|
||||
}
|
||||
}
|
||||
|
||||
cost := (float64(uncachedTokens) * meta.Cost.Input / 1000000.0) +
|
||||
(float64(completionTokens) * meta.Cost.Output / 1000000.0)
|
||||
|
||||
if meta.Cost.CacheRead != nil {
|
||||
cost += float64(cacheRead) * (*meta.Cost.CacheRead) / 1000000.0
|
||||
}
|
||||
if meta.Cost.CacheWrite != nil {
|
||||
cost += float64(cacheWrite) * (*meta.Cost.CacheWrite) / 1000000.0
|
||||
}
|
||||
|
||||
return cost
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gophergate/internal/models"
|
||||
)
|
||||
|
||||
func TestCalculateCost_NotFound(t *testing.T) {
|
||||
r := &models.ModelRegistry{Providers: make(map[string]models.ProviderInfo)}
|
||||
cost := CalculateCost(r, "unknown-model", 100, 50, 0, 0, 0)
|
||||
if cost != 0.0 {
|
||||
t.Fatalf("expected 0 cost for unknown model, got %f", cost)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateCost_KnownModel(t *testing.T) {
|
||||
inputCost := 2.5 // $2.50 per 1M tokens
|
||||
outputCost := 10.0 // $10.00 per 1M tokens
|
||||
r := &models.ModelRegistry{
|
||||
Providers: map[string]models.ProviderInfo{
|
||||
"openai": {
|
||||
Models: map[string]models.ModelMetadata{
|
||||
"gpt-4o": {
|
||||
Cost: &models.ModelCost{
|
||||
Input: inputCost,
|
||||
Output: outputCost,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cost := CalculateCost(r, "gpt-4o", 1000, 500, 0, 0, 0)
|
||||
expected := (1000 * inputCost / 1000000.0) + (500 * outputCost / 1000000.0)
|
||||
if cost != expected {
|
||||
t.Fatalf("expected %f, got %f", expected, cost)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func ParseDataURL(dataURL string) (string, string, error) {
|
||||
if !strings.HasPrefix(dataURL, "data:") {
|
||||
return "", "", fmt.Errorf("not a data URL")
|
||||
}
|
||||
|
||||
parts := strings.Split(dataURL[5:], ";base64,")
|
||||
if len(parts) != 2 {
|
||||
return "", "", fmt.Errorf("invalid data URL format")
|
||||
}
|
||||
|
||||
return parts[0], parts[1], nil
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
-- Migration: add billing_mode to provider_configs
|
||||
-- Adds a billing_mode TEXT column with default 'prepaid'
|
||||
-- After applying, set Gemini to postpaid with:
|
||||
-- UPDATE provider_configs SET billing_mode = 'postpaid' WHERE id = 'gemini';
|
||||
|
||||
BEGIN TRANSACTION;
|
||||
|
||||
ALTER TABLE provider_configs ADD COLUMN billing_mode TEXT DEFAULT 'prepaid';
|
||||
|
||||
COMMIT;
|
||||
|
||||
-- NOTE: If you use a production SQLite file, run the following to set Gemini to postpaid:
|
||||
-- sqlite3 /path/to/db.sqlite "UPDATE provider_configs SET billing_mode='postpaid' WHERE id='gemini';"
|
||||
@@ -1,13 +0,0 @@
|
||||
-- Migration: add composite indexes for query performance
|
||||
-- Adds three composite indexes:
|
||||
-- 1. idx_llm_requests_client_timestamp on llm_requests(client_id, timestamp)
|
||||
-- 2. idx_llm_requests_provider_timestamp on llm_requests(provider, timestamp)
|
||||
-- 3. idx_model_configs_provider_id on model_configs(provider_id)
|
||||
|
||||
BEGIN TRANSACTION;
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_llm_requests_client_timestamp ON llm_requests(client_id, timestamp);
|
||||
CREATE INDEX IF NOT EXISTS idx_llm_requests_provider_timestamp ON llm_requests(provider, timestamp);
|
||||
CREATE INDEX IF NOT EXISTS idx_model_configs_provider_id ON model_configs(provider_id);
|
||||
|
||||
COMMIT;
|
||||
@@ -1,2 +0,0 @@
|
||||
max_width = 120
|
||||
use_field_init_shorthand = true
|
||||
-11
@@ -1,11 +0,0 @@
|
||||
[2m2026-03-06T20:07:36.737914Z[0m [32m INFO[0m Starting LLM Proxy Gateway v0.1.0
|
||||
[2m2026-03-06T20:07:36.738903Z[0m [32m INFO[0m Configuration loaded from Some("/home/newkirk/Documents/projects/web_projects/llm-proxy/config.toml")
|
||||
[2m2026-03-06T20:07:36.738945Z[0m [32m INFO[0m Encryption initialized
|
||||
[2m2026-03-06T20:07:36.739124Z[0m [32m INFO[0m Connecting to database at ./data/llm_proxy.db
|
||||
[2m2026-03-06T20:07:36.753254Z[0m [32m INFO[0m Database migrations completed
|
||||
[2m2026-03-06T20:07:36.753294Z[0m [32m INFO[0m Database initialized at "./data/llm_proxy.db"
|
||||
[2m2026-03-06T20:07:36.755187Z[0m [32m INFO[0m Fetching model registry from https://models.dev/api.json
|
||||
[2m2026-03-06T20:07:37.000853Z[0m [32m INFO[0m Successfully loaded model registry
|
||||
[2m2026-03-06T20:07:37.001382Z[0m [32m INFO[0m Model config cache initialized
|
||||
[2m2026-03-06T20:07:37.001702Z[0m [33m WARN[0m SESSION_SECRET environment variable not set. Using a randomly generated secret. This will invalidate all sessions on restart. Set SESSION_SECRET to a fixed hex or base64 encoded 32-byte value.
|
||||
[2m2026-03-06T20:07:37.002898Z[0m [32m INFO[0m Server listening on http://0.0.0.0:8082
|
||||
@@ -1 +0,0 @@
|
||||
945904
|
||||
@@ -1,45 +0,0 @@
|
||||
use axum::{extract::FromRequestParts, http::request::Parts};
|
||||
|
||||
use crate::errors::AppError;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AuthInfo {
|
||||
pub token: String,
|
||||
pub client_id: String,
|
||||
}
|
||||
|
||||
pub struct AuthenticatedClient {
|
||||
pub info: AuthInfo,
|
||||
}
|
||||
|
||||
impl<S> FromRequestParts<S> for AuthenticatedClient
|
||||
where
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = AppError;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
|
||||
// Retrieve AuthInfo from request extensions, where it was placed by rate_limit_middleware
|
||||
let info = parts
|
||||
.extensions
|
||||
.get::<AuthInfo>()
|
||||
.cloned()
|
||||
.ok_or_else(|| AppError::AuthError("Authentication info not found in request".to_string()))?;
|
||||
|
||||
Ok(AuthenticatedClient { info })
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for AuthenticatedClient {
|
||||
type Target = AuthInfo;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.info
|
||||
}
|
||||
}
|
||||
|
||||
pub fn validate_token(token: &str, valid_tokens: &[String]) -> bool {
|
||||
// Simple validation against list of tokens
|
||||
// In production, use proper token validation (JWT, database lookup, etc.)
|
||||
valid_tokens.contains(&token.to_string())
|
||||
}
|
||||
@@ -1,304 +0,0 @@
|
||||
//! Client management for LLM proxy
|
||||
//!
|
||||
//! This module handles:
|
||||
//! 1. Client registration and management
|
||||
//! 2. Client usage tracking
|
||||
//! 3. Client rate limit configuration
|
||||
|
||||
use anyhow::Result;
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{Row, SqlitePool};
|
||||
use tracing::{info, warn};
|
||||
|
||||
/// Client information
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Client {
|
||||
pub id: i64,
|
||||
pub client_id: String,
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
pub is_active: bool,
|
||||
pub rate_limit_per_minute: i64,
|
||||
pub total_requests: i64,
|
||||
pub total_tokens: i64,
|
||||
pub total_cost: f64,
|
||||
}
|
||||
|
||||
/// Client creation request
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CreateClientRequest {
|
||||
pub client_id: String,
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub rate_limit_per_minute: Option<i64>,
|
||||
}
|
||||
|
||||
/// Client update request
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct UpdateClientRequest {
|
||||
pub name: Option<String>,
|
||||
pub description: Option<String>,
|
||||
pub is_active: Option<bool>,
|
||||
pub rate_limit_per_minute: Option<i64>,
|
||||
}
|
||||
|
||||
/// Client manager for database operations
|
||||
pub struct ClientManager {
|
||||
db_pool: SqlitePool,
|
||||
}
|
||||
|
||||
impl ClientManager {
|
||||
pub fn new(db_pool: SqlitePool) -> Self {
|
||||
Self { db_pool }
|
||||
}
|
||||
|
||||
/// Create a new client
|
||||
pub async fn create_client(&self, request: CreateClientRequest) -> Result<Client> {
|
||||
let rate_limit = request.rate_limit_per_minute.unwrap_or(60);
|
||||
|
||||
// First insert the client
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO clients (client_id, name, description, rate_limit_per_minute)
|
||||
VALUES (?, ?, ?, ?)
|
||||
"#,
|
||||
)
|
||||
.bind(&request.client_id)
|
||||
.bind(&request.name)
|
||||
.bind(&request.description)
|
||||
.bind(rate_limit)
|
||||
.execute(&self.db_pool)
|
||||
.await?;
|
||||
|
||||
// Then fetch the created client
|
||||
let client = self
|
||||
.get_client(&request.client_id)
|
||||
.await?
|
||||
.ok_or_else(|| anyhow::anyhow!("Failed to retrieve created client"))?;
|
||||
|
||||
info!("Created client: {} ({})", client.name, client.client_id);
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
/// Get a client by ID
|
||||
pub async fn get_client(&self, client_id: &str) -> Result<Option<Client>> {
|
||||
let row = sqlx::query(
|
||||
r#"
|
||||
SELECT
|
||||
id, client_id, name, description,
|
||||
created_at, updated_at, is_active,
|
||||
rate_limit_per_minute, total_requests, total_tokens, total_cost
|
||||
FROM clients
|
||||
WHERE client_id = ?
|
||||
"#,
|
||||
)
|
||||
.bind(client_id)
|
||||
.fetch_optional(&self.db_pool)
|
||||
.await?;
|
||||
|
||||
if let Some(row) = row {
|
||||
let client = Client {
|
||||
id: row.get("id"),
|
||||
client_id: row.get("client_id"),
|
||||
name: row.get("name"),
|
||||
description: row.get("description"),
|
||||
created_at: row.get("created_at"),
|
||||
updated_at: row.get("updated_at"),
|
||||
is_active: row.get("is_active"),
|
||||
rate_limit_per_minute: row.get("rate_limit_per_minute"),
|
||||
total_requests: row.get("total_requests"),
|
||||
total_tokens: row.get("total_tokens"),
|
||||
total_cost: row.get("total_cost"),
|
||||
};
|
||||
Ok(Some(client))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Update a client
|
||||
pub async fn update_client(&self, client_id: &str, request: UpdateClientRequest) -> Result<Option<Client>> {
|
||||
// First, get the current client to check if it exists
|
||||
let current_client = self.get_client(client_id).await?;
|
||||
if current_client.is_none() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Build update query dynamically based on provided fields
|
||||
let mut query_builder = sqlx::QueryBuilder::new("UPDATE clients SET ");
|
||||
let mut has_updates = false;
|
||||
|
||||
if let Some(name) = &request.name {
|
||||
query_builder.push("name = ");
|
||||
query_builder.push_bind(name);
|
||||
has_updates = true;
|
||||
}
|
||||
|
||||
if let Some(description) = &request.description {
|
||||
if has_updates {
|
||||
query_builder.push(", ");
|
||||
}
|
||||
query_builder.push("description = ");
|
||||
query_builder.push_bind(description);
|
||||
has_updates = true;
|
||||
}
|
||||
|
||||
if let Some(is_active) = request.is_active {
|
||||
if has_updates {
|
||||
query_builder.push(", ");
|
||||
}
|
||||
query_builder.push("is_active = ");
|
||||
query_builder.push_bind(is_active);
|
||||
has_updates = true;
|
||||
}
|
||||
|
||||
if let Some(rate_limit) = request.rate_limit_per_minute {
|
||||
if has_updates {
|
||||
query_builder.push(", ");
|
||||
}
|
||||
query_builder.push("rate_limit_per_minute = ");
|
||||
query_builder.push_bind(rate_limit);
|
||||
has_updates = true;
|
||||
}
|
||||
|
||||
// Always update the updated_at timestamp
|
||||
if has_updates {
|
||||
query_builder.push(", ");
|
||||
}
|
||||
query_builder.push("updated_at = CURRENT_TIMESTAMP");
|
||||
|
||||
if !has_updates {
|
||||
// No updates to make
|
||||
return self.get_client(client_id).await;
|
||||
}
|
||||
|
||||
query_builder.push(" WHERE client_id = ");
|
||||
query_builder.push_bind(client_id);
|
||||
|
||||
let query = query_builder.build();
|
||||
query.execute(&self.db_pool).await?;
|
||||
|
||||
// Fetch the updated client
|
||||
let updated_client = self.get_client(client_id).await?;
|
||||
|
||||
if updated_client.is_some() {
|
||||
info!("Updated client: {}", client_id);
|
||||
}
|
||||
|
||||
Ok(updated_client)
|
||||
}
|
||||
|
||||
/// List all clients
|
||||
pub async fn list_clients(&self, limit: Option<i64>, offset: Option<i64>) -> Result<Vec<Client>> {
|
||||
let limit = limit.unwrap_or(100);
|
||||
let offset = offset.unwrap_or(0);
|
||||
|
||||
let rows = sqlx::query(
|
||||
r#"
|
||||
SELECT
|
||||
id, client_id, name, description,
|
||||
created_at, updated_at, is_active,
|
||||
rate_limit_per_minute, total_requests, total_tokens, total_cost
|
||||
FROM clients
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ? OFFSET ?
|
||||
"#,
|
||||
)
|
||||
.bind(limit)
|
||||
.bind(offset)
|
||||
.fetch_all(&self.db_pool)
|
||||
.await?;
|
||||
|
||||
let mut clients = Vec::new();
|
||||
for row in rows {
|
||||
let client = Client {
|
||||
id: row.get("id"),
|
||||
client_id: row.get("client_id"),
|
||||
name: row.get("name"),
|
||||
description: row.get("description"),
|
||||
created_at: row.get("created_at"),
|
||||
updated_at: row.get("updated_at"),
|
||||
is_active: row.get("is_active"),
|
||||
rate_limit_per_minute: row.get("rate_limit_per_minute"),
|
||||
total_requests: row.get("total_requests"),
|
||||
total_tokens: row.get("total_tokens"),
|
||||
total_cost: row.get("total_cost"),
|
||||
};
|
||||
clients.push(client);
|
||||
}
|
||||
|
||||
Ok(clients)
|
||||
}
|
||||
|
||||
/// Delete a client
|
||||
pub async fn delete_client(&self, client_id: &str) -> Result<bool> {
|
||||
let result = sqlx::query("DELETE FROM clients WHERE client_id = ?")
|
||||
.bind(client_id)
|
||||
.execute(&self.db_pool)
|
||||
.await?;
|
||||
|
||||
let deleted = result.rows_affected() > 0;
|
||||
|
||||
if deleted {
|
||||
info!("Deleted client: {}", client_id);
|
||||
} else {
|
||||
warn!("Client not found for deletion: {}", client_id);
|
||||
}
|
||||
|
||||
Ok(deleted)
|
||||
}
|
||||
|
||||
/// Update client usage statistics after a request
|
||||
pub async fn update_client_usage(&self, client_id: &str, tokens: i64, cost: f64) -> Result<()> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE clients
|
||||
SET
|
||||
total_requests = total_requests + 1,
|
||||
total_tokens = total_tokens + ?,
|
||||
total_cost = total_cost + ?,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE client_id = ?
|
||||
"#,
|
||||
)
|
||||
.bind(tokens)
|
||||
.bind(cost)
|
||||
.bind(client_id)
|
||||
.execute(&self.db_pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get client usage statistics
|
||||
pub async fn get_client_usage(&self, client_id: &str) -> Result<Option<(i64, i64, f64)>> {
|
||||
let row = sqlx::query(
|
||||
r#"
|
||||
SELECT total_requests, total_tokens, total_cost
|
||||
FROM clients
|
||||
WHERE client_id = ?
|
||||
"#,
|
||||
)
|
||||
.bind(client_id)
|
||||
.fetch_optional(&self.db_pool)
|
||||
.await?;
|
||||
|
||||
if let Some(row) = row {
|
||||
let total_requests: i64 = row.get("total_requests");
|
||||
let total_tokens: i64 = row.get("total_tokens");
|
||||
let total_cost: f64 = row.get("total_cost");
|
||||
Ok(Some((total_requests, total_tokens, total_cost)))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a client exists and is active
|
||||
pub async fn validate_client(&self, client_id: &str) -> Result<bool> {
|
||||
let client = self.get_client(client_id).await?;
|
||||
Ok(client.map(|c| c.is_active).unwrap_or(false))
|
||||
}
|
||||
}
|
||||
@@ -1,260 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use base64::{Engine as _};
|
||||
use config::{Config, File, FileFormat};
|
||||
use hex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ServerConfig {
|
||||
pub port: u16,
|
||||
pub host: String,
|
||||
#[serde(deserialize_with = "deserialize_vec_or_string")]
|
||||
pub auth_tokens: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DatabaseConfig {
|
||||
pub path: PathBuf,
|
||||
pub max_connections: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ProviderConfig {
|
||||
pub openai: OpenAIConfig,
|
||||
pub gemini: GeminiConfig,
|
||||
pub deepseek: DeepSeekConfig,
|
||||
pub grok: GrokConfig,
|
||||
pub ollama: OllamaConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OpenAIConfig {
|
||||
pub api_key_env: String,
|
||||
pub base_url: String,
|
||||
pub default_model: String,
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GeminiConfig {
|
||||
pub api_key_env: String,
|
||||
pub base_url: String,
|
||||
pub default_model: String,
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DeepSeekConfig {
|
||||
pub api_key_env: String,
|
||||
pub base_url: String,
|
||||
pub default_model: String,
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GrokConfig {
|
||||
pub api_key_env: String,
|
||||
pub base_url: String,
|
||||
pub default_model: String,
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OllamaConfig {
|
||||
pub base_url: String,
|
||||
pub enabled: bool,
|
||||
#[serde(deserialize_with = "deserialize_vec_or_string")]
|
||||
pub models: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelMappingConfig {
|
||||
pub patterns: Vec<(String, String)>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PricingConfig {
|
||||
pub openai: Vec<ModelPricing>,
|
||||
pub gemini: Vec<ModelPricing>,
|
||||
pub deepseek: Vec<ModelPricing>,
|
||||
pub grok: Vec<ModelPricing>,
|
||||
pub ollama: Vec<ModelPricing>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelPricing {
|
||||
pub model: String,
|
||||
pub prompt_tokens_per_million: f64,
|
||||
pub completion_tokens_per_million: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AppConfig {
|
||||
pub server: ServerConfig,
|
||||
pub database: DatabaseConfig,
|
||||
pub providers: ProviderConfig,
|
||||
pub model_mapping: ModelMappingConfig,
|
||||
pub pricing: PricingConfig,
|
||||
pub config_path: Option<PathBuf>,
|
||||
pub encryption_key: String,
|
||||
}
|
||||
|
||||
impl AppConfig {
|
||||
pub async fn load() -> Result<Arc<Self>> {
|
||||
Self::load_from_path(None).await
|
||||
}
|
||||
|
||||
/// Load configuration from a specific path (for testing)
|
||||
pub async fn load_from_path(config_path: Option<PathBuf>) -> Result<Arc<Self>> {
|
||||
// Load configuration from multiple sources
|
||||
let mut config_builder = Config::builder();
|
||||
|
||||
// Default configuration
|
||||
config_builder = config_builder
|
||||
.set_default("server.port", 8080)?
|
||||
.set_default("server.host", "0.0.0.0")?
|
||||
.set_default("server.auth_tokens", Vec::<String>::new())?
|
||||
.set_default("database.path", "./data/llm_proxy.db")?
|
||||
.set_default("database.max_connections", 10)?
|
||||
.set_default("providers.openai.api_key_env", "OPENAI_API_KEY")?
|
||||
.set_default("providers.openai.base_url", "https://api.openai.com/v1")?
|
||||
.set_default("providers.openai.default_model", "gpt-4o")?
|
||||
.set_default("providers.openai.enabled", true)?
|
||||
.set_default("providers.gemini.api_key_env", "GEMINI_API_KEY")?
|
||||
.set_default(
|
||||
"providers.gemini.base_url",
|
||||
"https://generativelanguage.googleapis.com/v1",
|
||||
)?
|
||||
.set_default("providers.gemini.default_model", "gemini-2.0-flash")?
|
||||
.set_default("providers.gemini.enabled", true)?
|
||||
.set_default("providers.deepseek.api_key_env", "DEEPSEEK_API_KEY")?
|
||||
.set_default("providers.deepseek.base_url", "https://api.deepseek.com")?
|
||||
.set_default("providers.deepseek.default_model", "deepseek-reasoner")?
|
||||
.set_default("providers.deepseek.enabled", true)?
|
||||
.set_default("providers.grok.api_key_env", "GROK_API_KEY")?
|
||||
.set_default("providers.grok.base_url", "https://api.x.ai/v1")?
|
||||
.set_default("providers.grok.default_model", "grok-beta")?
|
||||
.set_default("providers.grok.enabled", true)?
|
||||
.set_default("providers.ollama.base_url", "http://localhost:11434/v1")?
|
||||
.set_default("providers.ollama.enabled", false)?
|
||||
.set_default("providers.ollama.models", Vec::<String>::new())?
|
||||
.set_default("encryption_key", "")?;
|
||||
|
||||
// Load from config file if exists
|
||||
// Priority: explicit path arg > LLM_PROXY__CONFIG_PATH env var > ./config.toml
|
||||
let config_path = config_path
|
||||
.or_else(|| std::env::var("LLM_PROXY__CONFIG_PATH").ok().map(PathBuf::from))
|
||||
.unwrap_or_else(|| {
|
||||
std::env::current_dir()
|
||||
.unwrap_or_else(|_| PathBuf::from("."))
|
||||
.join("config.toml")
|
||||
});
|
||||
if config_path.exists() {
|
||||
config_builder = config_builder.add_source(File::from(config_path.clone()).format(FileFormat::Toml));
|
||||
}
|
||||
|
||||
// Load from .env file
|
||||
dotenvy::dotenv().ok();
|
||||
|
||||
// Load from environment variables (with prefix "LLM_PROXY_")
|
||||
config_builder = config_builder.add_source(
|
||||
config::Environment::with_prefix("LLM_PROXY")
|
||||
.separator("__")
|
||||
.try_parsing(true),
|
||||
);
|
||||
|
||||
let config = config_builder.build()?;
|
||||
|
||||
// Deserialize configuration
|
||||
let server: ServerConfig = config.get("server")?;
|
||||
let database: DatabaseConfig = config.get("database")?;
|
||||
let providers: ProviderConfig = config.get("providers")?;
|
||||
let encryption_key: String = config.get("encryption_key")?;
|
||||
|
||||
// Validate encryption key length (must be 32 bytes after hex or base64 decoding)
|
||||
if encryption_key.is_empty() {
|
||||
anyhow::bail!("Encryption key is required (LLM_PROXY__ENCRYPTION_KEY environment variable)");
|
||||
}
|
||||
// Try hex decode first, then base64
|
||||
let key_bytes = hex::decode(&encryption_key)
|
||||
.or_else(|_| base64::engine::general_purpose::STANDARD.decode(&encryption_key))
|
||||
.map_err(|e| anyhow::anyhow!("Encryption key must be hex or base64 encoded: {}", e))?;
|
||||
if key_bytes.len() != 32 {
|
||||
anyhow::bail!("Encryption key must be 32 bytes (256 bits), got {} bytes", key_bytes.len());
|
||||
}
|
||||
|
||||
// For now, use empty model mapping and pricing (will be populated later)
|
||||
let model_mapping = ModelMappingConfig { patterns: vec![] };
|
||||
let pricing = PricingConfig {
|
||||
openai: vec![],
|
||||
gemini: vec![],
|
||||
deepseek: vec![],
|
||||
grok: vec![],
|
||||
ollama: vec![],
|
||||
};
|
||||
|
||||
Ok(Arc::new(AppConfig {
|
||||
server,
|
||||
database,
|
||||
providers,
|
||||
model_mapping,
|
||||
pricing,
|
||||
config_path: Some(config_path),
|
||||
encryption_key,
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn get_api_key(&self, provider: &str) -> Result<String> {
|
||||
let env_var = match provider {
|
||||
"openai" => &self.providers.openai.api_key_env,
|
||||
"gemini" => &self.providers.gemini.api_key_env,
|
||||
"deepseek" => &self.providers.deepseek.api_key_env,
|
||||
"grok" => &self.providers.grok.api_key_env,
|
||||
_ => return Err(anyhow::anyhow!("Unknown provider: {}", provider)),
|
||||
};
|
||||
|
||||
std::env::var(env_var).map_err(|_| anyhow::anyhow!("Environment variable {} not set for {}", env_var, provider))
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to deserialize a Vec<String> from either a sequence or a comma-separated string
|
||||
fn deserialize_vec_or_string<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
struct VecOrString;
|
||||
|
||||
impl<'de> serde::de::Visitor<'de> for VecOrString {
|
||||
type Value = Vec<String>;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
formatter.write_str("a sequence or a comma-separated string")
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
Ok(value
|
||||
.split(',')
|
||||
.map(|s| s.trim().to_string())
|
||||
.filter(|s| !s.is_empty())
|
||||
.collect())
|
||||
}
|
||||
|
||||
fn visit_seq<S>(self, mut seq: S) -> Result<Self::Value, S::Error>
|
||||
where
|
||||
S: serde::de::SeqAccess<'de>,
|
||||
{
|
||||
let mut vec = Vec::new();
|
||||
while let Some(element) = seq.next_element()? {
|
||||
vec.push(element);
|
||||
}
|
||||
Ok(vec)
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_any(VecOrString)
|
||||
}
|
||||
@@ -1,229 +0,0 @@
|
||||
use axum::{extract::State, http::{HeaderMap, HeaderValue}, response::{Json, IntoResponse}};
|
||||
use bcrypt;
|
||||
use serde::Deserialize;
|
||||
use sqlx::Row;
|
||||
use tracing::warn;
|
||||
|
||||
use super::{ApiResponse, DashboardState};
|
||||
|
||||
// Authentication handlers
|
||||
#[derive(Deserialize)]
|
||||
pub(super) struct LoginRequest {
|
||||
pub(super) username: String,
|
||||
pub(super) password: String,
|
||||
}
|
||||
|
||||
pub(super) async fn handle_login(
|
||||
State(state): State<DashboardState>,
|
||||
Json(payload): Json<LoginRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
let user_result = sqlx::query(
|
||||
"SELECT username, password_hash, display_name, role, must_change_password FROM users WHERE username = ?",
|
||||
)
|
||||
.bind(&payload.username)
|
||||
.fetch_optional(pool)
|
||||
.await;
|
||||
|
||||
match user_result {
|
||||
Ok(Some(row)) => {
|
||||
let hash = row.get::<String, _>("password_hash");
|
||||
if bcrypt::verify(&payload.password, &hash).unwrap_or(false) {
|
||||
let username = row.get::<String, _>("username");
|
||||
let role = row.get::<String, _>("role");
|
||||
let display_name = row
|
||||
.get::<Option<String>, _>("display_name")
|
||||
.unwrap_or_else(|| username.clone());
|
||||
let must_change_password = row.get::<bool, _>("must_change_password");
|
||||
let token = state
|
||||
.session_manager
|
||||
.create_session(username.clone(), role.clone())
|
||||
.await;
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"token": token,
|
||||
"must_change_password": must_change_password,
|
||||
"user": {
|
||||
"username": username,
|
||||
"name": display_name,
|
||||
"role": role
|
||||
}
|
||||
})))
|
||||
} else {
|
||||
Json(ApiResponse::error("Invalid username or password".to_string()))
|
||||
}
|
||||
}
|
||||
Ok(None) => Json(ApiResponse::error("Invalid username or password".to_string())),
|
||||
Err(e) => {
|
||||
warn!("Database error during login: {}", e);
|
||||
Json(ApiResponse::error("Login failed due to system error".to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_auth_status(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
) -> impl IntoResponse {
|
||||
let token = headers
|
||||
.get("Authorization")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|v| v.strip_prefix("Bearer "));
|
||||
|
||||
if let Some(token) = token
|
||||
&& let Some((session, new_token)) = state.session_manager.validate_session_with_refresh(token).await
|
||||
{
|
||||
// Look up display_name from DB
|
||||
let display_name = sqlx::query_scalar::<_, Option<String>>(
|
||||
"SELECT display_name FROM users WHERE username = ?",
|
||||
)
|
||||
.bind(&session.username)
|
||||
.fetch_optional(&state.app_state.db_pool)
|
||||
.await
|
||||
.ok()
|
||||
.flatten()
|
||||
.flatten()
|
||||
.unwrap_or_else(|| session.username.clone());
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
if let Some(refreshed_token) = new_token {
|
||||
if let Ok(header_value) = HeaderValue::from_str(&refreshed_token) {
|
||||
headers.insert("X-Refreshed-Token", header_value);
|
||||
}
|
||||
}
|
||||
return (headers, Json(ApiResponse::success(serde_json::json!({
|
||||
"authenticated": true,
|
||||
"user": {
|
||||
"username": session.username,
|
||||
"name": display_name,
|
||||
"role": session.role
|
||||
}
|
||||
}))));
|
||||
}
|
||||
|
||||
(HeaderMap::new(), Json(ApiResponse::error("Not authenticated".to_string())))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(super) struct ChangePasswordRequest {
|
||||
pub(super) current_password: String,
|
||||
pub(super) new_password: String,
|
||||
}
|
||||
|
||||
pub(super) async fn handle_change_password(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Json(payload): Json<ChangePasswordRequest>,
|
||||
) -> impl IntoResponse {
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Extract the authenticated user from the session token
|
||||
let token = headers
|
||||
.get("Authorization")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|v| v.strip_prefix("Bearer "));
|
||||
|
||||
let (session, new_token) = match token {
|
||||
Some(t) => match state.session_manager.validate_session_with_refresh(t).await {
|
||||
Some((session, new_token)) => (Some(session), new_token),
|
||||
None => (None, None),
|
||||
},
|
||||
None => (None, None),
|
||||
};
|
||||
|
||||
let mut response_headers = HeaderMap::new();
|
||||
if let Some(refreshed_token) = new_token {
|
||||
if let Ok(header_value) = HeaderValue::from_str(&refreshed_token) {
|
||||
response_headers.insert("X-Refreshed-Token", header_value);
|
||||
}
|
||||
}
|
||||
|
||||
let username = match session {
|
||||
Some(s) => s.username,
|
||||
None => return (response_headers, Json(ApiResponse::error("Not authenticated".to_string()))),
|
||||
};
|
||||
|
||||
let user_result = sqlx::query("SELECT password_hash FROM users WHERE username = ?")
|
||||
.bind(&username)
|
||||
.fetch_one(pool)
|
||||
.await;
|
||||
|
||||
match user_result {
|
||||
Ok(row) => {
|
||||
let hash = row.get::<String, _>("password_hash");
|
||||
if bcrypt::verify(&payload.current_password, &hash).unwrap_or(false) {
|
||||
let new_hash = match bcrypt::hash(&payload.new_password, 12) {
|
||||
Ok(h) => h,
|
||||
Err(_) => return (response_headers, Json(ApiResponse::error("Failed to hash new password".to_string()))),
|
||||
};
|
||||
|
||||
let update_result = sqlx::query(
|
||||
"UPDATE users SET password_hash = ?, must_change_password = FALSE WHERE username = ?",
|
||||
)
|
||||
.bind(new_hash)
|
||||
.bind(&username)
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
match update_result {
|
||||
Ok(_) => (response_headers, Json(ApiResponse::success(
|
||||
serde_json::json!({ "message": "Password updated successfully" }),
|
||||
))),
|
||||
Err(e) => (response_headers, Json(ApiResponse::error(format!("Failed to update database: {}", e)))),
|
||||
}
|
||||
} else {
|
||||
(response_headers, Json(ApiResponse::error("Current password incorrect".to_string())))
|
||||
}
|
||||
}
|
||||
Err(e) => (response_headers, Json(ApiResponse::error(format!("User not found: {}", e)))),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_logout(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let token = headers
|
||||
.get("Authorization")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|v| v.strip_prefix("Bearer "));
|
||||
|
||||
if let Some(token) = token {
|
||||
state.session_manager.revoke_session(token).await;
|
||||
}
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!({ "message": "Logged out" })))
|
||||
}
|
||||
|
||||
/// Helper: Extract and validate a session from the Authorization header.
|
||||
/// Returns the Session and optional new token if refreshed, or an error response.
|
||||
pub(super) async fn extract_session(
|
||||
state: &DashboardState,
|
||||
headers: &axum::http::HeaderMap,
|
||||
) -> Result<(super::sessions::Session, Option<String>), Json<ApiResponse<serde_json::Value>>> {
|
||||
let token = headers
|
||||
.get("Authorization")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|v| v.strip_prefix("Bearer "));
|
||||
|
||||
match token {
|
||||
Some(t) => match state.session_manager.validate_session_with_refresh(t).await {
|
||||
Some((session, new_token)) => Ok((session, new_token)),
|
||||
None => Err(Json(ApiResponse::error("Session expired or invalid".to_string()))),
|
||||
},
|
||||
None => Err(Json(ApiResponse::error("Not authenticated".to_string()))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper: Extract session and require admin role.
|
||||
/// Returns session and optional new token if refreshed.
|
||||
pub(super) async fn require_admin(
|
||||
state: &DashboardState,
|
||||
headers: &axum::http::HeaderMap,
|
||||
) -> Result<(super::sessions::Session, Option<String>), Json<ApiResponse<serde_json::Value>>> {
|
||||
let (session, new_token) = extract_session(state, headers).await?;
|
||||
if session.role != "admin" {
|
||||
return Err(Json(ApiResponse::error("Admin access required".to_string())));
|
||||
}
|
||||
Ok((session, new_token))
|
||||
}
|
||||
@@ -1,538 +0,0 @@
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
response::Json,
|
||||
};
|
||||
use chrono;
|
||||
use rand::Rng;
|
||||
use serde::Deserialize;
|
||||
use serde_json;
|
||||
use sqlx::Row;
|
||||
use tracing::warn;
|
||||
use uuid;
|
||||
|
||||
use super::{ApiResponse, DashboardState};
|
||||
|
||||
/// Generate a random API token: sk-{48 hex chars}
|
||||
fn generate_token() -> String {
|
||||
let mut rng = rand::rng();
|
||||
let bytes: Vec<u8> = (0..24).map(|_| rng.random::<u8>()).collect();
|
||||
format!("sk-{}", hex::encode(bytes))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(super) struct CreateClientRequest {
|
||||
pub(super) name: String,
|
||||
pub(super) client_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(super) struct UpdateClientPayload {
|
||||
pub(super) name: Option<String>,
|
||||
pub(super) description: Option<String>,
|
||||
pub(super) is_active: Option<bool>,
|
||||
pub(super) rate_limit_per_minute: Option<i64>,
|
||||
}
|
||||
|
||||
pub(super) async fn handle_get_clients(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
SELECT
|
||||
client_id as id,
|
||||
name,
|
||||
description,
|
||||
created_at,
|
||||
total_requests,
|
||||
total_tokens,
|
||||
total_cost,
|
||||
is_active,
|
||||
rate_limit_per_minute
|
||||
FROM clients
|
||||
ORDER BY created_at DESC
|
||||
"#,
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(rows) => {
|
||||
let clients: Vec<serde_json::Value> = rows
|
||||
.into_iter()
|
||||
.map(|row| {
|
||||
serde_json::json!({
|
||||
"id": row.get::<String, _>("id"),
|
||||
"name": row.get::<Option<String>, _>("name").unwrap_or_else(|| "Unnamed".to_string()),
|
||||
"description": row.get::<Option<String>, _>("description"),
|
||||
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
|
||||
"requests_count": row.get::<i64, _>("total_requests"),
|
||||
"total_tokens": row.get::<i64, _>("total_tokens"),
|
||||
"total_cost": row.get::<f64, _>("total_cost"),
|
||||
"status": if row.get::<bool, _>("is_active") { "active" } else { "inactive" },
|
||||
"rate_limit_per_minute": row.get::<Option<i64>, _>("rate_limit_per_minute"),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!(clients)))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to fetch clients: {}", e);
|
||||
Json(ApiResponse::error("Failed to fetch clients".to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_create_client(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Json(payload): Json<CreateClientRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
let client_id = payload
|
||||
.client_id
|
||||
.unwrap_or_else(|| format!("client-{}", &uuid::Uuid::new_v4().to_string()[..8]));
|
||||
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
INSERT INTO clients (client_id, name, is_active)
|
||||
VALUES (?, ?, TRUE)
|
||||
RETURNING *
|
||||
"#,
|
||||
)
|
||||
.bind(&client_id)
|
||||
.bind(&payload.name)
|
||||
.fetch_one(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(row) => {
|
||||
// Auto-generate a token for the new client
|
||||
let token = generate_token();
|
||||
let token_result = sqlx::query(
|
||||
"INSERT INTO client_tokens (client_id, token, name) VALUES (?, ?, 'default')",
|
||||
)
|
||||
.bind(&client_id)
|
||||
.bind(&token)
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
if let Err(e) = token_result {
|
||||
warn!("Client created but failed to generate token: {}", e);
|
||||
}
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"id": row.get::<String, _>("client_id"),
|
||||
"name": row.get::<Option<String>, _>("name"),
|
||||
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
|
||||
"status": "active",
|
||||
"token": token,
|
||||
})))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to create client: {}", e);
|
||||
Json(ApiResponse::error(format!("Failed to create client: {}", e)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_get_client(
|
||||
State(state): State<DashboardState>,
|
||||
Path(id): Path<String>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
SELECT
|
||||
c.client_id as id,
|
||||
c.name,
|
||||
c.description,
|
||||
c.is_active,
|
||||
c.rate_limit_per_minute,
|
||||
c.created_at,
|
||||
COALESCE(c.total_tokens, 0) as total_tokens,
|
||||
COALESCE(c.total_cost, 0.0) as total_cost,
|
||||
COUNT(r.id) as total_requests,
|
||||
MAX(r.timestamp) as last_request
|
||||
FROM clients c
|
||||
LEFT JOIN llm_requests r ON c.client_id = r.client_id
|
||||
WHERE c.client_id = ?
|
||||
GROUP BY c.client_id
|
||||
"#,
|
||||
)
|
||||
.bind(&id)
|
||||
.fetch_optional(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(Some(row)) => Json(ApiResponse::success(serde_json::json!({
|
||||
"id": row.get::<String, _>("id"),
|
||||
"name": row.get::<Option<String>, _>("name").unwrap_or_else(|| "Unnamed".to_string()),
|
||||
"description": row.get::<Option<String>, _>("description"),
|
||||
"is_active": row.get::<bool, _>("is_active"),
|
||||
"rate_limit_per_minute": row.get::<Option<i64>, _>("rate_limit_per_minute"),
|
||||
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
|
||||
"total_tokens": row.get::<i64, _>("total_tokens"),
|
||||
"total_cost": row.get::<f64, _>("total_cost"),
|
||||
"total_requests": row.get::<i64, _>("total_requests"),
|
||||
"last_request": row.get::<Option<chrono::DateTime<chrono::Utc>>, _>("last_request"),
|
||||
"status": if row.get::<bool, _>("is_active") { "active" } else { "inactive" },
|
||||
}))),
|
||||
Ok(None) => Json(ApiResponse::error(format!("Client '{}' not found", id))),
|
||||
Err(e) => {
|
||||
warn!("Failed to fetch client: {}", e);
|
||||
Json(ApiResponse::error(format!("Failed to fetch client: {}", e)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_update_client(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Path(id): Path<String>,
|
||||
Json(payload): Json<UpdateClientPayload>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Build dynamic UPDATE query from provided fields
|
||||
let mut sets = Vec::new();
|
||||
let mut binds: Vec<String> = Vec::new();
|
||||
|
||||
if let Some(ref name) = payload.name {
|
||||
sets.push("name = ?");
|
||||
binds.push(name.clone());
|
||||
}
|
||||
if let Some(ref desc) = payload.description {
|
||||
sets.push("description = ?");
|
||||
binds.push(desc.clone());
|
||||
}
|
||||
if payload.is_active.is_some() {
|
||||
sets.push("is_active = ?");
|
||||
}
|
||||
if payload.rate_limit_per_minute.is_some() {
|
||||
sets.push("rate_limit_per_minute = ?");
|
||||
}
|
||||
|
||||
if sets.is_empty() {
|
||||
return Json(ApiResponse::error("No fields to update".to_string()));
|
||||
}
|
||||
|
||||
// Always update updated_at
|
||||
sets.push("updated_at = CURRENT_TIMESTAMP");
|
||||
|
||||
let sql = format!("UPDATE clients SET {} WHERE client_id = ?", sets.join(", "));
|
||||
let mut query = sqlx::query(&sql);
|
||||
|
||||
// Bind in the same order as sets
|
||||
for b in &binds {
|
||||
query = query.bind(b);
|
||||
}
|
||||
if let Some(active) = payload.is_active {
|
||||
query = query.bind(active);
|
||||
}
|
||||
if let Some(rate) = payload.rate_limit_per_minute {
|
||||
query = query.bind(rate);
|
||||
}
|
||||
query = query.bind(&id);
|
||||
|
||||
match query.execute(pool).await {
|
||||
Ok(result) => {
|
||||
if result.rows_affected() == 0 {
|
||||
return Json(ApiResponse::error(format!("Client '{}' not found", id)));
|
||||
}
|
||||
// Return the updated client
|
||||
let row = sqlx::query(
|
||||
r#"
|
||||
SELECT client_id as id, name, description, is_active, rate_limit_per_minute,
|
||||
created_at, total_requests, total_tokens, total_cost
|
||||
FROM clients WHERE client_id = ?
|
||||
"#,
|
||||
)
|
||||
.bind(&id)
|
||||
.fetch_one(pool)
|
||||
.await;
|
||||
|
||||
match row {
|
||||
Ok(row) => Json(ApiResponse::success(serde_json::json!({
|
||||
"id": row.get::<String, _>("id"),
|
||||
"name": row.get::<Option<String>, _>("name").unwrap_or_else(|| "Unnamed".to_string()),
|
||||
"description": row.get::<Option<String>, _>("description"),
|
||||
"is_active": row.get::<bool, _>("is_active"),
|
||||
"rate_limit_per_minute": row.get::<Option<i64>, _>("rate_limit_per_minute"),
|
||||
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
|
||||
"total_requests": row.get::<i64, _>("total_requests"),
|
||||
"total_tokens": row.get::<i64, _>("total_tokens"),
|
||||
"total_cost": row.get::<f64, _>("total_cost"),
|
||||
"status": if row.get::<bool, _>("is_active") { "active" } else { "inactive" },
|
||||
}))),
|
||||
Err(e) => {
|
||||
warn!("Failed to fetch updated client: {}", e);
|
||||
// Update succeeded but fetch failed — still report success
|
||||
Json(ApiResponse::success(serde_json::json!({ "message": "Client updated" })))
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to update client: {}", e);
|
||||
Json(ApiResponse::error(format!("Failed to update client: {}", e)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_delete_client(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Path(id): Path<String>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Don't allow deleting the default client
|
||||
if id == "default" {
|
||||
return Json(ApiResponse::error("Cannot delete default client".to_string()));
|
||||
}
|
||||
|
||||
let result = sqlx::query("DELETE FROM clients WHERE client_id = ?")
|
||||
.bind(id)
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => Json(ApiResponse::success(serde_json::json!({ "message": "Client deleted" }))),
|
||||
Err(e) => Json(ApiResponse::error(format!("Failed to delete client: {}", e))),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_client_usage(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Path(id): Path<String>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Get per-model breakdown for this client
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
SELECT
|
||||
model,
|
||||
provider,
|
||||
COUNT(*) as request_count,
|
||||
SUM(prompt_tokens) as prompt_tokens,
|
||||
SUM(completion_tokens) as completion_tokens,
|
||||
SUM(total_tokens) as total_tokens,
|
||||
SUM(cost) as total_cost,
|
||||
AVG(duration_ms) as avg_duration_ms
|
||||
FROM llm_requests
|
||||
WHERE client_id = ?
|
||||
GROUP BY model, provider
|
||||
ORDER BY total_cost DESC
|
||||
"#,
|
||||
)
|
||||
.bind(&id)
|
||||
.fetch_all(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(rows) => {
|
||||
let breakdown: Vec<serde_json::Value> = rows
|
||||
.into_iter()
|
||||
.map(|row| {
|
||||
serde_json::json!({
|
||||
"model": row.get::<String, _>("model"),
|
||||
"provider": row.get::<String, _>("provider"),
|
||||
"request_count": row.get::<i64, _>("request_count"),
|
||||
"prompt_tokens": row.get::<i64, _>("prompt_tokens"),
|
||||
"completion_tokens": row.get::<i64, _>("completion_tokens"),
|
||||
"total_tokens": row.get::<i64, _>("total_tokens"),
|
||||
"total_cost": row.get::<f64, _>("total_cost"),
|
||||
"avg_duration_ms": row.get::<f64, _>("avg_duration_ms"),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"client_id": id,
|
||||
"breakdown": breakdown,
|
||||
})))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to fetch client usage: {}", e);
|
||||
Json(ApiResponse::error(format!("Failed to fetch client usage: {}", e)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Token management endpoints ──────────────────────────────────────
|
||||
|
||||
pub(super) async fn handle_get_client_tokens(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Path(id): Path<String>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
SELECT id, token, name, is_active, created_at, last_used_at
|
||||
FROM client_tokens
|
||||
WHERE client_id = ?
|
||||
ORDER BY created_at DESC
|
||||
"#,
|
||||
)
|
||||
.bind(&id)
|
||||
.fetch_all(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(rows) => {
|
||||
let tokens: Vec<serde_json::Value> = rows
|
||||
.into_iter()
|
||||
.map(|row| {
|
||||
let token: String = row.get("token");
|
||||
// Mask all but last 8 chars: sk-••••abcd1234
|
||||
let masked = if token.len() > 8 {
|
||||
format!("{}••••{}", &token[..3], &token[token.len() - 8..])
|
||||
} else {
|
||||
"••••".to_string()
|
||||
};
|
||||
serde_json::json!({
|
||||
"id": row.get::<i64, _>("id"),
|
||||
"token_masked": masked,
|
||||
"name": row.get::<Option<String>, _>("name").unwrap_or_else(|| "default".to_string()),
|
||||
"is_active": row.get::<bool, _>("is_active"),
|
||||
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
|
||||
"last_used_at": row.get::<Option<chrono::DateTime<chrono::Utc>>, _>("last_used_at"),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!(tokens)))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to fetch client tokens: {}", e);
|
||||
Json(ApiResponse::error(format!("Failed to fetch client tokens: {}", e)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(super) struct CreateTokenRequest {
|
||||
pub(super) name: Option<String>,
|
||||
}
|
||||
|
||||
pub(super) async fn handle_create_client_token(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Path(id): Path<String>,
|
||||
Json(payload): Json<CreateTokenRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Verify client exists
|
||||
let exists: Option<(i64,)> = sqlx::query_as("SELECT 1 as x FROM clients WHERE client_id = ?")
|
||||
.bind(&id)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
.unwrap_or(None);
|
||||
|
||||
if exists.is_none() {
|
||||
return Json(ApiResponse::error(format!("Client '{}' not found", id)));
|
||||
}
|
||||
|
||||
let token = generate_token();
|
||||
let token_name = payload.name.unwrap_or_else(|| "default".to_string());
|
||||
|
||||
let result = sqlx::query(
|
||||
"INSERT INTO client_tokens (client_id, token, name) VALUES (?, ?, ?) RETURNING id, created_at",
|
||||
)
|
||||
.bind(&id)
|
||||
.bind(&token)
|
||||
.bind(&token_name)
|
||||
.fetch_one(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(row) => Json(ApiResponse::success(serde_json::json!({
|
||||
"id": row.get::<i64, _>("id"),
|
||||
"token": token,
|
||||
"name": token_name,
|
||||
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
|
||||
}))),
|
||||
Err(e) => {
|
||||
warn!("Failed to create client token: {}", e);
|
||||
Json(ApiResponse::error(format!("Failed to create token: {}", e)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_delete_client_token(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Path((client_id, token_id)): Path<(String, i64)>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
let result = sqlx::query("DELETE FROM client_tokens WHERE id = ? AND client_id = ?")
|
||||
.bind(token_id)
|
||||
.bind(&client_id)
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(r) => {
|
||||
if r.rows_affected() == 0 {
|
||||
Json(ApiResponse::error("Token not found".to_string()))
|
||||
} else {
|
||||
Json(ApiResponse::success(serde_json::json!({ "message": "Token revoked" })))
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to delete client token: {}", e);
|
||||
Json(ApiResponse::error(format!("Failed to revoke token: {}", e)))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,174 +0,0 @@
|
||||
// Dashboard module for LLM Proxy Gateway
|
||||
|
||||
mod auth;
|
||||
mod clients;
|
||||
mod models;
|
||||
mod providers;
|
||||
pub mod sessions;
|
||||
mod system;
|
||||
mod usage;
|
||||
mod users;
|
||||
mod websocket;
|
||||
|
||||
use axum::{
|
||||
extract::{Request, State},
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
Router,
|
||||
routing::{delete, get, post, put},
|
||||
};
|
||||
use axum::http::{header, HeaderValue};
|
||||
use serde::Serialize;
|
||||
use tower_http::{
|
||||
limit::RequestBodyLimitLayer,
|
||||
set_header::SetResponseHeaderLayer,
|
||||
};
|
||||
|
||||
use crate::state::AppState;
|
||||
use sessions::SessionManager;
|
||||
|
||||
// Dashboard state
|
||||
#[derive(Clone)]
|
||||
struct DashboardState {
|
||||
app_state: AppState,
|
||||
session_manager: SessionManager,
|
||||
}
|
||||
|
||||
// API Response types
|
||||
#[derive(Serialize)]
|
||||
struct ApiResponse<T> {
|
||||
success: bool,
|
||||
data: Option<T>,
|
||||
error: Option<String>,
|
||||
}
|
||||
|
||||
impl<T> ApiResponse<T> {
|
||||
fn success(data: T) -> Self {
|
||||
Self {
|
||||
success: true,
|
||||
data: Some(data),
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn error(error: String) -> Self {
|
||||
Self {
|
||||
success: false,
|
||||
data: None,
|
||||
error: Some(error),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Rate limiting middleware for dashboard routes
|
||||
async fn dashboard_rate_limit_middleware(
|
||||
State(_dashboard_state): State<DashboardState>,
|
||||
request: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, crate::errors::AppError> {
|
||||
// Bypass rate limiting for dashboard routes to prevent "Failed to load statistics"
|
||||
// when the UI makes many concurrent requests on load.
|
||||
// Dashboard endpoints are already secured via auth::require_admin.
|
||||
Ok(next.run(request).await)
|
||||
}
|
||||
|
||||
// Dashboard routes
|
||||
pub fn router(state: AppState) -> Router {
|
||||
let session_manager = SessionManager::new(24); // 24-hour session TTL
|
||||
let dashboard_state = DashboardState {
|
||||
app_state: state,
|
||||
session_manager,
|
||||
};
|
||||
|
||||
// Security headers
|
||||
let csp_header: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
|
||||
header::CONTENT_SECURITY_POLICY,
|
||||
"default-src 'self'; script-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net; style-src 'self' 'unsafe-inline' https://cdnjs.cloudflare.com https://fonts.googleapis.com; font-src 'self' https://cdnjs.cloudflare.com https://fonts.gstatic.com; img-src 'self' data:; connect-src 'self' ws:;"
|
||||
.parse()
|
||||
.unwrap(),
|
||||
);
|
||||
let x_frame_options: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
|
||||
header::X_FRAME_OPTIONS,
|
||||
"DENY".parse().unwrap(),
|
||||
);
|
||||
let x_content_type_options: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
|
||||
header::X_CONTENT_TYPE_OPTIONS,
|
||||
"nosniff".parse().unwrap(),
|
||||
);
|
||||
let strict_transport_security: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
|
||||
header::STRICT_TRANSPORT_SECURITY,
|
||||
"max-age=31536000; includeSubDomains".parse().unwrap(),
|
||||
);
|
||||
|
||||
Router::new()
|
||||
// Static file serving
|
||||
.fallback_service(tower_http::services::ServeDir::new("static"))
|
||||
// WebSocket endpoint
|
||||
.route("/ws", get(websocket::handle_websocket))
|
||||
// API endpoints
|
||||
.route("/api/auth/login", post(auth::handle_login))
|
||||
.route("/api/auth/status", get(auth::handle_auth_status))
|
||||
.route("/api/auth/logout", post(auth::handle_logout))
|
||||
.route("/api/auth/change-password", post(auth::handle_change_password))
|
||||
.route(
|
||||
"/api/users",
|
||||
get(users::handle_get_users).post(users::handle_create_user),
|
||||
)
|
||||
.route(
|
||||
"/api/users/{id}",
|
||||
put(users::handle_update_user).delete(users::handle_delete_user),
|
||||
)
|
||||
.route("/api/usage/summary", get(usage::handle_usage_summary))
|
||||
.route("/api/usage/time-series", get(usage::handle_time_series))
|
||||
.route("/api/usage/clients", get(usage::handle_clients_usage))
|
||||
.route("/api/usage/providers", get(usage::handle_providers_usage))
|
||||
.route("/api/usage/detailed", get(usage::handle_detailed_usage))
|
||||
.route("/api/analytics/breakdown", get(usage::handle_analytics_breakdown))
|
||||
.route("/api/models", get(models::handle_get_models))
|
||||
.route("/api/models/{id}", put(models::handle_update_model))
|
||||
.route(
|
||||
"/api/clients",
|
||||
get(clients::handle_get_clients).post(clients::handle_create_client),
|
||||
)
|
||||
.route(
|
||||
"/api/clients/{id}",
|
||||
get(clients::handle_get_client)
|
||||
.put(clients::handle_update_client)
|
||||
.delete(clients::handle_delete_client),
|
||||
)
|
||||
.route("/api/clients/{id}/usage", get(clients::handle_client_usage))
|
||||
.route(
|
||||
"/api/clients/{id}/tokens",
|
||||
get(clients::handle_get_client_tokens).post(clients::handle_create_client_token),
|
||||
)
|
||||
.route(
|
||||
"/api/clients/{id}/tokens/{token_id}",
|
||||
delete(clients::handle_delete_client_token),
|
||||
)
|
||||
.route("/api/providers", get(providers::handle_get_providers))
|
||||
.route(
|
||||
"/api/providers/{name}",
|
||||
get(providers::handle_get_provider).put(providers::handle_update_provider),
|
||||
)
|
||||
.route("/api/providers/{name}/test", post(providers::handle_test_provider))
|
||||
.route("/api/system/health", get(system::handle_system_health))
|
||||
.route("/api/system/metrics", get(system::handle_system_metrics))
|
||||
.route("/api/system/logs", get(system::handle_system_logs))
|
||||
.route("/api/system/backup", post(system::handle_system_backup))
|
||||
.route(
|
||||
"/api/system/settings",
|
||||
get(system::handle_get_settings).post(system::handle_update_settings),
|
||||
)
|
||||
// Security layers
|
||||
.layer(RequestBodyLimitLayer::new(10 * 1024 * 1024)) // 10 MB limit
|
||||
.layer(csp_header)
|
||||
.layer(x_frame_options)
|
||||
.layer(x_content_type_options)
|
||||
.layer(strict_transport_security)
|
||||
// Rate limiting middleware
|
||||
.layer(axum::middleware::from_fn_with_state(
|
||||
dashboard_state.clone(),
|
||||
dashboard_rate_limit_middleware,
|
||||
))
|
||||
.with_state(dashboard_state)
|
||||
}
|
||||
@@ -1,254 +0,0 @@
|
||||
use axum::{
|
||||
extract::{Path, Query, State},
|
||||
response::Json,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use serde_json;
|
||||
use sqlx::Row;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::{ApiResponse, DashboardState};
|
||||
use crate::models::registry::{ModelFilter, ModelSortBy, SortOrder};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(super) struct UpdateModelRequest {
|
||||
pub(super) enabled: bool,
|
||||
pub(super) prompt_cost: Option<f64>,
|
||||
pub(super) completion_cost: Option<f64>,
|
||||
pub(super) mapping: Option<String>,
|
||||
}
|
||||
|
||||
/// Query parameters for `GET /api/models`.
|
||||
#[derive(Debug, Deserialize, Default)]
|
||||
pub(super) struct ModelListParams {
|
||||
/// Filter by provider ID.
|
||||
pub provider: Option<String>,
|
||||
/// Text search on model ID or name.
|
||||
pub search: Option<String>,
|
||||
/// Filter by input modality (e.g. "image").
|
||||
pub modality: Option<String>,
|
||||
/// Only models that support tool calling.
|
||||
pub tool_call: Option<bool>,
|
||||
/// Only models that support reasoning.
|
||||
pub reasoning: Option<bool>,
|
||||
/// Only models that have pricing data.
|
||||
pub has_cost: Option<bool>,
|
||||
/// Only models that have been used in requests.
|
||||
pub used_only: Option<bool>,
|
||||
/// Sort field (name, id, provider, context_limit, input_cost, output_cost).
|
||||
pub sort_by: Option<ModelSortBy>,
|
||||
/// Sort direction (asc, desc).
|
||||
pub sort_order: Option<SortOrder>,
|
||||
}
|
||||
|
||||
pub(super) async fn handle_get_models(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Query(params): Query<ModelListParams>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let registry = &state.app_state.model_registry;
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Load overrides from database
|
||||
let db_models_result =
|
||||
sqlx::query("SELECT id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping FROM model_configs")
|
||||
.fetch_all(pool)
|
||||
.await;
|
||||
|
||||
let mut db_models = HashMap::new();
|
||||
if let Ok(rows) = db_models_result {
|
||||
for row in rows {
|
||||
let id: String = row.get("id");
|
||||
db_models.insert(id, row);
|
||||
}
|
||||
}
|
||||
|
||||
let mut models_json = Vec::new();
|
||||
|
||||
if params.used_only.unwrap_or(false) {
|
||||
// EXACT USED MODELS LOGIC
|
||||
let used_pairs_result = sqlx::query(
|
||||
"SELECT DISTINCT provider, model FROM llm_requests",
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await;
|
||||
|
||||
if let Ok(rows) = used_pairs_result {
|
||||
for row in rows {
|
||||
let provider: String = row.get("provider");
|
||||
let m_key: String = row.get("model");
|
||||
|
||||
let provider_name = match provider.as_str() {
|
||||
"openai" => "OpenAI",
|
||||
"gemini" => "Google Gemini",
|
||||
"deepseek" => "DeepSeek",
|
||||
"grok" => "xAI Grok",
|
||||
"ollama" => "Ollama",
|
||||
_ => provider.as_str(),
|
||||
}.to_string();
|
||||
|
||||
let m_meta = registry.find_model(&m_key);
|
||||
|
||||
let mut enabled = true;
|
||||
let mut prompt_cost = m_meta.and_then(|m| m.cost.as_ref().map(|c| c.input)).unwrap_or(0.0);
|
||||
let mut completion_cost = m_meta.and_then(|m| m.cost.as_ref().map(|c| c.output)).unwrap_or(0.0);
|
||||
let cache_read_cost = m_meta.and_then(|m| m.cost.as_ref().and_then(|c| c.cache_read));
|
||||
let cache_write_cost = m_meta.and_then(|m| m.cost.as_ref().and_then(|c| c.cache_write));
|
||||
let mut mapping = None::<String>;
|
||||
|
||||
if let Some(db_row) = db_models.get(&m_key) {
|
||||
enabled = db_row.get("enabled");
|
||||
if let Some(p) = db_row.get::<Option<f64>, _>("prompt_cost_per_m") {
|
||||
prompt_cost = p;
|
||||
}
|
||||
if let Some(c) = db_row.get::<Option<f64>, _>("completion_cost_per_m") {
|
||||
completion_cost = c;
|
||||
}
|
||||
mapping = db_row.get("mapping");
|
||||
}
|
||||
|
||||
models_json.push(serde_json::json!({
|
||||
"id": m_key,
|
||||
"provider": provider,
|
||||
"provider_name": provider_name,
|
||||
"name": m_meta.map(|m| m.name.clone()).unwrap_or_else(|| m_key.clone()),
|
||||
"enabled": enabled,
|
||||
"prompt_cost": prompt_cost,
|
||||
"completion_cost": completion_cost,
|
||||
"cache_read_cost": cache_read_cost,
|
||||
"cache_write_cost": cache_write_cost,
|
||||
"mapping": mapping,
|
||||
"context_limit": m_meta.and_then(|m| m.limit.as_ref().map(|l| l.context)).unwrap_or(0),
|
||||
"output_limit": m_meta.and_then(|m| m.limit.as_ref().map(|l| l.output)).unwrap_or(0),
|
||||
"modalities": m_meta.and_then(|m| m.modalities.as_ref().map(|mo| serde_json::json!({
|
||||
"input": mo.input,
|
||||
"output": mo.output,
|
||||
}))),
|
||||
"tool_call": m_meta.and_then(|m| m.tool_call),
|
||||
"reasoning": m_meta.and_then(|m| m.reasoning),
|
||||
}));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// REGISTRY LISTING LOGIC
|
||||
// Build filter from query params
|
||||
let filter = ModelFilter {
|
||||
provider: params.provider,
|
||||
search: params.search,
|
||||
modality: params.modality,
|
||||
tool_call: params.tool_call,
|
||||
reasoning: params.reasoning,
|
||||
has_cost: params.has_cost,
|
||||
};
|
||||
let sort_by = params.sort_by.unwrap_or_default();
|
||||
let sort_order = params.sort_order.unwrap_or_default();
|
||||
|
||||
// Get filtered and sorted model entries
|
||||
let entries = registry.list_models(&filter, &sort_by, &sort_order);
|
||||
|
||||
for entry in &entries {
|
||||
let m_key = entry.model_key;
|
||||
let m_meta = entry.metadata;
|
||||
|
||||
let mut enabled = true;
|
||||
let mut prompt_cost = m_meta.cost.as_ref().map(|c| c.input).unwrap_or(0.0);
|
||||
let mut completion_cost = m_meta.cost.as_ref().map(|c| c.output).unwrap_or(0.0);
|
||||
let cache_read_cost = m_meta.cost.as_ref().and_then(|c| c.cache_read);
|
||||
let cache_write_cost = m_meta.cost.as_ref().and_then(|c| c.cache_write);
|
||||
let mut mapping = None::<String>;
|
||||
|
||||
if let Some(row) = db_models.get(m_key) {
|
||||
enabled = row.get("enabled");
|
||||
if let Some(p) = row.get::<Option<f64>, _>("prompt_cost_per_m") {
|
||||
prompt_cost = p;
|
||||
}
|
||||
if let Some(c) = row.get::<Option<f64>, _>("completion_cost_per_m") {
|
||||
completion_cost = c;
|
||||
}
|
||||
mapping = row.get("mapping");
|
||||
}
|
||||
|
||||
models_json.push(serde_json::json!({
|
||||
"id": m_key,
|
||||
"provider": entry.provider_id,
|
||||
"provider_name": entry.provider_name,
|
||||
"name": m_meta.name,
|
||||
"enabled": enabled,
|
||||
"prompt_cost": prompt_cost,
|
||||
"completion_cost": completion_cost,
|
||||
"cache_read_cost": cache_read_cost,
|
||||
"cache_write_cost": cache_write_cost,
|
||||
"mapping": mapping,
|
||||
"context_limit": m_meta.limit.as_ref().map(|l| l.context).unwrap_or(0),
|
||||
"output_limit": m_meta.limit.as_ref().map(|l| l.output).unwrap_or(0),
|
||||
"modalities": m_meta.modalities.as_ref().map(|m| serde_json::json!({
|
||||
"input": m.input,
|
||||
"output": m.output,
|
||||
})),
|
||||
"tool_call": m_meta.tool_call,
|
||||
"reasoning": m_meta.reasoning,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!(models_json)))
|
||||
}
|
||||
|
||||
pub(super) async fn handle_update_model(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Path(id): Path<String>,
|
||||
Json(payload): Json<UpdateModelRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Find provider_id for this model in registry
|
||||
let provider_id = state
|
||||
.app_state
|
||||
.model_registry
|
||||
.providers
|
||||
.iter()
|
||||
.find(|(_, p)| p.models.contains_key(&id))
|
||||
.map(|(id, _)| id.clone())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
INSERT INTO model_configs (id, provider_id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
enabled = excluded.enabled,
|
||||
prompt_cost_per_m = excluded.prompt_cost_per_m,
|
||||
completion_cost_per_m = excluded.completion_cost_per_m,
|
||||
mapping = excluded.mapping,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
"#,
|
||||
)
|
||||
.bind(&id)
|
||||
.bind(provider_id)
|
||||
.bind(payload.enabled)
|
||||
.bind(payload.prompt_cost)
|
||||
.bind(payload.completion_cost)
|
||||
.bind(payload.mapping)
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => {
|
||||
// Invalidate the in-memory cache so the proxy picks up the change immediately
|
||||
state.app_state.model_config_cache.invalidate().await;
|
||||
Json(ApiResponse::success(serde_json::json!({ "message": "Model updated" })))
|
||||
}
|
||||
Err(e) => Json(ApiResponse::error(format!("Failed to update model: {}", e))),
|
||||
}
|
||||
}
|
||||
@@ -1,440 +0,0 @@
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
response::Json,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use serde_json;
|
||||
use sqlx::Row;
|
||||
use std::collections::HashMap;
|
||||
use tracing::warn;
|
||||
|
||||
use super::{ApiResponse, DashboardState};
|
||||
use crate::utils::crypto;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(super) struct UpdateProviderRequest {
|
||||
pub(super) enabled: bool,
|
||||
pub(super) base_url: Option<String>,
|
||||
pub(super) api_key: Option<String>,
|
||||
pub(super) credit_balance: Option<f64>,
|
||||
pub(super) low_credit_threshold: Option<f64>,
|
||||
pub(super) billing_mode: Option<String>,
|
||||
}
|
||||
|
||||
pub(super) async fn handle_get_providers(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let registry = &state.app_state.model_registry;
|
||||
let config = &state.app_state.config;
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Load all overrides from database (including billing_mode)
|
||||
let db_configs_result = sqlx::query(
|
||||
"SELECT id, enabled, base_url, credit_balance, low_credit_threshold, billing_mode FROM provider_configs",
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await;
|
||||
|
||||
let mut db_configs = HashMap::new();
|
||||
if let Ok(rows) = db_configs_result {
|
||||
for row in rows {
|
||||
let id: String = row.get("id");
|
||||
let enabled: bool = row.get("enabled");
|
||||
let base_url: Option<String> = row.get("base_url");
|
||||
let balance: f64 = row.get("credit_balance");
|
||||
let threshold: f64 = row.get("low_credit_threshold");
|
||||
let billing_mode: Option<String> = row.get("billing_mode");
|
||||
db_configs.insert(id, (enabled, base_url, balance, threshold, billing_mode));
|
||||
}
|
||||
}
|
||||
|
||||
let mut providers_json = Vec::new();
|
||||
|
||||
// Define the list of providers we support
|
||||
let provider_ids = vec!["openai", "gemini", "deepseek", "grok", "ollama"];
|
||||
|
||||
for id in provider_ids {
|
||||
// Get base config
|
||||
let (mut enabled, mut base_url, display_name) = match id {
|
||||
"openai" => (
|
||||
config.providers.openai.enabled,
|
||||
config.providers.openai.base_url.clone(),
|
||||
"OpenAI",
|
||||
),
|
||||
"gemini" => (
|
||||
config.providers.gemini.enabled,
|
||||
config.providers.gemini.base_url.clone(),
|
||||
"Google Gemini",
|
||||
),
|
||||
"deepseek" => (
|
||||
config.providers.deepseek.enabled,
|
||||
config.providers.deepseek.base_url.clone(),
|
||||
"DeepSeek",
|
||||
),
|
||||
"grok" => (
|
||||
config.providers.grok.enabled,
|
||||
config.providers.grok.base_url.clone(),
|
||||
"xAI Grok",
|
||||
),
|
||||
"ollama" => (
|
||||
config.providers.ollama.enabled,
|
||||
config.providers.ollama.base_url.clone(),
|
||||
"Ollama",
|
||||
),
|
||||
_ => (false, "".to_string(), "Unknown"),
|
||||
};
|
||||
|
||||
let mut balance = 0.0;
|
||||
let mut threshold = 5.0;
|
||||
let mut billing_mode: Option<String> = None;
|
||||
|
||||
// Apply database overrides
|
||||
if let Some((db_enabled, db_url, db_balance, db_threshold, db_billing)) = db_configs.get(id) {
|
||||
enabled = *db_enabled;
|
||||
if let Some(url) = db_url {
|
||||
base_url = url.clone();
|
||||
}
|
||||
balance = *db_balance;
|
||||
threshold = *db_threshold;
|
||||
billing_mode = db_billing.clone();
|
||||
}
|
||||
|
||||
// Find models for this provider in registry
|
||||
// NOTE: registry provider IDs differ from internal IDs for some providers.
|
||||
let registry_key = match id {
|
||||
"gemini" => "google",
|
||||
"grok" => "xai",
|
||||
_ => id,
|
||||
};
|
||||
|
||||
let mut models = Vec::new();
|
||||
if let Some(p_info) = registry.providers.get(registry_key) {
|
||||
models = p_info.models.keys().cloned().collect();
|
||||
} else if id == "ollama" {
|
||||
models = config.providers.ollama.models.clone();
|
||||
}
|
||||
|
||||
// Determine status
|
||||
let status = if !enabled {
|
||||
"disabled"
|
||||
} else {
|
||||
// Check if it's actually initialized in the provider manager
|
||||
if state.app_state.provider_manager.get_provider(id).await.is_some() {
|
||||
// Check circuit breaker
|
||||
if state
|
||||
.app_state
|
||||
.rate_limit_manager
|
||||
.check_provider_request(id)
|
||||
.await
|
||||
.unwrap_or(true)
|
||||
{
|
||||
"online"
|
||||
} else {
|
||||
"degraded"
|
||||
}
|
||||
} else {
|
||||
"error" // Enabled but failed to initialize (e.g. missing API key)
|
||||
}
|
||||
};
|
||||
|
||||
providers_json.push(serde_json::json!({
|
||||
"id": id,
|
||||
"name": display_name,
|
||||
"enabled": enabled,
|
||||
"status": status,
|
||||
"models": models,
|
||||
"base_url": base_url,
|
||||
"credit_balance": balance,
|
||||
"low_credit_threshold": threshold,
|
||||
"billing_mode": billing_mode,
|
||||
"last_used": None::<String>,
|
||||
}));
|
||||
}
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!(providers_json)))
|
||||
}
|
||||
|
||||
pub(super) async fn handle_get_provider(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Path(name): Path<String>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let registry = &state.app_state.model_registry;
|
||||
let config = &state.app_state.config;
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Validate provider name
|
||||
let (mut enabled, mut base_url, display_name) = match name.as_str() {
|
||||
"openai" => (
|
||||
config.providers.openai.enabled,
|
||||
config.providers.openai.base_url.clone(),
|
||||
"OpenAI",
|
||||
),
|
||||
"gemini" => (
|
||||
config.providers.gemini.enabled,
|
||||
config.providers.gemini.base_url.clone(),
|
||||
"Google Gemini",
|
||||
),
|
||||
"deepseek" => (
|
||||
config.providers.deepseek.enabled,
|
||||
config.providers.deepseek.base_url.clone(),
|
||||
"DeepSeek",
|
||||
),
|
||||
"grok" => (
|
||||
config.providers.grok.enabled,
|
||||
config.providers.grok.base_url.clone(),
|
||||
"xAI Grok",
|
||||
),
|
||||
"ollama" => (
|
||||
config.providers.ollama.enabled,
|
||||
config.providers.ollama.base_url.clone(),
|
||||
"Ollama",
|
||||
),
|
||||
_ => return Json(ApiResponse::error(format!("Unknown provider '{}'", name))),
|
||||
};
|
||||
|
||||
let mut balance = 0.0;
|
||||
let mut threshold = 5.0;
|
||||
let mut billing_mode: Option<String> = None;
|
||||
|
||||
// Apply database overrides
|
||||
let db_config = sqlx::query(
|
||||
"SELECT enabled, base_url, credit_balance, low_credit_threshold, billing_mode FROM provider_configs WHERE id = ?",
|
||||
)
|
||||
.bind(&name)
|
||||
.fetch_optional(pool)
|
||||
.await;
|
||||
|
||||
if let Ok(Some(row)) = db_config {
|
||||
enabled = row.get::<bool, _>("enabled");
|
||||
if let Some(url) = row.get::<Option<String>, _>("base_url") {
|
||||
base_url = url;
|
||||
}
|
||||
balance = row.get::<f64, _>("credit_balance");
|
||||
threshold = row.get::<f64, _>("low_credit_threshold");
|
||||
billing_mode = row.get::<Option<String>, _>("billing_mode");
|
||||
}
|
||||
|
||||
// Find models for this provider
|
||||
// NOTE: registry provider IDs differ from internal IDs for some providers.
|
||||
let registry_key = match name.as_str() {
|
||||
"gemini" => "google",
|
||||
"grok" => "xai",
|
||||
_ => name.as_str(),
|
||||
};
|
||||
|
||||
let mut models = Vec::new();
|
||||
if let Some(p_info) = registry.providers.get(registry_key) {
|
||||
models = p_info.models.keys().cloned().collect();
|
||||
} else if name == "ollama" {
|
||||
models = config.providers.ollama.models.clone();
|
||||
}
|
||||
|
||||
// Determine status
|
||||
let status = if !enabled {
|
||||
"disabled"
|
||||
} else if state.app_state.provider_manager.get_provider(&name).await.is_some() {
|
||||
if state
|
||||
.app_state
|
||||
.rate_limit_manager
|
||||
.check_provider_request(&name)
|
||||
.await
|
||||
.unwrap_or(true)
|
||||
{
|
||||
"online"
|
||||
} else {
|
||||
"degraded"
|
||||
}
|
||||
} else {
|
||||
"error"
|
||||
};
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"id": name,
|
||||
"name": display_name,
|
||||
"enabled": enabled,
|
||||
"status": status,
|
||||
"models": models,
|
||||
"base_url": base_url,
|
||||
"credit_balance": balance,
|
||||
"low_credit_threshold": threshold,
|
||||
"billing_mode": billing_mode,
|
||||
"last_used": None::<String>,
|
||||
})))
|
||||
}
|
||||
|
||||
pub(super) async fn handle_update_provider(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Path(name): Path<String>,
|
||||
Json(payload): Json<UpdateProviderRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Prepare API key encryption if provided
|
||||
let (api_key_to_store, api_key_encrypted_flag) = match &payload.api_key {
|
||||
Some(key) if !key.is_empty() => {
|
||||
match crypto::encrypt(key) {
|
||||
Ok(encrypted) => (Some(encrypted), Some(true)),
|
||||
Err(e) => {
|
||||
warn!("Failed to encrypt API key for provider {}: {}", name, e);
|
||||
return Json(ApiResponse::error(format!("Failed to encrypt API key: {}", e)));
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(_) => {
|
||||
// Empty string means clear the key
|
||||
(None, Some(false))
|
||||
}
|
||||
None => {
|
||||
// Keep existing key, we'll rely on COALESCE in SQL
|
||||
(None, None)
|
||||
}
|
||||
};
|
||||
|
||||
// Update or insert into database (include billing_mode and api_key_encrypted)
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, api_key_encrypted, credit_balance, low_credit_threshold, billing_mode)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
enabled = excluded.enabled,
|
||||
base_url = excluded.base_url,
|
||||
api_key = COALESCE(excluded.api_key, provider_configs.api_key),
|
||||
api_key_encrypted = COALESCE(excluded.api_key_encrypted, provider_configs.api_key_encrypted),
|
||||
credit_balance = COALESCE(excluded.credit_balance, provider_configs.credit_balance),
|
||||
low_credit_threshold = COALESCE(excluded.low_credit_threshold, provider_configs.low_credit_threshold),
|
||||
billing_mode = COALESCE(excluded.billing_mode, provider_configs.billing_mode),
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
"#,
|
||||
)
|
||||
.bind(&name)
|
||||
.bind(name.to_uppercase())
|
||||
.bind(payload.enabled)
|
||||
.bind(&payload.base_url)
|
||||
.bind(&api_key_to_store)
|
||||
.bind(api_key_encrypted_flag)
|
||||
.bind(payload.credit_balance)
|
||||
.bind(payload.low_credit_threshold)
|
||||
.bind(payload.billing_mode)
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => {
|
||||
// Re-initialize provider in manager
|
||||
if let Err(e) = state
|
||||
.app_state
|
||||
.provider_manager
|
||||
.initialize_provider(&name, &state.app_state.config, &state.app_state.db_pool)
|
||||
.await
|
||||
{
|
||||
warn!("Failed to re-initialize provider {}: {}", name, e);
|
||||
return Json(ApiResponse::error(format!(
|
||||
"Provider settings saved but initialization failed: {}",
|
||||
e
|
||||
)));
|
||||
}
|
||||
|
||||
Json(ApiResponse::success(
|
||||
serde_json::json!({ "message": "Provider updated and re-initialized" }),
|
||||
))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to update provider config: {}", e);
|
||||
Json(ApiResponse::error(format!("Failed to update provider: {}", e)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_test_provider(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Path(name): Path<String>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let provider = match state.app_state.provider_manager.get_provider(&name).await {
|
||||
Some(p) => p,
|
||||
None => {
|
||||
return Json(ApiResponse::error(format!(
|
||||
"Provider '{}' not found or not enabled",
|
||||
name
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
// Pick a real model for this provider from the registry
|
||||
// NOTE: registry provider IDs differ from internal IDs for some providers.
|
||||
let registry_key = match name.as_str() {
|
||||
"gemini" => "google",
|
||||
"grok" => "xai",
|
||||
_ => name.as_str(),
|
||||
};
|
||||
|
||||
let test_model = state
|
||||
.app_state
|
||||
.model_registry
|
||||
.providers
|
||||
.get(registry_key)
|
||||
.and_then(|p| p.models.keys().next().cloned())
|
||||
.unwrap_or_else(|| name.clone());
|
||||
|
||||
let test_request = crate::models::UnifiedRequest {
|
||||
client_id: "system-test".to_string(),
|
||||
model: test_model,
|
||||
messages: vec![crate::models::UnifiedMessage {
|
||||
role: "user".to_string(),
|
||||
content: vec![crate::models::ContentPart::Text { text: "Hi".to_string() }],
|
||||
reasoning_content: None,
|
||||
tool_calls: None,
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
}],
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
top_k: None,
|
||||
n: None,
|
||||
stop: None,
|
||||
max_tokens: Some(5),
|
||||
presence_penalty: None,
|
||||
frequency_penalty: None,
|
||||
stream: false,
|
||||
has_images: false,
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
};
|
||||
|
||||
match provider.chat_completion(test_request).await {
|
||||
Ok(_) => {
|
||||
let latency = start.elapsed().as_millis();
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"success": true,
|
||||
"latency": latency,
|
||||
"message": "Connection test successful"
|
||||
})))
|
||||
}
|
||||
Err(e) => Json(ApiResponse::error(format!("Provider test failed: {}", e))),
|
||||
}
|
||||
}
|
||||
@@ -1,311 +0,0 @@
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use hmac::{Hmac, Mac};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Sha256, digest::generic_array::GenericArray};
|
||||
use std::collections::HashMap;
|
||||
use std::env;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use uuid::Uuid;
|
||||
|
||||
use base64::{engine::general_purpose::URL_SAFE, Engine as _};
|
||||
|
||||
const TOKEN_VERSION: &str = "v2";
|
||||
const REFRESH_WINDOW_MINUTES: i64 = 15; // refresh if token expires within 15 minutes
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Session {
|
||||
pub username: String,
|
||||
pub role: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub expires_at: DateTime<Utc>,
|
||||
pub session_id: String, // unique identifier for the session (UUID)
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SessionManager {
|
||||
sessions: Arc<RwLock<HashMap<String, Session>>>, // key = session_id
|
||||
ttl_hours: i64,
|
||||
secret: Vec<u8>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct SessionPayload {
|
||||
session_id: String,
|
||||
username: String,
|
||||
role: String,
|
||||
iat: i64, // issued at (Unix timestamp)
|
||||
exp: i64, // expiry (Unix timestamp)
|
||||
version: String,
|
||||
}
|
||||
|
||||
impl SessionManager {
|
||||
pub fn new(ttl_hours: i64) -> Self {
|
||||
let secret = load_session_secret();
|
||||
Self {
|
||||
sessions: Arc::new(RwLock::new(HashMap::new())),
|
||||
ttl_hours,
|
||||
secret,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new session and return a signed session token.
|
||||
pub async fn create_session(&self, username: String, role: String) -> String {
|
||||
let session_id = Uuid::new_v4().to_string();
|
||||
let now = Utc::now();
|
||||
let expires_at = now + Duration::hours(self.ttl_hours);
|
||||
let session = Session {
|
||||
username: username.clone(),
|
||||
role: role.clone(),
|
||||
created_at: now,
|
||||
expires_at,
|
||||
session_id: session_id.clone(),
|
||||
};
|
||||
// Store session by session_id
|
||||
self.sessions.write().await.insert(session_id.clone(), session);
|
||||
// Create signed token
|
||||
self.create_signed_token(&session_id, &username, &role, now.timestamp(), expires_at.timestamp())
|
||||
}
|
||||
|
||||
/// Validate a session token and return the session if valid and not expired.
|
||||
/// If the token is within the refresh window, returns a new token as well.
|
||||
pub async fn validate_session(&self, token: &str) -> Option<Session> {
|
||||
self.validate_session_with_refresh(token).await.map(|(session, _)| session)
|
||||
}
|
||||
|
||||
/// Validate a session token and return (session, optional new token if refreshed).
|
||||
pub async fn validate_session_with_refresh(&self, token: &str) -> Option<(Session, Option<String>)> {
|
||||
// Legacy token format (UUID)
|
||||
if token.starts_with("session-") {
|
||||
let sessions = self.sessions.read().await;
|
||||
return sessions.get(token).and_then(|s| {
|
||||
if s.expires_at > Utc::now() {
|
||||
Some((s.clone(), None))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Signed token format
|
||||
let payload = match verify_signed_token(token, &self.secret) {
|
||||
Ok(p) => p,
|
||||
Err(_) => return None,
|
||||
};
|
||||
|
||||
// Check expiry
|
||||
let now = Utc::now().timestamp();
|
||||
if payload.exp <= now {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Look up session by session_id
|
||||
let sessions = self.sessions.read().await;
|
||||
let session = match sessions.get(&payload.session_id) {
|
||||
Some(s) => s.clone(),
|
||||
None => return None, // session revoked or not found
|
||||
};
|
||||
|
||||
// Ensure session username/role matches (should always match)
|
||||
if session.username != payload.username || session.role != payload.role {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Check if token is within refresh window (last REFRESH_WINDOW_MINUTES of validity)
|
||||
let refresh_threshold = payload.exp - REFRESH_WINDOW_MINUTES * 60;
|
||||
let new_token = if now >= refresh_threshold {
|
||||
// Generate a new token with same session data but updated iat/exp?
|
||||
// According to activity-based refresh, we should extend the session expiry.
|
||||
// We'll extend from now by ttl_hours (or keep original expiry?).
|
||||
// Let's extend from now by ttl_hours (sliding window).
|
||||
let new_exp = Utc::now() + Duration::hours(self.ttl_hours);
|
||||
// Update session expiry in store
|
||||
drop(sessions); // release read lock before acquiring write lock
|
||||
self.update_session_expiry(&payload.session_id, new_exp).await;
|
||||
// Create new token with updated iat/exp
|
||||
let new_token = self.create_signed_token(
|
||||
&payload.session_id,
|
||||
&payload.username,
|
||||
&payload.role,
|
||||
now,
|
||||
new_exp.timestamp(),
|
||||
);
|
||||
Some(new_token)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Some((session, new_token))
|
||||
}
|
||||
|
||||
/// Revoke (delete) a session by token.
|
||||
/// Supports both legacy tokens (token is key) and signed tokens (extract session_id).
|
||||
pub async fn revoke_session(&self, token: &str) {
|
||||
if token.starts_with("session-") {
|
||||
self.sessions.write().await.remove(token);
|
||||
return;
|
||||
}
|
||||
// For signed token, try to extract session_id
|
||||
if let Ok(payload) = verify_signed_token(token, &self.secret) {
|
||||
self.sessions.write().await.remove(&payload.session_id);
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove all expired sessions from the store.
|
||||
pub async fn cleanup_expired(&self) {
|
||||
let now = Utc::now();
|
||||
self.sessions.write().await.retain(|_, s| s.expires_at > now);
|
||||
}
|
||||
|
||||
// --- Private helpers ---
|
||||
|
||||
fn create_signed_token(&self, session_id: &str, username: &str, role: &str, iat: i64, exp: i64) -> String {
|
||||
let payload = SessionPayload {
|
||||
session_id: session_id.to_string(),
|
||||
username: username.to_string(),
|
||||
role: role.to_string(),
|
||||
iat,
|
||||
exp,
|
||||
version: TOKEN_VERSION.to_string(),
|
||||
};
|
||||
sign_token(&payload, &self.secret)
|
||||
}
|
||||
|
||||
async fn update_session_expiry(&self, session_id: &str, new_expires_at: DateTime<Utc>) {
|
||||
let mut sessions = self.sessions.write().await;
|
||||
if let Some(session) = sessions.get_mut(session_id) {
|
||||
session.expires_at = new_expires_at;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Load session secret from environment variable SESSION_SECRET (hex or base64 encoded).
|
||||
/// If not set, generates a random 32-byte secret and logs a warning.
|
||||
fn load_session_secret() -> Vec<u8> {
|
||||
let secret_str = env::var("SESSION_SECRET").unwrap_or_else(|_| {
|
||||
// Also check LLM_PROXY__SESSION_SECRET for consistency with config prefix
|
||||
env::var("LLM_PROXY__SESSION_SECRET").unwrap_or_else(|_| {
|
||||
// Generate a random secret (32 bytes) and encode as hex
|
||||
use rand::RngCore;
|
||||
let mut bytes = [0u8; 32];
|
||||
rand::rng().fill_bytes(&mut bytes);
|
||||
let hex_secret = hex::encode(bytes);
|
||||
tracing::warn!(
|
||||
"SESSION_SECRET environment variable not set. Using a randomly generated secret. \
|
||||
This will invalidate all sessions on restart. Set SESSION_SECRET to a fixed hex or base64 encoded 32-byte value."
|
||||
);
|
||||
hex_secret
|
||||
})
|
||||
});
|
||||
|
||||
// Decode hex or base64
|
||||
hex::decode(&secret_str)
|
||||
.or_else(|_| URL_SAFE.decode(&secret_str))
|
||||
.or_else(|_| base64::engine::general_purpose::STANDARD.decode(&secret_str))
|
||||
.unwrap_or_else(|_| {
|
||||
panic!("SESSION_SECRET must be hex or base64 encoded (32 bytes)");
|
||||
})
|
||||
}
|
||||
|
||||
/// Sign a session payload and return a token string in format base64_url(payload).base64_url(signature).
|
||||
fn sign_token(payload: &SessionPayload, secret: &[u8]) -> String {
|
||||
let json = serde_json::to_vec(payload).expect("Failed to serialize payload");
|
||||
let payload_b64 = URL_SAFE.encode(&json);
|
||||
let mut mac = Hmac::<Sha256>::new_from_slice(secret).expect("HMAC can take key of any size");
|
||||
mac.update(&json);
|
||||
let signature = mac.finalize().into_bytes();
|
||||
let signature_b64 = URL_SAFE.encode(signature);
|
||||
format!("{}.{}", payload_b64, signature_b64)
|
||||
}
|
||||
|
||||
/// Verify a signed token and return the decoded payload if valid.
|
||||
fn verify_signed_token(token: &str, secret: &[u8]) -> Result<SessionPayload, TokenError> {
|
||||
let parts: Vec<&str> = token.split('.').collect();
|
||||
if parts.len() != 2 {
|
||||
return Err(TokenError::InvalidFormat);
|
||||
}
|
||||
let payload_b64 = parts[0];
|
||||
let signature_b64 = parts[1];
|
||||
|
||||
let json = URL_SAFE.decode(payload_b64).map_err(|_| TokenError::InvalidFormat)?;
|
||||
let signature = URL_SAFE.decode(signature_b64).map_err(|_| TokenError::InvalidFormat)?;
|
||||
|
||||
// Verify HMAC
|
||||
let mut mac = Hmac::<Sha256>::new_from_slice(secret).expect("HMAC can take key of any size");
|
||||
mac.update(&json);
|
||||
// Convert signature slice to GenericArray
|
||||
let tag = GenericArray::from_slice(&signature);
|
||||
mac.verify(tag).map_err(|_| TokenError::InvalidSignature)?;
|
||||
|
||||
// Deserialize payload
|
||||
let payload: SessionPayload = serde_json::from_slice(&json).map_err(|_| TokenError::InvalidPayload)?;
|
||||
Ok(payload)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum TokenError {
|
||||
InvalidFormat,
|
||||
InvalidSignature,
|
||||
InvalidPayload,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::env;
|
||||
|
||||
#[test]
|
||||
fn test_sign_and_verify_token() {
|
||||
let secret = b"test-secret-must-be-32-bytes-long!";
|
||||
let payload = SessionPayload {
|
||||
session_id: "test-session".to_string(),
|
||||
username: "testuser".to_string(),
|
||||
role: "user".to_string(),
|
||||
iat: 1000,
|
||||
exp: 2000,
|
||||
version: TOKEN_VERSION.to_string(),
|
||||
};
|
||||
let token = sign_token(&payload, secret);
|
||||
let verified = verify_signed_token(&token, secret).unwrap();
|
||||
assert_eq!(verified.session_id, payload.session_id);
|
||||
assert_eq!(verified.username, payload.username);
|
||||
assert_eq!(verified.role, payload.role);
|
||||
assert_eq!(verified.iat, payload.iat);
|
||||
assert_eq!(verified.exp, payload.exp);
|
||||
assert_eq!(verified.version, payload.version);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tampered_token() {
|
||||
let secret = b"test-secret-must-be-32-bytes-long!";
|
||||
let payload = SessionPayload {
|
||||
session_id: "test-session".to_string(),
|
||||
username: "testuser".to_string(),
|
||||
role: "user".to_string(),
|
||||
iat: 1000,
|
||||
exp: 2000,
|
||||
version: TOKEN_VERSION.to_string(),
|
||||
};
|
||||
let mut token = sign_token(&payload, secret);
|
||||
// Tamper with payload part
|
||||
let mut parts: Vec<&str> = token.split('.').collect();
|
||||
let mut payload_bytes = URL_SAFE.decode(parts[0]).unwrap();
|
||||
payload_bytes[0] ^= 0xFF; // flip some bits
|
||||
let tampered_payload = URL_SAFE.encode(payload_bytes);
|
||||
parts[0] = &tampered_payload;
|
||||
token = parts.join(".");
|
||||
assert!(verify_signed_token(&token, secret).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_session_secret_from_env() {
|
||||
unsafe {
|
||||
env::set_var("SESSION_SECRET", hex::encode([0xAA; 32]));
|
||||
}
|
||||
let secret = load_session_secret();
|
||||
assert_eq!(secret, vec![0xAA; 32]);
|
||||
unsafe {
|
||||
env::remove_var("SESSION_SECRET");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,405 +0,0 @@
|
||||
use axum::{extract::State, response::Json};
|
||||
use chrono;
|
||||
use serde_json;
|
||||
use sqlx::Row;
|
||||
use std::collections::HashMap;
|
||||
use tracing::warn;
|
||||
|
||||
use super::{ApiResponse, DashboardState};
|
||||
|
||||
/// Read a value from /proc files, returning None on any failure.
|
||||
fn read_proc_file(path: &str) -> Option<String> {
|
||||
std::fs::read_to_string(path).ok()
|
||||
}
|
||||
|
||||
pub(super) async fn handle_system_health(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let mut components = HashMap::new();
|
||||
components.insert("api_server".to_string(), "online".to_string());
|
||||
components.insert("database".to_string(), "online".to_string());
|
||||
|
||||
// Check provider health via circuit breakers
|
||||
let provider_ids: Vec<String> = state
|
||||
.app_state
|
||||
.provider_manager
|
||||
.get_all_providers()
|
||||
.await
|
||||
.iter()
|
||||
.map(|p| p.name().to_string())
|
||||
.collect();
|
||||
|
||||
for p_id in provider_ids {
|
||||
if state
|
||||
.app_state
|
||||
.rate_limit_manager
|
||||
.check_provider_request(&p_id)
|
||||
.await
|
||||
.unwrap_or(true)
|
||||
{
|
||||
components.insert(p_id, "online".to_string());
|
||||
} else {
|
||||
components.insert(p_id, "degraded".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// Read real memory usage from /proc/self/status
|
||||
let memory_mb = read_proc_file("/proc/self/status")
|
||||
.and_then(|s| s.lines().find(|l| l.starts_with("VmRSS:")).map(|l| l.to_string()))
|
||||
.and_then(|l| l.split_whitespace().nth(1).and_then(|v| v.parse::<f64>().ok()))
|
||||
.map(|kb| kb / 1024.0)
|
||||
.unwrap_or(0.0);
|
||||
|
||||
// Get real database pool stats
|
||||
let db_pool_size = state.app_state.db_pool.size();
|
||||
let db_pool_idle = state.app_state.db_pool.num_idle();
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"status": "healthy",
|
||||
"timestamp": chrono::Utc::now().to_rfc3339(),
|
||||
"components": components,
|
||||
"metrics": {
|
||||
"memory_usage_mb": (memory_mb * 10.0).round() / 10.0,
|
||||
"db_connections_active": db_pool_size - db_pool_idle as u32,
|
||||
"db_connections_idle": db_pool_idle,
|
||||
}
|
||||
})))
|
||||
}
|
||||
|
||||
/// Real system metrics from /proc (Linux only).
|
||||
pub(super) async fn handle_system_metrics(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
// --- CPU usage (aggregate across all cores) ---
|
||||
// /proc/stat first line: cpu user nice system idle iowait irq softirq steal guest guest_nice
|
||||
let cpu_percent = read_proc_file("/proc/stat")
|
||||
.and_then(|s| {
|
||||
let line = s.lines().find(|l| l.starts_with("cpu "))?.to_string();
|
||||
let fields: Vec<u64> = line
|
||||
.split_whitespace()
|
||||
.skip(1)
|
||||
.filter_map(|v| v.parse().ok())
|
||||
.collect();
|
||||
if fields.len() >= 4 {
|
||||
let idle = fields[3];
|
||||
let total: u64 = fields.iter().sum();
|
||||
if total > 0 {
|
||||
Some(((total - idle) as f64 / total as f64 * 100.0 * 10.0).round() / 10.0)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.unwrap_or(0.0);
|
||||
|
||||
// --- Memory (system-wide from /proc/meminfo) ---
|
||||
let meminfo = read_proc_file("/proc/meminfo").unwrap_or_default();
|
||||
let parse_meminfo = |key: &str| -> u64 {
|
||||
meminfo
|
||||
.lines()
|
||||
.find(|l| l.starts_with(key))
|
||||
.and_then(|l| l.split_whitespace().nth(1))
|
||||
.and_then(|v| v.parse::<u64>().ok())
|
||||
.unwrap_or(0)
|
||||
};
|
||||
let mem_total_kb = parse_meminfo("MemTotal:");
|
||||
let mem_available_kb = parse_meminfo("MemAvailable:");
|
||||
let mem_used_kb = mem_total_kb.saturating_sub(mem_available_kb);
|
||||
let mem_total_mb = mem_total_kb as f64 / 1024.0;
|
||||
let mem_used_mb = mem_used_kb as f64 / 1024.0;
|
||||
let mem_percent = if mem_total_kb > 0 {
|
||||
(mem_used_kb as f64 / mem_total_kb as f64 * 100.0 * 10.0).round() / 10.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
// --- Process-specific memory (VmRSS) ---
|
||||
let process_rss_mb = read_proc_file("/proc/self/status")
|
||||
.and_then(|s| s.lines().find(|l| l.starts_with("VmRSS:")).map(|l| l.to_string()))
|
||||
.and_then(|l| l.split_whitespace().nth(1).and_then(|v| v.parse::<f64>().ok()))
|
||||
.map(|kb| (kb / 1024.0 * 10.0).round() / 10.0)
|
||||
.unwrap_or(0.0);
|
||||
|
||||
// --- Disk usage of the data directory ---
|
||||
let (disk_total_gb, disk_used_gb, disk_percent) = {
|
||||
// statvfs via libc would be ideal; use df as a simple fallback
|
||||
std::process::Command::new("df")
|
||||
.args(["-BM", "--output=size,used,pcent", "."])
|
||||
.output()
|
||||
.ok()
|
||||
.and_then(|o| {
|
||||
let out = String::from_utf8_lossy(&o.stdout);
|
||||
let line = out.lines().nth(1)?.to_string();
|
||||
let parts: Vec<&str> = line.split_whitespace().collect();
|
||||
if parts.len() >= 3 {
|
||||
let total = parts[0].trim_end_matches('M').parse::<f64>().unwrap_or(0.0) / 1024.0;
|
||||
let used = parts[1].trim_end_matches('M').parse::<f64>().unwrap_or(0.0) / 1024.0;
|
||||
let pct = parts[2].trim_end_matches('%').parse::<f64>().unwrap_or(0.0);
|
||||
Some(((total * 10.0).round() / 10.0, (used * 10.0).round() / 10.0, pct))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.unwrap_or((0.0, 0.0, 0.0))
|
||||
};
|
||||
|
||||
// --- Uptime ---
|
||||
let uptime_seconds = read_proc_file("/proc/uptime")
|
||||
.and_then(|s| s.split_whitespace().next().and_then(|v| v.parse::<f64>().ok()))
|
||||
.unwrap_or(0.0) as u64;
|
||||
|
||||
// --- Load average ---
|
||||
let (load_1, load_5, load_15) = read_proc_file("/proc/loadavg")
|
||||
.and_then(|s| {
|
||||
let parts: Vec<&str> = s.split_whitespace().collect();
|
||||
if parts.len() >= 3 {
|
||||
Some((
|
||||
parts[0].parse::<f64>().unwrap_or(0.0),
|
||||
parts[1].parse::<f64>().unwrap_or(0.0),
|
||||
parts[2].parse::<f64>().unwrap_or(0.0),
|
||||
))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.unwrap_or((0.0, 0.0, 0.0));
|
||||
|
||||
// --- Network (from /proc/net/dev, aggregate non-lo interfaces) ---
|
||||
let (net_rx_bytes, net_tx_bytes) = read_proc_file("/proc/net/dev")
|
||||
.map(|s| {
|
||||
s.lines()
|
||||
.skip(2) // skip header lines
|
||||
.filter(|l| !l.trim().starts_with("lo:"))
|
||||
.fold((0u64, 0u64), |(rx, tx), line| {
|
||||
let parts: Vec<&str> = line.split_whitespace().collect();
|
||||
if parts.len() >= 10 {
|
||||
let r = parts[1].parse::<u64>().unwrap_or(0);
|
||||
let t = parts[9].parse::<u64>().unwrap_or(0);
|
||||
(rx + r, tx + t)
|
||||
} else {
|
||||
(rx, tx)
|
||||
}
|
||||
})
|
||||
})
|
||||
.unwrap_or((0, 0));
|
||||
|
||||
// --- Database pool ---
|
||||
let db_pool_size = state.app_state.db_pool.size();
|
||||
let db_pool_idle = state.app_state.db_pool.num_idle();
|
||||
|
||||
// --- Active WebSocket listeners ---
|
||||
let ws_listeners = state.app_state.dashboard_tx.receiver_count();
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"cpu": {
|
||||
"usage_percent": cpu_percent,
|
||||
"load_average": [load_1, load_5, load_15],
|
||||
},
|
||||
"memory": {
|
||||
"total_mb": (mem_total_mb * 10.0).round() / 10.0,
|
||||
"used_mb": (mem_used_mb * 10.0).round() / 10.0,
|
||||
"usage_percent": mem_percent,
|
||||
"process_rss_mb": process_rss_mb,
|
||||
},
|
||||
"disk": {
|
||||
"total_gb": disk_total_gb,
|
||||
"used_gb": disk_used_gb,
|
||||
"usage_percent": disk_percent,
|
||||
},
|
||||
"network": {
|
||||
"rx_bytes": net_rx_bytes,
|
||||
"tx_bytes": net_tx_bytes,
|
||||
},
|
||||
"uptime_seconds": uptime_seconds,
|
||||
"connections": {
|
||||
"db_active": db_pool_size - db_pool_idle as u32,
|
||||
"db_idle": db_pool_idle,
|
||||
"websocket_listeners": ws_listeners,
|
||||
},
|
||||
"timestamp": chrono::Utc::now().to_rfc3339(),
|
||||
})))
|
||||
}
|
||||
|
||||
pub(super) async fn handle_system_logs(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
SELECT
|
||||
id,
|
||||
timestamp,
|
||||
client_id,
|
||||
provider,
|
||||
model,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
reasoning_tokens,
|
||||
total_tokens,
|
||||
cache_read_tokens,
|
||||
cache_write_tokens,
|
||||
cost,
|
||||
status,
|
||||
error_message,
|
||||
duration_ms
|
||||
FROM llm_requests
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT 100
|
||||
"#,
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(rows) => {
|
||||
let logs: Vec<serde_json::Value> = rows
|
||||
.into_iter()
|
||||
.map(|row| {
|
||||
serde_json::json!({
|
||||
"id": row.get::<i64, _>("id"),
|
||||
"timestamp": row.get::<chrono::DateTime<chrono::Utc>, _>("timestamp"),
|
||||
"client_id": row.get::<String, _>("client_id"),
|
||||
"provider": row.get::<String, _>("provider"),
|
||||
"model": row.get::<String, _>("model"),
|
||||
"prompt_tokens": row.get::<i64, _>("prompt_tokens"),
|
||||
"completion_tokens": row.get::<i64, _>("completion_tokens"),
|
||||
"reasoning_tokens": row.get::<i64, _>("reasoning_tokens"),
|
||||
"cache_read_tokens": row.get::<i64, _>("cache_read_tokens"),
|
||||
"cache_write_tokens": row.get::<i64, _>("cache_write_tokens"),
|
||||
"tokens": row.get::<i64, _>("total_tokens"),
|
||||
"cost": row.get::<f64, _>("cost"),
|
||||
"status": row.get::<String, _>("status"),
|
||||
"error": row.get::<Option<String>, _>("error_message"),
|
||||
"duration": row.get::<i64, _>("duration_ms"),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!(logs)))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to fetch system logs: {}", e);
|
||||
Json(ApiResponse::error("Failed to fetch system logs".to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_system_backup(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
let backup_id = format!("backup-{}", chrono::Utc::now().timestamp());
|
||||
let backup_path = format!("data/{}.db", backup_id);
|
||||
|
||||
// Ensure the data directory exists
|
||||
if let Err(e) = std::fs::create_dir_all("data") {
|
||||
return Json(ApiResponse::error(format!("Failed to create backup directory: {}", e)));
|
||||
}
|
||||
|
||||
// Use SQLite VACUUM INTO for a consistent backup
|
||||
let result = sqlx::query(&format!("VACUUM INTO '{}'", backup_path))
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => {
|
||||
// Get backup file size
|
||||
let size_bytes = std::fs::metadata(&backup_path).map(|m| m.len()).unwrap_or(0);
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"success": true,
|
||||
"message": "Backup completed successfully",
|
||||
"backup_id": backup_id,
|
||||
"backup_path": backup_path,
|
||||
"size_bytes": size_bytes,
|
||||
})))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Database backup failed: {}", e);
|
||||
Json(ApiResponse::error(format!("Backup failed: {}", e)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_get_settings(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let registry = &state.app_state.model_registry;
|
||||
let provider_count = registry.providers.len();
|
||||
let model_count: usize = registry.providers.values().map(|p| p.models.len()).sum();
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"server": {
|
||||
"auth_tokens": state.app_state.auth_tokens.iter().map(|t| mask_token(t)).collect::<Vec<_>>(),
|
||||
"version": env!("CARGO_PKG_VERSION"),
|
||||
},
|
||||
"registry": {
|
||||
"provider_count": provider_count,
|
||||
"model_count": model_count,
|
||||
},
|
||||
"database": {
|
||||
"type": "SQLite",
|
||||
}
|
||||
})))
|
||||
}
|
||||
|
||||
pub(super) async fn handle_update_settings(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
Json(ApiResponse::error(
|
||||
"Changing settings at runtime is not yet supported. Please update your config file and restart the server."
|
||||
.to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
fn mask_token(token: &str) -> String {
|
||||
if token.len() <= 8 {
|
||||
return "*****".to_string();
|
||||
}
|
||||
|
||||
let masked_len = token.len().min(12);
|
||||
let visible_len = 4;
|
||||
let mask_len = masked_len - visible_len;
|
||||
|
||||
format!("{}{}", "*".repeat(mask_len), &token[token.len() - visible_len..])
|
||||
}
|
||||
@@ -1,526 +0,0 @@
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
response::Json,
|
||||
};
|
||||
use chrono;
|
||||
use serde::Deserialize;
|
||||
use serde_json;
|
||||
use sqlx::Row;
|
||||
use tracing::warn;
|
||||
|
||||
use super::{ApiResponse, DashboardState};
|
||||
|
||||
/// Query parameters for time-based filtering on usage endpoints.
|
||||
#[derive(Debug, Deserialize, Default)]
|
||||
pub(super) struct UsagePeriodFilter {
|
||||
/// Preset period: "today", "24h", "7d", "30d", "all" (default: "all")
|
||||
pub period: Option<String>,
|
||||
/// Custom range start (ISO 8601, e.g. "2025-06-01T00:00:00Z")
|
||||
pub from: Option<String>,
|
||||
/// Custom range end (ISO 8601)
|
||||
pub to: Option<String>,
|
||||
}
|
||||
|
||||
impl UsagePeriodFilter {
|
||||
/// Returns `(sql_fragment, bind_values)` for a WHERE clause.
|
||||
/// The fragment is either empty (no filter) or " AND timestamp >= ? [AND timestamp <= ?]".
|
||||
fn to_sql(&self) -> (String, Vec<String>) {
|
||||
let period = self.period.as_deref().unwrap_or("all");
|
||||
|
||||
if period == "custom" {
|
||||
let mut clause = String::new();
|
||||
let mut binds = Vec::new();
|
||||
if let Some(ref from) = self.from {
|
||||
clause.push_str(" AND timestamp >= ?");
|
||||
binds.push(from.clone());
|
||||
}
|
||||
if let Some(ref to) = self.to {
|
||||
clause.push_str(" AND timestamp <= ?");
|
||||
binds.push(to.clone());
|
||||
}
|
||||
return (clause, binds);
|
||||
}
|
||||
|
||||
let now = chrono::Utc::now();
|
||||
let cutoff = match period {
|
||||
"today" => {
|
||||
// Start of today (UTC)
|
||||
let today = now.format("%Y-%m-%dT00:00:00Z").to_string();
|
||||
Some(today)
|
||||
}
|
||||
"24h" => Some((now - chrono::Duration::hours(24)).to_rfc3339()),
|
||||
"7d" => Some((now - chrono::Duration::days(7)).to_rfc3339()),
|
||||
"30d" => Some((now - chrono::Duration::days(30)).to_rfc3339()),
|
||||
_ => None, // "all" or unrecognized
|
||||
};
|
||||
|
||||
match cutoff {
|
||||
Some(ts) => (" AND timestamp >= ?".to_string(), vec![ts]),
|
||||
None => (String::new(), vec![]),
|
||||
}
|
||||
}
|
||||
|
||||
/// Determine the time-series granularity label for grouping.
|
||||
fn granularity(&self) -> &'static str {
|
||||
match self.period.as_deref().unwrap_or("all") {
|
||||
"today" | "24h" => "hour",
|
||||
_ => "day",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_usage_summary(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Query(filter): Query<UsagePeriodFilter>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
let (period_clause, period_binds) = filter.to_sql();
|
||||
|
||||
// Total stats (filtered by period)
|
||||
let period_sql = format!(
|
||||
r#"
|
||||
SELECT
|
||||
COUNT(*) as total_requests,
|
||||
COALESCE(SUM(total_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(cost), 0.0) as total_cost,
|
||||
COUNT(DISTINCT client_id) as active_clients,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read,
|
||||
COALESCE(SUM(cache_write_tokens), 0) as total_cache_write
|
||||
FROM llm_requests
|
||||
WHERE 1=1 {}
|
||||
"#,
|
||||
period_clause
|
||||
);
|
||||
let mut q = sqlx::query(&period_sql);
|
||||
for b in &period_binds {
|
||||
q = q.bind(b);
|
||||
}
|
||||
let total_stats = q.fetch_one(pool);
|
||||
|
||||
// Today's stats
|
||||
let today = chrono::Utc::now().format("%Y-%m-%d").to_string();
|
||||
let today_stats = sqlx::query(
|
||||
r#"
|
||||
SELECT
|
||||
COUNT(*) as today_requests,
|
||||
COALESCE(SUM(total_tokens), 0) as today_tokens,
|
||||
COALESCE(SUM(cost), 0.0) as today_cost
|
||||
FROM llm_requests
|
||||
WHERE strftime('%Y-%m-%d', timestamp) = ?
|
||||
"#,
|
||||
)
|
||||
.bind(today)
|
||||
.fetch_one(pool);
|
||||
|
||||
// Error stats
|
||||
let error_stats = sqlx::query(
|
||||
r#"
|
||||
SELECT
|
||||
COUNT(*) as total,
|
||||
SUM(CASE WHEN status = 'error' THEN 1 ELSE 0 END) as errors
|
||||
FROM llm_requests
|
||||
"#,
|
||||
)
|
||||
.fetch_one(pool);
|
||||
|
||||
// Average response time
|
||||
let avg_response = sqlx::query(
|
||||
r#"
|
||||
SELECT COALESCE(AVG(duration_ms), 0.0) as avg_duration
|
||||
FROM llm_requests
|
||||
WHERE status = 'success'
|
||||
"#,
|
||||
)
|
||||
.fetch_one(pool);
|
||||
|
||||
let results = tokio::join!(total_stats, today_stats, error_stats, avg_response);
|
||||
|
||||
match results {
|
||||
(Ok(t), Ok(d), Ok(e), Ok(a)) => {
|
||||
let total_requests: i64 = t.get("total_requests");
|
||||
let total_tokens: i64 = t.get("total_tokens");
|
||||
let total_cost: f64 = t.get("total_cost");
|
||||
let active_clients: i64 = t.get("active_clients");
|
||||
let total_cache_read: i64 = t.get("total_cache_read");
|
||||
let total_cache_write: i64 = t.get("total_cache_write");
|
||||
|
||||
let today_requests: i64 = d.get("today_requests");
|
||||
let today_cost: f64 = d.get("today_cost");
|
||||
|
||||
let total_count: i64 = e.get("total");
|
||||
let error_count: i64 = e.get("errors");
|
||||
let error_rate = if total_count > 0 {
|
||||
(error_count as f64 / total_count as f64) * 100.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
let avg_response_time: f64 = a.get("avg_duration");
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"total_requests": total_requests,
|
||||
"total_tokens": total_tokens,
|
||||
"total_cost": total_cost,
|
||||
"active_clients": active_clients,
|
||||
"today_requests": today_requests,
|
||||
"today_cost": today_cost,
|
||||
"error_rate": error_rate,
|
||||
"avg_response_time": avg_response_time,
|
||||
"total_cache_read_tokens": total_cache_read,
|
||||
"total_cache_write_tokens": total_cache_write,
|
||||
})))
|
||||
}
|
||||
(t_res, d_res, e_res, a_res) => {
|
||||
if let Err(e) = t_res { warn!("Total stats query failed: {}", e); }
|
||||
if let Err(e) = d_res { warn!("Today stats query failed: {}", e); }
|
||||
if let Err(e) = e_res { warn!("Error stats query failed: {}", e); }
|
||||
if let Err(e) = a_res { warn!("Avg response query failed: {}", e); }
|
||||
Json(ApiResponse::error("Failed to fetch usage statistics".to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_time_series(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Query(filter): Query<UsagePeriodFilter>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
let (period_clause, period_binds) = filter.to_sql();
|
||||
let granularity = filter.granularity();
|
||||
|
||||
// Determine the strftime format and default lookback
|
||||
let (strftime_fmt, _label_key, default_lookback) = match granularity {
|
||||
"hour" => ("%H:00", "hour", chrono::Duration::hours(24)),
|
||||
_ => ("%Y-%m-%d", "day", chrono::Duration::days(30)),
|
||||
};
|
||||
|
||||
// If no period filter, apply a sensible default lookback
|
||||
let (clause, binds) = if period_clause.is_empty() {
|
||||
let cutoff = (chrono::Utc::now() - default_lookback).to_rfc3339();
|
||||
(" AND timestamp >= ?".to_string(), vec![cutoff])
|
||||
} else {
|
||||
(period_clause, period_binds)
|
||||
};
|
||||
|
||||
let sql = format!(
|
||||
r#"
|
||||
SELECT
|
||||
strftime('{strftime_fmt}', timestamp) as bucket,
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(total_tokens), 0) as tokens,
|
||||
COALESCE(SUM(cost), 0.0) as cost
|
||||
FROM llm_requests
|
||||
WHERE 1=1 {clause}
|
||||
GROUP BY bucket
|
||||
ORDER BY bucket
|
||||
"#,
|
||||
);
|
||||
|
||||
let mut q = sqlx::query(&sql);
|
||||
for b in &binds {
|
||||
q = q.bind(b);
|
||||
}
|
||||
|
||||
let result = q.fetch_all(pool).await;
|
||||
|
||||
match result {
|
||||
Ok(rows) => {
|
||||
let mut series = Vec::new();
|
||||
|
||||
for row in rows {
|
||||
let bucket: String = row.get("bucket");
|
||||
let requests: i64 = row.get("requests");
|
||||
let tokens: i64 = row.get("tokens");
|
||||
let cost: f64 = row.get("cost");
|
||||
|
||||
series.push(serde_json::json!({
|
||||
"time": bucket,
|
||||
"requests": requests,
|
||||
"tokens": tokens,
|
||||
"cost": cost,
|
||||
}));
|
||||
}
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"series": series,
|
||||
"period": filter.period.as_deref().unwrap_or("all"),
|
||||
"granularity": granularity,
|
||||
})))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to fetch time series data: {}", e);
|
||||
Json(ApiResponse::error("Failed to fetch time series data".to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_clients_usage(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Query(filter): Query<UsagePeriodFilter>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
let (period_clause, period_binds) = filter.to_sql();
|
||||
|
||||
let sql = format!(
|
||||
r#"
|
||||
SELECT
|
||||
client_id,
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(total_tokens), 0) as tokens,
|
||||
COALESCE(SUM(cost), 0.0) as cost,
|
||||
MAX(timestamp) as last_request
|
||||
FROM llm_requests
|
||||
WHERE 1=1 {}
|
||||
GROUP BY client_id
|
||||
ORDER BY requests DESC
|
||||
"#,
|
||||
period_clause
|
||||
);
|
||||
|
||||
let mut q = sqlx::query(&sql);
|
||||
for b in &period_binds {
|
||||
q = q.bind(b);
|
||||
}
|
||||
|
||||
let result = q.fetch_all(pool).await;
|
||||
|
||||
match result {
|
||||
Ok(rows) => {
|
||||
let mut client_usage = Vec::new();
|
||||
|
||||
for row in rows {
|
||||
let client_id: String = row.get("client_id");
|
||||
let requests: i64 = row.get("requests");
|
||||
let tokens: i64 = row.get("tokens");
|
||||
let cost: f64 = row.get("cost");
|
||||
let last_request: Option<chrono::DateTime<chrono::Utc>> = row.get("last_request");
|
||||
|
||||
client_usage.push(serde_json::json!({
|
||||
"client_id": client_id,
|
||||
"client_name": client_id,
|
||||
"requests": requests,
|
||||
"tokens": tokens,
|
||||
"cost": cost,
|
||||
"last_request": last_request,
|
||||
}));
|
||||
}
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!(client_usage)))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to fetch client usage data: {}", e);
|
||||
Json(ApiResponse::error("Failed to fetch client usage data".to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_providers_usage(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Query(filter): Query<UsagePeriodFilter>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
let (period_clause, period_binds) = filter.to_sql();
|
||||
|
||||
let sql = format!(
|
||||
r#"
|
||||
SELECT
|
||||
provider,
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(total_tokens), 0) as tokens,
|
||||
COALESCE(SUM(cost), 0.0) as cost,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as cache_read,
|
||||
COALESCE(SUM(cache_write_tokens), 0) as cache_write
|
||||
FROM llm_requests
|
||||
WHERE 1=1 {}
|
||||
GROUP BY provider
|
||||
ORDER BY requests DESC
|
||||
"#,
|
||||
period_clause
|
||||
);
|
||||
|
||||
let mut q = sqlx::query(&sql);
|
||||
for b in &period_binds {
|
||||
q = q.bind(b);
|
||||
}
|
||||
|
||||
let result = q.fetch_all(pool).await;
|
||||
|
||||
match result {
|
||||
Ok(rows) => {
|
||||
let mut provider_usage = Vec::new();
|
||||
|
||||
for row in rows {
|
||||
let provider: String = row.get("provider");
|
||||
let requests: i64 = row.get("requests");
|
||||
let tokens: i64 = row.get("tokens");
|
||||
let cost: f64 = row.get("cost");
|
||||
let cache_read: i64 = row.get("cache_read");
|
||||
let cache_write: i64 = row.get("cache_write");
|
||||
|
||||
provider_usage.push(serde_json::json!({
|
||||
"provider": provider,
|
||||
"requests": requests,
|
||||
"tokens": tokens,
|
||||
"cost": cost,
|
||||
"cache_read_tokens": cache_read,
|
||||
"cache_write_tokens": cache_write,
|
||||
}));
|
||||
}
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!(provider_usage)))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to fetch provider usage data: {}", e);
|
||||
Json(ApiResponse::error("Failed to fetch provider usage data".to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_detailed_usage(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Query(filter): Query<UsagePeriodFilter>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
let (period_clause, period_binds) = filter.to_sql();
|
||||
|
||||
let sql = format!(
|
||||
r#"
|
||||
SELECT
|
||||
strftime('%Y-%m-%d', timestamp) as date,
|
||||
client_id,
|
||||
provider,
|
||||
model,
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(total_tokens), 0) as tokens,
|
||||
COALESCE(SUM(cost), 0.0) as cost,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as cache_read,
|
||||
COALESCE(SUM(cache_write_tokens), 0) as cache_write
|
||||
FROM llm_requests
|
||||
WHERE 1=1 {}
|
||||
GROUP BY date, client_id, provider, model
|
||||
ORDER BY date DESC
|
||||
LIMIT 200
|
||||
"#,
|
||||
period_clause
|
||||
);
|
||||
|
||||
let mut q = sqlx::query(&sql);
|
||||
for b in &period_binds {
|
||||
q = q.bind(b);
|
||||
}
|
||||
|
||||
let result = q.fetch_all(pool).await;
|
||||
|
||||
match result {
|
||||
Ok(rows) => {
|
||||
let usage: Vec<serde_json::Value> = rows
|
||||
.into_iter()
|
||||
.map(|row| {
|
||||
serde_json::json!({
|
||||
"date": row.get::<String, _>("date"),
|
||||
"client": row.get::<String, _>("client_id"),
|
||||
"provider": row.get::<String, _>("provider"),
|
||||
"model": row.get::<String, _>("model"),
|
||||
"requests": row.get::<i64, _>("requests"),
|
||||
"tokens": row.get::<i64, _>("tokens"),
|
||||
"cost": row.get::<f64, _>("cost"),
|
||||
"cache_read_tokens": row.get::<i64, _>("cache_read"),
|
||||
"cache_write_tokens": row.get::<i64, _>("cache_write"),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!(usage)))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to fetch detailed usage: {}", e);
|
||||
Json(ApiResponse::error("Failed to fetch detailed usage".to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_analytics_breakdown(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Query(filter): Query<UsagePeriodFilter>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
let (period_clause, period_binds) = filter.to_sql();
|
||||
|
||||
// Model breakdown
|
||||
let model_sql = format!(
|
||||
"SELECT model as label, COUNT(*) as value FROM llm_requests WHERE 1=1 {} GROUP BY model ORDER BY value DESC",
|
||||
period_clause
|
||||
);
|
||||
let mut mq = sqlx::query(&model_sql);
|
||||
for b in &period_binds {
|
||||
mq = mq.bind(b);
|
||||
}
|
||||
let models = mq.fetch_all(pool);
|
||||
|
||||
// Client breakdown
|
||||
let client_sql = format!(
|
||||
"SELECT client_id as label, COUNT(*) as value FROM llm_requests WHERE 1=1 {} GROUP BY client_id ORDER BY value DESC",
|
||||
period_clause
|
||||
);
|
||||
let mut cq = sqlx::query(&client_sql);
|
||||
for b in &period_binds {
|
||||
cq = cq.bind(b);
|
||||
}
|
||||
let clients = cq.fetch_all(pool);
|
||||
|
||||
match tokio::join!(models, clients) {
|
||||
(Ok(m_rows), Ok(c_rows)) => {
|
||||
let model_breakdown: Vec<serde_json::Value> = m_rows
|
||||
.into_iter()
|
||||
.map(|r| serde_json::json!({ "label": r.get::<String, _>("label"), "value": r.get::<i64, _>("value") }))
|
||||
.collect();
|
||||
|
||||
let client_breakdown: Vec<serde_json::Value> = c_rows
|
||||
.into_iter()
|
||||
.map(|r| serde_json::json!({ "label": r.get::<String, _>("label"), "value": r.get::<i64, _>("value") }))
|
||||
.collect();
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"models": model_breakdown,
|
||||
"clients": client_breakdown
|
||||
})))
|
||||
}
|
||||
_ => Json(ApiResponse::error("Failed to fetch analytics breakdown".to_string())),
|
||||
}
|
||||
}
|
||||
@@ -1,290 +0,0 @@
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
response::Json,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use sqlx::Row;
|
||||
use tracing::warn;
|
||||
|
||||
use super::{ApiResponse, DashboardState, auth};
|
||||
|
||||
// ── User management endpoints (admin-only) ──────────────────────────
|
||||
|
||||
pub(super) async fn handle_get_users(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
let result = sqlx::query(
|
||||
"SELECT id, username, display_name, role, must_change_password, created_at FROM users ORDER BY created_at ASC",
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(rows) => {
|
||||
let users: Vec<serde_json::Value> = rows
|
||||
.into_iter()
|
||||
.map(|row| {
|
||||
let username: String = row.get("username");
|
||||
let display_name: Option<String> = row.get("display_name");
|
||||
serde_json::json!({
|
||||
"id": row.get::<i64, _>("id"),
|
||||
"username": &username,
|
||||
"display_name": display_name.as_deref().unwrap_or(&username),
|
||||
"role": row.get::<String, _>("role"),
|
||||
"must_change_password": row.get::<bool, _>("must_change_password"),
|
||||
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!(users)))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to fetch users: {}", e);
|
||||
Json(ApiResponse::error("Failed to fetch users".to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(super) struct CreateUserRequest {
|
||||
pub(super) username: String,
|
||||
pub(super) password: String,
|
||||
pub(super) display_name: Option<String>,
|
||||
pub(super) role: Option<String>,
|
||||
}
|
||||
|
||||
pub(super) async fn handle_create_user(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Json(payload): Json<CreateUserRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Validate role
|
||||
let role = payload.role.as_deref().unwrap_or("viewer");
|
||||
if role != "admin" && role != "viewer" {
|
||||
return Json(ApiResponse::error("Role must be 'admin' or 'viewer'".to_string()));
|
||||
}
|
||||
|
||||
// Validate username
|
||||
let username = payload.username.trim();
|
||||
if username.is_empty() || username.len() > 64 {
|
||||
return Json(ApiResponse::error("Username must be 1-64 characters".to_string()));
|
||||
}
|
||||
|
||||
// Validate password
|
||||
if payload.password.len() < 4 {
|
||||
return Json(ApiResponse::error("Password must be at least 4 characters".to_string()));
|
||||
}
|
||||
|
||||
let password_hash = match bcrypt::hash(&payload.password, 12) {
|
||||
Ok(h) => h,
|
||||
Err(_) => return Json(ApiResponse::error("Failed to hash password".to_string())),
|
||||
};
|
||||
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
INSERT INTO users (username, password_hash, display_name, role, must_change_password)
|
||||
VALUES (?, ?, ?, ?, TRUE)
|
||||
RETURNING id, username, display_name, role, must_change_password, created_at
|
||||
"#,
|
||||
)
|
||||
.bind(username)
|
||||
.bind(&password_hash)
|
||||
.bind(&payload.display_name)
|
||||
.bind(role)
|
||||
.fetch_one(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(row) => {
|
||||
let uname: String = row.get("username");
|
||||
let display_name: Option<String> = row.get("display_name");
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"id": row.get::<i64, _>("id"),
|
||||
"username": &uname,
|
||||
"display_name": display_name.as_deref().unwrap_or(&uname),
|
||||
"role": row.get::<String, _>("role"),
|
||||
"must_change_password": row.get::<bool, _>("must_change_password"),
|
||||
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
|
||||
})))
|
||||
}
|
||||
Err(e) => {
|
||||
let msg = if e.to_string().contains("UNIQUE") {
|
||||
format!("Username '{}' already exists", username)
|
||||
} else {
|
||||
format!("Failed to create user: {}", e)
|
||||
};
|
||||
warn!("Failed to create user: {}", e);
|
||||
Json(ApiResponse::error(msg))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(super) struct UpdateUserRequest {
|
||||
pub(super) display_name: Option<String>,
|
||||
pub(super) role: Option<String>,
|
||||
pub(super) password: Option<String>,
|
||||
pub(super) must_change_password: Option<bool>,
|
||||
}
|
||||
|
||||
pub(super) async fn handle_update_user(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Path(id): Path<i64>,
|
||||
Json(payload): Json<UpdateUserRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Validate role if provided
|
||||
if let Some(ref role) = payload.role {
|
||||
if role != "admin" && role != "viewer" {
|
||||
return Json(ApiResponse::error("Role must be 'admin' or 'viewer'".to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
// Build dynamic update
|
||||
let mut sets = Vec::new();
|
||||
let mut string_binds: Vec<String> = Vec::new();
|
||||
let mut has_password = false;
|
||||
let mut has_must_change = false;
|
||||
|
||||
if let Some(ref display_name) = payload.display_name {
|
||||
sets.push("display_name = ?");
|
||||
string_binds.push(display_name.clone());
|
||||
}
|
||||
if let Some(ref role) = payload.role {
|
||||
sets.push("role = ?");
|
||||
string_binds.push(role.clone());
|
||||
}
|
||||
if let Some(ref password) = payload.password {
|
||||
if password.len() < 4 {
|
||||
return Json(ApiResponse::error("Password must be at least 4 characters".to_string()));
|
||||
}
|
||||
let hash = match bcrypt::hash(password, 12) {
|
||||
Ok(h) => h,
|
||||
Err(_) => return Json(ApiResponse::error("Failed to hash password".to_string())),
|
||||
};
|
||||
sets.push("password_hash = ?");
|
||||
string_binds.push(hash);
|
||||
has_password = true;
|
||||
}
|
||||
if let Some(mcp) = payload.must_change_password {
|
||||
sets.push("must_change_password = ?");
|
||||
has_must_change = true;
|
||||
let _ = mcp; // used below in bind
|
||||
}
|
||||
|
||||
if sets.is_empty() {
|
||||
return Json(ApiResponse::error("No fields to update".to_string()));
|
||||
}
|
||||
|
||||
let sql = format!("UPDATE users SET {} WHERE id = ?", sets.join(", "));
|
||||
let mut query = sqlx::query(&sql);
|
||||
|
||||
for b in &string_binds {
|
||||
query = query.bind(b);
|
||||
}
|
||||
if has_must_change {
|
||||
query = query.bind(payload.must_change_password.unwrap());
|
||||
}
|
||||
let _ = has_password; // consumed above via string_binds
|
||||
query = query.bind(id);
|
||||
|
||||
match query.execute(pool).await {
|
||||
Ok(result) => {
|
||||
if result.rows_affected() == 0 {
|
||||
return Json(ApiResponse::error("User not found".to_string()));
|
||||
}
|
||||
// Fetch updated user
|
||||
let row = sqlx::query(
|
||||
"SELECT id, username, display_name, role, must_change_password, created_at FROM users WHERE id = ?",
|
||||
)
|
||||
.bind(id)
|
||||
.fetch_one(pool)
|
||||
.await;
|
||||
|
||||
match row {
|
||||
Ok(row) => {
|
||||
let uname: String = row.get("username");
|
||||
let display_name: Option<String> = row.get("display_name");
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"id": row.get::<i64, _>("id"),
|
||||
"username": &uname,
|
||||
"display_name": display_name.as_deref().unwrap_or(&uname),
|
||||
"role": row.get::<String, _>("role"),
|
||||
"must_change_password": row.get::<bool, _>("must_change_password"),
|
||||
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
|
||||
})))
|
||||
}
|
||||
Err(_) => Json(ApiResponse::success(serde_json::json!({ "message": "User updated" }))),
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to update user: {}", e);
|
||||
Json(ApiResponse::error(format!("Failed to update user: {}", e)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_delete_user(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Path(id): Path<i64>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (session, _) = match auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Don't allow deleting yourself
|
||||
let target_username: Option<String> =
|
||||
sqlx::query_scalar::<_, String>("SELECT username FROM users WHERE id = ?")
|
||||
.bind(id)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
.unwrap_or(None);
|
||||
|
||||
match target_username {
|
||||
None => return Json(ApiResponse::error("User not found".to_string())),
|
||||
Some(ref uname) if uname == &session.username => {
|
||||
return Json(ApiResponse::error("Cannot delete your own account".to_string()));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let result = sqlx::query("DELETE FROM users WHERE id = ?")
|
||||
.bind(id)
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => Json(ApiResponse::success(serde_json::json!({ "message": "User deleted" }))),
|
||||
Err(e) => {
|
||||
warn!("Failed to delete user: {}", e);
|
||||
Json(ApiResponse::error(format!("Failed to delete user: {}", e)))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,75 +0,0 @@
|
||||
use axum::{
|
||||
extract::{
|
||||
State,
|
||||
ws::{Message, WebSocket, WebSocketUpgrade},
|
||||
},
|
||||
response::IntoResponse,
|
||||
};
|
||||
use serde_json;
|
||||
use tracing::info;
|
||||
|
||||
use super::DashboardState;
|
||||
|
||||
// WebSocket handler
|
||||
pub(super) async fn handle_websocket(ws: WebSocketUpgrade, State(state): State<DashboardState>) -> impl IntoResponse {
|
||||
ws.on_upgrade(|socket| handle_websocket_connection(socket, state))
|
||||
}
|
||||
|
||||
pub(super) async fn handle_websocket_connection(mut socket: WebSocket, state: DashboardState) {
|
||||
info!("WebSocket connection established");
|
||||
|
||||
// Subscribe to events from the global bus
|
||||
let mut rx = state.app_state.dashboard_tx.subscribe();
|
||||
|
||||
// Send initial connection message
|
||||
let _ = socket
|
||||
.send(Message::Text(
|
||||
serde_json::json!({
|
||||
"type": "connected",
|
||||
"message": "Connected to LLM Proxy Dashboard"
|
||||
})
|
||||
.to_string()
|
||||
.into(),
|
||||
))
|
||||
.await;
|
||||
|
||||
// Handle incoming messages and broadcast events
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Receive broadcast events
|
||||
Ok(event) = rx.recv() => {
|
||||
let Ok(json_str) = serde_json::to_string(&event) else {
|
||||
continue;
|
||||
};
|
||||
let message = Message::Text(json_str.into());
|
||||
if socket.send(message).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Receive WebSocket messages
|
||||
result = socket.recv() => {
|
||||
match result {
|
||||
Some(Ok(Message::Text(text))) => {
|
||||
handle_websocket_message(&text, &state).await;
|
||||
}
|
||||
_ => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!("WebSocket connection closed");
|
||||
}
|
||||
|
||||
pub(super) async fn handle_websocket_message(text: &str, state: &DashboardState) {
|
||||
// Parse and handle WebSocket messages
|
||||
if let Ok(data) = serde_json::from_str::<serde_json::Value>(text)
|
||||
&& data.get("type").and_then(|v| v.as_str()) == Some("ping")
|
||||
{
|
||||
let _ = state.app_state.dashboard_tx.send(serde_json::json!({
|
||||
"type": "pong",
|
||||
"payload": {}
|
||||
}));
|
||||
}
|
||||
}
|
||||
@@ -1,275 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use sqlx::sqlite::{SqliteConnectOptions, SqlitePool};
|
||||
use std::str::FromStr;
|
||||
use tracing::info;
|
||||
|
||||
use crate::config::DatabaseConfig;
|
||||
|
||||
pub type DbPool = SqlitePool;
|
||||
|
||||
pub async fn init(config: &DatabaseConfig) -> Result<DbPool> {
|
||||
// Ensure the database directory exists
|
||||
if let Some(parent) = config.path.parent()
|
||||
&& !parent.as_os_str().is_empty()
|
||||
{
|
||||
tokio::fs::create_dir_all(parent).await?;
|
||||
}
|
||||
|
||||
let database_path = config.path.to_string_lossy().to_string();
|
||||
info!("Connecting to database at {}", database_path);
|
||||
|
||||
let options = SqliteConnectOptions::from_str(&format!("sqlite:{}", database_path))?
|
||||
.create_if_missing(true)
|
||||
.pragma("foreign_keys", "ON");
|
||||
|
||||
let pool = SqlitePool::connect_with(options).await?;
|
||||
|
||||
// Run migrations
|
||||
run_migrations(&pool).await?;
|
||||
info!("Database migrations completed");
|
||||
|
||||
Ok(pool)
|
||||
}
|
||||
|
||||
pub async fn run_migrations(pool: &DbPool) -> Result<()> {
|
||||
// Create clients table if it doesn't exist
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS clients (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
client_id TEXT UNIQUE NOT NULL,
|
||||
name TEXT,
|
||||
description TEXT,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
is_active BOOLEAN DEFAULT TRUE,
|
||||
rate_limit_per_minute INTEGER DEFAULT 60,
|
||||
total_requests INTEGER DEFAULT 0,
|
||||
total_tokens INTEGER DEFAULT 0,
|
||||
total_cost REAL DEFAULT 0.0
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
// Create llm_requests table if it doesn't exist
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS llm_requests (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
client_id TEXT,
|
||||
provider TEXT,
|
||||
model TEXT,
|
||||
prompt_tokens INTEGER,
|
||||
completion_tokens INTEGER,
|
||||
reasoning_tokens INTEGER DEFAULT 0,
|
||||
total_tokens INTEGER,
|
||||
cost REAL,
|
||||
has_images BOOLEAN DEFAULT FALSE,
|
||||
status TEXT DEFAULT 'success',
|
||||
error_message TEXT,
|
||||
duration_ms INTEGER,
|
||||
request_body TEXT,
|
||||
response_body TEXT,
|
||||
FOREIGN KEY (client_id) REFERENCES clients(client_id) ON DELETE SET NULL
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
// Create provider_configs table
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS provider_configs (
|
||||
id TEXT PRIMARY KEY,
|
||||
display_name TEXT NOT NULL,
|
||||
enabled BOOLEAN DEFAULT TRUE,
|
||||
base_url TEXT,
|
||||
api_key TEXT,
|
||||
credit_balance REAL DEFAULT 0.0,
|
||||
low_credit_threshold REAL DEFAULT 5.0,
|
||||
billing_mode TEXT,
|
||||
api_key_encrypted BOOLEAN DEFAULT FALSE,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
// Create model_configs table
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS model_configs (
|
||||
id TEXT PRIMARY KEY,
|
||||
provider_id TEXT NOT NULL,
|
||||
display_name TEXT,
|
||||
enabled BOOLEAN DEFAULT TRUE,
|
||||
prompt_cost_per_m REAL,
|
||||
completion_cost_per_m REAL,
|
||||
cache_read_cost_per_m REAL,
|
||||
cache_write_cost_per_m REAL,
|
||||
mapping TEXT,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (provider_id) REFERENCES provider_configs(id) ON DELETE CASCADE
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
// Create users table for dashboard access
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username TEXT UNIQUE NOT NULL,
|
||||
password_hash TEXT NOT NULL,
|
||||
display_name TEXT,
|
||||
role TEXT DEFAULT 'admin',
|
||||
must_change_password BOOLEAN DEFAULT FALSE,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
// Create client_tokens table for DB-based token auth
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS client_tokens (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
client_id TEXT NOT NULL,
|
||||
token TEXT NOT NULL UNIQUE,
|
||||
name TEXT DEFAULT 'default',
|
||||
is_active BOOLEAN DEFAULT TRUE,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
last_used_at DATETIME,
|
||||
FOREIGN KEY (client_id) REFERENCES clients(client_id) ON DELETE CASCADE
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
// Add must_change_password column if it doesn't exist (migration for existing DBs)
|
||||
let _ = sqlx::query("ALTER TABLE users ADD COLUMN must_change_password BOOLEAN DEFAULT FALSE")
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
// Add display_name column if it doesn't exist (migration for existing DBs)
|
||||
let _ = sqlx::query("ALTER TABLE users ADD COLUMN display_name TEXT")
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
// Add cache token columns if they don't exist (migration for existing DBs)
|
||||
let _ = sqlx::query("ALTER TABLE llm_requests ADD COLUMN cache_read_tokens INTEGER DEFAULT 0")
|
||||
.execute(pool)
|
||||
.await;
|
||||
let _ = sqlx::query("ALTER TABLE llm_requests ADD COLUMN cache_write_tokens INTEGER DEFAULT 0")
|
||||
.execute(pool)
|
||||
.await;
|
||||
let _ = sqlx::query("ALTER TABLE llm_requests ADD COLUMN reasoning_tokens INTEGER DEFAULT 0")
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
// Add billing_mode column if it doesn't exist (migration for existing DBs)
|
||||
let _ = sqlx::query("ALTER TABLE provider_configs ADD COLUMN billing_mode TEXT")
|
||||
.execute(pool)
|
||||
.await;
|
||||
// Add api_key_encrypted column if it doesn't exist (migration for existing DBs)
|
||||
let _ = sqlx::query("ALTER TABLE provider_configs ADD COLUMN api_key_encrypted BOOLEAN DEFAULT FALSE")
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
// Add manual cache cost columns to model_configs if they don't exist
|
||||
let _ = sqlx::query("ALTER TABLE model_configs ADD COLUMN cache_read_cost_per_m REAL")
|
||||
.execute(pool)
|
||||
.await;
|
||||
let _ = sqlx::query("ALTER TABLE model_configs ADD COLUMN cache_write_cost_per_m REAL")
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
// Insert default admin user if none exists (default password: admin)
|
||||
let user_count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users").fetch_one(pool).await?;
|
||||
|
||||
if user_count.0 == 0 {
|
||||
// 'admin' hashed with default cost (12)
|
||||
let default_admin_hash =
|
||||
bcrypt::hash("admin", 12).map_err(|e| anyhow::anyhow!("Failed to hash default password: {}", e))?;
|
||||
sqlx::query(
|
||||
"INSERT INTO users (username, password_hash, role, must_change_password) VALUES ('admin', ?, 'admin', TRUE)"
|
||||
)
|
||||
.bind(default_admin_hash)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
info!("Created default admin user with password 'admin' (must change on first login)");
|
||||
}
|
||||
|
||||
// Create indices
|
||||
sqlx::query("CREATE INDEX IF NOT EXISTS idx_clients_client_id ON clients(client_id)")
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query("CREATE INDEX IF NOT EXISTS idx_clients_created_at ON clients(created_at)")
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_timestamp ON llm_requests(timestamp)")
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_client_id ON llm_requests(client_id)")
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_provider ON llm_requests(provider)")
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_status ON llm_requests(status)")
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query("CREATE UNIQUE INDEX IF NOT EXISTS idx_client_tokens_token ON client_tokens(token)")
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query("CREATE INDEX IF NOT EXISTS idx_client_tokens_client_id ON client_tokens(client_id)")
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
// Composite indexes for performance
|
||||
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_client_timestamp ON llm_requests(client_id, timestamp)")
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_provider_timestamp ON llm_requests(provider, timestamp)")
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query("CREATE INDEX IF NOT EXISTS idx_model_configs_provider_id ON model_configs(provider_id)")
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
// Insert default client if none exists
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT OR IGNORE INTO clients (client_id, name, description)
|
||||
VALUES ('default', 'Default Client', 'Default client for anonymous requests')
|
||||
"#,
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn test_connection(pool: &DbPool) -> Result<()> {
|
||||
sqlx::query("SELECT 1").execute(pool).await?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,58 +0,0 @@
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Error, Debug, Clone)]
|
||||
pub enum AppError {
|
||||
#[error("Authentication failed: {0}")]
|
||||
AuthError(String),
|
||||
|
||||
#[error("Configuration error: {0}")]
|
||||
ConfigError(String),
|
||||
|
||||
#[error("Database error: {0}")]
|
||||
DatabaseError(String),
|
||||
|
||||
#[error("Provider error: {0}")]
|
||||
ProviderError(String),
|
||||
|
||||
#[error("Validation error: {0}")]
|
||||
ValidationError(String),
|
||||
|
||||
#[error("Multimodal processing error: {0}")]
|
||||
MultimodalError(String),
|
||||
|
||||
#[error("Rate limit exceeded: {0}")]
|
||||
RateLimitError(String),
|
||||
|
||||
#[error("Internal server error: {0}")]
|
||||
InternalError(String),
|
||||
}
|
||||
|
||||
impl From<sqlx::Error> for AppError {
|
||||
fn from(err: sqlx::Error) -> Self {
|
||||
AppError::DatabaseError(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<anyhow::Error> for AppError {
|
||||
fn from(err: anyhow::Error) -> Self {
|
||||
AppError::InternalError(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl axum::response::IntoResponse for AppError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let status = match self {
|
||||
AppError::AuthError(_) => axum::http::StatusCode::UNAUTHORIZED,
|
||||
AppError::RateLimitError(_) => axum::http::StatusCode::TOO_MANY_REQUESTS,
|
||||
AppError::ValidationError(_) => axum::http::StatusCode::BAD_REQUEST,
|
||||
_ => axum::http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||
};
|
||||
|
||||
let body = axum::Json(serde_json::json!({
|
||||
"error": self.to_string(),
|
||||
"type": format!("{:?}", self)
|
||||
}));
|
||||
|
||||
(status, body).into_response()
|
||||
}
|
||||
}
|
||||
-328
@@ -1,328 +0,0 @@
|
||||
//! LLM Proxy Library
|
||||
//!
|
||||
//! This library provides the core functionality for the LLM proxy gateway,
|
||||
//! including provider integration, token tracking, and API endpoints.
|
||||
|
||||
pub mod auth;
|
||||
pub mod client;
|
||||
pub mod config;
|
||||
pub mod dashboard;
|
||||
pub mod database;
|
||||
pub mod errors;
|
||||
pub mod logging;
|
||||
pub mod models;
|
||||
pub mod multimodal;
|
||||
pub mod providers;
|
||||
pub mod rate_limiting;
|
||||
pub mod server;
|
||||
pub mod state;
|
||||
pub mod utils;
|
||||
|
||||
// Re-exports for convenience
|
||||
pub use auth::{AuthenticatedClient, validate_token};
|
||||
pub use config::{
|
||||
AppConfig, DatabaseConfig, DeepSeekConfig, GeminiConfig, GrokConfig, ModelMappingConfig, ModelPricing,
|
||||
OllamaConfig, OpenAIConfig, PricingConfig, ProviderConfig, ServerConfig,
|
||||
};
|
||||
pub use database::{DbPool, init as init_db, test_connection};
|
||||
pub use errors::AppError;
|
||||
pub use logging::{LoggingContext, RequestLog, RequestLogger};
|
||||
pub use models::{
|
||||
ChatChoice, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, ChatMessage,
|
||||
ChatStreamChoice, ChatStreamDelta, ContentPart, ContentPartValue, FromOpenAI, ImageUrl, MessageContent,
|
||||
OpenAIContentPart, OpenAIMessage, OpenAIRequest, ToOpenAI, UnifiedMessage, UnifiedRequest, Usage,
|
||||
};
|
||||
pub use providers::{Provider, ProviderManager, ProviderResponse, ProviderStreamChunk};
|
||||
pub use server::router;
|
||||
pub use state::AppState;
|
||||
|
||||
/// Test utilities for integration testing
|
||||
#[cfg(test)]
|
||||
pub mod test_utils {
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{client::ClientManager, providers::ProviderManager, rate_limiting::RateLimitManager, state::AppState, utils::crypto, database::run_migrations};
|
||||
use sqlx::sqlite::SqlitePool;
|
||||
|
||||
/// Create a test application state
|
||||
pub async fn create_test_state() -> AppState {
|
||||
// Create in-memory database
|
||||
let pool = SqlitePool::connect("sqlite::memory:")
|
||||
.await
|
||||
.expect("Failed to create test database");
|
||||
|
||||
// Run migrations on the pool
|
||||
run_migrations(&pool).await.expect("Failed to run migrations");
|
||||
|
||||
let rate_limit_manager = RateLimitManager::new(
|
||||
crate::rate_limiting::RateLimiterConfig::default(),
|
||||
crate::rate_limiting::CircuitBreakerConfig::default(),
|
||||
);
|
||||
|
||||
let client_manager = Arc::new(ClientManager::new(pool.clone()));
|
||||
|
||||
// Create provider manager
|
||||
let provider_manager = ProviderManager::new();
|
||||
|
||||
let model_registry = crate::models::registry::ModelRegistry {
|
||||
providers: std::collections::HashMap::new(),
|
||||
};
|
||||
|
||||
let (dashboard_tx, _) = tokio::sync::broadcast::channel::<serde_json::Value>(100);
|
||||
|
||||
let config = Arc::new(crate::config::AppConfig {
|
||||
server: crate::config::ServerConfig {
|
||||
port: 8080,
|
||||
host: "127.0.0.1".to_string(),
|
||||
auth_tokens: vec![],
|
||||
},
|
||||
database: crate::config::DatabaseConfig {
|
||||
path: std::path::PathBuf::from(":memory:"),
|
||||
max_connections: 5,
|
||||
},
|
||||
providers: crate::config::ProviderConfig {
|
||||
openai: crate::config::OpenAIConfig {
|
||||
api_key_env: "OPENAI_API_KEY".to_string(),
|
||||
base_url: "".to_string(),
|
||||
default_model: "".to_string(),
|
||||
enabled: true,
|
||||
},
|
||||
gemini: crate::config::GeminiConfig {
|
||||
api_key_env: "GEMINI_API_KEY".to_string(),
|
||||
base_url: "".to_string(),
|
||||
default_model: "".to_string(),
|
||||
enabled: true,
|
||||
},
|
||||
deepseek: crate::config::DeepSeekConfig {
|
||||
api_key_env: "DEEPSEEK_API_KEY".to_string(),
|
||||
base_url: "".to_string(),
|
||||
default_model: "".to_string(),
|
||||
enabled: true,
|
||||
},
|
||||
grok: crate::config::GrokConfig {
|
||||
api_key_env: "GROK_API_KEY".to_string(),
|
||||
base_url: "".to_string(),
|
||||
default_model: "".to_string(),
|
||||
enabled: true,
|
||||
},
|
||||
ollama: crate::config::OllamaConfig {
|
||||
base_url: "".to_string(),
|
||||
enabled: true,
|
||||
models: vec![],
|
||||
},
|
||||
},
|
||||
model_mapping: crate::config::ModelMappingConfig { patterns: vec![] },
|
||||
pricing: crate::config::PricingConfig {
|
||||
openai: vec![],
|
||||
gemini: vec![],
|
||||
deepseek: vec![],
|
||||
grok: vec![],
|
||||
ollama: vec![],
|
||||
},
|
||||
config_path: None,
|
||||
encryption_key: "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f".to_string(),
|
||||
});
|
||||
|
||||
// Initialize encryption with the test key
|
||||
crypto::init_with_key(&config.encryption_key).expect("failed to initialize crypto");
|
||||
|
||||
AppState::new(
|
||||
config,
|
||||
provider_manager,
|
||||
pool,
|
||||
rate_limit_manager,
|
||||
model_registry,
|
||||
vec![], // auth_tokens
|
||||
)
|
||||
}
|
||||
|
||||
/// Create a test HTTP client
|
||||
pub fn create_test_client() -> reqwest::Client {
|
||||
reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(30))
|
||||
.build()
|
||||
.expect("Failed to create test HTTP client")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod integration_tests {
|
||||
use super::test_utils::*;
|
||||
use crate::{
|
||||
models::{ChatCompletionRequest, ChatMessage},
|
||||
server::router,
|
||||
utils::crypto,
|
||||
};
|
||||
use axum::{
|
||||
body::Body,
|
||||
http::{Request, StatusCode},
|
||||
};
|
||||
use mockito::Server;
|
||||
use serde_json::json;
|
||||
use sqlx::Row;
|
||||
use tower::util::ServiceExt;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_encrypted_provider_key_integration() {
|
||||
// Step 1: Setup test database and state
|
||||
let state = create_test_state().await;
|
||||
let pool = state.db_pool.clone();
|
||||
|
||||
// Step 2: Insert provider with encrypted API key
|
||||
let test_api_key = "test-openai-key-12345";
|
||||
let encrypted_key = crypto::encrypt(test_api_key).expect("Failed to encrypt test key");
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, api_key_encrypted, credit_balance, low_credit_threshold)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
"#,
|
||||
)
|
||||
.bind("openai")
|
||||
.bind("OpenAI")
|
||||
.bind(true)
|
||||
.bind("http://localhost:1234") // Mock server URL
|
||||
.bind(&encrypted_key)
|
||||
.bind(true) // api_key_encrypted flag
|
||||
.bind(100.0)
|
||||
.bind(5.0)
|
||||
.execute(&pool)
|
||||
.await
|
||||
.expect("Failed to update provider URL");
|
||||
|
||||
// Re-initialize provider with new URL
|
||||
state
|
||||
.provider_manager
|
||||
.initialize_provider("openai", &state.config, &pool)
|
||||
.await
|
||||
.expect("Failed to re-initialize provider");
|
||||
|
||||
// Step 4: Mock OpenAI API server
|
||||
let mut server = Server::new_async().await;
|
||||
let mock = server
|
||||
.mock("POST", "/chat/completions")
|
||||
.match_header("authorization", format!("Bearer {}", test_api_key).as_str())
|
||||
.with_status(200)
|
||||
.with_header("content-type", "application/json")
|
||||
.with_body(
|
||||
json!({
|
||||
"id": "chatcmpl-test",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": "gpt-3.5-turbo",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello, world!"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 5,
|
||||
"total_tokens": 15
|
||||
}
|
||||
})
|
||||
.to_string(),
|
||||
)
|
||||
.create_async()
|
||||
.await;
|
||||
|
||||
// Update provider base URL to use mock server
|
||||
sqlx::query("UPDATE provider_configs SET base_url = ? WHERE id = 'openai'")
|
||||
.bind(&server.url())
|
||||
.execute(&pool)
|
||||
.await
|
||||
.expect("Failed to update provider URL");
|
||||
|
||||
// Re-initialize provider with new URL
|
||||
state
|
||||
.provider_manager
|
||||
.initialize_provider("openai", &state.config, &pool)
|
||||
.await
|
||||
.expect("Failed to re-initialize provider");
|
||||
|
||||
// Step 5: Create test router and make request
|
||||
let app = router(state);
|
||||
|
||||
let request_body = ChatCompletionRequest {
|
||||
model: "gpt-3.5-turbo".to_string(),
|
||||
messages: vec![ChatMessage {
|
||||
role: "user".to_string(),
|
||||
content: crate::models::MessageContent::Text {
|
||||
content: "Hello".to_string(),
|
||||
},
|
||||
reasoning_content: None,
|
||||
tool_calls: None,
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
}],
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
top_k: None,
|
||||
n: None,
|
||||
stop: None,
|
||||
max_tokens: Some(100),
|
||||
presence_penalty: None,
|
||||
frequency_penalty: None,
|
||||
stream: Some(false),
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
};
|
||||
|
||||
let request = Request::builder()
|
||||
.method("POST")
|
||||
.uri("/v1/chat/completions")
|
||||
.header("content-type", "application/json")
|
||||
.header("authorization", "Bearer test-token")
|
||||
.body(Body::from(serde_json::to_string(&request_body).unwrap()))
|
||||
.unwrap();
|
||||
|
||||
// Step 6: Execute request through proxy
|
||||
let response = app
|
||||
.oneshot(request)
|
||||
.await
|
||||
.expect("Failed to execute request");
|
||||
|
||||
let status = response.status();
|
||||
println!("Response status: {}", status);
|
||||
|
||||
if status != StatusCode::OK {
|
||||
let body_bytes = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap();
|
||||
let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
|
||||
println!("Response body: {}", body_str);
|
||||
panic!("Response status is not OK: {}", status);
|
||||
}
|
||||
|
||||
assert_eq!(status, StatusCode::OK);
|
||||
|
||||
// Verify the mock was called
|
||||
mock.assert_async().await;
|
||||
|
||||
// Give the async logging task time to complete
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
|
||||
// Step 7: Verify usage was logged in database
|
||||
let log_row = sqlx::query("SELECT * FROM llm_requests WHERE client_id = 'client_test-tok' ORDER BY id DESC LIMIT 1")
|
||||
.fetch_one(&pool)
|
||||
.await
|
||||
.expect("Request log not found");
|
||||
|
||||
assert_eq!(log_row.get::<String, _>("provider"), "openai");
|
||||
assert_eq!(log_row.get::<String, _>("model"), "gpt-3.5-turbo");
|
||||
assert_eq!(log_row.get::<i64, _>("prompt_tokens"), 10);
|
||||
assert_eq!(log_row.get::<i64, _>("completion_tokens"), 5);
|
||||
assert_eq!(log_row.get::<i64, _>("total_tokens"), 15);
|
||||
assert_eq!(log_row.get::<String, _>("status"), "success");
|
||||
|
||||
// Verify client usage was updated
|
||||
let client_row = sqlx::query("SELECT total_requests, total_tokens, total_cost FROM clients WHERE client_id = 'client_test-tok'")
|
||||
.fetch_one(&pool)
|
||||
.await
|
||||
.expect("Client not found");
|
||||
|
||||
assert_eq!(client_row.get::<i64, _>("total_requests"), 1);
|
||||
assert_eq!(client_row.get::<i64, _>("total_tokens"), 15);
|
||||
}
|
||||
}
|
||||
@@ -1,242 +0,0 @@
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::Serialize;
|
||||
use sqlx::SqlitePool;
|
||||
use tokio::sync::broadcast;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::errors::AppError;
|
||||
|
||||
/// Request log entry for database storage
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct RequestLog {
|
||||
pub timestamp: DateTime<Utc>,
|
||||
pub client_id: String,
|
||||
pub provider: String,
|
||||
pub model: String,
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub reasoning_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
pub cache_read_tokens: u32,
|
||||
pub cache_write_tokens: u32,
|
||||
pub cost: f64,
|
||||
pub has_images: bool,
|
||||
pub status: String, // "success", "error"
|
||||
pub error_message: Option<String>,
|
||||
pub duration_ms: u64,
|
||||
}
|
||||
|
||||
/// Database operations for request logging
|
||||
pub struct RequestLogger {
|
||||
db_pool: SqlitePool,
|
||||
dashboard_tx: broadcast::Sender<serde_json::Value>,
|
||||
}
|
||||
|
||||
impl RequestLogger {
|
||||
pub fn new(db_pool: SqlitePool, dashboard_tx: broadcast::Sender<serde_json::Value>) -> Self {
|
||||
Self { db_pool, dashboard_tx }
|
||||
}
|
||||
|
||||
/// Log a request to the database (async, spawns a task)
|
||||
pub fn log_request(&self, log: RequestLog) {
|
||||
let pool = self.db_pool.clone();
|
||||
let tx = self.dashboard_tx.clone();
|
||||
|
||||
// Spawn async task to log without blocking response
|
||||
tokio::spawn(async move {
|
||||
// Broadcast to dashboard
|
||||
let broadcast_result = tx.send(serde_json::json!({
|
||||
"type": "request",
|
||||
"channel": "requests",
|
||||
"payload": log
|
||||
}));
|
||||
match broadcast_result {
|
||||
Ok(receivers) => info!("Broadcast request log to {} dashboard listeners", receivers),
|
||||
Err(_) => {} // No active WebSocket clients — expected when dashboard isn't open
|
||||
}
|
||||
|
||||
match Self::insert_log(&pool, log).await {
|
||||
Ok(()) => info!("Request logged to database successfully"),
|
||||
Err(e) => warn!("Failed to log request to database: {}", e),
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Insert a log entry into the database
|
||||
async fn insert_log(pool: &SqlitePool, log: RequestLog) -> Result<(), sqlx::Error> {
|
||||
let mut tx = pool.begin().await?;
|
||||
|
||||
// Ensure the client row exists (FK constraint requires it)
|
||||
sqlx::query(
|
||||
"INSERT OR IGNORE INTO clients (client_id, name, description) VALUES (?, ?, 'Auto-created from request')",
|
||||
)
|
||||
.bind(&log.client_id)
|
||||
.bind(&log.client_id)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO llm_requests
|
||||
(timestamp, client_id, provider, model, prompt_tokens, completion_tokens, reasoning_tokens, total_tokens, cache_read_tokens, cache_write_tokens, cost, has_images, status, error_message, duration_ms, request_body, response_body)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
"#,
|
||||
)
|
||||
.bind(log.timestamp)
|
||||
.bind(&log.client_id)
|
||||
.bind(&log.provider)
|
||||
.bind(&log.model)
|
||||
.bind(log.prompt_tokens as i64)
|
||||
.bind(log.completion_tokens as i64)
|
||||
.bind(log.reasoning_tokens as i64)
|
||||
.bind(log.total_tokens as i64)
|
||||
.bind(log.cache_read_tokens as i64)
|
||||
.bind(log.cache_write_tokens as i64)
|
||||
.bind(log.cost)
|
||||
.bind(log.has_images)
|
||||
.bind(&log.status)
|
||||
.bind(log.error_message)
|
||||
.bind(log.duration_ms as i64)
|
||||
.bind(None::<String>) // request_body - optional, not stored to save disk space
|
||||
.bind(None::<String>) // response_body - optional, not stored to save disk space
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
// Update client usage statistics
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE clients SET
|
||||
total_requests = total_requests + 1,
|
||||
total_tokens = total_tokens + ?,
|
||||
total_cost = total_cost + ?,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE client_id = ?
|
||||
"#,
|
||||
)
|
||||
.bind(log.total_tokens as i64)
|
||||
.bind(log.cost)
|
||||
.bind(&log.client_id)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
// Deduct from provider balance if successful.
|
||||
// Providers configured with billing_mode = 'postpaid' will not have their
|
||||
// credit_balance decremented. Use a conditional UPDATE so we don't need
|
||||
// a prior SELECT and avoid extra round-trips.
|
||||
if log.cost > 0.0 {
|
||||
sqlx::query(
|
||||
"UPDATE provider_configs SET credit_balance = credit_balance - ? WHERE id = ? AND (billing_mode IS NULL OR billing_mode != 'postpaid')",
|
||||
)
|
||||
.bind(log.cost)
|
||||
.bind(&log.provider)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
}
|
||||
|
||||
tx.commit().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// /// Middleware to log LLM API requests
|
||||
// /// TODO: Implement proper middleware that can extract response body details
|
||||
// pub async fn request_logging_middleware(
|
||||
// // Extract the authenticated client (if available)
|
||||
// auth_result: Result<AuthenticatedClient, AppError>,
|
||||
// request: Request,
|
||||
// next: Next,
|
||||
// ) -> Response {
|
||||
// let start_time = std::time::Instant::now();
|
||||
//
|
||||
// // Extract client_id from auth or use "unknown"
|
||||
// let client_id = match auth_result {
|
||||
// Ok(auth) => auth.client_id,
|
||||
// Err(_) => "unknown".to_string(),
|
||||
// };
|
||||
//
|
||||
// // Try to extract request details
|
||||
// let (request_parts, request_body) = request.into_parts();
|
||||
//
|
||||
// // Clone request parts for logging
|
||||
// let path = request_parts.uri.path().to_string();
|
||||
//
|
||||
// // Check if this is a chat completion request
|
||||
// let is_chat_completion = path == "/v1/chat/completions";
|
||||
//
|
||||
// // Reconstruct request for downstream handlers
|
||||
// let request = Request::from_parts(request_parts, request_body);
|
||||
//
|
||||
// // Process request and get response
|
||||
// let response = next.run(request).await;
|
||||
//
|
||||
// // Calculate duration
|
||||
// let duration = start_time.elapsed();
|
||||
// let duration_ms = duration.as_millis() as u64;
|
||||
//
|
||||
// // Log basic request info
|
||||
// info!(
|
||||
// "Request from {} to {} - Status: {} - Duration: {}ms",
|
||||
// client_id,
|
||||
// path,
|
||||
// response.status().as_u16(),
|
||||
// duration_ms
|
||||
// );
|
||||
//
|
||||
// // TODO: Extract more details from request/response for logging
|
||||
// // For now, we'll need to modify the server handler to pass additional context
|
||||
//
|
||||
// response
|
||||
// }
|
||||
|
||||
/// Context for request logging that can be passed through extensions
|
||||
#[derive(Clone)]
|
||||
pub struct LoggingContext {
|
||||
pub client_id: String,
|
||||
pub provider_name: String,
|
||||
pub model: String,
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
pub cost: f64,
|
||||
pub has_images: bool,
|
||||
pub error: Option<AppError>,
|
||||
}
|
||||
|
||||
impl LoggingContext {
|
||||
pub fn new(client_id: String, provider_name: String, model: String) -> Self {
|
||||
Self {
|
||||
client_id,
|
||||
provider_name,
|
||||
model,
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 0,
|
||||
cost: 0.0,
|
||||
has_images: false,
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_token_counts(mut self, prompt_tokens: u32, completion_tokens: u32) -> Self {
|
||||
self.prompt_tokens = prompt_tokens;
|
||||
self.completion_tokens = completion_tokens;
|
||||
self.total_tokens = prompt_tokens + completion_tokens;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_cost(mut self, cost: f64) -> Self {
|
||||
self.cost = cost;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_images(mut self, has_images: bool) -> Self {
|
||||
self.has_images = has_images;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_error(mut self, error: AppError) -> Self {
|
||||
self.error = Some(error);
|
||||
self
|
||||
}
|
||||
}
|
||||
-96
@@ -1,96 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use axum::{Router, routing::get};
|
||||
use std::net::SocketAddr;
|
||||
use tracing::{error, info};
|
||||
|
||||
use llm_proxy::{
|
||||
config::AppConfig,
|
||||
dashboard, database,
|
||||
providers::ProviderManager,
|
||||
rate_limiting::{CircuitBreakerConfig, RateLimitManager, RateLimiterConfig},
|
||||
server,
|
||||
state::AppState,
|
||||
utils::crypto,
|
||||
};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
// Initialize tracing (logging)
|
||||
tracing_subscriber::fmt()
|
||||
.with_max_level(tracing::Level::INFO)
|
||||
.with_target(false)
|
||||
.init();
|
||||
|
||||
info!("Starting LLM Proxy Gateway v{}", env!("CARGO_PKG_VERSION"));
|
||||
|
||||
// Load configuration
|
||||
let config = AppConfig::load().await?;
|
||||
info!("Configuration loaded from {:?}", config.config_path);
|
||||
|
||||
// Initialize encryption
|
||||
crypto::init_with_key(&config.encryption_key)?;
|
||||
info!("Encryption initialized");
|
||||
|
||||
// Initialize database connection pool
|
||||
let db_pool = database::init(&config.database).await?;
|
||||
info!("Database initialized at {:?}", config.database.path);
|
||||
|
||||
// Initialize provider manager with configured providers
|
||||
let provider_manager = ProviderManager::new();
|
||||
|
||||
// Initialize all supported providers (they handle their own enabled check)
|
||||
let supported_providers = vec!["openai", "gemini", "deepseek", "grok", "ollama"];
|
||||
for name in supported_providers {
|
||||
if let Err(e) = provider_manager.initialize_provider(name, &config, &db_pool).await {
|
||||
error!("Failed to initialize provider {}: {}", name, e);
|
||||
}
|
||||
}
|
||||
|
||||
// Create rate limit manager
|
||||
let rate_limit_manager = RateLimitManager::new(RateLimiterConfig::default(), CircuitBreakerConfig::default());
|
||||
|
||||
// Fetch model registry from models.dev
|
||||
let model_registry = match llm_proxy::utils::registry::fetch_registry().await {
|
||||
Ok(registry) => registry,
|
||||
Err(e) => {
|
||||
error!("Failed to fetch model registry: {}. Using empty registry.", e);
|
||||
llm_proxy::models::registry::ModelRegistry {
|
||||
providers: std::collections::HashMap::new(),
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Create application state
|
||||
let state = AppState::new(
|
||||
config.clone(),
|
||||
provider_manager,
|
||||
db_pool,
|
||||
rate_limit_manager,
|
||||
model_registry,
|
||||
config.server.auth_tokens.clone(),
|
||||
);
|
||||
|
||||
// Initialize model config cache and start background refresh (every 30s)
|
||||
state.model_config_cache.refresh().await;
|
||||
state.model_config_cache.clone().start_refresh_task(30);
|
||||
info!("Model config cache initialized");
|
||||
|
||||
// Create application router
|
||||
let app = Router::new()
|
||||
.route("/health", get(health_check))
|
||||
.merge(server::router(state.clone()))
|
||||
.merge(dashboard::router(state.clone()));
|
||||
|
||||
// Start server
|
||||
let addr = SocketAddr::from(([0, 0, 0, 0], config.server.port));
|
||||
info!("Server listening on http://{}", addr);
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(&addr).await?;
|
||||
axum::serve(listener, app).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn health_check() -> &'static str {
|
||||
"OK"
|
||||
}
|
||||
@@ -1,381 +0,0 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
pub mod registry;
|
||||
|
||||
// ========== OpenAI-compatible Request/Response Structs ==========
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<ChatMessage>,
|
||||
#[serde(default)]
|
||||
pub temperature: Option<f64>,
|
||||
#[serde(default)]
|
||||
pub top_p: Option<f64>,
|
||||
#[serde(default)]
|
||||
pub top_k: Option<u32>,
|
||||
#[serde(default)]
|
||||
pub n: Option<u32>,
|
||||
#[serde(default)]
|
||||
pub stop: Option<Value>, // Can be string or array of strings
|
||||
#[serde(default)]
|
||||
pub max_tokens: Option<u32>,
|
||||
#[serde(default)]
|
||||
pub presence_penalty: Option<f64>,
|
||||
#[serde(default)]
|
||||
pub frequency_penalty: Option<f64>,
|
||||
#[serde(default)]
|
||||
pub stream: Option<bool>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<Vec<Tool>>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub tool_choice: Option<ToolChoice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatMessage {
|
||||
pub role: String, // "system", "user", "assistant", "tool"
|
||||
#[serde(flatten)]
|
||||
pub content: MessageContent,
|
||||
#[serde(alias = "reasoning", alias = "thought", skip_serializing_if = "Option::is_none")]
|
||||
pub reasoning_content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub name: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum MessageContent {
|
||||
Text { content: String },
|
||||
Parts { content: Vec<ContentPartValue> },
|
||||
None, // Handle cases where content might be null but reasoning is present
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ContentPartValue {
|
||||
Text { text: String },
|
||||
ImageUrl { image_url: ImageUrl },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ImageUrl {
|
||||
pub url: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub detail: Option<String>,
|
||||
}
|
||||
|
||||
// ========== Tool-Calling Types ==========
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Tool {
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: String,
|
||||
pub function: FunctionDef,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FunctionDef {
|
||||
pub name: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub parameters: Option<Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ToolChoice {
|
||||
Mode(String), // "auto", "none", "required"
|
||||
Specific(ToolChoiceSpecific),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolChoiceSpecific {
|
||||
#[serde(rename = "type")]
|
||||
pub choice_type: String,
|
||||
pub function: ToolChoiceFunction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolChoiceFunction {
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCall {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
pub call_type: String,
|
||||
pub function: FunctionCall,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FunctionCall {
|
||||
pub name: String,
|
||||
pub arguments: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCallDelta {
|
||||
pub index: u32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<String>,
|
||||
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
|
||||
pub call_type: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub function: Option<FunctionCallDelta>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FunctionCallDelta {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub name: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub arguments: Option<String>,
|
||||
}
|
||||
|
||||
// ========== OpenAI-compatible Response Structs ==========
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub choices: Vec<ChatChoice>,
|
||||
pub usage: Option<Usage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatChoice {
|
||||
pub index: u32,
|
||||
pub message: ChatMessage,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub reasoning_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cache_read_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cache_write_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
// ========== Streaming Response Structs ==========
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionStreamResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub choices: Vec<ChatStreamChoice>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub usage: Option<Usage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatStreamChoice {
|
||||
pub index: u32,
|
||||
pub delta: ChatStreamDelta,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatStreamDelta {
|
||||
pub role: Option<String>,
|
||||
pub content: Option<String>,
|
||||
#[serde(alias = "reasoning", alias = "thought", skip_serializing_if = "Option::is_none")]
|
||||
pub reasoning_content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCallDelta>>,
|
||||
}
|
||||
|
||||
// ========== Unified Request Format (for internal use) ==========
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct UnifiedRequest {
|
||||
pub client_id: String,
|
||||
pub model: String,
|
||||
pub messages: Vec<UnifiedMessage>,
|
||||
pub temperature: Option<f64>,
|
||||
pub top_p: Option<f64>,
|
||||
pub top_k: Option<u32>,
|
||||
pub n: Option<u32>,
|
||||
pub stop: Option<Vec<String>>,
|
||||
pub max_tokens: Option<u32>,
|
||||
pub presence_penalty: Option<f64>,
|
||||
pub frequency_penalty: Option<f64>,
|
||||
pub stream: bool,
|
||||
pub has_images: bool,
|
||||
pub tools: Option<Vec<Tool>>,
|
||||
pub tool_choice: Option<ToolChoice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct UnifiedMessage {
|
||||
pub role: String,
|
||||
pub content: Vec<ContentPart>,
|
||||
pub reasoning_content: Option<String>,
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
pub name: Option<String>,
|
||||
pub tool_call_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ContentPart {
|
||||
Text { text: String },
|
||||
Image(crate::multimodal::ImageInput),
|
||||
}
|
||||
|
||||
// ========== Provider-specific Structs ==========
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct OpenAIRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<OpenAIMessage>,
|
||||
pub temperature: Option<f64>,
|
||||
pub max_tokens: Option<u32>,
|
||||
pub stream: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct OpenAIMessage {
|
||||
pub role: String,
|
||||
pub content: Vec<OpenAIContentPart>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum OpenAIContentPart {
|
||||
Text { text: String },
|
||||
ImageUrl { image_url: ImageUrl },
|
||||
}
|
||||
|
||||
// Note: ImageUrl struct is defined earlier in the file
|
||||
|
||||
// ========== Conversion Traits ==========
|
||||
|
||||
pub trait ToOpenAI {
|
||||
fn to_openai(&self) -> Result<OpenAIRequest, anyhow::Error>;
|
||||
}
|
||||
|
||||
pub trait FromOpenAI {
|
||||
fn from_openai(request: &OpenAIRequest) -> Result<Self, anyhow::Error>
|
||||
where
|
||||
Self: Sized;
|
||||
}
|
||||
|
||||
impl UnifiedRequest {
|
||||
/// Hydrate all image content by fetching URLs and converting to base64/bytes
|
||||
pub async fn hydrate_images(&mut self) -> anyhow::Result<()> {
|
||||
if !self.has_images {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
for msg in &mut self.messages {
|
||||
for part in &mut msg.content {
|
||||
if let ContentPart::Image(image_input) = part {
|
||||
// Pre-fetch and validate if it's a URL
|
||||
if let crate::multimodal::ImageInput::Url(_url) = image_input {
|
||||
let (base64_data, mime_type) = image_input.to_base64().await?;
|
||||
*image_input = crate::multimodal::ImageInput::Base64 {
|
||||
data: base64_data,
|
||||
mime_type,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<ChatCompletionRequest> for UnifiedRequest {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn try_from(req: ChatCompletionRequest) -> Result<Self, Self::Error> {
|
||||
let mut has_images = false;
|
||||
|
||||
// Convert OpenAI-compatible request to unified format
|
||||
let messages = req
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|msg| {
|
||||
let (content, _images_in_message) = match msg.content {
|
||||
MessageContent::Text { content } => (vec![ContentPart::Text { text: content }], false),
|
||||
MessageContent::Parts { content } => {
|
||||
let mut unified_content = Vec::new();
|
||||
let mut has_images_in_msg = false;
|
||||
|
||||
for part in content {
|
||||
match part {
|
||||
ContentPartValue::Text { text } => {
|
||||
unified_content.push(ContentPart::Text { text });
|
||||
}
|
||||
ContentPartValue::ImageUrl { image_url } => {
|
||||
has_images_in_msg = true;
|
||||
has_images = true;
|
||||
unified_content.push(ContentPart::Image(crate::multimodal::ImageInput::from_url(
|
||||
image_url.url,
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(unified_content, has_images_in_msg)
|
||||
}
|
||||
MessageContent::None => (vec![], false),
|
||||
};
|
||||
|
||||
UnifiedMessage {
|
||||
role: msg.role,
|
||||
content,
|
||||
reasoning_content: msg.reasoning_content,
|
||||
tool_calls: msg.tool_calls,
|
||||
name: msg.name,
|
||||
tool_call_id: msg.tool_call_id,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let stop = match req.stop {
|
||||
Some(Value::String(s)) => Some(vec![s]),
|
||||
Some(Value::Array(a)) => Some(
|
||||
a.into_iter()
|
||||
.filter_map(|v| v.as_str().map(|s| s.to_string()))
|
||||
.collect(),
|
||||
),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
Ok(UnifiedRequest {
|
||||
client_id: String::new(), // Will be populated by auth middleware
|
||||
model: req.model,
|
||||
messages,
|
||||
temperature: req.temperature,
|
||||
top_p: req.top_p,
|
||||
top_k: req.top_k,
|
||||
n: req.n,
|
||||
stop,
|
||||
max_tokens: req.max_tokens,
|
||||
presence_penalty: req.presence_penalty,
|
||||
frequency_penalty: req.frequency_penalty,
|
||||
stream: req.stream.unwrap_or(false),
|
||||
has_images,
|
||||
tools: req.tools,
|
||||
tool_choice: req.tool_choice,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,219 +0,0 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelRegistry {
|
||||
#[serde(flatten)]
|
||||
pub providers: HashMap<String, ProviderInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ProviderInfo {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub models: HashMap<String, ModelMetadata>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelMetadata {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub cost: Option<ModelCost>,
|
||||
pub limit: Option<ModelLimit>,
|
||||
pub modalities: Option<ModelModalities>,
|
||||
pub tool_call: Option<bool>,
|
||||
pub reasoning: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelCost {
|
||||
pub input: f64,
|
||||
pub output: f64,
|
||||
pub cache_read: Option<f64>,
|
||||
pub cache_write: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelLimit {
|
||||
pub context: u32,
|
||||
pub output: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelModalities {
|
||||
pub input: Vec<String>,
|
||||
pub output: Vec<String>,
|
||||
}
|
||||
|
||||
/// A model entry paired with its provider ID, returned by listing/filtering methods.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelEntry<'a> {
|
||||
pub model_key: &'a str,
|
||||
pub provider_id: &'a str,
|
||||
pub provider_name: &'a str,
|
||||
pub metadata: &'a ModelMetadata,
|
||||
}
|
||||
|
||||
/// Filter criteria for listing models. All fields are optional; `None` means no filter.
|
||||
#[derive(Debug, Default, Clone, Deserialize)]
|
||||
pub struct ModelFilter {
|
||||
/// Filter by provider ID (exact match).
|
||||
pub provider: Option<String>,
|
||||
/// Text search on model ID or name (case-insensitive substring).
|
||||
pub search: Option<String>,
|
||||
/// Filter by input modality (e.g. "image", "text").
|
||||
pub modality: Option<String>,
|
||||
/// Only models that support tool calling.
|
||||
pub tool_call: Option<bool>,
|
||||
/// Only models that support reasoning.
|
||||
pub reasoning: Option<bool>,
|
||||
/// Only models that have pricing data.
|
||||
pub has_cost: Option<bool>,
|
||||
}
|
||||
|
||||
/// Sort field for model listings.
|
||||
#[derive(Debug, Clone, Deserialize, Default, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ModelSortBy {
|
||||
#[default]
|
||||
Name,
|
||||
Id,
|
||||
Provider,
|
||||
ContextLimit,
|
||||
InputCost,
|
||||
OutputCost,
|
||||
}
|
||||
|
||||
/// Sort direction.
|
||||
#[derive(Debug, Clone, Deserialize, Default, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum SortOrder {
|
||||
#[default]
|
||||
Asc,
|
||||
Desc,
|
||||
}
|
||||
|
||||
impl ModelRegistry {
|
||||
/// Find a model by its ID (searching across all providers)
|
||||
pub fn find_model(&self, model_id: &str) -> Option<&ModelMetadata> {
|
||||
// First try exact match if the key in models map matches the ID
|
||||
for provider in self.providers.values() {
|
||||
if let Some(model) = provider.models.get(model_id) {
|
||||
return Some(model);
|
||||
}
|
||||
}
|
||||
|
||||
// Try searching for the model ID inside the metadata if the key was different
|
||||
for provider in self.providers.values() {
|
||||
for model in provider.models.values() {
|
||||
if model.id == model_id {
|
||||
return Some(model);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// List all models with optional filtering and sorting.
|
||||
pub fn list_models(
|
||||
&self,
|
||||
filter: &ModelFilter,
|
||||
sort_by: &ModelSortBy,
|
||||
sort_order: &SortOrder,
|
||||
) -> Vec<ModelEntry<'_>> {
|
||||
let mut entries: Vec<ModelEntry<'_>> = Vec::new();
|
||||
|
||||
for (p_id, p_info) in &self.providers {
|
||||
// Provider filter
|
||||
if let Some(ref prov) = filter.provider {
|
||||
if p_id != prov {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
for (m_key, m_meta) in &p_info.models {
|
||||
// Text search filter
|
||||
if let Some(ref search) = filter.search {
|
||||
let search_lower = search.to_lowercase();
|
||||
if !m_meta.id.to_lowercase().contains(&search_lower)
|
||||
&& !m_meta.name.to_lowercase().contains(&search_lower)
|
||||
&& !m_key.to_lowercase().contains(&search_lower)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Modality filter
|
||||
if let Some(ref modality) = filter.modality {
|
||||
let has_modality = m_meta
|
||||
.modalities
|
||||
.as_ref()
|
||||
.is_some_and(|m| m.input.iter().any(|i| i.eq_ignore_ascii_case(modality)));
|
||||
if !has_modality {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Tool call filter
|
||||
if let Some(tc) = filter.tool_call {
|
||||
if m_meta.tool_call.unwrap_or(false) != tc {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Reasoning filter
|
||||
if let Some(r) = filter.reasoning {
|
||||
if m_meta.reasoning.unwrap_or(false) != r {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Has cost filter
|
||||
if let Some(hc) = filter.has_cost {
|
||||
if hc != m_meta.cost.is_some() {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
entries.push(ModelEntry {
|
||||
model_key: m_key,
|
||||
provider_id: p_id,
|
||||
provider_name: &p_info.name,
|
||||
metadata: m_meta,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Sort
|
||||
entries.sort_by(|a, b| {
|
||||
let cmp = match sort_by {
|
||||
ModelSortBy::Name => a.metadata.name.to_lowercase().cmp(&b.metadata.name.to_lowercase()),
|
||||
ModelSortBy::Id => a.model_key.cmp(b.model_key),
|
||||
ModelSortBy::Provider => a.provider_id.cmp(b.provider_id),
|
||||
ModelSortBy::ContextLimit => {
|
||||
let a_ctx = a.metadata.limit.as_ref().map(|l| l.context).unwrap_or(0);
|
||||
let b_ctx = b.metadata.limit.as_ref().map(|l| l.context).unwrap_or(0);
|
||||
a_ctx.cmp(&b_ctx)
|
||||
}
|
||||
ModelSortBy::InputCost => {
|
||||
let a_cost = a.metadata.cost.as_ref().map(|c| c.input).unwrap_or(0.0);
|
||||
let b_cost = b.metadata.cost.as_ref().map(|c| c.input).unwrap_or(0.0);
|
||||
a_cost.partial_cmp(&b_cost).unwrap_or(std::cmp::Ordering::Equal)
|
||||
}
|
||||
ModelSortBy::OutputCost => {
|
||||
let a_cost = a.metadata.cost.as_ref().map(|c| c.output).unwrap_or(0.0);
|
||||
let b_cost = b.metadata.cost.as_ref().map(|c| c.output).unwrap_or(0.0);
|
||||
a_cost.partial_cmp(&b_cost).unwrap_or(std::cmp::Ordering::Equal)
|
||||
}
|
||||
};
|
||||
|
||||
match sort_order {
|
||||
SortOrder::Asc => cmp,
|
||||
SortOrder::Desc => cmp.reverse(),
|
||||
}
|
||||
});
|
||||
|
||||
entries
|
||||
}
|
||||
}
|
||||
@@ -1,299 +0,0 @@
|
||||
//! Multimodal support for image processing and conversion
|
||||
//!
|
||||
//! This module handles:
|
||||
//! 1. Image format detection and conversion
|
||||
//! 2. Base64 encoding/decoding
|
||||
//! 3. URL fetching for images
|
||||
//! 4. Provider-specific image format conversion
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use base64::{Engine as _, engine::general_purpose};
|
||||
use std::sync::LazyLock;
|
||||
use tracing::{info, warn};
|
||||
|
||||
/// Shared HTTP client for image fetching — avoids creating a new TCP+TLS
|
||||
/// connection for every image URL.
|
||||
static IMAGE_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
|
||||
reqwest::Client::builder()
|
||||
.connect_timeout(std::time::Duration::from_secs(5))
|
||||
.timeout(std::time::Duration::from_secs(30))
|
||||
.pool_idle_timeout(std::time::Duration::from_secs(60))
|
||||
.build()
|
||||
.expect("Failed to build image HTTP client")
|
||||
});
|
||||
|
||||
/// Supported image formats for multimodal input
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ImageInput {
|
||||
/// Base64-encoded image data with MIME type
|
||||
Base64 { data: String, mime_type: String },
|
||||
/// URL to fetch image from
|
||||
Url(String),
|
||||
/// Raw bytes with MIME type
|
||||
Bytes { data: Vec<u8>, mime_type: String },
|
||||
}
|
||||
|
||||
impl ImageInput {
|
||||
/// Create ImageInput from base64 string
|
||||
pub fn from_base64(data: String, mime_type: String) -> Self {
|
||||
Self::Base64 { data, mime_type }
|
||||
}
|
||||
|
||||
/// Create ImageInput from URL
|
||||
pub fn from_url(url: String) -> Self {
|
||||
Self::Url(url)
|
||||
}
|
||||
|
||||
/// Create ImageInput from raw bytes
|
||||
pub fn from_bytes(data: Vec<u8>, mime_type: String) -> Self {
|
||||
Self::Bytes { data, mime_type }
|
||||
}
|
||||
|
||||
/// Get MIME type if available
|
||||
pub fn mime_type(&self) -> Option<&str> {
|
||||
match self {
|
||||
Self::Base64 { mime_type, .. } => Some(mime_type),
|
||||
Self::Bytes { mime_type, .. } => Some(mime_type),
|
||||
Self::Url(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert to base64 if not already
|
||||
pub async fn to_base64(&self) -> Result<(String, String)> {
|
||||
match self {
|
||||
Self::Base64 { data, mime_type } => Ok((data.clone(), mime_type.clone())),
|
||||
Self::Bytes { data, mime_type } => {
|
||||
let base64_data = general_purpose::STANDARD.encode(data);
|
||||
Ok((base64_data, mime_type.clone()))
|
||||
}
|
||||
Self::Url(url) => {
|
||||
// Fetch image from URL using shared client
|
||||
info!("Fetching image from URL: {}", url);
|
||||
let response = IMAGE_CLIENT
|
||||
.get(url)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to fetch image from URL")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
anyhow::bail!("Failed to fetch image: HTTP {}", response.status());
|
||||
}
|
||||
|
||||
let mime_type = response
|
||||
.headers()
|
||||
.get(reqwest::header::CONTENT_TYPE)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.unwrap_or("image/jpeg")
|
||||
.to_string();
|
||||
|
||||
let bytes = response.bytes().await.context("Failed to read image bytes")?;
|
||||
|
||||
let base64_data = general_purpose::STANDARD.encode(&bytes);
|
||||
Ok((base64_data, mime_type))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get image dimensions (width, height)
|
||||
pub async fn get_dimensions(&self) -> Result<(u32, u32)> {
|
||||
let bytes = match self {
|
||||
Self::Base64 { data, .. } => general_purpose::STANDARD
|
||||
.decode(data)
|
||||
.context("Failed to decode base64")?,
|
||||
Self::Bytes { data, .. } => data.clone(),
|
||||
Self::Url(_) => {
|
||||
let (base64_data, _) = self.to_base64().await?;
|
||||
general_purpose::STANDARD
|
||||
.decode(&base64_data)
|
||||
.context("Failed to decode base64")?
|
||||
}
|
||||
};
|
||||
|
||||
let img = image::load_from_memory(&bytes).context("Failed to load image from bytes")?;
|
||||
Ok((img.width(), img.height()))
|
||||
}
|
||||
|
||||
/// Validate image size and format
|
||||
pub async fn validate(&self, max_size_mb: f64) -> Result<()> {
|
||||
let (width, height) = self.get_dimensions().await?;
|
||||
|
||||
// Check dimensions
|
||||
if width > 4096 || height > 4096 {
|
||||
warn!("Image dimensions too large: {}x{}", width, height);
|
||||
// Continue anyway, but log warning
|
||||
}
|
||||
|
||||
// Check file size
|
||||
let size_bytes = match self {
|
||||
Self::Base64 { data, .. } => {
|
||||
// Base64 size is ~4/3 of original
|
||||
(data.len() as f64 * 0.75) as usize
|
||||
}
|
||||
Self::Bytes { data, .. } => data.len(),
|
||||
Self::Url(_) => {
|
||||
// For URLs, we'd need to fetch to check size
|
||||
// Skip size check for URLs for now
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
let size_mb = size_bytes as f64 / (1024.0 * 1024.0);
|
||||
if size_mb > max_size_mb {
|
||||
anyhow::bail!("Image too large: {:.2}MB > {:.2}MB limit", size_mb, max_size_mb);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Provider-specific image format conversion
|
||||
pub struct ImageConverter;
|
||||
|
||||
impl ImageConverter {
|
||||
/// Convert image to OpenAI-compatible format
|
||||
pub async fn to_openai_format(image: &ImageInput) -> Result<serde_json::Value> {
|
||||
let (base64_data, mime_type) = image.to_base64().await?;
|
||||
|
||||
// OpenAI expects data URL format: "data:image/jpeg;base64,{data}"
|
||||
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
|
||||
|
||||
Ok(serde_json::json!({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": data_url,
|
||||
"detail": "auto" // Can be "low", "high", or "auto"
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
/// Convert image to Gemini-compatible format
|
||||
pub async fn to_gemini_format(image: &ImageInput) -> Result<serde_json::Value> {
|
||||
let (base64_data, mime_type) = image.to_base64().await?;
|
||||
|
||||
// Gemini expects inline data format
|
||||
Ok(serde_json::json!({
|
||||
"inline_data": {
|
||||
"mime_type": mime_type,
|
||||
"data": base64_data
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
/// Convert image to DeepSeek-compatible format
|
||||
pub async fn to_deepseek_format(image: &ImageInput) -> Result<serde_json::Value> {
|
||||
// DeepSeek uses OpenAI-compatible format for vision models
|
||||
Self::to_openai_format(image).await
|
||||
}
|
||||
|
||||
/// Detect if a model supports multimodal input
|
||||
pub fn model_supports_multimodal(model: &str) -> bool {
|
||||
// OpenAI vision models
|
||||
if (model.starts_with("gpt-4") && (model.contains("vision") || model.contains("-v") || model.contains("4o")))
|
||||
|| model.starts_with("o1-")
|
||||
|| model.starts_with("o3-")
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
// Gemini vision models
|
||||
if model.starts_with("gemini") {
|
||||
// Most Gemini models support vision
|
||||
return true;
|
||||
}
|
||||
|
||||
// DeepSeek vision models
|
||||
if model.starts_with("deepseek-vl") {
|
||||
return true;
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse OpenAI-compatible multimodal message content
|
||||
pub fn parse_openai_content(content: &serde_json::Value) -> Result<Vec<(String, Option<ImageInput>)>> {
|
||||
let mut parts = Vec::new();
|
||||
|
||||
if let Some(content_str) = content.as_str() {
|
||||
// Simple text content
|
||||
parts.push((content_str.to_string(), None));
|
||||
} else if let Some(content_array) = content.as_array() {
|
||||
// Array of content parts (text and/or images)
|
||||
for part in content_array {
|
||||
if let Some(part_obj) = part.as_object()
|
||||
&& let Some(part_type) = part_obj.get("type").and_then(|t| t.as_str())
|
||||
{
|
||||
match part_type {
|
||||
"text" => {
|
||||
if let Some(text) = part_obj.get("text").and_then(|t| t.as_str()) {
|
||||
parts.push((text.to_string(), None));
|
||||
}
|
||||
}
|
||||
"image_url" => {
|
||||
if let Some(image_url_obj) = part_obj.get("image_url").and_then(|o| o.as_object())
|
||||
&& let Some(url) = image_url_obj.get("url").and_then(|u| u.as_str())
|
||||
{
|
||||
if url.starts_with("data:") {
|
||||
// Parse data URL
|
||||
if let Some((mime_type, data)) = parse_data_url(url) {
|
||||
let image_input = ImageInput::from_base64(data, mime_type);
|
||||
parts.push(("".to_string(), Some(image_input)));
|
||||
}
|
||||
} else {
|
||||
// Regular URL
|
||||
let image_input = ImageInput::from_url(url.to_string());
|
||||
parts.push(("".to_string(), Some(image_input)));
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
warn!("Unknown content part type: {}", part_type);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(parts)
|
||||
}
|
||||
|
||||
/// Parse data URL (data:image/jpeg;base64,{data})
|
||||
fn parse_data_url(data_url: &str) -> Option<(String, String)> {
|
||||
if !data_url.starts_with("data:") {
|
||||
return None;
|
||||
}
|
||||
|
||||
let parts: Vec<&str> = data_url[5..].split(";base64,").collect();
|
||||
if parts.len() != 2 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mime_type = parts[0].to_string();
|
||||
let data = parts[1].to_string();
|
||||
|
||||
Some((mime_type, data))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_data_url() {
|
||||
let test_url = "data:image/jpeg;base64,SGVsbG8gV29ybGQ="; // "Hello World" in base64
|
||||
let (mime_type, data) = parse_data_url(test_url).unwrap();
|
||||
|
||||
assert_eq!(mime_type, "image/jpeg");
|
||||
assert_eq!(data, "SGVsbG8gV29ybGQ=");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_model_supports_multimodal() {
|
||||
assert!(ImageConverter::model_supports_multimodal("gpt-4-vision-preview"));
|
||||
assert!(ImageConverter::model_supports_multimodal("gpt-4o"));
|
||||
assert!(ImageConverter::model_supports_multimodal("gemini-pro-vision"));
|
||||
assert!(ImageConverter::model_supports_multimodal("gemini-pro"));
|
||||
assert!(!ImageConverter::model_supports_multimodal("gpt-3.5-turbo"));
|
||||
assert!(!ImageConverter::model_supports_multimodal("claude-3-opus"));
|
||||
}
|
||||
}
|
||||
@@ -1,268 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::BoxStream;
|
||||
use futures::StreamExt;
|
||||
|
||||
use super::helpers;
|
||||
use super::{ProviderResponse, ProviderStreamChunk};
|
||||
use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest};
|
||||
|
||||
pub struct DeepSeekProvider {
|
||||
client: reqwest::Client,
|
||||
config: crate::config::DeepSeekConfig,
|
||||
api_key: String,
|
||||
pricing: Vec<crate::config::ModelPricing>,
|
||||
}
|
||||
|
||||
impl DeepSeekProvider {
|
||||
pub fn new(config: &crate::config::DeepSeekConfig, app_config: &AppConfig) -> Result<Self> {
|
||||
let api_key = app_config.get_api_key("deepseek")?;
|
||||
Self::new_with_key(config, app_config, api_key)
|
||||
}
|
||||
|
||||
pub fn new_with_key(
|
||||
config: &crate::config::DeepSeekConfig,
|
||||
app_config: &AppConfig,
|
||||
api_key: String,
|
||||
) -> Result<Self> {
|
||||
let client = reqwest::Client::builder()
|
||||
.connect_timeout(std::time::Duration::from_secs(5))
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.pool_idle_timeout(std::time::Duration::from_secs(90))
|
||||
.pool_max_idle_per_host(4)
|
||||
.tcp_keepalive(std::time::Duration::from_secs(30))
|
||||
.build()?;
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
config: config.clone(),
|
||||
api_key,
|
||||
pricing: app_config.pricing.deepseek.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl super::Provider for DeepSeekProvider {
|
||||
fn name(&self) -> &str {
|
||||
"deepseek"
|
||||
}
|
||||
|
||||
fn supports_model(&self, model: &str) -> bool {
|
||||
model.starts_with("deepseek-") || model.contains("deepseek")
|
||||
}
|
||||
|
||||
fn supports_multimodal(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
||||
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
||||
let mut body = helpers::build_openai_body(&request, messages_json, false);
|
||||
|
||||
// Sanitize and fix for deepseek-reasoner (R1)
|
||||
if request.model == "deepseek-reasoner" {
|
||||
if let Some(obj) = body.as_object_mut() {
|
||||
// Remove unsupported parameters
|
||||
obj.remove("temperature");
|
||||
obj.remove("top_p");
|
||||
obj.remove("presence_penalty");
|
||||
obj.remove("frequency_penalty");
|
||||
obj.remove("logit_bias");
|
||||
obj.remove("logprobs");
|
||||
obj.remove("top_logprobs");
|
||||
|
||||
// ENSURE: EVERY assistant message must have reasoning_content and valid content
|
||||
if let Some(messages) = obj.get_mut("messages").and_then(|m| m.as_array_mut()) {
|
||||
for m in messages {
|
||||
if m["role"].as_str() == Some("assistant") {
|
||||
// DeepSeek R1 requires reasoning_content for consistency in history.
|
||||
if m.get("reasoning_content").is_none() || m["reasoning_content"].is_null() {
|
||||
m["reasoning_content"] = serde_json::json!(" ");
|
||||
}
|
||||
// DeepSeek R1 often requires content to be a string, not null/array
|
||||
if m.get("content").is_none() || m["content"].is_null() || m["content"].is_array() {
|
||||
m["content"] = serde_json::json!("");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(format!("{}/chat/completions", self.config.base_url))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
tracing::error!("DeepSeek API error ({}): {}", status, error_text);
|
||||
tracing::error!("Offending DeepSeek Request Body: {}", serde_json::to_string(&body).unwrap_or_default());
|
||||
return Err(AppError::ProviderError(format!("DeepSeek API error ({}): {}", status, error_text)));
|
||||
}
|
||||
|
||||
let resp_json: serde_json::Value = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||
|
||||
helpers::parse_openai_response(&resp_json, request.model)
|
||||
}
|
||||
|
||||
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
|
||||
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
|
||||
}
|
||||
|
||||
fn calculate_cost(
|
||||
&self,
|
||||
model: &str,
|
||||
prompt_tokens: u32,
|
||||
completion_tokens: u32,
|
||||
cache_read_tokens: u32,
|
||||
cache_write_tokens: u32,
|
||||
registry: &crate::models::registry::ModelRegistry,
|
||||
) -> f64 {
|
||||
if let Some(metadata) = registry.find_model(model) {
|
||||
if metadata.cost.is_some() {
|
||||
return helpers::calculate_cost_with_registry(
|
||||
model,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
cache_read_tokens,
|
||||
cache_write_tokens,
|
||||
registry,
|
||||
&self.pricing,
|
||||
0.28,
|
||||
0.42,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Custom DeepSeek fallback that correctly handles cache hits
|
||||
let (prompt_rate, completion_rate) = self
|
||||
.pricing
|
||||
.iter()
|
||||
.find(|p| model.contains(&p.model))
|
||||
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
|
||||
.unwrap_or((0.28, 0.42)); // Default to DeepSeek's current API pricing
|
||||
|
||||
let cache_hit_rate = prompt_rate / 10.0;
|
||||
let non_cached_prompt = prompt_tokens.saturating_sub(cache_read_tokens);
|
||||
|
||||
(non_cached_prompt as f64 * prompt_rate / 1_000_000.0)
|
||||
+ (cache_read_tokens as f64 * cache_hit_rate / 1_000_000.0)
|
||||
+ (completion_tokens as f64 * completion_rate / 1_000_000.0)
|
||||
}
|
||||
|
||||
async fn chat_completion_stream(
|
||||
&self,
|
||||
request: UnifiedRequest,
|
||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
||||
// DeepSeek doesn't support images in streaming, use text-only
|
||||
let messages_json = helpers::messages_to_openai_json_text_only(&request.messages).await?;
|
||||
let mut body = helpers::build_openai_body(&request, messages_json, true);
|
||||
|
||||
// Sanitize and fix for deepseek-reasoner (R1)
|
||||
if request.model == "deepseek-reasoner" {
|
||||
if let Some(obj) = body.as_object_mut() {
|
||||
// Keep stream_options if present (DeepSeek supports include_usage)
|
||||
|
||||
// Remove unsupported parameters
|
||||
obj.remove("temperature");
|
||||
|
||||
obj.remove("top_p");
|
||||
obj.remove("presence_penalty");
|
||||
obj.remove("frequency_penalty");
|
||||
obj.remove("logit_bias");
|
||||
obj.remove("logprobs");
|
||||
obj.remove("top_logprobs");
|
||||
|
||||
// ENSURE: EVERY assistant message must have reasoning_content and valid content
|
||||
if let Some(messages) = obj.get_mut("messages").and_then(|m| m.as_array_mut()) {
|
||||
for m in messages {
|
||||
if m["role"].as_str() == Some("assistant") {
|
||||
// DeepSeek R1 requires reasoning_content for consistency in history.
|
||||
if m.get("reasoning_content").is_none() || m["reasoning_content"].is_null() {
|
||||
m["reasoning_content"] = serde_json::json!(" ");
|
||||
}
|
||||
// DeepSeek R1 often requires content to be a string, not null/array
|
||||
if m.get("content").is_none() || m["content"].is_null() || m["content"].is_array() {
|
||||
m["content"] = serde_json::json!("");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let url = format!("{}/chat/completions", self.config.base_url);
|
||||
let api_key = self.api_key.clone();
|
||||
let probe_client = self.client.clone();
|
||||
let probe_body = body.clone();
|
||||
let model = request.model.clone();
|
||||
|
||||
let es = reqwest_eventsource::EventSource::new(
|
||||
self.client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&body),
|
||||
)
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
||||
|
||||
let stream = async_stream::try_stream! {
|
||||
let mut es = es;
|
||||
while let Some(event) = es.next().await {
|
||||
match event {
|
||||
Ok(reqwest_eventsource::Event::Message(msg)) => {
|
||||
if msg.data == "[DONE]" {
|
||||
break;
|
||||
}
|
||||
|
||||
let chunk: serde_json::Value = serde_json::from_str(&msg.data)
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
|
||||
|
||||
if let Some(p_chunk) = helpers::parse_openai_stream_chunk(&chunk, &model, None) {
|
||||
yield p_chunk?;
|
||||
}
|
||||
}
|
||||
Ok(_) => continue,
|
||||
Err(e) => {
|
||||
// Attempt to probe for the actual error body
|
||||
let probe_resp = probe_client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {}", api_key))
|
||||
.json(&probe_body)
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match probe_resp {
|
||||
Ok(r) if !r.status().is_success() => {
|
||||
let status = r.status();
|
||||
let error_body = r.text().await.unwrap_or_default();
|
||||
tracing::error!("DeepSeek Stream Error Probe ({}): {}", status, error_body);
|
||||
// Log the offending request body at ERROR level so it shows up in standard logs
|
||||
tracing::error!("Offending DeepSeek Request Body: {}", serde_json::to_string(&probe_body).unwrap_or_default());
|
||||
Err(AppError::ProviderError(format!("DeepSeek API error ({}): {}", status, error_body)))?;
|
||||
}
|
||||
Ok(_) => {
|
||||
Err(AppError::ProviderError(format!("Stream error (probe returned 200): {}", e)))?;
|
||||
}
|
||||
Err(probe_err) => {
|
||||
tracing::error!("DeepSeek Stream Error Probe failed: {}", probe_err);
|
||||
Err(AppError::ProviderError(format!("Stream error (probe failed: {}): {}", probe_err, e)))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Box::pin(stream))
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,123 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::BoxStream;
|
||||
|
||||
use super::helpers;
|
||||
use super::{ProviderResponse, ProviderStreamChunk};
|
||||
use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest};
|
||||
|
||||
pub struct GrokProvider {
|
||||
client: reqwest::Client,
|
||||
config: crate::config::GrokConfig,
|
||||
api_key: String,
|
||||
pricing: Vec<crate::config::ModelPricing>,
|
||||
}
|
||||
|
||||
impl GrokProvider {
|
||||
pub fn new(config: &crate::config::GrokConfig, app_config: &AppConfig) -> Result<Self> {
|
||||
let api_key = app_config.get_api_key("grok")?;
|
||||
Self::new_with_key(config, app_config, api_key)
|
||||
}
|
||||
|
||||
pub fn new_with_key(config: &crate::config::GrokConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
|
||||
let client = reqwest::Client::builder()
|
||||
.connect_timeout(std::time::Duration::from_secs(5))
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.pool_idle_timeout(std::time::Duration::from_secs(90))
|
||||
.pool_max_idle_per_host(4)
|
||||
.tcp_keepalive(std::time::Duration::from_secs(30))
|
||||
.build()?;
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
config: config.clone(),
|
||||
api_key,
|
||||
pricing: app_config.pricing.grok.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl super::Provider for GrokProvider {
|
||||
fn name(&self) -> &str {
|
||||
"grok"
|
||||
}
|
||||
|
||||
fn supports_model(&self, model: &str) -> bool {
|
||||
model.starts_with("grok-")
|
||||
}
|
||||
|
||||
fn supports_multimodal(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
||||
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
||||
let body = helpers::build_openai_body(&request, messages_json, false);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(format!("{}/chat/completions", self.config.base_url))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
return Err(AppError::ProviderError(format!("Grok API error: {}", error_text)));
|
||||
}
|
||||
|
||||
let resp_json: serde_json::Value = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||
|
||||
helpers::parse_openai_response(&resp_json, request.model)
|
||||
}
|
||||
|
||||
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
|
||||
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
|
||||
}
|
||||
|
||||
fn calculate_cost(
|
||||
&self,
|
||||
model: &str,
|
||||
prompt_tokens: u32,
|
||||
completion_tokens: u32,
|
||||
cache_read_tokens: u32,
|
||||
cache_write_tokens: u32,
|
||||
registry: &crate::models::registry::ModelRegistry,
|
||||
) -> f64 {
|
||||
helpers::calculate_cost_with_registry(
|
||||
model,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
cache_read_tokens,
|
||||
cache_write_tokens,
|
||||
registry,
|
||||
&self.pricing,
|
||||
5.0,
|
||||
15.0,
|
||||
)
|
||||
}
|
||||
|
||||
async fn chat_completion_stream(
|
||||
&self,
|
||||
request: UnifiedRequest,
|
||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
||||
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
||||
let body = helpers::build_openai_body(&request, messages_json, true);
|
||||
|
||||
let es = reqwest_eventsource::EventSource::new(
|
||||
self.client
|
||||
.post(format!("{}/chat/completions", self.config.base_url))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&body),
|
||||
)
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
||||
|
||||
Ok(helpers::create_openai_stream(es, request.model, None))
|
||||
}
|
||||
}
|
||||
@@ -1,454 +0,0 @@
|
||||
use super::{ProviderResponse, ProviderStreamChunk, StreamUsage};
|
||||
use crate::errors::AppError;
|
||||
use crate::models::{ContentPart, ToolCall, ToolCallDelta, UnifiedMessage, UnifiedRequest};
|
||||
use futures::stream::{BoxStream, StreamExt};
|
||||
use serde_json::Value;
|
||||
|
||||
/// Convert messages to OpenAI-compatible JSON, resolving images asynchronously.
|
||||
///
|
||||
/// This avoids the deadlock caused by `futures::executor::block_on` inside a
|
||||
/// Tokio async context. All image base64 conversions are awaited properly.
|
||||
/// Handles tool-calling messages: assistant messages with tool_calls, and
|
||||
/// tool-role messages with tool_call_id/name.
|
||||
pub async fn messages_to_openai_json(messages: &[UnifiedMessage]) -> Result<Vec<serde_json::Value>, AppError> {
|
||||
let mut result = Vec::new();
|
||||
for m in messages {
|
||||
// Tool-role messages: { role: "tool", content: "...", tool_call_id: "...", name: "..." }
|
||||
if m.role == "tool" {
|
||||
let text_content = m
|
||||
.content
|
||||
.first()
|
||||
.map(|p| match p {
|
||||
ContentPart::Text { text } => text.clone(),
|
||||
ContentPart::Image(_) => "[Image]".to_string(),
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let mut msg = serde_json::json!({
|
||||
"role": "tool",
|
||||
"content": text_content
|
||||
});
|
||||
if let Some(tool_call_id) = &m.tool_call_id {
|
||||
// OpenAI and others have a 40-char limit for tool_call_id.
|
||||
// Gemini signatures (56 chars) must be shortened for compatibility.
|
||||
let id = if tool_call_id.len() > 40 {
|
||||
&tool_call_id[..40]
|
||||
} else {
|
||||
tool_call_id
|
||||
};
|
||||
msg["tool_call_id"] = serde_json::json!(id);
|
||||
}
|
||||
if let Some(name) = &m.name {
|
||||
msg["name"] = serde_json::json!(name);
|
||||
}
|
||||
result.push(msg);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Build content parts for non-tool messages
|
||||
let mut parts = Vec::new();
|
||||
for p in &m.content {
|
||||
match p {
|
||||
ContentPart::Text { text } => {
|
||||
parts.push(serde_json::json!({ "type": "text", "text": text }));
|
||||
}
|
||||
ContentPart::Image(image_input) => {
|
||||
let (base64_data, mime_type) = image_input
|
||||
.to_base64()
|
||||
.await
|
||||
.map_err(|e| AppError::MultimodalError(e.to_string()))?;
|
||||
parts.push(serde_json::json!({
|
||||
"type": "image_url",
|
||||
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut msg = serde_json::json!({ "role": m.role });
|
||||
|
||||
// Include reasoning_content if present (DeepSeek R1/reasoner requires this in history)
|
||||
if let Some(reasoning) = &m.reasoning_content {
|
||||
msg["reasoning_content"] = serde_json::json!(reasoning);
|
||||
}
|
||||
|
||||
// For assistant messages with tool_calls, content can be empty string
|
||||
if let Some(tool_calls) = &m.tool_calls {
|
||||
// Sanitize tool call IDs for OpenAI compatibility (max 40 chars)
|
||||
let sanitized_calls: Vec<_> = tool_calls.iter().map(|tc| {
|
||||
let mut sanitized = tc.clone();
|
||||
if sanitized.id.len() > 40 {
|
||||
sanitized.id = sanitized.id[..40].to_string();
|
||||
}
|
||||
sanitized
|
||||
}).collect();
|
||||
|
||||
if parts.is_empty() {
|
||||
msg["content"] = serde_json::json!("");
|
||||
} else {
|
||||
msg["content"] = serde_json::json!(parts);
|
||||
}
|
||||
msg["tool_calls"] = serde_json::json!(sanitized_calls);
|
||||
} else {
|
||||
msg["content"] = serde_json::json!(parts);
|
||||
}
|
||||
|
||||
if let Some(name) = &m.name {
|
||||
msg["name"] = serde_json::json!(name);
|
||||
}
|
||||
|
||||
result.push(msg);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Convert messages to OpenAI-compatible JSON, but replace images with a
|
||||
/// text placeholder "[Image]". Useful for providers that don't support
|
||||
/// multimodal in streaming mode or at all.
|
||||
///
|
||||
/// Handles tool-calling messages identically to `messages_to_openai_json`:
|
||||
/// assistant messages with `tool_calls`, and tool-role messages with
|
||||
/// `tool_call_id`/`name`.
|
||||
pub async fn messages_to_openai_json_text_only(
|
||||
messages: &[UnifiedMessage],
|
||||
) -> Result<Vec<serde_json::Value>, AppError> {
|
||||
let mut result = Vec::new();
|
||||
for m in messages {
|
||||
// Tool-role messages: { role: "tool", content: "...", tool_call_id: "...", name: "..." }
|
||||
if m.role == "tool" {
|
||||
let text_content = m
|
||||
.content
|
||||
.first()
|
||||
.map(|p| match p {
|
||||
ContentPart::Text { text } => text.clone(),
|
||||
ContentPart::Image(_) => "[Image]".to_string(),
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let mut msg = serde_json::json!({
|
||||
"role": "tool",
|
||||
"content": text_content
|
||||
});
|
||||
if let Some(tool_call_id) = &m.tool_call_id {
|
||||
// OpenAI and others have a 40-char limit for tool_call_id.
|
||||
let id = if tool_call_id.len() > 40 {
|
||||
&tool_call_id[..40]
|
||||
} else {
|
||||
tool_call_id
|
||||
};
|
||||
msg["tool_call_id"] = serde_json::json!(id);
|
||||
}
|
||||
if let Some(name) = &m.name {
|
||||
msg["name"] = serde_json::json!(name);
|
||||
}
|
||||
result.push(msg);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Build content parts for non-tool messages (images become "[Image]" text)
|
||||
let mut parts = Vec::new();
|
||||
for p in &m.content {
|
||||
match p {
|
||||
ContentPart::Text { text } => {
|
||||
parts.push(serde_json::json!({ "type": "text", "text": text }));
|
||||
}
|
||||
ContentPart::Image(_) => {
|
||||
parts.push(serde_json::json!({ "type": "text", "text": "[Image]" }));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut msg = serde_json::json!({ "role": m.role });
|
||||
|
||||
// Include reasoning_content if present (DeepSeek R1/reasoner requires this in history)
|
||||
if let Some(reasoning) = &m.reasoning_content {
|
||||
msg["reasoning_content"] = serde_json::json!(reasoning);
|
||||
}
|
||||
|
||||
// For assistant messages with tool_calls, content can be empty string
|
||||
if let Some(tool_calls) = &m.tool_calls {
|
||||
// Sanitize tool call IDs for OpenAI compatibility (max 40 chars)
|
||||
let sanitized_calls: Vec<_> = tool_calls.iter().map(|tc| {
|
||||
let mut sanitized = tc.clone();
|
||||
if sanitized.id.len() > 40 {
|
||||
sanitized.id = sanitized.id[..40].to_string();
|
||||
}
|
||||
sanitized
|
||||
}).collect();
|
||||
|
||||
if parts.is_empty() {
|
||||
msg["content"] = serde_json::json!("");
|
||||
} else {
|
||||
msg["content"] = serde_json::json!(parts);
|
||||
}
|
||||
msg["tool_calls"] = serde_json::json!(sanitized_calls);
|
||||
} else {
|
||||
msg["content"] = serde_json::json!(parts);
|
||||
}
|
||||
|
||||
if let Some(name) = &m.name {
|
||||
msg["name"] = serde_json::json!(name);
|
||||
}
|
||||
|
||||
result.push(msg);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Build an OpenAI-compatible request body from a UnifiedRequest and pre-converted messages.
|
||||
/// Includes tools and tool_choice when present.
|
||||
/// When streaming, adds `stream_options.include_usage: true` so providers report
|
||||
/// token counts in the final SSE chunk.
|
||||
pub fn build_openai_body(
|
||||
request: &UnifiedRequest,
|
||||
messages_json: Vec<serde_json::Value>,
|
||||
stream: bool,
|
||||
) -> serde_json::Value {
|
||||
let mut body = serde_json::json!({
|
||||
"model": request.model,
|
||||
"messages": messages_json,
|
||||
"stream": stream,
|
||||
});
|
||||
|
||||
if stream {
|
||||
body["stream_options"] = serde_json::json!({ "include_usage": true });
|
||||
}
|
||||
|
||||
if let Some(temp) = request.temperature {
|
||||
body["temperature"] = serde_json::json!(temp);
|
||||
}
|
||||
if let Some(max_tokens) = request.max_tokens {
|
||||
body["max_tokens"] = serde_json::json!(max_tokens);
|
||||
}
|
||||
if let Some(tools) = &request.tools {
|
||||
body["tools"] = serde_json::json!(tools);
|
||||
}
|
||||
if let Some(tool_choice) = &request.tool_choice {
|
||||
body["tool_choice"] = serde_json::json!(tool_choice);
|
||||
}
|
||||
|
||||
body
|
||||
}
|
||||
|
||||
/// Parse an OpenAI-compatible chat completion response JSON into a ProviderResponse.
|
||||
/// Extracts tool_calls from the message when present.
|
||||
/// Extracts cache token counts from:
|
||||
/// - OpenAI/Grok: `usage.prompt_tokens_details.cached_tokens`
|
||||
/// - DeepSeek: `usage.prompt_cache_hit_tokens` / `usage.prompt_cache_miss_tokens`
|
||||
pub fn parse_openai_response(resp_json: &Value, model: String) -> Result<ProviderResponse, AppError> {
|
||||
let choice = resp_json["choices"]
|
||||
.get(0)
|
||||
.ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?;
|
||||
let message = &choice["message"];
|
||||
|
||||
let content = message["content"].as_str().unwrap_or_default().to_string();
|
||||
let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string());
|
||||
|
||||
// Parse tool_calls from the response message
|
||||
let tool_calls: Option<Vec<ToolCall>> = message
|
||||
.get("tool_calls")
|
||||
.and_then(|tc| serde_json::from_value(tc.clone()).ok());
|
||||
|
||||
let usage = &resp_json["usage"];
|
||||
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
|
||||
// Extract reasoning tokens
|
||||
let reasoning_tokens = usage["completion_tokens_details"]["reasoning_tokens"]
|
||||
.as_u64()
|
||||
.unwrap_or(0) as u32;
|
||||
|
||||
// Extract cache tokens — try OpenAI/Grok format first, then DeepSeek format
|
||||
let cache_read_tokens = usage["prompt_tokens_details"]["cached_tokens"]
|
||||
.as_u64()
|
||||
// DeepSeek uses a different field name
|
||||
.or_else(|| usage["prompt_cache_hit_tokens"].as_u64())
|
||||
.unwrap_or(0) as u32;
|
||||
|
||||
// DeepSeek reports prompt_cache_miss_tokens which are just regular non-cached tokens.
|
||||
// They do not incur a separate cache_write fee, so we don't map them here to avoid double-charging.
|
||||
let cache_write_tokens = 0;
|
||||
|
||||
Ok(ProviderResponse {
|
||||
content,
|
||||
reasoning_content,
|
||||
tool_calls,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
reasoning_tokens,
|
||||
total_tokens,
|
||||
cache_read_tokens,
|
||||
cache_write_tokens,
|
||||
model,
|
||||
})
|
||||
}
|
||||
|
||||
/// Parse a single OpenAI-compatible stream chunk into a ProviderStreamChunk.
|
||||
/// Returns None if the chunk should be skipped (e.g. promptFeedback).
|
||||
pub fn parse_openai_stream_chunk(
|
||||
chunk: &Value,
|
||||
model: &str,
|
||||
reasoning_field: Option<&'static str>,
|
||||
) -> Option<Result<ProviderStreamChunk, AppError>> {
|
||||
// Parse usage from the final chunk (sent when stream_options.include_usage is true).
|
||||
// This chunk may have an empty `choices` array.
|
||||
let stream_usage = chunk.get("usage").and_then(|u| {
|
||||
if u.is_null() {
|
||||
return None;
|
||||
}
|
||||
let prompt_tokens = u["prompt_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
let completion_tokens = u["completion_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
let total_tokens = u["total_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
|
||||
let reasoning_tokens = u["completion_tokens_details"]["reasoning_tokens"]
|
||||
.as_u64()
|
||||
.unwrap_or(0) as u32;
|
||||
|
||||
let cache_read_tokens = u["prompt_tokens_details"]["cached_tokens"]
|
||||
.as_u64()
|
||||
.or_else(|| u["prompt_cache_hit_tokens"].as_u64())
|
||||
.unwrap_or(0) as u32;
|
||||
|
||||
let cache_write_tokens = 0;
|
||||
|
||||
Some(StreamUsage {
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
reasoning_tokens,
|
||||
total_tokens,
|
||||
cache_read_tokens,
|
||||
cache_write_tokens,
|
||||
})
|
||||
});
|
||||
|
||||
if let Some(choice) = chunk["choices"].get(0) {
|
||||
let delta = &choice["delta"];
|
||||
let content = delta["content"].as_str().unwrap_or_default().to_string();
|
||||
let reasoning_content = delta["reasoning_content"]
|
||||
.as_str()
|
||||
.or_else(|| reasoning_field.and_then(|f| delta[f].as_str()))
|
||||
.map(|s| s.to_string());
|
||||
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
|
||||
|
||||
// Parse tool_calls deltas from the stream chunk
|
||||
let tool_calls: Option<Vec<ToolCallDelta>> = delta
|
||||
.get("tool_calls")
|
||||
.and_then(|tc| serde_json::from_value(tc.clone()).ok());
|
||||
|
||||
Some(Ok(ProviderStreamChunk {
|
||||
content,
|
||||
reasoning_content,
|
||||
finish_reason,
|
||||
tool_calls,
|
||||
model: model.to_string(),
|
||||
usage: stream_usage,
|
||||
}))
|
||||
} else if stream_usage.is_some() {
|
||||
// Final usage-only chunk (empty choices array) — yield it so
|
||||
// AggregatingStream can capture the real token counts.
|
||||
Some(Ok(ProviderStreamChunk {
|
||||
content: String::new(),
|
||||
reasoning_content: None,
|
||||
finish_reason: None,
|
||||
tool_calls: None,
|
||||
model: model.to_string(),
|
||||
usage: stream_usage,
|
||||
}))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an SSE stream that parses OpenAI-compatible streaming chunks.
|
||||
///
|
||||
/// The optional `reasoning_field` allows overriding the field name for
|
||||
/// reasoning content (e.g., "thought" for Ollama).
|
||||
/// Parses tool_calls deltas from streaming chunks when present.
|
||||
/// When `stream_options.include_usage: true` was sent, the provider sends a
|
||||
/// final chunk with `usage` data — this is parsed into `StreamUsage` and
|
||||
/// attached to the yielded `ProviderStreamChunk`.
|
||||
pub fn create_openai_stream(
|
||||
es: reqwest_eventsource::EventSource,
|
||||
model: String,
|
||||
reasoning_field: Option<&'static str>,
|
||||
) -> BoxStream<'static, Result<ProviderStreamChunk, AppError>> {
|
||||
use reqwest_eventsource::Event;
|
||||
|
||||
let stream = async_stream::try_stream! {
|
||||
let mut es = es;
|
||||
while let Some(event) = es.next().await {
|
||||
match event {
|
||||
Ok(Event::Message(msg)) => {
|
||||
if msg.data == "[DONE]" {
|
||||
break;
|
||||
}
|
||||
|
||||
let chunk: Value = serde_json::from_str(&msg.data)
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
|
||||
|
||||
if let Some(p_chunk) = parse_openai_stream_chunk(&chunk, &model, reasoning_field) {
|
||||
yield p_chunk?;
|
||||
}
|
||||
}
|
||||
Ok(_) => continue,
|
||||
Err(e) => {
|
||||
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Box::pin(stream)
|
||||
}
|
||||
|
||||
/// Calculate cost using the model registry first, then falling back to provider pricing config.
|
||||
///
|
||||
/// When the registry provides `cache_read` / `cache_write` rates, the formula is:
|
||||
/// (prompt_tokens - cache_read_tokens) * input_rate
|
||||
/// + cache_read_tokens * cache_read_rate
|
||||
/// + cache_write_tokens * cache_write_rate (if applicable)
|
||||
/// + completion_tokens * output_rate
|
||||
///
|
||||
/// All rates are per-token (the registry stores per-million-token rates).
|
||||
pub fn calculate_cost_with_registry(
|
||||
model: &str,
|
||||
prompt_tokens: u32,
|
||||
completion_tokens: u32,
|
||||
cache_read_tokens: u32,
|
||||
cache_write_tokens: u32,
|
||||
registry: &crate::models::registry::ModelRegistry,
|
||||
pricing: &[crate::config::ModelPricing],
|
||||
default_prompt_rate: f64,
|
||||
default_completion_rate: f64,
|
||||
) -> f64 {
|
||||
if let Some(metadata) = registry.find_model(model)
|
||||
&& let Some(cost) = &metadata.cost
|
||||
{
|
||||
let non_cached_prompt = prompt_tokens.saturating_sub(cache_read_tokens);
|
||||
let mut total = (non_cached_prompt as f64 * cost.input / 1_000_000.0)
|
||||
+ (completion_tokens as f64 * cost.output / 1_000_000.0);
|
||||
|
||||
if let Some(cache_read_rate) = cost.cache_read {
|
||||
total += cache_read_tokens as f64 * cache_read_rate / 1_000_000.0;
|
||||
} else {
|
||||
// No cache_read rate — charge cached tokens at full input rate
|
||||
total += cache_read_tokens as f64 * cost.input / 1_000_000.0;
|
||||
}
|
||||
|
||||
if let Some(cache_write_rate) = cost.cache_write {
|
||||
total += cache_write_tokens as f64 * cache_write_rate / 1_000_000.0;
|
||||
}
|
||||
|
||||
return total;
|
||||
}
|
||||
|
||||
// Fallback: no registry entry — use provider pricing config (no cache awareness)
|
||||
let (prompt_rate, completion_rate) = pricing
|
||||
.iter()
|
||||
.find(|p| model.contains(&p.model))
|
||||
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
|
||||
.unwrap_or((default_prompt_rate, default_completion_rate));
|
||||
|
||||
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
|
||||
}
|
||||
@@ -1,365 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::BoxStream;
|
||||
use sqlx::Row;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::errors::AppError;
|
||||
use crate::models::UnifiedRequest;
|
||||
|
||||
|
||||
pub mod deepseek;
|
||||
pub mod gemini;
|
||||
pub mod grok;
|
||||
pub mod helpers;
|
||||
pub mod ollama;
|
||||
pub mod openai;
|
||||
|
||||
#[async_trait]
|
||||
pub trait Provider: Send + Sync {
|
||||
/// Get provider name (e.g., "openai", "gemini")
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// Check if provider supports a specific model
|
||||
fn supports_model(&self, model: &str) -> bool;
|
||||
|
||||
/// Check if provider supports multimodal (images, etc.)
|
||||
fn supports_multimodal(&self) -> bool;
|
||||
|
||||
/// Process a chat completion request
|
||||
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError>;
|
||||
|
||||
/// Process a chat request using provider-specific "responses" style endpoint
|
||||
/// Default implementation falls back to `chat_completion` for providers
|
||||
/// that do not implement a dedicated responses endpoint.
|
||||
async fn chat_responses(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
||||
self.chat_completion(request).await
|
||||
}
|
||||
|
||||
/// Process a streaming chat completion request
|
||||
async fn chat_completion_stream(
|
||||
&self,
|
||||
request: UnifiedRequest,
|
||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError>;
|
||||
|
||||
/// Process a streaming chat request using provider-specific "responses" style endpoint
|
||||
/// Default implementation falls back to `chat_completion_stream` for providers
|
||||
/// that do not implement a dedicated responses endpoint.
|
||||
async fn chat_responses_stream(
|
||||
&self,
|
||||
request: UnifiedRequest,
|
||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
||||
self.chat_completion_stream(request).await
|
||||
}
|
||||
|
||||
/// Estimate token count for a request (for cost calculation)
|
||||
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32>;
|
||||
|
||||
/// Calculate cost based on token usage and model using the registry.
|
||||
/// `cache_read_tokens` / `cache_write_tokens` allow cache-aware pricing
|
||||
/// when the registry provides `cache_read` / `cache_write` rates.
|
||||
fn calculate_cost(
|
||||
&self,
|
||||
model: &str,
|
||||
prompt_tokens: u32,
|
||||
completion_tokens: u32,
|
||||
cache_read_tokens: u32,
|
||||
cache_write_tokens: u32,
|
||||
registry: &crate::models::registry::ModelRegistry,
|
||||
) -> f64;
|
||||
}
|
||||
|
||||
pub struct ProviderResponse {
|
||||
pub content: String,
|
||||
pub reasoning_content: Option<String>,
|
||||
pub tool_calls: Option<Vec<crate::models::ToolCall>>,
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub reasoning_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
pub cache_read_tokens: u32,
|
||||
pub cache_write_tokens: u32,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
/// Usage data from the final streaming chunk (when providers report real token counts).
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct StreamUsage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub reasoning_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
pub cache_read_tokens: u32,
|
||||
pub cache_write_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProviderStreamChunk {
|
||||
pub content: String,
|
||||
pub reasoning_content: Option<String>,
|
||||
pub finish_reason: Option<String>,
|
||||
pub tool_calls: Option<Vec<crate::models::ToolCallDelta>>,
|
||||
pub model: String,
|
||||
/// Populated only on the final chunk when providers report usage (e.g. stream_options.include_usage).
|
||||
pub usage: Option<StreamUsage>,
|
||||
}
|
||||
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::config::AppConfig;
|
||||
use crate::providers::{
|
||||
deepseek::DeepSeekProvider, gemini::GeminiProvider, grok::GrokProvider, ollama::OllamaProvider,
|
||||
openai::OpenAIProvider,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ProviderManager {
|
||||
providers: Arc<RwLock<Vec<Arc<dyn Provider>>>>,
|
||||
}
|
||||
|
||||
impl Default for ProviderManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderManager {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
providers: Arc::new(RwLock::new(Vec::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize a provider by name using config and database overrides
|
||||
pub async fn initialize_provider(
|
||||
&self,
|
||||
name: &str,
|
||||
app_config: &AppConfig,
|
||||
db_pool: &crate::database::DbPool,
|
||||
) -> Result<()> {
|
||||
// Load override from database
|
||||
let db_config = sqlx::query("SELECT enabled, base_url, api_key, api_key_encrypted FROM provider_configs WHERE id = ?")
|
||||
.bind(name)
|
||||
.fetch_optional(db_pool)
|
||||
.await?;
|
||||
|
||||
let (enabled, base_url, api_key) = if let Some(row) = db_config {
|
||||
let enabled = row.get::<bool, _>("enabled");
|
||||
let base_url = row.get::<Option<String>, _>("base_url");
|
||||
let api_key_encrypted = row.get::<bool, _>("api_key_encrypted");
|
||||
let api_key = row.get::<Option<String>, _>("api_key");
|
||||
// Decrypt API key if encrypted
|
||||
let api_key = match (api_key, api_key_encrypted) {
|
||||
(Some(key), true) => {
|
||||
match crate::utils::crypto::decrypt(&key) {
|
||||
Ok(decrypted) => Some(decrypted),
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to decrypt API key for provider {}: {}", name, e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
(Some(key), false) => {
|
||||
// Plaintext key - optionally encrypt and update database (lazy migration)
|
||||
// For now, just use plaintext
|
||||
Some(key)
|
||||
}
|
||||
(None, _) => None,
|
||||
};
|
||||
(enabled, base_url, api_key)
|
||||
} else {
|
||||
// No database override, use defaults from AppConfig
|
||||
match name {
|
||||
"openai" => (
|
||||
app_config.providers.openai.enabled,
|
||||
Some(app_config.providers.openai.base_url.clone()),
|
||||
None,
|
||||
),
|
||||
"gemini" => (
|
||||
app_config.providers.gemini.enabled,
|
||||
Some(app_config.providers.gemini.base_url.clone()),
|
||||
None,
|
||||
),
|
||||
"deepseek" => (
|
||||
app_config.providers.deepseek.enabled,
|
||||
Some(app_config.providers.deepseek.base_url.clone()),
|
||||
None,
|
||||
),
|
||||
"grok" => (
|
||||
app_config.providers.grok.enabled,
|
||||
Some(app_config.providers.grok.base_url.clone()),
|
||||
None,
|
||||
),
|
||||
"ollama" => (
|
||||
app_config.providers.ollama.enabled,
|
||||
Some(app_config.providers.ollama.base_url.clone()),
|
||||
None,
|
||||
),
|
||||
_ => (false, None, None),
|
||||
}
|
||||
};
|
||||
|
||||
if !enabled {
|
||||
self.remove_provider(name).await;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Create provider instance with merged config
|
||||
let provider: Arc<dyn Provider> = match name {
|
||||
"openai" => {
|
||||
let mut cfg = app_config.providers.openai.clone();
|
||||
if let Some(url) = base_url {
|
||||
cfg.base_url = url;
|
||||
}
|
||||
// Handle API key override if present
|
||||
let p = if let Some(key) = api_key {
|
||||
// We need a way to create a provider with an explicit key
|
||||
// Let's modify the providers to allow this
|
||||
OpenAIProvider::new_with_key(&cfg, app_config, key)?
|
||||
} else {
|
||||
OpenAIProvider::new(&cfg, app_config)?
|
||||
};
|
||||
Arc::new(p)
|
||||
}
|
||||
"ollama" => {
|
||||
let mut cfg = app_config.providers.ollama.clone();
|
||||
if let Some(url) = base_url {
|
||||
cfg.base_url = url;
|
||||
}
|
||||
Arc::new(OllamaProvider::new(&cfg, app_config)?)
|
||||
}
|
||||
"gemini" => {
|
||||
let mut cfg = app_config.providers.gemini.clone();
|
||||
if let Some(url) = base_url {
|
||||
cfg.base_url = url;
|
||||
}
|
||||
let p = if let Some(key) = api_key {
|
||||
GeminiProvider::new_with_key(&cfg, app_config, key)?
|
||||
} else {
|
||||
GeminiProvider::new(&cfg, app_config)?
|
||||
};
|
||||
Arc::new(p)
|
||||
}
|
||||
"deepseek" => {
|
||||
let mut cfg = app_config.providers.deepseek.clone();
|
||||
if let Some(url) = base_url {
|
||||
cfg.base_url = url;
|
||||
}
|
||||
let p = if let Some(key) = api_key {
|
||||
DeepSeekProvider::new_with_key(&cfg, app_config, key)?
|
||||
} else {
|
||||
DeepSeekProvider::new(&cfg, app_config)?
|
||||
};
|
||||
Arc::new(p)
|
||||
}
|
||||
"grok" => {
|
||||
let mut cfg = app_config.providers.grok.clone();
|
||||
if let Some(url) = base_url {
|
||||
cfg.base_url = url;
|
||||
}
|
||||
let p = if let Some(key) = api_key {
|
||||
GrokProvider::new_with_key(&cfg, app_config, key)?
|
||||
} else {
|
||||
GrokProvider::new(&cfg, app_config)?
|
||||
};
|
||||
Arc::new(p)
|
||||
}
|
||||
_ => return Err(anyhow::anyhow!("Unknown provider: {}", name)),
|
||||
};
|
||||
|
||||
self.add_provider(provider).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn add_provider(&self, provider: Arc<dyn Provider>) {
|
||||
let mut providers = self.providers.write().await;
|
||||
// If provider with same name exists, replace it
|
||||
if let Some(index) = providers.iter().position(|p| p.name() == provider.name()) {
|
||||
providers[index] = provider;
|
||||
} else {
|
||||
providers.push(provider);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn remove_provider(&self, name: &str) {
|
||||
let mut providers = self.providers.write().await;
|
||||
providers.retain(|p| p.name() != name);
|
||||
}
|
||||
|
||||
pub async fn get_provider_for_model(&self, model: &str) -> Option<Arc<dyn Provider>> {
|
||||
let providers = self.providers.read().await;
|
||||
providers.iter().find(|p| p.supports_model(model)).map(Arc::clone)
|
||||
}
|
||||
|
||||
pub async fn get_provider(&self, name: &str) -> Option<Arc<dyn Provider>> {
|
||||
let providers = self.providers.read().await;
|
||||
providers.iter().find(|p| p.name() == name).map(Arc::clone)
|
||||
}
|
||||
|
||||
pub async fn get_all_providers(&self) -> Vec<Arc<dyn Provider>> {
|
||||
let providers = self.providers.read().await;
|
||||
providers.clone()
|
||||
}
|
||||
}
|
||||
|
||||
// Create placeholder provider implementations
|
||||
pub mod placeholder {
|
||||
use super::*;
|
||||
|
||||
pub struct PlaceholderProvider {
|
||||
name: String,
|
||||
}
|
||||
|
||||
impl PlaceholderProvider {
|
||||
pub fn new(name: &str) -> Self {
|
||||
Self { name: name.to_string() }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for PlaceholderProvider {
|
||||
fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
fn supports_model(&self, _model: &str) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn supports_multimodal(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
async fn chat_completion_stream(
|
||||
&self,
|
||||
_request: UnifiedRequest,
|
||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
||||
Err(AppError::ProviderError(
|
||||
"Streaming not supported for placeholder provider".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
async fn chat_completion(&self, _request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
||||
Err(AppError::ProviderError(format!(
|
||||
"Provider {} not implemented",
|
||||
self.name
|
||||
)))
|
||||
}
|
||||
|
||||
fn estimate_tokens(&self, _request: &UnifiedRequest) -> Result<u32> {
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
fn calculate_cost(
|
||||
&self,
|
||||
_model: &str,
|
||||
_prompt_tokens: u32,
|
||||
_completion_tokens: u32,
|
||||
_cache_read_tokens: u32,
|
||||
_cache_write_tokens: u32,
|
||||
_registry: &crate::models::registry::ModelRegistry,
|
||||
) -> f64 {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,140 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::BoxStream;
|
||||
|
||||
use super::helpers;
|
||||
use super::{ProviderResponse, ProviderStreamChunk};
|
||||
use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest};
|
||||
|
||||
pub struct OllamaProvider {
|
||||
client: reqwest::Client,
|
||||
config: crate::config::OllamaConfig,
|
||||
pricing: Vec<crate::config::ModelPricing>,
|
||||
}
|
||||
|
||||
impl OllamaProvider {
|
||||
pub fn new(config: &crate::config::OllamaConfig, app_config: &AppConfig) -> Result<Self> {
|
||||
let client = reqwest::Client::builder()
|
||||
.connect_timeout(std::time::Duration::from_secs(5))
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.pool_idle_timeout(std::time::Duration::from_secs(90))
|
||||
.pool_max_idle_per_host(4)
|
||||
.tcp_keepalive(std::time::Duration::from_secs(30))
|
||||
.build()?;
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
config: config.clone(),
|
||||
pricing: app_config.pricing.ollama.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl super::Provider for OllamaProvider {
|
||||
fn name(&self) -> &str {
|
||||
"ollama"
|
||||
}
|
||||
|
||||
fn supports_model(&self, model: &str) -> bool {
|
||||
self.config.models.iter().any(|m| m == model) || model.starts_with("ollama/")
|
||||
}
|
||||
|
||||
fn supports_multimodal(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
async fn chat_completion(&self, mut request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
||||
// Strip "ollama/" prefix if present for the API call
|
||||
let api_model = request
|
||||
.model
|
||||
.strip_prefix("ollama/")
|
||||
.unwrap_or(&request.model)
|
||||
.to_string();
|
||||
let original_model = request.model.clone();
|
||||
request.model = api_model;
|
||||
|
||||
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
||||
let body = helpers::build_openai_body(&request, messages_json, false);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(format!("{}/chat/completions", self.config.base_url))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
return Err(AppError::ProviderError(format!("Ollama API error: {}", error_text)));
|
||||
}
|
||||
|
||||
let resp_json: serde_json::Value = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||
|
||||
// Ollama also supports "thought" as an alias for reasoning_content
|
||||
let mut result = helpers::parse_openai_response(&resp_json, original_model)?;
|
||||
if result.reasoning_content.is_none() {
|
||||
result.reasoning_content = resp_json["choices"]
|
||||
.get(0)
|
||||
.and_then(|c| c["message"]["thought"].as_str())
|
||||
.map(|s| s.to_string());
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
|
||||
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
|
||||
}
|
||||
|
||||
fn calculate_cost(
|
||||
&self,
|
||||
model: &str,
|
||||
prompt_tokens: u32,
|
||||
completion_tokens: u32,
|
||||
cache_read_tokens: u32,
|
||||
cache_write_tokens: u32,
|
||||
registry: &crate::models::registry::ModelRegistry,
|
||||
) -> f64 {
|
||||
helpers::calculate_cost_with_registry(
|
||||
model,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
cache_read_tokens,
|
||||
cache_write_tokens,
|
||||
registry,
|
||||
&self.pricing,
|
||||
0.0,
|
||||
0.0,
|
||||
)
|
||||
}
|
||||
|
||||
async fn chat_completion_stream(
|
||||
&self,
|
||||
mut request: UnifiedRequest,
|
||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
||||
let api_model = request
|
||||
.model
|
||||
.strip_prefix("ollama/")
|
||||
.unwrap_or(&request.model)
|
||||
.to_string();
|
||||
let original_model = request.model.clone();
|
||||
request.model = api_model;
|
||||
|
||||
let messages_json = helpers::messages_to_openai_json_text_only(&request.messages).await?;
|
||||
let body = helpers::build_openai_body(&request, messages_json, true);
|
||||
|
||||
let es = reqwest_eventsource::EventSource::new(
|
||||
self.client
|
||||
.post(format!("{}/chat/completions", self.config.base_url))
|
||||
.json(&body),
|
||||
)
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
||||
|
||||
// Ollama uses "thought" as an alternative field for reasoning content
|
||||
Ok(helpers::create_openai_stream(es, original_model, Some("thought")))
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user