refactor: comprehensive audit — fix bugs, harden security, deduplicate providers, add CI/Docker
Phase 1: Fix compilation (config_path Option<PathBuf>, streaming test, stale test cleanup) Phase 2: Fix critical bugs (remove block_on deadlocks in 4 providers, fix broken SQL query builder) Phase 3: Security hardening (session manager, real auth, token masking, Gemini key to header, password policy) Phase 4: Implement stubs (real provider test, /proc health metrics, client/provider/backup endpoints, has_images) Phase 5: Code quality (shared provider helpers, explicit re-exports, all Clippy warnings fixed, unwrap removal, 6 unused deps removed, dashboard split into 7 sub-modules) Phase 6: Infrastructure (GitHub Actions CI, multi-stage Dockerfile, rustfmt.toml, clippy.toml, script fixes)
This commit is contained in:
61
.github/workflows/ci.yml
vendored
Normal file
61
.github/workflows/ci.yml
vendored
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
name: CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
pull_request:
|
||||||
|
branches: [main]
|
||||||
|
|
||||||
|
env:
|
||||||
|
CARGO_TERM_COLOR: always
|
||||||
|
RUST_BACKTRACE: 1
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
check:
|
||||||
|
name: Check
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: dtolnay/rust-toolchain@stable
|
||||||
|
- uses: Swatinem/rust-cache@v2
|
||||||
|
- run: cargo check --all-targets
|
||||||
|
|
||||||
|
clippy:
|
||||||
|
name: Clippy
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: dtolnay/rust-toolchain@stable
|
||||||
|
with:
|
||||||
|
components: clippy
|
||||||
|
- uses: Swatinem/rust-cache@v2
|
||||||
|
- run: cargo clippy --all-targets -- -D warnings
|
||||||
|
|
||||||
|
fmt:
|
||||||
|
name: Formatting
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: dtolnay/rust-toolchain@stable
|
||||||
|
with:
|
||||||
|
components: rustfmt
|
||||||
|
- run: cargo fmt --all -- --check
|
||||||
|
|
||||||
|
test:
|
||||||
|
name: Test
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: dtolnay/rust-toolchain@stable
|
||||||
|
- uses: Swatinem/rust-cache@v2
|
||||||
|
- run: cargo test --all-targets
|
||||||
|
|
||||||
|
build-release:
|
||||||
|
name: Release Build
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: [check, clippy, test]
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: dtolnay/rust-toolchain@stable
|
||||||
|
- uses: Swatinem/rust-cache@v2
|
||||||
|
- run: cargo build --release
|
||||||
270
Cargo.lock
generated
270
Cargo.lock
generated
@@ -92,46 +92,6 @@ dependencies = [
|
|||||||
"tokio",
|
"tokio",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "async-openai"
|
|
||||||
version = "0.33.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "b3ef6d66e53b3ec8ed4bae8e6dcdda75c0f5b4f67faba78fac1b09434cb4f2fc"
|
|
||||||
dependencies = [
|
|
||||||
"async-openai-macros",
|
|
||||||
"backoff",
|
|
||||||
"base64 0.22.1",
|
|
||||||
"bytes",
|
|
||||||
"derive_builder",
|
|
||||||
"eventsource-stream",
|
|
||||||
"futures",
|
|
||||||
"getrandom 0.3.4",
|
|
||||||
"rand 0.9.2",
|
|
||||||
"reqwest",
|
|
||||||
"reqwest-eventsource",
|
|
||||||
"secrecy",
|
|
||||||
"serde",
|
|
||||||
"serde_json",
|
|
||||||
"serde_urlencoded",
|
|
||||||
"thiserror 2.0.18",
|
|
||||||
"tokio",
|
|
||||||
"tokio-stream",
|
|
||||||
"tokio-util",
|
|
||||||
"tracing",
|
|
||||||
"url",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "async-openai-macros"
|
|
||||||
version = "0.1.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "81872a8e595e8ceceab71c6ba1f9078e313b452a1e31934e6763ef5d308705e4"
|
|
||||||
dependencies = [
|
|
||||||
"proc-macro2",
|
|
||||||
"quote",
|
|
||||||
"syn",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "async-stream"
|
name = "async-stream"
|
||||||
version = "0.3.6"
|
version = "0.3.6"
|
||||||
@@ -275,20 +235,6 @@ dependencies = [
|
|||||||
"syn",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "backoff"
|
|
||||||
version = "0.4.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1"
|
|
||||||
dependencies = [
|
|
||||||
"futures-core",
|
|
||||||
"getrandom 0.2.17",
|
|
||||||
"instant",
|
|
||||||
"pin-project-lite",
|
|
||||||
"rand 0.8.5",
|
|
||||||
"tokio",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "base64"
|
name = "base64"
|
||||||
version = "0.13.1"
|
version = "0.13.1"
|
||||||
@@ -598,54 +544,6 @@ dependencies = [
|
|||||||
"typenum",
|
"typenum",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "darling"
|
|
||||||
version = "0.20.11"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee"
|
|
||||||
dependencies = [
|
|
||||||
"darling_core",
|
|
||||||
"darling_macro",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "darling_core"
|
|
||||||
version = "0.20.11"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "0d00b9596d185e565c2207a0b01f8bd1a135483d02d9b7b0a54b11da8d53412e"
|
|
||||||
dependencies = [
|
|
||||||
"fnv",
|
|
||||||
"ident_case",
|
|
||||||
"proc-macro2",
|
|
||||||
"quote",
|
|
||||||
"strsim",
|
|
||||||
"syn",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "darling_macro"
|
|
||||||
version = "0.20.11"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead"
|
|
||||||
dependencies = [
|
|
||||||
"darling_core",
|
|
||||||
"quote",
|
|
||||||
"syn",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "dashmap"
|
|
||||||
version = "5.5.3"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856"
|
|
||||||
dependencies = [
|
|
||||||
"cfg-if",
|
|
||||||
"hashbrown 0.14.5",
|
|
||||||
"lock_api",
|
|
||||||
"once_cell",
|
|
||||||
"parking_lot_core",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "data-encoding"
|
name = "data-encoding"
|
||||||
version = "2.10.0"
|
version = "2.10.0"
|
||||||
@@ -663,37 +561,6 @@ dependencies = [
|
|||||||
"zeroize",
|
"zeroize",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "derive_builder"
|
|
||||||
version = "0.20.2"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947"
|
|
||||||
dependencies = [
|
|
||||||
"derive_builder_macro",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "derive_builder_core"
|
|
||||||
version = "0.20.2"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8"
|
|
||||||
dependencies = [
|
|
||||||
"darling",
|
|
||||||
"proc-macro2",
|
|
||||||
"quote",
|
|
||||||
"syn",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "derive_builder_macro"
|
|
||||||
version = "0.20.2"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c"
|
|
||||||
dependencies = [
|
|
||||||
"derive_builder_core",
|
|
||||||
"syn",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "difflib"
|
name = "difflib"
|
||||||
version = "0.4.0"
|
version = "0.4.0"
|
||||||
@@ -1028,26 +895,6 @@ dependencies = [
|
|||||||
"wasip3",
|
"wasip3",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "governor"
|
|
||||||
version = "0.6.3"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "68a7f542ee6b35af73b06abc0dad1c1bae89964e4e253bc4b587b91c9637867b"
|
|
||||||
dependencies = [
|
|
||||||
"cfg-if",
|
|
||||||
"dashmap",
|
|
||||||
"futures",
|
|
||||||
"futures-timer",
|
|
||||||
"no-std-compat",
|
|
||||||
"nonzero_ext",
|
|
||||||
"parking_lot",
|
|
||||||
"portable-atomic",
|
|
||||||
"quanta",
|
|
||||||
"rand 0.8.5",
|
|
||||||
"smallvec",
|
|
||||||
"spinning_top",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "h2"
|
name = "h2"
|
||||||
version = "0.4.13"
|
version = "0.4.13"
|
||||||
@@ -1076,12 +923,6 @@ dependencies = [
|
|||||||
"ahash",
|
"ahash",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "hashbrown"
|
|
||||||
version = "0.14.5"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "hashbrown"
|
name = "hashbrown"
|
||||||
version = "0.15.5"
|
version = "0.15.5"
|
||||||
@@ -1396,12 +1237,6 @@ version = "2.3.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954"
|
checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "ident_case"
|
|
||||||
version = "1.0.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "idna"
|
name = "idna"
|
||||||
version = "1.1.0"
|
version = "1.1.0"
|
||||||
@@ -1482,15 +1317,6 @@ dependencies = [
|
|||||||
"tempfile",
|
"tempfile",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "instant"
|
|
||||||
version = "0.1.13"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222"
|
|
||||||
dependencies = [
|
|
||||||
"cfg-if",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ipnet"
|
name = "ipnet"
|
||||||
version = "2.11.0"
|
version = "2.11.0"
|
||||||
@@ -1607,7 +1433,6 @@ version = "0.1.0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"assert_cmd",
|
"assert_cmd",
|
||||||
"async-openai",
|
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"axum",
|
"axum",
|
||||||
@@ -1618,16 +1443,11 @@ dependencies = [
|
|||||||
"config",
|
"config",
|
||||||
"dotenvy",
|
"dotenvy",
|
||||||
"futures",
|
"futures",
|
||||||
"governor",
|
|
||||||
"headers",
|
"headers",
|
||||||
"image",
|
"image",
|
||||||
"insta",
|
"insta",
|
||||||
"mime",
|
"mime",
|
||||||
"mime_guess",
|
|
||||||
"mockito",
|
"mockito",
|
||||||
"once_cell",
|
|
||||||
"rand 0.8.5",
|
|
||||||
"regex",
|
|
||||||
"reqwest",
|
"reqwest",
|
||||||
"reqwest-eventsource",
|
"reqwest-eventsource",
|
||||||
"serde",
|
"serde",
|
||||||
@@ -1776,12 +1596,6 @@ dependencies = [
|
|||||||
"pxfm",
|
"pxfm",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "no-std-compat"
|
|
||||||
version = "0.4.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "nom"
|
name = "nom"
|
||||||
version = "7.1.3"
|
version = "7.1.3"
|
||||||
@@ -1792,12 +1606,6 @@ dependencies = [
|
|||||||
"minimal-lexical",
|
"minimal-lexical",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "nonzero_ext"
|
|
||||||
version = "0.3.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "nu-ansi-term"
|
name = "nu-ansi-term"
|
||||||
version = "0.50.3"
|
version = "0.50.3"
|
||||||
@@ -2014,12 +1822,6 @@ dependencies = [
|
|||||||
"miniz_oxide",
|
"miniz_oxide",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "portable-atomic"
|
|
||||||
version = "1.13.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "potential_utf"
|
name = "potential_utf"
|
||||||
version = "0.1.4"
|
version = "0.1.4"
|
||||||
@@ -2093,21 +1895,6 @@ dependencies = [
|
|||||||
"num-traits",
|
"num-traits",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "quanta"
|
|
||||||
version = "0.12.6"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "f3ab5a9d756f0d97bdc89019bd2e4ea098cf9cde50ee7564dde6b81ccc8f06c7"
|
|
||||||
dependencies = [
|
|
||||||
"crossbeam-utils",
|
|
||||||
"libc",
|
|
||||||
"once_cell",
|
|
||||||
"raw-cpuid",
|
|
||||||
"wasi",
|
|
||||||
"web-sys",
|
|
||||||
"winapi",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "quick-error"
|
name = "quick-error"
|
||||||
version = "2.0.1"
|
version = "2.0.1"
|
||||||
@@ -2243,15 +2030,6 @@ dependencies = [
|
|||||||
"getrandom 0.3.4",
|
"getrandom 0.3.4",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "raw-cpuid"
|
|
||||||
version = "11.6.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186"
|
|
||||||
dependencies = [
|
|
||||||
"bitflags 2.11.0",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "redox_syscall"
|
name = "redox_syscall"
|
||||||
version = "0.5.18"
|
version = "0.5.18"
|
||||||
@@ -2317,7 +2095,6 @@ dependencies = [
|
|||||||
"hyper-util",
|
"hyper-util",
|
||||||
"js-sys",
|
"js-sys",
|
||||||
"log",
|
"log",
|
||||||
"mime_guess",
|
|
||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"quinn",
|
"quinn",
|
||||||
@@ -2490,16 +2267,6 @@ version = "1.2.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
|
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "secrecy"
|
|
||||||
version = "0.10.3"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "e891af845473308773346dc847b2c23ee78fe442e0472ac50e22a18a93d3ae5a"
|
|
||||||
dependencies = [
|
|
||||||
"serde",
|
|
||||||
"zeroize",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "semver"
|
name = "semver"
|
||||||
version = "1.0.27"
|
version = "1.0.27"
|
||||||
@@ -2684,15 +2451,6 @@ dependencies = [
|
|||||||
"lock_api",
|
"lock_api",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "spinning_top"
|
|
||||||
version = "0.3.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300"
|
|
||||||
dependencies = [
|
|
||||||
"lock_api",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "spki"
|
name = "spki"
|
||||||
version = "0.7.3"
|
version = "0.7.3"
|
||||||
@@ -2912,12 +2670,6 @@ dependencies = [
|
|||||||
"unicode-properties",
|
"unicode-properties",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "strsim"
|
|
||||||
version = "0.11.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "subtle"
|
name = "subtle"
|
||||||
version = "2.6.1"
|
version = "2.6.1"
|
||||||
@@ -3657,28 +3409,6 @@ dependencies = [
|
|||||||
"wasite",
|
"wasite",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "winapi"
|
|
||||||
version = "0.3.9"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
|
|
||||||
dependencies = [
|
|
||||||
"winapi-i686-pc-windows-gnu",
|
|
||||||
"winapi-x86_64-pc-windows-gnu",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "winapi-i686-pc-windows-gnu"
|
|
||||||
version = "0.4.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "winapi-x86_64-pc-windows-gnu"
|
|
||||||
version = "0.4.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows-core"
|
name = "windows-core"
|
||||||
version = "0.62.2"
|
version = "0.62.2"
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ tower-http = { version = "0.6", features = ["trace", "cors", "compression-gzip",
|
|||||||
|
|
||||||
# ========== HTTP Clients ==========
|
# ========== HTTP Clients ==========
|
||||||
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] }
|
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] }
|
||||||
async-openai = { version = "0.33", default-features = false, features = ["_api", "chat-completion"] }
|
|
||||||
tiktoken-rs = "0.9"
|
tiktoken-rs = "0.9"
|
||||||
|
|
||||||
# ========== Database & ORM ==========
|
# ========== Database & ORM ==========
|
||||||
@@ -41,7 +40,6 @@ tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
|
|||||||
base64 = "0.21"
|
base64 = "0.21"
|
||||||
image = { version = "0.25", default-features = false, features = ["jpeg", "png", "webp"] }
|
image = { version = "0.25", default-features = false, features = ["jpeg", "png", "webp"] }
|
||||||
mime = "0.3"
|
mime = "0.3"
|
||||||
mime_guess = "2.0"
|
|
||||||
|
|
||||||
# ========== Error Handling & Utilities ==========
|
# ========== Error Handling & Utilities ==========
|
||||||
anyhow = "1.0"
|
anyhow = "1.0"
|
||||||
@@ -53,12 +51,6 @@ futures = "0.3"
|
|||||||
async-trait = "0.1"
|
async-trait = "0.1"
|
||||||
async-stream = "0.3"
|
async-stream = "0.3"
|
||||||
reqwest-eventsource = "0.6"
|
reqwest-eventsource = "0.6"
|
||||||
once_cell = "1.19"
|
|
||||||
regex = "1.10"
|
|
||||||
rand = "0.8"
|
|
||||||
|
|
||||||
# ========== Rate Limiting & Circuit Breaking ==========
|
|
||||||
governor = "0.6"
|
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tokio-test = "0.4"
|
tokio-test = "0.4"
|
||||||
|
|||||||
35
Dockerfile
Normal file
35
Dockerfile
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
# ── Build stage ──────────────────────────────────────────────
|
||||||
|
FROM rust:1-bookworm AS builder
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Cache dependency build
|
||||||
|
COPY Cargo.toml Cargo.lock ./
|
||||||
|
RUN mkdir src && echo 'fn main() {}' > src/main.rs && \
|
||||||
|
cargo build --release && \
|
||||||
|
rm -rf src
|
||||||
|
|
||||||
|
# Build the actual binary
|
||||||
|
COPY src/ src/
|
||||||
|
RUN touch src/main.rs && cargo build --release
|
||||||
|
|
||||||
|
# ── Runtime stage ────────────────────────────────────────────
|
||||||
|
FROM debian:bookworm-slim
|
||||||
|
|
||||||
|
RUN apt-get update && \
|
||||||
|
apt-get install -y --no-install-recommends ca-certificates && \
|
||||||
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY --from=builder /app/target/release/llm-proxy /app/llm-proxy
|
||||||
|
COPY static/ /app/static/
|
||||||
|
|
||||||
|
# Default config location
|
||||||
|
VOLUME ["/app/config", "/app/data"]
|
||||||
|
|
||||||
|
EXPOSE 8080
|
||||||
|
|
||||||
|
ENV RUST_LOG=info
|
||||||
|
|
||||||
|
ENTRYPOINT ["/app/llm-proxy"]
|
||||||
1
clippy.toml
Normal file
1
clippy.toml
Normal file
@@ -0,0 +1 @@
|
|||||||
|
too-many-arguments-threshold = 8
|
||||||
2
rustfmt.toml
Normal file
2
rustfmt.toml
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
max_width = 120
|
||||||
|
use_field_init_shorthand = true
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
use axum::{extract::FromRequestParts, http::request::Parts};
|
use axum::{extract::FromRequestParts, http::request::Parts};
|
||||||
use axum_extra::headers::Authorization;
|
|
||||||
use axum_extra::TypedHeader;
|
use axum_extra::TypedHeader;
|
||||||
|
use axum_extra::headers::Authorization;
|
||||||
use headers::authorization::Bearer;
|
use headers::authorization::Bearer;
|
||||||
|
|
||||||
use crate::errors::AppError;
|
use crate::errors::AppError;
|
||||||
@@ -16,33 +16,19 @@ where
|
|||||||
{
|
{
|
||||||
type Rejection = AppError;
|
type Rejection = AppError;
|
||||||
|
|
||||||
fn from_request_parts(parts: &mut Parts, state: &S) -> impl std::future::Future<Output = Result<Self, Self::Rejection>> + Send {
|
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
||||||
// We need access to the AppState to get valid tokens
|
|
||||||
// Since state is generic here, we try to cast it or assume it's available via extensions
|
|
||||||
// In this project, AppState is cloned into Axum state.
|
|
||||||
|
|
||||||
async move {
|
|
||||||
// Extract bearer token from Authorization header
|
// Extract bearer token from Authorization header
|
||||||
let TypedHeader(Authorization(bearer)) =
|
let TypedHeader(Authorization(bearer)) = TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state)
|
||||||
TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state)
|
|
||||||
.await
|
.await
|
||||||
.map_err(|_| AppError::AuthError("Missing or invalid bearer token".to_string()))?;
|
.map_err(|_| AppError::AuthError("Missing or invalid bearer token".to_string()))?;
|
||||||
|
|
||||||
let token = bearer.token().to_string();
|
let token = bearer.token().to_string();
|
||||||
|
|
||||||
// For a proxy, we want to check if this token is in our allowed list
|
// Derive client_id from the token prefix
|
||||||
// The list is stored in AppState which is available in Parts extensions
|
let client_id = format!("client_{}", &token[..8.min(token.len())]);
|
||||||
let client_id = {
|
|
||||||
// In main.rs, we set up the router with State(state).
|
|
||||||
// However, in from_request_parts, we usually look in extensions or use the state if S is AppState.
|
|
||||||
// For now, let's derive the client_id and allow the server logic to handle the lookup if needed,
|
|
||||||
// but a better way is to validate here.
|
|
||||||
format!("client_{}", &token[..8.min(token.len())])
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(AuthenticatedClient { token, client_id })
|
Ok(AuthenticatedClient { token, client_id })
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn validate_token(token: &str, valid_tokens: &[String]) -> bool {
|
pub fn validate_token(token: &str, valid_tokens: &[String]) -> bool {
|
||||||
|
|||||||
@@ -5,10 +5,10 @@
|
|||||||
//! 2. Client usage tracking
|
//! 2. Client usage tracking
|
||||||
//! 3. Client rate limit configuration
|
//! 3. Client rate limit configuration
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use sqlx::{SqlitePool, Row};
|
use sqlx::{Row, SqlitePool};
|
||||||
use anyhow::Result;
|
|
||||||
use tracing::{info, warn};
|
use tracing::{info, warn};
|
||||||
|
|
||||||
/// Client information
|
/// Client information
|
||||||
@@ -74,7 +74,9 @@ impl ClientManager {
|
|||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// Then fetch the created client
|
// Then fetch the created client
|
||||||
let client = self.get_client(&request.client_id).await?
|
let client = self
|
||||||
|
.get_client(&request.client_id)
|
||||||
|
.await?
|
||||||
.ok_or_else(|| anyhow::anyhow!("Failed to retrieve created client"))?;
|
.ok_or_else(|| anyhow::anyhow!("Failed to retrieve created client"))?;
|
||||||
|
|
||||||
info!("Created client: {} ({})", client.name, client.client_id);
|
info!("Created client: {} ({})", client.name, client.client_id);
|
||||||
@@ -126,12 +128,11 @@ impl ClientManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Build update query dynamically based on provided fields
|
// Build update query dynamically based on provided fields
|
||||||
let mut updates = Vec::new();
|
|
||||||
let mut query_builder = sqlx::QueryBuilder::new("UPDATE clients SET ");
|
let mut query_builder = sqlx::QueryBuilder::new("UPDATE clients SET ");
|
||||||
let mut has_updates = false;
|
let mut has_updates = false;
|
||||||
|
|
||||||
if let Some(name) = &request.name {
|
if let Some(name) = &request.name {
|
||||||
updates.push("name = ");
|
query_builder.push("name = ");
|
||||||
query_builder.push_bind(name);
|
query_builder.push_bind(name);
|
||||||
has_updates = true;
|
has_updates = true;
|
||||||
}
|
}
|
||||||
@@ -140,7 +141,7 @@ impl ClientManager {
|
|||||||
if has_updates {
|
if has_updates {
|
||||||
query_builder.push(", ");
|
query_builder.push(", ");
|
||||||
}
|
}
|
||||||
updates.push("description = ");
|
query_builder.push("description = ");
|
||||||
query_builder.push_bind(description);
|
query_builder.push_bind(description);
|
||||||
has_updates = true;
|
has_updates = true;
|
||||||
}
|
}
|
||||||
@@ -149,7 +150,7 @@ impl ClientManager {
|
|||||||
if has_updates {
|
if has_updates {
|
||||||
query_builder.push(", ");
|
query_builder.push(", ");
|
||||||
}
|
}
|
||||||
updates.push("is_active = ");
|
query_builder.push("is_active = ");
|
||||||
query_builder.push_bind(is_active);
|
query_builder.push_bind(is_active);
|
||||||
has_updates = true;
|
has_updates = true;
|
||||||
}
|
}
|
||||||
@@ -158,7 +159,7 @@ impl ClientManager {
|
|||||||
if has_updates {
|
if has_updates {
|
||||||
query_builder.push(", ");
|
query_builder.push(", ");
|
||||||
}
|
}
|
||||||
updates.push("rate_limit_per_minute = ");
|
query_builder.push("rate_limit_per_minute = ");
|
||||||
query_builder.push_bind(rate_limit);
|
query_builder.push_bind(rate_limit);
|
||||||
has_updates = true;
|
has_updates = true;
|
||||||
}
|
}
|
||||||
@@ -204,7 +205,7 @@ impl ClientManager {
|
|||||||
FROM clients
|
FROM clients
|
||||||
ORDER BY created_at DESC
|
ORDER BY created_at DESC
|
||||||
LIMIT ? OFFSET ?
|
LIMIT ? OFFSET ?
|
||||||
"#
|
"#,
|
||||||
)
|
)
|
||||||
.bind(limit)
|
.bind(limit)
|
||||||
.bind(offset)
|
.bind(offset)
|
||||||
@@ -234,9 +235,7 @@ impl ClientManager {
|
|||||||
|
|
||||||
/// Delete a client
|
/// Delete a client
|
||||||
pub async fn delete_client(&self, client_id: &str) -> Result<bool> {
|
pub async fn delete_client(&self, client_id: &str) -> Result<bool> {
|
||||||
let result = sqlx::query(
|
let result = sqlx::query("DELETE FROM clients WHERE client_id = ?")
|
||||||
"DELETE FROM clients WHERE client_id = ?"
|
|
||||||
)
|
|
||||||
.bind(client_id)
|
.bind(client_id)
|
||||||
.execute(&self.db_pool)
|
.execute(&self.db_pool)
|
||||||
.await?;
|
.await?;
|
||||||
@@ -253,12 +252,7 @@ impl ClientManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Update client usage statistics after a request
|
/// Update client usage statistics after a request
|
||||||
pub async fn update_client_usage(
|
pub async fn update_client_usage(&self, client_id: &str, tokens: i64, cost: f64) -> Result<()> {
|
||||||
&self,
|
|
||||||
client_id: &str,
|
|
||||||
tokens: i64,
|
|
||||||
cost: f64,
|
|
||||||
) -> Result<()> {
|
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
r#"
|
r#"
|
||||||
UPDATE clients
|
UPDATE clients
|
||||||
@@ -268,7 +262,7 @@ impl ClientManager {
|
|||||||
total_cost = total_cost + ?,
|
total_cost = total_cost + ?,
|
||||||
updated_at = CURRENT_TIMESTAMP
|
updated_at = CURRENT_TIMESTAMP
|
||||||
WHERE client_id = ?
|
WHERE client_id = ?
|
||||||
"#
|
"#,
|
||||||
)
|
)
|
||||||
.bind(tokens)
|
.bind(tokens)
|
||||||
.bind(cost)
|
.bind(cost)
|
||||||
@@ -286,7 +280,7 @@ impl ClientManager {
|
|||||||
SELECT total_requests, total_tokens, total_cost
|
SELECT total_requests, total_tokens, total_cost
|
||||||
FROM clients
|
FROM clients
|
||||||
WHERE client_id = ?
|
WHERE client_id = ?
|
||||||
"#
|
"#,
|
||||||
)
|
)
|
||||||
.bind(client_id)
|
.bind(client_id)
|
||||||
.fetch_optional(&self.db_pool)
|
.fetch_optional(&self.db_pool)
|
||||||
|
|||||||
@@ -95,7 +95,7 @@ pub struct AppConfig {
|
|||||||
pub providers: ProviderConfig,
|
pub providers: ProviderConfig,
|
||||||
pub model_mapping: ModelMappingConfig,
|
pub model_mapping: ModelMappingConfig,
|
||||||
pub pricing: PricingConfig,
|
pub pricing: PricingConfig,
|
||||||
pub config_path: PathBuf,
|
pub config_path: Option<PathBuf>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AppConfig {
|
impl AppConfig {
|
||||||
@@ -120,7 +120,10 @@ impl AppConfig {
|
|||||||
.set_default("providers.openai.default_model", "gpt-4o")?
|
.set_default("providers.openai.default_model", "gpt-4o")?
|
||||||
.set_default("providers.openai.enabled", true)?
|
.set_default("providers.openai.enabled", true)?
|
||||||
.set_default("providers.gemini.api_key_env", "GEMINI_API_KEY")?
|
.set_default("providers.gemini.api_key_env", "GEMINI_API_KEY")?
|
||||||
.set_default("providers.gemini.base_url", "https://generativelanguage.googleapis.com/v1")?
|
.set_default(
|
||||||
|
"providers.gemini.base_url",
|
||||||
|
"https://generativelanguage.googleapis.com/v1",
|
||||||
|
)?
|
||||||
.set_default("providers.gemini.default_model", "gemini-2.0-flash")?
|
.set_default("providers.gemini.default_model", "gemini-2.0-flash")?
|
||||||
.set_default("providers.gemini.enabled", true)?
|
.set_default("providers.gemini.enabled", true)?
|
||||||
.set_default("providers.deepseek.api_key_env", "DEEPSEEK_API_KEY")?
|
.set_default("providers.deepseek.api_key_env", "DEEPSEEK_API_KEY")?
|
||||||
@@ -136,7 +139,11 @@ impl AppConfig {
|
|||||||
.set_default("providers.ollama.models", Vec::<String>::new())?;
|
.set_default("providers.ollama.models", Vec::<String>::new())?;
|
||||||
|
|
||||||
// Load from config file if exists
|
// Load from config file if exists
|
||||||
let config_path = config_path.unwrap_or_else(|| std::env::current_dir().unwrap().join("config.toml"));
|
let config_path = config_path.unwrap_or_else(|| {
|
||||||
|
std::env::current_dir()
|
||||||
|
.unwrap_or_else(|_| PathBuf::from("."))
|
||||||
|
.join("config.toml")
|
||||||
|
});
|
||||||
if config_path.exists() {
|
if config_path.exists() {
|
||||||
config_builder = config_builder.add_source(File::from(config_path.clone()).format(FileFormat::Toml));
|
config_builder = config_builder.add_source(File::from(config_path.clone()).format(FileFormat::Toml));
|
||||||
}
|
}
|
||||||
@@ -174,7 +181,7 @@ impl AppConfig {
|
|||||||
providers,
|
providers,
|
||||||
model_mapping,
|
model_mapping,
|
||||||
pricing,
|
pricing,
|
||||||
config_path,
|
config_path: Some(config_path),
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -187,16 +194,15 @@ impl AppConfig {
|
|||||||
_ => return Err(anyhow::anyhow!("Unknown provider: {}", provider)),
|
_ => return Err(anyhow::anyhow!("Unknown provider: {}", provider)),
|
||||||
};
|
};
|
||||||
|
|
||||||
std::env::var(env_var)
|
std::env::var(env_var).map_err(|_| anyhow::anyhow!("Environment variable {} not set for {}", env_var, provider))
|
||||||
.map_err(|_| anyhow::anyhow!("Environment variable {} not set for {}", env_var, provider))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Helper function to deserialize a Vec<String> from either a sequence or a comma-separated string
|
/// Helper function to deserialize a Vec<String> from either a sequence or a comma-separated string
|
||||||
fn deserialize_vec_or_string<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
|
fn deserialize_vec_or_string<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
|
||||||
where
|
where
|
||||||
D: serde::Deserializer<'de>,
|
D: serde::Deserializer<'de>,
|
||||||
{
|
{
|
||||||
struct VecOrString;
|
struct VecOrString;
|
||||||
|
|
||||||
impl<'de> serde::de::Visitor<'de> for VecOrString {
|
impl<'de> serde::de::Visitor<'de> for VecOrString {
|
||||||
@@ -230,5 +236,4 @@ impl AppConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
deserializer.deserialize_any(VecOrString)
|
deserializer.deserialize_any(VecOrString)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
130
src/dashboard/auth.rs
Normal file
130
src/dashboard/auth.rs
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
use axum::{extract::State, response::Json};
|
||||||
|
use bcrypt;
|
||||||
|
use serde::Deserialize;
|
||||||
|
use sqlx::Row;
|
||||||
|
use tracing::warn;
|
||||||
|
|
||||||
|
use super::{ApiResponse, DashboardState};
|
||||||
|
|
||||||
|
// Authentication handlers
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub(super) struct LoginRequest {
|
||||||
|
pub(super) username: String,
|
||||||
|
pub(super) password: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) async fn handle_login(
|
||||||
|
State(state): State<DashboardState>,
|
||||||
|
Json(payload): Json<LoginRequest>,
|
||||||
|
) -> Json<ApiResponse<serde_json::Value>> {
|
||||||
|
let pool = &state.app_state.db_pool;
|
||||||
|
|
||||||
|
let user_result =
|
||||||
|
sqlx::query("SELECT username, password_hash, role, must_change_password FROM users WHERE username = ?")
|
||||||
|
.bind(&payload.username)
|
||||||
|
.fetch_optional(pool)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match user_result {
|
||||||
|
Ok(Some(row)) => {
|
||||||
|
let hash = row.get::<String, _>("password_hash");
|
||||||
|
if bcrypt::verify(&payload.password, &hash).unwrap_or(false) {
|
||||||
|
let username = row.get::<String, _>("username");
|
||||||
|
let role = row.get::<String, _>("role");
|
||||||
|
let must_change_password = row.get::<bool, _>("must_change_password");
|
||||||
|
let token = state
|
||||||
|
.session_manager
|
||||||
|
.create_session(username.clone(), role.clone())
|
||||||
|
.await;
|
||||||
|
Json(ApiResponse::success(serde_json::json!({
|
||||||
|
"token": token,
|
||||||
|
"must_change_password": must_change_password,
|
||||||
|
"user": {
|
||||||
|
"username": username,
|
||||||
|
"name": "Administrator",
|
||||||
|
"role": role
|
||||||
|
}
|
||||||
|
})))
|
||||||
|
} else {
|
||||||
|
Json(ApiResponse::error("Invalid username or password".to_string()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(None) => Json(ApiResponse::error("Invalid username or password".to_string())),
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Database error during login: {}", e);
|
||||||
|
Json(ApiResponse::error("Login failed due to system error".to_string()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) async fn handle_auth_status(
|
||||||
|
State(state): State<DashboardState>,
|
||||||
|
headers: axum::http::HeaderMap,
|
||||||
|
) -> Json<ApiResponse<serde_json::Value>> {
|
||||||
|
let token = headers
|
||||||
|
.get("Authorization")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.and_then(|v| v.strip_prefix("Bearer "));
|
||||||
|
|
||||||
|
if let Some(token) = token
|
||||||
|
&& let Some(session) = state.session_manager.validate_session(token).await
|
||||||
|
{
|
||||||
|
return Json(ApiResponse::success(serde_json::json!({
|
||||||
|
"authenticated": true,
|
||||||
|
"user": {
|
||||||
|
"username": session.username,
|
||||||
|
"name": "Administrator",
|
||||||
|
"role": session.role
|
||||||
|
}
|
||||||
|
})));
|
||||||
|
}
|
||||||
|
|
||||||
|
Json(ApiResponse::error("Not authenticated".to_string()))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub(super) struct ChangePasswordRequest {
|
||||||
|
pub(super) current_password: String,
|
||||||
|
pub(super) new_password: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) async fn handle_change_password(
|
||||||
|
State(state): State<DashboardState>,
|
||||||
|
Json(payload): Json<ChangePasswordRequest>,
|
||||||
|
) -> Json<ApiResponse<serde_json::Value>> {
|
||||||
|
let pool = &state.app_state.db_pool;
|
||||||
|
|
||||||
|
// For now, always change 'admin' user
|
||||||
|
let user_result = sqlx::query("SELECT password_hash FROM users WHERE username = 'admin'")
|
||||||
|
.fetch_one(pool)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match user_result {
|
||||||
|
Ok(row) => {
|
||||||
|
let hash = row.get::<String, _>("password_hash");
|
||||||
|
if bcrypt::verify(&payload.current_password, &hash).unwrap_or(false) {
|
||||||
|
let new_hash = match bcrypt::hash(&payload.new_password, 12) {
|
||||||
|
Ok(h) => h,
|
||||||
|
Err(_) => return Json(ApiResponse::error("Failed to hash new password".to_string())),
|
||||||
|
};
|
||||||
|
|
||||||
|
let update_result = sqlx::query(
|
||||||
|
"UPDATE users SET password_hash = ?, must_change_password = FALSE WHERE username = 'admin'",
|
||||||
|
)
|
||||||
|
.bind(new_hash)
|
||||||
|
.execute(pool)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match update_result {
|
||||||
|
Ok(_) => Json(ApiResponse::success(
|
||||||
|
serde_json::json!({ "message": "Password updated successfully" }),
|
||||||
|
)),
|
||||||
|
Err(e) => Json(ApiResponse::error(format!("Failed to update database: {}", e))),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Json(ApiResponse::error("Current password incorrect".to_string()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => Json(ApiResponse::error(format!("User not found: {}", e))),
|
||||||
|
}
|
||||||
|
}
|
||||||
227
src/dashboard/clients.rs
Normal file
227
src/dashboard/clients.rs
Normal file
@@ -0,0 +1,227 @@
|
|||||||
|
use axum::{
|
||||||
|
extract::{Path, State},
|
||||||
|
response::Json,
|
||||||
|
};
|
||||||
|
use chrono;
|
||||||
|
use serde::Deserialize;
|
||||||
|
use serde_json;
|
||||||
|
use sqlx::Row;
|
||||||
|
use tracing::warn;
|
||||||
|
use uuid;
|
||||||
|
|
||||||
|
use super::{ApiResponse, DashboardState};
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub(super) struct CreateClientRequest {
|
||||||
|
pub(super) name: String,
|
||||||
|
pub(super) client_id: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) async fn handle_get_clients(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||||
|
let pool = &state.app_state.db_pool;
|
||||||
|
|
||||||
|
let result = sqlx::query(
|
||||||
|
r#"
|
||||||
|
SELECT
|
||||||
|
client_id as id,
|
||||||
|
name,
|
||||||
|
created_at,
|
||||||
|
total_requests,
|
||||||
|
total_tokens,
|
||||||
|
total_cost,
|
||||||
|
is_active
|
||||||
|
FROM clients
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.fetch_all(pool)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(rows) => {
|
||||||
|
let clients: Vec<serde_json::Value> = rows
|
||||||
|
.into_iter()
|
||||||
|
.map(|row| {
|
||||||
|
serde_json::json!({
|
||||||
|
"id": row.get::<String, _>("id"),
|
||||||
|
"name": row.get::<Option<String>, _>("name").unwrap_or_else(|| "Unnamed".to_string()),
|
||||||
|
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
|
||||||
|
"requests_count": row.get::<i64, _>("total_requests"),
|
||||||
|
"total_tokens": row.get::<i64, _>("total_tokens"),
|
||||||
|
"total_cost": row.get::<f64, _>("total_cost"),
|
||||||
|
"status": if row.get::<bool, _>("is_active") { "active" } else { "inactive" },
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Json(ApiResponse::success(serde_json::json!(clients)))
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Failed to fetch clients: {}", e);
|
||||||
|
Json(ApiResponse::error("Failed to fetch clients".to_string()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) async fn handle_create_client(
|
||||||
|
State(state): State<DashboardState>,
|
||||||
|
Json(payload): Json<CreateClientRequest>,
|
||||||
|
) -> Json<ApiResponse<serde_json::Value>> {
|
||||||
|
let pool = &state.app_state.db_pool;
|
||||||
|
|
||||||
|
let client_id = payload
|
||||||
|
.client_id
|
||||||
|
.unwrap_or_else(|| format!("client-{}", &uuid::Uuid::new_v4().to_string()[..8]));
|
||||||
|
|
||||||
|
let result = sqlx::query(
|
||||||
|
r#"
|
||||||
|
INSERT INTO clients (client_id, name, is_active)
|
||||||
|
VALUES (?, ?, TRUE)
|
||||||
|
RETURNING *
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.bind(&client_id)
|
||||||
|
.bind(&payload.name)
|
||||||
|
.fetch_one(pool)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(row) => Json(ApiResponse::success(serde_json::json!({
|
||||||
|
"id": row.get::<String, _>("client_id"),
|
||||||
|
"name": row.get::<Option<String>, _>("name"),
|
||||||
|
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
|
||||||
|
"status": "active",
|
||||||
|
}))),
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Failed to create client: {}", e);
|
||||||
|
Json(ApiResponse::error(format!("Failed to create client: {}", e)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) async fn handle_get_client(
|
||||||
|
State(state): State<DashboardState>,
|
||||||
|
Path(id): Path<String>,
|
||||||
|
) -> Json<ApiResponse<serde_json::Value>> {
|
||||||
|
let pool = &state.app_state.db_pool;
|
||||||
|
|
||||||
|
let result = sqlx::query(
|
||||||
|
r#"
|
||||||
|
SELECT
|
||||||
|
c.client_id as id,
|
||||||
|
c.name,
|
||||||
|
c.is_active,
|
||||||
|
c.created_at,
|
||||||
|
COALESCE(c.total_tokens, 0) as total_tokens,
|
||||||
|
COALESCE(c.total_cost, 0.0) as total_cost,
|
||||||
|
COUNT(r.id) as total_requests,
|
||||||
|
MAX(r.timestamp) as last_request
|
||||||
|
FROM clients c
|
||||||
|
LEFT JOIN llm_requests r ON c.client_id = r.client_id
|
||||||
|
WHERE c.client_id = ?
|
||||||
|
GROUP BY c.client_id
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.bind(&id)
|
||||||
|
.fetch_optional(pool)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(Some(row)) => Json(ApiResponse::success(serde_json::json!({
|
||||||
|
"id": row.get::<String, _>("id"),
|
||||||
|
"name": row.get::<Option<String>, _>("name").unwrap_or_else(|| "Unnamed".to_string()),
|
||||||
|
"is_active": row.get::<bool, _>("is_active"),
|
||||||
|
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
|
||||||
|
"total_tokens": row.get::<i64, _>("total_tokens"),
|
||||||
|
"total_cost": row.get::<f64, _>("total_cost"),
|
||||||
|
"total_requests": row.get::<i64, _>("total_requests"),
|
||||||
|
"last_request": row.get::<Option<chrono::DateTime<chrono::Utc>>, _>("last_request"),
|
||||||
|
"status": if row.get::<bool, _>("is_active") { "active" } else { "inactive" },
|
||||||
|
}))),
|
||||||
|
Ok(None) => Json(ApiResponse::error(format!("Client '{}' not found", id))),
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Failed to fetch client: {}", e);
|
||||||
|
Json(ApiResponse::error(format!("Failed to fetch client: {}", e)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) async fn handle_delete_client(
|
||||||
|
State(state): State<DashboardState>,
|
||||||
|
Path(id): Path<String>,
|
||||||
|
) -> Json<ApiResponse<serde_json::Value>> {
|
||||||
|
let pool = &state.app_state.db_pool;
|
||||||
|
|
||||||
|
// Don't allow deleting the default client
|
||||||
|
if id == "default" {
|
||||||
|
return Json(ApiResponse::error("Cannot delete default client".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = sqlx::query("DELETE FROM clients WHERE client_id = ?")
|
||||||
|
.bind(id)
|
||||||
|
.execute(pool)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(_) => Json(ApiResponse::success(serde_json::json!({ "message": "Client deleted" }))),
|
||||||
|
Err(e) => Json(ApiResponse::error(format!("Failed to delete client: {}", e))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) async fn handle_client_usage(
|
||||||
|
State(state): State<DashboardState>,
|
||||||
|
Path(id): Path<String>,
|
||||||
|
) -> Json<ApiResponse<serde_json::Value>> {
|
||||||
|
let pool = &state.app_state.db_pool;
|
||||||
|
|
||||||
|
// Get per-model breakdown for this client
|
||||||
|
let result = sqlx::query(
|
||||||
|
r#"
|
||||||
|
SELECT
|
||||||
|
model,
|
||||||
|
provider,
|
||||||
|
COUNT(*) as request_count,
|
||||||
|
SUM(prompt_tokens) as prompt_tokens,
|
||||||
|
SUM(completion_tokens) as completion_tokens,
|
||||||
|
SUM(total_tokens) as total_tokens,
|
||||||
|
SUM(cost) as total_cost,
|
||||||
|
AVG(duration_ms) as avg_duration_ms
|
||||||
|
FROM llm_requests
|
||||||
|
WHERE client_id = ?
|
||||||
|
GROUP BY model, provider
|
||||||
|
ORDER BY total_cost DESC
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.bind(&id)
|
||||||
|
.fetch_all(pool)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(rows) => {
|
||||||
|
let breakdown: Vec<serde_json::Value> = rows
|
||||||
|
.into_iter()
|
||||||
|
.map(|row| {
|
||||||
|
serde_json::json!({
|
||||||
|
"model": row.get::<String, _>("model"),
|
||||||
|
"provider": row.get::<String, _>("provider"),
|
||||||
|
"request_count": row.get::<i64, _>("request_count"),
|
||||||
|
"prompt_tokens": row.get::<i64, _>("prompt_tokens"),
|
||||||
|
"completion_tokens": row.get::<i64, _>("completion_tokens"),
|
||||||
|
"total_tokens": row.get::<i64, _>("total_tokens"),
|
||||||
|
"total_cost": row.get::<f64, _>("total_cost"),
|
||||||
|
"avg_duration_ms": row.get::<f64, _>("avg_duration_ms"),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Json(ApiResponse::success(serde_json::json!({
|
||||||
|
"client_id": id,
|
||||||
|
"breakdown": breakdown,
|
||||||
|
})))
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Failed to fetch client usage: {}", e);
|
||||||
|
Json(ApiResponse::error(format!("Failed to fetch client usage: {}", e)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
1073
src/dashboard/mod.rs
1073
src/dashboard/mod.rs
File diff suppressed because it is too large
Load Diff
116
src/dashboard/models.rs
Normal file
116
src/dashboard/models.rs
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
use axum::{
|
||||||
|
extract::{Path, State},
|
||||||
|
response::Json,
|
||||||
|
};
|
||||||
|
use serde::Deserialize;
|
||||||
|
use serde_json;
|
||||||
|
use sqlx::Row;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use super::{ApiResponse, DashboardState};
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub(super) struct UpdateModelRequest {
|
||||||
|
pub(super) enabled: bool,
|
||||||
|
pub(super) prompt_cost: Option<f64>,
|
||||||
|
pub(super) completion_cost: Option<f64>,
|
||||||
|
pub(super) mapping: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) async fn handle_get_models(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||||
|
let registry = &state.app_state.model_registry;
|
||||||
|
let pool = &state.app_state.db_pool;
|
||||||
|
|
||||||
|
// Load overrides from database
|
||||||
|
let db_models_result =
|
||||||
|
sqlx::query("SELECT id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping FROM model_configs")
|
||||||
|
.fetch_all(pool)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let mut db_models = HashMap::new();
|
||||||
|
if let Ok(rows) = db_models_result {
|
||||||
|
for row in rows {
|
||||||
|
let id: String = row.get("id");
|
||||||
|
db_models.insert(id, row);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut models_json = Vec::new();
|
||||||
|
|
||||||
|
for (p_id, p_info) in ®istry.providers {
|
||||||
|
for (m_id, m_meta) in &p_info.models {
|
||||||
|
let mut enabled = true;
|
||||||
|
let mut prompt_cost = m_meta.cost.as_ref().map(|c| c.input).unwrap_or(0.0);
|
||||||
|
let mut completion_cost = m_meta.cost.as_ref().map(|c| c.output).unwrap_or(0.0);
|
||||||
|
let mut mapping = None::<String>;
|
||||||
|
|
||||||
|
if let Some(row) = db_models.get(m_id) {
|
||||||
|
enabled = row.get("enabled");
|
||||||
|
if let Some(p) = row.get::<Option<f64>, _>("prompt_cost_per_m") {
|
||||||
|
prompt_cost = p;
|
||||||
|
}
|
||||||
|
if let Some(c) = row.get::<Option<f64>, _>("completion_cost_per_m") {
|
||||||
|
completion_cost = c;
|
||||||
|
}
|
||||||
|
mapping = row.get("mapping");
|
||||||
|
}
|
||||||
|
|
||||||
|
models_json.push(serde_json::json!({
|
||||||
|
"id": m_id,
|
||||||
|
"provider": p_id,
|
||||||
|
"name": m_meta.name,
|
||||||
|
"enabled": enabled,
|
||||||
|
"prompt_cost": prompt_cost,
|
||||||
|
"completion_cost": completion_cost,
|
||||||
|
"mapping": mapping,
|
||||||
|
"context_limit": m_meta.limit.as_ref().map(|l| l.context).unwrap_or(0),
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Json(ApiResponse::success(serde_json::json!(models_json)))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) async fn handle_update_model(
|
||||||
|
State(state): State<DashboardState>,
|
||||||
|
Path(id): Path<String>,
|
||||||
|
Json(payload): Json<UpdateModelRequest>,
|
||||||
|
) -> Json<ApiResponse<serde_json::Value>> {
|
||||||
|
let pool = &state.app_state.db_pool;
|
||||||
|
|
||||||
|
// Find provider_id for this model in registry
|
||||||
|
let provider_id = state
|
||||||
|
.app_state
|
||||||
|
.model_registry
|
||||||
|
.providers
|
||||||
|
.iter()
|
||||||
|
.find(|(_, p)| p.models.contains_key(&id))
|
||||||
|
.map(|(id, _)| id.clone())
|
||||||
|
.unwrap_or_else(|| "unknown".to_string());
|
||||||
|
|
||||||
|
let result = sqlx::query(
|
||||||
|
r#"
|
||||||
|
INSERT INTO model_configs (id, provider_id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?)
|
||||||
|
ON CONFLICT(id) DO UPDATE SET
|
||||||
|
enabled = excluded.enabled,
|
||||||
|
prompt_cost_per_m = excluded.prompt_cost_per_m,
|
||||||
|
completion_cost_per_m = excluded.completion_cost_per_m,
|
||||||
|
mapping = excluded.mapping,
|
||||||
|
updated_at = CURRENT_TIMESTAMP
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.bind(&id)
|
||||||
|
.bind(provider_id)
|
||||||
|
.bind(payload.enabled)
|
||||||
|
.bind(payload.prompt_cost)
|
||||||
|
.bind(payload.completion_cost)
|
||||||
|
.bind(payload.mapping)
|
||||||
|
.execute(pool)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(_) => Json(ApiResponse::success(serde_json::json!({ "message": "Model updated" }))),
|
||||||
|
Err(e) => Json(ApiResponse::error(format!("Failed to update model: {}", e))),
|
||||||
|
}
|
||||||
|
}
|
||||||
346
src/dashboard/providers.rs
Normal file
346
src/dashboard/providers.rs
Normal file
@@ -0,0 +1,346 @@
|
|||||||
|
use axum::{
|
||||||
|
extract::{Path, State},
|
||||||
|
response::Json,
|
||||||
|
};
|
||||||
|
use serde::Deserialize;
|
||||||
|
use serde_json;
|
||||||
|
use sqlx::Row;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use tracing::warn;
|
||||||
|
|
||||||
|
use super::{ApiResponse, DashboardState};
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub(super) struct UpdateProviderRequest {
|
||||||
|
pub(super) enabled: bool,
|
||||||
|
pub(super) base_url: Option<String>,
|
||||||
|
pub(super) api_key: Option<String>,
|
||||||
|
pub(super) credit_balance: Option<f64>,
|
||||||
|
pub(super) low_credit_threshold: Option<f64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) async fn handle_get_providers(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||||
|
let registry = &state.app_state.model_registry;
|
||||||
|
let config = &state.app_state.config;
|
||||||
|
let pool = &state.app_state.db_pool;
|
||||||
|
|
||||||
|
// Load all overrides from database
|
||||||
|
let db_configs_result =
|
||||||
|
sqlx::query("SELECT id, enabled, base_url, credit_balance, low_credit_threshold FROM provider_configs")
|
||||||
|
.fetch_all(pool)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let mut db_configs = HashMap::new();
|
||||||
|
if let Ok(rows) = db_configs_result {
|
||||||
|
for row in rows {
|
||||||
|
let id: String = row.get("id");
|
||||||
|
let enabled: bool = row.get("enabled");
|
||||||
|
let base_url: Option<String> = row.get("base_url");
|
||||||
|
let balance: f64 = row.get("credit_balance");
|
||||||
|
let threshold: f64 = row.get("low_credit_threshold");
|
||||||
|
db_configs.insert(id, (enabled, base_url, balance, threshold));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut providers_json = Vec::new();
|
||||||
|
|
||||||
|
// Define the list of providers we support
|
||||||
|
let provider_ids = vec!["openai", "gemini", "deepseek", "grok", "ollama"];
|
||||||
|
|
||||||
|
for id in provider_ids {
|
||||||
|
// Get base config
|
||||||
|
let (mut enabled, mut base_url, display_name) = match id {
|
||||||
|
"openai" => (
|
||||||
|
config.providers.openai.enabled,
|
||||||
|
config.providers.openai.base_url.clone(),
|
||||||
|
"OpenAI",
|
||||||
|
),
|
||||||
|
"gemini" => (
|
||||||
|
config.providers.gemini.enabled,
|
||||||
|
config.providers.gemini.base_url.clone(),
|
||||||
|
"Google Gemini",
|
||||||
|
),
|
||||||
|
"deepseek" => (
|
||||||
|
config.providers.deepseek.enabled,
|
||||||
|
config.providers.deepseek.base_url.clone(),
|
||||||
|
"DeepSeek",
|
||||||
|
),
|
||||||
|
"grok" => (
|
||||||
|
config.providers.grok.enabled,
|
||||||
|
config.providers.grok.base_url.clone(),
|
||||||
|
"xAI Grok",
|
||||||
|
),
|
||||||
|
"ollama" => (
|
||||||
|
config.providers.ollama.enabled,
|
||||||
|
config.providers.ollama.base_url.clone(),
|
||||||
|
"Ollama",
|
||||||
|
),
|
||||||
|
_ => (false, "".to_string(), "Unknown"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut balance = 0.0;
|
||||||
|
let mut threshold = 5.0;
|
||||||
|
|
||||||
|
// Apply database overrides
|
||||||
|
if let Some((db_enabled, db_url, db_balance, db_threshold)) = db_configs.get(id) {
|
||||||
|
enabled = *db_enabled;
|
||||||
|
if let Some(url) = db_url {
|
||||||
|
base_url = url.clone();
|
||||||
|
}
|
||||||
|
balance = *db_balance;
|
||||||
|
threshold = *db_threshold;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find models for this provider in registry
|
||||||
|
let mut models = Vec::new();
|
||||||
|
if let Some(p_info) = registry.providers.get(id) {
|
||||||
|
models = p_info.models.keys().cloned().collect();
|
||||||
|
} else if id == "ollama" {
|
||||||
|
models = config.providers.ollama.models.clone();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine status
|
||||||
|
let status = if !enabled {
|
||||||
|
"disabled"
|
||||||
|
} else {
|
||||||
|
// Check if it's actually initialized in the provider manager
|
||||||
|
if state.app_state.provider_manager.get_provider(id).await.is_some() {
|
||||||
|
// Check circuit breaker
|
||||||
|
if state
|
||||||
|
.app_state
|
||||||
|
.rate_limit_manager
|
||||||
|
.check_provider_request(id)
|
||||||
|
.await
|
||||||
|
.unwrap_or(true)
|
||||||
|
{
|
||||||
|
"online"
|
||||||
|
} else {
|
||||||
|
"degraded"
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
"error" // Enabled but failed to initialize (e.g. missing API key)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
providers_json.push(serde_json::json!({
|
||||||
|
"id": id,
|
||||||
|
"name": display_name,
|
||||||
|
"enabled": enabled,
|
||||||
|
"status": status,
|
||||||
|
"models": models,
|
||||||
|
"base_url": base_url,
|
||||||
|
"credit_balance": balance,
|
||||||
|
"low_credit_threshold": threshold,
|
||||||
|
"last_used": None::<String>,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
Json(ApiResponse::success(serde_json::json!(providers_json)))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) async fn handle_get_provider(
|
||||||
|
State(state): State<DashboardState>,
|
||||||
|
Path(name): Path<String>,
|
||||||
|
) -> Json<ApiResponse<serde_json::Value>> {
|
||||||
|
let registry = &state.app_state.model_registry;
|
||||||
|
let config = &state.app_state.config;
|
||||||
|
let pool = &state.app_state.db_pool;
|
||||||
|
|
||||||
|
// Validate provider name
|
||||||
|
let (mut enabled, mut base_url, display_name) = match name.as_str() {
|
||||||
|
"openai" => (
|
||||||
|
config.providers.openai.enabled,
|
||||||
|
config.providers.openai.base_url.clone(),
|
||||||
|
"OpenAI",
|
||||||
|
),
|
||||||
|
"gemini" => (
|
||||||
|
config.providers.gemini.enabled,
|
||||||
|
config.providers.gemini.base_url.clone(),
|
||||||
|
"Google Gemini",
|
||||||
|
),
|
||||||
|
"deepseek" => (
|
||||||
|
config.providers.deepseek.enabled,
|
||||||
|
config.providers.deepseek.base_url.clone(),
|
||||||
|
"DeepSeek",
|
||||||
|
),
|
||||||
|
"grok" => (
|
||||||
|
config.providers.grok.enabled,
|
||||||
|
config.providers.grok.base_url.clone(),
|
||||||
|
"xAI Grok",
|
||||||
|
),
|
||||||
|
"ollama" => (
|
||||||
|
config.providers.ollama.enabled,
|
||||||
|
config.providers.ollama.base_url.clone(),
|
||||||
|
"Ollama",
|
||||||
|
),
|
||||||
|
_ => return Json(ApiResponse::error(format!("Unknown provider '{}'", name))),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut balance = 0.0;
|
||||||
|
let mut threshold = 5.0;
|
||||||
|
|
||||||
|
// Apply database overrides
|
||||||
|
let db_config = sqlx::query(
|
||||||
|
"SELECT enabled, base_url, credit_balance, low_credit_threshold FROM provider_configs WHERE id = ?",
|
||||||
|
)
|
||||||
|
.bind(&name)
|
||||||
|
.fetch_optional(pool)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
if let Ok(Some(row)) = db_config {
|
||||||
|
enabled = row.get::<bool, _>("enabled");
|
||||||
|
if let Some(url) = row.get::<Option<String>, _>("base_url") {
|
||||||
|
base_url = url;
|
||||||
|
}
|
||||||
|
balance = row.get::<f64, _>("credit_balance");
|
||||||
|
threshold = row.get::<f64, _>("low_credit_threshold");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find models for this provider
|
||||||
|
let mut models = Vec::new();
|
||||||
|
if let Some(p_info) = registry.providers.get(name.as_str()) {
|
||||||
|
models = p_info.models.keys().cloned().collect();
|
||||||
|
} else if name == "ollama" {
|
||||||
|
models = config.providers.ollama.models.clone();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine status
|
||||||
|
let status = if !enabled {
|
||||||
|
"disabled"
|
||||||
|
} else if state.app_state.provider_manager.get_provider(&name).await.is_some() {
|
||||||
|
if state
|
||||||
|
.app_state
|
||||||
|
.rate_limit_manager
|
||||||
|
.check_provider_request(&name)
|
||||||
|
.await
|
||||||
|
.unwrap_or(true)
|
||||||
|
{
|
||||||
|
"online"
|
||||||
|
} else {
|
||||||
|
"degraded"
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
"error"
|
||||||
|
};
|
||||||
|
|
||||||
|
Json(ApiResponse::success(serde_json::json!({
|
||||||
|
"id": name,
|
||||||
|
"name": display_name,
|
||||||
|
"enabled": enabled,
|
||||||
|
"status": status,
|
||||||
|
"models": models,
|
||||||
|
"base_url": base_url,
|
||||||
|
"credit_balance": balance,
|
||||||
|
"low_credit_threshold": threshold,
|
||||||
|
"last_used": None::<String>,
|
||||||
|
})))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) async fn handle_update_provider(
|
||||||
|
State(state): State<DashboardState>,
|
||||||
|
Path(name): Path<String>,
|
||||||
|
Json(payload): Json<UpdateProviderRequest>,
|
||||||
|
) -> Json<ApiResponse<serde_json::Value>> {
|
||||||
|
let pool = &state.app_state.db_pool;
|
||||||
|
|
||||||
|
// Update or insert into database
|
||||||
|
let result = sqlx::query(
|
||||||
|
r#"
|
||||||
|
INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, credit_balance, low_credit_threshold)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||||
|
ON CONFLICT(id) DO UPDATE SET
|
||||||
|
enabled = excluded.enabled,
|
||||||
|
base_url = excluded.base_url,
|
||||||
|
api_key = COALESCE(excluded.api_key, provider_configs.api_key),
|
||||||
|
credit_balance = COALESCE(excluded.credit_balance, provider_configs.credit_balance),
|
||||||
|
low_credit_threshold = COALESCE(excluded.low_credit_threshold, provider_configs.low_credit_threshold),
|
||||||
|
updated_at = CURRENT_TIMESTAMP
|
||||||
|
"#
|
||||||
|
)
|
||||||
|
.bind(&name)
|
||||||
|
.bind(name.to_uppercase())
|
||||||
|
.bind(payload.enabled)
|
||||||
|
.bind(&payload.base_url)
|
||||||
|
.bind(&payload.api_key)
|
||||||
|
.bind(payload.credit_balance)
|
||||||
|
.bind(payload.low_credit_threshold)
|
||||||
|
.execute(pool)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(_) => {
|
||||||
|
// Re-initialize provider in manager
|
||||||
|
if let Err(e) = state
|
||||||
|
.app_state
|
||||||
|
.provider_manager
|
||||||
|
.initialize_provider(&name, &state.app_state.config, &state.app_state.db_pool)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
warn!("Failed to re-initialize provider {}: {}", name, e);
|
||||||
|
return Json(ApiResponse::error(format!(
|
||||||
|
"Provider settings saved but initialization failed: {}",
|
||||||
|
e
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
Json(ApiResponse::success(
|
||||||
|
serde_json::json!({ "message": "Provider updated and re-initialized" }),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Failed to update provider config: {}", e);
|
||||||
|
Json(ApiResponse::error(format!("Failed to update provider: {}", e)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) async fn handle_test_provider(
|
||||||
|
State(state): State<DashboardState>,
|
||||||
|
Path(name): Path<String>,
|
||||||
|
) -> Json<ApiResponse<serde_json::Value>> {
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
|
||||||
|
let provider = match state.app_state.provider_manager.get_provider(&name).await {
|
||||||
|
Some(p) => p,
|
||||||
|
None => {
|
||||||
|
return Json(ApiResponse::error(format!(
|
||||||
|
"Provider '{}' not found or not enabled",
|
||||||
|
name
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Pick a real model for this provider from the registry
|
||||||
|
let test_model = state
|
||||||
|
.app_state
|
||||||
|
.model_registry
|
||||||
|
.providers
|
||||||
|
.get(&name)
|
||||||
|
.and_then(|p| p.models.keys().next().cloned())
|
||||||
|
.unwrap_or_else(|| name.clone());
|
||||||
|
|
||||||
|
let test_request = crate::models::UnifiedRequest {
|
||||||
|
client_id: "system-test".to_string(),
|
||||||
|
model: test_model,
|
||||||
|
messages: vec![crate::models::UnifiedMessage {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: vec![crate::models::ContentPart::Text { text: "Hi".to_string() }],
|
||||||
|
}],
|
||||||
|
temperature: None,
|
||||||
|
max_tokens: Some(5),
|
||||||
|
stream: false,
|
||||||
|
has_images: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
match provider.chat_completion(test_request).await {
|
||||||
|
Ok(_) => {
|
||||||
|
let latency = start.elapsed().as_millis();
|
||||||
|
Json(ApiResponse::success(serde_json::json!({
|
||||||
|
"success": true,
|
||||||
|
"latency": latency,
|
||||||
|
"message": "Connection test successful"
|
||||||
|
})))
|
||||||
|
}
|
||||||
|
Err(e) => Json(ApiResponse::error(format!("Provider test failed: {}", e))),
|
||||||
|
}
|
||||||
|
}
|
||||||
64
src/dashboard/sessions.rs
Normal file
64
src/dashboard/sessions.rs
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
use chrono::{DateTime, Duration, Utc};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct Session {
|
||||||
|
pub username: String,
|
||||||
|
pub role: String,
|
||||||
|
pub created_at: DateTime<Utc>,
|
||||||
|
pub expires_at: DateTime<Utc>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct SessionManager {
|
||||||
|
sessions: Arc<RwLock<HashMap<String, Session>>>,
|
||||||
|
ttl_hours: i64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SessionManager {
|
||||||
|
pub fn new(ttl_hours: i64) -> Self {
|
||||||
|
Self {
|
||||||
|
sessions: Arc::new(RwLock::new(HashMap::new())),
|
||||||
|
ttl_hours,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new session and return the session token.
|
||||||
|
pub async fn create_session(&self, username: String, role: String) -> String {
|
||||||
|
let token = format!("session-{}", uuid::Uuid::new_v4());
|
||||||
|
let now = Utc::now();
|
||||||
|
let session = Session {
|
||||||
|
username,
|
||||||
|
role,
|
||||||
|
created_at: now,
|
||||||
|
expires_at: now + Duration::hours(self.ttl_hours),
|
||||||
|
};
|
||||||
|
self.sessions.write().await.insert(token.clone(), session);
|
||||||
|
token
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validate a session token and return the session if valid and not expired.
|
||||||
|
pub async fn validate_session(&self, token: &str) -> Option<Session> {
|
||||||
|
let sessions = self.sessions.read().await;
|
||||||
|
sessions.get(token).and_then(|s| {
|
||||||
|
if s.expires_at > Utc::now() {
|
||||||
|
Some(s.clone())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Revoke (delete) a session by token.
|
||||||
|
pub async fn revoke_session(&self, token: &str) {
|
||||||
|
self.sessions.write().await.remove(token);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Remove all expired sessions from the store.
|
||||||
|
pub async fn cleanup_expired(&self) {
|
||||||
|
let now = Utc::now();
|
||||||
|
self.sessions.write().await.retain(|_, s| s.expires_at > now);
|
||||||
|
}
|
||||||
|
}
|
||||||
193
src/dashboard/system.rs
Normal file
193
src/dashboard/system.rs
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
use axum::{extract::State, response::Json};
|
||||||
|
use chrono;
|
||||||
|
use serde_json;
|
||||||
|
use sqlx::Row;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use tracing::warn;
|
||||||
|
|
||||||
|
use super::{ApiResponse, DashboardState};
|
||||||
|
|
||||||
|
pub(super) async fn handle_system_health(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||||
|
let mut components = HashMap::new();
|
||||||
|
components.insert("api_server".to_string(), "online".to_string());
|
||||||
|
components.insert("database".to_string(), "online".to_string());
|
||||||
|
|
||||||
|
// Check provider health via circuit breakers
|
||||||
|
let provider_ids: Vec<String> = state
|
||||||
|
.app_state
|
||||||
|
.provider_manager
|
||||||
|
.get_all_providers()
|
||||||
|
.await
|
||||||
|
.iter()
|
||||||
|
.map(|p| p.name().to_string())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
for p_id in provider_ids {
|
||||||
|
if state
|
||||||
|
.app_state
|
||||||
|
.rate_limit_manager
|
||||||
|
.check_provider_request(&p_id)
|
||||||
|
.await
|
||||||
|
.unwrap_or(true)
|
||||||
|
{
|
||||||
|
components.insert(p_id, "online".to_string());
|
||||||
|
} else {
|
||||||
|
components.insert(p_id, "degraded".to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read real memory usage from /proc/self/status
|
||||||
|
let memory_mb = std::fs::read_to_string("/proc/self/status")
|
||||||
|
.ok()
|
||||||
|
.and_then(|s| s.lines().find(|l| l.starts_with("VmRSS:")).map(|l| l.to_string()))
|
||||||
|
.and_then(|l| l.split_whitespace().nth(1).and_then(|v| v.parse::<f64>().ok()))
|
||||||
|
.map(|kb| kb / 1024.0)
|
||||||
|
.unwrap_or(0.0);
|
||||||
|
|
||||||
|
// Get real database pool stats
|
||||||
|
let db_pool_size = state.app_state.db_pool.size();
|
||||||
|
let db_pool_idle = state.app_state.db_pool.num_idle();
|
||||||
|
|
||||||
|
Json(ApiResponse::success(serde_json::json!({
|
||||||
|
"status": "healthy",
|
||||||
|
"timestamp": chrono::Utc::now().to_rfc3339(),
|
||||||
|
"components": components,
|
||||||
|
"metrics": {
|
||||||
|
"memory_usage_mb": (memory_mb * 10.0).round() / 10.0,
|
||||||
|
"db_connections_active": db_pool_size - db_pool_idle as u32,
|
||||||
|
"db_connections_idle": db_pool_idle,
|
||||||
|
}
|
||||||
|
})))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) async fn handle_system_logs(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||||
|
let pool = &state.app_state.db_pool;
|
||||||
|
|
||||||
|
let result = sqlx::query(
|
||||||
|
r#"
|
||||||
|
SELECT
|
||||||
|
id,
|
||||||
|
timestamp,
|
||||||
|
client_id,
|
||||||
|
provider,
|
||||||
|
model,
|
||||||
|
prompt_tokens,
|
||||||
|
completion_tokens,
|
||||||
|
total_tokens,
|
||||||
|
cost,
|
||||||
|
status,
|
||||||
|
error_message,
|
||||||
|
duration_ms
|
||||||
|
FROM llm_requests
|
||||||
|
ORDER BY timestamp DESC
|
||||||
|
LIMIT 100
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.fetch_all(pool)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(rows) => {
|
||||||
|
let logs: Vec<serde_json::Value> = rows
|
||||||
|
.into_iter()
|
||||||
|
.map(|row| {
|
||||||
|
serde_json::json!({
|
||||||
|
"id": row.get::<i64, _>("id"),
|
||||||
|
"timestamp": row.get::<chrono::DateTime<chrono::Utc>, _>("timestamp"),
|
||||||
|
"client_id": row.get::<String, _>("client_id"),
|
||||||
|
"provider": row.get::<String, _>("provider"),
|
||||||
|
"model": row.get::<String, _>("model"),
|
||||||
|
"tokens": row.get::<i64, _>("total_tokens"),
|
||||||
|
"cost": row.get::<f64, _>("cost"),
|
||||||
|
"status": row.get::<String, _>("status"),
|
||||||
|
"error": row.get::<Option<String>, _>("error_message"),
|
||||||
|
"duration": row.get::<i64, _>("duration_ms"),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Json(ApiResponse::success(serde_json::json!(logs)))
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Failed to fetch system logs: {}", e);
|
||||||
|
Json(ApiResponse::error("Failed to fetch system logs".to_string()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) async fn handle_system_backup(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||||
|
let pool = &state.app_state.db_pool;
|
||||||
|
let backup_id = format!("backup-{}", chrono::Utc::now().timestamp());
|
||||||
|
let backup_path = format!("data/{}.db", backup_id);
|
||||||
|
|
||||||
|
// Ensure the data directory exists
|
||||||
|
if let Err(e) = std::fs::create_dir_all("data") {
|
||||||
|
return Json(ApiResponse::error(format!("Failed to create backup directory: {}", e)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use SQLite VACUUM INTO for a consistent backup
|
||||||
|
let result = sqlx::query(&format!("VACUUM INTO '{}'", backup_path))
|
||||||
|
.execute(pool)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(_) => {
|
||||||
|
// Get backup file size
|
||||||
|
let size_bytes = std::fs::metadata(&backup_path).map(|m| m.len()).unwrap_or(0);
|
||||||
|
|
||||||
|
Json(ApiResponse::success(serde_json::json!({
|
||||||
|
"success": true,
|
||||||
|
"message": "Backup completed successfully",
|
||||||
|
"backup_id": backup_id,
|
||||||
|
"backup_path": backup_path,
|
||||||
|
"size_bytes": size_bytes,
|
||||||
|
})))
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Database backup failed: {}", e);
|
||||||
|
Json(ApiResponse::error(format!("Backup failed: {}", e)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) async fn handle_get_settings(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||||
|
let registry = &state.app_state.model_registry;
|
||||||
|
let provider_count = registry.providers.len();
|
||||||
|
let model_count: usize = registry.providers.values().map(|p| p.models.len()).sum();
|
||||||
|
|
||||||
|
Json(ApiResponse::success(serde_json::json!({
|
||||||
|
"server": {
|
||||||
|
"auth_tokens": state.app_state.auth_tokens.iter().map(|t| mask_token(t)).collect::<Vec<_>>(),
|
||||||
|
"version": env!("CARGO_PKG_VERSION"),
|
||||||
|
},
|
||||||
|
"registry": {
|
||||||
|
"provider_count": provider_count,
|
||||||
|
"model_count": model_count,
|
||||||
|
},
|
||||||
|
"database": {
|
||||||
|
"type": "SQLite",
|
||||||
|
}
|
||||||
|
})))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) async fn handle_update_settings(
|
||||||
|
State(_state): State<DashboardState>,
|
||||||
|
) -> Json<ApiResponse<serde_json::Value>> {
|
||||||
|
Json(ApiResponse::error(
|
||||||
|
"Changing settings at runtime is not yet supported. Please update your config file and restart the server."
|
||||||
|
.to_string(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper functions
|
||||||
|
fn mask_token(token: &str) -> String {
|
||||||
|
if token.len() <= 8 {
|
||||||
|
return "*****".to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
let masked_len = token.len().min(12);
|
||||||
|
let visible_len = 4;
|
||||||
|
let mask_len = masked_len - visible_len;
|
||||||
|
|
||||||
|
format!("{}{}", "*".repeat(mask_len), &token[token.len() - visible_len..])
|
||||||
|
}
|
||||||
330
src/dashboard/usage.rs
Normal file
330
src/dashboard/usage.rs
Normal file
@@ -0,0 +1,330 @@
|
|||||||
|
use axum::{extract::State, response::Json};
|
||||||
|
use chrono;
|
||||||
|
use serde_json;
|
||||||
|
use sqlx::Row;
|
||||||
|
use tracing::warn;
|
||||||
|
|
||||||
|
use super::{ApiResponse, DashboardState};
|
||||||
|
|
||||||
|
pub(super) async fn handle_usage_summary(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||||
|
let pool = &state.app_state.db_pool;
|
||||||
|
|
||||||
|
// Total stats
|
||||||
|
let total_stats = sqlx::query(
|
||||||
|
r#"
|
||||||
|
SELECT
|
||||||
|
COUNT(*) as total_requests,
|
||||||
|
COALESCE(SUM(total_tokens), 0) as total_tokens,
|
||||||
|
COALESCE(SUM(cost), 0.0) as total_cost,
|
||||||
|
COUNT(DISTINCT client_id) as active_clients
|
||||||
|
FROM llm_requests
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.fetch_one(pool);
|
||||||
|
|
||||||
|
// Today's stats
|
||||||
|
let today = chrono::Utc::now().format("%Y-%m-%d").to_string();
|
||||||
|
let today_stats = sqlx::query(
|
||||||
|
r#"
|
||||||
|
SELECT
|
||||||
|
COUNT(*) as today_requests,
|
||||||
|
COALESCE(SUM(total_tokens), 0) as today_tokens,
|
||||||
|
COALESCE(SUM(cost), 0.0) as today_cost
|
||||||
|
FROM llm_requests
|
||||||
|
WHERE strftime('%Y-%m-%d', timestamp) = ?
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.bind(today)
|
||||||
|
.fetch_one(pool);
|
||||||
|
|
||||||
|
// Error stats
|
||||||
|
let error_stats = sqlx::query(
|
||||||
|
r#"
|
||||||
|
SELECT
|
||||||
|
COUNT(*) as total,
|
||||||
|
SUM(CASE WHEN status = 'error' THEN 1 ELSE 0 END) as errors
|
||||||
|
FROM llm_requests
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.fetch_one(pool);
|
||||||
|
|
||||||
|
// Average response time
|
||||||
|
let avg_response = sqlx::query(
|
||||||
|
r#"
|
||||||
|
SELECT COALESCE(AVG(duration_ms), 0.0) as avg_duration
|
||||||
|
FROM llm_requests
|
||||||
|
WHERE status = 'success'
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.fetch_one(pool);
|
||||||
|
|
||||||
|
match tokio::join!(total_stats, today_stats, error_stats, avg_response) {
|
||||||
|
(Ok(t), Ok(d), Ok(e), Ok(a)) => {
|
||||||
|
let total_requests: i64 = t.get("total_requests");
|
||||||
|
let total_tokens: i64 = t.get("total_tokens");
|
||||||
|
let total_cost: f64 = t.get("total_cost");
|
||||||
|
let active_clients: i64 = t.get("active_clients");
|
||||||
|
|
||||||
|
let today_requests: i64 = d.get("today_requests");
|
||||||
|
let today_cost: f64 = d.get("today_cost");
|
||||||
|
|
||||||
|
let total_count: i64 = e.get("total");
|
||||||
|
let error_count: i64 = e.get("errors");
|
||||||
|
let error_rate = if total_count > 0 {
|
||||||
|
(error_count as f64 / total_count as f64) * 100.0
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
};
|
||||||
|
|
||||||
|
let avg_response_time: f64 = a.get("avg_duration");
|
||||||
|
|
||||||
|
Json(ApiResponse::success(serde_json::json!({
|
||||||
|
"total_requests": total_requests,
|
||||||
|
"total_tokens": total_tokens,
|
||||||
|
"total_cost": total_cost,
|
||||||
|
"active_clients": active_clients,
|
||||||
|
"today_requests": today_requests,
|
||||||
|
"today_cost": today_cost,
|
||||||
|
"error_rate": error_rate,
|
||||||
|
"avg_response_time": avg_response_time,
|
||||||
|
})))
|
||||||
|
}
|
||||||
|
_ => Json(ApiResponse::error("Failed to fetch usage statistics".to_string())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) async fn handle_time_series(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||||
|
let pool = &state.app_state.db_pool;
|
||||||
|
|
||||||
|
let now = chrono::Utc::now();
|
||||||
|
let twenty_four_hours_ago = now - chrono::Duration::hours(24);
|
||||||
|
|
||||||
|
let result = sqlx::query(
|
||||||
|
r#"
|
||||||
|
SELECT
|
||||||
|
strftime('%H:00', timestamp) as hour,
|
||||||
|
COUNT(*) as requests,
|
||||||
|
SUM(total_tokens) as tokens,
|
||||||
|
SUM(cost) as cost
|
||||||
|
FROM llm_requests
|
||||||
|
WHERE timestamp >= ?
|
||||||
|
GROUP BY hour
|
||||||
|
ORDER BY hour
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.bind(twenty_four_hours_ago)
|
||||||
|
.fetch_all(pool)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(rows) => {
|
||||||
|
let mut series = Vec::new();
|
||||||
|
|
||||||
|
for row in rows {
|
||||||
|
let hour: String = row.get("hour");
|
||||||
|
let requests: i64 = row.get("requests");
|
||||||
|
let tokens: i64 = row.get("tokens");
|
||||||
|
let cost: f64 = row.get("cost");
|
||||||
|
|
||||||
|
series.push(serde_json::json!({
|
||||||
|
"time": hour,
|
||||||
|
"requests": requests,
|
||||||
|
"tokens": tokens,
|
||||||
|
"cost": cost,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
Json(ApiResponse::success(serde_json::json!({
|
||||||
|
"series": series,
|
||||||
|
"period": "24h"
|
||||||
|
})))
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Failed to fetch time series data: {}", e);
|
||||||
|
Json(ApiResponse::error("Failed to fetch time series data".to_string()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) async fn handle_clients_usage(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||||
|
// Query database for client usage statistics
|
||||||
|
let pool = &state.app_state.db_pool;
|
||||||
|
|
||||||
|
let result = sqlx::query(
|
||||||
|
r#"
|
||||||
|
SELECT
|
||||||
|
client_id,
|
||||||
|
COUNT(*) as requests,
|
||||||
|
SUM(total_tokens) as tokens,
|
||||||
|
SUM(cost) as cost,
|
||||||
|
MAX(timestamp) as last_request
|
||||||
|
FROM llm_requests
|
||||||
|
GROUP BY client_id
|
||||||
|
ORDER BY requests DESC
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.fetch_all(pool)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(rows) => {
|
||||||
|
let mut client_usage = Vec::new();
|
||||||
|
|
||||||
|
for row in rows {
|
||||||
|
let client_id: String = row.get("client_id");
|
||||||
|
let requests: i64 = row.get("requests");
|
||||||
|
let tokens: i64 = row.get("tokens");
|
||||||
|
let cost: f64 = row.get("cost");
|
||||||
|
let last_request: Option<chrono::DateTime<chrono::Utc>> = row.get("last_request");
|
||||||
|
|
||||||
|
client_usage.push(serde_json::json!({
|
||||||
|
"client_id": client_id,
|
||||||
|
"client_name": client_id,
|
||||||
|
"requests": requests,
|
||||||
|
"tokens": tokens,
|
||||||
|
"cost": cost,
|
||||||
|
"last_request": last_request,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
Json(ApiResponse::success(serde_json::json!(client_usage)))
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Failed to fetch client usage data: {}", e);
|
||||||
|
Json(ApiResponse::error("Failed to fetch client usage data".to_string()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) async fn handle_providers_usage(
|
||||||
|
State(state): State<DashboardState>,
|
||||||
|
) -> Json<ApiResponse<serde_json::Value>> {
|
||||||
|
// Query database for provider usage statistics
|
||||||
|
let pool = &state.app_state.db_pool;
|
||||||
|
|
||||||
|
let result = sqlx::query(
|
||||||
|
r#"
|
||||||
|
SELECT
|
||||||
|
provider,
|
||||||
|
COUNT(*) as requests,
|
||||||
|
COALESCE(SUM(total_tokens), 0) as tokens,
|
||||||
|
COALESCE(SUM(cost), 0.0) as cost
|
||||||
|
FROM llm_requests
|
||||||
|
GROUP BY provider
|
||||||
|
ORDER BY requests DESC
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.fetch_all(pool)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(rows) => {
|
||||||
|
let mut provider_usage = Vec::new();
|
||||||
|
|
||||||
|
for row in rows {
|
||||||
|
let provider: String = row.get("provider");
|
||||||
|
let requests: i64 = row.get("requests");
|
||||||
|
let tokens: i64 = row.get("tokens");
|
||||||
|
let cost: f64 = row.get("cost");
|
||||||
|
|
||||||
|
provider_usage.push(serde_json::json!({
|
||||||
|
"provider": provider,
|
||||||
|
"requests": requests,
|
||||||
|
"tokens": tokens,
|
||||||
|
"cost": cost,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
Json(ApiResponse::success(serde_json::json!(provider_usage)))
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Failed to fetch provider usage data: {}", e);
|
||||||
|
Json(ApiResponse::error("Failed to fetch provider usage data".to_string()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) async fn handle_detailed_usage(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||||
|
let pool = &state.app_state.db_pool;
|
||||||
|
|
||||||
|
let result = sqlx::query(
|
||||||
|
r#"
|
||||||
|
SELECT
|
||||||
|
strftime('%Y-%m-%d', timestamp) as date,
|
||||||
|
client_id,
|
||||||
|
provider,
|
||||||
|
model,
|
||||||
|
COUNT(*) as requests,
|
||||||
|
SUM(total_tokens) as tokens,
|
||||||
|
SUM(cost) as cost
|
||||||
|
FROM llm_requests
|
||||||
|
GROUP BY date, client_id, provider, model
|
||||||
|
ORDER BY date DESC
|
||||||
|
LIMIT 100
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.fetch_all(pool)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(rows) => {
|
||||||
|
let usage: Vec<serde_json::Value> = rows
|
||||||
|
.into_iter()
|
||||||
|
.map(|row| {
|
||||||
|
serde_json::json!({
|
||||||
|
"date": row.get::<String, _>("date"),
|
||||||
|
"client": row.get::<String, _>("client_id"),
|
||||||
|
"provider": row.get::<String, _>("provider"),
|
||||||
|
"model": row.get::<String, _>("model"),
|
||||||
|
"requests": row.get::<i64, _>("requests"),
|
||||||
|
"tokens": row.get::<i64, _>("tokens"),
|
||||||
|
"cost": row.get::<f64, _>("cost"),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Json(ApiResponse::success(serde_json::json!(usage)))
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Failed to fetch detailed usage: {}", e);
|
||||||
|
Json(ApiResponse::error("Failed to fetch detailed usage".to_string()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) async fn handle_analytics_breakdown(
|
||||||
|
State(state): State<DashboardState>,
|
||||||
|
) -> Json<ApiResponse<serde_json::Value>> {
|
||||||
|
let pool = &state.app_state.db_pool;
|
||||||
|
|
||||||
|
// Model breakdown
|
||||||
|
let models =
|
||||||
|
sqlx::query("SELECT model as label, COUNT(*) as value FROM llm_requests GROUP BY model ORDER BY value DESC")
|
||||||
|
.fetch_all(pool);
|
||||||
|
|
||||||
|
// Client breakdown
|
||||||
|
let clients = sqlx::query(
|
||||||
|
"SELECT client_id as label, COUNT(*) as value FROM llm_requests GROUP BY client_id ORDER BY value DESC",
|
||||||
|
)
|
||||||
|
.fetch_all(pool);
|
||||||
|
|
||||||
|
match tokio::join!(models, clients) {
|
||||||
|
(Ok(m_rows), Ok(c_rows)) => {
|
||||||
|
let model_breakdown: Vec<serde_json::Value> = m_rows
|
||||||
|
.into_iter()
|
||||||
|
.map(|r| serde_json::json!({ "label": r.get::<String, _>("label"), "value": r.get::<i64, _>("value") }))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let client_breakdown: Vec<serde_json::Value> = c_rows
|
||||||
|
.into_iter()
|
||||||
|
.map(|r| serde_json::json!({ "label": r.get::<String, _>("label"), "value": r.get::<i64, _>("value") }))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Json(ApiResponse::success(serde_json::json!({
|
||||||
|
"models": model_breakdown,
|
||||||
|
"clients": client_breakdown
|
||||||
|
})))
|
||||||
|
}
|
||||||
|
_ => Json(ApiResponse::error("Failed to fetch analytics breakdown".to_string())),
|
||||||
|
}
|
||||||
|
}
|
||||||
75
src/dashboard/websocket.rs
Normal file
75
src/dashboard/websocket.rs
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
use axum::{
|
||||||
|
extract::{
|
||||||
|
State,
|
||||||
|
ws::{Message, WebSocket, WebSocketUpgrade},
|
||||||
|
},
|
||||||
|
response::IntoResponse,
|
||||||
|
};
|
||||||
|
use serde_json;
|
||||||
|
use tracing::info;
|
||||||
|
|
||||||
|
use super::DashboardState;
|
||||||
|
|
||||||
|
// WebSocket handler
|
||||||
|
pub(super) async fn handle_websocket(ws: WebSocketUpgrade, State(state): State<DashboardState>) -> impl IntoResponse {
|
||||||
|
ws.on_upgrade(|socket| handle_websocket_connection(socket, state))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) async fn handle_websocket_connection(mut socket: WebSocket, state: DashboardState) {
|
||||||
|
info!("WebSocket connection established");
|
||||||
|
|
||||||
|
// Subscribe to events from the global bus
|
||||||
|
let mut rx = state.app_state.dashboard_tx.subscribe();
|
||||||
|
|
||||||
|
// Send initial connection message
|
||||||
|
let _ = socket
|
||||||
|
.send(Message::Text(
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "connected",
|
||||||
|
"message": "Connected to LLM Proxy Dashboard"
|
||||||
|
})
|
||||||
|
.to_string()
|
||||||
|
.into(),
|
||||||
|
))
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// Handle incoming messages and broadcast events
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
// Receive broadcast events
|
||||||
|
Ok(event) = rx.recv() => {
|
||||||
|
let Ok(json_str) = serde_json::to_string(&event) else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
let message = Message::Text(json_str.into());
|
||||||
|
if socket.send(message).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receive WebSocket messages
|
||||||
|
result = socket.recv() => {
|
||||||
|
match result {
|
||||||
|
Some(Ok(Message::Text(text))) => {
|
||||||
|
handle_websocket_message(&text, &state).await;
|
||||||
|
}
|
||||||
|
_ => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
info!("WebSocket connection closed");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) async fn handle_websocket_message(text: &str, state: &DashboardState) {
|
||||||
|
// Parse and handle WebSocket messages
|
||||||
|
if let Ok(data) = serde_json::from_str::<serde_json::Value>(text)
|
||||||
|
&& let Some("ping") = data.get("type").and_then(|v| v.as_str())
|
||||||
|
{
|
||||||
|
let _ = state.app_state.dashboard_tx.send(serde_json::json!({
|
||||||
|
"type": "pong",
|
||||||
|
"payload": {}
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use sqlx::sqlite::{SqlitePool, SqliteConnectOptions};
|
use sqlx::sqlite::{SqliteConnectOptions, SqlitePool};
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
@@ -9,17 +9,16 @@ pub type DbPool = SqlitePool;
|
|||||||
|
|
||||||
pub async fn init(config: &DatabaseConfig) -> Result<DbPool> {
|
pub async fn init(config: &DatabaseConfig) -> Result<DbPool> {
|
||||||
// Ensure the database directory exists
|
// Ensure the database directory exists
|
||||||
if let Some(parent) = config.path.parent() {
|
if let Some(parent) = config.path.parent()
|
||||||
if !parent.as_os_str().is_empty() {
|
&& !parent.as_os_str().is_empty()
|
||||||
|
{
|
||||||
tokio::fs::create_dir_all(parent).await?;
|
tokio::fs::create_dir_all(parent).await?;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
let database_path = config.path.to_string_lossy().to_string();
|
let database_path = config.path.to_string_lossy().to_string();
|
||||||
info!("Connecting to database at {}", database_path);
|
info!("Connecting to database at {}", database_path);
|
||||||
|
|
||||||
let options = SqliteConnectOptions::from_str(&format!("sqlite:{}", database_path))?
|
let options = SqliteConnectOptions::from_str(&format!("sqlite:{}", database_path))?.create_if_missing(true);
|
||||||
.create_if_missing(true);
|
|
||||||
|
|
||||||
let pool = SqlitePool::connect_with(options).await?;
|
let pool = SqlitePool::connect_with(options).await?;
|
||||||
|
|
||||||
@@ -91,7 +90,7 @@ async fn run_migrations(pool: &DbPool) -> Result<()> {
|
|||||||
low_credit_threshold REAL DEFAULT 5.0,
|
low_credit_threshold REAL DEFAULT 5.0,
|
||||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||||
)
|
)
|
||||||
"#
|
"#,
|
||||||
)
|
)
|
||||||
.execute(pool)
|
.execute(pool)
|
||||||
.await?;
|
.await?;
|
||||||
@@ -110,7 +109,7 @@ async fn run_migrations(pool: &DbPool) -> Result<()> {
|
|||||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
FOREIGN KEY (provider_id) REFERENCES provider_configs(id) ON DELETE CASCADE
|
FOREIGN KEY (provider_id) REFERENCES provider_configs(id) ON DELETE CASCADE
|
||||||
)
|
)
|
||||||
"#
|
"#,
|
||||||
)
|
)
|
||||||
.execute(pool)
|
.execute(pool)
|
||||||
.await?;
|
.await?;
|
||||||
@@ -123,64 +122,57 @@ async fn run_migrations(pool: &DbPool) -> Result<()> {
|
|||||||
username TEXT UNIQUE NOT NULL,
|
username TEXT UNIQUE NOT NULL,
|
||||||
password_hash TEXT NOT NULL,
|
password_hash TEXT NOT NULL,
|
||||||
role TEXT DEFAULT 'admin',
|
role TEXT DEFAULT 'admin',
|
||||||
|
must_change_password BOOLEAN DEFAULT FALSE,
|
||||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||||
)
|
)
|
||||||
"#
|
"#,
|
||||||
)
|
)
|
||||||
.execute(pool)
|
.execute(pool)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
// Add must_change_password column if it doesn't exist (migration for existing DBs)
|
||||||
|
let _ = sqlx::query("ALTER TABLE users ADD COLUMN must_change_password BOOLEAN DEFAULT FALSE")
|
||||||
|
.execute(pool)
|
||||||
|
.await;
|
||||||
|
|
||||||
// Insert default admin user if none exists (default password: admin)
|
// Insert default admin user if none exists (default password: admin)
|
||||||
let user_count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users")
|
let user_count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users").fetch_one(pool).await?;
|
||||||
.fetch_one(pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
if user_count.0 == 0 {
|
if user_count.0 == 0 {
|
||||||
// 'admin' hashed with default cost (12)
|
// 'admin' hashed with default cost (12)
|
||||||
let default_admin_hash = bcrypt::hash("admin", 12).unwrap();
|
let default_admin_hash =
|
||||||
|
bcrypt::hash("admin", 12).map_err(|e| anyhow::anyhow!("Failed to hash default password: {}", e))?;
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
"INSERT INTO users (username, password_hash, role) VALUES ('admin', ?, 'admin')"
|
"INSERT INTO users (username, password_hash, role, must_change_password) VALUES ('admin', ?, 'admin', TRUE)"
|
||||||
)
|
)
|
||||||
.bind(default_admin_hash)
|
.bind(default_admin_hash)
|
||||||
.execute(pool)
|
.execute(pool)
|
||||||
.await?;
|
.await?;
|
||||||
info!("Created default admin user with password 'admin'");
|
info!("Created default admin user with password 'admin' (must change on first login)");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create indices
|
// Create indices
|
||||||
sqlx::query(
|
sqlx::query("CREATE INDEX IF NOT EXISTS idx_clients_client_id ON clients(client_id)")
|
||||||
"CREATE INDEX IF NOT EXISTS idx_clients_client_id ON clients(client_id)"
|
|
||||||
)
|
|
||||||
.execute(pool)
|
.execute(pool)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
sqlx::query(
|
sqlx::query("CREATE INDEX IF NOT EXISTS idx_clients_created_at ON clients(created_at)")
|
||||||
"CREATE INDEX IF NOT EXISTS idx_clients_created_at ON clients(created_at)"
|
|
||||||
)
|
|
||||||
.execute(pool)
|
.execute(pool)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
sqlx::query(
|
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_timestamp ON llm_requests(timestamp)")
|
||||||
"CREATE INDEX IF NOT EXISTS idx_llm_requests_timestamp ON llm_requests(timestamp)"
|
|
||||||
)
|
|
||||||
.execute(pool)
|
.execute(pool)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
sqlx::query(
|
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_client_id ON llm_requests(client_id)")
|
||||||
"CREATE INDEX IF NOT EXISTS idx_llm_requests_client_id ON llm_requests(client_id)"
|
|
||||||
)
|
|
||||||
.execute(pool)
|
.execute(pool)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
sqlx::query(
|
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_provider ON llm_requests(provider)")
|
||||||
"CREATE INDEX IF NOT EXISTS idx_llm_requests_provider ON llm_requests(provider)"
|
|
||||||
)
|
|
||||||
.execute(pool)
|
.execute(pool)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
sqlx::query(
|
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_status ON llm_requests(status)")
|
||||||
"CREATE INDEX IF NOT EXISTS idx_llm_requests_status ON llm_requests(status)"
|
|
||||||
)
|
|
||||||
.execute(pool)
|
.execute(pool)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
|||||||
72
src/lib.rs
72
src/lib.rs
@@ -6,8 +6,8 @@
|
|||||||
pub mod auth;
|
pub mod auth;
|
||||||
pub mod client;
|
pub mod client;
|
||||||
pub mod config;
|
pub mod config;
|
||||||
pub mod database;
|
|
||||||
pub mod dashboard;
|
pub mod dashboard;
|
||||||
|
pub mod database;
|
||||||
pub mod errors;
|
pub mod errors;
|
||||||
pub mod logging;
|
pub mod logging;
|
||||||
pub mod models;
|
pub mod models;
|
||||||
@@ -19,27 +19,29 @@ pub mod state;
|
|||||||
pub mod utils;
|
pub mod utils;
|
||||||
|
|
||||||
// Re-exports for convenience
|
// Re-exports for convenience
|
||||||
pub use auth::*;
|
pub use auth::{AuthenticatedClient, validate_token};
|
||||||
pub use config::*;
|
pub use config::{
|
||||||
pub use database::*;
|
AppConfig, DatabaseConfig, DeepSeekConfig, GeminiConfig, GrokConfig, ModelMappingConfig, ModelPricing,
|
||||||
pub use errors::*;
|
OllamaConfig, OpenAIConfig, PricingConfig, ProviderConfig, ServerConfig,
|
||||||
pub use logging::*;
|
};
|
||||||
pub use models::*;
|
pub use database::{DbPool, init as init_db, test_connection};
|
||||||
pub use providers::*;
|
pub use errors::AppError;
|
||||||
pub use server::*;
|
pub use logging::{LoggingContext, RequestLog, RequestLogger};
|
||||||
pub use state::*;
|
pub use models::{
|
||||||
|
ChatChoice, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, ChatMessage,
|
||||||
|
ChatStreamChoice, ChatStreamDelta, ContentPart, ContentPartValue, FromOpenAI, ImageUrl, MessageContent,
|
||||||
|
OpenAIContentPart, OpenAIMessage, OpenAIRequest, ToOpenAI, UnifiedMessage, UnifiedRequest, Usage,
|
||||||
|
};
|
||||||
|
pub use providers::{Provider, ProviderManager, ProviderResponse, ProviderStreamChunk};
|
||||||
|
pub use server::router;
|
||||||
|
pub use state::AppState;
|
||||||
|
|
||||||
/// Test utilities for integration testing
|
/// Test utilities for integration testing
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
pub mod test_utils {
|
pub mod test_utils {
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::{
|
use crate::{client::ClientManager, providers::ProviderManager, rate_limiting::RateLimitManager, state::AppState};
|
||||||
state::AppState,
|
|
||||||
rate_limiting::RateLimitManager,
|
|
||||||
client::ClientManager,
|
|
||||||
providers::ProviderManager,
|
|
||||||
};
|
|
||||||
use sqlx::sqlite::SqlitePool;
|
use sqlx::sqlite::SqlitePool;
|
||||||
|
|
||||||
/// Create a test application state
|
/// Create a test application state
|
||||||
@@ -53,7 +55,9 @@ pub mod test_utils {
|
|||||||
crate::database::init(&crate::config::DatabaseConfig {
|
crate::database::init(&crate::config::DatabaseConfig {
|
||||||
path: std::path::PathBuf::from(":memory:"),
|
path: std::path::PathBuf::from(":memory:"),
|
||||||
max_connections: 5,
|
max_connections: 5,
|
||||||
}).await.expect("Failed to initialize test database");
|
})
|
||||||
|
.await
|
||||||
|
.expect("Failed to initialize test database");
|
||||||
|
|
||||||
let rate_limit_manager = RateLimitManager::new(
|
let rate_limit_manager = RateLimitManager::new(
|
||||||
crate::rate_limiting::RateLimiterConfig::default(),
|
crate::rate_limiting::RateLimiterConfig::default(),
|
||||||
@@ -82,11 +86,35 @@ pub mod test_utils {
|
|||||||
max_connections: 5,
|
max_connections: 5,
|
||||||
},
|
},
|
||||||
providers: crate::config::ProviderConfig {
|
providers: crate::config::ProviderConfig {
|
||||||
openai: crate::config::OpenAIConfig { api_key_env: "OPENAI_API_KEY".to_string(), base_url: "".to_string(), default_model: "".to_string(), enabled: true },
|
openai: crate::config::OpenAIConfig {
|
||||||
gemini: crate::config::GeminiConfig { api_key_env: "GEMINI_API_KEY".to_string(), base_url: "".to_string(), default_model: "".to_string(), enabled: true },
|
api_key_env: "OPENAI_API_KEY".to_string(),
|
||||||
deepseek: crate::config::DeepSeekConfig { api_key_env: "DEEPSEEK_API_KEY".to_string(), base_url: "".to_string(), default_model: "".to_string(), enabled: true },
|
base_url: "".to_string(),
|
||||||
grok: crate::config::GrokConfig { api_key_env: "GROK_API_KEY".to_string(), base_url: "".to_string(), default_model: "".to_string(), enabled: true },
|
default_model: "".to_string(),
|
||||||
ollama: crate::config::OllamaConfig { base_url: "".to_string(), enabled: true, models: vec![] },
|
enabled: true,
|
||||||
|
},
|
||||||
|
gemini: crate::config::GeminiConfig {
|
||||||
|
api_key_env: "GEMINI_API_KEY".to_string(),
|
||||||
|
base_url: "".to_string(),
|
||||||
|
default_model: "".to_string(),
|
||||||
|
enabled: true,
|
||||||
|
},
|
||||||
|
deepseek: crate::config::DeepSeekConfig {
|
||||||
|
api_key_env: "DEEPSEEK_API_KEY".to_string(),
|
||||||
|
base_url: "".to_string(),
|
||||||
|
default_model: "".to_string(),
|
||||||
|
enabled: true,
|
||||||
|
},
|
||||||
|
grok: crate::config::GrokConfig {
|
||||||
|
api_key_env: "GROK_API_KEY".to_string(),
|
||||||
|
base_url: "".to_string(),
|
||||||
|
default_model: "".to_string(),
|
||||||
|
enabled: true,
|
||||||
|
},
|
||||||
|
ollama: crate::config::OllamaConfig {
|
||||||
|
base_url: "".to_string(),
|
||||||
|
enabled: true,
|
||||||
|
models: vec![],
|
||||||
|
},
|
||||||
},
|
},
|
||||||
model_mapping: crate::config::ModelMappingConfig { patterns: vec![] },
|
model_mapping: crate::config::ModelMappingConfig { patterns: vec![] },
|
||||||
pricing: crate::config::PricingConfig {
|
pricing: crate::config::PricingConfig {
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
|
use serde::Serialize;
|
||||||
use sqlx::SqlitePool;
|
use sqlx::SqlitePool;
|
||||||
use tokio::sync::broadcast;
|
use tokio::sync::broadcast;
|
||||||
use tracing::warn;
|
use tracing::warn;
|
||||||
use serde::Serialize;
|
|
||||||
|
|
||||||
use crate::errors::AppError;
|
use crate::errors::AppError;
|
||||||
|
|
||||||
@@ -77,16 +77,14 @@ impl RequestLogger {
|
|||||||
.bind(log.status)
|
.bind(log.status)
|
||||||
.bind(log.error_message)
|
.bind(log.error_message)
|
||||||
.bind(log.duration_ms as i64)
|
.bind(log.duration_ms as i64)
|
||||||
.bind(None::<String>) // request_body - TODO: store serialized request
|
.bind(None::<String>) // request_body - optional, not stored to save disk space
|
||||||
.bind(None::<String>) // response_body - TODO: store serialized response or error
|
.bind(None::<String>) // response_body - optional, not stored to save disk space
|
||||||
.execute(&mut *tx)
|
.execute(&mut *tx)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// Deduct from provider balance if successful
|
// Deduct from provider balance if successful
|
||||||
if log.cost > 0.0 {
|
if log.cost > 0.0 {
|
||||||
sqlx::query(
|
sqlx::query("UPDATE provider_configs SET credit_balance = credit_balance - ? WHERE id = ?")
|
||||||
"UPDATE provider_configs SET credit_balance = credit_balance - ? WHERE id = ?"
|
|
||||||
)
|
|
||||||
.bind(log.cost)
|
.bind(log.cost)
|
||||||
.bind(&log.provider)
|
.bind(&log.provider)
|
||||||
.execute(&mut *tx)
|
.execute(&mut *tx)
|
||||||
|
|||||||
27
src/main.rs
27
src/main.rs
@@ -1,16 +1,15 @@
|
|||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use axum::{Router, routing::get};
|
use axum::{Router, routing::get};
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use tracing::{info, error};
|
use tracing::{error, info};
|
||||||
|
|
||||||
use llm_proxy::{
|
use llm_proxy::{
|
||||||
config::AppConfig,
|
config::AppConfig,
|
||||||
state::AppState,
|
dashboard, database,
|
||||||
providers::ProviderManager,
|
providers::ProviderManager,
|
||||||
database,
|
rate_limiting::{CircuitBreakerConfig, RateLimitManager, RateLimiterConfig},
|
||||||
server,
|
server,
|
||||||
dashboard,
|
state::AppState,
|
||||||
rate_limiting::{RateLimitManager, RateLimiterConfig, CircuitBreakerConfig},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
@@ -43,22 +42,28 @@ async fn main() -> Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create rate limit manager
|
// Create rate limit manager
|
||||||
let rate_limit_manager = RateLimitManager::new(
|
let rate_limit_manager = RateLimitManager::new(RateLimiterConfig::default(), CircuitBreakerConfig::default());
|
||||||
RateLimiterConfig::default(),
|
|
||||||
CircuitBreakerConfig::default(),
|
|
||||||
);
|
|
||||||
|
|
||||||
// Fetch model registry from models.dev
|
// Fetch model registry from models.dev
|
||||||
let model_registry = match llm_proxy::utils::registry::fetch_registry().await {
|
let model_registry = match llm_proxy::utils::registry::fetch_registry().await {
|
||||||
Ok(registry) => registry,
|
Ok(registry) => registry,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Failed to fetch model registry: {}. Using empty registry.", e);
|
error!("Failed to fetch model registry: {}. Using empty registry.", e);
|
||||||
llm_proxy::models::registry::ModelRegistry { providers: std::collections::HashMap::new() }
|
llm_proxy::models::registry::ModelRegistry {
|
||||||
|
providers: std::collections::HashMap::new(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create application state
|
// Create application state
|
||||||
let state = AppState::new(config.clone(), provider_manager, db_pool, rate_limit_manager, model_registry, config.server.auth_tokens.clone());
|
let state = AppState::new(
|
||||||
|
config.clone(),
|
||||||
|
provider_manager,
|
||||||
|
db_pool,
|
||||||
|
rate_limit_manager,
|
||||||
|
model_registry,
|
||||||
|
config.server.auth_tokens.clone(),
|
||||||
|
);
|
||||||
|
|
||||||
// Create application router
|
// Create application router
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
|
|||||||
@@ -198,9 +198,7 @@ impl TryFrom<ChatCompletionRequest> for UnifiedRequest {
|
|||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|msg| {
|
.map(|msg| {
|
||||||
let (content, _images_in_message) = match msg.content {
|
let (content, _images_in_message) = match msg.content {
|
||||||
MessageContent::Text { content } => {
|
MessageContent::Text { content } => (vec![ContentPart::Text { text: content }], false),
|
||||||
(vec![ContentPart::Text { text: content }], false)
|
|
||||||
}
|
|
||||||
MessageContent::Parts { content } => {
|
MessageContent::Parts { content } => {
|
||||||
let mut unified_content = Vec::new();
|
let mut unified_content = Vec::new();
|
||||||
let mut has_images_in_msg = false;
|
let mut has_images_in_msg = false;
|
||||||
@@ -213,18 +211,16 @@ impl TryFrom<ChatCompletionRequest> for UnifiedRequest {
|
|||||||
ContentPartValue::ImageUrl { image_url } => {
|
ContentPartValue::ImageUrl { image_url } => {
|
||||||
has_images_in_msg = true;
|
has_images_in_msg = true;
|
||||||
has_images = true;
|
has_images = true;
|
||||||
unified_content.push(ContentPart::Image(
|
unified_content.push(ContentPart::Image(crate::multimodal::ImageInput::from_url(
|
||||||
crate::multimodal::ImageInput::from_url(image_url.url)
|
image_url.url,
|
||||||
));
|
)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
(unified_content, has_images_in_msg)
|
(unified_content, has_images_in_msg)
|
||||||
}
|
}
|
||||||
MessageContent::None => {
|
MessageContent::None => (vec![], false),
|
||||||
(vec![], false)
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
UnifiedMessage {
|
UnifiedMessage {
|
||||||
|
|||||||
@@ -7,24 +7,18 @@
|
|||||||
//! 4. Provider-specific image format conversion
|
//! 4. Provider-specific image format conversion
|
||||||
|
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use base64::{engine::general_purpose, Engine as _};
|
use base64::{Engine as _, engine::general_purpose};
|
||||||
use tracing::{info, warn};
|
use tracing::{info, warn};
|
||||||
|
|
||||||
/// Supported image formats for multimodal input
|
/// Supported image formats for multimodal input
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum ImageInput {
|
pub enum ImageInput {
|
||||||
/// Base64-encoded image data with MIME type
|
/// Base64-encoded image data with MIME type
|
||||||
Base64 {
|
Base64 { data: String, mime_type: String },
|
||||||
data: String,
|
|
||||||
mime_type: String,
|
|
||||||
},
|
|
||||||
/// URL to fetch image from
|
/// URL to fetch image from
|
||||||
Url(String),
|
Url(String),
|
||||||
/// Raw bytes with MIME type
|
/// Raw bytes with MIME type
|
||||||
Bytes {
|
Bytes { data: Vec<u8>, mime_type: String },
|
||||||
data: Vec<u8>,
|
|
||||||
mime_type: String,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ImageInput {
|
impl ImageInput {
|
||||||
@@ -63,9 +57,7 @@ impl ImageInput {
|
|||||||
Self::Url(url) => {
|
Self::Url(url) => {
|
||||||
// Fetch image from URL
|
// Fetch image from URL
|
||||||
info!("Fetching image from URL: {}", url);
|
info!("Fetching image from URL: {}", url);
|
||||||
let response = reqwest::get(url)
|
let response = reqwest::get(url).await.context("Failed to fetch image from URL")?;
|
||||||
.await
|
|
||||||
.context("Failed to fetch image from URL")?;
|
|
||||||
|
|
||||||
if !response.status().is_success() {
|
if !response.status().is_success() {
|
||||||
anyhow::bail!("Failed to fetch image: HTTP {}", response.status());
|
anyhow::bail!("Failed to fetch image: HTTP {}", response.status());
|
||||||
@@ -89,13 +81,15 @@ impl ImageInput {
|
|||||||
/// Get image dimensions (width, height)
|
/// Get image dimensions (width, height)
|
||||||
pub async fn get_dimensions(&self) -> Result<(u32, u32)> {
|
pub async fn get_dimensions(&self) -> Result<(u32, u32)> {
|
||||||
let bytes = match self {
|
let bytes = match self {
|
||||||
Self::Base64 { data, .. } => {
|
Self::Base64 { data, .. } => general_purpose::STANDARD
|
||||||
general_purpose::STANDARD.decode(data).context("Failed to decode base64")?
|
.decode(data)
|
||||||
}
|
.context("Failed to decode base64")?,
|
||||||
Self::Bytes { data, .. } => data.clone(),
|
Self::Bytes { data, .. } => data.clone(),
|
||||||
Self::Url(_) => {
|
Self::Url(_) => {
|
||||||
let (base64_data, _) = self.to_base64().await?;
|
let (base64_data, _) = self.to_base64().await?;
|
||||||
general_purpose::STANDARD.decode(&base64_data).context("Failed to decode base64")?
|
general_purpose::STANDARD
|
||||||
|
.decode(&base64_data)
|
||||||
|
.context("Failed to decode base64")?
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -178,8 +172,10 @@ impl ImageConverter {
|
|||||||
/// Detect if a model supports multimodal input
|
/// Detect if a model supports multimodal input
|
||||||
pub fn model_supports_multimodal(model: &str) -> bool {
|
pub fn model_supports_multimodal(model: &str) -> bool {
|
||||||
// OpenAI vision models
|
// OpenAI vision models
|
||||||
if (model.starts_with("gpt-4") && (model.contains("vision") || model.contains("-v") || model.contains("4o"))) ||
|
if (model.starts_with("gpt-4") && (model.contains("vision") || model.contains("-v") || model.contains("4o")))
|
||||||
model.starts_with("o1-") || model.starts_with("o3-") {
|
|| model.starts_with("o1-")
|
||||||
|
|| model.starts_with("o3-")
|
||||||
|
{
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -208,8 +204,9 @@ pub fn parse_openai_content(content: &serde_json::Value) -> Result<Vec<(String,
|
|||||||
} else if let Some(content_array) = content.as_array() {
|
} else if let Some(content_array) = content.as_array() {
|
||||||
// Array of content parts (text and/or images)
|
// Array of content parts (text and/or images)
|
||||||
for part in content_array {
|
for part in content_array {
|
||||||
if let Some(part_obj) = part.as_object() {
|
if let Some(part_obj) = part.as_object()
|
||||||
if let Some(part_type) = part_obj.get("type").and_then(|t| t.as_str()) {
|
&& let Some(part_type) = part_obj.get("type").and_then(|t| t.as_str())
|
||||||
|
{
|
||||||
match part_type {
|
match part_type {
|
||||||
"text" => {
|
"text" => {
|
||||||
if let Some(text) = part_obj.get("text").and_then(|t| t.as_str()) {
|
if let Some(text) = part_obj.get("text").and_then(|t| t.as_str()) {
|
||||||
@@ -217,8 +214,9 @@ pub fn parse_openai_content(content: &serde_json::Value) -> Result<Vec<(String,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
"image_url" => {
|
"image_url" => {
|
||||||
if let Some(image_url_obj) = part_obj.get("image_url").and_then(|o| o.as_object()) {
|
if let Some(image_url_obj) = part_obj.get("image_url").and_then(|o| o.as_object())
|
||||||
if let Some(url) = image_url_obj.get("url").and_then(|u| u.as_str()) {
|
&& let Some(url) = image_url_obj.get("url").and_then(|u| u.as_str())
|
||||||
|
{
|
||||||
if url.starts_with("data:") {
|
if url.starts_with("data:") {
|
||||||
// Parse data URL
|
// Parse data URL
|
||||||
if let Some((mime_type, data)) = parse_data_url(url) {
|
if let Some((mime_type, data)) = parse_data_url(url) {
|
||||||
@@ -232,7 +230,6 @@ pub fn parse_openai_content(content: &serde_json::Value) -> Result<Vec<(String,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
_ => {
|
_ => {
|
||||||
warn!("Unknown content part type: {}", part_type);
|
warn!("Unknown content part type: {}", part_type);
|
||||||
}
|
}
|
||||||
@@ -240,7 +237,6 @@ pub fn parse_openai_content(content: &serde_json::Value) -> Result<Vec<(String,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
Ok(parts)
|
Ok(parts)
|
||||||
}
|
}
|
||||||
@@ -278,8 +274,10 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_model_supports_multimodal() {
|
async fn test_model_supports_multimodal() {
|
||||||
assert!(ImageConverter::model_supports_multimodal("gpt-4-vision-preview"));
|
assert!(ImageConverter::model_supports_multimodal("gpt-4-vision-preview"));
|
||||||
|
assert!(ImageConverter::model_supports_multimodal("gpt-4o"));
|
||||||
assert!(ImageConverter::model_supports_multimodal("gemini-pro-vision"));
|
assert!(ImageConverter::model_supports_multimodal("gemini-pro-vision"));
|
||||||
|
assert!(ImageConverter::model_supports_multimodal("gemini-pro"));
|
||||||
assert!(!ImageConverter::model_supports_multimodal("gpt-3.5-turbo"));
|
assert!(!ImageConverter::model_supports_multimodal("gpt-3.5-turbo"));
|
||||||
assert!(!ImageConverter::model_supports_multimodal("gemini-pro"));
|
assert!(!ImageConverter::model_supports_multimodal("claude-3-opus"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1,14 +1,10 @@
|
|||||||
use async_trait::async_trait;
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use futures::stream::{BoxStream, StreamExt};
|
use async_trait::async_trait;
|
||||||
use serde_json::Value;
|
use futures::stream::BoxStream;
|
||||||
|
|
||||||
use crate::{
|
use super::helpers;
|
||||||
models::UnifiedRequest,
|
|
||||||
errors::AppError,
|
|
||||||
config::AppConfig,
|
|
||||||
};
|
|
||||||
use super::{ProviderResponse, ProviderStreamChunk};
|
use super::{ProviderResponse, ProviderStreamChunk};
|
||||||
|
use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest};
|
||||||
|
|
||||||
pub struct DeepSeekProvider {
|
pub struct DeepSeekProvider {
|
||||||
client: reqwest::Client,
|
client: reqwest::Client,
|
||||||
@@ -23,7 +19,11 @@ impl DeepSeekProvider {
|
|||||||
Self::new_with_key(config, app_config, api_key)
|
Self::new_with_key(config, app_config, api_key)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_with_key(config: &crate::config::DeepSeekConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
|
pub fn new_with_key(
|
||||||
|
config: &crate::config::DeepSeekConfig,
|
||||||
|
app_config: &AppConfig,
|
||||||
|
api_key: String,
|
||||||
|
) -> Result<Self> {
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
client: reqwest::Client::new(),
|
client: reqwest::Client::new(),
|
||||||
config: config.clone(),
|
config: config.clone(),
|
||||||
@@ -47,42 +47,13 @@ impl super::Provider for DeepSeekProvider {
|
|||||||
false
|
false
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn chat_completion(
|
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
||||||
&self,
|
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
||||||
request: UnifiedRequest,
|
let body = helpers::build_openai_body(&request, messages_json, false);
|
||||||
) -> Result<ProviderResponse, AppError> {
|
|
||||||
// Build the OpenAI-compatible body
|
|
||||||
let mut body = serde_json::json!({
|
|
||||||
"model": request.model,
|
|
||||||
"messages": request.messages.iter().map(|m| {
|
|
||||||
serde_json::json!({
|
|
||||||
"role": m.role,
|
|
||||||
"content": m.content.iter().map(|p| {
|
|
||||||
match p {
|
|
||||||
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
|
|
||||||
crate::models::ContentPart::Image(image_input) => {
|
|
||||||
// DeepSeek currently doesn't support images in the same way, but we'll try to be standard
|
|
||||||
let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
|
|
||||||
serde_json::json!({
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}).collect::<Vec<_>>()
|
|
||||||
})
|
|
||||||
}).collect::<Vec<_>>(),
|
|
||||||
"stream": false,
|
|
||||||
});
|
|
||||||
|
|
||||||
if let Some(temp) = request.temperature {
|
let response = self
|
||||||
body["temperature"] = serde_json::json!(temp);
|
.client
|
||||||
}
|
.post(format!("{}/chat/completions", self.config.base_url))
|
||||||
if let Some(max_tokens) = request.max_tokens {
|
|
||||||
body["max_tokens"] = serde_json::json!(max_tokens);
|
|
||||||
}
|
|
||||||
|
|
||||||
let response = self.client.post(format!("{}/chat/completions", self.config.base_url))
|
|
||||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||||
.json(&body)
|
.json(&body)
|
||||||
.send()
|
.send()
|
||||||
@@ -94,119 +65,52 @@ impl super::Provider for DeepSeekProvider {
|
|||||||
return Err(AppError::ProviderError(format!("DeepSeek API error: {}", error_text)));
|
return Err(AppError::ProviderError(format!("DeepSeek API error: {}", error_text)));
|
||||||
}
|
}
|
||||||
|
|
||||||
let resp_json: Value = response.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
let resp_json: serde_json::Value = response
|
||||||
|
.json()
|
||||||
|
.await
|
||||||
|
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||||
|
|
||||||
let choice = resp_json["choices"].get(0).ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?;
|
helpers::parse_openai_response(&resp_json, request.model)
|
||||||
let message = &choice["message"];
|
|
||||||
|
|
||||||
let content = message["content"].as_str().unwrap_or_default().to_string();
|
|
||||||
let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string());
|
|
||||||
|
|
||||||
let usage = &resp_json["usage"];
|
|
||||||
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
|
|
||||||
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
|
|
||||||
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
|
|
||||||
|
|
||||||
Ok(ProviderResponse {
|
|
||||||
content,
|
|
||||||
reasoning_content,
|
|
||||||
prompt_tokens,
|
|
||||||
completion_tokens,
|
|
||||||
total_tokens,
|
|
||||||
model: request.model,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
|
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
|
||||||
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
|
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64 {
|
fn calculate_cost(
|
||||||
if let Some(metadata) = registry.find_model(model) {
|
&self,
|
||||||
if let Some(cost) = &metadata.cost {
|
model: &str,
|
||||||
return (prompt_tokens as f64 * cost.input / 1_000_000.0) +
|
prompt_tokens: u32,
|
||||||
(completion_tokens as f64 * cost.output / 1_000_000.0);
|
completion_tokens: u32,
|
||||||
}
|
registry: &crate::models::registry::ModelRegistry,
|
||||||
}
|
) -> f64 {
|
||||||
|
helpers::calculate_cost_with_registry(
|
||||||
let (prompt_rate, completion_rate) = self.pricing.iter()
|
model,
|
||||||
.find(|p| model.contains(&p.model))
|
prompt_tokens,
|
||||||
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
|
completion_tokens,
|
||||||
.unwrap_or((0.14, 0.28));
|
registry,
|
||||||
|
&self.pricing,
|
||||||
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
|
0.14,
|
||||||
|
0.28,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn chat_completion_stream(
|
async fn chat_completion_stream(
|
||||||
&self,
|
&self,
|
||||||
request: UnifiedRequest,
|
request: UnifiedRequest,
|
||||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
||||||
let mut body = serde_json::json!({
|
// DeepSeek doesn't support images in streaming, use text-only
|
||||||
"model": request.model,
|
let messages_json = helpers::messages_to_openai_json_text_only(&request.messages).await?;
|
||||||
"messages": request.messages.iter().map(|m| {
|
let body = helpers::build_openai_body(&request, messages_json, true);
|
||||||
serde_json::json!({
|
|
||||||
"role": m.role,
|
|
||||||
"content": m.content.iter().map(|p| {
|
|
||||||
match p {
|
|
||||||
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
|
|
||||||
crate::models::ContentPart::Image(_) => serde_json::json!({ "type": "text", "text": "[Image]" }),
|
|
||||||
}
|
|
||||||
}).collect::<Vec<_>>()
|
|
||||||
})
|
|
||||||
}).collect::<Vec<_>>(),
|
|
||||||
"stream": true,
|
|
||||||
});
|
|
||||||
|
|
||||||
if let Some(temp) = request.temperature {
|
let es = reqwest_eventsource::EventSource::new(
|
||||||
body["temperature"] = serde_json::json!(temp);
|
self.client
|
||||||
}
|
.post(format!("{}/chat/completions", self.config.base_url))
|
||||||
if let Some(max_tokens) = request.max_tokens {
|
|
||||||
body["max_tokens"] = serde_json::json!(max_tokens);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create eventsource stream
|
|
||||||
use reqwest_eventsource::{EventSource, Event};
|
|
||||||
let es = EventSource::new(self.client.post(format!("{}/chat/completions", self.config.base_url))
|
|
||||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||||
.json(&body))
|
.json(&body),
|
||||||
|
)
|
||||||
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
||||||
|
|
||||||
let model = request.model.clone();
|
Ok(helpers::create_openai_stream(es, request.model, None))
|
||||||
|
|
||||||
let stream = async_stream::try_stream! {
|
|
||||||
let mut es = es;
|
|
||||||
while let Some(event) = es.next().await {
|
|
||||||
match event {
|
|
||||||
Ok(Event::Message(msg)) => {
|
|
||||||
if msg.data == "[DONE]" {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
let chunk: Value = serde_json::from_str(&msg.data)
|
|
||||||
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
|
|
||||||
|
|
||||||
if let Some(choice) = chunk["choices"].get(0) {
|
|
||||||
let delta = &choice["delta"];
|
|
||||||
let content = delta["content"].as_str().unwrap_or_default().to_string();
|
|
||||||
let reasoning_content = delta["reasoning_content"].as_str().map(|s| s.to_string());
|
|
||||||
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
|
|
||||||
|
|
||||||
yield ProviderStreamChunk {
|
|
||||||
content,
|
|
||||||
reasoning_content,
|
|
||||||
finish_reason,
|
|
||||||
model: model.clone(),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(_) => continue,
|
|
||||||
Err(e) => {
|
|
||||||
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(Box::pin(stream))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,14 +1,10 @@
|
|||||||
use async_trait::async_trait;
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use serde::{Deserialize, Serialize};
|
use async_trait::async_trait;
|
||||||
use futures::stream::BoxStream;
|
use futures::stream::BoxStream;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::{
|
|
||||||
models::UnifiedRequest,
|
|
||||||
errors::AppError,
|
|
||||||
config::AppConfig,
|
|
||||||
};
|
|
||||||
use super::{ProviderResponse, ProviderStreamChunk};
|
use super::{ProviderResponse, ProviderStreamChunk};
|
||||||
|
use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest};
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
struct GeminiRequest {
|
struct GeminiRequest {
|
||||||
@@ -61,8 +57,6 @@ struct GeminiResponse {
|
|||||||
usage_metadata: Option<GeminiUsageMetadata>,
|
usage_metadata: Option<GeminiUsageMetadata>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
pub struct GeminiProvider {
|
pub struct GeminiProvider {
|
||||||
client: reqwest::Client,
|
client: reqwest::Client,
|
||||||
config: crate::config::GeminiConfig,
|
config: crate::config::GeminiConfig,
|
||||||
@@ -104,10 +98,7 @@ impl super::Provider for GeminiProvider {
|
|||||||
true // Gemini supports vision
|
true // Gemini supports vision
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn chat_completion(
|
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
||||||
&self,
|
|
||||||
request: UnifiedRequest,
|
|
||||||
) -> Result<ProviderResponse, AppError> {
|
|
||||||
// Convert UnifiedRequest to Gemini request
|
// Convert UnifiedRequest to Gemini request
|
||||||
let mut contents = Vec::with_capacity(request.messages.len());
|
let mut contents = Vec::with_capacity(request.messages.len());
|
||||||
|
|
||||||
@@ -123,7 +114,9 @@ impl super::Provider for GeminiProvider {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
crate::models::ContentPart::Image(image_input) => {
|
crate::models::ContentPart::Image(image_input) => {
|
||||||
let (base64_data, mime_type) = image_input.to_base64().await
|
let (base64_data, mime_type) = image_input
|
||||||
|
.to_base64()
|
||||||
|
.await
|
||||||
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
|
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
|
||||||
|
|
||||||
parts.push(GeminiPart {
|
parts.push(GeminiPart {
|
||||||
@@ -143,10 +136,7 @@ impl super::Provider for GeminiProvider {
|
|||||||
_ => "user".to_string(),
|
_ => "user".to_string(),
|
||||||
};
|
};
|
||||||
|
|
||||||
contents.push(GeminiContent {
|
contents.push(GeminiContent { parts, role });
|
||||||
parts,
|
|
||||||
role,
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if contents.is_empty() {
|
if contents.is_empty() {
|
||||||
@@ -169,15 +159,13 @@ impl super::Provider for GeminiProvider {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Build URL
|
// Build URL
|
||||||
let url = format!("{}/models/{}:generateContent?key={}",
|
let url = format!("{}/models/{}:generateContent", self.config.base_url, request.model,);
|
||||||
self.config.base_url,
|
|
||||||
request.model,
|
|
||||||
self.api_key
|
|
||||||
);
|
|
||||||
|
|
||||||
// Send request
|
// Send request
|
||||||
let response = self.client
|
let response = self
|
||||||
|
.client
|
||||||
.post(&url)
|
.post(&url)
|
||||||
|
.header("x-goog-api-key", &self.api_key)
|
||||||
.json(&gemini_request)
|
.json(&gemini_request)
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
@@ -187,7 +175,10 @@ impl super::Provider for GeminiProvider {
|
|||||||
let status = response.status();
|
let status = response.status();
|
||||||
if !status.is_success() {
|
if !status.is_success() {
|
||||||
let error_text = response.text().await.unwrap_or_default();
|
let error_text = response.text().await.unwrap_or_default();
|
||||||
return Err(AppError::ProviderError(format!("Gemini API error ({}): {}", status, error_text)));
|
return Err(AppError::ProviderError(format!(
|
||||||
|
"Gemini API error ({}): {}",
|
||||||
|
status, error_text
|
||||||
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
let gemini_response: GeminiResponse = response
|
let gemini_response: GeminiResponse = response
|
||||||
@@ -196,16 +187,29 @@ impl super::Provider for GeminiProvider {
|
|||||||
.map_err(|e| AppError::ProviderError(format!("Failed to parse response: {}", e)))?;
|
.map_err(|e| AppError::ProviderError(format!("Failed to parse response: {}", e)))?;
|
||||||
|
|
||||||
// Extract content from first candidate
|
// Extract content from first candidate
|
||||||
let content = gemini_response.candidates
|
let content = gemini_response
|
||||||
|
.candidates
|
||||||
.first()
|
.first()
|
||||||
.and_then(|c| c.content.parts.first())
|
.and_then(|c| c.content.parts.first())
|
||||||
.and_then(|p| p.text.clone())
|
.and_then(|p| p.text.clone())
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
|
|
||||||
// Extract token usage
|
// Extract token usage
|
||||||
let prompt_tokens = gemini_response.usage_metadata.as_ref().map(|u| u.prompt_token_count).unwrap_or(0);
|
let prompt_tokens = gemini_response
|
||||||
let completion_tokens = gemini_response.usage_metadata.as_ref().map(|u| u.candidates_token_count).unwrap_or(0);
|
.usage_metadata
|
||||||
let total_tokens = gemini_response.usage_metadata.as_ref().map(|u| u.total_token_count).unwrap_or(0);
|
.as_ref()
|
||||||
|
.map(|u| u.prompt_token_count)
|
||||||
|
.unwrap_or(0);
|
||||||
|
let completion_tokens = gemini_response
|
||||||
|
.usage_metadata
|
||||||
|
.as_ref()
|
||||||
|
.map(|u| u.candidates_token_count)
|
||||||
|
.unwrap_or(0);
|
||||||
|
let total_tokens = gemini_response
|
||||||
|
.usage_metadata
|
||||||
|
.as_ref()
|
||||||
|
.map(|u| u.total_token_count)
|
||||||
|
.unwrap_or(0);
|
||||||
|
|
||||||
Ok(ProviderResponse {
|
Ok(ProviderResponse {
|
||||||
content,
|
content,
|
||||||
@@ -221,20 +225,22 @@ impl super::Provider for GeminiProvider {
|
|||||||
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
|
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64 {
|
fn calculate_cost(
|
||||||
if let Some(metadata) = registry.find_model(model) {
|
&self,
|
||||||
if let Some(cost) = &metadata.cost {
|
model: &str,
|
||||||
return (prompt_tokens as f64 * cost.input / 1_000_000.0) +
|
prompt_tokens: u32,
|
||||||
(completion_tokens as f64 * cost.output / 1_000_000.0);
|
completion_tokens: u32,
|
||||||
}
|
registry: &crate::models::registry::ModelRegistry,
|
||||||
}
|
) -> f64 {
|
||||||
|
super::helpers::calculate_cost_with_registry(
|
||||||
let (prompt_rate, completion_rate) = self.pricing.iter()
|
model,
|
||||||
.find(|p| model.contains(&p.model))
|
prompt_tokens,
|
||||||
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
|
completion_tokens,
|
||||||
.unwrap_or((0.075, 0.30)); // Default to Gemini 2.0 Flash price if not found
|
registry,
|
||||||
|
&self.pricing,
|
||||||
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
|
0.075,
|
||||||
|
0.30,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn chat_completion_stream(
|
async fn chat_completion_stream(
|
||||||
@@ -256,7 +262,9 @@ impl super::Provider for GeminiProvider {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
crate::models::ContentPart::Image(image_input) => {
|
crate::models::ContentPart::Image(image_input) => {
|
||||||
let (base64_data, mime_type) = image_input.to_base64().await
|
let (base64_data, mime_type) = image_input
|
||||||
|
.to_base64()
|
||||||
|
.await
|
||||||
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
|
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
|
||||||
|
|
||||||
parts.push(GeminiPart {
|
parts.push(GeminiPart {
|
||||||
@@ -276,10 +284,7 @@ impl super::Provider for GeminiProvider {
|
|||||||
_ => "user".to_string(),
|
_ => "user".to_string(),
|
||||||
};
|
};
|
||||||
|
|
||||||
contents.push(GeminiContent {
|
contents.push(GeminiContent { parts, role });
|
||||||
parts,
|
|
||||||
role,
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build generation config
|
// Build generation config
|
||||||
@@ -298,17 +303,21 @@ impl super::Provider for GeminiProvider {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Build URL for streaming
|
// Build URL for streaming
|
||||||
let url = format!("{}/models/{}:streamGenerateContent?alt=sse&key={}",
|
let url = format!(
|
||||||
self.config.base_url,
|
"{}/models/{}:streamGenerateContent?alt=sse",
|
||||||
request.model,
|
self.config.base_url, request.model,
|
||||||
self.api_key
|
|
||||||
);
|
);
|
||||||
|
|
||||||
// Create eventsource stream
|
// Create eventsource stream
|
||||||
use reqwest_eventsource::{EventSource, Event};
|
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
|
use reqwest_eventsource::{Event, EventSource};
|
||||||
|
|
||||||
let es = EventSource::new(self.client.post(&url).json(&gemini_request))
|
let es = EventSource::new(
|
||||||
|
self.client
|
||||||
|
.post(&url)
|
||||||
|
.header("x-goog-api-key", &self.api_key)
|
||||||
|
.json(&gemini_request),
|
||||||
|
)
|
||||||
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
||||||
|
|
||||||
let model = request.model.clone();
|
let model = request.model.clone();
|
||||||
|
|||||||
@@ -1,18 +1,14 @@
|
|||||||
use async_trait::async_trait;
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use futures::stream::{BoxStream, StreamExt};
|
use async_trait::async_trait;
|
||||||
use serde_json::Value;
|
use futures::stream::BoxStream;
|
||||||
|
|
||||||
use crate::{
|
use super::helpers;
|
||||||
models::UnifiedRequest,
|
|
||||||
errors::AppError,
|
|
||||||
config::AppConfig,
|
|
||||||
};
|
|
||||||
use super::{ProviderResponse, ProviderStreamChunk};
|
use super::{ProviderResponse, ProviderStreamChunk};
|
||||||
|
use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest};
|
||||||
|
|
||||||
pub struct GrokProvider {
|
pub struct GrokProvider {
|
||||||
client: reqwest::Client,
|
client: reqwest::Client,
|
||||||
_config: crate::config::GrokConfig,
|
config: crate::config::GrokConfig,
|
||||||
api_key: String,
|
api_key: String,
|
||||||
pricing: Vec<crate::config::ModelPricing>,
|
pricing: Vec<crate::config::ModelPricing>,
|
||||||
}
|
}
|
||||||
@@ -26,7 +22,7 @@ impl GrokProvider {
|
|||||||
pub fn new_with_key(config: &crate::config::GrokConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
|
pub fn new_with_key(config: &crate::config::GrokConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
client: reqwest::Client::new(),
|
client: reqwest::Client::new(),
|
||||||
_config: config.clone(),
|
config: config.clone(),
|
||||||
api_key,
|
api_key,
|
||||||
pricing: app_config.pricing.grok.clone(),
|
pricing: app_config.pricing.grok.clone(),
|
||||||
})
|
})
|
||||||
@@ -47,40 +43,13 @@ impl super::Provider for GrokProvider {
|
|||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn chat_completion(
|
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
||||||
&self,
|
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
||||||
request: UnifiedRequest,
|
let body = helpers::build_openai_body(&request, messages_json, false);
|
||||||
) -> Result<ProviderResponse, AppError> {
|
|
||||||
let mut body = serde_json::json!({
|
|
||||||
"model": request.model,
|
|
||||||
"messages": request.messages.iter().map(|m| {
|
|
||||||
serde_json::json!({
|
|
||||||
"role": m.role,
|
|
||||||
"content": m.content.iter().map(|p| {
|
|
||||||
match p {
|
|
||||||
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
|
|
||||||
crate::models::ContentPart::Image(image_input) => {
|
|
||||||
let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
|
|
||||||
serde_json::json!({
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}).collect::<Vec<_>>()
|
|
||||||
})
|
|
||||||
}).collect::<Vec<_>>(),
|
|
||||||
"stream": false,
|
|
||||||
});
|
|
||||||
|
|
||||||
if let Some(temp) = request.temperature {
|
let response = self
|
||||||
body["temperature"] = serde_json::json!(temp);
|
.client
|
||||||
}
|
.post(format!("{}/chat/completions", self.config.base_url))
|
||||||
if let Some(max_tokens) = request.max_tokens {
|
|
||||||
body["max_tokens"] = serde_json::json!(max_tokens);
|
|
||||||
}
|
|
||||||
|
|
||||||
let response = self.client.post(format!("{}/chat/completions", self._config.base_url))
|
|
||||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||||
.json(&body)
|
.json(&body)
|
||||||
.send()
|
.send()
|
||||||
@@ -92,125 +61,51 @@ impl super::Provider for GrokProvider {
|
|||||||
return Err(AppError::ProviderError(format!("Grok API error: {}", error_text)));
|
return Err(AppError::ProviderError(format!("Grok API error: {}", error_text)));
|
||||||
}
|
}
|
||||||
|
|
||||||
let resp_json: Value = response.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
let resp_json: serde_json::Value = response
|
||||||
|
.json()
|
||||||
|
.await
|
||||||
|
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||||
|
|
||||||
let choice = resp_json["choices"].get(0).ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?;
|
helpers::parse_openai_response(&resp_json, request.model)
|
||||||
let message = &choice["message"];
|
|
||||||
|
|
||||||
let content = message["content"].as_str().unwrap_or_default().to_string();
|
|
||||||
let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string());
|
|
||||||
|
|
||||||
let usage = &resp_json["usage"];
|
|
||||||
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
|
|
||||||
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
|
|
||||||
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
|
|
||||||
|
|
||||||
Ok(ProviderResponse {
|
|
||||||
content,
|
|
||||||
reasoning_content,
|
|
||||||
prompt_tokens,
|
|
||||||
completion_tokens,
|
|
||||||
total_tokens,
|
|
||||||
model: request.model,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
|
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
|
||||||
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
|
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64 {
|
fn calculate_cost(
|
||||||
if let Some(metadata) = registry.find_model(model) {
|
&self,
|
||||||
if let Some(cost) = &metadata.cost {
|
model: &str,
|
||||||
return (prompt_tokens as f64 * cost.input / 1_000_000.0) +
|
prompt_tokens: u32,
|
||||||
(completion_tokens as f64 * cost.output / 1_000_000.0);
|
completion_tokens: u32,
|
||||||
}
|
registry: &crate::models::registry::ModelRegistry,
|
||||||
}
|
) -> f64 {
|
||||||
|
helpers::calculate_cost_with_registry(
|
||||||
let (prompt_rate, completion_rate) = self.pricing.iter()
|
model,
|
||||||
.find(|p| model.contains(&p.model))
|
prompt_tokens,
|
||||||
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
|
completion_tokens,
|
||||||
.unwrap_or((5.0, 15.0));
|
registry,
|
||||||
|
&self.pricing,
|
||||||
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
|
5.0,
|
||||||
|
15.0,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn chat_completion_stream(
|
async fn chat_completion_stream(
|
||||||
&self,
|
&self,
|
||||||
request: UnifiedRequest,
|
request: UnifiedRequest,
|
||||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
||||||
let mut body = serde_json::json!({
|
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
||||||
"model": request.model,
|
let body = helpers::build_openai_body(&request, messages_json, true);
|
||||||
"messages": request.messages.iter().map(|m| {
|
|
||||||
serde_json::json!({
|
|
||||||
"role": m.role,
|
|
||||||
"content": m.content.iter().map(|p| {
|
|
||||||
match p {
|
|
||||||
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
|
|
||||||
crate::models::ContentPart::Image(image_input) => {
|
|
||||||
let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
|
|
||||||
serde_json::json!({
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}).collect::<Vec<_>>()
|
|
||||||
})
|
|
||||||
}).collect::<Vec<_>>(),
|
|
||||||
"stream": true,
|
|
||||||
});
|
|
||||||
|
|
||||||
if let Some(temp) = request.temperature {
|
let es = reqwest_eventsource::EventSource::new(
|
||||||
body["temperature"] = serde_json::json!(temp);
|
self.client
|
||||||
}
|
.post(format!("{}/chat/completions", self.config.base_url))
|
||||||
if let Some(max_tokens) = request.max_tokens {
|
|
||||||
body["max_tokens"] = serde_json::json!(max_tokens);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create eventsource stream
|
|
||||||
use reqwest_eventsource::{EventSource, Event};
|
|
||||||
let es = EventSource::new(self.client.post(format!("{}/chat/completions", self._config.base_url))
|
|
||||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||||
.json(&body))
|
.json(&body),
|
||||||
|
)
|
||||||
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
||||||
|
|
||||||
let model = request.model.clone();
|
Ok(helpers::create_openai_stream(es, request.model, None))
|
||||||
|
|
||||||
let stream = async_stream::try_stream! {
|
|
||||||
let mut es = es;
|
|
||||||
while let Some(event) = es.next().await {
|
|
||||||
match event {
|
|
||||||
Ok(Event::Message(msg)) => {
|
|
||||||
if msg.data == "[DONE]" {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
let chunk: Value = serde_json::from_str(&msg.data)
|
|
||||||
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
|
|
||||||
|
|
||||||
if let Some(choice) = chunk["choices"].get(0) {
|
|
||||||
let delta = &choice["delta"];
|
|
||||||
let content = delta["content"].as_str().unwrap_or_default().to_string();
|
|
||||||
let reasoning_content = delta["reasoning_content"].as_str().map(|s| s.to_string());
|
|
||||||
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
|
|
||||||
|
|
||||||
yield ProviderStreamChunk {
|
|
||||||
content,
|
|
||||||
reasoning_content,
|
|
||||||
finish_reason,
|
|
||||||
model: model.clone(),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(_) => continue,
|
|
||||||
Err(e) => {
|
|
||||||
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(Box::pin(stream))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
189
src/providers/helpers.rs
Normal file
189
src/providers/helpers.rs
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
use super::{ProviderResponse, ProviderStreamChunk};
|
||||||
|
use crate::errors::AppError;
|
||||||
|
use crate::models::{ContentPart, UnifiedMessage, UnifiedRequest};
|
||||||
|
use futures::stream::{BoxStream, StreamExt};
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
/// Convert messages to OpenAI-compatible JSON, resolving images asynchronously.
|
||||||
|
///
|
||||||
|
/// This avoids the deadlock caused by `futures::executor::block_on` inside a
|
||||||
|
/// Tokio async context. All image base64 conversions are awaited properly.
|
||||||
|
pub async fn messages_to_openai_json(messages: &[UnifiedMessage]) -> Result<Vec<serde_json::Value>, AppError> {
|
||||||
|
let mut result = Vec::new();
|
||||||
|
for m in messages {
|
||||||
|
let mut parts = Vec::new();
|
||||||
|
for p in &m.content {
|
||||||
|
match p {
|
||||||
|
ContentPart::Text { text } => {
|
||||||
|
parts.push(serde_json::json!({ "type": "text", "text": text }));
|
||||||
|
}
|
||||||
|
ContentPart::Image(image_input) => {
|
||||||
|
let (base64_data, mime_type) = image_input
|
||||||
|
.to_base64()
|
||||||
|
.await
|
||||||
|
.map_err(|e| AppError::MultimodalError(e.to_string()))?;
|
||||||
|
parts.push(serde_json::json!({
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result.push(serde_json::json!({
|
||||||
|
"role": m.role,
|
||||||
|
"content": parts
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert messages to OpenAI-compatible JSON, but replace images with a
|
||||||
|
/// text placeholder "[Image]". Useful for providers that don't support
|
||||||
|
/// multimodal in streaming mode or at all.
|
||||||
|
pub async fn messages_to_openai_json_text_only(
|
||||||
|
messages: &[UnifiedMessage],
|
||||||
|
) -> Result<Vec<serde_json::Value>, AppError> {
|
||||||
|
let mut result = Vec::new();
|
||||||
|
for m in messages {
|
||||||
|
let mut parts = Vec::new();
|
||||||
|
for p in &m.content {
|
||||||
|
match p {
|
||||||
|
ContentPart::Text { text } => {
|
||||||
|
parts.push(serde_json::json!({ "type": "text", "text": text }));
|
||||||
|
}
|
||||||
|
ContentPart::Image(_) => {
|
||||||
|
parts.push(serde_json::json!({ "type": "text", "text": "[Image]" }));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result.push(serde_json::json!({
|
||||||
|
"role": m.role,
|
||||||
|
"content": parts
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build an OpenAI-compatible request body from a UnifiedRequest and pre-converted messages.
|
||||||
|
pub fn build_openai_body(
|
||||||
|
request: &UnifiedRequest,
|
||||||
|
messages_json: Vec<serde_json::Value>,
|
||||||
|
stream: bool,
|
||||||
|
) -> serde_json::Value {
|
||||||
|
let mut body = serde_json::json!({
|
||||||
|
"model": request.model,
|
||||||
|
"messages": messages_json,
|
||||||
|
"stream": stream,
|
||||||
|
});
|
||||||
|
|
||||||
|
if let Some(temp) = request.temperature {
|
||||||
|
body["temperature"] = serde_json::json!(temp);
|
||||||
|
}
|
||||||
|
if let Some(max_tokens) = request.max_tokens {
|
||||||
|
body["max_tokens"] = serde_json::json!(max_tokens);
|
||||||
|
}
|
||||||
|
|
||||||
|
body
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse an OpenAI-compatible chat completion response JSON into a ProviderResponse.
|
||||||
|
pub fn parse_openai_response(resp_json: &Value, model: String) -> Result<ProviderResponse, AppError> {
|
||||||
|
let choice = resp_json["choices"]
|
||||||
|
.get(0)
|
||||||
|
.ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?;
|
||||||
|
let message = &choice["message"];
|
||||||
|
|
||||||
|
let content = message["content"].as_str().unwrap_or_default().to_string();
|
||||||
|
let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string());
|
||||||
|
|
||||||
|
let usage = &resp_json["usage"];
|
||||||
|
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
|
||||||
|
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
|
||||||
|
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
|
||||||
|
|
||||||
|
Ok(ProviderResponse {
|
||||||
|
content,
|
||||||
|
reasoning_content,
|
||||||
|
prompt_tokens,
|
||||||
|
completion_tokens,
|
||||||
|
total_tokens,
|
||||||
|
model,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create an SSE stream that parses OpenAI-compatible streaming chunks.
|
||||||
|
///
|
||||||
|
/// The optional `reasoning_field` allows overriding the field name for
|
||||||
|
/// reasoning content (e.g., "thought" for Ollama).
|
||||||
|
pub fn create_openai_stream(
|
||||||
|
es: reqwest_eventsource::EventSource,
|
||||||
|
model: String,
|
||||||
|
reasoning_field: Option<&'static str>,
|
||||||
|
) -> BoxStream<'static, Result<ProviderStreamChunk, AppError>> {
|
||||||
|
use reqwest_eventsource::Event;
|
||||||
|
|
||||||
|
let stream = async_stream::try_stream! {
|
||||||
|
let mut es = es;
|
||||||
|
while let Some(event) = es.next().await {
|
||||||
|
match event {
|
||||||
|
Ok(Event::Message(msg)) => {
|
||||||
|
if msg.data == "[DONE]" {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let chunk: Value = serde_json::from_str(&msg.data)
|
||||||
|
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
|
||||||
|
|
||||||
|
if let Some(choice) = chunk["choices"].get(0) {
|
||||||
|
let delta = &choice["delta"];
|
||||||
|
let content = delta["content"].as_str().unwrap_or_default().to_string();
|
||||||
|
let reasoning_content = delta["reasoning_content"]
|
||||||
|
.as_str()
|
||||||
|
.or_else(|| reasoning_field.and_then(|f| delta[f].as_str()))
|
||||||
|
.map(|s| s.to_string());
|
||||||
|
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
|
||||||
|
|
||||||
|
yield ProviderStreamChunk {
|
||||||
|
content,
|
||||||
|
reasoning_content,
|
||||||
|
finish_reason,
|
||||||
|
model: model.clone(),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(_) => continue,
|
||||||
|
Err(e) => {
|
||||||
|
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Box::pin(stream)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Calculate cost using the model registry first, then falling back to provider pricing config.
|
||||||
|
pub fn calculate_cost_with_registry(
|
||||||
|
model: &str,
|
||||||
|
prompt_tokens: u32,
|
||||||
|
completion_tokens: u32,
|
||||||
|
registry: &crate::models::registry::ModelRegistry,
|
||||||
|
pricing: &[crate::config::ModelPricing],
|
||||||
|
default_prompt_rate: f64,
|
||||||
|
default_completion_rate: f64,
|
||||||
|
) -> f64 {
|
||||||
|
if let Some(metadata) = registry.find_model(model)
|
||||||
|
&& let Some(cost) = &metadata.cost
|
||||||
|
{
|
||||||
|
return (prompt_tokens as f64 * cost.input / 1_000_000.0)
|
||||||
|
+ (completion_tokens as f64 * cost.output / 1_000_000.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
let (prompt_rate, completion_rate) = pricing
|
||||||
|
.iter()
|
||||||
|
.find(|p| model.contains(&p.model))
|
||||||
|
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
|
||||||
|
.unwrap_or((default_prompt_rate, default_completion_rate));
|
||||||
|
|
||||||
|
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
|
||||||
|
}
|
||||||
@@ -1,17 +1,18 @@
|
|||||||
use async_trait::async_trait;
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use std::sync::Arc;
|
use async_trait::async_trait;
|
||||||
use futures::stream::BoxStream;
|
use futures::stream::BoxStream;
|
||||||
use sqlx::Row;
|
use sqlx::Row;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::models::UnifiedRequest;
|
|
||||||
use crate::errors::AppError;
|
use crate::errors::AppError;
|
||||||
|
use crate::models::UnifiedRequest;
|
||||||
|
|
||||||
pub mod openai;
|
|
||||||
pub mod gemini;
|
|
||||||
pub mod deepseek;
|
pub mod deepseek;
|
||||||
|
pub mod gemini;
|
||||||
pub mod grok;
|
pub mod grok;
|
||||||
|
pub mod helpers;
|
||||||
pub mod ollama;
|
pub mod ollama;
|
||||||
|
pub mod openai;
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait Provider: Send + Sync {
|
pub trait Provider: Send + Sync {
|
||||||
@@ -25,10 +26,7 @@ pub trait Provider: Send + Sync {
|
|||||||
fn supports_multimodal(&self) -> bool;
|
fn supports_multimodal(&self) -> bool;
|
||||||
|
|
||||||
/// Process a chat completion request
|
/// Process a chat completion request
|
||||||
async fn chat_completion(
|
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError>;
|
||||||
&self,
|
|
||||||
request: UnifiedRequest,
|
|
||||||
) -> Result<ProviderResponse, AppError>;
|
|
||||||
|
|
||||||
/// Process a streaming chat completion request
|
/// Process a streaming chat completion request
|
||||||
async fn chat_completion_stream(
|
async fn chat_completion_stream(
|
||||||
@@ -40,7 +38,13 @@ pub trait Provider: Send + Sync {
|
|||||||
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32>;
|
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32>;
|
||||||
|
|
||||||
/// Calculate cost based on token usage and model using the registry
|
/// Calculate cost based on token usage and model using the registry
|
||||||
fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64;
|
fn calculate_cost(
|
||||||
|
&self,
|
||||||
|
model: &str,
|
||||||
|
prompt_tokens: u32,
|
||||||
|
completion_tokens: u32,
|
||||||
|
registry: &crate::models::registry::ModelRegistry,
|
||||||
|
) -> f64;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct ProviderResponse {
|
pub struct ProviderResponse {
|
||||||
@@ -64,11 +68,8 @@ use tokio::sync::RwLock;
|
|||||||
|
|
||||||
use crate::config::AppConfig;
|
use crate::config::AppConfig;
|
||||||
use crate::providers::{
|
use crate::providers::{
|
||||||
|
deepseek::DeepSeekProvider, gemini::GeminiProvider, grok::GrokProvider, ollama::OllamaProvider,
|
||||||
openai::OpenAIProvider,
|
openai::OpenAIProvider,
|
||||||
gemini::GeminiProvider,
|
|
||||||
deepseek::DeepSeekProvider,
|
|
||||||
grok::GrokProvider,
|
|
||||||
ollama::OllamaProvider,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
@@ -76,6 +77,12 @@ pub struct ProviderManager {
|
|||||||
providers: Arc<RwLock<Vec<Arc<dyn Provider>>>>,
|
providers: Arc<RwLock<Vec<Arc<dyn Provider>>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Default for ProviderManager {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl ProviderManager {
|
impl ProviderManager {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
@@ -84,7 +91,12 @@ impl ProviderManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Initialize a provider by name using config and database overrides
|
/// Initialize a provider by name using config and database overrides
|
||||||
pub async fn initialize_provider(&self, name: &str, app_config: &AppConfig, db_pool: &crate::database::DbPool) -> Result<()> {
|
pub async fn initialize_provider(
|
||||||
|
&self,
|
||||||
|
name: &str,
|
||||||
|
app_config: &AppConfig,
|
||||||
|
db_pool: &crate::database::DbPool,
|
||||||
|
) -> Result<()> {
|
||||||
// Load override from database
|
// Load override from database
|
||||||
let db_config = sqlx::query("SELECT enabled, base_url, api_key FROM provider_configs WHERE id = ?")
|
let db_config = sqlx::query("SELECT enabled, base_url, api_key FROM provider_configs WHERE id = ?")
|
||||||
.bind(name)
|
.bind(name)
|
||||||
@@ -100,11 +112,31 @@ impl ProviderManager {
|
|||||||
} else {
|
} else {
|
||||||
// No database override, use defaults from AppConfig
|
// No database override, use defaults from AppConfig
|
||||||
match name {
|
match name {
|
||||||
"openai" => (app_config.providers.openai.enabled, Some(app_config.providers.openai.base_url.clone()), None),
|
"openai" => (
|
||||||
"gemini" => (app_config.providers.gemini.enabled, Some(app_config.providers.gemini.base_url.clone()), None),
|
app_config.providers.openai.enabled,
|
||||||
"deepseek" => (app_config.providers.deepseek.enabled, Some(app_config.providers.deepseek.base_url.clone()), None),
|
Some(app_config.providers.openai.base_url.clone()),
|
||||||
"grok" => (app_config.providers.grok.enabled, Some(app_config.providers.grok.base_url.clone()), None),
|
None,
|
||||||
"ollama" => (app_config.providers.ollama.enabled, Some(app_config.providers.ollama.base_url.clone()), None),
|
),
|
||||||
|
"gemini" => (
|
||||||
|
app_config.providers.gemini.enabled,
|
||||||
|
Some(app_config.providers.gemini.base_url.clone()),
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
"deepseek" => (
|
||||||
|
app_config.providers.deepseek.enabled,
|
||||||
|
Some(app_config.providers.deepseek.base_url.clone()),
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
"grok" => (
|
||||||
|
app_config.providers.grok.enabled,
|
||||||
|
Some(app_config.providers.grok.base_url.clone()),
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
"ollama" => (
|
||||||
|
app_config.providers.ollama.enabled,
|
||||||
|
Some(app_config.providers.ollama.base_url.clone()),
|
||||||
|
None,
|
||||||
|
),
|
||||||
_ => (false, None, None),
|
_ => (false, None, None),
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -118,7 +150,9 @@ impl ProviderManager {
|
|||||||
let provider: Arc<dyn Provider> = match name {
|
let provider: Arc<dyn Provider> = match name {
|
||||||
"openai" => {
|
"openai" => {
|
||||||
let mut cfg = app_config.providers.openai.clone();
|
let mut cfg = app_config.providers.openai.clone();
|
||||||
if let Some(url) = base_url { cfg.base_url = url; }
|
if let Some(url) = base_url {
|
||||||
|
cfg.base_url = url;
|
||||||
|
}
|
||||||
// Handle API key override if present
|
// Handle API key override if present
|
||||||
let p = if let Some(key) = api_key {
|
let p = if let Some(key) = api_key {
|
||||||
// We need a way to create a provider with an explicit key
|
// We need a way to create a provider with an explicit key
|
||||||
@@ -128,42 +162,50 @@ impl ProviderManager {
|
|||||||
OpenAIProvider::new(&cfg, app_config)?
|
OpenAIProvider::new(&cfg, app_config)?
|
||||||
};
|
};
|
||||||
Arc::new(p)
|
Arc::new(p)
|
||||||
},
|
}
|
||||||
"ollama" => {
|
"ollama" => {
|
||||||
let mut cfg = app_config.providers.ollama.clone();
|
let mut cfg = app_config.providers.ollama.clone();
|
||||||
if let Some(url) = base_url { cfg.base_url = url; }
|
if let Some(url) = base_url {
|
||||||
|
cfg.base_url = url;
|
||||||
|
}
|
||||||
Arc::new(OllamaProvider::new(&cfg, app_config)?)
|
Arc::new(OllamaProvider::new(&cfg, app_config)?)
|
||||||
},
|
}
|
||||||
"gemini" => {
|
"gemini" => {
|
||||||
let mut cfg = app_config.providers.gemini.clone();
|
let mut cfg = app_config.providers.gemini.clone();
|
||||||
if let Some(url) = base_url { cfg.base_url = url; }
|
if let Some(url) = base_url {
|
||||||
|
cfg.base_url = url;
|
||||||
|
}
|
||||||
let p = if let Some(key) = api_key {
|
let p = if let Some(key) = api_key {
|
||||||
GeminiProvider::new_with_key(&cfg, app_config, key)?
|
GeminiProvider::new_with_key(&cfg, app_config, key)?
|
||||||
} else {
|
} else {
|
||||||
GeminiProvider::new(&cfg, app_config)?
|
GeminiProvider::new(&cfg, app_config)?
|
||||||
};
|
};
|
||||||
Arc::new(p)
|
Arc::new(p)
|
||||||
},
|
}
|
||||||
"deepseek" => {
|
"deepseek" => {
|
||||||
let mut cfg = app_config.providers.deepseek.clone();
|
let mut cfg = app_config.providers.deepseek.clone();
|
||||||
if let Some(url) = base_url { cfg.base_url = url; }
|
if let Some(url) = base_url {
|
||||||
|
cfg.base_url = url;
|
||||||
|
}
|
||||||
let p = if let Some(key) = api_key {
|
let p = if let Some(key) = api_key {
|
||||||
DeepSeekProvider::new_with_key(&cfg, app_config, key)?
|
DeepSeekProvider::new_with_key(&cfg, app_config, key)?
|
||||||
} else {
|
} else {
|
||||||
DeepSeekProvider::new(&cfg, app_config)?
|
DeepSeekProvider::new(&cfg, app_config)?
|
||||||
};
|
};
|
||||||
Arc::new(p)
|
Arc::new(p)
|
||||||
},
|
}
|
||||||
"grok" => {
|
"grok" => {
|
||||||
let mut cfg = app_config.providers.grok.clone();
|
let mut cfg = app_config.providers.grok.clone();
|
||||||
if let Some(url) = base_url { cfg.base_url = url; }
|
if let Some(url) = base_url {
|
||||||
|
cfg.base_url = url;
|
||||||
|
}
|
||||||
let p = if let Some(key) = api_key {
|
let p = if let Some(key) = api_key {
|
||||||
GrokProvider::new_with_key(&cfg, app_config, key)?
|
GrokProvider::new_with_key(&cfg, app_config, key)?
|
||||||
} else {
|
} else {
|
||||||
GrokProvider::new(&cfg, app_config)?
|
GrokProvider::new(&cfg, app_config)?
|
||||||
};
|
};
|
||||||
Arc::new(p)
|
Arc::new(p)
|
||||||
},
|
}
|
||||||
_ => return Err(anyhow::anyhow!("Unknown provider: {}", name)),
|
_ => return Err(anyhow::anyhow!("Unknown provider: {}", name)),
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -188,16 +230,12 @@ impl ProviderManager {
|
|||||||
|
|
||||||
pub async fn get_provider_for_model(&self, model: &str) -> Option<Arc<dyn Provider>> {
|
pub async fn get_provider_for_model(&self, model: &str) -> Option<Arc<dyn Provider>> {
|
||||||
let providers = self.providers.read().await;
|
let providers = self.providers.read().await;
|
||||||
providers.iter()
|
providers.iter().find(|p| p.supports_model(model)).map(Arc::clone)
|
||||||
.find(|p| p.supports_model(model))
|
|
||||||
.map(|p| Arc::clone(p))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_provider(&self, name: &str) -> Option<Arc<dyn Provider>> {
|
pub async fn get_provider(&self, name: &str) -> Option<Arc<dyn Provider>> {
|
||||||
let providers = self.providers.read().await;
|
let providers = self.providers.read().await;
|
||||||
providers.iter()
|
providers.iter().find(|p| p.name() == name).map(Arc::clone)
|
||||||
.find(|p| p.name() == name)
|
|
||||||
.map(|p| Arc::clone(p))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_all_providers(&self) -> Vec<Arc<dyn Provider>> {
|
pub async fn get_all_providers(&self) -> Vec<Arc<dyn Provider>> {
|
||||||
@@ -238,21 +276,29 @@ pub mod placeholder {
|
|||||||
&self,
|
&self,
|
||||||
_request: UnifiedRequest,
|
_request: UnifiedRequest,
|
||||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
||||||
Err(AppError::ProviderError("Streaming not supported for placeholder provider".to_string()))
|
Err(AppError::ProviderError(
|
||||||
|
"Streaming not supported for placeholder provider".to_string(),
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn chat_completion(
|
async fn chat_completion(&self, _request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
||||||
&self,
|
Err(AppError::ProviderError(format!(
|
||||||
_request: UnifiedRequest,
|
"Provider {} not implemented",
|
||||||
) -> Result<ProviderResponse, AppError> {
|
self.name
|
||||||
Err(AppError::ProviderError(format!("Provider {} not implemented", self.name)))
|
)))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn estimate_tokens(&self, _request: &UnifiedRequest) -> Result<u32> {
|
fn estimate_tokens(&self, _request: &UnifiedRequest) -> Result<u32> {
|
||||||
Ok(0)
|
Ok(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn calculate_cost(&self, _model: &str, _prompt_tokens: u32, _completion_tokens: u32, _registry: &crate::models::registry::ModelRegistry) -> f64 {
|
fn calculate_cost(
|
||||||
|
&self,
|
||||||
|
_model: &str,
|
||||||
|
_prompt_tokens: u32,
|
||||||
|
_completion_tokens: u32,
|
||||||
|
_registry: &crate::models::registry::ModelRegistry,
|
||||||
|
) -> f64 {
|
||||||
0.0
|
0.0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,18 +1,14 @@
|
|||||||
use async_trait::async_trait;
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use futures::stream::{BoxStream, StreamExt};
|
use async_trait::async_trait;
|
||||||
use serde_json::Value;
|
use futures::stream::BoxStream;
|
||||||
|
|
||||||
use crate::{
|
use super::helpers;
|
||||||
models::UnifiedRequest,
|
|
||||||
errors::AppError,
|
|
||||||
config::AppConfig,
|
|
||||||
};
|
|
||||||
use super::{ProviderResponse, ProviderStreamChunk};
|
use super::{ProviderResponse, ProviderStreamChunk};
|
||||||
|
use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest};
|
||||||
|
|
||||||
pub struct OllamaProvider {
|
pub struct OllamaProvider {
|
||||||
client: reqwest::Client,
|
client: reqwest::Client,
|
||||||
_config: crate::config::OllamaConfig,
|
config: crate::config::OllamaConfig,
|
||||||
pricing: Vec<crate::config::ModelPricing>,
|
pricing: Vec<crate::config::ModelPricing>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -20,7 +16,7 @@ impl OllamaProvider {
|
|||||||
pub fn new(config: &crate::config::OllamaConfig, app_config: &AppConfig) -> Result<Self> {
|
pub fn new(config: &crate::config::OllamaConfig, app_config: &AppConfig) -> Result<Self> {
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
client: reqwest::Client::new(),
|
client: reqwest::Client::new(),
|
||||||
_config: config.clone(),
|
config: config.clone(),
|
||||||
pricing: app_config.pricing.ollama.clone(),
|
pricing: app_config.pricing.ollama.clone(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -33,49 +29,29 @@ impl super::Provider for OllamaProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn supports_model(&self, model: &str) -> bool {
|
fn supports_model(&self, model: &str) -> bool {
|
||||||
self._config.models.iter().any(|m| m == model) || model.starts_with("ollama/")
|
self.config.models.iter().any(|m| m == model) || model.starts_with("ollama/")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn supports_multimodal(&self) -> bool {
|
fn supports_multimodal(&self) -> bool {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn chat_completion(
|
async fn chat_completion(&self, mut request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
||||||
&self,
|
// Strip "ollama/" prefix if present for the API call
|
||||||
request: UnifiedRequest,
|
let api_model = request
|
||||||
) -> Result<ProviderResponse, AppError> {
|
.model
|
||||||
let model = request.model.strip_prefix("ollama/").unwrap_or(&request.model).to_string();
|
.strip_prefix("ollama/")
|
||||||
|
.unwrap_or(&request.model)
|
||||||
|
.to_string();
|
||||||
|
let original_model = request.model.clone();
|
||||||
|
request.model = api_model;
|
||||||
|
|
||||||
let mut body = serde_json::json!({
|
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
||||||
"model": model,
|
let body = helpers::build_openai_body(&request, messages_json, false);
|
||||||
"messages": request.messages.iter().map(|m| {
|
|
||||||
serde_json::json!({
|
|
||||||
"role": m.role,
|
|
||||||
"content": m.content.iter().map(|p| {
|
|
||||||
match p {
|
|
||||||
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
|
|
||||||
crate::models::ContentPart::Image(image_input) => {
|
|
||||||
let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
|
|
||||||
serde_json::json!({
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}).collect::<Vec<_>>()
|
|
||||||
})
|
|
||||||
}).collect::<Vec<_>>(),
|
|
||||||
"stream": false,
|
|
||||||
});
|
|
||||||
|
|
||||||
if let Some(temp) = request.temperature {
|
let response = self
|
||||||
body["temperature"] = serde_json::json!(temp);
|
.client
|
||||||
}
|
.post(format!("{}/chat/completions", self.config.base_url))
|
||||||
if let Some(max_tokens) = request.max_tokens {
|
|
||||||
body["max_tokens"] = serde_json::json!(max_tokens);
|
|
||||||
}
|
|
||||||
|
|
||||||
let response = self.client.post(format!("{}/chat/completions", self._config.base_url))
|
|
||||||
.json(&body)
|
.json(&body)
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
@@ -86,120 +62,67 @@ impl super::Provider for OllamaProvider {
|
|||||||
return Err(AppError::ProviderError(format!("Ollama API error: {}", error_text)));
|
return Err(AppError::ProviderError(format!("Ollama API error: {}", error_text)));
|
||||||
}
|
}
|
||||||
|
|
||||||
let resp_json: Value = response.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
let resp_json: serde_json::Value = response
|
||||||
|
.json()
|
||||||
|
.await
|
||||||
|
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||||
|
|
||||||
let choice = resp_json["choices"].get(0).ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?;
|
// Ollama also supports "thought" as an alias for reasoning_content
|
||||||
let message = &choice["message"];
|
let mut result = helpers::parse_openai_response(&resp_json, original_model)?;
|
||||||
|
if result.reasoning_content.is_none() {
|
||||||
let content = message["content"].as_str().unwrap_or_default().to_string();
|
result.reasoning_content = resp_json["choices"]
|
||||||
let reasoning_content = message["reasoning_content"].as_str().or_else(|| message["thought"].as_str()).map(|s| s.to_string());
|
.get(0)
|
||||||
|
.and_then(|c| c["message"]["thought"].as_str())
|
||||||
let usage = &resp_json["usage"];
|
.map(|s| s.to_string());
|
||||||
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
|
}
|
||||||
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
|
Ok(result)
|
||||||
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
|
|
||||||
|
|
||||||
Ok(ProviderResponse {
|
|
||||||
content,
|
|
||||||
reasoning_content,
|
|
||||||
prompt_tokens,
|
|
||||||
completion_tokens,
|
|
||||||
total_tokens,
|
|
||||||
model: request.model,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
|
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
|
||||||
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
|
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64 {
|
fn calculate_cost(
|
||||||
if let Some(metadata) = registry.find_model(model) {
|
&self,
|
||||||
if let Some(cost) = &metadata.cost {
|
model: &str,
|
||||||
return (prompt_tokens as f64 * cost.input / 1_000_000.0) +
|
prompt_tokens: u32,
|
||||||
(completion_tokens as f64 * cost.output / 1_000_000.0);
|
completion_tokens: u32,
|
||||||
}
|
registry: &crate::models::registry::ModelRegistry,
|
||||||
}
|
) -> f64 {
|
||||||
|
helpers::calculate_cost_with_registry(
|
||||||
let (prompt_rate, completion_rate) = self.pricing.iter()
|
model,
|
||||||
.find(|p| model.contains(&p.model))
|
prompt_tokens,
|
||||||
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
|
completion_tokens,
|
||||||
.unwrap_or((0.0, 0.0));
|
registry,
|
||||||
|
&self.pricing,
|
||||||
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
|
0.0,
|
||||||
|
0.0,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn chat_completion_stream(
|
async fn chat_completion_stream(
|
||||||
&self,
|
&self,
|
||||||
request: UnifiedRequest,
|
mut request: UnifiedRequest,
|
||||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
||||||
let model = request.model.strip_prefix("ollama/").unwrap_or(&request.model).to_string();
|
let api_model = request
|
||||||
|
.model
|
||||||
|
.strip_prefix("ollama/")
|
||||||
|
.unwrap_or(&request.model)
|
||||||
|
.to_string();
|
||||||
|
let original_model = request.model.clone();
|
||||||
|
request.model = api_model;
|
||||||
|
|
||||||
let mut body = serde_json::json!({
|
let messages_json = helpers::messages_to_openai_json_text_only(&request.messages).await?;
|
||||||
"model": model,
|
let body = helpers::build_openai_body(&request, messages_json, true);
|
||||||
"messages": request.messages.iter().map(|m| {
|
|
||||||
serde_json::json!({
|
|
||||||
"role": m.role,
|
|
||||||
"content": m.content.iter().map(|p| {
|
|
||||||
match p {
|
|
||||||
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
|
|
||||||
crate::models::ContentPart::Image(_) => serde_json::json!({ "type": "text", "text": "[Image]" }),
|
|
||||||
}
|
|
||||||
}).collect::<Vec<_>>()
|
|
||||||
})
|
|
||||||
}).collect::<Vec<_>>(),
|
|
||||||
"stream": true,
|
|
||||||
});
|
|
||||||
|
|
||||||
if let Some(temp) = request.temperature {
|
let es = reqwest_eventsource::EventSource::new(
|
||||||
body["temperature"] = serde_json::json!(temp);
|
self.client
|
||||||
}
|
.post(format!("{}/chat/completions", self.config.base_url))
|
||||||
if let Some(max_tokens) = request.max_tokens {
|
.json(&body),
|
||||||
body["max_tokens"] = serde_json::json!(max_tokens);
|
)
|
||||||
}
|
|
||||||
|
|
||||||
// Create eventsource stream
|
|
||||||
use reqwest_eventsource::{EventSource, Event};
|
|
||||||
let es = EventSource::new(self.client.post(format!("{}/chat/completions", self._config.base_url))
|
|
||||||
.json(&body))
|
|
||||||
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
||||||
|
|
||||||
let model_name = request.model.clone();
|
// Ollama uses "thought" as an alternative field for reasoning content
|
||||||
|
Ok(helpers::create_openai_stream(es, original_model, Some("thought")))
|
||||||
let stream = async_stream::try_stream! {
|
|
||||||
let mut es = es;
|
|
||||||
while let Some(event) = es.next().await {
|
|
||||||
match event {
|
|
||||||
Ok(Event::Message(msg)) => {
|
|
||||||
if msg.data == "[DONE]" {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
let chunk: Value = serde_json::from_str(&msg.data)
|
|
||||||
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
|
|
||||||
|
|
||||||
if let Some(choice) = chunk["choices"].get(0) {
|
|
||||||
let delta = &choice["delta"];
|
|
||||||
let content = delta["content"].as_str().unwrap_or_default().to_string();
|
|
||||||
let reasoning_content = delta["reasoning_content"].as_str().or_else(|| delta["thought"].as_str()).map(|s| s.to_string());
|
|
||||||
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
|
|
||||||
|
|
||||||
yield ProviderStreamChunk {
|
|
||||||
content,
|
|
||||||
reasoning_content,
|
|
||||||
finish_reason,
|
|
||||||
model: model_name.clone(),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(_) => continue,
|
|
||||||
Err(e) => {
|
|
||||||
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(Box::pin(stream))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,18 +1,14 @@
|
|||||||
use async_trait::async_trait;
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use futures::stream::{BoxStream, StreamExt};
|
use async_trait::async_trait;
|
||||||
use serde_json::Value;
|
use futures::stream::BoxStream;
|
||||||
|
|
||||||
use crate::{
|
use super::helpers;
|
||||||
models::UnifiedRequest,
|
|
||||||
errors::AppError,
|
|
||||||
config::AppConfig,
|
|
||||||
};
|
|
||||||
use super::{ProviderResponse, ProviderStreamChunk};
|
use super::{ProviderResponse, ProviderStreamChunk};
|
||||||
|
use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest};
|
||||||
|
|
||||||
pub struct OpenAIProvider {
|
pub struct OpenAIProvider {
|
||||||
client: reqwest::Client,
|
client: reqwest::Client,
|
||||||
_config: crate::config::OpenAIConfig,
|
config: crate::config::OpenAIConfig,
|
||||||
api_key: String,
|
api_key: String,
|
||||||
pricing: Vec<crate::config::ModelPricing>,
|
pricing: Vec<crate::config::ModelPricing>,
|
||||||
}
|
}
|
||||||
@@ -26,7 +22,7 @@ impl OpenAIProvider {
|
|||||||
pub fn new_with_key(config: &crate::config::OpenAIConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
|
pub fn new_with_key(config: &crate::config::OpenAIConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
client: reqwest::Client::new(),
|
client: reqwest::Client::new(),
|
||||||
_config: config.clone(),
|
config: config.clone(),
|
||||||
api_key,
|
api_key,
|
||||||
pricing: app_config.pricing.openai.clone(),
|
pricing: app_config.pricing.openai.clone(),
|
||||||
})
|
})
|
||||||
@@ -47,40 +43,13 @@ impl super::Provider for OpenAIProvider {
|
|||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn chat_completion(
|
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
||||||
&self,
|
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
||||||
request: UnifiedRequest,
|
let body = helpers::build_openai_body(&request, messages_json, false);
|
||||||
) -> Result<ProviderResponse, AppError> {
|
|
||||||
let mut body = serde_json::json!({
|
|
||||||
"model": request.model,
|
|
||||||
"messages": request.messages.iter().map(|m| {
|
|
||||||
serde_json::json!({
|
|
||||||
"role": m.role,
|
|
||||||
"content": m.content.iter().map(|p| {
|
|
||||||
match p {
|
|
||||||
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
|
|
||||||
crate::models::ContentPart::Image(image_input) => {
|
|
||||||
let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
|
|
||||||
serde_json::json!({
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}).collect::<Vec<_>>()
|
|
||||||
})
|
|
||||||
}).collect::<Vec<_>>(),
|
|
||||||
"stream": false,
|
|
||||||
});
|
|
||||||
|
|
||||||
if let Some(temp) = request.temperature {
|
let response = self
|
||||||
body["temperature"] = serde_json::json!(temp);
|
.client
|
||||||
}
|
.post(format!("{}/chat/completions", self.config.base_url))
|
||||||
if let Some(max_tokens) = request.max_tokens {
|
|
||||||
body["max_tokens"] = serde_json::json!(max_tokens);
|
|
||||||
}
|
|
||||||
|
|
||||||
let response = self.client.post(format!("{}/chat/completions", self._config.base_url))
|
|
||||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||||
.json(&body)
|
.json(&body)
|
||||||
.send()
|
.send()
|
||||||
@@ -92,125 +61,51 @@ impl super::Provider for OpenAIProvider {
|
|||||||
return Err(AppError::ProviderError(format!("OpenAI API error: {}", error_text)));
|
return Err(AppError::ProviderError(format!("OpenAI API error: {}", error_text)));
|
||||||
}
|
}
|
||||||
|
|
||||||
let resp_json: Value = response.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
let resp_json: serde_json::Value = response
|
||||||
|
.json()
|
||||||
|
.await
|
||||||
|
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||||
|
|
||||||
let choice = resp_json["choices"].get(0).ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?;
|
helpers::parse_openai_response(&resp_json, request.model)
|
||||||
let message = &choice["message"];
|
|
||||||
|
|
||||||
let content = message["content"].as_str().unwrap_or_default().to_string();
|
|
||||||
let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string());
|
|
||||||
|
|
||||||
let usage = &resp_json["usage"];
|
|
||||||
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
|
|
||||||
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
|
|
||||||
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
|
|
||||||
|
|
||||||
Ok(ProviderResponse {
|
|
||||||
content,
|
|
||||||
reasoning_content,
|
|
||||||
prompt_tokens,
|
|
||||||
completion_tokens,
|
|
||||||
total_tokens,
|
|
||||||
model: request.model,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
|
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
|
||||||
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
|
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64 {
|
fn calculate_cost(
|
||||||
if let Some(metadata) = registry.find_model(model) {
|
&self,
|
||||||
if let Some(cost) = &metadata.cost {
|
model: &str,
|
||||||
return (prompt_tokens as f64 * cost.input / 1_000_000.0) +
|
prompt_tokens: u32,
|
||||||
(completion_tokens as f64 * cost.output / 1_000_000.0);
|
completion_tokens: u32,
|
||||||
}
|
registry: &crate::models::registry::ModelRegistry,
|
||||||
}
|
) -> f64 {
|
||||||
|
helpers::calculate_cost_with_registry(
|
||||||
let (prompt_rate, completion_rate) = self.pricing.iter()
|
model,
|
||||||
.find(|p| model.contains(&p.model))
|
prompt_tokens,
|
||||||
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
|
completion_tokens,
|
||||||
.unwrap_or((0.15, 0.60));
|
registry,
|
||||||
|
&self.pricing,
|
||||||
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
|
0.15,
|
||||||
|
0.60,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn chat_completion_stream(
|
async fn chat_completion_stream(
|
||||||
&self,
|
&self,
|
||||||
request: UnifiedRequest,
|
request: UnifiedRequest,
|
||||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
||||||
let mut body = serde_json::json!({
|
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
||||||
"model": request.model,
|
let body = helpers::build_openai_body(&request, messages_json, true);
|
||||||
"messages": request.messages.iter().map(|m| {
|
|
||||||
serde_json::json!({
|
|
||||||
"role": m.role,
|
|
||||||
"content": m.content.iter().map(|p| {
|
|
||||||
match p {
|
|
||||||
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
|
|
||||||
crate::models::ContentPart::Image(image_input) => {
|
|
||||||
let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
|
|
||||||
serde_json::json!({
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}).collect::<Vec<_>>()
|
|
||||||
})
|
|
||||||
}).collect::<Vec<_>>(),
|
|
||||||
"stream": true,
|
|
||||||
});
|
|
||||||
|
|
||||||
if let Some(temp) = request.temperature {
|
let es = reqwest_eventsource::EventSource::new(
|
||||||
body["temperature"] = serde_json::json!(temp);
|
self.client
|
||||||
}
|
.post(format!("{}/chat/completions", self.config.base_url))
|
||||||
if let Some(max_tokens) = request.max_tokens {
|
|
||||||
body["max_tokens"] = serde_json::json!(max_tokens);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create eventsource stream
|
|
||||||
use reqwest_eventsource::{EventSource, Event};
|
|
||||||
let es = EventSource::new(self.client.post(format!("{}/chat/completions", self._config.base_url))
|
|
||||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||||
.json(&body))
|
.json(&body),
|
||||||
|
)
|
||||||
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
||||||
|
|
||||||
let model = request.model.clone();
|
Ok(helpers::create_openai_stream(es, request.model, None))
|
||||||
|
|
||||||
let stream = async_stream::try_stream! {
|
|
||||||
let mut es = es;
|
|
||||||
while let Some(event) = es.next().await {
|
|
||||||
match event {
|
|
||||||
Ok(Event::Message(msg)) => {
|
|
||||||
if msg.data == "[DONE]" {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
let chunk: Value = serde_json::from_str(&msg.data)
|
|
||||||
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
|
|
||||||
|
|
||||||
if let Some(choice) = chunk["choices"].get(0) {
|
|
||||||
let delta = &choice["delta"];
|
|
||||||
let content = delta["content"].as_str().unwrap_or_default().to_string();
|
|
||||||
let reasoning_content = delta["reasoning_content"].as_str().map(|s| s.to_string());
|
|
||||||
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
|
|
||||||
|
|
||||||
yield ProviderStreamChunk {
|
|
||||||
content,
|
|
||||||
reasoning_content,
|
|
||||||
finish_reason,
|
|
||||||
model: model.clone(),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(_) => continue,
|
|
||||||
Err(e) => {
|
|
||||||
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(Box::pin(stream))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,12 +5,12 @@
|
|||||||
//! 2. Provider circuit breaking to handle API failures
|
//! 2. Provider circuit breaking to handle API failures
|
||||||
//! 3. Global rate limiting for overall system protection
|
//! 3. Global rate limiting for overall system protection
|
||||||
|
|
||||||
use std::sync::Arc;
|
use anyhow::Result;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
use tracing::{info, warn};
|
use tracing::{info, warn};
|
||||||
use anyhow::Result;
|
|
||||||
|
|
||||||
/// Rate limiter configuration
|
/// Rate limiter configuration
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@@ -177,12 +177,12 @@ impl ProviderCircuitBreaker {
|
|||||||
let now = std::time::Instant::now();
|
let now = std::time::Instant::now();
|
||||||
|
|
||||||
// Check if failure window has expired
|
// Check if failure window has expired
|
||||||
if let Some(last_failure) = self.last_failure_time {
|
if let Some(last_failure) = self.last_failure_time
|
||||||
if now.duration_since(last_failure).as_secs() > self.config.failure_window_secs {
|
&& now.duration_since(last_failure).as_secs() > self.config.failure_window_secs
|
||||||
|
{
|
||||||
// Reset failure count if window expired
|
// Reset failure count if window expired
|
||||||
self.failure_count = 0;
|
self.failure_count = 0;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
self.failure_count += 1;
|
self.failure_count += 1;
|
||||||
self.last_failure_time = Some(now);
|
self.last_failure_time = Some(now);
|
||||||
@@ -246,9 +246,7 @@ impl RateLimitManager {
|
|||||||
|
|
||||||
// Check client-specific rate limit
|
// Check client-specific rate limit
|
||||||
let mut buckets = self.client_buckets.write().await;
|
let mut buckets = self.client_buckets.write().await;
|
||||||
let bucket = buckets
|
let bucket = buckets.entry(client_id.to_string()).or_insert_with(|| {
|
||||||
.entry(client_id.to_string())
|
|
||||||
.or_insert_with(|| {
|
|
||||||
TokenBucket::new(
|
TokenBucket::new(
|
||||||
self.config.burst_size as f64,
|
self.config.burst_size as f64,
|
||||||
self.config.requests_per_minute as f64 / 60.0,
|
self.config.requests_per_minute as f64 / 60.0,
|
||||||
@@ -299,14 +297,13 @@ impl RateLimitManager {
|
|||||||
/// Axum middleware for rate limiting
|
/// Axum middleware for rate limiting
|
||||||
pub mod middleware {
|
pub mod middleware {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use crate::errors::AppError;
|
||||||
|
use crate::state::AppState;
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{Request, State},
|
extract::{Request, State},
|
||||||
middleware::Next,
|
middleware::Next,
|
||||||
response::Response,
|
response::Response,
|
||||||
};
|
};
|
||||||
use crate::errors::AppError;
|
|
||||||
use crate::state::AppState;
|
|
||||||
|
|
||||||
|
|
||||||
/// Rate limiting middleware
|
/// Rate limiting middleware
|
||||||
pub async fn rate_limit_middleware(
|
pub async fn rate_limit_middleware(
|
||||||
@@ -319,9 +316,7 @@ pub mod middleware {
|
|||||||
|
|
||||||
// Check rate limits
|
// Check rate limits
|
||||||
if !state.rate_limit_manager.check_client_request(&client_id).await? {
|
if !state.rate_limit_manager.check_client_request(&client_id).await? {
|
||||||
return Err(AppError::RateLimitError(
|
return Err(AppError::RateLimitError("Rate limit exceeded".to_string()));
|
||||||
"Rate limit exceeded".to_string()
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(next.run(request).await)
|
Ok(next.run(request).await)
|
||||||
@@ -330,29 +325,25 @@ pub mod middleware {
|
|||||||
/// Extract client ID from request (helper function)
|
/// Extract client ID from request (helper function)
|
||||||
fn extract_client_id_from_request(request: &Request) -> String {
|
fn extract_client_id_from_request(request: &Request) -> String {
|
||||||
// Try to extract from Authorization header
|
// Try to extract from Authorization header
|
||||||
if let Some(auth_header) = request.headers().get("Authorization") {
|
if let Some(auth_header) = request.headers().get("Authorization")
|
||||||
if let Ok(auth_str) = auth_header.to_str() {
|
&& let Ok(auth_str) = auth_header.to_str()
|
||||||
if auth_str.starts_with("Bearer ") {
|
&& let Some(token) = auth_str.strip_prefix("Bearer ")
|
||||||
let token = &auth_str[7..];
|
{
|
||||||
// Use token hash as client ID (same logic as auth module)
|
// Use token hash as client ID (same logic as auth module)
|
||||||
return format!("client_{}", &token[..8.min(token.len())]);
|
return format!("client_{}", &token[..8.min(token.len())]);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fallback to anonymous
|
// Fallback to anonymous
|
||||||
"anonymous".to_string()
|
"anonymous".to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Circuit breaker middleware for provider requests
|
/// Circuit breaker middleware for provider requests
|
||||||
pub async fn circuit_breaker_middleware(
|
pub async fn circuit_breaker_middleware(provider_name: &str, state: &AppState) -> Result<(), AppError> {
|
||||||
provider_name: &str,
|
|
||||||
state: &AppState,
|
|
||||||
) -> Result<(), AppError> {
|
|
||||||
if !state.rate_limit_manager.check_provider_request(provider_name).await? {
|
if !state.rate_limit_manager.check_provider_request(provider_name).await? {
|
||||||
return Err(AppError::ProviderError(
|
return Err(AppError::ProviderError(format!(
|
||||||
format!("Provider {} is currently unavailable (circuit breaker open)", provider_name)
|
"Provider {} is currently unavailable (circuit breaker open)",
|
||||||
));
|
provider_name
|
||||||
|
)));
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,22 +1,25 @@
|
|||||||
use std::sync::Arc;
|
|
||||||
use sqlx::Row;
|
|
||||||
use uuid::Uuid;
|
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::State,
|
|
||||||
routing::post,
|
|
||||||
Json, Router,
|
Json, Router,
|
||||||
response::sse::{Event, Sse},
|
extract::State,
|
||||||
response::IntoResponse,
|
response::IntoResponse,
|
||||||
|
response::sse::{Event, Sse},
|
||||||
|
routing::post,
|
||||||
};
|
};
|
||||||
use futures::stream::StreamExt;
|
use futures::stream::StreamExt;
|
||||||
|
use sqlx::Row;
|
||||||
|
use std::sync::Arc;
|
||||||
use tracing::{info, warn};
|
use tracing::{info, warn};
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
auth::AuthenticatedClient,
|
auth::AuthenticatedClient,
|
||||||
errors::AppError,
|
errors::AppError,
|
||||||
models::{ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, ChatStreamChoice, ChatStreamDelta, ChatMessage, ChatChoice, Usage},
|
models::{
|
||||||
state::AppState,
|
ChatChoice, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, ChatMessage,
|
||||||
|
ChatStreamChoice, ChatStreamDelta, Usage,
|
||||||
|
},
|
||||||
rate_limiting,
|
rate_limiting,
|
||||||
|
state::AppState,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn router(state: AppState) -> Router {
|
pub fn router(state: AppState) -> Router {
|
||||||
@@ -85,7 +88,10 @@ async fn chat_completions(
|
|||||||
};
|
};
|
||||||
|
|
||||||
if !model_enabled {
|
if !model_enabled {
|
||||||
return Err(AppError::ValidationError(format!("Model {} is currently disabled", model)));
|
return Err(AppError::ValidationError(format!(
|
||||||
|
"Model {} is currently disabled",
|
||||||
|
model
|
||||||
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply mapping if present
|
// Apply mapping if present
|
||||||
@@ -95,7 +101,10 @@ async fn chat_completions(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Find appropriate provider for the model
|
// Find appropriate provider for the model
|
||||||
let provider = state.provider_manager.get_provider_for_model(&request.model).await
|
let provider = state
|
||||||
|
.provider_manager
|
||||||
|
.get_provider_for_model(&request.model)
|
||||||
|
.await
|
||||||
.ok_or_else(|| AppError::ProviderError(format!("No provider found for model: {}", request.model)))?;
|
.ok_or_else(|| AppError::ProviderError(format!("No provider found for model: {}", request.model)))?;
|
||||||
|
|
||||||
let provider_name = provider.name().to_string();
|
let provider_name = provider.name().to_string();
|
||||||
@@ -104,23 +113,26 @@ async fn chat_completions(
|
|||||||
rate_limiting::middleware::circuit_breaker_middleware(&provider_name, &state).await?;
|
rate_limiting::middleware::circuit_breaker_middleware(&provider_name, &state).await?;
|
||||||
|
|
||||||
// Convert to unified request format
|
// Convert to unified request format
|
||||||
let mut unified_request = crate::models::UnifiedRequest::try_from(request)
|
let mut unified_request =
|
||||||
.map_err(|e| AppError::ValidationError(e.to_string()))?;
|
crate::models::UnifiedRequest::try_from(request).map_err(|e| AppError::ValidationError(e.to_string()))?;
|
||||||
|
|
||||||
// Set client_id from authentication
|
// Set client_id from authentication
|
||||||
unified_request.client_id = client_id.clone();
|
unified_request.client_id = client_id.clone();
|
||||||
|
|
||||||
// Hydrate images if present
|
// Hydrate images if present
|
||||||
if unified_request.has_images {
|
if unified_request.has_images {
|
||||||
unified_request.hydrate_images().await
|
unified_request
|
||||||
|
.hydrate_images()
|
||||||
|
.await
|
||||||
.map_err(|e| AppError::ValidationError(format!("Failed to process images: {}", e)))?;
|
.map_err(|e| AppError::ValidationError(format!("Failed to process images: {}", e)))?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let has_images = unified_request.has_images;
|
||||||
|
|
||||||
// Check if streaming is requested
|
// Check if streaming is requested
|
||||||
if unified_request.stream {
|
if unified_request.stream {
|
||||||
// Estimate prompt tokens for logging later
|
// Estimate prompt tokens for logging later
|
||||||
let prompt_tokens = crate::utils::tokens::estimate_request_tokens(&model, &unified_request);
|
let prompt_tokens = crate::utils::tokens::estimate_request_tokens(&model, &unified_request);
|
||||||
let has_images = unified_request.has_images;
|
|
||||||
|
|
||||||
// Handle streaming response
|
// Handle streaming response
|
||||||
let stream_result = provider.chat_completion_stream(unified_request).await;
|
let stream_result = provider.chat_completion_stream(unified_request).await;
|
||||||
@@ -133,15 +145,17 @@ async fn chat_completions(
|
|||||||
// Wrap with AggregatingStream for token counting and database logging
|
// Wrap with AggregatingStream for token counting and database logging
|
||||||
let aggregating_stream = crate::utils::streaming::AggregatingStream::new(
|
let aggregating_stream = crate::utils::streaming::AggregatingStream::new(
|
||||||
stream,
|
stream,
|
||||||
client_id.clone(),
|
crate::utils::streaming::StreamConfig {
|
||||||
provider.clone(),
|
client_id: client_id.clone(),
|
||||||
model.clone(),
|
provider: provider.clone(),
|
||||||
|
model: model.clone(),
|
||||||
prompt_tokens,
|
prompt_tokens,
|
||||||
has_images,
|
has_images,
|
||||||
state.request_logger.clone(),
|
logger: state.request_logger.clone(),
|
||||||
state.client_manager.clone(),
|
client_manager: state.client_manager.clone(),
|
||||||
state.model_registry.clone(),
|
model_registry: state.model_registry.clone(),
|
||||||
state.db_pool.clone(),
|
db_pool: state.db_pool.clone(),
|
||||||
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
// Create SSE stream from aggregating stream
|
// Create SSE stream from aggregating stream
|
||||||
@@ -165,7 +179,13 @@ async fn chat_completions(
|
|||||||
}],
|
}],
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(Event::default().json_data(response).unwrap())
|
match Event::default().json_data(response) {
|
||||||
|
Ok(event) => Ok(event),
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Failed to serialize SSE event: {}", e);
|
||||||
|
Err(AppError::InternalError("SSE serialization failed".to_string()))
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
warn!("Error in streaming response: {}", e);
|
warn!("Error in streaming response: {}", e);
|
||||||
@@ -197,7 +217,14 @@ async fn chat_completions(
|
|||||||
state.rate_limit_manager.record_provider_success(&provider_name).await;
|
state.rate_limit_manager.record_provider_success(&provider_name).await;
|
||||||
|
|
||||||
let duration = start_time.elapsed();
|
let duration = start_time.elapsed();
|
||||||
let cost = get_model_cost(&response.model, response.prompt_tokens, response.completion_tokens, &provider, &state).await;
|
let cost = get_model_cost(
|
||||||
|
&response.model,
|
||||||
|
response.prompt_tokens,
|
||||||
|
response.completion_tokens,
|
||||||
|
&provider,
|
||||||
|
&state,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
// Log request to database
|
// Log request to database
|
||||||
state.request_logger.log_request(crate::logging::RequestLog {
|
state.request_logger.log_request(crate::logging::RequestLog {
|
||||||
timestamp: chrono::Utc::now(),
|
timestamp: chrono::Utc::now(),
|
||||||
@@ -208,18 +235,17 @@ async fn chat_completions(
|
|||||||
completion_tokens: response.completion_tokens,
|
completion_tokens: response.completion_tokens,
|
||||||
total_tokens: response.total_tokens,
|
total_tokens: response.total_tokens,
|
||||||
cost,
|
cost,
|
||||||
has_images: false, // TODO: check images
|
has_images,
|
||||||
status: "success".to_string(),
|
status: "success".to_string(),
|
||||||
error_message: None,
|
error_message: None,
|
||||||
duration_ms: duration.as_millis() as u64,
|
duration_ms: duration.as_millis() as u64,
|
||||||
});
|
});
|
||||||
|
|
||||||
// Update client usage
|
// Update client usage
|
||||||
let _ = state.client_manager.update_client_usage(
|
let _ = state
|
||||||
&client_id,
|
.client_manager
|
||||||
response.total_tokens as i64,
|
.update_client_usage(&client_id, response.total_tokens as i64, cost)
|
||||||
cost,
|
.await;
|
||||||
).await;
|
|
||||||
|
|
||||||
// Convert ProviderResponse to ChatCompletionResponse
|
// Convert ProviderResponse to ChatCompletionResponse
|
||||||
let chat_response = ChatCompletionResponse {
|
let chat_response = ChatCompletionResponse {
|
||||||
@@ -232,7 +258,7 @@ async fn chat_completions(
|
|||||||
message: ChatMessage {
|
message: ChatMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: crate::models::MessageContent::Text {
|
content: crate::models::MessageContent::Text {
|
||||||
content: response.content
|
content: response.content,
|
||||||
},
|
},
|
||||||
reasoning_content: response.reasoning_content,
|
reasoning_content: response.reasoning_content,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -2,9 +2,8 @@ use std::sync::Arc;
|
|||||||
use tokio::sync::broadcast;
|
use tokio::sync::broadcast;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
client::ClientManager, database::DbPool, providers::ProviderManager,
|
client::ClientManager, config::AppConfig, database::DbPool, logging::RequestLogger,
|
||||||
rate_limiting::RateLimitManager, logging::RequestLogger,
|
models::registry::ModelRegistry, providers::ProviderManager, rate_limiting::RateLimitManager,
|
||||||
models::registry::ModelRegistry, config::AppConfig,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Shared application state
|
/// Shared application state
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
pub mod tokens;
|
|
||||||
pub mod registry;
|
pub mod registry;
|
||||||
pub mod streaming;
|
pub mod streaming;
|
||||||
|
pub mod tokens;
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
|
use crate::models::registry::ModelRegistry;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
use crate::models::registry::ModelRegistry;
|
|
||||||
|
|
||||||
const MODELS_DEV_URL: &str = "https://models.dev/api.json";
|
const MODELS_DEV_URL: &str = "https://models.dev/api.json";
|
||||||
|
|
||||||
|
|||||||
@@ -1,13 +1,26 @@
|
|||||||
use futures::stream::Stream;
|
|
||||||
use std::pin::Pin;
|
|
||||||
use std::task::{Context, Poll};
|
|
||||||
use std::sync::Arc;
|
|
||||||
use sqlx::Row;
|
|
||||||
use crate::logging::{RequestLogger, RequestLog};
|
|
||||||
use crate::client::ClientManager;
|
use crate::client::ClientManager;
|
||||||
use crate::providers::{Provider, ProviderStreamChunk};
|
|
||||||
use crate::errors::AppError;
|
use crate::errors::AppError;
|
||||||
|
use crate::logging::{RequestLog, RequestLogger};
|
||||||
|
use crate::providers::{Provider, ProviderStreamChunk};
|
||||||
use crate::utils::tokens::estimate_completion_tokens;
|
use crate::utils::tokens::estimate_completion_tokens;
|
||||||
|
use futures::stream::Stream;
|
||||||
|
use sqlx::Row;
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::task::{Context, Poll};
|
||||||
|
|
||||||
|
/// Configuration for creating an AggregatingStream.
|
||||||
|
pub struct StreamConfig {
|
||||||
|
pub client_id: String,
|
||||||
|
pub provider: Arc<dyn Provider>,
|
||||||
|
pub model: String,
|
||||||
|
pub prompt_tokens: u32,
|
||||||
|
pub has_images: bool,
|
||||||
|
pub logger: Arc<RequestLogger>,
|
||||||
|
pub client_manager: Arc<ClientManager>,
|
||||||
|
pub model_registry: Arc<crate::models::registry::ModelRegistry>,
|
||||||
|
pub db_pool: crate::database::DbPool,
|
||||||
|
}
|
||||||
|
|
||||||
pub struct AggregatingStream<S> {
|
pub struct AggregatingStream<S> {
|
||||||
inner: S,
|
inner: S,
|
||||||
@@ -28,33 +41,22 @@ pub struct AggregatingStream<S> {
|
|||||||
|
|
||||||
impl<S> AggregatingStream<S>
|
impl<S> AggregatingStream<S>
|
||||||
where
|
where
|
||||||
S: Stream<Item = Result<ProviderStreamChunk, AppError>> + Unpin
|
S: Stream<Item = Result<ProviderStreamChunk, AppError>> + Unpin,
|
||||||
{
|
{
|
||||||
pub fn new(
|
pub fn new(inner: S, config: StreamConfig) -> Self {
|
||||||
inner: S,
|
|
||||||
client_id: String,
|
|
||||||
provider: Arc<dyn Provider>,
|
|
||||||
model: String,
|
|
||||||
prompt_tokens: u32,
|
|
||||||
has_images: bool,
|
|
||||||
logger: Arc<RequestLogger>,
|
|
||||||
client_manager: Arc<ClientManager>,
|
|
||||||
model_registry: Arc<crate::models::registry::ModelRegistry>,
|
|
||||||
db_pool: crate::database::DbPool,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
Self {
|
||||||
inner,
|
inner,
|
||||||
client_id,
|
client_id: config.client_id,
|
||||||
provider,
|
provider: config.provider,
|
||||||
model,
|
model: config.model,
|
||||||
prompt_tokens,
|
prompt_tokens: config.prompt_tokens,
|
||||||
has_images,
|
has_images: config.has_images,
|
||||||
accumulated_content: String::new(),
|
accumulated_content: String::new(),
|
||||||
accumulated_reasoning: String::new(),
|
accumulated_reasoning: String::new(),
|
||||||
logger,
|
logger: config.logger,
|
||||||
client_manager,
|
client_manager: config.client_manager,
|
||||||
model_registry,
|
model_registry: config.model_registry,
|
||||||
db_pool,
|
db_pool: config.db_pool,
|
||||||
start_time: std::time::Instant::now(),
|
start_time: std::time::Instant::now(),
|
||||||
has_logged: false,
|
has_logged: false,
|
||||||
}
|
}
|
||||||
@@ -92,7 +94,8 @@ where
|
|||||||
// Spawn a background task to log the completion
|
// Spawn a background task to log the completion
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
// Check database for cost overrides
|
// Check database for cost overrides
|
||||||
let db_cost = sqlx::query("SELECT prompt_cost_per_m, completion_cost_per_m FROM model_configs WHERE id = ?")
|
let db_cost =
|
||||||
|
sqlx::query("SELECT prompt_cost_per_m, completion_cost_per_m FROM model_configs WHERE id = ?")
|
||||||
.bind(&model)
|
.bind(&model)
|
||||||
.fetch_optional(&pool)
|
.fetch_optional(&pool)
|
||||||
.await
|
.await
|
||||||
@@ -128,18 +131,16 @@ where
|
|||||||
});
|
});
|
||||||
|
|
||||||
// Update client usage
|
// Update client usage
|
||||||
let _ = client_manager.update_client_usage(
|
let _ = client_manager
|
||||||
&client_id,
|
.update_client_usage(&client_id, total_tokens as i64, cost)
|
||||||
total_tokens as i64,
|
.await;
|
||||||
cost,
|
|
||||||
).await;
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S> Stream for AggregatingStream<S>
|
impl<S> Stream for AggregatingStream<S>
|
||||||
where
|
where
|
||||||
S: Stream<Item = Result<ProviderStreamChunk, AppError>> + Unpin
|
S: Stream<Item = Result<ProviderStreamChunk, AppError>> + Unpin,
|
||||||
{
|
{
|
||||||
type Item = Result<ProviderStreamChunk, AppError>;
|
type Item = Result<ProviderStreamChunk, AppError>;
|
||||||
|
|
||||||
@@ -173,46 +174,81 @@ where
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use futures::stream::{self, StreamExt};
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
|
use futures::stream::{self, StreamExt};
|
||||||
|
|
||||||
// Simple mock provider for testing
|
// Simple mock provider for testing
|
||||||
struct MockProvider;
|
struct MockProvider;
|
||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
impl Provider for MockProvider {
|
impl Provider for MockProvider {
|
||||||
fn name(&self) -> &str { "mock" }
|
fn name(&self) -> &str {
|
||||||
fn supports_model(&self, _model: &str) -> bool { true }
|
"mock"
|
||||||
fn supports_multimodal(&self) -> bool { false }
|
}
|
||||||
async fn chat_completion(&self, _req: crate::models::UnifiedRequest) -> Result<crate::providers::ProviderResponse, AppError> { unimplemented!() }
|
fn supports_model(&self, _model: &str) -> bool {
|
||||||
async fn chat_completion_stream(&self, _req: crate::models::UnifiedRequest) -> Result<futures::stream::BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> { unimplemented!() }
|
true
|
||||||
fn estimate_tokens(&self, _req: &crate::models::UnifiedRequest) -> Result<u32> { Ok(10) }
|
}
|
||||||
fn calculate_cost(&self, _model: &str, _p: u32, _c: u32, _r: &crate::models::registry::ModelRegistry) -> f64 { 0.05 }
|
fn supports_multimodal(&self) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
async fn chat_completion(
|
||||||
|
&self,
|
||||||
|
_req: crate::models::UnifiedRequest,
|
||||||
|
) -> Result<crate::providers::ProviderResponse, AppError> {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
async fn chat_completion_stream(
|
||||||
|
&self,
|
||||||
|
_req: crate::models::UnifiedRequest,
|
||||||
|
) -> Result<futures::stream::BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
fn estimate_tokens(&self, _req: &crate::models::UnifiedRequest) -> Result<u32> {
|
||||||
|
Ok(10)
|
||||||
|
}
|
||||||
|
fn calculate_cost(&self, _model: &str, _p: u32, _c: u32, _r: &crate::models::registry::ModelRegistry) -> f64 {
|
||||||
|
0.05
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_aggregating_stream() {
|
async fn test_aggregating_stream() {
|
||||||
let chunks = vec![
|
let chunks = vec![
|
||||||
Ok(ProviderStreamChunk { content: "Hello".to_string(), finish_reason: None, model: "test".to_string() }),
|
Ok(ProviderStreamChunk {
|
||||||
Ok(ProviderStreamChunk { content: " World".to_string(), finish_reason: Some("stop".to_string()), model: "test".to_string() }),
|
content: "Hello".to_string(),
|
||||||
|
reasoning_content: None,
|
||||||
|
finish_reason: None,
|
||||||
|
model: "test".to_string(),
|
||||||
|
}),
|
||||||
|
Ok(ProviderStreamChunk {
|
||||||
|
content: " World".to_string(),
|
||||||
|
reasoning_content: None,
|
||||||
|
finish_reason: Some("stop".to_string()),
|
||||||
|
model: "test".to_string(),
|
||||||
|
}),
|
||||||
];
|
];
|
||||||
let inner_stream = stream::iter(chunks);
|
let inner_stream = stream::iter(chunks);
|
||||||
|
|
||||||
let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap();
|
let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap();
|
||||||
let logger = Arc::new(RequestLogger::new(pool.clone()));
|
let (dashboard_tx, _) = tokio::sync::broadcast::channel(16);
|
||||||
|
let logger = Arc::new(RequestLogger::new(pool.clone(), dashboard_tx));
|
||||||
let client_manager = Arc::new(ClientManager::new(pool.clone()));
|
let client_manager = Arc::new(ClientManager::new(pool.clone()));
|
||||||
let registry = Arc::new(crate::models::registry::ModelRegistry { providers: std::collections::HashMap::new() });
|
let registry = Arc::new(crate::models::registry::ModelRegistry {
|
||||||
|
providers: std::collections::HashMap::new(),
|
||||||
|
});
|
||||||
|
|
||||||
let mut agg_stream = AggregatingStream::new(
|
let mut agg_stream = AggregatingStream::new(
|
||||||
inner_stream,
|
inner_stream,
|
||||||
"client_1".to_string(),
|
StreamConfig {
|
||||||
Arc::new(MockProvider),
|
client_id: "client_1".to_string(),
|
||||||
"test".to_string(),
|
provider: Arc::new(MockProvider),
|
||||||
10,
|
model: "test".to_string(),
|
||||||
false,
|
prompt_tokens: 10,
|
||||||
|
has_images: false,
|
||||||
logger,
|
logger,
|
||||||
client_manager,
|
client_manager,
|
||||||
registry,
|
model_registry: registry,
|
||||||
pool.clone(),
|
db_pool: pool.clone(),
|
||||||
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
while let Some(item) = agg_stream.next().await {
|
while let Some(item) = agg_stream.next().await {
|
||||||
|
|||||||
@@ -1,12 +1,11 @@
|
|||||||
use tiktoken_rs::get_bpe_from_model;
|
|
||||||
use crate::models::UnifiedRequest;
|
use crate::models::UnifiedRequest;
|
||||||
|
use tiktoken_rs::get_bpe_from_model;
|
||||||
|
|
||||||
/// Count tokens for a given model and text
|
/// Count tokens for a given model and text
|
||||||
pub fn count_tokens(model: &str, text: &str) -> u32 {
|
pub fn count_tokens(model: &str, text: &str) -> u32 {
|
||||||
// If we can't get the bpe for the model, fallback to a safe default (cl100k_base for GPT-4/o1)
|
// If we can't get the bpe for the model, fallback to a safe default (cl100k_base for GPT-4/o1)
|
||||||
let bpe = get_bpe_from_model(model).unwrap_or_else(|_| {
|
let bpe = get_bpe_from_model(model)
|
||||||
tiktoken_rs::cl100k_base().expect("Failed to get cl100k_base encoding")
|
.unwrap_or_else(|_| tiktoken_rs::cl100k_base().expect("Failed to get cl100k_base encoding"));
|
||||||
});
|
|
||||||
|
|
||||||
bpe.encode_with_special_tokens(text).len() as u32
|
bpe.encode_with_special_tokens(text).len() as u32
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ curl -s http://localhost:8080/api/auth/status | jq . 2>/dev/null || echo "JSON r
|
|||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
echo "Dashboard should be available at: http://localhost:8080"
|
echo "Dashboard should be available at: http://localhost:8080"
|
||||||
echo "Default login: admin / admin123"
|
echo "Default login: admin / admin"
|
||||||
echo ""
|
echo ""
|
||||||
echo "Press Ctrl+C to stop the server"
|
echo "Press Ctrl+C to stop the server"
|
||||||
|
|
||||||
|
|||||||
@@ -1,188 +0,0 @@
|
|||||||
// Integration tests for LLM Proxy Gateway
|
|
||||||
|
|
||||||
use llm_proxy::config::Config;
|
|
||||||
use llm_proxy::database::Database;
|
|
||||||
use llm_proxy::state::AppState;
|
|
||||||
use llm_proxy::rate_limiting::RateLimitManager;
|
|
||||||
use tempfile::TempDir;
|
|
||||||
use std::fs;
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_config_loading() {
|
|
||||||
// Create a temporary config file
|
|
||||||
let temp_dir = TempDir::new().unwrap();
|
|
||||||
let config_path = temp_dir.path().join("config.toml");
|
|
||||||
|
|
||||||
let config_content = r#"
|
|
||||||
[server]
|
|
||||||
port = 8080
|
|
||||||
host = "0.0.0.0"
|
|
||||||
|
|
||||||
[database]
|
|
||||||
path = "./data/test.db"
|
|
||||||
max_connections = 5
|
|
||||||
|
|
||||||
[providers.openai]
|
|
||||||
enabled = true
|
|
||||||
base_url = "https://api.openai.com/v1"
|
|
||||||
|
|
||||||
[providers.gemini]
|
|
||||||
enabled = true
|
|
||||||
base_url = "https://generativelanguage.googleapis.com/v1"
|
|
||||||
|
|
||||||
[providers.deepseek]
|
|
||||||
enabled = true
|
|
||||||
base_url = "https://api.deepseek.com"
|
|
||||||
|
|
||||||
[providers.grok]
|
|
||||||
enabled = false
|
|
||||||
base_url = "https://api.x.ai/v1"
|
|
||||||
|
|
||||||
[model_mapping]
|
|
||||||
"gpt-*" = "openai"
|
|
||||||
"gemini-*" = "gemini"
|
|
||||||
"deepseek-*" = "deepseek"
|
|
||||||
"grok-*" = "grok"
|
|
||||||
|
|
||||||
[pricing]
|
|
||||||
openai = { input = 0.01, output = 0.03 }
|
|
||||||
gemini = { input = 0.0005, output = 0.0015 }
|
|
||||||
deepseek = { input = 0.00014, output = 0.00028 }
|
|
||||||
grok = { input = 0.001, output = 0.003 }
|
|
||||||
"#;
|
|
||||||
|
|
||||||
fs::write(&config_path, config_content).unwrap();
|
|
||||||
|
|
||||||
// Test loading config
|
|
||||||
let config = Config::load_from_path(&config_path);
|
|
||||||
assert!(config.is_ok());
|
|
||||||
|
|
||||||
let config = config.unwrap();
|
|
||||||
assert_eq!(config.server.port, 8080);
|
|
||||||
assert!(config.providers.openai.is_some());
|
|
||||||
assert!(config.providers.grok.is_none());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_database_initialization() {
|
|
||||||
// Create a temporary database file
|
|
||||||
let temp_dir = TempDir::new().unwrap();
|
|
||||||
let db_path = temp_dir.path().join("test.db");
|
|
||||||
|
|
||||||
// Test database initialization
|
|
||||||
let database = Database::new(&db_path).await;
|
|
||||||
assert!(database.is_ok());
|
|
||||||
|
|
||||||
let database = database.unwrap();
|
|
||||||
|
|
||||||
// Test connection
|
|
||||||
let test_result = database.test_connection().await;
|
|
||||||
assert!(test_result.is_ok());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_provider_manager() {
|
|
||||||
// Create a provider manager
|
|
||||||
use llm_proxy::providers::{ProviderManager, Provider};
|
|
||||||
use llm_proxy::config::OpenAIConfig;
|
|
||||||
|
|
||||||
let mut manager = ProviderManager::new();
|
|
||||||
assert_eq!(manager.providers.len(), 0);
|
|
||||||
|
|
||||||
// Test adding providers (we can't actually add real providers without API keys)
|
|
||||||
// This test just verifies the manager structure works
|
|
||||||
assert!(manager.get_provider_for_model("gpt-4").is_none());
|
|
||||||
assert!(manager.get_provider("openai").is_none());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_rate_limit_manager() {
|
|
||||||
let manager = RateLimitManager::new(60, 10);
|
|
||||||
|
|
||||||
// Test client rate limiting
|
|
||||||
let allowed = manager.check_request("test-client").await;
|
|
||||||
assert!(allowed); // First request should be allowed
|
|
||||||
|
|
||||||
// Test provider circuit breaker
|
|
||||||
let allowed = manager.check_provider("openai").await;
|
|
||||||
assert!(allowed); // Circuit should be closed initially
|
|
||||||
|
|
||||||
// Record some failures
|
|
||||||
manager.record_provider_failure("openai").await;
|
|
||||||
manager.record_provider_failure("openai").await;
|
|
||||||
manager.record_provider_failure("openai").await;
|
|
||||||
manager.record_provider_failure("openai").await;
|
|
||||||
manager.record_provider_failure("openai").await;
|
|
||||||
|
|
||||||
// After 5 failures, circuit should be open
|
|
||||||
let allowed = manager.check_provider("openai").await;
|
|
||||||
assert!(!allowed); // Circuit should be open
|
|
||||||
|
|
||||||
// Record success to close circuit
|
|
||||||
manager.record_provider_success("openai").await;
|
|
||||||
manager.record_provider_success("openai").await;
|
|
||||||
manager.record_provider_success("openai").await;
|
|
||||||
|
|
||||||
// After 3 successes in half-open state, circuit should be closed
|
|
||||||
let allowed = manager.check_provider("openai").await;
|
|
||||||
assert!(allowed); // Circuit should be closed again
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_app_state_creation() {
|
|
||||||
// Create a temporary database
|
|
||||||
let temp_dir = TempDir::new().unwrap();
|
|
||||||
let db_path = temp_dir.path().join("test.db");
|
|
||||||
|
|
||||||
let database = Database::new(&db_path).await.unwrap();
|
|
||||||
|
|
||||||
// Test AppState creation using test utilities
|
|
||||||
use llm_proxy::test_utils::create_test_state;
|
|
||||||
let state = create_test_state().await;
|
|
||||||
|
|
||||||
// Verify state components are initialized
|
|
||||||
assert!(state.database.test_connection().await.is_ok());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_multimodal_image_converter() {
|
|
||||||
use llm_proxy::multimodal::{ImageConverter, ImageInput};
|
|
||||||
|
|
||||||
// Test model detection
|
|
||||||
assert!(ImageConverter::model_supports_multimodal("gpt-4-vision-preview"));
|
|
||||||
assert!(ImageConverter::model_supports_multimodal("gemini-pro-vision"));
|
|
||||||
assert!(!ImageConverter::model_supports_multimodal("gpt-3.5-turbo"));
|
|
||||||
assert!(!ImageConverter::model_supports_multimodal("gemini-pro"));
|
|
||||||
|
|
||||||
// Test data URL parsing (utility function)
|
|
||||||
let test_url = "data:image/jpeg;base64,SGVsbG8gV29ybGQ=";
|
|
||||||
let parts: Vec<&str> = test_url[5..].split(";base64,").collect();
|
|
||||||
assert_eq!(parts.len(), 2);
|
|
||||||
assert_eq!(parts[0], "image/jpeg");
|
|
||||||
assert_eq!(parts[1], "SGVsbG8gV29ybGQ=");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_error_conversions() {
|
|
||||||
use llm_proxy::errors::AppError;
|
|
||||||
use anyhow::anyhow;
|
|
||||||
|
|
||||||
// Test anyhow error conversion
|
|
||||||
let anyhow_error = anyhow!("Test error");
|
|
||||||
let app_error: AppError = anyhow_error.into();
|
|
||||||
|
|
||||||
match app_error {
|
|
||||||
AppError::InternalError(msg) => assert_eq!(msg, "Test error"),
|
|
||||||
_ => panic!("Expected InternalError"),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test sqlx error conversion
|
|
||||||
use sqlx::Error as SqlxError;
|
|
||||||
let sqlx_error = SqlxError::PoolClosed;
|
|
||||||
let app_error: AppError = sqlx_error.into();
|
|
||||||
|
|
||||||
match app_error {
|
|
||||||
AppError::DatabaseError(msg) => assert!(msg.contains("pool closed")),
|
|
||||||
_ => panic!("Expected DatabaseError"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user