refactor: comprehensive audit — fix bugs, harden security, deduplicate providers, add CI/Docker
Some checks failed
CI / Check (push) Has been cancelled
CI / Clippy (push) Has been cancelled
CI / Formatting (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Release Build (push) Has been cancelled

Phase 1: Fix compilation (config_path Option<PathBuf>, streaming test, stale test cleanup)
Phase 2: Fix critical bugs (remove block_on deadlocks in 4 providers, fix broken SQL query builder)
Phase 3: Security hardening (session manager, real auth, token masking, Gemini key to header, password policy)
Phase 4: Implement stubs (real provider test, /proc health metrics, client/provider/backup endpoints, has_images)
Phase 5: Code quality (shared provider helpers, explicit re-exports, all Clippy warnings fixed, unwrap removal, 6 unused deps removed, dashboard split into 7 sub-modules)
Phase 6: Infrastructure (GitHub Actions CI, multi-stage Dockerfile, rustfmt.toml, clippy.toml, script fixes)
This commit is contained in:
2026-03-02 00:35:45 -05:00
parent ba643dd2b0
commit 2cdc49d7f2
42 changed files with 2800 additions and 2747 deletions

61
.github/workflows/ci.yml vendored Normal file
View File

@@ -0,0 +1,61 @@
name: CI
on:
push:
branches: [main]
pull_request:
branches: [main]
env:
CARGO_TERM_COLOR: always
RUST_BACKTRACE: 1
jobs:
check:
name: Check
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
- uses: Swatinem/rust-cache@v2
- run: cargo check --all-targets
clippy:
name: Clippy
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
with:
components: clippy
- uses: Swatinem/rust-cache@v2
- run: cargo clippy --all-targets -- -D warnings
fmt:
name: Formatting
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
with:
components: rustfmt
- run: cargo fmt --all -- --check
test:
name: Test
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
- uses: Swatinem/rust-cache@v2
- run: cargo test --all-targets
build-release:
name: Release Build
runs-on: ubuntu-latest
needs: [check, clippy, test]
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
- uses: Swatinem/rust-cache@v2
- run: cargo build --release

270
Cargo.lock generated
View File

@@ -92,46 +92,6 @@ dependencies = [
"tokio", "tokio",
] ]
[[package]]
name = "async-openai"
version = "0.33.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b3ef6d66e53b3ec8ed4bae8e6dcdda75c0f5b4f67faba78fac1b09434cb4f2fc"
dependencies = [
"async-openai-macros",
"backoff",
"base64 0.22.1",
"bytes",
"derive_builder",
"eventsource-stream",
"futures",
"getrandom 0.3.4",
"rand 0.9.2",
"reqwest",
"reqwest-eventsource",
"secrecy",
"serde",
"serde_json",
"serde_urlencoded",
"thiserror 2.0.18",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
"url",
]
[[package]]
name = "async-openai-macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81872a8e595e8ceceab71c6ba1f9078e313b452a1e31934e6763ef5d308705e4"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "async-stream" name = "async-stream"
version = "0.3.6" version = "0.3.6"
@@ -275,20 +235,6 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "backoff"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1"
dependencies = [
"futures-core",
"getrandom 0.2.17",
"instant",
"pin-project-lite",
"rand 0.8.5",
"tokio",
]
[[package]] [[package]]
name = "base64" name = "base64"
version = "0.13.1" version = "0.13.1"
@@ -598,54 +544,6 @@ dependencies = [
"typenum", "typenum",
] ]
[[package]]
name = "darling"
version = "0.20.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee"
dependencies = [
"darling_core",
"darling_macro",
]
[[package]]
name = "darling_core"
version = "0.20.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d00b9596d185e565c2207a0b01f8bd1a135483d02d9b7b0a54b11da8d53412e"
dependencies = [
"fnv",
"ident_case",
"proc-macro2",
"quote",
"strsim",
"syn",
]
[[package]]
name = "darling_macro"
version = "0.20.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead"
dependencies = [
"darling_core",
"quote",
"syn",
]
[[package]]
name = "dashmap"
version = "5.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856"
dependencies = [
"cfg-if",
"hashbrown 0.14.5",
"lock_api",
"once_cell",
"parking_lot_core",
]
[[package]] [[package]]
name = "data-encoding" name = "data-encoding"
version = "2.10.0" version = "2.10.0"
@@ -663,37 +561,6 @@ dependencies = [
"zeroize", "zeroize",
] ]
[[package]]
name = "derive_builder"
version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947"
dependencies = [
"derive_builder_macro",
]
[[package]]
name = "derive_builder_core"
version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8"
dependencies = [
"darling",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "derive_builder_macro"
version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c"
dependencies = [
"derive_builder_core",
"syn",
]
[[package]] [[package]]
name = "difflib" name = "difflib"
version = "0.4.0" version = "0.4.0"
@@ -1028,26 +895,6 @@ dependencies = [
"wasip3", "wasip3",
] ]
[[package]]
name = "governor"
version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68a7f542ee6b35af73b06abc0dad1c1bae89964e4e253bc4b587b91c9637867b"
dependencies = [
"cfg-if",
"dashmap",
"futures",
"futures-timer",
"no-std-compat",
"nonzero_ext",
"parking_lot",
"portable-atomic",
"quanta",
"rand 0.8.5",
"smallvec",
"spinning_top",
]
[[package]] [[package]]
name = "h2" name = "h2"
version = "0.4.13" version = "0.4.13"
@@ -1076,12 +923,6 @@ dependencies = [
"ahash", "ahash",
] ]
[[package]]
name = "hashbrown"
version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
[[package]] [[package]]
name = "hashbrown" name = "hashbrown"
version = "0.15.5" version = "0.15.5"
@@ -1396,12 +1237,6 @@ version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954"
[[package]]
name = "ident_case"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
[[package]] [[package]]
name = "idna" name = "idna"
version = "1.1.0" version = "1.1.0"
@@ -1482,15 +1317,6 @@ dependencies = [
"tempfile", "tempfile",
] ]
[[package]]
name = "instant"
version = "0.1.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222"
dependencies = [
"cfg-if",
]
[[package]] [[package]]
name = "ipnet" name = "ipnet"
version = "2.11.0" version = "2.11.0"
@@ -1607,7 +1433,6 @@ version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"assert_cmd", "assert_cmd",
"async-openai",
"async-stream", "async-stream",
"async-trait", "async-trait",
"axum", "axum",
@@ -1618,16 +1443,11 @@ dependencies = [
"config", "config",
"dotenvy", "dotenvy",
"futures", "futures",
"governor",
"headers", "headers",
"image", "image",
"insta", "insta",
"mime", "mime",
"mime_guess",
"mockito", "mockito",
"once_cell",
"rand 0.8.5",
"regex",
"reqwest", "reqwest",
"reqwest-eventsource", "reqwest-eventsource",
"serde", "serde",
@@ -1776,12 +1596,6 @@ dependencies = [
"pxfm", "pxfm",
] ]
[[package]]
name = "no-std-compat"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c"
[[package]] [[package]]
name = "nom" name = "nom"
version = "7.1.3" version = "7.1.3"
@@ -1792,12 +1606,6 @@ dependencies = [
"minimal-lexical", "minimal-lexical",
] ]
[[package]]
name = "nonzero_ext"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21"
[[package]] [[package]]
name = "nu-ansi-term" name = "nu-ansi-term"
version = "0.50.3" version = "0.50.3"
@@ -2014,12 +1822,6 @@ dependencies = [
"miniz_oxide", "miniz_oxide",
] ]
[[package]]
name = "portable-atomic"
version = "1.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49"
[[package]] [[package]]
name = "potential_utf" name = "potential_utf"
version = "0.1.4" version = "0.1.4"
@@ -2093,21 +1895,6 @@ dependencies = [
"num-traits", "num-traits",
] ]
[[package]]
name = "quanta"
version = "0.12.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f3ab5a9d756f0d97bdc89019bd2e4ea098cf9cde50ee7564dde6b81ccc8f06c7"
dependencies = [
"crossbeam-utils",
"libc",
"once_cell",
"raw-cpuid",
"wasi",
"web-sys",
"winapi",
]
[[package]] [[package]]
name = "quick-error" name = "quick-error"
version = "2.0.1" version = "2.0.1"
@@ -2243,15 +2030,6 @@ dependencies = [
"getrandom 0.3.4", "getrandom 0.3.4",
] ]
[[package]]
name = "raw-cpuid"
version = "11.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186"
dependencies = [
"bitflags 2.11.0",
]
[[package]] [[package]]
name = "redox_syscall" name = "redox_syscall"
version = "0.5.18" version = "0.5.18"
@@ -2317,7 +2095,6 @@ dependencies = [
"hyper-util", "hyper-util",
"js-sys", "js-sys",
"log", "log",
"mime_guess",
"percent-encoding", "percent-encoding",
"pin-project-lite", "pin-project-lite",
"quinn", "quinn",
@@ -2490,16 +2267,6 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "secrecy"
version = "0.10.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e891af845473308773346dc847b2c23ee78fe442e0472ac50e22a18a93d3ae5a"
dependencies = [
"serde",
"zeroize",
]
[[package]] [[package]]
name = "semver" name = "semver"
version = "1.0.27" version = "1.0.27"
@@ -2684,15 +2451,6 @@ dependencies = [
"lock_api", "lock_api",
] ]
[[package]]
name = "spinning_top"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300"
dependencies = [
"lock_api",
]
[[package]] [[package]]
name = "spki" name = "spki"
version = "0.7.3" version = "0.7.3"
@@ -2912,12 +2670,6 @@ dependencies = [
"unicode-properties", "unicode-properties",
] ]
[[package]]
name = "strsim"
version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
[[package]] [[package]]
name = "subtle" name = "subtle"
version = "2.6.1" version = "2.6.1"
@@ -3657,28 +3409,6 @@ dependencies = [
"wasite", "wasite",
] ]
[[package]]
name = "winapi"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
dependencies = [
"winapi-i686-pc-windows-gnu",
"winapi-x86_64-pc-windows-gnu",
]
[[package]]
name = "winapi-i686-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]] [[package]]
name = "windows-core" name = "windows-core"
version = "0.62.2" version = "0.62.2"

View File

@@ -16,7 +16,6 @@ tower-http = { version = "0.6", features = ["trace", "cors", "compression-gzip",
# ========== HTTP Clients ========== # ========== HTTP Clients ==========
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] } reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] }
async-openai = { version = "0.33", default-features = false, features = ["_api", "chat-completion"] }
tiktoken-rs = "0.9" tiktoken-rs = "0.9"
# ========== Database & ORM ========== # ========== Database & ORM ==========
@@ -41,7 +40,6 @@ tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
base64 = "0.21" base64 = "0.21"
image = { version = "0.25", default-features = false, features = ["jpeg", "png", "webp"] } image = { version = "0.25", default-features = false, features = ["jpeg", "png", "webp"] }
mime = "0.3" mime = "0.3"
mime_guess = "2.0"
# ========== Error Handling & Utilities ========== # ========== Error Handling & Utilities ==========
anyhow = "1.0" anyhow = "1.0"
@@ -53,12 +51,6 @@ futures = "0.3"
async-trait = "0.1" async-trait = "0.1"
async-stream = "0.3" async-stream = "0.3"
reqwest-eventsource = "0.6" reqwest-eventsource = "0.6"
once_cell = "1.19"
regex = "1.10"
rand = "0.8"
# ========== Rate Limiting & Circuit Breaking ==========
governor = "0.6"
[dev-dependencies] [dev-dependencies]
tokio-test = "0.4" tokio-test = "0.4"

35
Dockerfile Normal file
View File

@@ -0,0 +1,35 @@
# ── Build stage ──────────────────────────────────────────────
FROM rust:1-bookworm AS builder
WORKDIR /app
# Cache dependency build
COPY Cargo.toml Cargo.lock ./
RUN mkdir src && echo 'fn main() {}' > src/main.rs && \
cargo build --release && \
rm -rf src
# Build the actual binary
COPY src/ src/
RUN touch src/main.rs && cargo build --release
# ── Runtime stage ────────────────────────────────────────────
FROM debian:bookworm-slim
RUN apt-get update && \
apt-get install -y --no-install-recommends ca-certificates && \
rm -rf /var/lib/apt/lists/*
WORKDIR /app
COPY --from=builder /app/target/release/llm-proxy /app/llm-proxy
COPY static/ /app/static/
# Default config location
VOLUME ["/app/config", "/app/data"]
EXPOSE 8080
ENV RUST_LOG=info
ENTRYPOINT ["/app/llm-proxy"]

1
clippy.toml Normal file
View File

@@ -0,0 +1 @@
too-many-arguments-threshold = 8

2
rustfmt.toml Normal file
View File

@@ -0,0 +1,2 @@
max_width = 120
use_field_init_shorthand = true

View File

@@ -1,6 +1,6 @@
use axum::{extract::FromRequestParts, http::request::Parts}; use axum::{extract::FromRequestParts, http::request::Parts};
use axum_extra::headers::Authorization;
use axum_extra::TypedHeader; use axum_extra::TypedHeader;
use axum_extra::headers::Authorization;
use headers::authorization::Bearer; use headers::authorization::Bearer;
use crate::errors::AppError; use crate::errors::AppError;
@@ -16,33 +16,19 @@ where
{ {
type Rejection = AppError; type Rejection = AppError;
fn from_request_parts(parts: &mut Parts, state: &S) -> impl std::future::Future<Output = Result<Self, Self::Rejection>> + Send { async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
// We need access to the AppState to get valid tokens
// Since state is generic here, we try to cast it or assume it's available via extensions
// In this project, AppState is cloned into Axum state.
async move {
// Extract bearer token from Authorization header // Extract bearer token from Authorization header
let TypedHeader(Authorization(bearer)) = let TypedHeader(Authorization(bearer)) = TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state)
TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state)
.await .await
.map_err(|_| AppError::AuthError("Missing or invalid bearer token".to_string()))?; .map_err(|_| AppError::AuthError("Missing or invalid bearer token".to_string()))?;
let token = bearer.token().to_string(); let token = bearer.token().to_string();
// For a proxy, we want to check if this token is in our allowed list // Derive client_id from the token prefix
// The list is stored in AppState which is available in Parts extensions let client_id = format!("client_{}", &token[..8.min(token.len())]);
let client_id = {
// In main.rs, we set up the router with State(state).
// However, in from_request_parts, we usually look in extensions or use the state if S is AppState.
// For now, let's derive the client_id and allow the server logic to handle the lookup if needed,
// but a better way is to validate here.
format!("client_{}", &token[..8.min(token.len())])
};
Ok(AuthenticatedClient { token, client_id }) Ok(AuthenticatedClient { token, client_id })
} }
}
} }
pub fn validate_token(token: &str, valid_tokens: &[String]) -> bool { pub fn validate_token(token: &str, valid_tokens: &[String]) -> bool {

View File

@@ -5,10 +5,10 @@
//! 2. Client usage tracking //! 2. Client usage tracking
//! 3. Client rate limit configuration //! 3. Client rate limit configuration
use anyhow::Result;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::{SqlitePool, Row}; use sqlx::{Row, SqlitePool};
use anyhow::Result;
use tracing::{info, warn}; use tracing::{info, warn};
/// Client information /// Client information
@@ -74,7 +74,9 @@ impl ClientManager {
.await?; .await?;
// Then fetch the created client // Then fetch the created client
let client = self.get_client(&request.client_id).await? let client = self
.get_client(&request.client_id)
.await?
.ok_or_else(|| anyhow::anyhow!("Failed to retrieve created client"))?; .ok_or_else(|| anyhow::anyhow!("Failed to retrieve created client"))?;
info!("Created client: {} ({})", client.name, client.client_id); info!("Created client: {} ({})", client.name, client.client_id);
@@ -126,12 +128,11 @@ impl ClientManager {
} }
// Build update query dynamically based on provided fields // Build update query dynamically based on provided fields
let mut updates = Vec::new();
let mut query_builder = sqlx::QueryBuilder::new("UPDATE clients SET "); let mut query_builder = sqlx::QueryBuilder::new("UPDATE clients SET ");
let mut has_updates = false; let mut has_updates = false;
if let Some(name) = &request.name { if let Some(name) = &request.name {
updates.push("name = "); query_builder.push("name = ");
query_builder.push_bind(name); query_builder.push_bind(name);
has_updates = true; has_updates = true;
} }
@@ -140,7 +141,7 @@ impl ClientManager {
if has_updates { if has_updates {
query_builder.push(", "); query_builder.push(", ");
} }
updates.push("description = "); query_builder.push("description = ");
query_builder.push_bind(description); query_builder.push_bind(description);
has_updates = true; has_updates = true;
} }
@@ -149,7 +150,7 @@ impl ClientManager {
if has_updates { if has_updates {
query_builder.push(", "); query_builder.push(", ");
} }
updates.push("is_active = "); query_builder.push("is_active = ");
query_builder.push_bind(is_active); query_builder.push_bind(is_active);
has_updates = true; has_updates = true;
} }
@@ -158,7 +159,7 @@ impl ClientManager {
if has_updates { if has_updates {
query_builder.push(", "); query_builder.push(", ");
} }
updates.push("rate_limit_per_minute = "); query_builder.push("rate_limit_per_minute = ");
query_builder.push_bind(rate_limit); query_builder.push_bind(rate_limit);
has_updates = true; has_updates = true;
} }
@@ -204,7 +205,7 @@ impl ClientManager {
FROM clients FROM clients
ORDER BY created_at DESC ORDER BY created_at DESC
LIMIT ? OFFSET ? LIMIT ? OFFSET ?
"# "#,
) )
.bind(limit) .bind(limit)
.bind(offset) .bind(offset)
@@ -234,9 +235,7 @@ impl ClientManager {
/// Delete a client /// Delete a client
pub async fn delete_client(&self, client_id: &str) -> Result<bool> { pub async fn delete_client(&self, client_id: &str) -> Result<bool> {
let result = sqlx::query( let result = sqlx::query("DELETE FROM clients WHERE client_id = ?")
"DELETE FROM clients WHERE client_id = ?"
)
.bind(client_id) .bind(client_id)
.execute(&self.db_pool) .execute(&self.db_pool)
.await?; .await?;
@@ -253,12 +252,7 @@ impl ClientManager {
} }
/// Update client usage statistics after a request /// Update client usage statistics after a request
pub async fn update_client_usage( pub async fn update_client_usage(&self, client_id: &str, tokens: i64, cost: f64) -> Result<()> {
&self,
client_id: &str,
tokens: i64,
cost: f64,
) -> Result<()> {
sqlx::query( sqlx::query(
r#" r#"
UPDATE clients UPDATE clients
@@ -268,7 +262,7 @@ impl ClientManager {
total_cost = total_cost + ?, total_cost = total_cost + ?,
updated_at = CURRENT_TIMESTAMP updated_at = CURRENT_TIMESTAMP
WHERE client_id = ? WHERE client_id = ?
"# "#,
) )
.bind(tokens) .bind(tokens)
.bind(cost) .bind(cost)
@@ -286,7 +280,7 @@ impl ClientManager {
SELECT total_requests, total_tokens, total_cost SELECT total_requests, total_tokens, total_cost
FROM clients FROM clients
WHERE client_id = ? WHERE client_id = ?
"# "#,
) )
.bind(client_id) .bind(client_id)
.fetch_optional(&self.db_pool) .fetch_optional(&self.db_pool)

View File

@@ -95,7 +95,7 @@ pub struct AppConfig {
pub providers: ProviderConfig, pub providers: ProviderConfig,
pub model_mapping: ModelMappingConfig, pub model_mapping: ModelMappingConfig,
pub pricing: PricingConfig, pub pricing: PricingConfig,
pub config_path: PathBuf, pub config_path: Option<PathBuf>,
} }
impl AppConfig { impl AppConfig {
@@ -120,7 +120,10 @@ impl AppConfig {
.set_default("providers.openai.default_model", "gpt-4o")? .set_default("providers.openai.default_model", "gpt-4o")?
.set_default("providers.openai.enabled", true)? .set_default("providers.openai.enabled", true)?
.set_default("providers.gemini.api_key_env", "GEMINI_API_KEY")? .set_default("providers.gemini.api_key_env", "GEMINI_API_KEY")?
.set_default("providers.gemini.base_url", "https://generativelanguage.googleapis.com/v1")? .set_default(
"providers.gemini.base_url",
"https://generativelanguage.googleapis.com/v1",
)?
.set_default("providers.gemini.default_model", "gemini-2.0-flash")? .set_default("providers.gemini.default_model", "gemini-2.0-flash")?
.set_default("providers.gemini.enabled", true)? .set_default("providers.gemini.enabled", true)?
.set_default("providers.deepseek.api_key_env", "DEEPSEEK_API_KEY")? .set_default("providers.deepseek.api_key_env", "DEEPSEEK_API_KEY")?
@@ -136,7 +139,11 @@ impl AppConfig {
.set_default("providers.ollama.models", Vec::<String>::new())?; .set_default("providers.ollama.models", Vec::<String>::new())?;
// Load from config file if exists // Load from config file if exists
let config_path = config_path.unwrap_or_else(|| std::env::current_dir().unwrap().join("config.toml")); let config_path = config_path.unwrap_or_else(|| {
std::env::current_dir()
.unwrap_or_else(|_| PathBuf::from("."))
.join("config.toml")
});
if config_path.exists() { if config_path.exists() {
config_builder = config_builder.add_source(File::from(config_path.clone()).format(FileFormat::Toml)); config_builder = config_builder.add_source(File::from(config_path.clone()).format(FileFormat::Toml));
} }
@@ -174,7 +181,7 @@ impl AppConfig {
providers, providers,
model_mapping, model_mapping,
pricing, pricing,
config_path, config_path: Some(config_path),
})) }))
} }
@@ -187,16 +194,15 @@ impl AppConfig {
_ => return Err(anyhow::anyhow!("Unknown provider: {}", provider)), _ => return Err(anyhow::anyhow!("Unknown provider: {}", provider)),
}; };
std::env::var(env_var) std::env::var(env_var).map_err(|_| anyhow::anyhow!("Environment variable {} not set for {}", env_var, provider))
.map_err(|_| anyhow::anyhow!("Environment variable {} not set for {}", env_var, provider))
}
} }
}
/// Helper function to deserialize a Vec<String> from either a sequence or a comma-separated string /// Helper function to deserialize a Vec<String> from either a sequence or a comma-separated string
fn deserialize_vec_or_string<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error> fn deserialize_vec_or_string<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
where where
D: serde::Deserializer<'de>, D: serde::Deserializer<'de>,
{ {
struct VecOrString; struct VecOrString;
impl<'de> serde::de::Visitor<'de> for VecOrString { impl<'de> serde::de::Visitor<'de> for VecOrString {
@@ -230,5 +236,4 @@ impl AppConfig {
} }
deserializer.deserialize_any(VecOrString) deserializer.deserialize_any(VecOrString)
} }

130
src/dashboard/auth.rs Normal file
View File

@@ -0,0 +1,130 @@
use axum::{extract::State, response::Json};
use bcrypt;
use serde::Deserialize;
use sqlx::Row;
use tracing::warn;
use super::{ApiResponse, DashboardState};
// Authentication handlers
#[derive(Deserialize)]
pub(super) struct LoginRequest {
pub(super) username: String,
pub(super) password: String,
}
pub(super) async fn handle_login(
State(state): State<DashboardState>,
Json(payload): Json<LoginRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
let pool = &state.app_state.db_pool;
let user_result =
sqlx::query("SELECT username, password_hash, role, must_change_password FROM users WHERE username = ?")
.bind(&payload.username)
.fetch_optional(pool)
.await;
match user_result {
Ok(Some(row)) => {
let hash = row.get::<String, _>("password_hash");
if bcrypt::verify(&payload.password, &hash).unwrap_or(false) {
let username = row.get::<String, _>("username");
let role = row.get::<String, _>("role");
let must_change_password = row.get::<bool, _>("must_change_password");
let token = state
.session_manager
.create_session(username.clone(), role.clone())
.await;
Json(ApiResponse::success(serde_json::json!({
"token": token,
"must_change_password": must_change_password,
"user": {
"username": username,
"name": "Administrator",
"role": role
}
})))
} else {
Json(ApiResponse::error("Invalid username or password".to_string()))
}
}
Ok(None) => Json(ApiResponse::error("Invalid username or password".to_string())),
Err(e) => {
warn!("Database error during login: {}", e);
Json(ApiResponse::error("Login failed due to system error".to_string()))
}
}
}
pub(super) async fn handle_auth_status(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
) -> Json<ApiResponse<serde_json::Value>> {
let token = headers
.get("Authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "));
if let Some(token) = token
&& let Some(session) = state.session_manager.validate_session(token).await
{
return Json(ApiResponse::success(serde_json::json!({
"authenticated": true,
"user": {
"username": session.username,
"name": "Administrator",
"role": session.role
}
})));
}
Json(ApiResponse::error("Not authenticated".to_string()))
}
#[derive(Deserialize)]
pub(super) struct ChangePasswordRequest {
pub(super) current_password: String,
pub(super) new_password: String,
}
pub(super) async fn handle_change_password(
State(state): State<DashboardState>,
Json(payload): Json<ChangePasswordRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
let pool = &state.app_state.db_pool;
// For now, always change 'admin' user
let user_result = sqlx::query("SELECT password_hash FROM users WHERE username = 'admin'")
.fetch_one(pool)
.await;
match user_result {
Ok(row) => {
let hash = row.get::<String, _>("password_hash");
if bcrypt::verify(&payload.current_password, &hash).unwrap_or(false) {
let new_hash = match bcrypt::hash(&payload.new_password, 12) {
Ok(h) => h,
Err(_) => return Json(ApiResponse::error("Failed to hash new password".to_string())),
};
let update_result = sqlx::query(
"UPDATE users SET password_hash = ?, must_change_password = FALSE WHERE username = 'admin'",
)
.bind(new_hash)
.execute(pool)
.await;
match update_result {
Ok(_) => Json(ApiResponse::success(
serde_json::json!({ "message": "Password updated successfully" }),
)),
Err(e) => Json(ApiResponse::error(format!("Failed to update database: {}", e))),
}
} else {
Json(ApiResponse::error("Current password incorrect".to_string()))
}
}
Err(e) => Json(ApiResponse::error(format!("User not found: {}", e))),
}
}

227
src/dashboard/clients.rs Normal file
View File

@@ -0,0 +1,227 @@
use axum::{
extract::{Path, State},
response::Json,
};
use chrono;
use serde::Deserialize;
use serde_json;
use sqlx::Row;
use tracing::warn;
use uuid;
use super::{ApiResponse, DashboardState};
#[derive(Deserialize)]
pub(super) struct CreateClientRequest {
pub(super) name: String,
pub(super) client_id: Option<String>,
}
pub(super) async fn handle_get_clients(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
let pool = &state.app_state.db_pool;
let result = sqlx::query(
r#"
SELECT
client_id as id,
name,
created_at,
total_requests,
total_tokens,
total_cost,
is_active
FROM clients
ORDER BY created_at DESC
"#,
)
.fetch_all(pool)
.await;
match result {
Ok(rows) => {
let clients: Vec<serde_json::Value> = rows
.into_iter()
.map(|row| {
serde_json::json!({
"id": row.get::<String, _>("id"),
"name": row.get::<Option<String>, _>("name").unwrap_or_else(|| "Unnamed".to_string()),
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
"requests_count": row.get::<i64, _>("total_requests"),
"total_tokens": row.get::<i64, _>("total_tokens"),
"total_cost": row.get::<f64, _>("total_cost"),
"status": if row.get::<bool, _>("is_active") { "active" } else { "inactive" },
})
})
.collect();
Json(ApiResponse::success(serde_json::json!(clients)))
}
Err(e) => {
warn!("Failed to fetch clients: {}", e);
Json(ApiResponse::error("Failed to fetch clients".to_string()))
}
}
}
pub(super) async fn handle_create_client(
State(state): State<DashboardState>,
Json(payload): Json<CreateClientRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
let pool = &state.app_state.db_pool;
let client_id = payload
.client_id
.unwrap_or_else(|| format!("client-{}", &uuid::Uuid::new_v4().to_string()[..8]));
let result = sqlx::query(
r#"
INSERT INTO clients (client_id, name, is_active)
VALUES (?, ?, TRUE)
RETURNING *
"#,
)
.bind(&client_id)
.bind(&payload.name)
.fetch_one(pool)
.await;
match result {
Ok(row) => Json(ApiResponse::success(serde_json::json!({
"id": row.get::<String, _>("client_id"),
"name": row.get::<Option<String>, _>("name"),
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
"status": "active",
}))),
Err(e) => {
warn!("Failed to create client: {}", e);
Json(ApiResponse::error(format!("Failed to create client: {}", e)))
}
}
}
pub(super) async fn handle_get_client(
State(state): State<DashboardState>,
Path(id): Path<String>,
) -> Json<ApiResponse<serde_json::Value>> {
let pool = &state.app_state.db_pool;
let result = sqlx::query(
r#"
SELECT
c.client_id as id,
c.name,
c.is_active,
c.created_at,
COALESCE(c.total_tokens, 0) as total_tokens,
COALESCE(c.total_cost, 0.0) as total_cost,
COUNT(r.id) as total_requests,
MAX(r.timestamp) as last_request
FROM clients c
LEFT JOIN llm_requests r ON c.client_id = r.client_id
WHERE c.client_id = ?
GROUP BY c.client_id
"#,
)
.bind(&id)
.fetch_optional(pool)
.await;
match result {
Ok(Some(row)) => Json(ApiResponse::success(serde_json::json!({
"id": row.get::<String, _>("id"),
"name": row.get::<Option<String>, _>("name").unwrap_or_else(|| "Unnamed".to_string()),
"is_active": row.get::<bool, _>("is_active"),
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
"total_tokens": row.get::<i64, _>("total_tokens"),
"total_cost": row.get::<f64, _>("total_cost"),
"total_requests": row.get::<i64, _>("total_requests"),
"last_request": row.get::<Option<chrono::DateTime<chrono::Utc>>, _>("last_request"),
"status": if row.get::<bool, _>("is_active") { "active" } else { "inactive" },
}))),
Ok(None) => Json(ApiResponse::error(format!("Client '{}' not found", id))),
Err(e) => {
warn!("Failed to fetch client: {}", e);
Json(ApiResponse::error(format!("Failed to fetch client: {}", e)))
}
}
}
pub(super) async fn handle_delete_client(
State(state): State<DashboardState>,
Path(id): Path<String>,
) -> Json<ApiResponse<serde_json::Value>> {
let pool = &state.app_state.db_pool;
// Don't allow deleting the default client
if id == "default" {
return Json(ApiResponse::error("Cannot delete default client".to_string()));
}
let result = sqlx::query("DELETE FROM clients WHERE client_id = ?")
.bind(id)
.execute(pool)
.await;
match result {
Ok(_) => Json(ApiResponse::success(serde_json::json!({ "message": "Client deleted" }))),
Err(e) => Json(ApiResponse::error(format!("Failed to delete client: {}", e))),
}
}
pub(super) async fn handle_client_usage(
State(state): State<DashboardState>,
Path(id): Path<String>,
) -> Json<ApiResponse<serde_json::Value>> {
let pool = &state.app_state.db_pool;
// Get per-model breakdown for this client
let result = sqlx::query(
r#"
SELECT
model,
provider,
COUNT(*) as request_count,
SUM(prompt_tokens) as prompt_tokens,
SUM(completion_tokens) as completion_tokens,
SUM(total_tokens) as total_tokens,
SUM(cost) as total_cost,
AVG(duration_ms) as avg_duration_ms
FROM llm_requests
WHERE client_id = ?
GROUP BY model, provider
ORDER BY total_cost DESC
"#,
)
.bind(&id)
.fetch_all(pool)
.await;
match result {
Ok(rows) => {
let breakdown: Vec<serde_json::Value> = rows
.into_iter()
.map(|row| {
serde_json::json!({
"model": row.get::<String, _>("model"),
"provider": row.get::<String, _>("provider"),
"request_count": row.get::<i64, _>("request_count"),
"prompt_tokens": row.get::<i64, _>("prompt_tokens"),
"completion_tokens": row.get::<i64, _>("completion_tokens"),
"total_tokens": row.get::<i64, _>("total_tokens"),
"total_cost": row.get::<f64, _>("total_cost"),
"avg_duration_ms": row.get::<f64, _>("avg_duration_ms"),
})
})
.collect();
Json(ApiResponse::success(serde_json::json!({
"client_id": id,
"breakdown": breakdown,
})))
}
Err(e) => {
warn!("Failed to fetch client usage: {}", e);
Json(ApiResponse::error(format!("Failed to fetch client usage: {}", e)))
}
}
}

File diff suppressed because it is too large Load Diff

116
src/dashboard/models.rs Normal file
View File

@@ -0,0 +1,116 @@
use axum::{
extract::{Path, State},
response::Json,
};
use serde::Deserialize;
use serde_json;
use sqlx::Row;
use std::collections::HashMap;
use super::{ApiResponse, DashboardState};
#[derive(Deserialize)]
pub(super) struct UpdateModelRequest {
pub(super) enabled: bool,
pub(super) prompt_cost: Option<f64>,
pub(super) completion_cost: Option<f64>,
pub(super) mapping: Option<String>,
}
pub(super) async fn handle_get_models(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
let registry = &state.app_state.model_registry;
let pool = &state.app_state.db_pool;
// Load overrides from database
let db_models_result =
sqlx::query("SELECT id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping FROM model_configs")
.fetch_all(pool)
.await;
let mut db_models = HashMap::new();
if let Ok(rows) = db_models_result {
for row in rows {
let id: String = row.get("id");
db_models.insert(id, row);
}
}
let mut models_json = Vec::new();
for (p_id, p_info) in &registry.providers {
for (m_id, m_meta) in &p_info.models {
let mut enabled = true;
let mut prompt_cost = m_meta.cost.as_ref().map(|c| c.input).unwrap_or(0.0);
let mut completion_cost = m_meta.cost.as_ref().map(|c| c.output).unwrap_or(0.0);
let mut mapping = None::<String>;
if let Some(row) = db_models.get(m_id) {
enabled = row.get("enabled");
if let Some(p) = row.get::<Option<f64>, _>("prompt_cost_per_m") {
prompt_cost = p;
}
if let Some(c) = row.get::<Option<f64>, _>("completion_cost_per_m") {
completion_cost = c;
}
mapping = row.get("mapping");
}
models_json.push(serde_json::json!({
"id": m_id,
"provider": p_id,
"name": m_meta.name,
"enabled": enabled,
"prompt_cost": prompt_cost,
"completion_cost": completion_cost,
"mapping": mapping,
"context_limit": m_meta.limit.as_ref().map(|l| l.context).unwrap_or(0),
}));
}
}
Json(ApiResponse::success(serde_json::json!(models_json)))
}
pub(super) async fn handle_update_model(
State(state): State<DashboardState>,
Path(id): Path<String>,
Json(payload): Json<UpdateModelRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
let pool = &state.app_state.db_pool;
// Find provider_id for this model in registry
let provider_id = state
.app_state
.model_registry
.providers
.iter()
.find(|(_, p)| p.models.contains_key(&id))
.map(|(id, _)| id.clone())
.unwrap_or_else(|| "unknown".to_string());
let result = sqlx::query(
r#"
INSERT INTO model_configs (id, provider_id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
enabled = excluded.enabled,
prompt_cost_per_m = excluded.prompt_cost_per_m,
completion_cost_per_m = excluded.completion_cost_per_m,
mapping = excluded.mapping,
updated_at = CURRENT_TIMESTAMP
"#,
)
.bind(&id)
.bind(provider_id)
.bind(payload.enabled)
.bind(payload.prompt_cost)
.bind(payload.completion_cost)
.bind(payload.mapping)
.execute(pool)
.await;
match result {
Ok(_) => Json(ApiResponse::success(serde_json::json!({ "message": "Model updated" }))),
Err(e) => Json(ApiResponse::error(format!("Failed to update model: {}", e))),
}
}

346
src/dashboard/providers.rs Normal file
View File

@@ -0,0 +1,346 @@
use axum::{
extract::{Path, State},
response::Json,
};
use serde::Deserialize;
use serde_json;
use sqlx::Row;
use std::collections::HashMap;
use tracing::warn;
use super::{ApiResponse, DashboardState};
#[derive(Deserialize)]
pub(super) struct UpdateProviderRequest {
pub(super) enabled: bool,
pub(super) base_url: Option<String>,
pub(super) api_key: Option<String>,
pub(super) credit_balance: Option<f64>,
pub(super) low_credit_threshold: Option<f64>,
}
pub(super) async fn handle_get_providers(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
let registry = &state.app_state.model_registry;
let config = &state.app_state.config;
let pool = &state.app_state.db_pool;
// Load all overrides from database
let db_configs_result =
sqlx::query("SELECT id, enabled, base_url, credit_balance, low_credit_threshold FROM provider_configs")
.fetch_all(pool)
.await;
let mut db_configs = HashMap::new();
if let Ok(rows) = db_configs_result {
for row in rows {
let id: String = row.get("id");
let enabled: bool = row.get("enabled");
let base_url: Option<String> = row.get("base_url");
let balance: f64 = row.get("credit_balance");
let threshold: f64 = row.get("low_credit_threshold");
db_configs.insert(id, (enabled, base_url, balance, threshold));
}
}
let mut providers_json = Vec::new();
// Define the list of providers we support
let provider_ids = vec!["openai", "gemini", "deepseek", "grok", "ollama"];
for id in provider_ids {
// Get base config
let (mut enabled, mut base_url, display_name) = match id {
"openai" => (
config.providers.openai.enabled,
config.providers.openai.base_url.clone(),
"OpenAI",
),
"gemini" => (
config.providers.gemini.enabled,
config.providers.gemini.base_url.clone(),
"Google Gemini",
),
"deepseek" => (
config.providers.deepseek.enabled,
config.providers.deepseek.base_url.clone(),
"DeepSeek",
),
"grok" => (
config.providers.grok.enabled,
config.providers.grok.base_url.clone(),
"xAI Grok",
),
"ollama" => (
config.providers.ollama.enabled,
config.providers.ollama.base_url.clone(),
"Ollama",
),
_ => (false, "".to_string(), "Unknown"),
};
let mut balance = 0.0;
let mut threshold = 5.0;
// Apply database overrides
if let Some((db_enabled, db_url, db_balance, db_threshold)) = db_configs.get(id) {
enabled = *db_enabled;
if let Some(url) = db_url {
base_url = url.clone();
}
balance = *db_balance;
threshold = *db_threshold;
}
// Find models for this provider in registry
let mut models = Vec::new();
if let Some(p_info) = registry.providers.get(id) {
models = p_info.models.keys().cloned().collect();
} else if id == "ollama" {
models = config.providers.ollama.models.clone();
}
// Determine status
let status = if !enabled {
"disabled"
} else {
// Check if it's actually initialized in the provider manager
if state.app_state.provider_manager.get_provider(id).await.is_some() {
// Check circuit breaker
if state
.app_state
.rate_limit_manager
.check_provider_request(id)
.await
.unwrap_or(true)
{
"online"
} else {
"degraded"
}
} else {
"error" // Enabled but failed to initialize (e.g. missing API key)
}
};
providers_json.push(serde_json::json!({
"id": id,
"name": display_name,
"enabled": enabled,
"status": status,
"models": models,
"base_url": base_url,
"credit_balance": balance,
"low_credit_threshold": threshold,
"last_used": None::<String>,
}));
}
Json(ApiResponse::success(serde_json::json!(providers_json)))
}
pub(super) async fn handle_get_provider(
State(state): State<DashboardState>,
Path(name): Path<String>,
) -> Json<ApiResponse<serde_json::Value>> {
let registry = &state.app_state.model_registry;
let config = &state.app_state.config;
let pool = &state.app_state.db_pool;
// Validate provider name
let (mut enabled, mut base_url, display_name) = match name.as_str() {
"openai" => (
config.providers.openai.enabled,
config.providers.openai.base_url.clone(),
"OpenAI",
),
"gemini" => (
config.providers.gemini.enabled,
config.providers.gemini.base_url.clone(),
"Google Gemini",
),
"deepseek" => (
config.providers.deepseek.enabled,
config.providers.deepseek.base_url.clone(),
"DeepSeek",
),
"grok" => (
config.providers.grok.enabled,
config.providers.grok.base_url.clone(),
"xAI Grok",
),
"ollama" => (
config.providers.ollama.enabled,
config.providers.ollama.base_url.clone(),
"Ollama",
),
_ => return Json(ApiResponse::error(format!("Unknown provider '{}'", name))),
};
let mut balance = 0.0;
let mut threshold = 5.0;
// Apply database overrides
let db_config = sqlx::query(
"SELECT enabled, base_url, credit_balance, low_credit_threshold FROM provider_configs WHERE id = ?",
)
.bind(&name)
.fetch_optional(pool)
.await;
if let Ok(Some(row)) = db_config {
enabled = row.get::<bool, _>("enabled");
if let Some(url) = row.get::<Option<String>, _>("base_url") {
base_url = url;
}
balance = row.get::<f64, _>("credit_balance");
threshold = row.get::<f64, _>("low_credit_threshold");
}
// Find models for this provider
let mut models = Vec::new();
if let Some(p_info) = registry.providers.get(name.as_str()) {
models = p_info.models.keys().cloned().collect();
} else if name == "ollama" {
models = config.providers.ollama.models.clone();
}
// Determine status
let status = if !enabled {
"disabled"
} else if state.app_state.provider_manager.get_provider(&name).await.is_some() {
if state
.app_state
.rate_limit_manager
.check_provider_request(&name)
.await
.unwrap_or(true)
{
"online"
} else {
"degraded"
}
} else {
"error"
};
Json(ApiResponse::success(serde_json::json!({
"id": name,
"name": display_name,
"enabled": enabled,
"status": status,
"models": models,
"base_url": base_url,
"credit_balance": balance,
"low_credit_threshold": threshold,
"last_used": None::<String>,
})))
}
pub(super) async fn handle_update_provider(
State(state): State<DashboardState>,
Path(name): Path<String>,
Json(payload): Json<UpdateProviderRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
let pool = &state.app_state.db_pool;
// Update or insert into database
let result = sqlx::query(
r#"
INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, credit_balance, low_credit_threshold)
VALUES (?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
enabled = excluded.enabled,
base_url = excluded.base_url,
api_key = COALESCE(excluded.api_key, provider_configs.api_key),
credit_balance = COALESCE(excluded.credit_balance, provider_configs.credit_balance),
low_credit_threshold = COALESCE(excluded.low_credit_threshold, provider_configs.low_credit_threshold),
updated_at = CURRENT_TIMESTAMP
"#
)
.bind(&name)
.bind(name.to_uppercase())
.bind(payload.enabled)
.bind(&payload.base_url)
.bind(&payload.api_key)
.bind(payload.credit_balance)
.bind(payload.low_credit_threshold)
.execute(pool)
.await;
match result {
Ok(_) => {
// Re-initialize provider in manager
if let Err(e) = state
.app_state
.provider_manager
.initialize_provider(&name, &state.app_state.config, &state.app_state.db_pool)
.await
{
warn!("Failed to re-initialize provider {}: {}", name, e);
return Json(ApiResponse::error(format!(
"Provider settings saved but initialization failed: {}",
e
)));
}
Json(ApiResponse::success(
serde_json::json!({ "message": "Provider updated and re-initialized" }),
))
}
Err(e) => {
warn!("Failed to update provider config: {}", e);
Json(ApiResponse::error(format!("Failed to update provider: {}", e)))
}
}
}
pub(super) async fn handle_test_provider(
State(state): State<DashboardState>,
Path(name): Path<String>,
) -> Json<ApiResponse<serde_json::Value>> {
let start = std::time::Instant::now();
let provider = match state.app_state.provider_manager.get_provider(&name).await {
Some(p) => p,
None => {
return Json(ApiResponse::error(format!(
"Provider '{}' not found or not enabled",
name
)));
}
};
// Pick a real model for this provider from the registry
let test_model = state
.app_state
.model_registry
.providers
.get(&name)
.and_then(|p| p.models.keys().next().cloned())
.unwrap_or_else(|| name.clone());
let test_request = crate::models::UnifiedRequest {
client_id: "system-test".to_string(),
model: test_model,
messages: vec![crate::models::UnifiedMessage {
role: "user".to_string(),
content: vec![crate::models::ContentPart::Text { text: "Hi".to_string() }],
}],
temperature: None,
max_tokens: Some(5),
stream: false,
has_images: false,
};
match provider.chat_completion(test_request).await {
Ok(_) => {
let latency = start.elapsed().as_millis();
Json(ApiResponse::success(serde_json::json!({
"success": true,
"latency": latency,
"message": "Connection test successful"
})))
}
Err(e) => Json(ApiResponse::error(format!("Provider test failed: {}", e))),
}
}

64
src/dashboard/sessions.rs Normal file
View File

@@ -0,0 +1,64 @@
use chrono::{DateTime, Duration, Utc};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Clone, Debug)]
pub struct Session {
pub username: String,
pub role: String,
pub created_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
}
#[derive(Clone)]
pub struct SessionManager {
sessions: Arc<RwLock<HashMap<String, Session>>>,
ttl_hours: i64,
}
impl SessionManager {
pub fn new(ttl_hours: i64) -> Self {
Self {
sessions: Arc::new(RwLock::new(HashMap::new())),
ttl_hours,
}
}
/// Create a new session and return the session token.
pub async fn create_session(&self, username: String, role: String) -> String {
let token = format!("session-{}", uuid::Uuid::new_v4());
let now = Utc::now();
let session = Session {
username,
role,
created_at: now,
expires_at: now + Duration::hours(self.ttl_hours),
};
self.sessions.write().await.insert(token.clone(), session);
token
}
/// Validate a session token and return the session if valid and not expired.
pub async fn validate_session(&self, token: &str) -> Option<Session> {
let sessions = self.sessions.read().await;
sessions.get(token).and_then(|s| {
if s.expires_at > Utc::now() {
Some(s.clone())
} else {
None
}
})
}
/// Revoke (delete) a session by token.
pub async fn revoke_session(&self, token: &str) {
self.sessions.write().await.remove(token);
}
/// Remove all expired sessions from the store.
pub async fn cleanup_expired(&self) {
let now = Utc::now();
self.sessions.write().await.retain(|_, s| s.expires_at > now);
}
}

193
src/dashboard/system.rs Normal file
View File

@@ -0,0 +1,193 @@
use axum::{extract::State, response::Json};
use chrono;
use serde_json;
use sqlx::Row;
use std::collections::HashMap;
use tracing::warn;
use super::{ApiResponse, DashboardState};
pub(super) async fn handle_system_health(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
let mut components = HashMap::new();
components.insert("api_server".to_string(), "online".to_string());
components.insert("database".to_string(), "online".to_string());
// Check provider health via circuit breakers
let provider_ids: Vec<String> = state
.app_state
.provider_manager
.get_all_providers()
.await
.iter()
.map(|p| p.name().to_string())
.collect();
for p_id in provider_ids {
if state
.app_state
.rate_limit_manager
.check_provider_request(&p_id)
.await
.unwrap_or(true)
{
components.insert(p_id, "online".to_string());
} else {
components.insert(p_id, "degraded".to_string());
}
}
// Read real memory usage from /proc/self/status
let memory_mb = std::fs::read_to_string("/proc/self/status")
.ok()
.and_then(|s| s.lines().find(|l| l.starts_with("VmRSS:")).map(|l| l.to_string()))
.and_then(|l| l.split_whitespace().nth(1).and_then(|v| v.parse::<f64>().ok()))
.map(|kb| kb / 1024.0)
.unwrap_or(0.0);
// Get real database pool stats
let db_pool_size = state.app_state.db_pool.size();
let db_pool_idle = state.app_state.db_pool.num_idle();
Json(ApiResponse::success(serde_json::json!({
"status": "healthy",
"timestamp": chrono::Utc::now().to_rfc3339(),
"components": components,
"metrics": {
"memory_usage_mb": (memory_mb * 10.0).round() / 10.0,
"db_connections_active": db_pool_size - db_pool_idle as u32,
"db_connections_idle": db_pool_idle,
}
})))
}
pub(super) async fn handle_system_logs(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
let pool = &state.app_state.db_pool;
let result = sqlx::query(
r#"
SELECT
id,
timestamp,
client_id,
provider,
model,
prompt_tokens,
completion_tokens,
total_tokens,
cost,
status,
error_message,
duration_ms
FROM llm_requests
ORDER BY timestamp DESC
LIMIT 100
"#,
)
.fetch_all(pool)
.await;
match result {
Ok(rows) => {
let logs: Vec<serde_json::Value> = rows
.into_iter()
.map(|row| {
serde_json::json!({
"id": row.get::<i64, _>("id"),
"timestamp": row.get::<chrono::DateTime<chrono::Utc>, _>("timestamp"),
"client_id": row.get::<String, _>("client_id"),
"provider": row.get::<String, _>("provider"),
"model": row.get::<String, _>("model"),
"tokens": row.get::<i64, _>("total_tokens"),
"cost": row.get::<f64, _>("cost"),
"status": row.get::<String, _>("status"),
"error": row.get::<Option<String>, _>("error_message"),
"duration": row.get::<i64, _>("duration_ms"),
})
})
.collect();
Json(ApiResponse::success(serde_json::json!(logs)))
}
Err(e) => {
warn!("Failed to fetch system logs: {}", e);
Json(ApiResponse::error("Failed to fetch system logs".to_string()))
}
}
}
pub(super) async fn handle_system_backup(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
let pool = &state.app_state.db_pool;
let backup_id = format!("backup-{}", chrono::Utc::now().timestamp());
let backup_path = format!("data/{}.db", backup_id);
// Ensure the data directory exists
if let Err(e) = std::fs::create_dir_all("data") {
return Json(ApiResponse::error(format!("Failed to create backup directory: {}", e)));
}
// Use SQLite VACUUM INTO for a consistent backup
let result = sqlx::query(&format!("VACUUM INTO '{}'", backup_path))
.execute(pool)
.await;
match result {
Ok(_) => {
// Get backup file size
let size_bytes = std::fs::metadata(&backup_path).map(|m| m.len()).unwrap_or(0);
Json(ApiResponse::success(serde_json::json!({
"success": true,
"message": "Backup completed successfully",
"backup_id": backup_id,
"backup_path": backup_path,
"size_bytes": size_bytes,
})))
}
Err(e) => {
warn!("Database backup failed: {}", e);
Json(ApiResponse::error(format!("Backup failed: {}", e)))
}
}
}
pub(super) async fn handle_get_settings(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
let registry = &state.app_state.model_registry;
let provider_count = registry.providers.len();
let model_count: usize = registry.providers.values().map(|p| p.models.len()).sum();
Json(ApiResponse::success(serde_json::json!({
"server": {
"auth_tokens": state.app_state.auth_tokens.iter().map(|t| mask_token(t)).collect::<Vec<_>>(),
"version": env!("CARGO_PKG_VERSION"),
},
"registry": {
"provider_count": provider_count,
"model_count": model_count,
},
"database": {
"type": "SQLite",
}
})))
}
pub(super) async fn handle_update_settings(
State(_state): State<DashboardState>,
) -> Json<ApiResponse<serde_json::Value>> {
Json(ApiResponse::error(
"Changing settings at runtime is not yet supported. Please update your config file and restart the server."
.to_string(),
))
}
// Helper functions
fn mask_token(token: &str) -> String {
if token.len() <= 8 {
return "*****".to_string();
}
let masked_len = token.len().min(12);
let visible_len = 4;
let mask_len = masked_len - visible_len;
format!("{}{}", "*".repeat(mask_len), &token[token.len() - visible_len..])
}

330
src/dashboard/usage.rs Normal file
View File

@@ -0,0 +1,330 @@
use axum::{extract::State, response::Json};
use chrono;
use serde_json;
use sqlx::Row;
use tracing::warn;
use super::{ApiResponse, DashboardState};
pub(super) async fn handle_usage_summary(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
let pool = &state.app_state.db_pool;
// Total stats
let total_stats = sqlx::query(
r#"
SELECT
COUNT(*) as total_requests,
COALESCE(SUM(total_tokens), 0) as total_tokens,
COALESCE(SUM(cost), 0.0) as total_cost,
COUNT(DISTINCT client_id) as active_clients
FROM llm_requests
"#,
)
.fetch_one(pool);
// Today's stats
let today = chrono::Utc::now().format("%Y-%m-%d").to_string();
let today_stats = sqlx::query(
r#"
SELECT
COUNT(*) as today_requests,
COALESCE(SUM(total_tokens), 0) as today_tokens,
COALESCE(SUM(cost), 0.0) as today_cost
FROM llm_requests
WHERE strftime('%Y-%m-%d', timestamp) = ?
"#,
)
.bind(today)
.fetch_one(pool);
// Error stats
let error_stats = sqlx::query(
r#"
SELECT
COUNT(*) as total,
SUM(CASE WHEN status = 'error' THEN 1 ELSE 0 END) as errors
FROM llm_requests
"#,
)
.fetch_one(pool);
// Average response time
let avg_response = sqlx::query(
r#"
SELECT COALESCE(AVG(duration_ms), 0.0) as avg_duration
FROM llm_requests
WHERE status = 'success'
"#,
)
.fetch_one(pool);
match tokio::join!(total_stats, today_stats, error_stats, avg_response) {
(Ok(t), Ok(d), Ok(e), Ok(a)) => {
let total_requests: i64 = t.get("total_requests");
let total_tokens: i64 = t.get("total_tokens");
let total_cost: f64 = t.get("total_cost");
let active_clients: i64 = t.get("active_clients");
let today_requests: i64 = d.get("today_requests");
let today_cost: f64 = d.get("today_cost");
let total_count: i64 = e.get("total");
let error_count: i64 = e.get("errors");
let error_rate = if total_count > 0 {
(error_count as f64 / total_count as f64) * 100.0
} else {
0.0
};
let avg_response_time: f64 = a.get("avg_duration");
Json(ApiResponse::success(serde_json::json!({
"total_requests": total_requests,
"total_tokens": total_tokens,
"total_cost": total_cost,
"active_clients": active_clients,
"today_requests": today_requests,
"today_cost": today_cost,
"error_rate": error_rate,
"avg_response_time": avg_response_time,
})))
}
_ => Json(ApiResponse::error("Failed to fetch usage statistics".to_string())),
}
}
pub(super) async fn handle_time_series(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
let pool = &state.app_state.db_pool;
let now = chrono::Utc::now();
let twenty_four_hours_ago = now - chrono::Duration::hours(24);
let result = sqlx::query(
r#"
SELECT
strftime('%H:00', timestamp) as hour,
COUNT(*) as requests,
SUM(total_tokens) as tokens,
SUM(cost) as cost
FROM llm_requests
WHERE timestamp >= ?
GROUP BY hour
ORDER BY hour
"#,
)
.bind(twenty_four_hours_ago)
.fetch_all(pool)
.await;
match result {
Ok(rows) => {
let mut series = Vec::new();
for row in rows {
let hour: String = row.get("hour");
let requests: i64 = row.get("requests");
let tokens: i64 = row.get("tokens");
let cost: f64 = row.get("cost");
series.push(serde_json::json!({
"time": hour,
"requests": requests,
"tokens": tokens,
"cost": cost,
}));
}
Json(ApiResponse::success(serde_json::json!({
"series": series,
"period": "24h"
})))
}
Err(e) => {
warn!("Failed to fetch time series data: {}", e);
Json(ApiResponse::error("Failed to fetch time series data".to_string()))
}
}
}
pub(super) async fn handle_clients_usage(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
// Query database for client usage statistics
let pool = &state.app_state.db_pool;
let result = sqlx::query(
r#"
SELECT
client_id,
COUNT(*) as requests,
SUM(total_tokens) as tokens,
SUM(cost) as cost,
MAX(timestamp) as last_request
FROM llm_requests
GROUP BY client_id
ORDER BY requests DESC
"#,
)
.fetch_all(pool)
.await;
match result {
Ok(rows) => {
let mut client_usage = Vec::new();
for row in rows {
let client_id: String = row.get("client_id");
let requests: i64 = row.get("requests");
let tokens: i64 = row.get("tokens");
let cost: f64 = row.get("cost");
let last_request: Option<chrono::DateTime<chrono::Utc>> = row.get("last_request");
client_usage.push(serde_json::json!({
"client_id": client_id,
"client_name": client_id,
"requests": requests,
"tokens": tokens,
"cost": cost,
"last_request": last_request,
}));
}
Json(ApiResponse::success(serde_json::json!(client_usage)))
}
Err(e) => {
warn!("Failed to fetch client usage data: {}", e);
Json(ApiResponse::error("Failed to fetch client usage data".to_string()))
}
}
}
pub(super) async fn handle_providers_usage(
State(state): State<DashboardState>,
) -> Json<ApiResponse<serde_json::Value>> {
// Query database for provider usage statistics
let pool = &state.app_state.db_pool;
let result = sqlx::query(
r#"
SELECT
provider,
COUNT(*) as requests,
COALESCE(SUM(total_tokens), 0) as tokens,
COALESCE(SUM(cost), 0.0) as cost
FROM llm_requests
GROUP BY provider
ORDER BY requests DESC
"#,
)
.fetch_all(pool)
.await;
match result {
Ok(rows) => {
let mut provider_usage = Vec::new();
for row in rows {
let provider: String = row.get("provider");
let requests: i64 = row.get("requests");
let tokens: i64 = row.get("tokens");
let cost: f64 = row.get("cost");
provider_usage.push(serde_json::json!({
"provider": provider,
"requests": requests,
"tokens": tokens,
"cost": cost,
}));
}
Json(ApiResponse::success(serde_json::json!(provider_usage)))
}
Err(e) => {
warn!("Failed to fetch provider usage data: {}", e);
Json(ApiResponse::error("Failed to fetch provider usage data".to_string()))
}
}
}
pub(super) async fn handle_detailed_usage(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
let pool = &state.app_state.db_pool;
let result = sqlx::query(
r#"
SELECT
strftime('%Y-%m-%d', timestamp) as date,
client_id,
provider,
model,
COUNT(*) as requests,
SUM(total_tokens) as tokens,
SUM(cost) as cost
FROM llm_requests
GROUP BY date, client_id, provider, model
ORDER BY date DESC
LIMIT 100
"#,
)
.fetch_all(pool)
.await;
match result {
Ok(rows) => {
let usage: Vec<serde_json::Value> = rows
.into_iter()
.map(|row| {
serde_json::json!({
"date": row.get::<String, _>("date"),
"client": row.get::<String, _>("client_id"),
"provider": row.get::<String, _>("provider"),
"model": row.get::<String, _>("model"),
"requests": row.get::<i64, _>("requests"),
"tokens": row.get::<i64, _>("tokens"),
"cost": row.get::<f64, _>("cost"),
})
})
.collect();
Json(ApiResponse::success(serde_json::json!(usage)))
}
Err(e) => {
warn!("Failed to fetch detailed usage: {}", e);
Json(ApiResponse::error("Failed to fetch detailed usage".to_string()))
}
}
}
pub(super) async fn handle_analytics_breakdown(
State(state): State<DashboardState>,
) -> Json<ApiResponse<serde_json::Value>> {
let pool = &state.app_state.db_pool;
// Model breakdown
let models =
sqlx::query("SELECT model as label, COUNT(*) as value FROM llm_requests GROUP BY model ORDER BY value DESC")
.fetch_all(pool);
// Client breakdown
let clients = sqlx::query(
"SELECT client_id as label, COUNT(*) as value FROM llm_requests GROUP BY client_id ORDER BY value DESC",
)
.fetch_all(pool);
match tokio::join!(models, clients) {
(Ok(m_rows), Ok(c_rows)) => {
let model_breakdown: Vec<serde_json::Value> = m_rows
.into_iter()
.map(|r| serde_json::json!({ "label": r.get::<String, _>("label"), "value": r.get::<i64, _>("value") }))
.collect();
let client_breakdown: Vec<serde_json::Value> = c_rows
.into_iter()
.map(|r| serde_json::json!({ "label": r.get::<String, _>("label"), "value": r.get::<i64, _>("value") }))
.collect();
Json(ApiResponse::success(serde_json::json!({
"models": model_breakdown,
"clients": client_breakdown
})))
}
_ => Json(ApiResponse::error("Failed to fetch analytics breakdown".to_string())),
}
}

View File

@@ -0,0 +1,75 @@
use axum::{
extract::{
State,
ws::{Message, WebSocket, WebSocketUpgrade},
},
response::IntoResponse,
};
use serde_json;
use tracing::info;
use super::DashboardState;
// WebSocket handler
pub(super) async fn handle_websocket(ws: WebSocketUpgrade, State(state): State<DashboardState>) -> impl IntoResponse {
ws.on_upgrade(|socket| handle_websocket_connection(socket, state))
}
pub(super) async fn handle_websocket_connection(mut socket: WebSocket, state: DashboardState) {
info!("WebSocket connection established");
// Subscribe to events from the global bus
let mut rx = state.app_state.dashboard_tx.subscribe();
// Send initial connection message
let _ = socket
.send(Message::Text(
serde_json::json!({
"type": "connected",
"message": "Connected to LLM Proxy Dashboard"
})
.to_string()
.into(),
))
.await;
// Handle incoming messages and broadcast events
loop {
tokio::select! {
// Receive broadcast events
Ok(event) = rx.recv() => {
let Ok(json_str) = serde_json::to_string(&event) else {
continue;
};
let message = Message::Text(json_str.into());
if socket.send(message).await.is_err() {
break;
}
}
// Receive WebSocket messages
result = socket.recv() => {
match result {
Some(Ok(Message::Text(text))) => {
handle_websocket_message(&text, &state).await;
}
_ => break,
}
}
}
}
info!("WebSocket connection closed");
}
pub(super) async fn handle_websocket_message(text: &str, state: &DashboardState) {
// Parse and handle WebSocket messages
if let Ok(data) = serde_json::from_str::<serde_json::Value>(text)
&& let Some("ping") = data.get("type").and_then(|v| v.as_str())
{
let _ = state.app_state.dashboard_tx.send(serde_json::json!({
"type": "pong",
"payload": {}
}));
}
}

View File

@@ -1,5 +1,5 @@
use anyhow::Result; use anyhow::Result;
use sqlx::sqlite::{SqlitePool, SqliteConnectOptions}; use sqlx::sqlite::{SqliteConnectOptions, SqlitePool};
use std::str::FromStr; use std::str::FromStr;
use tracing::info; use tracing::info;
@@ -9,17 +9,16 @@ pub type DbPool = SqlitePool;
pub async fn init(config: &DatabaseConfig) -> Result<DbPool> { pub async fn init(config: &DatabaseConfig) -> Result<DbPool> {
// Ensure the database directory exists // Ensure the database directory exists
if let Some(parent) = config.path.parent() { if let Some(parent) = config.path.parent()
if !parent.as_os_str().is_empty() { && !parent.as_os_str().is_empty()
{
tokio::fs::create_dir_all(parent).await?; tokio::fs::create_dir_all(parent).await?;
} }
}
let database_path = config.path.to_string_lossy().to_string(); let database_path = config.path.to_string_lossy().to_string();
info!("Connecting to database at {}", database_path); info!("Connecting to database at {}", database_path);
let options = SqliteConnectOptions::from_str(&format!("sqlite:{}", database_path))? let options = SqliteConnectOptions::from_str(&format!("sqlite:{}", database_path))?.create_if_missing(true);
.create_if_missing(true);
let pool = SqlitePool::connect_with(options).await?; let pool = SqlitePool::connect_with(options).await?;
@@ -91,7 +90,7 @@ async fn run_migrations(pool: &DbPool) -> Result<()> {
low_credit_threshold REAL DEFAULT 5.0, low_credit_threshold REAL DEFAULT 5.0,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
) )
"# "#,
) )
.execute(pool) .execute(pool)
.await?; .await?;
@@ -110,7 +109,7 @@ async fn run_migrations(pool: &DbPool) -> Result<()> {
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (provider_id) REFERENCES provider_configs(id) ON DELETE CASCADE FOREIGN KEY (provider_id) REFERENCES provider_configs(id) ON DELETE CASCADE
) )
"# "#,
) )
.execute(pool) .execute(pool)
.await?; .await?;
@@ -123,64 +122,57 @@ async fn run_migrations(pool: &DbPool) -> Result<()> {
username TEXT UNIQUE NOT NULL, username TEXT UNIQUE NOT NULL,
password_hash TEXT NOT NULL, password_hash TEXT NOT NULL,
role TEXT DEFAULT 'admin', role TEXT DEFAULT 'admin',
must_change_password BOOLEAN DEFAULT FALSE,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP created_at DATETIME DEFAULT CURRENT_TIMESTAMP
) )
"# "#,
) )
.execute(pool) .execute(pool)
.await?; .await?;
// Add must_change_password column if it doesn't exist (migration for existing DBs)
let _ = sqlx::query("ALTER TABLE users ADD COLUMN must_change_password BOOLEAN DEFAULT FALSE")
.execute(pool)
.await;
// Insert default admin user if none exists (default password: admin) // Insert default admin user if none exists (default password: admin)
let user_count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users") let user_count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users").fetch_one(pool).await?;
.fetch_one(pool)
.await?;
if user_count.0 == 0 { if user_count.0 == 0 {
// 'admin' hashed with default cost (12) // 'admin' hashed with default cost (12)
let default_admin_hash = bcrypt::hash("admin", 12).unwrap(); let default_admin_hash =
bcrypt::hash("admin", 12).map_err(|e| anyhow::anyhow!("Failed to hash default password: {}", e))?;
sqlx::query( sqlx::query(
"INSERT INTO users (username, password_hash, role) VALUES ('admin', ?, 'admin')" "INSERT INTO users (username, password_hash, role, must_change_password) VALUES ('admin', ?, 'admin', TRUE)"
) )
.bind(default_admin_hash) .bind(default_admin_hash)
.execute(pool) .execute(pool)
.await?; .await?;
info!("Created default admin user with password 'admin'"); info!("Created default admin user with password 'admin' (must change on first login)");
} }
// Create indices // Create indices
sqlx::query( sqlx::query("CREATE INDEX IF NOT EXISTS idx_clients_client_id ON clients(client_id)")
"CREATE INDEX IF NOT EXISTS idx_clients_client_id ON clients(client_id)"
)
.execute(pool) .execute(pool)
.await?; .await?;
sqlx::query( sqlx::query("CREATE INDEX IF NOT EXISTS idx_clients_created_at ON clients(created_at)")
"CREATE INDEX IF NOT EXISTS idx_clients_created_at ON clients(created_at)"
)
.execute(pool) .execute(pool)
.await?; .await?;
sqlx::query( sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_timestamp ON llm_requests(timestamp)")
"CREATE INDEX IF NOT EXISTS idx_llm_requests_timestamp ON llm_requests(timestamp)"
)
.execute(pool) .execute(pool)
.await?; .await?;
sqlx::query( sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_client_id ON llm_requests(client_id)")
"CREATE INDEX IF NOT EXISTS idx_llm_requests_client_id ON llm_requests(client_id)"
)
.execute(pool) .execute(pool)
.await?; .await?;
sqlx::query( sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_provider ON llm_requests(provider)")
"CREATE INDEX IF NOT EXISTS idx_llm_requests_provider ON llm_requests(provider)"
)
.execute(pool) .execute(pool)
.await?; .await?;
sqlx::query( sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_status ON llm_requests(status)")
"CREATE INDEX IF NOT EXISTS idx_llm_requests_status ON llm_requests(status)"
)
.execute(pool) .execute(pool)
.await?; .await?;

View File

@@ -6,8 +6,8 @@
pub mod auth; pub mod auth;
pub mod client; pub mod client;
pub mod config; pub mod config;
pub mod database;
pub mod dashboard; pub mod dashboard;
pub mod database;
pub mod errors; pub mod errors;
pub mod logging; pub mod logging;
pub mod models; pub mod models;
@@ -19,27 +19,29 @@ pub mod state;
pub mod utils; pub mod utils;
// Re-exports for convenience // Re-exports for convenience
pub use auth::*; pub use auth::{AuthenticatedClient, validate_token};
pub use config::*; pub use config::{
pub use database::*; AppConfig, DatabaseConfig, DeepSeekConfig, GeminiConfig, GrokConfig, ModelMappingConfig, ModelPricing,
pub use errors::*; OllamaConfig, OpenAIConfig, PricingConfig, ProviderConfig, ServerConfig,
pub use logging::*; };
pub use models::*; pub use database::{DbPool, init as init_db, test_connection};
pub use providers::*; pub use errors::AppError;
pub use server::*; pub use logging::{LoggingContext, RequestLog, RequestLogger};
pub use state::*; pub use models::{
ChatChoice, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, ChatMessage,
ChatStreamChoice, ChatStreamDelta, ContentPart, ContentPartValue, FromOpenAI, ImageUrl, MessageContent,
OpenAIContentPart, OpenAIMessage, OpenAIRequest, ToOpenAI, UnifiedMessage, UnifiedRequest, Usage,
};
pub use providers::{Provider, ProviderManager, ProviderResponse, ProviderStreamChunk};
pub use server::router;
pub use state::AppState;
/// Test utilities for integration testing /// Test utilities for integration testing
#[cfg(test)] #[cfg(test)]
pub mod test_utils { pub mod test_utils {
use std::sync::Arc; use std::sync::Arc;
use crate::{ use crate::{client::ClientManager, providers::ProviderManager, rate_limiting::RateLimitManager, state::AppState};
state::AppState,
rate_limiting::RateLimitManager,
client::ClientManager,
providers::ProviderManager,
};
use sqlx::sqlite::SqlitePool; use sqlx::sqlite::SqlitePool;
/// Create a test application state /// Create a test application state
@@ -53,7 +55,9 @@ pub mod test_utils {
crate::database::init(&crate::config::DatabaseConfig { crate::database::init(&crate::config::DatabaseConfig {
path: std::path::PathBuf::from(":memory:"), path: std::path::PathBuf::from(":memory:"),
max_connections: 5, max_connections: 5,
}).await.expect("Failed to initialize test database"); })
.await
.expect("Failed to initialize test database");
let rate_limit_manager = RateLimitManager::new( let rate_limit_manager = RateLimitManager::new(
crate::rate_limiting::RateLimiterConfig::default(), crate::rate_limiting::RateLimiterConfig::default(),
@@ -82,11 +86,35 @@ pub mod test_utils {
max_connections: 5, max_connections: 5,
}, },
providers: crate::config::ProviderConfig { providers: crate::config::ProviderConfig {
openai: crate::config::OpenAIConfig { api_key_env: "OPENAI_API_KEY".to_string(), base_url: "".to_string(), default_model: "".to_string(), enabled: true }, openai: crate::config::OpenAIConfig {
gemini: crate::config::GeminiConfig { api_key_env: "GEMINI_API_KEY".to_string(), base_url: "".to_string(), default_model: "".to_string(), enabled: true }, api_key_env: "OPENAI_API_KEY".to_string(),
deepseek: crate::config::DeepSeekConfig { api_key_env: "DEEPSEEK_API_KEY".to_string(), base_url: "".to_string(), default_model: "".to_string(), enabled: true }, base_url: "".to_string(),
grok: crate::config::GrokConfig { api_key_env: "GROK_API_KEY".to_string(), base_url: "".to_string(), default_model: "".to_string(), enabled: true }, default_model: "".to_string(),
ollama: crate::config::OllamaConfig { base_url: "".to_string(), enabled: true, models: vec![] }, enabled: true,
},
gemini: crate::config::GeminiConfig {
api_key_env: "GEMINI_API_KEY".to_string(),
base_url: "".to_string(),
default_model: "".to_string(),
enabled: true,
},
deepseek: crate::config::DeepSeekConfig {
api_key_env: "DEEPSEEK_API_KEY".to_string(),
base_url: "".to_string(),
default_model: "".to_string(),
enabled: true,
},
grok: crate::config::GrokConfig {
api_key_env: "GROK_API_KEY".to_string(),
base_url: "".to_string(),
default_model: "".to_string(),
enabled: true,
},
ollama: crate::config::OllamaConfig {
base_url: "".to_string(),
enabled: true,
models: vec![],
},
}, },
model_mapping: crate::config::ModelMappingConfig { patterns: vec![] }, model_mapping: crate::config::ModelMappingConfig { patterns: vec![] },
pricing: crate::config::PricingConfig { pricing: crate::config::PricingConfig {

View File

@@ -1,8 +1,8 @@
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use serde::Serialize;
use sqlx::SqlitePool; use sqlx::SqlitePool;
use tokio::sync::broadcast; use tokio::sync::broadcast;
use tracing::warn; use tracing::warn;
use serde::Serialize;
use crate::errors::AppError; use crate::errors::AppError;
@@ -77,16 +77,14 @@ impl RequestLogger {
.bind(log.status) .bind(log.status)
.bind(log.error_message) .bind(log.error_message)
.bind(log.duration_ms as i64) .bind(log.duration_ms as i64)
.bind(None::<String>) // request_body - TODO: store serialized request .bind(None::<String>) // request_body - optional, not stored to save disk space
.bind(None::<String>) // response_body - TODO: store serialized response or error .bind(None::<String>) // response_body - optional, not stored to save disk space
.execute(&mut *tx) .execute(&mut *tx)
.await?; .await?;
// Deduct from provider balance if successful // Deduct from provider balance if successful
if log.cost > 0.0 { if log.cost > 0.0 {
sqlx::query( sqlx::query("UPDATE provider_configs SET credit_balance = credit_balance - ? WHERE id = ?")
"UPDATE provider_configs SET credit_balance = credit_balance - ? WHERE id = ?"
)
.bind(log.cost) .bind(log.cost)
.bind(&log.provider) .bind(&log.provider)
.execute(&mut *tx) .execute(&mut *tx)

View File

@@ -1,16 +1,15 @@
use anyhow::Result; use anyhow::Result;
use axum::{Router, routing::get}; use axum::{Router, routing::get};
use std::net::SocketAddr; use std::net::SocketAddr;
use tracing::{info, error}; use tracing::{error, info};
use llm_proxy::{ use llm_proxy::{
config::AppConfig, config::AppConfig,
state::AppState, dashboard, database,
providers::ProviderManager, providers::ProviderManager,
database, rate_limiting::{CircuitBreakerConfig, RateLimitManager, RateLimiterConfig},
server, server,
dashboard, state::AppState,
rate_limiting::{RateLimitManager, RateLimiterConfig, CircuitBreakerConfig},
}; };
#[tokio::main] #[tokio::main]
@@ -43,22 +42,28 @@ async fn main() -> Result<()> {
} }
// Create rate limit manager // Create rate limit manager
let rate_limit_manager = RateLimitManager::new( let rate_limit_manager = RateLimitManager::new(RateLimiterConfig::default(), CircuitBreakerConfig::default());
RateLimiterConfig::default(),
CircuitBreakerConfig::default(),
);
// Fetch model registry from models.dev // Fetch model registry from models.dev
let model_registry = match llm_proxy::utils::registry::fetch_registry().await { let model_registry = match llm_proxy::utils::registry::fetch_registry().await {
Ok(registry) => registry, Ok(registry) => registry,
Err(e) => { Err(e) => {
error!("Failed to fetch model registry: {}. Using empty registry.", e); error!("Failed to fetch model registry: {}. Using empty registry.", e);
llm_proxy::models::registry::ModelRegistry { providers: std::collections::HashMap::new() } llm_proxy::models::registry::ModelRegistry {
providers: std::collections::HashMap::new(),
}
} }
}; };
// Create application state // Create application state
let state = AppState::new(config.clone(), provider_manager, db_pool, rate_limit_manager, model_registry, config.server.auth_tokens.clone()); let state = AppState::new(
config.clone(),
provider_manager,
db_pool,
rate_limit_manager,
model_registry,
config.server.auth_tokens.clone(),
);
// Create application router // Create application router
let app = Router::new() let app = Router::new()

View File

@@ -198,9 +198,7 @@ impl TryFrom<ChatCompletionRequest> for UnifiedRequest {
.into_iter() .into_iter()
.map(|msg| { .map(|msg| {
let (content, _images_in_message) = match msg.content { let (content, _images_in_message) = match msg.content {
MessageContent::Text { content } => { MessageContent::Text { content } => (vec![ContentPart::Text { text: content }], false),
(vec![ContentPart::Text { text: content }], false)
}
MessageContent::Parts { content } => { MessageContent::Parts { content } => {
let mut unified_content = Vec::new(); let mut unified_content = Vec::new();
let mut has_images_in_msg = false; let mut has_images_in_msg = false;
@@ -213,18 +211,16 @@ impl TryFrom<ChatCompletionRequest> for UnifiedRequest {
ContentPartValue::ImageUrl { image_url } => { ContentPartValue::ImageUrl { image_url } => {
has_images_in_msg = true; has_images_in_msg = true;
has_images = true; has_images = true;
unified_content.push(ContentPart::Image( unified_content.push(ContentPart::Image(crate::multimodal::ImageInput::from_url(
crate::multimodal::ImageInput::from_url(image_url.url) image_url.url,
)); )));
} }
} }
} }
(unified_content, has_images_in_msg) (unified_content, has_images_in_msg)
} }
MessageContent::None => { MessageContent::None => (vec![], false),
(vec![], false)
}
}; };
UnifiedMessage { UnifiedMessage {

View File

@@ -7,24 +7,18 @@
//! 4. Provider-specific image format conversion //! 4. Provider-specific image format conversion
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use base64::{engine::general_purpose, Engine as _}; use base64::{Engine as _, engine::general_purpose};
use tracing::{info, warn}; use tracing::{info, warn};
/// Supported image formats for multimodal input /// Supported image formats for multimodal input
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum ImageInput { pub enum ImageInput {
/// Base64-encoded image data with MIME type /// Base64-encoded image data with MIME type
Base64 { Base64 { data: String, mime_type: String },
data: String,
mime_type: String,
},
/// URL to fetch image from /// URL to fetch image from
Url(String), Url(String),
/// Raw bytes with MIME type /// Raw bytes with MIME type
Bytes { Bytes { data: Vec<u8>, mime_type: String },
data: Vec<u8>,
mime_type: String,
},
} }
impl ImageInput { impl ImageInput {
@@ -63,9 +57,7 @@ impl ImageInput {
Self::Url(url) => { Self::Url(url) => {
// Fetch image from URL // Fetch image from URL
info!("Fetching image from URL: {}", url); info!("Fetching image from URL: {}", url);
let response = reqwest::get(url) let response = reqwest::get(url).await.context("Failed to fetch image from URL")?;
.await
.context("Failed to fetch image from URL")?;
if !response.status().is_success() { if !response.status().is_success() {
anyhow::bail!("Failed to fetch image: HTTP {}", response.status()); anyhow::bail!("Failed to fetch image: HTTP {}", response.status());
@@ -89,13 +81,15 @@ impl ImageInput {
/// Get image dimensions (width, height) /// Get image dimensions (width, height)
pub async fn get_dimensions(&self) -> Result<(u32, u32)> { pub async fn get_dimensions(&self) -> Result<(u32, u32)> {
let bytes = match self { let bytes = match self {
Self::Base64 { data, .. } => { Self::Base64 { data, .. } => general_purpose::STANDARD
general_purpose::STANDARD.decode(data).context("Failed to decode base64")? .decode(data)
} .context("Failed to decode base64")?,
Self::Bytes { data, .. } => data.clone(), Self::Bytes { data, .. } => data.clone(),
Self::Url(_) => { Self::Url(_) => {
let (base64_data, _) = self.to_base64().await?; let (base64_data, _) = self.to_base64().await?;
general_purpose::STANDARD.decode(&base64_data).context("Failed to decode base64")? general_purpose::STANDARD
.decode(&base64_data)
.context("Failed to decode base64")?
} }
}; };
@@ -178,8 +172,10 @@ impl ImageConverter {
/// Detect if a model supports multimodal input /// Detect if a model supports multimodal input
pub fn model_supports_multimodal(model: &str) -> bool { pub fn model_supports_multimodal(model: &str) -> bool {
// OpenAI vision models // OpenAI vision models
if (model.starts_with("gpt-4") && (model.contains("vision") || model.contains("-v") || model.contains("4o"))) || if (model.starts_with("gpt-4") && (model.contains("vision") || model.contains("-v") || model.contains("4o")))
model.starts_with("o1-") || model.starts_with("o3-") { || model.starts_with("o1-")
|| model.starts_with("o3-")
{
return true; return true;
} }
@@ -208,8 +204,9 @@ pub fn parse_openai_content(content: &serde_json::Value) -> Result<Vec<(String,
} else if let Some(content_array) = content.as_array() { } else if let Some(content_array) = content.as_array() {
// Array of content parts (text and/or images) // Array of content parts (text and/or images)
for part in content_array { for part in content_array {
if let Some(part_obj) = part.as_object() { if let Some(part_obj) = part.as_object()
if let Some(part_type) = part_obj.get("type").and_then(|t| t.as_str()) { && let Some(part_type) = part_obj.get("type").and_then(|t| t.as_str())
{
match part_type { match part_type {
"text" => { "text" => {
if let Some(text) = part_obj.get("text").and_then(|t| t.as_str()) { if let Some(text) = part_obj.get("text").and_then(|t| t.as_str()) {
@@ -217,8 +214,9 @@ pub fn parse_openai_content(content: &serde_json::Value) -> Result<Vec<(String,
} }
} }
"image_url" => { "image_url" => {
if let Some(image_url_obj) = part_obj.get("image_url").and_then(|o| o.as_object()) { if let Some(image_url_obj) = part_obj.get("image_url").and_then(|o| o.as_object())
if let Some(url) = image_url_obj.get("url").and_then(|u| u.as_str()) { && let Some(url) = image_url_obj.get("url").and_then(|u| u.as_str())
{
if url.starts_with("data:") { if url.starts_with("data:") {
// Parse data URL // Parse data URL
if let Some((mime_type, data)) = parse_data_url(url) { if let Some((mime_type, data)) = parse_data_url(url) {
@@ -232,7 +230,6 @@ pub fn parse_openai_content(content: &serde_json::Value) -> Result<Vec<(String,
} }
} }
} }
}
_ => { _ => {
warn!("Unknown content part type: {}", part_type); warn!("Unknown content part type: {}", part_type);
} }
@@ -240,7 +237,6 @@ pub fn parse_openai_content(content: &serde_json::Value) -> Result<Vec<(String,
} }
} }
} }
}
Ok(parts) Ok(parts)
} }
@@ -278,8 +274,10 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_model_supports_multimodal() { async fn test_model_supports_multimodal() {
assert!(ImageConverter::model_supports_multimodal("gpt-4-vision-preview")); assert!(ImageConverter::model_supports_multimodal("gpt-4-vision-preview"));
assert!(ImageConverter::model_supports_multimodal("gpt-4o"));
assert!(ImageConverter::model_supports_multimodal("gemini-pro-vision")); assert!(ImageConverter::model_supports_multimodal("gemini-pro-vision"));
assert!(ImageConverter::model_supports_multimodal("gemini-pro"));
assert!(!ImageConverter::model_supports_multimodal("gpt-3.5-turbo")); assert!(!ImageConverter::model_supports_multimodal("gpt-3.5-turbo"));
assert!(!ImageConverter::model_supports_multimodal("gemini-pro")); assert!(!ImageConverter::model_supports_multimodal("claude-3-opus"));
} }
} }

View File

@@ -1,14 +1,10 @@
use async_trait::async_trait;
use anyhow::Result; use anyhow::Result;
use futures::stream::{BoxStream, StreamExt}; use async_trait::async_trait;
use serde_json::Value; use futures::stream::BoxStream;
use crate::{ use super::helpers;
models::UnifiedRequest,
errors::AppError,
config::AppConfig,
};
use super::{ProviderResponse, ProviderStreamChunk}; use super::{ProviderResponse, ProviderStreamChunk};
use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest};
pub struct DeepSeekProvider { pub struct DeepSeekProvider {
client: reqwest::Client, client: reqwest::Client,
@@ -23,7 +19,11 @@ impl DeepSeekProvider {
Self::new_with_key(config, app_config, api_key) Self::new_with_key(config, app_config, api_key)
} }
pub fn new_with_key(config: &crate::config::DeepSeekConfig, app_config: &AppConfig, api_key: String) -> Result<Self> { pub fn new_with_key(
config: &crate::config::DeepSeekConfig,
app_config: &AppConfig,
api_key: String,
) -> Result<Self> {
Ok(Self { Ok(Self {
client: reqwest::Client::new(), client: reqwest::Client::new(),
config: config.clone(), config: config.clone(),
@@ -47,42 +47,13 @@ impl super::Provider for DeepSeekProvider {
false false
} }
async fn chat_completion( async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
&self, let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
request: UnifiedRequest, let body = helpers::build_openai_body(&request, messages_json, false);
) -> Result<ProviderResponse, AppError> {
// Build the OpenAI-compatible body
let mut body = serde_json::json!({
"model": request.model,
"messages": request.messages.iter().map(|m| {
serde_json::json!({
"role": m.role,
"content": m.content.iter().map(|p| {
match p {
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
crate::models::ContentPart::Image(image_input) => {
// DeepSeek currently doesn't support images in the same way, but we'll try to be standard
let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
serde_json::json!({
"type": "image_url",
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
})
}
}
}).collect::<Vec<_>>()
})
}).collect::<Vec<_>>(),
"stream": false,
});
if let Some(temp) = request.temperature { let response = self
body["temperature"] = serde_json::json!(temp); .client
} .post(format!("{}/chat/completions", self.config.base_url))
if let Some(max_tokens) = request.max_tokens {
body["max_tokens"] = serde_json::json!(max_tokens);
}
let response = self.client.post(format!("{}/chat/completions", self.config.base_url))
.header("Authorization", format!("Bearer {}", self.api_key)) .header("Authorization", format!("Bearer {}", self.api_key))
.json(&body) .json(&body)
.send() .send()
@@ -94,119 +65,52 @@ impl super::Provider for DeepSeekProvider {
return Err(AppError::ProviderError(format!("DeepSeek API error: {}", error_text))); return Err(AppError::ProviderError(format!("DeepSeek API error: {}", error_text)));
} }
let resp_json: Value = response.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?; let resp_json: serde_json::Value = response
.json()
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
let choice = resp_json["choices"].get(0).ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?; helpers::parse_openai_response(&resp_json, request.model)
let message = &choice["message"];
let content = message["content"].as_str().unwrap_or_default().to_string();
let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string());
let usage = &resp_json["usage"];
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
Ok(ProviderResponse {
content,
reasoning_content,
prompt_tokens,
completion_tokens,
total_tokens,
model: request.model,
})
} }
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> { fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request)) Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
} }
fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64 { fn calculate_cost(
if let Some(metadata) = registry.find_model(model) { &self,
if let Some(cost) = &metadata.cost { model: &str,
return (prompt_tokens as f64 * cost.input / 1_000_000.0) + prompt_tokens: u32,
(completion_tokens as f64 * cost.output / 1_000_000.0); completion_tokens: u32,
} registry: &crate::models::registry::ModelRegistry,
} ) -> f64 {
helpers::calculate_cost_with_registry(
let (prompt_rate, completion_rate) = self.pricing.iter() model,
.find(|p| model.contains(&p.model)) prompt_tokens,
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million)) completion_tokens,
.unwrap_or((0.14, 0.28)); registry,
&self.pricing,
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0) 0.14,
0.28,
)
} }
async fn chat_completion_stream( async fn chat_completion_stream(
&self, &self,
request: UnifiedRequest, request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> { ) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
let mut body = serde_json::json!({ // DeepSeek doesn't support images in streaming, use text-only
"model": request.model, let messages_json = helpers::messages_to_openai_json_text_only(&request.messages).await?;
"messages": request.messages.iter().map(|m| { let body = helpers::build_openai_body(&request, messages_json, true);
serde_json::json!({
"role": m.role,
"content": m.content.iter().map(|p| {
match p {
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
crate::models::ContentPart::Image(_) => serde_json::json!({ "type": "text", "text": "[Image]" }),
}
}).collect::<Vec<_>>()
})
}).collect::<Vec<_>>(),
"stream": true,
});
if let Some(temp) = request.temperature { let es = reqwest_eventsource::EventSource::new(
body["temperature"] = serde_json::json!(temp); self.client
} .post(format!("{}/chat/completions", self.config.base_url))
if let Some(max_tokens) = request.max_tokens {
body["max_tokens"] = serde_json::json!(max_tokens);
}
// Create eventsource stream
use reqwest_eventsource::{EventSource, Event};
let es = EventSource::new(self.client.post(format!("{}/chat/completions", self.config.base_url))
.header("Authorization", format!("Bearer {}", self.api_key)) .header("Authorization", format!("Bearer {}", self.api_key))
.json(&body)) .json(&body),
)
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?; .map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
let model = request.model.clone(); Ok(helpers::create_openai_stream(es, request.model, None))
let stream = async_stream::try_stream! {
let mut es = es;
while let Some(event) = es.next().await {
match event {
Ok(Event::Message(msg)) => {
if msg.data == "[DONE]" {
break;
}
let chunk: Value = serde_json::from_str(&msg.data)
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
if let Some(choice) = chunk["choices"].get(0) {
let delta = &choice["delta"];
let content = delta["content"].as_str().unwrap_or_default().to_string();
let reasoning_content = delta["reasoning_content"].as_str().map(|s| s.to_string());
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
yield ProviderStreamChunk {
content,
reasoning_content,
finish_reason,
model: model.clone(),
};
}
}
Ok(_) => continue,
Err(e) => {
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
}
}
}
};
Ok(Box::pin(stream))
} }
} }

View File

@@ -1,14 +1,10 @@
use async_trait::async_trait;
use anyhow::Result; use anyhow::Result;
use serde::{Deserialize, Serialize}; use async_trait::async_trait;
use futures::stream::BoxStream; use futures::stream::BoxStream;
use serde::{Deserialize, Serialize};
use crate::{
models::UnifiedRequest,
errors::AppError,
config::AppConfig,
};
use super::{ProviderResponse, ProviderStreamChunk}; use super::{ProviderResponse, ProviderStreamChunk};
use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest};
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
struct GeminiRequest { struct GeminiRequest {
@@ -61,8 +57,6 @@ struct GeminiResponse {
usage_metadata: Option<GeminiUsageMetadata>, usage_metadata: Option<GeminiUsageMetadata>,
} }
pub struct GeminiProvider { pub struct GeminiProvider {
client: reqwest::Client, client: reqwest::Client,
config: crate::config::GeminiConfig, config: crate::config::GeminiConfig,
@@ -104,10 +98,7 @@ impl super::Provider for GeminiProvider {
true // Gemini supports vision true // Gemini supports vision
} }
async fn chat_completion( async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
&self,
request: UnifiedRequest,
) -> Result<ProviderResponse, AppError> {
// Convert UnifiedRequest to Gemini request // Convert UnifiedRequest to Gemini request
let mut contents = Vec::with_capacity(request.messages.len()); let mut contents = Vec::with_capacity(request.messages.len());
@@ -123,7 +114,9 @@ impl super::Provider for GeminiProvider {
}); });
} }
crate::models::ContentPart::Image(image_input) => { crate::models::ContentPart::Image(image_input) => {
let (base64_data, mime_type) = image_input.to_base64().await let (base64_data, mime_type) = image_input
.to_base64()
.await
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?; .map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
parts.push(GeminiPart { parts.push(GeminiPart {
@@ -143,10 +136,7 @@ impl super::Provider for GeminiProvider {
_ => "user".to_string(), _ => "user".to_string(),
}; };
contents.push(GeminiContent { contents.push(GeminiContent { parts, role });
parts,
role,
});
} }
if contents.is_empty() { if contents.is_empty() {
@@ -169,15 +159,13 @@ impl super::Provider for GeminiProvider {
}; };
// Build URL // Build URL
let url = format!("{}/models/{}:generateContent?key={}", let url = format!("{}/models/{}:generateContent", self.config.base_url, request.model,);
self.config.base_url,
request.model,
self.api_key
);
// Send request // Send request
let response = self.client let response = self
.client
.post(&url) .post(&url)
.header("x-goog-api-key", &self.api_key)
.json(&gemini_request) .json(&gemini_request)
.send() .send()
.await .await
@@ -187,7 +175,10 @@ impl super::Provider for GeminiProvider {
let status = response.status(); let status = response.status();
if !status.is_success() { if !status.is_success() {
let error_text = response.text().await.unwrap_or_default(); let error_text = response.text().await.unwrap_or_default();
return Err(AppError::ProviderError(format!("Gemini API error ({}): {}", status, error_text))); return Err(AppError::ProviderError(format!(
"Gemini API error ({}): {}",
status, error_text
)));
} }
let gemini_response: GeminiResponse = response let gemini_response: GeminiResponse = response
@@ -196,16 +187,29 @@ impl super::Provider for GeminiProvider {
.map_err(|e| AppError::ProviderError(format!("Failed to parse response: {}", e)))?; .map_err(|e| AppError::ProviderError(format!("Failed to parse response: {}", e)))?;
// Extract content from first candidate // Extract content from first candidate
let content = gemini_response.candidates let content = gemini_response
.candidates
.first() .first()
.and_then(|c| c.content.parts.first()) .and_then(|c| c.content.parts.first())
.and_then(|p| p.text.clone()) .and_then(|p| p.text.clone())
.unwrap_or_default(); .unwrap_or_default();
// Extract token usage // Extract token usage
let prompt_tokens = gemini_response.usage_metadata.as_ref().map(|u| u.prompt_token_count).unwrap_or(0); let prompt_tokens = gemini_response
let completion_tokens = gemini_response.usage_metadata.as_ref().map(|u| u.candidates_token_count).unwrap_or(0); .usage_metadata
let total_tokens = gemini_response.usage_metadata.as_ref().map(|u| u.total_token_count).unwrap_or(0); .as_ref()
.map(|u| u.prompt_token_count)
.unwrap_or(0);
let completion_tokens = gemini_response
.usage_metadata
.as_ref()
.map(|u| u.candidates_token_count)
.unwrap_or(0);
let total_tokens = gemini_response
.usage_metadata
.as_ref()
.map(|u| u.total_token_count)
.unwrap_or(0);
Ok(ProviderResponse { Ok(ProviderResponse {
content, content,
@@ -221,20 +225,22 @@ impl super::Provider for GeminiProvider {
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request)) Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
} }
fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64 { fn calculate_cost(
if let Some(metadata) = registry.find_model(model) { &self,
if let Some(cost) = &metadata.cost { model: &str,
return (prompt_tokens as f64 * cost.input / 1_000_000.0) + prompt_tokens: u32,
(completion_tokens as f64 * cost.output / 1_000_000.0); completion_tokens: u32,
} registry: &crate::models::registry::ModelRegistry,
} ) -> f64 {
super::helpers::calculate_cost_with_registry(
let (prompt_rate, completion_rate) = self.pricing.iter() model,
.find(|p| model.contains(&p.model)) prompt_tokens,
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million)) completion_tokens,
.unwrap_or((0.075, 0.30)); // Default to Gemini 2.0 Flash price if not found registry,
&self.pricing,
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0) 0.075,
0.30,
)
} }
async fn chat_completion_stream( async fn chat_completion_stream(
@@ -256,7 +262,9 @@ impl super::Provider for GeminiProvider {
}); });
} }
crate::models::ContentPart::Image(image_input) => { crate::models::ContentPart::Image(image_input) => {
let (base64_data, mime_type) = image_input.to_base64().await let (base64_data, mime_type) = image_input
.to_base64()
.await
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?; .map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
parts.push(GeminiPart { parts.push(GeminiPart {
@@ -276,10 +284,7 @@ impl super::Provider for GeminiProvider {
_ => "user".to_string(), _ => "user".to_string(),
}; };
contents.push(GeminiContent { contents.push(GeminiContent { parts, role });
parts,
role,
});
} }
// Build generation config // Build generation config
@@ -298,17 +303,21 @@ impl super::Provider for GeminiProvider {
}; };
// Build URL for streaming // Build URL for streaming
let url = format!("{}/models/{}:streamGenerateContent?alt=sse&key={}", let url = format!(
self.config.base_url, "{}/models/{}:streamGenerateContent?alt=sse",
request.model, self.config.base_url, request.model,
self.api_key
); );
// Create eventsource stream // Create eventsource stream
use reqwest_eventsource::{EventSource, Event};
use futures::StreamExt; use futures::StreamExt;
use reqwest_eventsource::{Event, EventSource};
let es = EventSource::new(self.client.post(&url).json(&gemini_request)) let es = EventSource::new(
self.client
.post(&url)
.header("x-goog-api-key", &self.api_key)
.json(&gemini_request),
)
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?; .map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
let model = request.model.clone(); let model = request.model.clone();

View File

@@ -1,18 +1,14 @@
use async_trait::async_trait;
use anyhow::Result; use anyhow::Result;
use futures::stream::{BoxStream, StreamExt}; use async_trait::async_trait;
use serde_json::Value; use futures::stream::BoxStream;
use crate::{ use super::helpers;
models::UnifiedRequest,
errors::AppError,
config::AppConfig,
};
use super::{ProviderResponse, ProviderStreamChunk}; use super::{ProviderResponse, ProviderStreamChunk};
use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest};
pub struct GrokProvider { pub struct GrokProvider {
client: reqwest::Client, client: reqwest::Client,
_config: crate::config::GrokConfig, config: crate::config::GrokConfig,
api_key: String, api_key: String,
pricing: Vec<crate::config::ModelPricing>, pricing: Vec<crate::config::ModelPricing>,
} }
@@ -26,7 +22,7 @@ impl GrokProvider {
pub fn new_with_key(config: &crate::config::GrokConfig, app_config: &AppConfig, api_key: String) -> Result<Self> { pub fn new_with_key(config: &crate::config::GrokConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
Ok(Self { Ok(Self {
client: reqwest::Client::new(), client: reqwest::Client::new(),
_config: config.clone(), config: config.clone(),
api_key, api_key,
pricing: app_config.pricing.grok.clone(), pricing: app_config.pricing.grok.clone(),
}) })
@@ -47,40 +43,13 @@ impl super::Provider for GrokProvider {
true true
} }
async fn chat_completion( async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
&self, let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
request: UnifiedRequest, let body = helpers::build_openai_body(&request, messages_json, false);
) -> Result<ProviderResponse, AppError> {
let mut body = serde_json::json!({
"model": request.model,
"messages": request.messages.iter().map(|m| {
serde_json::json!({
"role": m.role,
"content": m.content.iter().map(|p| {
match p {
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
crate::models::ContentPart::Image(image_input) => {
let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
serde_json::json!({
"type": "image_url",
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
})
}
}
}).collect::<Vec<_>>()
})
}).collect::<Vec<_>>(),
"stream": false,
});
if let Some(temp) = request.temperature { let response = self
body["temperature"] = serde_json::json!(temp); .client
} .post(format!("{}/chat/completions", self.config.base_url))
if let Some(max_tokens) = request.max_tokens {
body["max_tokens"] = serde_json::json!(max_tokens);
}
let response = self.client.post(format!("{}/chat/completions", self._config.base_url))
.header("Authorization", format!("Bearer {}", self.api_key)) .header("Authorization", format!("Bearer {}", self.api_key))
.json(&body) .json(&body)
.send() .send()
@@ -92,125 +61,51 @@ impl super::Provider for GrokProvider {
return Err(AppError::ProviderError(format!("Grok API error: {}", error_text))); return Err(AppError::ProviderError(format!("Grok API error: {}", error_text)));
} }
let resp_json: Value = response.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?; let resp_json: serde_json::Value = response
.json()
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
let choice = resp_json["choices"].get(0).ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?; helpers::parse_openai_response(&resp_json, request.model)
let message = &choice["message"];
let content = message["content"].as_str().unwrap_or_default().to_string();
let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string());
let usage = &resp_json["usage"];
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
Ok(ProviderResponse {
content,
reasoning_content,
prompt_tokens,
completion_tokens,
total_tokens,
model: request.model,
})
} }
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> { fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request)) Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
} }
fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64 { fn calculate_cost(
if let Some(metadata) = registry.find_model(model) { &self,
if let Some(cost) = &metadata.cost { model: &str,
return (prompt_tokens as f64 * cost.input / 1_000_000.0) + prompt_tokens: u32,
(completion_tokens as f64 * cost.output / 1_000_000.0); completion_tokens: u32,
} registry: &crate::models::registry::ModelRegistry,
} ) -> f64 {
helpers::calculate_cost_with_registry(
let (prompt_rate, completion_rate) = self.pricing.iter() model,
.find(|p| model.contains(&p.model)) prompt_tokens,
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million)) completion_tokens,
.unwrap_or((5.0, 15.0)); registry,
&self.pricing,
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0) 5.0,
15.0,
)
} }
async fn chat_completion_stream( async fn chat_completion_stream(
&self, &self,
request: UnifiedRequest, request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> { ) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
let mut body = serde_json::json!({ let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
"model": request.model, let body = helpers::build_openai_body(&request, messages_json, true);
"messages": request.messages.iter().map(|m| {
serde_json::json!({
"role": m.role,
"content": m.content.iter().map(|p| {
match p {
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
crate::models::ContentPart::Image(image_input) => {
let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
serde_json::json!({
"type": "image_url",
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
})
}
}
}).collect::<Vec<_>>()
})
}).collect::<Vec<_>>(),
"stream": true,
});
if let Some(temp) = request.temperature { let es = reqwest_eventsource::EventSource::new(
body["temperature"] = serde_json::json!(temp); self.client
} .post(format!("{}/chat/completions", self.config.base_url))
if let Some(max_tokens) = request.max_tokens {
body["max_tokens"] = serde_json::json!(max_tokens);
}
// Create eventsource stream
use reqwest_eventsource::{EventSource, Event};
let es = EventSource::new(self.client.post(format!("{}/chat/completions", self._config.base_url))
.header("Authorization", format!("Bearer {}", self.api_key)) .header("Authorization", format!("Bearer {}", self.api_key))
.json(&body)) .json(&body),
)
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?; .map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
let model = request.model.clone(); Ok(helpers::create_openai_stream(es, request.model, None))
let stream = async_stream::try_stream! {
let mut es = es;
while let Some(event) = es.next().await {
match event {
Ok(Event::Message(msg)) => {
if msg.data == "[DONE]" {
break;
}
let chunk: Value = serde_json::from_str(&msg.data)
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
if let Some(choice) = chunk["choices"].get(0) {
let delta = &choice["delta"];
let content = delta["content"].as_str().unwrap_or_default().to_string();
let reasoning_content = delta["reasoning_content"].as_str().map(|s| s.to_string());
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
yield ProviderStreamChunk {
content,
reasoning_content,
finish_reason,
model: model.clone(),
};
}
}
Ok(_) => continue,
Err(e) => {
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
}
}
}
};
Ok(Box::pin(stream))
} }
} }

189
src/providers/helpers.rs Normal file
View File

@@ -0,0 +1,189 @@
use super::{ProviderResponse, ProviderStreamChunk};
use crate::errors::AppError;
use crate::models::{ContentPart, UnifiedMessage, UnifiedRequest};
use futures::stream::{BoxStream, StreamExt};
use serde_json::Value;
/// Convert messages to OpenAI-compatible JSON, resolving images asynchronously.
///
/// This avoids the deadlock caused by `futures::executor::block_on` inside a
/// Tokio async context. All image base64 conversions are awaited properly.
pub async fn messages_to_openai_json(messages: &[UnifiedMessage]) -> Result<Vec<serde_json::Value>, AppError> {
let mut result = Vec::new();
for m in messages {
let mut parts = Vec::new();
for p in &m.content {
match p {
ContentPart::Text { text } => {
parts.push(serde_json::json!({ "type": "text", "text": text }));
}
ContentPart::Image(image_input) => {
let (base64_data, mime_type) = image_input
.to_base64()
.await
.map_err(|e| AppError::MultimodalError(e.to_string()))?;
parts.push(serde_json::json!({
"type": "image_url",
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
}));
}
}
}
result.push(serde_json::json!({
"role": m.role,
"content": parts
}));
}
Ok(result)
}
/// Convert messages to OpenAI-compatible JSON, but replace images with a
/// text placeholder "[Image]". Useful for providers that don't support
/// multimodal in streaming mode or at all.
pub async fn messages_to_openai_json_text_only(
messages: &[UnifiedMessage],
) -> Result<Vec<serde_json::Value>, AppError> {
let mut result = Vec::new();
for m in messages {
let mut parts = Vec::new();
for p in &m.content {
match p {
ContentPart::Text { text } => {
parts.push(serde_json::json!({ "type": "text", "text": text }));
}
ContentPart::Image(_) => {
parts.push(serde_json::json!({ "type": "text", "text": "[Image]" }));
}
}
}
result.push(serde_json::json!({
"role": m.role,
"content": parts
}));
}
Ok(result)
}
/// Build an OpenAI-compatible request body from a UnifiedRequest and pre-converted messages.
pub fn build_openai_body(
request: &UnifiedRequest,
messages_json: Vec<serde_json::Value>,
stream: bool,
) -> serde_json::Value {
let mut body = serde_json::json!({
"model": request.model,
"messages": messages_json,
"stream": stream,
});
if let Some(temp) = request.temperature {
body["temperature"] = serde_json::json!(temp);
}
if let Some(max_tokens) = request.max_tokens {
body["max_tokens"] = serde_json::json!(max_tokens);
}
body
}
/// Parse an OpenAI-compatible chat completion response JSON into a ProviderResponse.
pub fn parse_openai_response(resp_json: &Value, model: String) -> Result<ProviderResponse, AppError> {
let choice = resp_json["choices"]
.get(0)
.ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?;
let message = &choice["message"];
let content = message["content"].as_str().unwrap_or_default().to_string();
let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string());
let usage = &resp_json["usage"];
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
Ok(ProviderResponse {
content,
reasoning_content,
prompt_tokens,
completion_tokens,
total_tokens,
model,
})
}
/// Create an SSE stream that parses OpenAI-compatible streaming chunks.
///
/// The optional `reasoning_field` allows overriding the field name for
/// reasoning content (e.g., "thought" for Ollama).
pub fn create_openai_stream(
es: reqwest_eventsource::EventSource,
model: String,
reasoning_field: Option<&'static str>,
) -> BoxStream<'static, Result<ProviderStreamChunk, AppError>> {
use reqwest_eventsource::Event;
let stream = async_stream::try_stream! {
let mut es = es;
while let Some(event) = es.next().await {
match event {
Ok(Event::Message(msg)) => {
if msg.data == "[DONE]" {
break;
}
let chunk: Value = serde_json::from_str(&msg.data)
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
if let Some(choice) = chunk["choices"].get(0) {
let delta = &choice["delta"];
let content = delta["content"].as_str().unwrap_or_default().to_string();
let reasoning_content = delta["reasoning_content"]
.as_str()
.or_else(|| reasoning_field.and_then(|f| delta[f].as_str()))
.map(|s| s.to_string());
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
yield ProviderStreamChunk {
content,
reasoning_content,
finish_reason,
model: model.clone(),
};
}
}
Ok(_) => continue,
Err(e) => {
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
}
}
}
};
Box::pin(stream)
}
/// Calculate cost using the model registry first, then falling back to provider pricing config.
pub fn calculate_cost_with_registry(
model: &str,
prompt_tokens: u32,
completion_tokens: u32,
registry: &crate::models::registry::ModelRegistry,
pricing: &[crate::config::ModelPricing],
default_prompt_rate: f64,
default_completion_rate: f64,
) -> f64 {
if let Some(metadata) = registry.find_model(model)
&& let Some(cost) = &metadata.cost
{
return (prompt_tokens as f64 * cost.input / 1_000_000.0)
+ (completion_tokens as f64 * cost.output / 1_000_000.0);
}
let (prompt_rate, completion_rate) = pricing
.iter()
.find(|p| model.contains(&p.model))
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
.unwrap_or((default_prompt_rate, default_completion_rate));
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
}

View File

@@ -1,17 +1,18 @@
use async_trait::async_trait;
use anyhow::Result; use anyhow::Result;
use std::sync::Arc; use async_trait::async_trait;
use futures::stream::BoxStream; use futures::stream::BoxStream;
use sqlx::Row; use sqlx::Row;
use std::sync::Arc;
use crate::models::UnifiedRequest;
use crate::errors::AppError; use crate::errors::AppError;
use crate::models::UnifiedRequest;
pub mod openai;
pub mod gemini;
pub mod deepseek; pub mod deepseek;
pub mod gemini;
pub mod grok; pub mod grok;
pub mod helpers;
pub mod ollama; pub mod ollama;
pub mod openai;
#[async_trait] #[async_trait]
pub trait Provider: Send + Sync { pub trait Provider: Send + Sync {
@@ -25,10 +26,7 @@ pub trait Provider: Send + Sync {
fn supports_multimodal(&self) -> bool; fn supports_multimodal(&self) -> bool;
/// Process a chat completion request /// Process a chat completion request
async fn chat_completion( async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError>;
&self,
request: UnifiedRequest,
) -> Result<ProviderResponse, AppError>;
/// Process a streaming chat completion request /// Process a streaming chat completion request
async fn chat_completion_stream( async fn chat_completion_stream(
@@ -40,7 +38,13 @@ pub trait Provider: Send + Sync {
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32>; fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32>;
/// Calculate cost based on token usage and model using the registry /// Calculate cost based on token usage and model using the registry
fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64; fn calculate_cost(
&self,
model: &str,
prompt_tokens: u32,
completion_tokens: u32,
registry: &crate::models::registry::ModelRegistry,
) -> f64;
} }
pub struct ProviderResponse { pub struct ProviderResponse {
@@ -64,11 +68,8 @@ use tokio::sync::RwLock;
use crate::config::AppConfig; use crate::config::AppConfig;
use crate::providers::{ use crate::providers::{
deepseek::DeepSeekProvider, gemini::GeminiProvider, grok::GrokProvider, ollama::OllamaProvider,
openai::OpenAIProvider, openai::OpenAIProvider,
gemini::GeminiProvider,
deepseek::DeepSeekProvider,
grok::GrokProvider,
ollama::OllamaProvider,
}; };
#[derive(Clone)] #[derive(Clone)]
@@ -76,6 +77,12 @@ pub struct ProviderManager {
providers: Arc<RwLock<Vec<Arc<dyn Provider>>>>, providers: Arc<RwLock<Vec<Arc<dyn Provider>>>>,
} }
impl Default for ProviderManager {
fn default() -> Self {
Self::new()
}
}
impl ProviderManager { impl ProviderManager {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
@@ -84,7 +91,12 @@ impl ProviderManager {
} }
/// Initialize a provider by name using config and database overrides /// Initialize a provider by name using config and database overrides
pub async fn initialize_provider(&self, name: &str, app_config: &AppConfig, db_pool: &crate::database::DbPool) -> Result<()> { pub async fn initialize_provider(
&self,
name: &str,
app_config: &AppConfig,
db_pool: &crate::database::DbPool,
) -> Result<()> {
// Load override from database // Load override from database
let db_config = sqlx::query("SELECT enabled, base_url, api_key FROM provider_configs WHERE id = ?") let db_config = sqlx::query("SELECT enabled, base_url, api_key FROM provider_configs WHERE id = ?")
.bind(name) .bind(name)
@@ -100,11 +112,31 @@ impl ProviderManager {
} else { } else {
// No database override, use defaults from AppConfig // No database override, use defaults from AppConfig
match name { match name {
"openai" => (app_config.providers.openai.enabled, Some(app_config.providers.openai.base_url.clone()), None), "openai" => (
"gemini" => (app_config.providers.gemini.enabled, Some(app_config.providers.gemini.base_url.clone()), None), app_config.providers.openai.enabled,
"deepseek" => (app_config.providers.deepseek.enabled, Some(app_config.providers.deepseek.base_url.clone()), None), Some(app_config.providers.openai.base_url.clone()),
"grok" => (app_config.providers.grok.enabled, Some(app_config.providers.grok.base_url.clone()), None), None,
"ollama" => (app_config.providers.ollama.enabled, Some(app_config.providers.ollama.base_url.clone()), None), ),
"gemini" => (
app_config.providers.gemini.enabled,
Some(app_config.providers.gemini.base_url.clone()),
None,
),
"deepseek" => (
app_config.providers.deepseek.enabled,
Some(app_config.providers.deepseek.base_url.clone()),
None,
),
"grok" => (
app_config.providers.grok.enabled,
Some(app_config.providers.grok.base_url.clone()),
None,
),
"ollama" => (
app_config.providers.ollama.enabled,
Some(app_config.providers.ollama.base_url.clone()),
None,
),
_ => (false, None, None), _ => (false, None, None),
} }
}; };
@@ -118,7 +150,9 @@ impl ProviderManager {
let provider: Arc<dyn Provider> = match name { let provider: Arc<dyn Provider> = match name {
"openai" => { "openai" => {
let mut cfg = app_config.providers.openai.clone(); let mut cfg = app_config.providers.openai.clone();
if let Some(url) = base_url { cfg.base_url = url; } if let Some(url) = base_url {
cfg.base_url = url;
}
// Handle API key override if present // Handle API key override if present
let p = if let Some(key) = api_key { let p = if let Some(key) = api_key {
// We need a way to create a provider with an explicit key // We need a way to create a provider with an explicit key
@@ -128,42 +162,50 @@ impl ProviderManager {
OpenAIProvider::new(&cfg, app_config)? OpenAIProvider::new(&cfg, app_config)?
}; };
Arc::new(p) Arc::new(p)
}, }
"ollama" => { "ollama" => {
let mut cfg = app_config.providers.ollama.clone(); let mut cfg = app_config.providers.ollama.clone();
if let Some(url) = base_url { cfg.base_url = url; } if let Some(url) = base_url {
cfg.base_url = url;
}
Arc::new(OllamaProvider::new(&cfg, app_config)?) Arc::new(OllamaProvider::new(&cfg, app_config)?)
}, }
"gemini" => { "gemini" => {
let mut cfg = app_config.providers.gemini.clone(); let mut cfg = app_config.providers.gemini.clone();
if let Some(url) = base_url { cfg.base_url = url; } if let Some(url) = base_url {
cfg.base_url = url;
}
let p = if let Some(key) = api_key { let p = if let Some(key) = api_key {
GeminiProvider::new_with_key(&cfg, app_config, key)? GeminiProvider::new_with_key(&cfg, app_config, key)?
} else { } else {
GeminiProvider::new(&cfg, app_config)? GeminiProvider::new(&cfg, app_config)?
}; };
Arc::new(p) Arc::new(p)
}, }
"deepseek" => { "deepseek" => {
let mut cfg = app_config.providers.deepseek.clone(); let mut cfg = app_config.providers.deepseek.clone();
if let Some(url) = base_url { cfg.base_url = url; } if let Some(url) = base_url {
cfg.base_url = url;
}
let p = if let Some(key) = api_key { let p = if let Some(key) = api_key {
DeepSeekProvider::new_with_key(&cfg, app_config, key)? DeepSeekProvider::new_with_key(&cfg, app_config, key)?
} else { } else {
DeepSeekProvider::new(&cfg, app_config)? DeepSeekProvider::new(&cfg, app_config)?
}; };
Arc::new(p) Arc::new(p)
}, }
"grok" => { "grok" => {
let mut cfg = app_config.providers.grok.clone(); let mut cfg = app_config.providers.grok.clone();
if let Some(url) = base_url { cfg.base_url = url; } if let Some(url) = base_url {
cfg.base_url = url;
}
let p = if let Some(key) = api_key { let p = if let Some(key) = api_key {
GrokProvider::new_with_key(&cfg, app_config, key)? GrokProvider::new_with_key(&cfg, app_config, key)?
} else { } else {
GrokProvider::new(&cfg, app_config)? GrokProvider::new(&cfg, app_config)?
}; };
Arc::new(p) Arc::new(p)
}, }
_ => return Err(anyhow::anyhow!("Unknown provider: {}", name)), _ => return Err(anyhow::anyhow!("Unknown provider: {}", name)),
}; };
@@ -188,16 +230,12 @@ impl ProviderManager {
pub async fn get_provider_for_model(&self, model: &str) -> Option<Arc<dyn Provider>> { pub async fn get_provider_for_model(&self, model: &str) -> Option<Arc<dyn Provider>> {
let providers = self.providers.read().await; let providers = self.providers.read().await;
providers.iter() providers.iter().find(|p| p.supports_model(model)).map(Arc::clone)
.find(|p| p.supports_model(model))
.map(|p| Arc::clone(p))
} }
pub async fn get_provider(&self, name: &str) -> Option<Arc<dyn Provider>> { pub async fn get_provider(&self, name: &str) -> Option<Arc<dyn Provider>> {
let providers = self.providers.read().await; let providers = self.providers.read().await;
providers.iter() providers.iter().find(|p| p.name() == name).map(Arc::clone)
.find(|p| p.name() == name)
.map(|p| Arc::clone(p))
} }
pub async fn get_all_providers(&self) -> Vec<Arc<dyn Provider>> { pub async fn get_all_providers(&self) -> Vec<Arc<dyn Provider>> {
@@ -238,21 +276,29 @@ pub mod placeholder {
&self, &self,
_request: UnifiedRequest, _request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> { ) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
Err(AppError::ProviderError("Streaming not supported for placeholder provider".to_string())) Err(AppError::ProviderError(
"Streaming not supported for placeholder provider".to_string(),
))
} }
async fn chat_completion( async fn chat_completion(&self, _request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
&self, Err(AppError::ProviderError(format!(
_request: UnifiedRequest, "Provider {} not implemented",
) -> Result<ProviderResponse, AppError> { self.name
Err(AppError::ProviderError(format!("Provider {} not implemented", self.name))) )))
} }
fn estimate_tokens(&self, _request: &UnifiedRequest) -> Result<u32> { fn estimate_tokens(&self, _request: &UnifiedRequest) -> Result<u32> {
Ok(0) Ok(0)
} }
fn calculate_cost(&self, _model: &str, _prompt_tokens: u32, _completion_tokens: u32, _registry: &crate::models::registry::ModelRegistry) -> f64 { fn calculate_cost(
&self,
_model: &str,
_prompt_tokens: u32,
_completion_tokens: u32,
_registry: &crate::models::registry::ModelRegistry,
) -> f64 {
0.0 0.0
} }
} }

View File

@@ -1,18 +1,14 @@
use async_trait::async_trait;
use anyhow::Result; use anyhow::Result;
use futures::stream::{BoxStream, StreamExt}; use async_trait::async_trait;
use serde_json::Value; use futures::stream::BoxStream;
use crate::{ use super::helpers;
models::UnifiedRequest,
errors::AppError,
config::AppConfig,
};
use super::{ProviderResponse, ProviderStreamChunk}; use super::{ProviderResponse, ProviderStreamChunk};
use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest};
pub struct OllamaProvider { pub struct OllamaProvider {
client: reqwest::Client, client: reqwest::Client,
_config: crate::config::OllamaConfig, config: crate::config::OllamaConfig,
pricing: Vec<crate::config::ModelPricing>, pricing: Vec<crate::config::ModelPricing>,
} }
@@ -20,7 +16,7 @@ impl OllamaProvider {
pub fn new(config: &crate::config::OllamaConfig, app_config: &AppConfig) -> Result<Self> { pub fn new(config: &crate::config::OllamaConfig, app_config: &AppConfig) -> Result<Self> {
Ok(Self { Ok(Self {
client: reqwest::Client::new(), client: reqwest::Client::new(),
_config: config.clone(), config: config.clone(),
pricing: app_config.pricing.ollama.clone(), pricing: app_config.pricing.ollama.clone(),
}) })
} }
@@ -33,49 +29,29 @@ impl super::Provider for OllamaProvider {
} }
fn supports_model(&self, model: &str) -> bool { fn supports_model(&self, model: &str) -> bool {
self._config.models.iter().any(|m| m == model) || model.starts_with("ollama/") self.config.models.iter().any(|m| m == model) || model.starts_with("ollama/")
} }
fn supports_multimodal(&self) -> bool { fn supports_multimodal(&self) -> bool {
true true
} }
async fn chat_completion( async fn chat_completion(&self, mut request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
&self, // Strip "ollama/" prefix if present for the API call
request: UnifiedRequest, let api_model = request
) -> Result<ProviderResponse, AppError> { .model
let model = request.model.strip_prefix("ollama/").unwrap_or(&request.model).to_string(); .strip_prefix("ollama/")
.unwrap_or(&request.model)
.to_string();
let original_model = request.model.clone();
request.model = api_model;
let mut body = serde_json::json!({ let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
"model": model, let body = helpers::build_openai_body(&request, messages_json, false);
"messages": request.messages.iter().map(|m| {
serde_json::json!({
"role": m.role,
"content": m.content.iter().map(|p| {
match p {
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
crate::models::ContentPart::Image(image_input) => {
let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
serde_json::json!({
"type": "image_url",
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
})
}
}
}).collect::<Vec<_>>()
})
}).collect::<Vec<_>>(),
"stream": false,
});
if let Some(temp) = request.temperature { let response = self
body["temperature"] = serde_json::json!(temp); .client
} .post(format!("{}/chat/completions", self.config.base_url))
if let Some(max_tokens) = request.max_tokens {
body["max_tokens"] = serde_json::json!(max_tokens);
}
let response = self.client.post(format!("{}/chat/completions", self._config.base_url))
.json(&body) .json(&body)
.send() .send()
.await .await
@@ -86,120 +62,67 @@ impl super::Provider for OllamaProvider {
return Err(AppError::ProviderError(format!("Ollama API error: {}", error_text))); return Err(AppError::ProviderError(format!("Ollama API error: {}", error_text)));
} }
let resp_json: Value = response.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?; let resp_json: serde_json::Value = response
.json()
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
let choice = resp_json["choices"].get(0).ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?; // Ollama also supports "thought" as an alias for reasoning_content
let message = &choice["message"]; let mut result = helpers::parse_openai_response(&resp_json, original_model)?;
if result.reasoning_content.is_none() {
let content = message["content"].as_str().unwrap_or_default().to_string(); result.reasoning_content = resp_json["choices"]
let reasoning_content = message["reasoning_content"].as_str().or_else(|| message["thought"].as_str()).map(|s| s.to_string()); .get(0)
.and_then(|c| c["message"]["thought"].as_str())
let usage = &resp_json["usage"]; .map(|s| s.to_string());
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32; }
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32; Ok(result)
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
Ok(ProviderResponse {
content,
reasoning_content,
prompt_tokens,
completion_tokens,
total_tokens,
model: request.model,
})
} }
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> { fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request)) Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
} }
fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64 { fn calculate_cost(
if let Some(metadata) = registry.find_model(model) { &self,
if let Some(cost) = &metadata.cost { model: &str,
return (prompt_tokens as f64 * cost.input / 1_000_000.0) + prompt_tokens: u32,
(completion_tokens as f64 * cost.output / 1_000_000.0); completion_tokens: u32,
} registry: &crate::models::registry::ModelRegistry,
} ) -> f64 {
helpers::calculate_cost_with_registry(
let (prompt_rate, completion_rate) = self.pricing.iter() model,
.find(|p| model.contains(&p.model)) prompt_tokens,
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million)) completion_tokens,
.unwrap_or((0.0, 0.0)); registry,
&self.pricing,
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0) 0.0,
0.0,
)
} }
async fn chat_completion_stream( async fn chat_completion_stream(
&self, &self,
request: UnifiedRequest, mut request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> { ) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
let model = request.model.strip_prefix("ollama/").unwrap_or(&request.model).to_string(); let api_model = request
.model
.strip_prefix("ollama/")
.unwrap_or(&request.model)
.to_string();
let original_model = request.model.clone();
request.model = api_model;
let mut body = serde_json::json!({ let messages_json = helpers::messages_to_openai_json_text_only(&request.messages).await?;
"model": model, let body = helpers::build_openai_body(&request, messages_json, true);
"messages": request.messages.iter().map(|m| {
serde_json::json!({
"role": m.role,
"content": m.content.iter().map(|p| {
match p {
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
crate::models::ContentPart::Image(_) => serde_json::json!({ "type": "text", "text": "[Image]" }),
}
}).collect::<Vec<_>>()
})
}).collect::<Vec<_>>(),
"stream": true,
});
if let Some(temp) = request.temperature { let es = reqwest_eventsource::EventSource::new(
body["temperature"] = serde_json::json!(temp); self.client
} .post(format!("{}/chat/completions", self.config.base_url))
if let Some(max_tokens) = request.max_tokens { .json(&body),
body["max_tokens"] = serde_json::json!(max_tokens); )
}
// Create eventsource stream
use reqwest_eventsource::{EventSource, Event};
let es = EventSource::new(self.client.post(format!("{}/chat/completions", self._config.base_url))
.json(&body))
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?; .map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
let model_name = request.model.clone(); // Ollama uses "thought" as an alternative field for reasoning content
Ok(helpers::create_openai_stream(es, original_model, Some("thought")))
let stream = async_stream::try_stream! {
let mut es = es;
while let Some(event) = es.next().await {
match event {
Ok(Event::Message(msg)) => {
if msg.data == "[DONE]" {
break;
}
let chunk: Value = serde_json::from_str(&msg.data)
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
if let Some(choice) = chunk["choices"].get(0) {
let delta = &choice["delta"];
let content = delta["content"].as_str().unwrap_or_default().to_string();
let reasoning_content = delta["reasoning_content"].as_str().or_else(|| delta["thought"].as_str()).map(|s| s.to_string());
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
yield ProviderStreamChunk {
content,
reasoning_content,
finish_reason,
model: model_name.clone(),
};
}
}
Ok(_) => continue,
Err(e) => {
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
}
}
}
};
Ok(Box::pin(stream))
} }
} }

View File

@@ -1,18 +1,14 @@
use async_trait::async_trait;
use anyhow::Result; use anyhow::Result;
use futures::stream::{BoxStream, StreamExt}; use async_trait::async_trait;
use serde_json::Value; use futures::stream::BoxStream;
use crate::{ use super::helpers;
models::UnifiedRequest,
errors::AppError,
config::AppConfig,
};
use super::{ProviderResponse, ProviderStreamChunk}; use super::{ProviderResponse, ProviderStreamChunk};
use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest};
pub struct OpenAIProvider { pub struct OpenAIProvider {
client: reqwest::Client, client: reqwest::Client,
_config: crate::config::OpenAIConfig, config: crate::config::OpenAIConfig,
api_key: String, api_key: String,
pricing: Vec<crate::config::ModelPricing>, pricing: Vec<crate::config::ModelPricing>,
} }
@@ -26,7 +22,7 @@ impl OpenAIProvider {
pub fn new_with_key(config: &crate::config::OpenAIConfig, app_config: &AppConfig, api_key: String) -> Result<Self> { pub fn new_with_key(config: &crate::config::OpenAIConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
Ok(Self { Ok(Self {
client: reqwest::Client::new(), client: reqwest::Client::new(),
_config: config.clone(), config: config.clone(),
api_key, api_key,
pricing: app_config.pricing.openai.clone(), pricing: app_config.pricing.openai.clone(),
}) })
@@ -47,40 +43,13 @@ impl super::Provider for OpenAIProvider {
true true
} }
async fn chat_completion( async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
&self, let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
request: UnifiedRequest, let body = helpers::build_openai_body(&request, messages_json, false);
) -> Result<ProviderResponse, AppError> {
let mut body = serde_json::json!({
"model": request.model,
"messages": request.messages.iter().map(|m| {
serde_json::json!({
"role": m.role,
"content": m.content.iter().map(|p| {
match p {
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
crate::models::ContentPart::Image(image_input) => {
let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
serde_json::json!({
"type": "image_url",
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
})
}
}
}).collect::<Vec<_>>()
})
}).collect::<Vec<_>>(),
"stream": false,
});
if let Some(temp) = request.temperature { let response = self
body["temperature"] = serde_json::json!(temp); .client
} .post(format!("{}/chat/completions", self.config.base_url))
if let Some(max_tokens) = request.max_tokens {
body["max_tokens"] = serde_json::json!(max_tokens);
}
let response = self.client.post(format!("{}/chat/completions", self._config.base_url))
.header("Authorization", format!("Bearer {}", self.api_key)) .header("Authorization", format!("Bearer {}", self.api_key))
.json(&body) .json(&body)
.send() .send()
@@ -92,125 +61,51 @@ impl super::Provider for OpenAIProvider {
return Err(AppError::ProviderError(format!("OpenAI API error: {}", error_text))); return Err(AppError::ProviderError(format!("OpenAI API error: {}", error_text)));
} }
let resp_json: Value = response.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?; let resp_json: serde_json::Value = response
.json()
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
let choice = resp_json["choices"].get(0).ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?; helpers::parse_openai_response(&resp_json, request.model)
let message = &choice["message"];
let content = message["content"].as_str().unwrap_or_default().to_string();
let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string());
let usage = &resp_json["usage"];
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
Ok(ProviderResponse {
content,
reasoning_content,
prompt_tokens,
completion_tokens,
total_tokens,
model: request.model,
})
} }
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> { fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request)) Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
} }
fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64 { fn calculate_cost(
if let Some(metadata) = registry.find_model(model) { &self,
if let Some(cost) = &metadata.cost { model: &str,
return (prompt_tokens as f64 * cost.input / 1_000_000.0) + prompt_tokens: u32,
(completion_tokens as f64 * cost.output / 1_000_000.0); completion_tokens: u32,
} registry: &crate::models::registry::ModelRegistry,
} ) -> f64 {
helpers::calculate_cost_with_registry(
let (prompt_rate, completion_rate) = self.pricing.iter() model,
.find(|p| model.contains(&p.model)) prompt_tokens,
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million)) completion_tokens,
.unwrap_or((0.15, 0.60)); registry,
&self.pricing,
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0) 0.15,
0.60,
)
} }
async fn chat_completion_stream( async fn chat_completion_stream(
&self, &self,
request: UnifiedRequest, request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> { ) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
let mut body = serde_json::json!({ let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
"model": request.model, let body = helpers::build_openai_body(&request, messages_json, true);
"messages": request.messages.iter().map(|m| {
serde_json::json!({
"role": m.role,
"content": m.content.iter().map(|p| {
match p {
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
crate::models::ContentPart::Image(image_input) => {
let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
serde_json::json!({
"type": "image_url",
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
})
}
}
}).collect::<Vec<_>>()
})
}).collect::<Vec<_>>(),
"stream": true,
});
if let Some(temp) = request.temperature { let es = reqwest_eventsource::EventSource::new(
body["temperature"] = serde_json::json!(temp); self.client
} .post(format!("{}/chat/completions", self.config.base_url))
if let Some(max_tokens) = request.max_tokens {
body["max_tokens"] = serde_json::json!(max_tokens);
}
// Create eventsource stream
use reqwest_eventsource::{EventSource, Event};
let es = EventSource::new(self.client.post(format!("{}/chat/completions", self._config.base_url))
.header("Authorization", format!("Bearer {}", self.api_key)) .header("Authorization", format!("Bearer {}", self.api_key))
.json(&body)) .json(&body),
)
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?; .map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
let model = request.model.clone(); Ok(helpers::create_openai_stream(es, request.model, None))
let stream = async_stream::try_stream! {
let mut es = es;
while let Some(event) = es.next().await {
match event {
Ok(Event::Message(msg)) => {
if msg.data == "[DONE]" {
break;
}
let chunk: Value = serde_json::from_str(&msg.data)
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
if let Some(choice) = chunk["choices"].get(0) {
let delta = &choice["delta"];
let content = delta["content"].as_str().unwrap_or_default().to_string();
let reasoning_content = delta["reasoning_content"].as_str().map(|s| s.to_string());
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
yield ProviderStreamChunk {
content,
reasoning_content,
finish_reason,
model: model.clone(),
};
}
}
Ok(_) => continue,
Err(e) => {
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
}
}
}
};
Ok(Box::pin(stream))
} }
} }

View File

@@ -5,12 +5,12 @@
//! 2. Provider circuit breaking to handle API failures //! 2. Provider circuit breaking to handle API failures
//! 3. Global rate limiting for overall system protection //! 3. Global rate limiting for overall system protection
use std::sync::Arc; use anyhow::Result;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
use tokio::sync::RwLock; use tokio::sync::RwLock;
use tracing::{info, warn}; use tracing::{info, warn};
use anyhow::Result;
/// Rate limiter configuration /// Rate limiter configuration
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@@ -177,12 +177,12 @@ impl ProviderCircuitBreaker {
let now = std::time::Instant::now(); let now = std::time::Instant::now();
// Check if failure window has expired // Check if failure window has expired
if let Some(last_failure) = self.last_failure_time { if let Some(last_failure) = self.last_failure_time
if now.duration_since(last_failure).as_secs() > self.config.failure_window_secs { && now.duration_since(last_failure).as_secs() > self.config.failure_window_secs
{
// Reset failure count if window expired // Reset failure count if window expired
self.failure_count = 0; self.failure_count = 0;
} }
}
self.failure_count += 1; self.failure_count += 1;
self.last_failure_time = Some(now); self.last_failure_time = Some(now);
@@ -246,9 +246,7 @@ impl RateLimitManager {
// Check client-specific rate limit // Check client-specific rate limit
let mut buckets = self.client_buckets.write().await; let mut buckets = self.client_buckets.write().await;
let bucket = buckets let bucket = buckets.entry(client_id.to_string()).or_insert_with(|| {
.entry(client_id.to_string())
.or_insert_with(|| {
TokenBucket::new( TokenBucket::new(
self.config.burst_size as f64, self.config.burst_size as f64,
self.config.requests_per_minute as f64 / 60.0, self.config.requests_per_minute as f64 / 60.0,
@@ -299,14 +297,13 @@ impl RateLimitManager {
/// Axum middleware for rate limiting /// Axum middleware for rate limiting
pub mod middleware { pub mod middleware {
use super::*; use super::*;
use crate::errors::AppError;
use crate::state::AppState;
use axum::{ use axum::{
extract::{Request, State}, extract::{Request, State},
middleware::Next, middleware::Next,
response::Response, response::Response,
}; };
use crate::errors::AppError;
use crate::state::AppState;
/// Rate limiting middleware /// Rate limiting middleware
pub async fn rate_limit_middleware( pub async fn rate_limit_middleware(
@@ -319,9 +316,7 @@ pub mod middleware {
// Check rate limits // Check rate limits
if !state.rate_limit_manager.check_client_request(&client_id).await? { if !state.rate_limit_manager.check_client_request(&client_id).await? {
return Err(AppError::RateLimitError( return Err(AppError::RateLimitError("Rate limit exceeded".to_string()));
"Rate limit exceeded".to_string()
));
} }
Ok(next.run(request).await) Ok(next.run(request).await)
@@ -330,29 +325,25 @@ pub mod middleware {
/// Extract client ID from request (helper function) /// Extract client ID from request (helper function)
fn extract_client_id_from_request(request: &Request) -> String { fn extract_client_id_from_request(request: &Request) -> String {
// Try to extract from Authorization header // Try to extract from Authorization header
if let Some(auth_header) = request.headers().get("Authorization") { if let Some(auth_header) = request.headers().get("Authorization")
if let Ok(auth_str) = auth_header.to_str() { && let Ok(auth_str) = auth_header.to_str()
if auth_str.starts_with("Bearer ") { && let Some(token) = auth_str.strip_prefix("Bearer ")
let token = &auth_str[7..]; {
// Use token hash as client ID (same logic as auth module) // Use token hash as client ID (same logic as auth module)
return format!("client_{}", &token[..8.min(token.len())]); return format!("client_{}", &token[..8.min(token.len())]);
} }
}
}
// Fallback to anonymous // Fallback to anonymous
"anonymous".to_string() "anonymous".to_string()
} }
/// Circuit breaker middleware for provider requests /// Circuit breaker middleware for provider requests
pub async fn circuit_breaker_middleware( pub async fn circuit_breaker_middleware(provider_name: &str, state: &AppState) -> Result<(), AppError> {
provider_name: &str,
state: &AppState,
) -> Result<(), AppError> {
if !state.rate_limit_manager.check_provider_request(provider_name).await? { if !state.rate_limit_manager.check_provider_request(provider_name).await? {
return Err(AppError::ProviderError( return Err(AppError::ProviderError(format!(
format!("Provider {} is currently unavailable (circuit breaker open)", provider_name) "Provider {} is currently unavailable (circuit breaker open)",
)); provider_name
)));
} }
Ok(()) Ok(())
} }

View File

@@ -1,22 +1,25 @@
use std::sync::Arc;
use sqlx::Row;
use uuid::Uuid;
use axum::{ use axum::{
extract::State,
routing::post,
Json, Router, Json, Router,
response::sse::{Event, Sse}, extract::State,
response::IntoResponse, response::IntoResponse,
response::sse::{Event, Sse},
routing::post,
}; };
use futures::stream::StreamExt; use futures::stream::StreamExt;
use sqlx::Row;
use std::sync::Arc;
use tracing::{info, warn}; use tracing::{info, warn};
use uuid::Uuid;
use crate::{ use crate::{
auth::AuthenticatedClient, auth::AuthenticatedClient,
errors::AppError, errors::AppError,
models::{ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, ChatStreamChoice, ChatStreamDelta, ChatMessage, ChatChoice, Usage}, models::{
state::AppState, ChatChoice, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, ChatMessage,
ChatStreamChoice, ChatStreamDelta, Usage,
},
rate_limiting, rate_limiting,
state::AppState,
}; };
pub fn router(state: AppState) -> Router { pub fn router(state: AppState) -> Router {
@@ -85,7 +88,10 @@ async fn chat_completions(
}; };
if !model_enabled { if !model_enabled {
return Err(AppError::ValidationError(format!("Model {} is currently disabled", model))); return Err(AppError::ValidationError(format!(
"Model {} is currently disabled",
model
)));
} }
// Apply mapping if present // Apply mapping if present
@@ -95,7 +101,10 @@ async fn chat_completions(
} }
// Find appropriate provider for the model // Find appropriate provider for the model
let provider = state.provider_manager.get_provider_for_model(&request.model).await let provider = state
.provider_manager
.get_provider_for_model(&request.model)
.await
.ok_or_else(|| AppError::ProviderError(format!("No provider found for model: {}", request.model)))?; .ok_or_else(|| AppError::ProviderError(format!("No provider found for model: {}", request.model)))?;
let provider_name = provider.name().to_string(); let provider_name = provider.name().to_string();
@@ -104,23 +113,26 @@ async fn chat_completions(
rate_limiting::middleware::circuit_breaker_middleware(&provider_name, &state).await?; rate_limiting::middleware::circuit_breaker_middleware(&provider_name, &state).await?;
// Convert to unified request format // Convert to unified request format
let mut unified_request = crate::models::UnifiedRequest::try_from(request) let mut unified_request =
.map_err(|e| AppError::ValidationError(e.to_string()))?; crate::models::UnifiedRequest::try_from(request).map_err(|e| AppError::ValidationError(e.to_string()))?;
// Set client_id from authentication // Set client_id from authentication
unified_request.client_id = client_id.clone(); unified_request.client_id = client_id.clone();
// Hydrate images if present // Hydrate images if present
if unified_request.has_images { if unified_request.has_images {
unified_request.hydrate_images().await unified_request
.hydrate_images()
.await
.map_err(|e| AppError::ValidationError(format!("Failed to process images: {}", e)))?; .map_err(|e| AppError::ValidationError(format!("Failed to process images: {}", e)))?;
} }
let has_images = unified_request.has_images;
// Check if streaming is requested // Check if streaming is requested
if unified_request.stream { if unified_request.stream {
// Estimate prompt tokens for logging later // Estimate prompt tokens for logging later
let prompt_tokens = crate::utils::tokens::estimate_request_tokens(&model, &unified_request); let prompt_tokens = crate::utils::tokens::estimate_request_tokens(&model, &unified_request);
let has_images = unified_request.has_images;
// Handle streaming response // Handle streaming response
let stream_result = provider.chat_completion_stream(unified_request).await; let stream_result = provider.chat_completion_stream(unified_request).await;
@@ -133,15 +145,17 @@ async fn chat_completions(
// Wrap with AggregatingStream for token counting and database logging // Wrap with AggregatingStream for token counting and database logging
let aggregating_stream = crate::utils::streaming::AggregatingStream::new( let aggregating_stream = crate::utils::streaming::AggregatingStream::new(
stream, stream,
client_id.clone(), crate::utils::streaming::StreamConfig {
provider.clone(), client_id: client_id.clone(),
model.clone(), provider: provider.clone(),
model: model.clone(),
prompt_tokens, prompt_tokens,
has_images, has_images,
state.request_logger.clone(), logger: state.request_logger.clone(),
state.client_manager.clone(), client_manager: state.client_manager.clone(),
state.model_registry.clone(), model_registry: state.model_registry.clone(),
state.db_pool.clone(), db_pool: state.db_pool.clone(),
},
); );
// Create SSE stream from aggregating stream // Create SSE stream from aggregating stream
@@ -165,7 +179,13 @@ async fn chat_completions(
}], }],
}; };
Ok(Event::default().json_data(response).unwrap()) match Event::default().json_data(response) {
Ok(event) => Ok(event),
Err(e) => {
warn!("Failed to serialize SSE event: {}", e);
Err(AppError::InternalError("SSE serialization failed".to_string()))
}
}
} }
Err(e) => { Err(e) => {
warn!("Error in streaming response: {}", e); warn!("Error in streaming response: {}", e);
@@ -197,7 +217,14 @@ async fn chat_completions(
state.rate_limit_manager.record_provider_success(&provider_name).await; state.rate_limit_manager.record_provider_success(&provider_name).await;
let duration = start_time.elapsed(); let duration = start_time.elapsed();
let cost = get_model_cost(&response.model, response.prompt_tokens, response.completion_tokens, &provider, &state).await; let cost = get_model_cost(
&response.model,
response.prompt_tokens,
response.completion_tokens,
&provider,
&state,
)
.await;
// Log request to database // Log request to database
state.request_logger.log_request(crate::logging::RequestLog { state.request_logger.log_request(crate::logging::RequestLog {
timestamp: chrono::Utc::now(), timestamp: chrono::Utc::now(),
@@ -208,18 +235,17 @@ async fn chat_completions(
completion_tokens: response.completion_tokens, completion_tokens: response.completion_tokens,
total_tokens: response.total_tokens, total_tokens: response.total_tokens,
cost, cost,
has_images: false, // TODO: check images has_images,
status: "success".to_string(), status: "success".to_string(),
error_message: None, error_message: None,
duration_ms: duration.as_millis() as u64, duration_ms: duration.as_millis() as u64,
}); });
// Update client usage // Update client usage
let _ = state.client_manager.update_client_usage( let _ = state
&client_id, .client_manager
response.total_tokens as i64, .update_client_usage(&client_id, response.total_tokens as i64, cost)
cost, .await;
).await;
// Convert ProviderResponse to ChatCompletionResponse // Convert ProviderResponse to ChatCompletionResponse
let chat_response = ChatCompletionResponse { let chat_response = ChatCompletionResponse {
@@ -232,7 +258,7 @@ async fn chat_completions(
message: ChatMessage { message: ChatMessage {
role: "assistant".to_string(), role: "assistant".to_string(),
content: crate::models::MessageContent::Text { content: crate::models::MessageContent::Text {
content: response.content content: response.content,
}, },
reasoning_content: response.reasoning_content, reasoning_content: response.reasoning_content,
}, },

View File

@@ -2,9 +2,8 @@ use std::sync::Arc;
use tokio::sync::broadcast; use tokio::sync::broadcast;
use crate::{ use crate::{
client::ClientManager, database::DbPool, providers::ProviderManager, client::ClientManager, config::AppConfig, database::DbPool, logging::RequestLogger,
rate_limiting::RateLimitManager, logging::RequestLogger, models::registry::ModelRegistry, providers::ProviderManager, rate_limiting::RateLimitManager,
models::registry::ModelRegistry, config::AppConfig,
}; };
/// Shared application state /// Shared application state

View File

@@ -1,3 +1,3 @@
pub mod tokens;
pub mod registry; pub mod registry;
pub mod streaming; pub mod streaming;
pub mod tokens;

View File

@@ -1,6 +1,6 @@
use crate::models::registry::ModelRegistry;
use anyhow::Result; use anyhow::Result;
use tracing::info; use tracing::info;
use crate::models::registry::ModelRegistry;
const MODELS_DEV_URL: &str = "https://models.dev/api.json"; const MODELS_DEV_URL: &str = "https://models.dev/api.json";

View File

@@ -1,13 +1,26 @@
use futures::stream::Stream;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::sync::Arc;
use sqlx::Row;
use crate::logging::{RequestLogger, RequestLog};
use crate::client::ClientManager; use crate::client::ClientManager;
use crate::providers::{Provider, ProviderStreamChunk};
use crate::errors::AppError; use crate::errors::AppError;
use crate::logging::{RequestLog, RequestLogger};
use crate::providers::{Provider, ProviderStreamChunk};
use crate::utils::tokens::estimate_completion_tokens; use crate::utils::tokens::estimate_completion_tokens;
use futures::stream::Stream;
use sqlx::Row;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
/// Configuration for creating an AggregatingStream.
pub struct StreamConfig {
pub client_id: String,
pub provider: Arc<dyn Provider>,
pub model: String,
pub prompt_tokens: u32,
pub has_images: bool,
pub logger: Arc<RequestLogger>,
pub client_manager: Arc<ClientManager>,
pub model_registry: Arc<crate::models::registry::ModelRegistry>,
pub db_pool: crate::database::DbPool,
}
pub struct AggregatingStream<S> { pub struct AggregatingStream<S> {
inner: S, inner: S,
@@ -28,33 +41,22 @@ pub struct AggregatingStream<S> {
impl<S> AggregatingStream<S> impl<S> AggregatingStream<S>
where where
S: Stream<Item = Result<ProviderStreamChunk, AppError>> + Unpin S: Stream<Item = Result<ProviderStreamChunk, AppError>> + Unpin,
{ {
pub fn new( pub fn new(inner: S, config: StreamConfig) -> Self {
inner: S,
client_id: String,
provider: Arc<dyn Provider>,
model: String,
prompt_tokens: u32,
has_images: bool,
logger: Arc<RequestLogger>,
client_manager: Arc<ClientManager>,
model_registry: Arc<crate::models::registry::ModelRegistry>,
db_pool: crate::database::DbPool,
) -> Self {
Self { Self {
inner, inner,
client_id, client_id: config.client_id,
provider, provider: config.provider,
model, model: config.model,
prompt_tokens, prompt_tokens: config.prompt_tokens,
has_images, has_images: config.has_images,
accumulated_content: String::new(), accumulated_content: String::new(),
accumulated_reasoning: String::new(), accumulated_reasoning: String::new(),
logger, logger: config.logger,
client_manager, client_manager: config.client_manager,
model_registry, model_registry: config.model_registry,
db_pool, db_pool: config.db_pool,
start_time: std::time::Instant::now(), start_time: std::time::Instant::now(),
has_logged: false, has_logged: false,
} }
@@ -92,7 +94,8 @@ where
// Spawn a background task to log the completion // Spawn a background task to log the completion
tokio::spawn(async move { tokio::spawn(async move {
// Check database for cost overrides // Check database for cost overrides
let db_cost = sqlx::query("SELECT prompt_cost_per_m, completion_cost_per_m FROM model_configs WHERE id = ?") let db_cost =
sqlx::query("SELECT prompt_cost_per_m, completion_cost_per_m FROM model_configs WHERE id = ?")
.bind(&model) .bind(&model)
.fetch_optional(&pool) .fetch_optional(&pool)
.await .await
@@ -128,18 +131,16 @@ where
}); });
// Update client usage // Update client usage
let _ = client_manager.update_client_usage( let _ = client_manager
&client_id, .update_client_usage(&client_id, total_tokens as i64, cost)
total_tokens as i64, .await;
cost,
).await;
}); });
} }
} }
impl<S> Stream for AggregatingStream<S> impl<S> Stream for AggregatingStream<S>
where where
S: Stream<Item = Result<ProviderStreamChunk, AppError>> + Unpin S: Stream<Item = Result<ProviderStreamChunk, AppError>> + Unpin,
{ {
type Item = Result<ProviderStreamChunk, AppError>; type Item = Result<ProviderStreamChunk, AppError>;
@@ -173,46 +174,81 @@ where
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use futures::stream::{self, StreamExt};
use anyhow::Result; use anyhow::Result;
use futures::stream::{self, StreamExt};
// Simple mock provider for testing // Simple mock provider for testing
struct MockProvider; struct MockProvider;
#[async_trait::async_trait] #[async_trait::async_trait]
impl Provider for MockProvider { impl Provider for MockProvider {
fn name(&self) -> &str { "mock" } fn name(&self) -> &str {
fn supports_model(&self, _model: &str) -> bool { true } "mock"
fn supports_multimodal(&self) -> bool { false } }
async fn chat_completion(&self, _req: crate::models::UnifiedRequest) -> Result<crate::providers::ProviderResponse, AppError> { unimplemented!() } fn supports_model(&self, _model: &str) -> bool {
async fn chat_completion_stream(&self, _req: crate::models::UnifiedRequest) -> Result<futures::stream::BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> { unimplemented!() } true
fn estimate_tokens(&self, _req: &crate::models::UnifiedRequest) -> Result<u32> { Ok(10) } }
fn calculate_cost(&self, _model: &str, _p: u32, _c: u32, _r: &crate::models::registry::ModelRegistry) -> f64 { 0.05 } fn supports_multimodal(&self) -> bool {
false
}
async fn chat_completion(
&self,
_req: crate::models::UnifiedRequest,
) -> Result<crate::providers::ProviderResponse, AppError> {
unimplemented!()
}
async fn chat_completion_stream(
&self,
_req: crate::models::UnifiedRequest,
) -> Result<futures::stream::BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
unimplemented!()
}
fn estimate_tokens(&self, _req: &crate::models::UnifiedRequest) -> Result<u32> {
Ok(10)
}
fn calculate_cost(&self, _model: &str, _p: u32, _c: u32, _r: &crate::models::registry::ModelRegistry) -> f64 {
0.05
}
} }
#[tokio::test] #[tokio::test]
async fn test_aggregating_stream() { async fn test_aggregating_stream() {
let chunks = vec![ let chunks = vec![
Ok(ProviderStreamChunk { content: "Hello".to_string(), finish_reason: None, model: "test".to_string() }), Ok(ProviderStreamChunk {
Ok(ProviderStreamChunk { content: " World".to_string(), finish_reason: Some("stop".to_string()), model: "test".to_string() }), content: "Hello".to_string(),
reasoning_content: None,
finish_reason: None,
model: "test".to_string(),
}),
Ok(ProviderStreamChunk {
content: " World".to_string(),
reasoning_content: None,
finish_reason: Some("stop".to_string()),
model: "test".to_string(),
}),
]; ];
let inner_stream = stream::iter(chunks); let inner_stream = stream::iter(chunks);
let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap(); let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap();
let logger = Arc::new(RequestLogger::new(pool.clone())); let (dashboard_tx, _) = tokio::sync::broadcast::channel(16);
let logger = Arc::new(RequestLogger::new(pool.clone(), dashboard_tx));
let client_manager = Arc::new(ClientManager::new(pool.clone())); let client_manager = Arc::new(ClientManager::new(pool.clone()));
let registry = Arc::new(crate::models::registry::ModelRegistry { providers: std::collections::HashMap::new() }); let registry = Arc::new(crate::models::registry::ModelRegistry {
providers: std::collections::HashMap::new(),
});
let mut agg_stream = AggregatingStream::new( let mut agg_stream = AggregatingStream::new(
inner_stream, inner_stream,
"client_1".to_string(), StreamConfig {
Arc::new(MockProvider), client_id: "client_1".to_string(),
"test".to_string(), provider: Arc::new(MockProvider),
10, model: "test".to_string(),
false, prompt_tokens: 10,
has_images: false,
logger, logger,
client_manager, client_manager,
registry, model_registry: registry,
pool.clone(), db_pool: pool.clone(),
},
); );
while let Some(item) = agg_stream.next().await { while let Some(item) = agg_stream.next().await {

View File

@@ -1,12 +1,11 @@
use tiktoken_rs::get_bpe_from_model;
use crate::models::UnifiedRequest; use crate::models::UnifiedRequest;
use tiktoken_rs::get_bpe_from_model;
/// Count tokens for a given model and text /// Count tokens for a given model and text
pub fn count_tokens(model: &str, text: &str) -> u32 { pub fn count_tokens(model: &str, text: &str) -> u32 {
// If we can't get the bpe for the model, fallback to a safe default (cl100k_base for GPT-4/o1) // If we can't get the bpe for the model, fallback to a safe default (cl100k_base for GPT-4/o1)
let bpe = get_bpe_from_model(model).unwrap_or_else(|_| { let bpe = get_bpe_from_model(model)
tiktoken_rs::cl100k_base().expect("Failed to get cl100k_base encoding") .unwrap_or_else(|_| tiktoken_rs::cl100k_base().expect("Failed to get cl100k_base encoding"));
});
bpe.encode_with_special_tokens(text).len() as u32 bpe.encode_with_special_tokens(text).len() as u32
} }

View File

@@ -30,7 +30,7 @@ curl -s http://localhost:8080/api/auth/status | jq . 2>/dev/null || echo "JSON r
echo "" echo ""
echo "Dashboard should be available at: http://localhost:8080" echo "Dashboard should be available at: http://localhost:8080"
echo "Default login: admin / admin123" echo "Default login: admin / admin"
echo "" echo ""
echo "Press Ctrl+C to stop the server" echo "Press Ctrl+C to stop the server"

View File

@@ -1,188 +0,0 @@
// Integration tests for LLM Proxy Gateway
use llm_proxy::config::Config;
use llm_proxy::database::Database;
use llm_proxy::state::AppState;
use llm_proxy::rate_limiting::RateLimitManager;
use tempfile::TempDir;
use std::fs;
#[tokio::test]
async fn test_config_loading() {
// Create a temporary config file
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml");
let config_content = r#"
[server]
port = 8080
host = "0.0.0.0"
[database]
path = "./data/test.db"
max_connections = 5
[providers.openai]
enabled = true
base_url = "https://api.openai.com/v1"
[providers.gemini]
enabled = true
base_url = "https://generativelanguage.googleapis.com/v1"
[providers.deepseek]
enabled = true
base_url = "https://api.deepseek.com"
[providers.grok]
enabled = false
base_url = "https://api.x.ai/v1"
[model_mapping]
"gpt-*" = "openai"
"gemini-*" = "gemini"
"deepseek-*" = "deepseek"
"grok-*" = "grok"
[pricing]
openai = { input = 0.01, output = 0.03 }
gemini = { input = 0.0005, output = 0.0015 }
deepseek = { input = 0.00014, output = 0.00028 }
grok = { input = 0.001, output = 0.003 }
"#;
fs::write(&config_path, config_content).unwrap();
// Test loading config
let config = Config::load_from_path(&config_path);
assert!(config.is_ok());
let config = config.unwrap();
assert_eq!(config.server.port, 8080);
assert!(config.providers.openai.is_some());
assert!(config.providers.grok.is_none());
}
#[tokio::test]
async fn test_database_initialization() {
// Create a temporary database file
let temp_dir = TempDir::new().unwrap();
let db_path = temp_dir.path().join("test.db");
// Test database initialization
let database = Database::new(&db_path).await;
assert!(database.is_ok());
let database = database.unwrap();
// Test connection
let test_result = database.test_connection().await;
assert!(test_result.is_ok());
}
#[tokio::test]
async fn test_provider_manager() {
// Create a provider manager
use llm_proxy::providers::{ProviderManager, Provider};
use llm_proxy::config::OpenAIConfig;
let mut manager = ProviderManager::new();
assert_eq!(manager.providers.len(), 0);
// Test adding providers (we can't actually add real providers without API keys)
// This test just verifies the manager structure works
assert!(manager.get_provider_for_model("gpt-4").is_none());
assert!(manager.get_provider("openai").is_none());
}
#[tokio::test]
async fn test_rate_limit_manager() {
let manager = RateLimitManager::new(60, 10);
// Test client rate limiting
let allowed = manager.check_request("test-client").await;
assert!(allowed); // First request should be allowed
// Test provider circuit breaker
let allowed = manager.check_provider("openai").await;
assert!(allowed); // Circuit should be closed initially
// Record some failures
manager.record_provider_failure("openai").await;
manager.record_provider_failure("openai").await;
manager.record_provider_failure("openai").await;
manager.record_provider_failure("openai").await;
manager.record_provider_failure("openai").await;
// After 5 failures, circuit should be open
let allowed = manager.check_provider("openai").await;
assert!(!allowed); // Circuit should be open
// Record success to close circuit
manager.record_provider_success("openai").await;
manager.record_provider_success("openai").await;
manager.record_provider_success("openai").await;
// After 3 successes in half-open state, circuit should be closed
let allowed = manager.check_provider("openai").await;
assert!(allowed); // Circuit should be closed again
}
#[tokio::test]
async fn test_app_state_creation() {
// Create a temporary database
let temp_dir = TempDir::new().unwrap();
let db_path = temp_dir.path().join("test.db");
let database = Database::new(&db_path).await.unwrap();
// Test AppState creation using test utilities
use llm_proxy::test_utils::create_test_state;
let state = create_test_state().await;
// Verify state components are initialized
assert!(state.database.test_connection().await.is_ok());
}
#[tokio::test]
async fn test_multimodal_image_converter() {
use llm_proxy::multimodal::{ImageConverter, ImageInput};
// Test model detection
assert!(ImageConverter::model_supports_multimodal("gpt-4-vision-preview"));
assert!(ImageConverter::model_supports_multimodal("gemini-pro-vision"));
assert!(!ImageConverter::model_supports_multimodal("gpt-3.5-turbo"));
assert!(!ImageConverter::model_supports_multimodal("gemini-pro"));
// Test data URL parsing (utility function)
let test_url = "data:image/jpeg;base64,SGVsbG8gV29ybGQ=";
let parts: Vec<&str> = test_url[5..].split(";base64,").collect();
assert_eq!(parts.len(), 2);
assert_eq!(parts[0], "image/jpeg");
assert_eq!(parts[1], "SGVsbG8gV29ybGQ=");
}
#[tokio::test]
async fn test_error_conversions() {
use llm_proxy::errors::AppError;
use anyhow::anyhow;
// Test anyhow error conversion
let anyhow_error = anyhow!("Test error");
let app_error: AppError = anyhow_error.into();
match app_error {
AppError::InternalError(msg) => assert_eq!(msg, "Test error"),
_ => panic!("Expected InternalError"),
}
// Test sqlx error conversion
use sqlx::Error as SqlxError;
let sqlx_error = SqlxError::PoolClosed;
let app_error: AppError = sqlx_error.into();
match app_error {
AppError::DatabaseError(msg) => assert!(msg.contains("pool closed")),
_ => panic!("Expected DatabaseError"),
}
}