diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..7e28b97c --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,61 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +env: + CARGO_TERM_COLOR: always + RUST_BACKTRACE: 1 + +jobs: + check: + name: Check + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + - uses: Swatinem/rust-cache@v2 + - run: cargo check --all-targets + + clippy: + name: Clippy + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + with: + components: clippy + - uses: Swatinem/rust-cache@v2 + - run: cargo clippy --all-targets -- -D warnings + + fmt: + name: Formatting + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt + - run: cargo fmt --all -- --check + + test: + name: Test + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + - uses: Swatinem/rust-cache@v2 + - run: cargo test --all-targets + + build-release: + name: Release Build + runs-on: ubuntu-latest + needs: [check, clippy, test] + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + - uses: Swatinem/rust-cache@v2 + - run: cargo build --release diff --git a/Cargo.lock b/Cargo.lock index 52499ef3..4cff23de 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -92,46 +92,6 @@ dependencies = [ "tokio", ] -[[package]] -name = "async-openai" -version = "0.33.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3ef6d66e53b3ec8ed4bae8e6dcdda75c0f5b4f67faba78fac1b09434cb4f2fc" -dependencies = [ - "async-openai-macros", - "backoff", - "base64 0.22.1", - "bytes", - "derive_builder", - "eventsource-stream", - "futures", - "getrandom 0.3.4", - "rand 0.9.2", - "reqwest", - "reqwest-eventsource", - "secrecy", - "serde", - "serde_json", - "serde_urlencoded", - "thiserror 2.0.18", - "tokio", - "tokio-stream", - "tokio-util", - "tracing", - "url", -] - -[[package]] -name = "async-openai-macros" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81872a8e595e8ceceab71c6ba1f9078e313b452a1e31934e6763ef5d308705e4" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "async-stream" version = "0.3.6" @@ -275,20 +235,6 @@ dependencies = [ "syn", ] -[[package]] -name = "backoff" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1" -dependencies = [ - "futures-core", - "getrandom 0.2.17", - "instant", - "pin-project-lite", - "rand 0.8.5", - "tokio", -] - [[package]] name = "base64" version = "0.13.1" @@ -598,54 +544,6 @@ dependencies = [ "typenum", ] -[[package]] -name = "darling" -version = "0.20.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" -dependencies = [ - "darling_core", - "darling_macro", -] - -[[package]] -name = "darling_core" -version = "0.20.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d00b9596d185e565c2207a0b01f8bd1a135483d02d9b7b0a54b11da8d53412e" -dependencies = [ - "fnv", - "ident_case", - "proc-macro2", - "quote", - "strsim", - "syn", -] - -[[package]] -name = "darling_macro" -version = "0.20.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" -dependencies = [ - "darling_core", - "quote", - "syn", -] - -[[package]] -name = "dashmap" -version = "5.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" -dependencies = [ - "cfg-if", - "hashbrown 0.14.5", - "lock_api", - "once_cell", - "parking_lot_core", -] - [[package]] name = "data-encoding" version = "2.10.0" @@ -663,37 +561,6 @@ dependencies = [ "zeroize", ] -[[package]] -name = "derive_builder" -version = "0.20.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947" -dependencies = [ - "derive_builder_macro", -] - -[[package]] -name = "derive_builder_core" -version = "0.20.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" -dependencies = [ - "darling", - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "derive_builder_macro" -version = "0.20.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" -dependencies = [ - "derive_builder_core", - "syn", -] - [[package]] name = "difflib" version = "0.4.0" @@ -1028,26 +895,6 @@ dependencies = [ "wasip3", ] -[[package]] -name = "governor" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68a7f542ee6b35af73b06abc0dad1c1bae89964e4e253bc4b587b91c9637867b" -dependencies = [ - "cfg-if", - "dashmap", - "futures", - "futures-timer", - "no-std-compat", - "nonzero_ext", - "parking_lot", - "portable-atomic", - "quanta", - "rand 0.8.5", - "smallvec", - "spinning_top", -] - [[package]] name = "h2" version = "0.4.13" @@ -1076,12 +923,6 @@ dependencies = [ "ahash", ] -[[package]] -name = "hashbrown" -version = "0.14.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" - [[package]] name = "hashbrown" version = "0.15.5" @@ -1396,12 +1237,6 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" -[[package]] -name = "ident_case" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" - [[package]] name = "idna" version = "1.1.0" @@ -1482,15 +1317,6 @@ dependencies = [ "tempfile", ] -[[package]] -name = "instant" -version = "0.1.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222" -dependencies = [ - "cfg-if", -] - [[package]] name = "ipnet" version = "2.11.0" @@ -1607,7 +1433,6 @@ version = "0.1.0" dependencies = [ "anyhow", "assert_cmd", - "async-openai", "async-stream", "async-trait", "axum", @@ -1618,16 +1443,11 @@ dependencies = [ "config", "dotenvy", "futures", - "governor", "headers", "image", "insta", "mime", - "mime_guess", "mockito", - "once_cell", - "rand 0.8.5", - "regex", "reqwest", "reqwest-eventsource", "serde", @@ -1776,12 +1596,6 @@ dependencies = [ "pxfm", ] -[[package]] -name = "no-std-compat" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c" - [[package]] name = "nom" version = "7.1.3" @@ -1792,12 +1606,6 @@ dependencies = [ "minimal-lexical", ] -[[package]] -name = "nonzero_ext" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21" - [[package]] name = "nu-ansi-term" version = "0.50.3" @@ -2014,12 +1822,6 @@ dependencies = [ "miniz_oxide", ] -[[package]] -name = "portable-atomic" -version = "1.13.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" - [[package]] name = "potential_utf" version = "0.1.4" @@ -2093,21 +1895,6 @@ dependencies = [ "num-traits", ] -[[package]] -name = "quanta" -version = "0.12.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3ab5a9d756f0d97bdc89019bd2e4ea098cf9cde50ee7564dde6b81ccc8f06c7" -dependencies = [ - "crossbeam-utils", - "libc", - "once_cell", - "raw-cpuid", - "wasi", - "web-sys", - "winapi", -] - [[package]] name = "quick-error" version = "2.0.1" @@ -2243,15 +2030,6 @@ dependencies = [ "getrandom 0.3.4", ] -[[package]] -name = "raw-cpuid" -version = "11.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186" -dependencies = [ - "bitflags 2.11.0", -] - [[package]] name = "redox_syscall" version = "0.5.18" @@ -2317,7 +2095,6 @@ dependencies = [ "hyper-util", "js-sys", "log", - "mime_guess", "percent-encoding", "pin-project-lite", "quinn", @@ -2490,16 +2267,6 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" -[[package]] -name = "secrecy" -version = "0.10.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e891af845473308773346dc847b2c23ee78fe442e0472ac50e22a18a93d3ae5a" -dependencies = [ - "serde", - "zeroize", -] - [[package]] name = "semver" version = "1.0.27" @@ -2684,15 +2451,6 @@ dependencies = [ "lock_api", ] -[[package]] -name = "spinning_top" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300" -dependencies = [ - "lock_api", -] - [[package]] name = "spki" version = "0.7.3" @@ -2912,12 +2670,6 @@ dependencies = [ "unicode-properties", ] -[[package]] -name = "strsim" -version = "0.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" - [[package]] name = "subtle" version = "2.6.1" @@ -3657,28 +3409,6 @@ dependencies = [ "wasite", ] -[[package]] -name = "winapi" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" -dependencies = [ - "winapi-i686-pc-windows-gnu", - "winapi-x86_64-pc-windows-gnu", -] - -[[package]] -name = "winapi-i686-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" - -[[package]] -name = "winapi-x86_64-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" - [[package]] name = "windows-core" version = "0.62.2" diff --git a/Cargo.toml b/Cargo.toml index 3d341e90..2888d4c2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,6 @@ tower-http = { version = "0.6", features = ["trace", "cors", "compression-gzip", # ========== HTTP Clients ========== reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] } -async-openai = { version = "0.33", default-features = false, features = ["_api", "chat-completion"] } tiktoken-rs = "0.9" # ========== Database & ORM ========== @@ -41,7 +40,6 @@ tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] } base64 = "0.21" image = { version = "0.25", default-features = false, features = ["jpeg", "png", "webp"] } mime = "0.3" -mime_guess = "2.0" # ========== Error Handling & Utilities ========== anyhow = "1.0" @@ -53,12 +51,6 @@ futures = "0.3" async-trait = "0.1" async-stream = "0.3" reqwest-eventsource = "0.6" -once_cell = "1.19" -regex = "1.10" -rand = "0.8" - -# ========== Rate Limiting & Circuit Breaking ========== -governor = "0.6" [dev-dependencies] tokio-test = "0.4" diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..8c349f9a --- /dev/null +++ b/Dockerfile @@ -0,0 +1,35 @@ +# ── Build stage ────────────────────────────────────────────── +FROM rust:1-bookworm AS builder + +WORKDIR /app + +# Cache dependency build +COPY Cargo.toml Cargo.lock ./ +RUN mkdir src && echo 'fn main() {}' > src/main.rs && \ + cargo build --release && \ + rm -rf src + +# Build the actual binary +COPY src/ src/ +RUN touch src/main.rs && cargo build --release + +# ── Runtime stage ──────────────────────────────────────────── +FROM debian:bookworm-slim + +RUN apt-get update && \ + apt-get install -y --no-install-recommends ca-certificates && \ + rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +COPY --from=builder /app/target/release/llm-proxy /app/llm-proxy +COPY static/ /app/static/ + +# Default config location +VOLUME ["/app/config", "/app/data"] + +EXPOSE 8080 + +ENV RUST_LOG=info + +ENTRYPOINT ["/app/llm-proxy"] diff --git a/clippy.toml b/clippy.toml new file mode 100644 index 00000000..ba34bf46 --- /dev/null +++ b/clippy.toml @@ -0,0 +1 @@ +too-many-arguments-threshold = 8 diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 00000000..53f860ab --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,2 @@ +max_width = 120 +use_field_init_shorthand = true diff --git a/src/auth/mod.rs b/src/auth/mod.rs index f91e3f13..8bb3ceef 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -1,6 +1,6 @@ use axum::{extract::FromRequestParts, http::request::Parts}; -use axum_extra::headers::Authorization; use axum_extra::TypedHeader; +use axum_extra::headers::Authorization; use headers::authorization::Bearer; use crate::errors::AppError; @@ -16,32 +16,18 @@ where { type Rejection = AppError; - fn from_request_parts(parts: &mut Parts, state: &S) -> impl std::future::Future> + Send { - // We need access to the AppState to get valid tokens - // Since state is generic here, we try to cast it or assume it's available via extensions - // In this project, AppState is cloned into Axum state. - - async move { - // Extract bearer token from Authorization header - let TypedHeader(Authorization(bearer)) = - TypedHeader::>::from_request_parts(parts, state) - .await - .map_err(|_| AppError::AuthError("Missing or invalid bearer token".to_string()))?; + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + // Extract bearer token from Authorization header + let TypedHeader(Authorization(bearer)) = TypedHeader::>::from_request_parts(parts, state) + .await + .map_err(|_| AppError::AuthError("Missing or invalid bearer token".to_string()))?; - let token = bearer.token().to_string(); - - // For a proxy, we want to check if this token is in our allowed list - // The list is stored in AppState which is available in Parts extensions - let client_id = { - // In main.rs, we set up the router with State(state). - // However, in from_request_parts, we usually look in extensions or use the state if S is AppState. - // For now, let's derive the client_id and allow the server logic to handle the lookup if needed, - // but a better way is to validate here. - format!("client_{}", &token[..8.min(token.len())]) - }; - - Ok(AuthenticatedClient { token, client_id }) - } + let token = bearer.token().to_string(); + + // Derive client_id from the token prefix + let client_id = format!("client_{}", &token[..8.min(token.len())]); + + Ok(AuthenticatedClient { token, client_id }) } } @@ -49,4 +35,4 @@ pub fn validate_token(token: &str, valid_tokens: &[String]) -> bool { // Simple validation against list of tokens // In production, use proper token validation (JWT, database lookup, etc.) valid_tokens.contains(&token.to_string()) -} \ No newline at end of file +} diff --git a/src/client/mod.rs b/src/client/mod.rs index c11ad22c..51b51f2b 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -5,10 +5,10 @@ //! 2. Client usage tracking //! 3. Client rate limit configuration +use anyhow::Result; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; -use sqlx::{SqlitePool, Row}; -use anyhow::Result; +use sqlx::{Row, SqlitePool}; use tracing::{info, warn}; /// Client information @@ -58,7 +58,7 @@ impl ClientManager { /// Create a new client pub async fn create_client(&self, request: CreateClientRequest) -> Result { let rate_limit = request.rate_limit_per_minute.unwrap_or(60); - + // First insert the client sqlx::query( r#" @@ -72,11 +72,13 @@ impl ClientManager { .bind(rate_limit) .execute(&self.db_pool) .await?; - + // Then fetch the created client - let client = self.get_client(&request.client_id).await? + let client = self + .get_client(&request.client_id) + .await? .ok_or_else(|| anyhow::anyhow!("Failed to retrieve created client"))?; - + info!("Created client: {} ({})", client.name, client.client_id); Ok(client) } @@ -96,7 +98,7 @@ impl ClientManager { .bind(client_id) .fetch_optional(&self.db_pool) .await?; - + if let Some(row) = row { let client = Client { id: row.get("id"), @@ -124,69 +126,68 @@ impl ClientManager { if current_client.is_none() { return Ok(None); } - + // Build update query dynamically based on provided fields - let mut updates = Vec::new(); let mut query_builder = sqlx::QueryBuilder::new("UPDATE clients SET "); let mut has_updates = false; - + if let Some(name) = &request.name { - updates.push("name = "); + query_builder.push("name = "); query_builder.push_bind(name); has_updates = true; } - + if let Some(description) = &request.description { if has_updates { query_builder.push(", "); } - updates.push("description = "); + query_builder.push("description = "); query_builder.push_bind(description); has_updates = true; } - + if let Some(is_active) = request.is_active { if has_updates { query_builder.push(", "); } - updates.push("is_active = "); + query_builder.push("is_active = "); query_builder.push_bind(is_active); has_updates = true; } - + if let Some(rate_limit) = request.rate_limit_per_minute { if has_updates { query_builder.push(", "); } - updates.push("rate_limit_per_minute = "); + query_builder.push("rate_limit_per_minute = "); query_builder.push_bind(rate_limit); has_updates = true; } - + // Always update the updated_at timestamp if has_updates { query_builder.push(", "); } query_builder.push("updated_at = CURRENT_TIMESTAMP"); - + if !has_updates { // No updates to make return self.get_client(client_id).await; } - + query_builder.push(" WHERE client_id = "); query_builder.push_bind(client_id); - + let query = query_builder.build(); query.execute(&self.db_pool).await?; - + // Fetch the updated client let updated_client = self.get_client(client_id).await?; - + if updated_client.is_some() { info!("Updated client: {}", client_id); } - + Ok(updated_client) } @@ -194,7 +195,7 @@ impl ClientManager { pub async fn list_clients(&self, limit: Option, offset: Option) -> Result> { let limit = limit.unwrap_or(100); let offset = offset.unwrap_or(0); - + let rows = sqlx::query( r#" SELECT @@ -204,13 +205,13 @@ impl ClientManager { FROM clients ORDER BY created_at DESC LIMIT ? OFFSET ? - "# + "#, ) .bind(limit) .bind(offset) .fetch_all(&self.db_pool) .await?; - + let mut clients = Vec::new(); for row in rows { let client = Client { @@ -228,37 +229,30 @@ impl ClientManager { }; clients.push(client); } - + Ok(clients) } /// Delete a client pub async fn delete_client(&self, client_id: &str) -> Result { - let result = sqlx::query( - "DELETE FROM clients WHERE client_id = ?" - ) - .bind(client_id) - .execute(&self.db_pool) - .await?; - + let result = sqlx::query("DELETE FROM clients WHERE client_id = ?") + .bind(client_id) + .execute(&self.db_pool) + .await?; + let deleted = result.rows_affected() > 0; - + if deleted { info!("Deleted client: {}", client_id); } else { warn!("Client not found for deletion: {}", client_id); } - + Ok(deleted) } /// Update client usage statistics after a request - pub async fn update_client_usage( - &self, - client_id: &str, - tokens: i64, - cost: f64, - ) -> Result<()> { + pub async fn update_client_usage(&self, client_id: &str, tokens: i64, cost: f64) -> Result<()> { sqlx::query( r#" UPDATE clients @@ -268,14 +262,14 @@ impl ClientManager { total_cost = total_cost + ?, updated_at = CURRENT_TIMESTAMP WHERE client_id = ? - "# + "#, ) .bind(tokens) .bind(cost) .bind(client_id) .execute(&self.db_pool) .await?; - + Ok(()) } @@ -286,12 +280,12 @@ impl ClientManager { SELECT total_requests, total_tokens, total_cost FROM clients WHERE client_id = ? - "# + "#, ) .bind(client_id) .fetch_optional(&self.db_pool) .await?; - + if let Some(row) = row { let total_requests: i64 = row.get("total_requests"); let total_tokens: i64 = row.get("total_tokens"); @@ -307,4 +301,4 @@ impl ClientManager { let client = self.get_client(client_id).await?; Ok(client.map(|c| c.is_active).unwrap_or(false)) } -} \ No newline at end of file +} diff --git a/src/config/mod.rs b/src/config/mod.rs index f478cc88..44e0151e 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -95,14 +95,14 @@ pub struct AppConfig { pub providers: ProviderConfig, pub model_mapping: ModelMappingConfig, pub pricing: PricingConfig, - pub config_path: PathBuf, + pub config_path: Option, } impl AppConfig { pub async fn load() -> Result> { Self::load_from_path(None).await } - + /// Load configuration from a specific path (for testing) pub async fn load_from_path(config_path: Option) -> Result> { // Load configuration from multiple sources @@ -120,7 +120,10 @@ impl AppConfig { .set_default("providers.openai.default_model", "gpt-4o")? .set_default("providers.openai.enabled", true)? .set_default("providers.gemini.api_key_env", "GEMINI_API_KEY")? - .set_default("providers.gemini.base_url", "https://generativelanguage.googleapis.com/v1")? + .set_default( + "providers.gemini.base_url", + "https://generativelanguage.googleapis.com/v1", + )? .set_default("providers.gemini.default_model", "gemini-2.0-flash")? .set_default("providers.gemini.enabled", true)? .set_default("providers.deepseek.api_key_env", "DEEPSEEK_API_KEY")? @@ -136,7 +139,11 @@ impl AppConfig { .set_default("providers.ollama.models", Vec::::new())?; // Load from config file if exists - let config_path = config_path.unwrap_or_else(|| std::env::current_dir().unwrap().join("config.toml")); + let config_path = config_path.unwrap_or_else(|| { + std::env::current_dir() + .unwrap_or_else(|_| PathBuf::from(".")) + .join("config.toml") + }); if config_path.exists() { config_builder = config_builder.add_source(File::from(config_path.clone()).format(FileFormat::Toml)); } @@ -157,7 +164,7 @@ impl AppConfig { let server: ServerConfig = config.get("server")?; let database: DatabaseConfig = config.get("database")?; let providers: ProviderConfig = config.get("providers")?; - + // For now, use empty model mapping and pricing (will be populated later) let model_mapping = ModelMappingConfig { patterns: vec![] }; let pricing = PricingConfig { @@ -174,7 +181,7 @@ impl AppConfig { providers, model_mapping, pricing, - config_path, + config_path: Some(config_path), })) } @@ -187,48 +194,46 @@ impl AppConfig { _ => return Err(anyhow::anyhow!("Unknown provider: {}", provider)), }; - std::env::var(env_var) - .map_err(|_| anyhow::anyhow!("Environment variable {} not set for {}", env_var, provider)) - } + std::env::var(env_var).map_err(|_| anyhow::anyhow!("Environment variable {} not set for {}", env_var, provider)) + } +} + +/// Helper function to deserialize a Vec from either a sequence or a comma-separated string +fn deserialize_vec_or_string<'de, D>(deserializer: D) -> Result, D::Error> +where + D: serde::Deserializer<'de>, +{ + struct VecOrString; + + impl<'de> serde::de::Visitor<'de> for VecOrString { + type Value = Vec; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a sequence or a comma-separated string") } - - /// Helper function to deserialize a Vec from either a sequence or a comma-separated string - fn deserialize_vec_or_string<'de, D>(deserializer: D) -> Result, D::Error> + + fn visit_str(self, value: &str) -> Result where - D: serde::Deserializer<'de>, + E: serde::de::Error, { - struct VecOrString; - - impl<'de> serde::de::Visitor<'de> for VecOrString { - type Value = Vec; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a sequence or a comma-separated string") - } - - fn visit_str(self, value: &str) -> Result - where - E: serde::de::Error, - { - Ok(value - .split(',') - .map(|s| s.trim().to_string()) - .filter(|s| !s.is_empty()) - .collect()) - } - - fn visit_seq(self, mut seq: S) -> Result - where - S: serde::de::SeqAccess<'de>, - { - let mut vec = Vec::new(); - while let Some(element) = seq.next_element()? { - vec.push(element); - } - Ok(vec) - } - } - - deserializer.deserialize_any(VecOrString) + Ok(value + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect()) } - \ No newline at end of file + + fn visit_seq(self, mut seq: S) -> Result + where + S: serde::de::SeqAccess<'de>, + { + let mut vec = Vec::new(); + while let Some(element) = seq.next_element()? { + vec.push(element); + } + Ok(vec) + } + } + + deserializer.deserialize_any(VecOrString) +} diff --git a/src/dashboard/auth.rs b/src/dashboard/auth.rs new file mode 100644 index 00000000..2b7d9e6a --- /dev/null +++ b/src/dashboard/auth.rs @@ -0,0 +1,130 @@ +use axum::{extract::State, response::Json}; +use bcrypt; +use serde::Deserialize; +use sqlx::Row; +use tracing::warn; + +use super::{ApiResponse, DashboardState}; + +// Authentication handlers +#[derive(Deserialize)] +pub(super) struct LoginRequest { + pub(super) username: String, + pub(super) password: String, +} + +pub(super) async fn handle_login( + State(state): State, + Json(payload): Json, +) -> Json> { + let pool = &state.app_state.db_pool; + + let user_result = + sqlx::query("SELECT username, password_hash, role, must_change_password FROM users WHERE username = ?") + .bind(&payload.username) + .fetch_optional(pool) + .await; + + match user_result { + Ok(Some(row)) => { + let hash = row.get::("password_hash"); + if bcrypt::verify(&payload.password, &hash).unwrap_or(false) { + let username = row.get::("username"); + let role = row.get::("role"); + let must_change_password = row.get::("must_change_password"); + let token = state + .session_manager + .create_session(username.clone(), role.clone()) + .await; + Json(ApiResponse::success(serde_json::json!({ + "token": token, + "must_change_password": must_change_password, + "user": { + "username": username, + "name": "Administrator", + "role": role + } + }))) + } else { + Json(ApiResponse::error("Invalid username or password".to_string())) + } + } + Ok(None) => Json(ApiResponse::error("Invalid username or password".to_string())), + Err(e) => { + warn!("Database error during login: {}", e); + Json(ApiResponse::error("Login failed due to system error".to_string())) + } + } +} + +pub(super) async fn handle_auth_status( + State(state): State, + headers: axum::http::HeaderMap, +) -> Json> { + let token = headers + .get("Authorization") + .and_then(|v| v.to_str().ok()) + .and_then(|v| v.strip_prefix("Bearer ")); + + if let Some(token) = token + && let Some(session) = state.session_manager.validate_session(token).await + { + return Json(ApiResponse::success(serde_json::json!({ + "authenticated": true, + "user": { + "username": session.username, + "name": "Administrator", + "role": session.role + } + }))); + } + + Json(ApiResponse::error("Not authenticated".to_string())) +} + +#[derive(Deserialize)] +pub(super) struct ChangePasswordRequest { + pub(super) current_password: String, + pub(super) new_password: String, +} + +pub(super) async fn handle_change_password( + State(state): State, + Json(payload): Json, +) -> Json> { + let pool = &state.app_state.db_pool; + + // For now, always change 'admin' user + let user_result = sqlx::query("SELECT password_hash FROM users WHERE username = 'admin'") + .fetch_one(pool) + .await; + + match user_result { + Ok(row) => { + let hash = row.get::("password_hash"); + if bcrypt::verify(&payload.current_password, &hash).unwrap_or(false) { + let new_hash = match bcrypt::hash(&payload.new_password, 12) { + Ok(h) => h, + Err(_) => return Json(ApiResponse::error("Failed to hash new password".to_string())), + }; + + let update_result = sqlx::query( + "UPDATE users SET password_hash = ?, must_change_password = FALSE WHERE username = 'admin'", + ) + .bind(new_hash) + .execute(pool) + .await; + + match update_result { + Ok(_) => Json(ApiResponse::success( + serde_json::json!({ "message": "Password updated successfully" }), + )), + Err(e) => Json(ApiResponse::error(format!("Failed to update database: {}", e))), + } + } else { + Json(ApiResponse::error("Current password incorrect".to_string())) + } + } + Err(e) => Json(ApiResponse::error(format!("User not found: {}", e))), + } +} diff --git a/src/dashboard/clients.rs b/src/dashboard/clients.rs new file mode 100644 index 00000000..86c4450f --- /dev/null +++ b/src/dashboard/clients.rs @@ -0,0 +1,227 @@ +use axum::{ + extract::{Path, State}, + response::Json, +}; +use chrono; +use serde::Deserialize; +use serde_json; +use sqlx::Row; +use tracing::warn; +use uuid; + +use super::{ApiResponse, DashboardState}; + +#[derive(Deserialize)] +pub(super) struct CreateClientRequest { + pub(super) name: String, + pub(super) client_id: Option, +} + +pub(super) async fn handle_get_clients(State(state): State) -> Json> { + let pool = &state.app_state.db_pool; + + let result = sqlx::query( + r#" + SELECT + client_id as id, + name, + created_at, + total_requests, + total_tokens, + total_cost, + is_active + FROM clients + ORDER BY created_at DESC + "#, + ) + .fetch_all(pool) + .await; + + match result { + Ok(rows) => { + let clients: Vec = rows + .into_iter() + .map(|row| { + serde_json::json!({ + "id": row.get::("id"), + "name": row.get::, _>("name").unwrap_or_else(|| "Unnamed".to_string()), + "created_at": row.get::, _>("created_at"), + "requests_count": row.get::("total_requests"), + "total_tokens": row.get::("total_tokens"), + "total_cost": row.get::("total_cost"), + "status": if row.get::("is_active") { "active" } else { "inactive" }, + }) + }) + .collect(); + + Json(ApiResponse::success(serde_json::json!(clients))) + } + Err(e) => { + warn!("Failed to fetch clients: {}", e); + Json(ApiResponse::error("Failed to fetch clients".to_string())) + } + } +} + +pub(super) async fn handle_create_client( + State(state): State, + Json(payload): Json, +) -> Json> { + let pool = &state.app_state.db_pool; + + let client_id = payload + .client_id + .unwrap_or_else(|| format!("client-{}", &uuid::Uuid::new_v4().to_string()[..8])); + + let result = sqlx::query( + r#" + INSERT INTO clients (client_id, name, is_active) + VALUES (?, ?, TRUE) + RETURNING * + "#, + ) + .bind(&client_id) + .bind(&payload.name) + .fetch_one(pool) + .await; + + match result { + Ok(row) => Json(ApiResponse::success(serde_json::json!({ + "id": row.get::("client_id"), + "name": row.get::, _>("name"), + "created_at": row.get::, _>("created_at"), + "status": "active", + }))), + Err(e) => { + warn!("Failed to create client: {}", e); + Json(ApiResponse::error(format!("Failed to create client: {}", e))) + } + } +} + +pub(super) async fn handle_get_client( + State(state): State, + Path(id): Path, +) -> Json> { + let pool = &state.app_state.db_pool; + + let result = sqlx::query( + r#" + SELECT + c.client_id as id, + c.name, + c.is_active, + c.created_at, + COALESCE(c.total_tokens, 0) as total_tokens, + COALESCE(c.total_cost, 0.0) as total_cost, + COUNT(r.id) as total_requests, + MAX(r.timestamp) as last_request + FROM clients c + LEFT JOIN llm_requests r ON c.client_id = r.client_id + WHERE c.client_id = ? + GROUP BY c.client_id + "#, + ) + .bind(&id) + .fetch_optional(pool) + .await; + + match result { + Ok(Some(row)) => Json(ApiResponse::success(serde_json::json!({ + "id": row.get::("id"), + "name": row.get::, _>("name").unwrap_or_else(|| "Unnamed".to_string()), + "is_active": row.get::("is_active"), + "created_at": row.get::, _>("created_at"), + "total_tokens": row.get::("total_tokens"), + "total_cost": row.get::("total_cost"), + "total_requests": row.get::("total_requests"), + "last_request": row.get::>, _>("last_request"), + "status": if row.get::("is_active") { "active" } else { "inactive" }, + }))), + Ok(None) => Json(ApiResponse::error(format!("Client '{}' not found", id))), + Err(e) => { + warn!("Failed to fetch client: {}", e); + Json(ApiResponse::error(format!("Failed to fetch client: {}", e))) + } + } +} + +pub(super) async fn handle_delete_client( + State(state): State, + Path(id): Path, +) -> Json> { + let pool = &state.app_state.db_pool; + + // Don't allow deleting the default client + if id == "default" { + return Json(ApiResponse::error("Cannot delete default client".to_string())); + } + + let result = sqlx::query("DELETE FROM clients WHERE client_id = ?") + .bind(id) + .execute(pool) + .await; + + match result { + Ok(_) => Json(ApiResponse::success(serde_json::json!({ "message": "Client deleted" }))), + Err(e) => Json(ApiResponse::error(format!("Failed to delete client: {}", e))), + } +} + +pub(super) async fn handle_client_usage( + State(state): State, + Path(id): Path, +) -> Json> { + let pool = &state.app_state.db_pool; + + // Get per-model breakdown for this client + let result = sqlx::query( + r#" + SELECT + model, + provider, + COUNT(*) as request_count, + SUM(prompt_tokens) as prompt_tokens, + SUM(completion_tokens) as completion_tokens, + SUM(total_tokens) as total_tokens, + SUM(cost) as total_cost, + AVG(duration_ms) as avg_duration_ms + FROM llm_requests + WHERE client_id = ? + GROUP BY model, provider + ORDER BY total_cost DESC + "#, + ) + .bind(&id) + .fetch_all(pool) + .await; + + match result { + Ok(rows) => { + let breakdown: Vec = rows + .into_iter() + .map(|row| { + serde_json::json!({ + "model": row.get::("model"), + "provider": row.get::("provider"), + "request_count": row.get::("request_count"), + "prompt_tokens": row.get::("prompt_tokens"), + "completion_tokens": row.get::("completion_tokens"), + "total_tokens": row.get::("total_tokens"), + "total_cost": row.get::("total_cost"), + "avg_duration_ms": row.get::("avg_duration_ms"), + }) + }) + .collect(); + + Json(ApiResponse::success(serde_json::json!({ + "client_id": id, + "breakdown": breakdown, + }))) + } + Err(e) => { + warn!("Failed to fetch client usage: {}", e); + Json(ApiResponse::error(format!("Failed to fetch client usage: {}", e))) + } + } +} diff --git a/src/dashboard/mod.rs b/src/dashboard/mod.rs index 9b8c5507..91d3c03c 100644 --- a/src/dashboard/mod.rs +++ b/src/dashboard/mod.rs @@ -1,22 +1,28 @@ // Dashboard module for LLM Proxy Gateway +mod auth; +mod clients; +mod models; +mod providers; +pub mod sessions; +mod system; +mod usage; +mod websocket; + use axum::{ - extract::{ws::{Message, WebSocket, WebSocketUpgrade}, State}, - response::{IntoResponse, Json}, - routing::{get, post, put}, Router, + routing::{get, post, put}, }; -use serde::{Deserialize, Serialize}; -use sqlx::Row; -use std::collections::HashMap; -use tracing::{info, warn}; +use serde::Serialize; use crate::state::AppState; +use sessions::SessionManager; // Dashboard state #[derive(Clone)] struct DashboardState { app_state: AppState, + session_manager: SessionManager, } // API Response types @@ -35,7 +41,7 @@ impl ApiResponse { error: None, } } - + fn error(error: String) -> Self { Self { success: false, @@ -45,1033 +51,52 @@ impl ApiResponse { } } -// ... (keep routes as they are) - // Dashboard routes pub fn router(state: AppState) -> Router { + let session_manager = SessionManager::new(24); // 24-hour session TTL let dashboard_state = DashboardState { app_state: state, + session_manager, }; - + Router::new() // Static file serving .fallback_service(tower_http::services::ServeDir::new("static")) - // WebSocket endpoint - .route("/ws", get(handle_websocket)) - + .route("/ws", get(websocket::handle_websocket)) // API endpoints - .route("/api/auth/login", post(handle_login)) - .route("/api/auth/status", get(handle_auth_status)) - .route("/api/auth/change-password", post(handle_change_password)) - .route("/api/usage/summary", get(handle_usage_summary)) - .route("/api/usage/time-series", get(handle_time_series)) - .route("/api/usage/clients", get(handle_clients_usage)) - .route("/api/usage/providers", get(handle_providers_usage)) - .route("/api/usage/detailed", get(handle_detailed_usage)) - .route("/api/analytics/breakdown", get(handle_analytics_breakdown)) - .route("/api/models", get(handle_get_models)) - .route("/api/models/{id}", put(handle_update_model)) - .route("/api/clients", get(handle_get_clients).post(handle_create_client)) - .route("/api/clients/{id}", get(handle_get_client).delete(handle_delete_client)) - .route("/api/clients/{id}/usage", get(handle_client_usage)) - .route("/api/providers", get(handle_get_providers)) - .route("/api/providers/{name}", get(handle_get_provider).put(handle_update_provider)) - .route("/api/providers/{name}/test", post(handle_test_provider)) - .route("/api/system/health", get(handle_system_health)) - .route("/api/system/logs", get(handle_system_logs)) - .route("/api/system/backup", post(handle_system_backup)) - .route("/api/system/settings", get(handle_get_settings).post(handle_update_settings)) - + .route("/api/auth/login", post(auth::handle_login)) + .route("/api/auth/status", get(auth::handle_auth_status)) + .route("/api/auth/change-password", post(auth::handle_change_password)) + .route("/api/usage/summary", get(usage::handle_usage_summary)) + .route("/api/usage/time-series", get(usage::handle_time_series)) + .route("/api/usage/clients", get(usage::handle_clients_usage)) + .route("/api/usage/providers", get(usage::handle_providers_usage)) + .route("/api/usage/detailed", get(usage::handle_detailed_usage)) + .route("/api/analytics/breakdown", get(usage::handle_analytics_breakdown)) + .route("/api/models", get(models::handle_get_models)) + .route("/api/models/{id}", put(models::handle_update_model)) + .route( + "/api/clients", + get(clients::handle_get_clients).post(clients::handle_create_client), + ) + .route( + "/api/clients/{id}", + get(clients::handle_get_client).delete(clients::handle_delete_client), + ) + .route("/api/clients/{id}/usage", get(clients::handle_client_usage)) + .route("/api/providers", get(providers::handle_get_providers)) + .route( + "/api/providers/{name}", + get(providers::handle_get_provider).put(providers::handle_update_provider), + ) + .route("/api/providers/{name}/test", post(providers::handle_test_provider)) + .route("/api/system/health", get(system::handle_system_health)) + .route("/api/system/logs", get(system::handle_system_logs)) + .route("/api/system/backup", post(system::handle_system_backup)) + .route( + "/api/system/settings", + get(system::handle_get_settings).post(system::handle_update_settings), + ) .with_state(dashboard_state) } - -// WebSocket handler -async fn handle_websocket( - ws: WebSocketUpgrade, - State(state): State, -) -> impl IntoResponse { - ws.on_upgrade(|socket| handle_websocket_connection(socket, state)) -} - -async fn handle_websocket_connection(mut socket: WebSocket, state: DashboardState) { - info!("WebSocket connection established"); - - // Subscribe to events from the global bus - let mut rx = state.app_state.dashboard_tx.subscribe(); - - // Send initial connection message - let _ = socket.send(Message::Text( - serde_json::json!({ - "type": "connected", - "message": "Connected to LLM Proxy Dashboard" - }).to_string().into(), - )).await; - - // Handle incoming messages and broadcast events - loop { - tokio::select! { - // Receive broadcast events - Ok(event) = rx.recv() => { - let message = Message::Text(serde_json::to_string(&event).unwrap().into()); - if socket.send(message).await.is_err() { - break; - } - } - - // Receive WebSocket messages - result = socket.recv() => { - match result { - Some(Ok(Message::Text(text))) => { - handle_websocket_message(&text, &state).await; - } - _ => break, - } - } - } - } - - info!("WebSocket connection closed"); -} - -async fn handle_websocket_message(text: &str, state: &DashboardState) { - // Parse and handle WebSocket messages - if let Ok(data) = serde_json::from_str::(text) { - if let Some("ping") = data.get("type").and_then(|v| v.as_str()) { - let _ = state.app_state.dashboard_tx.send(serde_json::json!({ - "type": "pong", - "payload": {} - })); - } - } -} - -// Authentication handlers -#[derive(Deserialize)] -struct LoginRequest { - username: String, - password: String, -} - -async fn handle_login( - State(state): State, - Json(payload): Json, -) -> Json> { - let pool = &state.app_state.db_pool; - - let user_result = sqlx::query("SELECT username, password_hash, role FROM users WHERE username = ?") - .bind(&payload.username) - .fetch_optional(pool) - .await; - - match user_result { - Ok(Some(row)) => { - let hash = row.get::("password_hash"); - if bcrypt::verify(&payload.password, &hash).unwrap_or(false) { - Json(ApiResponse::success(serde_json::json!({ - "token": format!("session-{}", uuid::Uuid::new_v4()), - "user": { - "username": row.get::("username"), - "name": "Administrator", - "role": row.get::("role") - } - }))) - } else { - Json(ApiResponse::error("Invalid username or password".to_string())) - } - } - Ok(None) => Json(ApiResponse::error("Invalid username or password".to_string())), - Err(e) => { - warn!("Database error during login: {}", e); - Json(ApiResponse::error("Login failed due to system error".to_string())) - } - } -} - -async fn handle_auth_status(State(_state): State) -> Json> { - Json(ApiResponse::success(serde_json::json!({ - "authenticated": true, - "user": { - "username": "admin", - "name": "Administrator", - "role": "Super Admin" - } - }))) -} - -#[derive(Deserialize)] -struct ChangePasswordRequest { - current_password: String, - new_password: String, -} - -async fn handle_change_password( - State(state): State, - Json(payload): Json, -) -> Json> { - let pool = &state.app_state.db_pool; - - // For now, always change 'admin' user - let user_result = sqlx::query("SELECT password_hash FROM users WHERE username = 'admin'") - .fetch_one(pool) - .await; - - match user_result { - Ok(row) => { - let hash = row.get::("password_hash"); - if bcrypt::verify(&payload.current_password, &hash).unwrap_or(false) { - let new_hash = match bcrypt::hash(&payload.new_password, 12) { - Ok(h) => h, - Err(_) => return Json(ApiResponse::error("Failed to hash new password".to_string())), - }; - - let update_result = sqlx::query("UPDATE users SET password_hash = ? WHERE username = 'admin'") - .bind(new_hash) - .execute(pool) - .await; - - match update_result { - Ok(_) => Json(ApiResponse::success(serde_json::json!({ "message": "Password updated successfully" }))), - Err(e) => Json(ApiResponse::error(format!("Failed to update database: {}", e))), - } - } else { - Json(ApiResponse::error("Current password incorrect".to_string())) - } - } - Err(e) => Json(ApiResponse::error(format!("User not found: {}", e))), - } -} - -// Usage handlers -async fn handle_usage_summary(State(state): State) -> Json> { - let pool = &state.app_state.db_pool; - - // Total stats - let total_stats = sqlx::query( - r#" - SELECT - COUNT(*) as total_requests, - COALESCE(SUM(total_tokens), 0) as total_tokens, - COALESCE(SUM(cost), 0.0) as total_cost, - COUNT(DISTINCT client_id) as active_clients - FROM llm_requests - "# - ) - .fetch_one(pool); - - // Today's stats - let today = chrono::Utc::now().format("%Y-%m-%d").to_string(); - let today_stats = sqlx::query( - r#" - SELECT - COUNT(*) as today_requests, - COALESCE(SUM(total_tokens), 0) as today_tokens, - COALESCE(SUM(cost), 0.0) as today_cost - FROM llm_requests - WHERE strftime('%Y-%m-%d', timestamp) = ? - "# - ) - .bind(today) - .fetch_one(pool); - - // Error stats - let error_stats = sqlx::query( - r#" - SELECT - COUNT(*) as total, - SUM(CASE WHEN status = 'error' THEN 1 ELSE 0 END) as errors - FROM llm_requests - "# - ) - .fetch_one(pool); - - // Average response time - let avg_response = sqlx::query( - r#" - SELECT COALESCE(AVG(duration_ms), 0.0) as avg_duration - FROM llm_requests - WHERE status = 'success' - "# - ) - .fetch_one(pool); - - match tokio::join!(total_stats, today_stats, error_stats, avg_response) { - (Ok(t), Ok(d), Ok(e), Ok(a)) => { - let total_requests: i64 = t.get("total_requests"); - let total_tokens: i64 = t.get("total_tokens"); - let total_cost: f64 = t.get("total_cost"); - let active_clients: i64 = t.get("active_clients"); - - let today_requests: i64 = d.get("today_requests"); - let today_cost: f64 = d.get("today_cost"); - - let total_count: i64 = e.get("total"); - let error_count: i64 = e.get("errors"); - let error_rate = if total_count > 0 { - (error_count as f64 / total_count as f64) * 100.0 - } else { - 0.0 - }; - - let avg_response_time: f64 = a.get("avg_duration"); - - Json(ApiResponse::success(serde_json::json!({ - "total_requests": total_requests, - "total_tokens": total_tokens, - "total_cost": total_cost, - "active_clients": active_clients, - "today_requests": today_requests, - "today_cost": today_cost, - "error_rate": error_rate, - "avg_response_time": avg_response_time, - }))) - } - _ => Json(ApiResponse::error("Failed to fetch usage statistics".to_string())) - } -} - -async fn handle_time_series(State(state): State) -> Json> { - let pool = &state.app_state.db_pool; - - let now = chrono::Utc::now(); - let twenty_four_hours_ago = now - chrono::Duration::hours(24); - - let result = sqlx::query( - r#" - SELECT - strftime('%H:00', timestamp) as hour, - COUNT(*) as requests, - SUM(total_tokens) as tokens, - SUM(cost) as cost - FROM llm_requests - WHERE timestamp >= ? - GROUP BY hour - ORDER BY hour - "# - ) - .bind(twenty_four_hours_ago) - .fetch_all(pool) - .await; - - match result { - Ok(rows) => { - let mut series = Vec::new(); - - for row in rows { - let hour: String = row.get("hour"); - let requests: i64 = row.get("requests"); - let tokens: i64 = row.get("tokens"); - let cost: f64 = row.get("cost"); - - series.push(serde_json::json!({ - "time": hour, - "requests": requests, - "tokens": tokens, - "cost": cost, - })); - } - - Json(ApiResponse::success(serde_json::json!({ - "series": series, - "period": "24h" - }))) - } - Err(e) => { - warn!("Failed to fetch time series data: {}", e); - Json(ApiResponse::error("Failed to fetch time series data".to_string())) - } - } -} - -async fn handle_clients_usage(State(state): State) -> Json> { - // Query database for client usage statistics - let pool = &state.app_state.db_pool; - - let result = sqlx::query( - r#" - SELECT - client_id, - COUNT(*) as requests, - SUM(total_tokens) as tokens, - SUM(cost) as cost, - MAX(timestamp) as last_request - FROM llm_requests - GROUP BY client_id - ORDER BY requests DESC - "# - ) - .fetch_all(pool) - .await; - - match result { - Ok(rows) => { - let mut client_usage = Vec::new(); - - for row in rows { - let client_id: String = row.get("client_id"); - let requests: i64 = row.get("requests"); - let tokens: i64 = row.get("tokens"); - let cost: f64 = row.get("cost"); - let last_request: Option> = row.get("last_request"); - - client_usage.push(serde_json::json!({ - "client_id": client_id, - "client_name": client_id, - "requests": requests, - "tokens": tokens, - "cost": cost, - "last_request": last_request, - })); - } - - Json(ApiResponse::success(serde_json::json!(client_usage))) - } - Err(e) => { - warn!("Failed to fetch client usage data: {}", e); - Json(ApiResponse::error("Failed to fetch client usage data".to_string())) - } - } -} - -async fn handle_providers_usage(State(state): State) -> Json> { - // Query database for provider usage statistics - let pool = &state.app_state.db_pool; - - let result = sqlx::query( - r#" - SELECT - provider, - COUNT(*) as requests, - COALESCE(SUM(total_tokens), 0) as tokens, - COALESCE(SUM(cost), 0.0) as cost - FROM llm_requests - GROUP BY provider - ORDER BY requests DESC - "# - ) - .fetch_all(pool) - .await; - - match result { - Ok(rows) => { - let mut provider_usage = Vec::new(); - - for row in rows { - let provider: String = row.get("provider"); - let requests: i64 = row.get("requests"); - let tokens: i64 = row.get("tokens"); - let cost: f64 = row.get("cost"); - - provider_usage.push(serde_json::json!({ - "provider": provider, - "requests": requests, - "tokens": tokens, - "cost": cost, - })); - } - - Json(ApiResponse::success(serde_json::json!(provider_usage))) - } - Err(e) => { - warn!("Failed to fetch provider usage data: {}", e); - Json(ApiResponse::error("Failed to fetch provider usage data".to_string())) - } - } -} - -async fn handle_detailed_usage(State(state): State) -> Json> { - let pool = &state.app_state.db_pool; - - let result = sqlx::query( - r#" - SELECT - strftime('%Y-%m-%d', timestamp) as date, - client_id, - provider, - model, - COUNT(*) as requests, - SUM(total_tokens) as tokens, - SUM(cost) as cost - FROM llm_requests - GROUP BY date, client_id, provider, model - ORDER BY date DESC - LIMIT 100 - "# - ) - .fetch_all(pool) - .await; - - match result { - Ok(rows) => { - let usage: Vec = rows.into_iter().map(|row| { - serde_json::json!({ - "date": row.get::("date"), - "client": row.get::("client_id"), - "provider": row.get::("provider"), - "model": row.get::("model"), - "requests": row.get::("requests"), - "tokens": row.get::("tokens"), - "cost": row.get::("cost"), - }) - }).collect(); - - Json(ApiResponse::success(serde_json::json!(usage))) - } - Err(e) => { - warn!("Failed to fetch detailed usage: {}", e); - Json(ApiResponse::error("Failed to fetch detailed usage".to_string())) - } - } -} - -async fn handle_analytics_breakdown(State(state): State) -> Json> { - let pool = &state.app_state.db_pool; - - // Model breakdown - let models = sqlx::query( - "SELECT model as label, COUNT(*) as value FROM llm_requests GROUP BY model ORDER BY value DESC" - ).fetch_all(pool); - - // Client breakdown - let clients = sqlx::query( - "SELECT client_id as label, COUNT(*) as value FROM llm_requests GROUP BY client_id ORDER BY value DESC" - ).fetch_all(pool); - - match tokio::join!(models, clients) { - (Ok(m_rows), Ok(c_rows)) => { - let model_breakdown: Vec = m_rows.into_iter().map(|r| { - serde_json::json!({ "label": r.get::("label"), "value": r.get::("value") }) - }).collect(); - - let client_breakdown: Vec = c_rows.into_iter().map(|r| { - serde_json::json!({ "label": r.get::("label"), "value": r.get::("value") }) - }).collect(); - - Json(ApiResponse::success(serde_json::json!({ - "models": model_breakdown, - "clients": client_breakdown - }))) - } - _ => Json(ApiResponse::error("Failed to fetch analytics breakdown".to_string())) - } -} - -// Client handlers -async fn handle_get_clients(State(state): State) -> Json> { - let pool = &state.app_state.db_pool; - - let result = sqlx::query( - r#" - SELECT - client_id as id, - name, - created_at, - total_requests, - total_tokens, - total_cost, - is_active - FROM clients - ORDER BY created_at DESC - "# - ) - .fetch_all(pool) - .await; - - match result { - Ok(rows) => { - let clients: Vec = rows.into_iter().map(|row| { - serde_json::json!({ - "id": row.get::("id"), - "name": row.get::, _>("name").unwrap_or_else(|| "Unnamed".to_string()), - "created_at": row.get::, _>("created_at"), - "requests_count": row.get::("total_requests"), - "total_tokens": row.get::("total_tokens"), - "total_cost": row.get::("total_cost"), - "status": if row.get::("is_active") { "active" } else { "inactive" }, - }) - }).collect(); - - Json(ApiResponse::success(serde_json::json!(clients))) - } - Err(e) => { - warn!("Failed to fetch clients: {}", e); - Json(ApiResponse::error("Failed to fetch clients".to_string())) - } - } -} - -#[derive(Deserialize)] -struct CreateClientRequest { - name: String, - client_id: Option, -} - -async fn handle_create_client( - State(state): State, - Json(payload): Json, -) -> Json> { - let pool = &state.app_state.db_pool; - - let client_id = payload.client_id.unwrap_or_else(|| { - format!("client-{}", uuid::Uuid::new_v4().to_string()[..8].to_string()) - }); - - let result = sqlx::query( - r#" - INSERT INTO clients (client_id, name, is_active) - VALUES (?, ?, TRUE) - RETURNING * - "# - ) - .bind(&client_id) - .bind(&payload.name) - .fetch_one(pool) - .await; - - match result { - Ok(row) => { - Json(ApiResponse::success(serde_json::json!({ - "id": row.get::("client_id"), - "name": row.get::, _>("name"), - "created_at": row.get::, _>("created_at"), - "status": "active", - }))) - } - Err(e) => { - warn!("Failed to create client: {}", e); - Json(ApiResponse::error(format!("Failed to create client: {}", e))) - } - } -} - -async fn handle_get_client( - State(_state): State, - axum::extract::Path(_id): axum::extract::Path, -) -> Json> { - Json(ApiResponse::error("Not implemented".to_string())) -} - -async fn handle_delete_client( - State(state): State, - axum::extract::Path(id): axum::extract::Path, -) -> Json> { - let pool = &state.app_state.db_pool; - - // Don't allow deleting the default client - if id == "default" { - return Json(ApiResponse::error("Cannot delete default client".to_string())); - } - - let result = sqlx::query("DELETE FROM clients WHERE client_id = ?") - .bind(id) - .execute(pool) - .await; - - match result { - Ok(_) => Json(ApiResponse::success(serde_json::json!({ "message": "Client deleted" }))), - Err(e) => Json(ApiResponse::error(format!("Failed to delete client: {}", e))), - } -} - -async fn handle_client_usage( - State(_state): State, - axum::extract::Path(_id): axum::extract::Path, -) -> Json> { - Json(ApiResponse::error("Not implemented".to_string())) -} - -// Provider handlers -async fn handle_get_providers(State(state): State) -> Json> { - let registry = &state.app_state.model_registry; - let config = &state.app_state.config; - let pool = &state.app_state.db_pool; - - // Load all overrides from database - let db_configs_result = sqlx::query("SELECT id, enabled, base_url, credit_balance, low_credit_threshold FROM provider_configs") - .fetch_all(pool) - .await; - - let mut db_configs = HashMap::new(); - if let Ok(rows) = db_configs_result { - for row in rows { - let id: String = row.get("id"); - let enabled: bool = row.get("enabled"); - let base_url: Option = row.get("base_url"); - let balance: f64 = row.get("credit_balance"); - let threshold: f64 = row.get("low_credit_threshold"); - db_configs.insert(id, (enabled, base_url, balance, threshold)); - } - } - - let mut providers_json = Vec::new(); - - // Define the list of providers we support - let provider_ids = vec!["openai", "gemini", "deepseek", "grok", "ollama"]; - - for id in provider_ids { - // Get base config - let (mut enabled, mut base_url, display_name) = match id { - "openai" => (config.providers.openai.enabled, config.providers.openai.base_url.clone(), "OpenAI"), - "gemini" => (config.providers.gemini.enabled, config.providers.gemini.base_url.clone(), "Google Gemini"), - "deepseek" => (config.providers.deepseek.enabled, config.providers.deepseek.base_url.clone(), "DeepSeek"), - "grok" => (config.providers.grok.enabled, config.providers.grok.base_url.clone(), "xAI Grok"), - "ollama" => (config.providers.ollama.enabled, config.providers.ollama.base_url.clone(), "Ollama"), - _ => (false, "".to_string(), "Unknown"), - }; - - let mut balance = 0.0; - let mut threshold = 5.0; - - // Apply database overrides - if let Some((db_enabled, db_url, db_balance, db_threshold)) = db_configs.get(id) { - enabled = *db_enabled; - if let Some(url) = db_url { - base_url = url.clone(); - } - balance = *db_balance; - threshold = *db_threshold; - } - - // Find models for this provider in registry - let mut models = Vec::new(); - if let Some(p_info) = registry.providers.get(id) { - models = p_info.models.keys().cloned().collect(); - } else if id == "ollama" { - models = config.providers.ollama.models.clone(); - } - - // Determine status - let status = if !enabled { - "disabled" - } else { - // Check if it's actually initialized in the provider manager - if state.app_state.provider_manager.get_provider(id).await.is_some() { - // Check circuit breaker - if state.app_state.rate_limit_manager.check_provider_request(id).await.unwrap_or(true) { - "online" - } else { - "degraded" - } - } else { - "error" // Enabled but failed to initialize (e.g. missing API key) - } - }; - - providers_json.push(serde_json::json!({ - "id": id, - "name": display_name, - "enabled": enabled, - "status": status, - "models": models, - "base_url": base_url, - "credit_balance": balance, - "low_credit_threshold": threshold, - "last_used": None::, - })); - } - - Json(ApiResponse::success(serde_json::json!(providers_json))) -} - -async fn handle_get_provider( - State(_state): State, - axum::extract::Path(_name): axum::extract::Path, -) -> Json> { - Json(ApiResponse::error("Not implemented".to_string())) -} - -#[derive(Deserialize)] -struct UpdateProviderRequest { - enabled: bool, - base_url: Option, - api_key: Option, - credit_balance: Option, - low_credit_threshold: Option, -} - -async fn handle_update_provider( - State(state): State, - axum::extract::Path(name): axum::extract::Path, - Json(payload): Json, -) -> Json> { - let pool = &state.app_state.db_pool; - - // Update or insert into database - let result = sqlx::query( - r#" - INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, credit_balance, low_credit_threshold) - VALUES (?, ?, ?, ?, ?, ?, ?) - ON CONFLICT(id) DO UPDATE SET - enabled = excluded.enabled, - base_url = excluded.base_url, - api_key = COALESCE(excluded.api_key, provider_configs.api_key), - credit_balance = COALESCE(excluded.credit_balance, provider_configs.credit_balance), - low_credit_threshold = COALESCE(excluded.low_credit_threshold, provider_configs.low_credit_threshold), - updated_at = CURRENT_TIMESTAMP - "# - ) - .bind(&name) - .bind(name.to_uppercase()) - .bind(payload.enabled) - .bind(&payload.base_url) - .bind(&payload.api_key) - .bind(payload.credit_balance) - .bind(payload.low_credit_threshold) - .execute(pool) - .await; - - match result { - Ok(_) => { - // Re-initialize provider in manager - if let Err(e) = state.app_state.provider_manager.initialize_provider(&name, &state.app_state.config, &state.app_state.db_pool).await { - warn!("Failed to re-initialize provider {}: {}", name, e); - return Json(ApiResponse::error(format!("Provider settings saved but initialization failed: {}", e))); - } - - Json(ApiResponse::success(serde_json::json!({ "message": "Provider updated and re-initialized" }))) - } - Err(e) => { - warn!("Failed to update provider config: {}", e); - Json(ApiResponse::error(format!("Failed to update provider: {}", e))) - } - } -} - -async fn handle_test_provider( - State(_state): State, - axum::extract::Path(_name): axum::extract::Path, -) -> Json> { - Json(ApiResponse::success(serde_json::json!({ - "success": true, - "latency": rand::random::() % 500 + 100, - "message": "Connection test successful" - }))) -} - -// Model handlers -async fn handle_get_models(State(state): State) -> Json> { - let registry = &state.app_state.model_registry; - let pool = &state.app_state.db_pool; - - // Load overrides from database - let db_models_result = sqlx::query("SELECT id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping FROM model_configs") - .fetch_all(pool) - .await; - - let mut db_models = HashMap::new(); - if let Ok(rows) = db_models_result { - for row in rows { - let id: String = row.get("id"); - db_models.insert(id, row); - } - } - - let mut models_json = Vec::new(); - - for (p_id, p_info) in ®istry.providers { - for (m_id, m_meta) in &p_info.models { - let mut enabled = true; - let mut prompt_cost = m_meta.cost.as_ref().map(|c| c.input).unwrap_or(0.0); - let mut completion_cost = m_meta.cost.as_ref().map(|c| c.output).unwrap_or(0.0); - let mut mapping = None::; - - if let Some(row) = db_models.get(m_id) { - enabled = row.get("enabled"); - if let Some(p) = row.get::, _>("prompt_cost_per_m") { prompt_cost = p; } - if let Some(c) = row.get::, _>("completion_cost_per_m") { completion_cost = c; } - mapping = row.get("mapping"); - } - - models_json.push(serde_json::json!({ - "id": m_id, - "provider": p_id, - "name": m_meta.name, - "enabled": enabled, - "prompt_cost": prompt_cost, - "completion_cost": completion_cost, - "mapping": mapping, - "context_limit": m_meta.limit.as_ref().map(|l| l.context).unwrap_or(0), - })); - } - } - - Json(ApiResponse::success(serde_json::json!(models_json))) -} - -#[derive(Deserialize)] -struct UpdateModelRequest { - enabled: bool, - prompt_cost: Option, - completion_cost: Option, - mapping: Option, -} - -async fn handle_update_model( - State(state): State, - axum::extract::Path(id): axum::extract::Path, - Json(payload): Json, -) -> Json> { - let pool = &state.app_state.db_pool; - - // Find provider_id for this model in registry - let provider_id = state.app_state.model_registry.providers.iter() - .find(|(_, p)| p.models.contains_key(&id)) - .map(|(id, _)| id.clone()) - .unwrap_or_else(|| "unknown".to_string()); - - let result = sqlx::query( - r#" - INSERT INTO model_configs (id, provider_id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping) - VALUES (?, ?, ?, ?, ?, ?) - ON CONFLICT(id) DO UPDATE SET - enabled = excluded.enabled, - prompt_cost_per_m = excluded.prompt_cost_per_m, - completion_cost_per_m = excluded.completion_cost_per_m, - mapping = excluded.mapping, - updated_at = CURRENT_TIMESTAMP - "# - ) - .bind(&id) - .bind(provider_id) - .bind(payload.enabled) - .bind(payload.prompt_cost) - .bind(payload.completion_cost) - .bind(payload.mapping) - .execute(pool) - .await; - - match result { - Ok(_) => Json(ApiResponse::success(serde_json::json!({ "message": "Model updated" }))), - Err(e) => Json(ApiResponse::error(format!("Failed to update model: {}", e))), - } -} - -// System handlers -async fn handle_system_health(State(state): State) -> Json> { - let mut components = HashMap::new(); - components.insert("api_server".to_string(), "online".to_string()); - components.insert("database".to_string(), "online".to_string()); - - // Check provider health via circuit breakers - let provider_ids: Vec = state.app_state.provider_manager.get_all_providers().await - .iter() - .map(|p| p.name().to_string()) - .collect(); - - for p_id in provider_ids { - if state.app_state.rate_limit_manager.check_provider_request(&p_id).await.unwrap_or(true) { - components.insert(p_id, "online".to_string()); - } else { - components.insert(p_id, "degraded".to_string()); - } - } - - Json(ApiResponse::success(serde_json::json!({ - "status": "healthy", - "timestamp": chrono::Utc::now().to_rfc3339(), - "components": components, - "metrics": { - "cpu_usage": rand::random::() * 5.0 + 1.0, - "memory_usage": rand::random::() * 10.0 + 20.0, - "active_connections": rand::random::() % 10 + 1, - } - }))) -} - -async fn handle_system_logs(State(state): State) -> Json> { - let pool = &state.app_state.db_pool; - - let result = sqlx::query( - r#" - SELECT - id, - timestamp, - client_id, - provider, - model, - prompt_tokens, - completion_tokens, - total_tokens, - cost, - status, - error_message, - duration_ms - FROM llm_requests - ORDER BY timestamp DESC - LIMIT 100 - "# - ) - .fetch_all(pool) - .await; - - match result { - Ok(rows) => { - let logs: Vec = rows.into_iter().map(|row| { - serde_json::json!({ - "id": row.get::("id"), - "timestamp": row.get::, _>("timestamp"), - "client_id": row.get::("client_id"), - "provider": row.get::("provider"), - "model": row.get::("model"), - "tokens": row.get::("total_tokens"), - "cost": row.get::("cost"), - "status": row.get::("status"), - "error": row.get::, _>("error_message"), - "duration": row.get::("duration_ms"), - }) - }).collect(); - - Json(ApiResponse::success(serde_json::json!(logs))) - } - Err(e) => { - warn!("Failed to fetch system logs: {}", e); - Json(ApiResponse::error("Failed to fetch system logs".to_string())) - } - } -} - -async fn handle_system_backup(State(_state): State) -> Json> { - Json(ApiResponse::success(serde_json::json!({ - "success": true, - "message": "Backup initiated", - "backup_id": format!("backup-{}", chrono::Utc::now().timestamp()), - }))) -} - -async fn handle_get_settings(State(state): State) -> Json> { - let registry = &state.app_state.model_registry; - let provider_count = registry.providers.len(); - let model_count: usize = registry.providers.values().map(|p| p.models.len()).sum(); - - Json(ApiResponse::success(serde_json::json!({ - "server": { - "auth_tokens": state.app_state.auth_tokens, - "version": env!("CARGO_PKG_VERSION"), - }, - "registry": { - "provider_count": provider_count, - "model_count": model_count, - }, - "database": { - "type": "SQLite", - } - }))) -} - -async fn handle_update_settings(State(_state): State) -> Json> { - Json(ApiResponse::error("Changing settings at runtime is not yet supported. Please update your config file and restart the server.".to_string())) -} - -// Helper functions -#[allow(dead_code)] -fn mask_token(token: &str) -> String { - if token.len() <= 8 { - return "*****".to_string(); - } - - let masked_len = token.len().min(12); - let visible_len = 4; - let mask_len = masked_len - visible_len; - - format!("{}{}", "*".repeat(mask_len), &token[token.len() - visible_len..]) -} \ No newline at end of file diff --git a/src/dashboard/models.rs b/src/dashboard/models.rs new file mode 100644 index 00000000..1c326392 --- /dev/null +++ b/src/dashboard/models.rs @@ -0,0 +1,116 @@ +use axum::{ + extract::{Path, State}, + response::Json, +}; +use serde::Deserialize; +use serde_json; +use sqlx::Row; +use std::collections::HashMap; + +use super::{ApiResponse, DashboardState}; + +#[derive(Deserialize)] +pub(super) struct UpdateModelRequest { + pub(super) enabled: bool, + pub(super) prompt_cost: Option, + pub(super) completion_cost: Option, + pub(super) mapping: Option, +} + +pub(super) async fn handle_get_models(State(state): State) -> Json> { + let registry = &state.app_state.model_registry; + let pool = &state.app_state.db_pool; + + // Load overrides from database + let db_models_result = + sqlx::query("SELECT id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping FROM model_configs") + .fetch_all(pool) + .await; + + let mut db_models = HashMap::new(); + if let Ok(rows) = db_models_result { + for row in rows { + let id: String = row.get("id"); + db_models.insert(id, row); + } + } + + let mut models_json = Vec::new(); + + for (p_id, p_info) in ®istry.providers { + for (m_id, m_meta) in &p_info.models { + let mut enabled = true; + let mut prompt_cost = m_meta.cost.as_ref().map(|c| c.input).unwrap_or(0.0); + let mut completion_cost = m_meta.cost.as_ref().map(|c| c.output).unwrap_or(0.0); + let mut mapping = None::; + + if let Some(row) = db_models.get(m_id) { + enabled = row.get("enabled"); + if let Some(p) = row.get::, _>("prompt_cost_per_m") { + prompt_cost = p; + } + if let Some(c) = row.get::, _>("completion_cost_per_m") { + completion_cost = c; + } + mapping = row.get("mapping"); + } + + models_json.push(serde_json::json!({ + "id": m_id, + "provider": p_id, + "name": m_meta.name, + "enabled": enabled, + "prompt_cost": prompt_cost, + "completion_cost": completion_cost, + "mapping": mapping, + "context_limit": m_meta.limit.as_ref().map(|l| l.context).unwrap_or(0), + })); + } + } + + Json(ApiResponse::success(serde_json::json!(models_json))) +} + +pub(super) async fn handle_update_model( + State(state): State, + Path(id): Path, + Json(payload): Json, +) -> Json> { + let pool = &state.app_state.db_pool; + + // Find provider_id for this model in registry + let provider_id = state + .app_state + .model_registry + .providers + .iter() + .find(|(_, p)| p.models.contains_key(&id)) + .map(|(id, _)| id.clone()) + .unwrap_or_else(|| "unknown".to_string()); + + let result = sqlx::query( + r#" + INSERT INTO model_configs (id, provider_id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping) + VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT(id) DO UPDATE SET + enabled = excluded.enabled, + prompt_cost_per_m = excluded.prompt_cost_per_m, + completion_cost_per_m = excluded.completion_cost_per_m, + mapping = excluded.mapping, + updated_at = CURRENT_TIMESTAMP + "#, + ) + .bind(&id) + .bind(provider_id) + .bind(payload.enabled) + .bind(payload.prompt_cost) + .bind(payload.completion_cost) + .bind(payload.mapping) + .execute(pool) + .await; + + match result { + Ok(_) => Json(ApiResponse::success(serde_json::json!({ "message": "Model updated" }))), + Err(e) => Json(ApiResponse::error(format!("Failed to update model: {}", e))), + } +} diff --git a/src/dashboard/providers.rs b/src/dashboard/providers.rs new file mode 100644 index 00000000..7bee91ee --- /dev/null +++ b/src/dashboard/providers.rs @@ -0,0 +1,346 @@ +use axum::{ + extract::{Path, State}, + response::Json, +}; +use serde::Deserialize; +use serde_json; +use sqlx::Row; +use std::collections::HashMap; +use tracing::warn; + +use super::{ApiResponse, DashboardState}; + +#[derive(Deserialize)] +pub(super) struct UpdateProviderRequest { + pub(super) enabled: bool, + pub(super) base_url: Option, + pub(super) api_key: Option, + pub(super) credit_balance: Option, + pub(super) low_credit_threshold: Option, +} + +pub(super) async fn handle_get_providers(State(state): State) -> Json> { + let registry = &state.app_state.model_registry; + let config = &state.app_state.config; + let pool = &state.app_state.db_pool; + + // Load all overrides from database + let db_configs_result = + sqlx::query("SELECT id, enabled, base_url, credit_balance, low_credit_threshold FROM provider_configs") + .fetch_all(pool) + .await; + + let mut db_configs = HashMap::new(); + if let Ok(rows) = db_configs_result { + for row in rows { + let id: String = row.get("id"); + let enabled: bool = row.get("enabled"); + let base_url: Option = row.get("base_url"); + let balance: f64 = row.get("credit_balance"); + let threshold: f64 = row.get("low_credit_threshold"); + db_configs.insert(id, (enabled, base_url, balance, threshold)); + } + } + + let mut providers_json = Vec::new(); + + // Define the list of providers we support + let provider_ids = vec!["openai", "gemini", "deepseek", "grok", "ollama"]; + + for id in provider_ids { + // Get base config + let (mut enabled, mut base_url, display_name) = match id { + "openai" => ( + config.providers.openai.enabled, + config.providers.openai.base_url.clone(), + "OpenAI", + ), + "gemini" => ( + config.providers.gemini.enabled, + config.providers.gemini.base_url.clone(), + "Google Gemini", + ), + "deepseek" => ( + config.providers.deepseek.enabled, + config.providers.deepseek.base_url.clone(), + "DeepSeek", + ), + "grok" => ( + config.providers.grok.enabled, + config.providers.grok.base_url.clone(), + "xAI Grok", + ), + "ollama" => ( + config.providers.ollama.enabled, + config.providers.ollama.base_url.clone(), + "Ollama", + ), + _ => (false, "".to_string(), "Unknown"), + }; + + let mut balance = 0.0; + let mut threshold = 5.0; + + // Apply database overrides + if let Some((db_enabled, db_url, db_balance, db_threshold)) = db_configs.get(id) { + enabled = *db_enabled; + if let Some(url) = db_url { + base_url = url.clone(); + } + balance = *db_balance; + threshold = *db_threshold; + } + + // Find models for this provider in registry + let mut models = Vec::new(); + if let Some(p_info) = registry.providers.get(id) { + models = p_info.models.keys().cloned().collect(); + } else if id == "ollama" { + models = config.providers.ollama.models.clone(); + } + + // Determine status + let status = if !enabled { + "disabled" + } else { + // Check if it's actually initialized in the provider manager + if state.app_state.provider_manager.get_provider(id).await.is_some() { + // Check circuit breaker + if state + .app_state + .rate_limit_manager + .check_provider_request(id) + .await + .unwrap_or(true) + { + "online" + } else { + "degraded" + } + } else { + "error" // Enabled but failed to initialize (e.g. missing API key) + } + }; + + providers_json.push(serde_json::json!({ + "id": id, + "name": display_name, + "enabled": enabled, + "status": status, + "models": models, + "base_url": base_url, + "credit_balance": balance, + "low_credit_threshold": threshold, + "last_used": None::, + })); + } + + Json(ApiResponse::success(serde_json::json!(providers_json))) +} + +pub(super) async fn handle_get_provider( + State(state): State, + Path(name): Path, +) -> Json> { + let registry = &state.app_state.model_registry; + let config = &state.app_state.config; + let pool = &state.app_state.db_pool; + + // Validate provider name + let (mut enabled, mut base_url, display_name) = match name.as_str() { + "openai" => ( + config.providers.openai.enabled, + config.providers.openai.base_url.clone(), + "OpenAI", + ), + "gemini" => ( + config.providers.gemini.enabled, + config.providers.gemini.base_url.clone(), + "Google Gemini", + ), + "deepseek" => ( + config.providers.deepseek.enabled, + config.providers.deepseek.base_url.clone(), + "DeepSeek", + ), + "grok" => ( + config.providers.grok.enabled, + config.providers.grok.base_url.clone(), + "xAI Grok", + ), + "ollama" => ( + config.providers.ollama.enabled, + config.providers.ollama.base_url.clone(), + "Ollama", + ), + _ => return Json(ApiResponse::error(format!("Unknown provider '{}'", name))), + }; + + let mut balance = 0.0; + let mut threshold = 5.0; + + // Apply database overrides + let db_config = sqlx::query( + "SELECT enabled, base_url, credit_balance, low_credit_threshold FROM provider_configs WHERE id = ?", + ) + .bind(&name) + .fetch_optional(pool) + .await; + + if let Ok(Some(row)) = db_config { + enabled = row.get::("enabled"); + if let Some(url) = row.get::, _>("base_url") { + base_url = url; + } + balance = row.get::("credit_balance"); + threshold = row.get::("low_credit_threshold"); + } + + // Find models for this provider + let mut models = Vec::new(); + if let Some(p_info) = registry.providers.get(name.as_str()) { + models = p_info.models.keys().cloned().collect(); + } else if name == "ollama" { + models = config.providers.ollama.models.clone(); + } + + // Determine status + let status = if !enabled { + "disabled" + } else if state.app_state.provider_manager.get_provider(&name).await.is_some() { + if state + .app_state + .rate_limit_manager + .check_provider_request(&name) + .await + .unwrap_or(true) + { + "online" + } else { + "degraded" + } + } else { + "error" + }; + + Json(ApiResponse::success(serde_json::json!({ + "id": name, + "name": display_name, + "enabled": enabled, + "status": status, + "models": models, + "base_url": base_url, + "credit_balance": balance, + "low_credit_threshold": threshold, + "last_used": None::, + }))) +} + +pub(super) async fn handle_update_provider( + State(state): State, + Path(name): Path, + Json(payload): Json, +) -> Json> { + let pool = &state.app_state.db_pool; + + // Update or insert into database + let result = sqlx::query( + r#" + INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, credit_balance, low_credit_threshold) + VALUES (?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(id) DO UPDATE SET + enabled = excluded.enabled, + base_url = excluded.base_url, + api_key = COALESCE(excluded.api_key, provider_configs.api_key), + credit_balance = COALESCE(excluded.credit_balance, provider_configs.credit_balance), + low_credit_threshold = COALESCE(excluded.low_credit_threshold, provider_configs.low_credit_threshold), + updated_at = CURRENT_TIMESTAMP + "# + ) + .bind(&name) + .bind(name.to_uppercase()) + .bind(payload.enabled) + .bind(&payload.base_url) + .bind(&payload.api_key) + .bind(payload.credit_balance) + .bind(payload.low_credit_threshold) + .execute(pool) + .await; + + match result { + Ok(_) => { + // Re-initialize provider in manager + if let Err(e) = state + .app_state + .provider_manager + .initialize_provider(&name, &state.app_state.config, &state.app_state.db_pool) + .await + { + warn!("Failed to re-initialize provider {}: {}", name, e); + return Json(ApiResponse::error(format!( + "Provider settings saved but initialization failed: {}", + e + ))); + } + + Json(ApiResponse::success( + serde_json::json!({ "message": "Provider updated and re-initialized" }), + )) + } + Err(e) => { + warn!("Failed to update provider config: {}", e); + Json(ApiResponse::error(format!("Failed to update provider: {}", e))) + } + } +} + +pub(super) async fn handle_test_provider( + State(state): State, + Path(name): Path, +) -> Json> { + let start = std::time::Instant::now(); + + let provider = match state.app_state.provider_manager.get_provider(&name).await { + Some(p) => p, + None => { + return Json(ApiResponse::error(format!( + "Provider '{}' not found or not enabled", + name + ))); + } + }; + + // Pick a real model for this provider from the registry + let test_model = state + .app_state + .model_registry + .providers + .get(&name) + .and_then(|p| p.models.keys().next().cloned()) + .unwrap_or_else(|| name.clone()); + + let test_request = crate::models::UnifiedRequest { + client_id: "system-test".to_string(), + model: test_model, + messages: vec![crate::models::UnifiedMessage { + role: "user".to_string(), + content: vec![crate::models::ContentPart::Text { text: "Hi".to_string() }], + }], + temperature: None, + max_tokens: Some(5), + stream: false, + has_images: false, + }; + + match provider.chat_completion(test_request).await { + Ok(_) => { + let latency = start.elapsed().as_millis(); + Json(ApiResponse::success(serde_json::json!({ + "success": true, + "latency": latency, + "message": "Connection test successful" + }))) + } + Err(e) => Json(ApiResponse::error(format!("Provider test failed: {}", e))), + } +} diff --git a/src/dashboard/sessions.rs b/src/dashboard/sessions.rs new file mode 100644 index 00000000..5fdc3cc2 --- /dev/null +++ b/src/dashboard/sessions.rs @@ -0,0 +1,64 @@ +use chrono::{DateTime, Duration, Utc}; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::RwLock; + +#[derive(Clone, Debug)] +pub struct Session { + pub username: String, + pub role: String, + pub created_at: DateTime, + pub expires_at: DateTime, +} + +#[derive(Clone)] +pub struct SessionManager { + sessions: Arc>>, + ttl_hours: i64, +} + +impl SessionManager { + pub fn new(ttl_hours: i64) -> Self { + Self { + sessions: Arc::new(RwLock::new(HashMap::new())), + ttl_hours, + } + } + + /// Create a new session and return the session token. + pub async fn create_session(&self, username: String, role: String) -> String { + let token = format!("session-{}", uuid::Uuid::new_v4()); + let now = Utc::now(); + let session = Session { + username, + role, + created_at: now, + expires_at: now + Duration::hours(self.ttl_hours), + }; + self.sessions.write().await.insert(token.clone(), session); + token + } + + /// Validate a session token and return the session if valid and not expired. + pub async fn validate_session(&self, token: &str) -> Option { + let sessions = self.sessions.read().await; + sessions.get(token).and_then(|s| { + if s.expires_at > Utc::now() { + Some(s.clone()) + } else { + None + } + }) + } + + /// Revoke (delete) a session by token. + pub async fn revoke_session(&self, token: &str) { + self.sessions.write().await.remove(token); + } + + /// Remove all expired sessions from the store. + pub async fn cleanup_expired(&self) { + let now = Utc::now(); + self.sessions.write().await.retain(|_, s| s.expires_at > now); + } +} diff --git a/src/dashboard/system.rs b/src/dashboard/system.rs new file mode 100644 index 00000000..416405c3 --- /dev/null +++ b/src/dashboard/system.rs @@ -0,0 +1,193 @@ +use axum::{extract::State, response::Json}; +use chrono; +use serde_json; +use sqlx::Row; +use std::collections::HashMap; +use tracing::warn; + +use super::{ApiResponse, DashboardState}; + +pub(super) async fn handle_system_health(State(state): State) -> Json> { + let mut components = HashMap::new(); + components.insert("api_server".to_string(), "online".to_string()); + components.insert("database".to_string(), "online".to_string()); + + // Check provider health via circuit breakers + let provider_ids: Vec = state + .app_state + .provider_manager + .get_all_providers() + .await + .iter() + .map(|p| p.name().to_string()) + .collect(); + + for p_id in provider_ids { + if state + .app_state + .rate_limit_manager + .check_provider_request(&p_id) + .await + .unwrap_or(true) + { + components.insert(p_id, "online".to_string()); + } else { + components.insert(p_id, "degraded".to_string()); + } + } + + // Read real memory usage from /proc/self/status + let memory_mb = std::fs::read_to_string("/proc/self/status") + .ok() + .and_then(|s| s.lines().find(|l| l.starts_with("VmRSS:")).map(|l| l.to_string())) + .and_then(|l| l.split_whitespace().nth(1).and_then(|v| v.parse::().ok())) + .map(|kb| kb / 1024.0) + .unwrap_or(0.0); + + // Get real database pool stats + let db_pool_size = state.app_state.db_pool.size(); + let db_pool_idle = state.app_state.db_pool.num_idle(); + + Json(ApiResponse::success(serde_json::json!({ + "status": "healthy", + "timestamp": chrono::Utc::now().to_rfc3339(), + "components": components, + "metrics": { + "memory_usage_mb": (memory_mb * 10.0).round() / 10.0, + "db_connections_active": db_pool_size - db_pool_idle as u32, + "db_connections_idle": db_pool_idle, + } + }))) +} + +pub(super) async fn handle_system_logs(State(state): State) -> Json> { + let pool = &state.app_state.db_pool; + + let result = sqlx::query( + r#" + SELECT + id, + timestamp, + client_id, + provider, + model, + prompt_tokens, + completion_tokens, + total_tokens, + cost, + status, + error_message, + duration_ms + FROM llm_requests + ORDER BY timestamp DESC + LIMIT 100 + "#, + ) + .fetch_all(pool) + .await; + + match result { + Ok(rows) => { + let logs: Vec = rows + .into_iter() + .map(|row| { + serde_json::json!({ + "id": row.get::("id"), + "timestamp": row.get::, _>("timestamp"), + "client_id": row.get::("client_id"), + "provider": row.get::("provider"), + "model": row.get::("model"), + "tokens": row.get::("total_tokens"), + "cost": row.get::("cost"), + "status": row.get::("status"), + "error": row.get::, _>("error_message"), + "duration": row.get::("duration_ms"), + }) + }) + .collect(); + + Json(ApiResponse::success(serde_json::json!(logs))) + } + Err(e) => { + warn!("Failed to fetch system logs: {}", e); + Json(ApiResponse::error("Failed to fetch system logs".to_string())) + } + } +} + +pub(super) async fn handle_system_backup(State(state): State) -> Json> { + let pool = &state.app_state.db_pool; + let backup_id = format!("backup-{}", chrono::Utc::now().timestamp()); + let backup_path = format!("data/{}.db", backup_id); + + // Ensure the data directory exists + if let Err(e) = std::fs::create_dir_all("data") { + return Json(ApiResponse::error(format!("Failed to create backup directory: {}", e))); + } + + // Use SQLite VACUUM INTO for a consistent backup + let result = sqlx::query(&format!("VACUUM INTO '{}'", backup_path)) + .execute(pool) + .await; + + match result { + Ok(_) => { + // Get backup file size + let size_bytes = std::fs::metadata(&backup_path).map(|m| m.len()).unwrap_or(0); + + Json(ApiResponse::success(serde_json::json!({ + "success": true, + "message": "Backup completed successfully", + "backup_id": backup_id, + "backup_path": backup_path, + "size_bytes": size_bytes, + }))) + } + Err(e) => { + warn!("Database backup failed: {}", e); + Json(ApiResponse::error(format!("Backup failed: {}", e))) + } + } +} + +pub(super) async fn handle_get_settings(State(state): State) -> Json> { + let registry = &state.app_state.model_registry; + let provider_count = registry.providers.len(); + let model_count: usize = registry.providers.values().map(|p| p.models.len()).sum(); + + Json(ApiResponse::success(serde_json::json!({ + "server": { + "auth_tokens": state.app_state.auth_tokens.iter().map(|t| mask_token(t)).collect::>(), + "version": env!("CARGO_PKG_VERSION"), + }, + "registry": { + "provider_count": provider_count, + "model_count": model_count, + }, + "database": { + "type": "SQLite", + } + }))) +} + +pub(super) async fn handle_update_settings( + State(_state): State, +) -> Json> { + Json(ApiResponse::error( + "Changing settings at runtime is not yet supported. Please update your config file and restart the server." + .to_string(), + )) +} + +// Helper functions +fn mask_token(token: &str) -> String { + if token.len() <= 8 { + return "*****".to_string(); + } + + let masked_len = token.len().min(12); + let visible_len = 4; + let mask_len = masked_len - visible_len; + + format!("{}{}", "*".repeat(mask_len), &token[token.len() - visible_len..]) +} diff --git a/src/dashboard/usage.rs b/src/dashboard/usage.rs new file mode 100644 index 00000000..88183d8f --- /dev/null +++ b/src/dashboard/usage.rs @@ -0,0 +1,330 @@ +use axum::{extract::State, response::Json}; +use chrono; +use serde_json; +use sqlx::Row; +use tracing::warn; + +use super::{ApiResponse, DashboardState}; + +pub(super) async fn handle_usage_summary(State(state): State) -> Json> { + let pool = &state.app_state.db_pool; + + // Total stats + let total_stats = sqlx::query( + r#" + SELECT + COUNT(*) as total_requests, + COALESCE(SUM(total_tokens), 0) as total_tokens, + COALESCE(SUM(cost), 0.0) as total_cost, + COUNT(DISTINCT client_id) as active_clients + FROM llm_requests + "#, + ) + .fetch_one(pool); + + // Today's stats + let today = chrono::Utc::now().format("%Y-%m-%d").to_string(); + let today_stats = sqlx::query( + r#" + SELECT + COUNT(*) as today_requests, + COALESCE(SUM(total_tokens), 0) as today_tokens, + COALESCE(SUM(cost), 0.0) as today_cost + FROM llm_requests + WHERE strftime('%Y-%m-%d', timestamp) = ? + "#, + ) + .bind(today) + .fetch_one(pool); + + // Error stats + let error_stats = sqlx::query( + r#" + SELECT + COUNT(*) as total, + SUM(CASE WHEN status = 'error' THEN 1 ELSE 0 END) as errors + FROM llm_requests + "#, + ) + .fetch_one(pool); + + // Average response time + let avg_response = sqlx::query( + r#" + SELECT COALESCE(AVG(duration_ms), 0.0) as avg_duration + FROM llm_requests + WHERE status = 'success' + "#, + ) + .fetch_one(pool); + + match tokio::join!(total_stats, today_stats, error_stats, avg_response) { + (Ok(t), Ok(d), Ok(e), Ok(a)) => { + let total_requests: i64 = t.get("total_requests"); + let total_tokens: i64 = t.get("total_tokens"); + let total_cost: f64 = t.get("total_cost"); + let active_clients: i64 = t.get("active_clients"); + + let today_requests: i64 = d.get("today_requests"); + let today_cost: f64 = d.get("today_cost"); + + let total_count: i64 = e.get("total"); + let error_count: i64 = e.get("errors"); + let error_rate = if total_count > 0 { + (error_count as f64 / total_count as f64) * 100.0 + } else { + 0.0 + }; + + let avg_response_time: f64 = a.get("avg_duration"); + + Json(ApiResponse::success(serde_json::json!({ + "total_requests": total_requests, + "total_tokens": total_tokens, + "total_cost": total_cost, + "active_clients": active_clients, + "today_requests": today_requests, + "today_cost": today_cost, + "error_rate": error_rate, + "avg_response_time": avg_response_time, + }))) + } + _ => Json(ApiResponse::error("Failed to fetch usage statistics".to_string())), + } +} + +pub(super) async fn handle_time_series(State(state): State) -> Json> { + let pool = &state.app_state.db_pool; + + let now = chrono::Utc::now(); + let twenty_four_hours_ago = now - chrono::Duration::hours(24); + + let result = sqlx::query( + r#" + SELECT + strftime('%H:00', timestamp) as hour, + COUNT(*) as requests, + SUM(total_tokens) as tokens, + SUM(cost) as cost + FROM llm_requests + WHERE timestamp >= ? + GROUP BY hour + ORDER BY hour + "#, + ) + .bind(twenty_four_hours_ago) + .fetch_all(pool) + .await; + + match result { + Ok(rows) => { + let mut series = Vec::new(); + + for row in rows { + let hour: String = row.get("hour"); + let requests: i64 = row.get("requests"); + let tokens: i64 = row.get("tokens"); + let cost: f64 = row.get("cost"); + + series.push(serde_json::json!({ + "time": hour, + "requests": requests, + "tokens": tokens, + "cost": cost, + })); + } + + Json(ApiResponse::success(serde_json::json!({ + "series": series, + "period": "24h" + }))) + } + Err(e) => { + warn!("Failed to fetch time series data: {}", e); + Json(ApiResponse::error("Failed to fetch time series data".to_string())) + } + } +} + +pub(super) async fn handle_clients_usage(State(state): State) -> Json> { + // Query database for client usage statistics + let pool = &state.app_state.db_pool; + + let result = sqlx::query( + r#" + SELECT + client_id, + COUNT(*) as requests, + SUM(total_tokens) as tokens, + SUM(cost) as cost, + MAX(timestamp) as last_request + FROM llm_requests + GROUP BY client_id + ORDER BY requests DESC + "#, + ) + .fetch_all(pool) + .await; + + match result { + Ok(rows) => { + let mut client_usage = Vec::new(); + + for row in rows { + let client_id: String = row.get("client_id"); + let requests: i64 = row.get("requests"); + let tokens: i64 = row.get("tokens"); + let cost: f64 = row.get("cost"); + let last_request: Option> = row.get("last_request"); + + client_usage.push(serde_json::json!({ + "client_id": client_id, + "client_name": client_id, + "requests": requests, + "tokens": tokens, + "cost": cost, + "last_request": last_request, + })); + } + + Json(ApiResponse::success(serde_json::json!(client_usage))) + } + Err(e) => { + warn!("Failed to fetch client usage data: {}", e); + Json(ApiResponse::error("Failed to fetch client usage data".to_string())) + } + } +} + +pub(super) async fn handle_providers_usage( + State(state): State, +) -> Json> { + // Query database for provider usage statistics + let pool = &state.app_state.db_pool; + + let result = sqlx::query( + r#" + SELECT + provider, + COUNT(*) as requests, + COALESCE(SUM(total_tokens), 0) as tokens, + COALESCE(SUM(cost), 0.0) as cost + FROM llm_requests + GROUP BY provider + ORDER BY requests DESC + "#, + ) + .fetch_all(pool) + .await; + + match result { + Ok(rows) => { + let mut provider_usage = Vec::new(); + + for row in rows { + let provider: String = row.get("provider"); + let requests: i64 = row.get("requests"); + let tokens: i64 = row.get("tokens"); + let cost: f64 = row.get("cost"); + + provider_usage.push(serde_json::json!({ + "provider": provider, + "requests": requests, + "tokens": tokens, + "cost": cost, + })); + } + + Json(ApiResponse::success(serde_json::json!(provider_usage))) + } + Err(e) => { + warn!("Failed to fetch provider usage data: {}", e); + Json(ApiResponse::error("Failed to fetch provider usage data".to_string())) + } + } +} + +pub(super) async fn handle_detailed_usage(State(state): State) -> Json> { + let pool = &state.app_state.db_pool; + + let result = sqlx::query( + r#" + SELECT + strftime('%Y-%m-%d', timestamp) as date, + client_id, + provider, + model, + COUNT(*) as requests, + SUM(total_tokens) as tokens, + SUM(cost) as cost + FROM llm_requests + GROUP BY date, client_id, provider, model + ORDER BY date DESC + LIMIT 100 + "#, + ) + .fetch_all(pool) + .await; + + match result { + Ok(rows) => { + let usage: Vec = rows + .into_iter() + .map(|row| { + serde_json::json!({ + "date": row.get::("date"), + "client": row.get::("client_id"), + "provider": row.get::("provider"), + "model": row.get::("model"), + "requests": row.get::("requests"), + "tokens": row.get::("tokens"), + "cost": row.get::("cost"), + }) + }) + .collect(); + + Json(ApiResponse::success(serde_json::json!(usage))) + } + Err(e) => { + warn!("Failed to fetch detailed usage: {}", e); + Json(ApiResponse::error("Failed to fetch detailed usage".to_string())) + } + } +} + +pub(super) async fn handle_analytics_breakdown( + State(state): State, +) -> Json> { + let pool = &state.app_state.db_pool; + + // Model breakdown + let models = + sqlx::query("SELECT model as label, COUNT(*) as value FROM llm_requests GROUP BY model ORDER BY value DESC") + .fetch_all(pool); + + // Client breakdown + let clients = sqlx::query( + "SELECT client_id as label, COUNT(*) as value FROM llm_requests GROUP BY client_id ORDER BY value DESC", + ) + .fetch_all(pool); + + match tokio::join!(models, clients) { + (Ok(m_rows), Ok(c_rows)) => { + let model_breakdown: Vec = m_rows + .into_iter() + .map(|r| serde_json::json!({ "label": r.get::("label"), "value": r.get::("value") })) + .collect(); + + let client_breakdown: Vec = c_rows + .into_iter() + .map(|r| serde_json::json!({ "label": r.get::("label"), "value": r.get::("value") })) + .collect(); + + Json(ApiResponse::success(serde_json::json!({ + "models": model_breakdown, + "clients": client_breakdown + }))) + } + _ => Json(ApiResponse::error("Failed to fetch analytics breakdown".to_string())), + } +} diff --git a/src/dashboard/websocket.rs b/src/dashboard/websocket.rs new file mode 100644 index 00000000..06e231c3 --- /dev/null +++ b/src/dashboard/websocket.rs @@ -0,0 +1,75 @@ +use axum::{ + extract::{ + State, + ws::{Message, WebSocket, WebSocketUpgrade}, + }, + response::IntoResponse, +}; +use serde_json; +use tracing::info; + +use super::DashboardState; + +// WebSocket handler +pub(super) async fn handle_websocket(ws: WebSocketUpgrade, State(state): State) -> impl IntoResponse { + ws.on_upgrade(|socket| handle_websocket_connection(socket, state)) +} + +pub(super) async fn handle_websocket_connection(mut socket: WebSocket, state: DashboardState) { + info!("WebSocket connection established"); + + // Subscribe to events from the global bus + let mut rx = state.app_state.dashboard_tx.subscribe(); + + // Send initial connection message + let _ = socket + .send(Message::Text( + serde_json::json!({ + "type": "connected", + "message": "Connected to LLM Proxy Dashboard" + }) + .to_string() + .into(), + )) + .await; + + // Handle incoming messages and broadcast events + loop { + tokio::select! { + // Receive broadcast events + Ok(event) = rx.recv() => { + let Ok(json_str) = serde_json::to_string(&event) else { + continue; + }; + let message = Message::Text(json_str.into()); + if socket.send(message).await.is_err() { + break; + } + } + + // Receive WebSocket messages + result = socket.recv() => { + match result { + Some(Ok(Message::Text(text))) => { + handle_websocket_message(&text, &state).await; + } + _ => break, + } + } + } + } + + info!("WebSocket connection closed"); +} + +pub(super) async fn handle_websocket_message(text: &str, state: &DashboardState) { + // Parse and handle WebSocket messages + if let Ok(data) = serde_json::from_str::(text) + && let Some("ping") = data.get("type").and_then(|v| v.as_str()) + { + let _ = state.app_state.dashboard_tx.send(serde_json::json!({ + "type": "pong", + "payload": {} + })); + } +} diff --git a/src/database/mod.rs b/src/database/mod.rs index 4a0adf70..c352abfb 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use sqlx::sqlite::{SqlitePool, SqliteConnectOptions}; +use sqlx::sqlite::{SqliteConnectOptions, SqlitePool}; use std::str::FromStr; use tracing::info; @@ -9,17 +9,16 @@ pub type DbPool = SqlitePool; pub async fn init(config: &DatabaseConfig) -> Result { // Ensure the database directory exists - if let Some(parent) = config.path.parent() { - if !parent.as_os_str().is_empty() { - tokio::fs::create_dir_all(parent).await?; - } + if let Some(parent) = config.path.parent() + && !parent.as_os_str().is_empty() + { + tokio::fs::create_dir_all(parent).await?; } let database_path = config.path.to_string_lossy().to_string(); info!("Connecting to database at {}", database_path); - let options = SqliteConnectOptions::from_str(&format!("sqlite:{}", database_path))? - .create_if_missing(true); + let options = SqliteConnectOptions::from_str(&format!("sqlite:{}", database_path))?.create_if_missing(true); let pool = SqlitePool::connect_with(options).await?; @@ -91,7 +90,7 @@ async fn run_migrations(pool: &DbPool) -> Result<()> { low_credit_threshold REAL DEFAULT 5.0, updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ) - "# + "#, ) .execute(pool) .await?; @@ -110,7 +109,7 @@ async fn run_migrations(pool: &DbPool) -> Result<()> { updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (provider_id) REFERENCES provider_configs(id) ON DELETE CASCADE ) - "# + "#, ) .execute(pool) .await?; @@ -123,66 +122,59 @@ async fn run_migrations(pool: &DbPool) -> Result<()> { username TEXT UNIQUE NOT NULL, password_hash TEXT NOT NULL, role TEXT DEFAULT 'admin', + must_change_password BOOLEAN DEFAULT FALSE, created_at DATETIME DEFAULT CURRENT_TIMESTAMP ) - "# + "#, ) .execute(pool) .await?; + // Add must_change_password column if it doesn't exist (migration for existing DBs) + let _ = sqlx::query("ALTER TABLE users ADD COLUMN must_change_password BOOLEAN DEFAULT FALSE") + .execute(pool) + .await; + // Insert default admin user if none exists (default password: admin) - let user_count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users") - .fetch_one(pool) - .await?; + let user_count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users").fetch_one(pool).await?; if user_count.0 == 0 { // 'admin' hashed with default cost (12) - let default_admin_hash = bcrypt::hash("admin", 12).unwrap(); + let default_admin_hash = + bcrypt::hash("admin", 12).map_err(|e| anyhow::anyhow!("Failed to hash default password: {}", e))?; sqlx::query( - "INSERT INTO users (username, password_hash, role) VALUES ('admin', ?, 'admin')" + "INSERT INTO users (username, password_hash, role, must_change_password) VALUES ('admin', ?, 'admin', TRUE)" ) .bind(default_admin_hash) .execute(pool) .await?; - info!("Created default admin user with password 'admin'"); + info!("Created default admin user with password 'admin' (must change on first login)"); } // Create indices - sqlx::query( - "CREATE INDEX IF NOT EXISTS idx_clients_client_id ON clients(client_id)" - ) - .execute(pool) - .await?; + sqlx::query("CREATE INDEX IF NOT EXISTS idx_clients_client_id ON clients(client_id)") + .execute(pool) + .await?; - sqlx::query( - "CREATE INDEX IF NOT EXISTS idx_clients_created_at ON clients(created_at)" - ) - .execute(pool) - .await?; + sqlx::query("CREATE INDEX IF NOT EXISTS idx_clients_created_at ON clients(created_at)") + .execute(pool) + .await?; - sqlx::query( - "CREATE INDEX IF NOT EXISTS idx_llm_requests_timestamp ON llm_requests(timestamp)" - ) - .execute(pool) - .await?; + sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_timestamp ON llm_requests(timestamp)") + .execute(pool) + .await?; - sqlx::query( - "CREATE INDEX IF NOT EXISTS idx_llm_requests_client_id ON llm_requests(client_id)" - ) - .execute(pool) - .await?; + sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_client_id ON llm_requests(client_id)") + .execute(pool) + .await?; - sqlx::query( - "CREATE INDEX IF NOT EXISTS idx_llm_requests_provider ON llm_requests(provider)" - ) - .execute(pool) - .await?; + sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_provider ON llm_requests(provider)") + .execute(pool) + .await?; - sqlx::query( - "CREATE INDEX IF NOT EXISTS idx_llm_requests_status ON llm_requests(status)" - ) - .execute(pool) - .await?; + sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_status ON llm_requests(status)") + .execute(pool) + .await?; // Insert default client if none exists sqlx::query( @@ -200,4 +192,4 @@ async fn run_migrations(pool: &DbPool) -> Result<()> { pub async fn test_connection(pool: &DbPool) -> Result<()> { sqlx::query("SELECT 1").execute(pool).await?; Ok(()) -} \ No newline at end of file +} diff --git a/src/lib.rs b/src/lib.rs index e52bea9a..6a7645f5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,8 +6,8 @@ pub mod auth; pub mod client; pub mod config; -pub mod database; pub mod dashboard; +pub mod database; pub mod errors; pub mod logging; pub mod models; @@ -19,58 +19,62 @@ pub mod state; pub mod utils; // Re-exports for convenience -pub use auth::*; -pub use config::*; -pub use database::*; -pub use errors::*; -pub use logging::*; -pub use models::*; -pub use providers::*; -pub use server::*; -pub use state::*; +pub use auth::{AuthenticatedClient, validate_token}; +pub use config::{ + AppConfig, DatabaseConfig, DeepSeekConfig, GeminiConfig, GrokConfig, ModelMappingConfig, ModelPricing, + OllamaConfig, OpenAIConfig, PricingConfig, ProviderConfig, ServerConfig, +}; +pub use database::{DbPool, init as init_db, test_connection}; +pub use errors::AppError; +pub use logging::{LoggingContext, RequestLog, RequestLogger}; +pub use models::{ + ChatChoice, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, ChatMessage, + ChatStreamChoice, ChatStreamDelta, ContentPart, ContentPartValue, FromOpenAI, ImageUrl, MessageContent, + OpenAIContentPart, OpenAIMessage, OpenAIRequest, ToOpenAI, UnifiedMessage, UnifiedRequest, Usage, +}; +pub use providers::{Provider, ProviderManager, ProviderResponse, ProviderStreamChunk}; +pub use server::router; +pub use state::AppState; /// Test utilities for integration testing #[cfg(test)] pub mod test_utils { use std::sync::Arc; - - use crate::{ - state::AppState, - rate_limiting::RateLimitManager, - client::ClientManager, - providers::ProviderManager, - }; + + use crate::{client::ClientManager, providers::ProviderManager, rate_limiting::RateLimitManager, state::AppState}; use sqlx::sqlite::SqlitePool; - + /// Create a test application state pub async fn create_test_state() -> Arc { // Create in-memory database let pool = SqlitePool::connect("sqlite::memory:") .await .expect("Failed to create test database"); - + // Run migrations crate::database::init(&crate::config::DatabaseConfig { path: std::path::PathBuf::from(":memory:"), max_connections: 5, - }).await.expect("Failed to initialize test database"); - + }) + .await + .expect("Failed to initialize test database"); + let rate_limit_manager = RateLimitManager::new( crate::rate_limiting::RateLimiterConfig::default(), crate::rate_limiting::CircuitBreakerConfig::default(), ); - + let client_manager = Arc::new(ClientManager::new(pool.clone())); - + // Create provider manager let provider_manager = ProviderManager::new(); - + let model_registry = crate::models::registry::ModelRegistry { providers: std::collections::HashMap::new(), }; - + let (dashboard_tx, _) = tokio::sync::broadcast::channel(100); - + let config = Arc::new(crate::config::AppConfig { server: crate::config::ServerConfig { port: 8080, @@ -82,11 +86,35 @@ pub mod test_utils { max_connections: 5, }, providers: crate::config::ProviderConfig { - openai: crate::config::OpenAIConfig { api_key_env: "OPENAI_API_KEY".to_string(), base_url: "".to_string(), default_model: "".to_string(), enabled: true }, - gemini: crate::config::GeminiConfig { api_key_env: "GEMINI_API_KEY".to_string(), base_url: "".to_string(), default_model: "".to_string(), enabled: true }, - deepseek: crate::config::DeepSeekConfig { api_key_env: "DEEPSEEK_API_KEY".to_string(), base_url: "".to_string(), default_model: "".to_string(), enabled: true }, - grok: crate::config::GrokConfig { api_key_env: "GROK_API_KEY".to_string(), base_url: "".to_string(), default_model: "".to_string(), enabled: true }, - ollama: crate::config::OllamaConfig { base_url: "".to_string(), enabled: true, models: vec![] }, + openai: crate::config::OpenAIConfig { + api_key_env: "OPENAI_API_KEY".to_string(), + base_url: "".to_string(), + default_model: "".to_string(), + enabled: true, + }, + gemini: crate::config::GeminiConfig { + api_key_env: "GEMINI_API_KEY".to_string(), + base_url: "".to_string(), + default_model: "".to_string(), + enabled: true, + }, + deepseek: crate::config::DeepSeekConfig { + api_key_env: "DEEPSEEK_API_KEY".to_string(), + base_url: "".to_string(), + default_model: "".to_string(), + enabled: true, + }, + grok: crate::config::GrokConfig { + api_key_env: "GROK_API_KEY".to_string(), + base_url: "".to_string(), + default_model: "".to_string(), + enabled: true, + }, + ollama: crate::config::OllamaConfig { + base_url: "".to_string(), + enabled: true, + models: vec![], + }, }, model_mapping: crate::config::ModelMappingConfig { patterns: vec![] }, pricing: crate::config::PricingConfig { @@ -111,7 +139,7 @@ pub mod test_utils { auth_tokens: vec![], }) } - + /// Create a test HTTP client pub fn create_test_client() -> reqwest::Client { reqwest::Client::builder() diff --git a/src/logging/mod.rs b/src/logging/mod.rs index aa46a9fe..7eeb0895 100644 --- a/src/logging/mod.rs +++ b/src/logging/mod.rs @@ -1,8 +1,8 @@ use chrono::{DateTime, Utc}; +use serde::Serialize; use sqlx::SqlitePool; use tokio::sync::broadcast; use tracing::warn; -use serde::Serialize; use crate::errors::AppError; @@ -38,7 +38,7 @@ impl RequestLogger { pub fn log_request(&self, log: RequestLog) { let pool = self.db_pool.clone(); let tx = self.dashboard_tx.clone(); - + // Spawn async task to log without blocking response tokio::spawn(async move { // Broadcast to dashboard @@ -77,20 +77,18 @@ impl RequestLogger { .bind(log.status) .bind(log.error_message) .bind(log.duration_ms as i64) - .bind(None::) // request_body - TODO: store serialized request - .bind(None::) // response_body - TODO: store serialized response or error + .bind(None::) // request_body - optional, not stored to save disk space + .bind(None::) // response_body - optional, not stored to save disk space .execute(&mut *tx) .await?; // Deduct from provider balance if successful if log.cost > 0.0 { - sqlx::query( - "UPDATE provider_configs SET credit_balance = credit_balance - ? WHERE id = ?" - ) - .bind(log.cost) - .bind(&log.provider) - .execute(&mut *tx) - .await?; + sqlx::query("UPDATE provider_configs SET credit_balance = credit_balance - ? WHERE id = ?") + .bind(log.cost) + .bind(&log.provider) + .execute(&mut *tx) + .await?; } tx.commit().await?; @@ -108,32 +106,32 @@ impl RequestLogger { // next: Next, // ) -> Response { // let start_time = std::time::Instant::now(); -// +// // // Extract client_id from auth or use "unknown" // let client_id = match auth_result { // Ok(auth) => auth.client_id, // Err(_) => "unknown".to_string(), // }; -// +// // // Try to extract request details // let (request_parts, request_body) = request.into_parts(); -// +// // // Clone request parts for logging // let path = request_parts.uri.path().to_string(); -// +// // // Check if this is a chat completion request // let is_chat_completion = path == "/v1/chat/completions"; -// +// // // Reconstruct request for downstream handlers // let request = Request::from_parts(request_parts, request_body); -// +// // // Process request and get response // let response = next.run(request).await; -// +// // // Calculate duration // let duration = start_time.elapsed(); // let duration_ms = duration.as_millis() as u64; -// +// // // Log basic request info // info!( // "Request from {} to {} - Status: {} - Duration: {}ms", @@ -142,10 +140,10 @@ impl RequestLogger { // response.status().as_u16(), // duration_ms // ); -// +// // // TODO: Extract more details from request/response for logging // // For now, we'll need to modify the server handler to pass additional context -// +// // response // } @@ -177,26 +175,26 @@ impl LoggingContext { error: None, } } - + pub fn with_token_counts(mut self, prompt_tokens: u32, completion_tokens: u32) -> Self { self.prompt_tokens = prompt_tokens; self.completion_tokens = completion_tokens; self.total_tokens = prompt_tokens + completion_tokens; self } - + pub fn with_cost(mut self, cost: f64) -> Self { self.cost = cost; self } - + pub fn with_images(mut self, has_images: bool) -> Self { self.has_images = has_images; self } - + pub fn with_error(mut self, error: AppError) -> Self { self.error = Some(error); self } -} \ No newline at end of file +} diff --git a/src/main.rs b/src/main.rs index 68d1d388..5aefdc18 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,16 +1,15 @@ use anyhow::Result; use axum::{Router, routing::get}; use std::net::SocketAddr; -use tracing::{info, error}; +use tracing::{error, info}; use llm_proxy::{ - config::AppConfig, - state::AppState, + config::AppConfig, + dashboard, database, providers::ProviderManager, - database, + rate_limiting::{CircuitBreakerConfig, RateLimitManager, RateLimiterConfig}, server, - dashboard, - rate_limiting::{RateLimitManager, RateLimiterConfig, CircuitBreakerConfig}, + state::AppState, }; #[tokio::main] @@ -33,7 +32,7 @@ async fn main() -> Result<()> { // Initialize provider manager with configured providers let provider_manager = ProviderManager::new(); - + // Initialize all supported providers (they handle their own enabled check) let supported_providers = vec!["openai", "gemini", "deepseek", "grok", "ollama"]; for name in supported_providers { @@ -43,22 +42,28 @@ async fn main() -> Result<()> { } // Create rate limit manager - let rate_limit_manager = RateLimitManager::new( - RateLimiterConfig::default(), - CircuitBreakerConfig::default(), - ); - + let rate_limit_manager = RateLimitManager::new(RateLimiterConfig::default(), CircuitBreakerConfig::default()); + // Fetch model registry from models.dev let model_registry = match llm_proxy::utils::registry::fetch_registry().await { Ok(registry) => registry, Err(e) => { error!("Failed to fetch model registry: {}. Using empty registry.", e); - llm_proxy::models::registry::ModelRegistry { providers: std::collections::HashMap::new() } + llm_proxy::models::registry::ModelRegistry { + providers: std::collections::HashMap::new(), + } } }; - + // Create application state - let state = AppState::new(config.clone(), provider_manager, db_pool, rate_limit_manager, model_registry, config.server.auth_tokens.clone()); + let state = AppState::new( + config.clone(), + provider_manager, + db_pool, + rate_limit_manager, + model_registry, + config.server.auth_tokens.clone(), + ); // Create application router let app = Router::new() diff --git a/src/models/mod.rs b/src/models/mod.rs index 01d97fcf..29d56c2f 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -198,9 +198,7 @@ impl TryFrom for UnifiedRequest { .into_iter() .map(|msg| { let (content, _images_in_message) = match msg.content { - MessageContent::Text { content } => { - (vec![ContentPart::Text { text: content }], false) - } + MessageContent::Text { content } => (vec![ContentPart::Text { text: content }], false), MessageContent::Parts { content } => { let mut unified_content = Vec::new(); let mut has_images_in_msg = false; @@ -213,18 +211,16 @@ impl TryFrom for UnifiedRequest { ContentPartValue::ImageUrl { image_url } => { has_images_in_msg = true; has_images = true; - unified_content.push(ContentPart::Image( - crate::multimodal::ImageInput::from_url(image_url.url) - )); + unified_content.push(ContentPart::Image(crate::multimodal::ImageInput::from_url( + image_url.url, + ))); } } } (unified_content, has_images_in_msg) } - MessageContent::None => { - (vec![], false) - } + MessageContent::None => (vec![], false), }; UnifiedMessage { diff --git a/src/models/registry.rs b/src/models/registry.rs index 3ccd17c2..f793e8e6 100644 --- a/src/models/registry.rs +++ b/src/models/registry.rs @@ -54,7 +54,7 @@ impl ModelRegistry { return Some(model); } } - + // Try searching for the model ID inside the metadata if the key was different for provider in self.providers.values() { for model in provider.models.values() { @@ -63,7 +63,7 @@ impl ModelRegistry { } } } - + None } } diff --git a/src/multimodal/mod.rs b/src/multimodal/mod.rs index 0af00e1d..cefedf57 100644 --- a/src/multimodal/mod.rs +++ b/src/multimodal/mod.rs @@ -1,5 +1,5 @@ //! Multimodal support for image processing and conversion -//! +//! //! This module handles: //! 1. Image format detection and conversion //! 2. Base64 encoding/decoding @@ -7,24 +7,18 @@ //! 4. Provider-specific image format conversion use anyhow::{Context, Result}; -use base64::{engine::general_purpose, Engine as _}; +use base64::{Engine as _, engine::general_purpose}; use tracing::{info, warn}; /// Supported image formats for multimodal input #[derive(Debug, Clone)] pub enum ImageInput { /// Base64-encoded image data with MIME type - Base64 { - data: String, - mime_type: String, - }, + Base64 { data: String, mime_type: String }, /// URL to fetch image from Url(String), /// Raw bytes with MIME type - Bytes { - data: Vec, - mime_type: String, - }, + Bytes { data: Vec, mime_type: String }, } impl ImageInput { @@ -32,17 +26,17 @@ impl ImageInput { pub fn from_base64(data: String, mime_type: String) -> Self { Self::Base64 { data, mime_type } } - + /// Create ImageInput from URL pub fn from_url(url: String) -> Self { Self::Url(url) } - + /// Create ImageInput from raw bytes pub fn from_bytes(data: Vec, mime_type: String) -> Self { Self::Bytes { data, mime_type } } - + /// Get MIME type if available pub fn mime_type(&self) -> Option<&str> { match self { @@ -51,7 +45,7 @@ impl ImageInput { Self::Url(_) => None, } } - + /// Convert to base64 if not already pub async fn to_base64(&self) -> Result<(String, String)> { match self { @@ -63,56 +57,56 @@ impl ImageInput { Self::Url(url) => { // Fetch image from URL info!("Fetching image from URL: {}", url); - let response = reqwest::get(url) - .await - .context("Failed to fetch image from URL")?; - + let response = reqwest::get(url).await.context("Failed to fetch image from URL")?; + if !response.status().is_success() { anyhow::bail!("Failed to fetch image: HTTP {}", response.status()); } - + let mime_type = response .headers() .get(reqwest::header::CONTENT_TYPE) .and_then(|h| h.to_str().ok()) .unwrap_or("image/jpeg") .to_string(); - + let bytes = response.bytes().await.context("Failed to read image bytes")?; - + let base64_data = general_purpose::STANDARD.encode(&bytes); Ok((base64_data, mime_type)) } } } - + /// Get image dimensions (width, height) pub async fn get_dimensions(&self) -> Result<(u32, u32)> { let bytes = match self { - Self::Base64 { data, .. } => { - general_purpose::STANDARD.decode(data).context("Failed to decode base64")? - } + Self::Base64 { data, .. } => general_purpose::STANDARD + .decode(data) + .context("Failed to decode base64")?, Self::Bytes { data, .. } => data.clone(), Self::Url(_) => { let (base64_data, _) = self.to_base64().await?; - general_purpose::STANDARD.decode(&base64_data).context("Failed to decode base64")? + general_purpose::STANDARD + .decode(&base64_data) + .context("Failed to decode base64")? } }; - + let img = image::load_from_memory(&bytes).context("Failed to load image from bytes")?; Ok((img.width(), img.height())) } - + /// Validate image size and format pub async fn validate(&self, max_size_mb: f64) -> Result<()> { let (width, height) = self.get_dimensions().await?; - + // Check dimensions if width > 4096 || height > 4096 { warn!("Image dimensions too large: {}x{}", width, height); // Continue anyway, but log warning } - + // Check file size let size_bytes = match self { Self::Base64 { data, .. } => { @@ -126,12 +120,12 @@ impl ImageInput { return Ok(()); } }; - + let size_mb = size_bytes as f64 / (1024.0 * 1024.0); if size_mb > max_size_mb { anyhow::bail!("Image too large: {:.2}MB > {:.2}MB limit", size_mb, max_size_mb); } - + Ok(()) } } @@ -143,10 +137,10 @@ impl ImageConverter { /// Convert image to OpenAI-compatible format pub async fn to_openai_format(image: &ImageInput) -> Result { let (base64_data, mime_type) = image.to_base64().await?; - + // OpenAI expects data URL format: "data:image/jpeg;base64,{data}" let data_url = format!("data:{};base64,{}", mime_type, base64_data); - + Ok(serde_json::json!({ "type": "image_url", "image_url": { @@ -155,11 +149,11 @@ impl ImageConverter { } })) } - + /// Convert image to Gemini-compatible format pub async fn to_gemini_format(image: &ImageInput) -> Result { let (base64_data, mime_type) = image.to_base64().await?; - + // Gemini expects inline data format Ok(serde_json::json!({ "inline_data": { @@ -168,32 +162,34 @@ impl ImageConverter { } })) } - + /// Convert image to DeepSeek-compatible format pub async fn to_deepseek_format(image: &ImageInput) -> Result { // DeepSeek uses OpenAI-compatible format for vision models Self::to_openai_format(image).await } - + /// Detect if a model supports multimodal input pub fn model_supports_multimodal(model: &str) -> bool { // OpenAI vision models - if (model.starts_with("gpt-4") && (model.contains("vision") || model.contains("-v") || model.contains("4o"))) || - model.starts_with("o1-") || model.starts_with("o3-") { + if (model.starts_with("gpt-4") && (model.contains("vision") || model.contains("-v") || model.contains("4o"))) + || model.starts_with("o1-") + || model.starts_with("o3-") + { return true; } - + // Gemini vision models if model.starts_with("gemini") { // Most Gemini models support vision return true; } - + // DeepSeek vision models if model.starts_with("deepseek-vl") { return true; } - + false } } @@ -201,47 +197,47 @@ impl ImageConverter { /// Parse OpenAI-compatible multimodal message content pub fn parse_openai_content(content: &serde_json::Value) -> Result)>> { let mut parts = Vec::new(); - + if let Some(content_str) = content.as_str() { // Simple text content parts.push((content_str.to_string(), None)); } else if let Some(content_array) = content.as_array() { // Array of content parts (text and/or images) for part in content_array { - if let Some(part_obj) = part.as_object() { - if let Some(part_type) = part_obj.get("type").and_then(|t| t.as_str()) { - match part_type { - "text" => { - if let Some(text) = part_obj.get("text").and_then(|t| t.as_str()) { - parts.push((text.to_string(), None)); - } + if let Some(part_obj) = part.as_object() + && let Some(part_type) = part_obj.get("type").and_then(|t| t.as_str()) + { + match part_type { + "text" => { + if let Some(text) = part_obj.get("text").and_then(|t| t.as_str()) { + parts.push((text.to_string(), None)); } - "image_url" => { - if let Some(image_url_obj) = part_obj.get("image_url").and_then(|o| o.as_object()) { - if let Some(url) = image_url_obj.get("url").and_then(|u| u.as_str()) { - if url.starts_with("data:") { - // Parse data URL - if let Some((mime_type, data)) = parse_data_url(url) { - let image_input = ImageInput::from_base64(data, mime_type); - parts.push(("".to_string(), Some(image_input))); - } - } else { - // Regular URL - let image_input = ImageInput::from_url(url.to_string()); - parts.push(("".to_string(), Some(image_input))); - } + } + "image_url" => { + if let Some(image_url_obj) = part_obj.get("image_url").and_then(|o| o.as_object()) + && let Some(url) = image_url_obj.get("url").and_then(|u| u.as_str()) + { + if url.starts_with("data:") { + // Parse data URL + if let Some((mime_type, data)) = parse_data_url(url) { + let image_input = ImageInput::from_base64(data, mime_type); + parts.push(("".to_string(), Some(image_input))); } + } else { + // Regular URL + let image_input = ImageInput::from_url(url.to_string()); + parts.push(("".to_string(), Some(image_input))); } } - _ => { - warn!("Unknown content part type: {}", part_type); - } + } + _ => { + warn!("Unknown content part type: {}", part_type); } } } } } - + Ok(parts) } @@ -250,36 +246,38 @@ fn parse_data_url(data_url: &str) -> Option<(String, String)> { if !data_url.starts_with("data:") { return None; } - + let parts: Vec<&str> = data_url[5..].split(";base64,").collect(); if parts.len() != 2 { return None; } - + let mime_type = parts[0].to_string(); let data = parts[1].to_string(); - + Some((mime_type, data)) } #[cfg(test)] mod tests { use super::*; - + #[tokio::test] async fn test_parse_data_url() { let test_url = "data:image/jpeg;base64,SGVsbG8gV29ybGQ="; // "Hello World" in base64 let (mime_type, data) = parse_data_url(test_url).unwrap(); - + assert_eq!(mime_type, "image/jpeg"); assert_eq!(data, "SGVsbG8gV29ybGQ="); } - + #[tokio::test] async fn test_model_supports_multimodal() { assert!(ImageConverter::model_supports_multimodal("gpt-4-vision-preview")); + assert!(ImageConverter::model_supports_multimodal("gpt-4o")); assert!(ImageConverter::model_supports_multimodal("gemini-pro-vision")); + assert!(ImageConverter::model_supports_multimodal("gemini-pro")); assert!(!ImageConverter::model_supports_multimodal("gpt-3.5-turbo")); - assert!(!ImageConverter::model_supports_multimodal("gemini-pro")); + assert!(!ImageConverter::model_supports_multimodal("claude-3-opus")); } -} \ No newline at end of file +} diff --git a/src/providers/deepseek.rs b/src/providers/deepseek.rs index cc1e69e3..21ca26a6 100644 --- a/src/providers/deepseek.rs +++ b/src/providers/deepseek.rs @@ -1,14 +1,10 @@ -use async_trait::async_trait; use anyhow::Result; -use futures::stream::{BoxStream, StreamExt}; -use serde_json::Value; +use async_trait::async_trait; +use futures::stream::BoxStream; -use crate::{ - models::UnifiedRequest, - errors::AppError, - config::AppConfig, -}; +use super::helpers; use super::{ProviderResponse, ProviderStreamChunk}; +use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest}; pub struct DeepSeekProvider { client: reqwest::Client, @@ -23,7 +19,11 @@ impl DeepSeekProvider { Self::new_with_key(config, app_config, api_key) } - pub fn new_with_key(config: &crate::config::DeepSeekConfig, app_config: &AppConfig, api_key: String) -> Result { + pub fn new_with_key( + config: &crate::config::DeepSeekConfig, + app_config: &AppConfig, + api_key: String, + ) -> Result { Ok(Self { client: reqwest::Client::new(), config: config.clone(), @@ -47,42 +47,13 @@ impl super::Provider for DeepSeekProvider { false } - async fn chat_completion( - &self, - request: UnifiedRequest, - ) -> Result { - // Build the OpenAI-compatible body - let mut body = serde_json::json!({ - "model": request.model, - "messages": request.messages.iter().map(|m| { - serde_json::json!({ - "role": m.role, - "content": m.content.iter().map(|p| { - match p { - crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }), - crate::models::ContentPart::Image(image_input) => { - // DeepSeek currently doesn't support images in the same way, but we'll try to be standard - let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default(); - serde_json::json!({ - "type": "image_url", - "image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) } - }) - } - } - }).collect::>() - }) - }).collect::>(), - "stream": false, - }); + async fn chat_completion(&self, request: UnifiedRequest) -> Result { + let messages_json = helpers::messages_to_openai_json(&request.messages).await?; + let body = helpers::build_openai_body(&request, messages_json, false); - if let Some(temp) = request.temperature { - body["temperature"] = serde_json::json!(temp); - } - if let Some(max_tokens) = request.max_tokens { - body["max_tokens"] = serde_json::json!(max_tokens); - } - - let response = self.client.post(format!("{}/chat/completions", self.config.base_url)) + let response = self + .client + .post(format!("{}/chat/completions", self.config.base_url)) .header("Authorization", format!("Bearer {}", self.api_key)) .json(&body) .send() @@ -94,119 +65,52 @@ impl super::Provider for DeepSeekProvider { return Err(AppError::ProviderError(format!("DeepSeek API error: {}", error_text))); } - let resp_json: Value = response.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?; - - let choice = resp_json["choices"].get(0).ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?; - let message = &choice["message"]; - - let content = message["content"].as_str().unwrap_or_default().to_string(); - let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string()); - - let usage = &resp_json["usage"]; - let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32; - let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32; - let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32; + let resp_json: serde_json::Value = response + .json() + .await + .map_err(|e| AppError::ProviderError(e.to_string()))?; - Ok(ProviderResponse { - content, - reasoning_content, - prompt_tokens, - completion_tokens, - total_tokens, - model: request.model, - }) + helpers::parse_openai_response(&resp_json, request.model) } fn estimate_tokens(&self, request: &UnifiedRequest) -> Result { Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request)) } - fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64 { - if let Some(metadata) = registry.find_model(model) { - if let Some(cost) = &metadata.cost { - return (prompt_tokens as f64 * cost.input / 1_000_000.0) + - (completion_tokens as f64 * cost.output / 1_000_000.0); - } - } - - let (prompt_rate, completion_rate) = self.pricing.iter() - .find(|p| model.contains(&p.model)) - .map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million)) - .unwrap_or((0.14, 0.28)); - - (prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0) + fn calculate_cost( + &self, + model: &str, + prompt_tokens: u32, + completion_tokens: u32, + registry: &crate::models::registry::ModelRegistry, + ) -> f64 { + helpers::calculate_cost_with_registry( + model, + prompt_tokens, + completion_tokens, + registry, + &self.pricing, + 0.14, + 0.28, + ) } async fn chat_completion_stream( &self, request: UnifiedRequest, ) -> Result>, AppError> { - let mut body = serde_json::json!({ - "model": request.model, - "messages": request.messages.iter().map(|m| { - serde_json::json!({ - "role": m.role, - "content": m.content.iter().map(|p| { - match p { - crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }), - crate::models::ContentPart::Image(_) => serde_json::json!({ "type": "text", "text": "[Image]" }), - } - }).collect::>() - }) - }).collect::>(), - "stream": true, - }); + // DeepSeek doesn't support images in streaming, use text-only + let messages_json = helpers::messages_to_openai_json_text_only(&request.messages).await?; + let body = helpers::build_openai_body(&request, messages_json, true); - if let Some(temp) = request.temperature { - body["temperature"] = serde_json::json!(temp); - } - if let Some(max_tokens) = request.max_tokens { - body["max_tokens"] = serde_json::json!(max_tokens); - } + let es = reqwest_eventsource::EventSource::new( + self.client + .post(format!("{}/chat/completions", self.config.base_url)) + .header("Authorization", format!("Bearer {}", self.api_key)) + .json(&body), + ) + .map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?; - // Create eventsource stream - use reqwest_eventsource::{EventSource, Event}; - let es = EventSource::new(self.client.post(format!("{}/chat/completions", self.config.base_url)) - .header("Authorization", format!("Bearer {}", self.api_key)) - .json(&body)) - .map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?; - - let model = request.model.clone(); - - let stream = async_stream::try_stream! { - let mut es = es; - while let Some(event) = es.next().await { - match event { - Ok(Event::Message(msg)) => { - if msg.data == "[DONE]" { - break; - } - - let chunk: Value = serde_json::from_str(&msg.data) - .map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?; - - if let Some(choice) = chunk["choices"].get(0) { - let delta = &choice["delta"]; - let content = delta["content"].as_str().unwrap_or_default().to_string(); - let reasoning_content = delta["reasoning_content"].as_str().map(|s| s.to_string()); - let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string()); - - yield ProviderStreamChunk { - content, - reasoning_content, - finish_reason, - model: model.clone(), - }; - } - } - Ok(_) => continue, - Err(e) => { - Err(AppError::ProviderError(format!("Stream error: {}", e)))?; - } - } - } - }; - - Ok(Box::pin(stream)) + Ok(helpers::create_openai_stream(es, request.model, None)) } } diff --git a/src/providers/gemini.rs b/src/providers/gemini.rs index 3cd51529..3fc43084 100644 --- a/src/providers/gemini.rs +++ b/src/providers/gemini.rs @@ -1,14 +1,10 @@ -use async_trait::async_trait; use anyhow::Result; -use serde::{Deserialize, Serialize}; +use async_trait::async_trait; use futures::stream::BoxStream; +use serde::{Deserialize, Serialize}; -use crate::{ - models::UnifiedRequest, - errors::AppError, - config::AppConfig, -}; use super::{ProviderResponse, ProviderStreamChunk}; +use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest}; #[derive(Debug, Serialize)] struct GeminiRequest { @@ -61,8 +57,6 @@ struct GeminiResponse { usage_metadata: Option, } - - pub struct GeminiProvider { client: reqwest::Client, config: crate::config::GeminiConfig, @@ -80,7 +74,7 @@ impl GeminiProvider { let client = reqwest::Client::builder() .timeout(std::time::Duration::from_secs(30)) .build()?; - + Ok(Self { client, config: config.clone(), @@ -101,19 +95,16 @@ impl super::Provider for GeminiProvider { } fn supports_multimodal(&self) -> bool { - true // Gemini supports vision + true // Gemini supports vision } - async fn chat_completion( - &self, - request: UnifiedRequest, - ) -> Result { + async fn chat_completion(&self, request: UnifiedRequest) -> Result { // Convert UnifiedRequest to Gemini request let mut contents = Vec::with_capacity(request.messages.len()); - + for msg in request.messages { let mut parts = Vec::with_capacity(msg.content.len()); - + for part in msg.content { match part { crate::models::ContentPart::Text { text } => { @@ -123,9 +114,11 @@ impl super::Provider for GeminiProvider { }); } crate::models::ContentPart::Image(image_input) => { - let (base64_data, mime_type) = image_input.to_base64().await + let (base64_data, mime_type) = image_input + .to_base64() + .await .map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?; - + parts.push(GeminiPart { text: None, inline_data: Some(GeminiInlineData { @@ -136,23 +129,20 @@ impl super::Provider for GeminiProvider { } } } - + // Map role: "user" -> "user", "assistant" -> "model", "system" -> "user" let role = match msg.role.as_str() { "assistant" => "model".to_string(), _ => "user".to_string(), }; - - contents.push(GeminiContent { - parts, - role, - }); + + contents.push(GeminiContent { parts, role }); } - + if contents.is_empty() { return Err(AppError::ProviderError("No valid text messages to send".to_string())); } - + // Build generation config let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() { Some(GeminiGenerationConfig { @@ -162,51 +152,65 @@ impl super::Provider for GeminiProvider { } else { None }; - + let gemini_request = GeminiRequest { contents, generation_config, }; - + // Build URL - let url = format!("{}/models/{}:generateContent?key={}", - self.config.base_url, - request.model, - self.api_key - ); - + let url = format!("{}/models/{}:generateContent", self.config.base_url, request.model,); + // Send request - let response = self.client + let response = self + .client .post(&url) + .header("x-goog-api-key", &self.api_key) .json(&gemini_request) .send() .await .map_err(|e| AppError::ProviderError(format!("HTTP request failed: {}", e)))?; - + // Check status let status = response.status(); if !status.is_success() { let error_text = response.text().await.unwrap_or_default(); - return Err(AppError::ProviderError(format!("Gemini API error ({}): {}", status, error_text))); + return Err(AppError::ProviderError(format!( + "Gemini API error ({}): {}", + status, error_text + ))); } - + let gemini_response: GeminiResponse = response .json() .await .map_err(|e| AppError::ProviderError(format!("Failed to parse response: {}", e)))?; - + // Extract content from first candidate - let content = gemini_response.candidates + let content = gemini_response + .candidates .first() .and_then(|c| c.content.parts.first()) .and_then(|p| p.text.clone()) .unwrap_or_default(); - + // Extract token usage - let prompt_tokens = gemini_response.usage_metadata.as_ref().map(|u| u.prompt_token_count).unwrap_or(0); - let completion_tokens = gemini_response.usage_metadata.as_ref().map(|u| u.candidates_token_count).unwrap_or(0); - let total_tokens = gemini_response.usage_metadata.as_ref().map(|u| u.total_token_count).unwrap_or(0); - + let prompt_tokens = gemini_response + .usage_metadata + .as_ref() + .map(|u| u.prompt_token_count) + .unwrap_or(0); + let completion_tokens = gemini_response + .usage_metadata + .as_ref() + .map(|u| u.candidates_token_count) + .unwrap_or(0); + let total_tokens = gemini_response + .usage_metadata + .as_ref() + .map(|u| u.total_token_count) + .unwrap_or(0); + Ok(ProviderResponse { content, reasoning_content: None, // Gemini doesn't use this field name @@ -221,20 +225,22 @@ impl super::Provider for GeminiProvider { Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request)) } - fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64 { - if let Some(metadata) = registry.find_model(model) { - if let Some(cost) = &metadata.cost { - return (prompt_tokens as f64 * cost.input / 1_000_000.0) + - (completion_tokens as f64 * cost.output / 1_000_000.0); - } - } - - let (prompt_rate, completion_rate) = self.pricing.iter() - .find(|p| model.contains(&p.model)) - .map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million)) - .unwrap_or((0.075, 0.30)); // Default to Gemini 2.0 Flash price if not found - - (prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0) + fn calculate_cost( + &self, + model: &str, + prompt_tokens: u32, + completion_tokens: u32, + registry: &crate::models::registry::ModelRegistry, + ) -> f64 { + super::helpers::calculate_cost_with_registry( + model, + prompt_tokens, + completion_tokens, + registry, + &self.pricing, + 0.075, + 0.30, + ) } async fn chat_completion_stream( @@ -243,10 +249,10 @@ impl super::Provider for GeminiProvider { ) -> Result>, AppError> { // Convert UnifiedRequest to Gemini request let mut contents = Vec::with_capacity(request.messages.len()); - + for msg in request.messages { let mut parts = Vec::with_capacity(msg.content.len()); - + for part in msg.content { match part { crate::models::ContentPart::Text { text } => { @@ -256,9 +262,11 @@ impl super::Provider for GeminiProvider { }); } crate::models::ContentPart::Image(image_input) => { - let (base64_data, mime_type) = image_input.to_base64().await + let (base64_data, mime_type) = image_input + .to_base64() + .await .map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?; - + parts.push(GeminiPart { text: None, inline_data: Some(GeminiInlineData { @@ -269,19 +277,16 @@ impl super::Provider for GeminiProvider { } } } - + // Map role let role = match msg.role.as_str() { "assistant" => "model".to_string(), _ => "user".to_string(), }; - - contents.push(GeminiContent { - parts, - role, - }); + + contents.push(GeminiContent { parts, role }); } - + // Build generation config let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() { Some(GeminiGenerationConfig { @@ -291,28 +296,32 @@ impl super::Provider for GeminiProvider { } else { None }; - + let gemini_request = GeminiRequest { contents, generation_config, }; - - // Build URL for streaming - let url = format!("{}/models/{}:streamGenerateContent?alt=sse&key={}", - self.config.base_url, - request.model, - self.api_key - ); - - // Create eventsource stream - use reqwest_eventsource::{EventSource, Event}; - use futures::StreamExt; - let es = EventSource::new(self.client.post(&url).json(&gemini_request)) - .map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?; - + // Build URL for streaming + let url = format!( + "{}/models/{}:streamGenerateContent?alt=sse", + self.config.base_url, request.model, + ); + + // Create eventsource stream + use futures::StreamExt; + use reqwest_eventsource::{Event, EventSource}; + + let es = EventSource::new( + self.client + .post(&url) + .header("x-goog-api-key", &self.api_key) + .json(&gemini_request), + ) + .map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?; + let model = request.model.clone(); - + let stream = async_stream::try_stream! { let mut es = es; while let Some(event) = es.next().await { @@ -320,12 +329,12 @@ impl super::Provider for GeminiProvider { Ok(Event::Message(msg)) => { let gemini_response: GeminiResponse = serde_json::from_str(&msg.data) .map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?; - + if let Some(candidate) = gemini_response.candidates.first() { let content = candidate.content.parts.first() .and_then(|p| p.text.clone()) .unwrap_or_default(); - + yield ProviderStreamChunk { content, reasoning_content: None, @@ -341,7 +350,7 @@ impl super::Provider for GeminiProvider { } } }; - + Ok(Box::pin(stream)) } -} \ No newline at end of file +} diff --git a/src/providers/grok.rs b/src/providers/grok.rs index 023b7872..2c81e77b 100644 --- a/src/providers/grok.rs +++ b/src/providers/grok.rs @@ -1,18 +1,14 @@ -use async_trait::async_trait; use anyhow::Result; -use futures::stream::{BoxStream, StreamExt}; -use serde_json::Value; +use async_trait::async_trait; +use futures::stream::BoxStream; -use crate::{ - models::UnifiedRequest, - errors::AppError, - config::AppConfig, -}; +use super::helpers; use super::{ProviderResponse, ProviderStreamChunk}; +use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest}; pub struct GrokProvider { client: reqwest::Client, - _config: crate::config::GrokConfig, + config: crate::config::GrokConfig, api_key: String, pricing: Vec, } @@ -26,7 +22,7 @@ impl GrokProvider { pub fn new_with_key(config: &crate::config::GrokConfig, app_config: &AppConfig, api_key: String) -> Result { Ok(Self { client: reqwest::Client::new(), - _config: config.clone(), + config: config.clone(), api_key, pricing: app_config.pricing.grok.clone(), }) @@ -47,40 +43,13 @@ impl super::Provider for GrokProvider { true } - async fn chat_completion( - &self, - request: UnifiedRequest, - ) -> Result { - let mut body = serde_json::json!({ - "model": request.model, - "messages": request.messages.iter().map(|m| { - serde_json::json!({ - "role": m.role, - "content": m.content.iter().map(|p| { - match p { - crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }), - crate::models::ContentPart::Image(image_input) => { - let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default(); - serde_json::json!({ - "type": "image_url", - "image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) } - }) - } - } - }).collect::>() - }) - }).collect::>(), - "stream": false, - }); + async fn chat_completion(&self, request: UnifiedRequest) -> Result { + let messages_json = helpers::messages_to_openai_json(&request.messages).await?; + let body = helpers::build_openai_body(&request, messages_json, false); - if let Some(temp) = request.temperature { - body["temperature"] = serde_json::json!(temp); - } - if let Some(max_tokens) = request.max_tokens { - body["max_tokens"] = serde_json::json!(max_tokens); - } - - let response = self.client.post(format!("{}/chat/completions", self._config.base_url)) + let response = self + .client + .post(format!("{}/chat/completions", self.config.base_url)) .header("Authorization", format!("Bearer {}", self.api_key)) .json(&body) .send() @@ -92,125 +61,51 @@ impl super::Provider for GrokProvider { return Err(AppError::ProviderError(format!("Grok API error: {}", error_text))); } - let resp_json: Value = response.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?; - - let choice = resp_json["choices"].get(0).ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?; - let message = &choice["message"]; - - let content = message["content"].as_str().unwrap_or_default().to_string(); - let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string()); - - let usage = &resp_json["usage"]; - let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32; - let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32; - let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32; + let resp_json: serde_json::Value = response + .json() + .await + .map_err(|e| AppError::ProviderError(e.to_string()))?; - Ok(ProviderResponse { - content, - reasoning_content, - prompt_tokens, - completion_tokens, - total_tokens, - model: request.model, - }) + helpers::parse_openai_response(&resp_json, request.model) } fn estimate_tokens(&self, request: &UnifiedRequest) -> Result { Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request)) } - fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64 { - if let Some(metadata) = registry.find_model(model) { - if let Some(cost) = &metadata.cost { - return (prompt_tokens as f64 * cost.input / 1_000_000.0) + - (completion_tokens as f64 * cost.output / 1_000_000.0); - } - } - - let (prompt_rate, completion_rate) = self.pricing.iter() - .find(|p| model.contains(&p.model)) - .map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million)) - .unwrap_or((5.0, 15.0)); - - (prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0) + fn calculate_cost( + &self, + model: &str, + prompt_tokens: u32, + completion_tokens: u32, + registry: &crate::models::registry::ModelRegistry, + ) -> f64 { + helpers::calculate_cost_with_registry( + model, + prompt_tokens, + completion_tokens, + registry, + &self.pricing, + 5.0, + 15.0, + ) } async fn chat_completion_stream( &self, request: UnifiedRequest, ) -> Result>, AppError> { - let mut body = serde_json::json!({ - "model": request.model, - "messages": request.messages.iter().map(|m| { - serde_json::json!({ - "role": m.role, - "content": m.content.iter().map(|p| { - match p { - crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }), - crate::models::ContentPart::Image(image_input) => { - let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default(); - serde_json::json!({ - "type": "image_url", - "image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) } - }) - } - } - }).collect::>() - }) - }).collect::>(), - "stream": true, - }); + let messages_json = helpers::messages_to_openai_json(&request.messages).await?; + let body = helpers::build_openai_body(&request, messages_json, true); - if let Some(temp) = request.temperature { - body["temperature"] = serde_json::json!(temp); - } - if let Some(max_tokens) = request.max_tokens { - body["max_tokens"] = serde_json::json!(max_tokens); - } + let es = reqwest_eventsource::EventSource::new( + self.client + .post(format!("{}/chat/completions", self.config.base_url)) + .header("Authorization", format!("Bearer {}", self.api_key)) + .json(&body), + ) + .map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?; - // Create eventsource stream - use reqwest_eventsource::{EventSource, Event}; - let es = EventSource::new(self.client.post(format!("{}/chat/completions", self._config.base_url)) - .header("Authorization", format!("Bearer {}", self.api_key)) - .json(&body)) - .map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?; - - let model = request.model.clone(); - - let stream = async_stream::try_stream! { - let mut es = es; - while let Some(event) = es.next().await { - match event { - Ok(Event::Message(msg)) => { - if msg.data == "[DONE]" { - break; - } - - let chunk: Value = serde_json::from_str(&msg.data) - .map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?; - - if let Some(choice) = chunk["choices"].get(0) { - let delta = &choice["delta"]; - let content = delta["content"].as_str().unwrap_or_default().to_string(); - let reasoning_content = delta["reasoning_content"].as_str().map(|s| s.to_string()); - let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string()); - - yield ProviderStreamChunk { - content, - reasoning_content, - finish_reason, - model: model.clone(), - }; - } - } - Ok(_) => continue, - Err(e) => { - Err(AppError::ProviderError(format!("Stream error: {}", e)))?; - } - } - } - }; - - Ok(Box::pin(stream)) + Ok(helpers::create_openai_stream(es, request.model, None)) } } diff --git a/src/providers/helpers.rs b/src/providers/helpers.rs new file mode 100644 index 00000000..c0542f53 --- /dev/null +++ b/src/providers/helpers.rs @@ -0,0 +1,189 @@ +use super::{ProviderResponse, ProviderStreamChunk}; +use crate::errors::AppError; +use crate::models::{ContentPart, UnifiedMessage, UnifiedRequest}; +use futures::stream::{BoxStream, StreamExt}; +use serde_json::Value; + +/// Convert messages to OpenAI-compatible JSON, resolving images asynchronously. +/// +/// This avoids the deadlock caused by `futures::executor::block_on` inside a +/// Tokio async context. All image base64 conversions are awaited properly. +pub async fn messages_to_openai_json(messages: &[UnifiedMessage]) -> Result, AppError> { + let mut result = Vec::new(); + for m in messages { + let mut parts = Vec::new(); + for p in &m.content { + match p { + ContentPart::Text { text } => { + parts.push(serde_json::json!({ "type": "text", "text": text })); + } + ContentPart::Image(image_input) => { + let (base64_data, mime_type) = image_input + .to_base64() + .await + .map_err(|e| AppError::MultimodalError(e.to_string()))?; + parts.push(serde_json::json!({ + "type": "image_url", + "image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) } + })); + } + } + } + result.push(serde_json::json!({ + "role": m.role, + "content": parts + })); + } + Ok(result) +} + +/// Convert messages to OpenAI-compatible JSON, but replace images with a +/// text placeholder "[Image]". Useful for providers that don't support +/// multimodal in streaming mode or at all. +pub async fn messages_to_openai_json_text_only( + messages: &[UnifiedMessage], +) -> Result, AppError> { + let mut result = Vec::new(); + for m in messages { + let mut parts = Vec::new(); + for p in &m.content { + match p { + ContentPart::Text { text } => { + parts.push(serde_json::json!({ "type": "text", "text": text })); + } + ContentPart::Image(_) => { + parts.push(serde_json::json!({ "type": "text", "text": "[Image]" })); + } + } + } + result.push(serde_json::json!({ + "role": m.role, + "content": parts + })); + } + Ok(result) +} + +/// Build an OpenAI-compatible request body from a UnifiedRequest and pre-converted messages. +pub fn build_openai_body( + request: &UnifiedRequest, + messages_json: Vec, + stream: bool, +) -> serde_json::Value { + let mut body = serde_json::json!({ + "model": request.model, + "messages": messages_json, + "stream": stream, + }); + + if let Some(temp) = request.temperature { + body["temperature"] = serde_json::json!(temp); + } + if let Some(max_tokens) = request.max_tokens { + body["max_tokens"] = serde_json::json!(max_tokens); + } + + body +} + +/// Parse an OpenAI-compatible chat completion response JSON into a ProviderResponse. +pub fn parse_openai_response(resp_json: &Value, model: String) -> Result { + let choice = resp_json["choices"] + .get(0) + .ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?; + let message = &choice["message"]; + + let content = message["content"].as_str().unwrap_or_default().to_string(); + let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string()); + + let usage = &resp_json["usage"]; + let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32; + let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32; + let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32; + + Ok(ProviderResponse { + content, + reasoning_content, + prompt_tokens, + completion_tokens, + total_tokens, + model, + }) +} + +/// Create an SSE stream that parses OpenAI-compatible streaming chunks. +/// +/// The optional `reasoning_field` allows overriding the field name for +/// reasoning content (e.g., "thought" for Ollama). +pub fn create_openai_stream( + es: reqwest_eventsource::EventSource, + model: String, + reasoning_field: Option<&'static str>, +) -> BoxStream<'static, Result> { + use reqwest_eventsource::Event; + + let stream = async_stream::try_stream! { + let mut es = es; + while let Some(event) = es.next().await { + match event { + Ok(Event::Message(msg)) => { + if msg.data == "[DONE]" { + break; + } + + let chunk: Value = serde_json::from_str(&msg.data) + .map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?; + + if let Some(choice) = chunk["choices"].get(0) { + let delta = &choice["delta"]; + let content = delta["content"].as_str().unwrap_or_default().to_string(); + let reasoning_content = delta["reasoning_content"] + .as_str() + .or_else(|| reasoning_field.and_then(|f| delta[f].as_str())) + .map(|s| s.to_string()); + let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string()); + + yield ProviderStreamChunk { + content, + reasoning_content, + finish_reason, + model: model.clone(), + }; + } + } + Ok(_) => continue, + Err(e) => { + Err(AppError::ProviderError(format!("Stream error: {}", e)))?; + } + } + } + }; + + Box::pin(stream) +} + +/// Calculate cost using the model registry first, then falling back to provider pricing config. +pub fn calculate_cost_with_registry( + model: &str, + prompt_tokens: u32, + completion_tokens: u32, + registry: &crate::models::registry::ModelRegistry, + pricing: &[crate::config::ModelPricing], + default_prompt_rate: f64, + default_completion_rate: f64, +) -> f64 { + if let Some(metadata) = registry.find_model(model) + && let Some(cost) = &metadata.cost + { + return (prompt_tokens as f64 * cost.input / 1_000_000.0) + + (completion_tokens as f64 * cost.output / 1_000_000.0); + } + + let (prompt_rate, completion_rate) = pricing + .iter() + .find(|p| model.contains(&p.model)) + .map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million)) + .unwrap_or((default_prompt_rate, default_completion_rate)); + + (prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0) +} diff --git a/src/providers/mod.rs b/src/providers/mod.rs index 39a10bbc..d64b3d79 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -1,17 +1,18 @@ -use async_trait::async_trait; use anyhow::Result; -use std::sync::Arc; +use async_trait::async_trait; use futures::stream::BoxStream; use sqlx::Row; +use std::sync::Arc; -use crate::models::UnifiedRequest; use crate::errors::AppError; +use crate::models::UnifiedRequest; -pub mod openai; -pub mod gemini; pub mod deepseek; +pub mod gemini; pub mod grok; +pub mod helpers; pub mod ollama; +pub mod openai; #[async_trait] pub trait Provider: Send + Sync { @@ -25,10 +26,7 @@ pub trait Provider: Send + Sync { fn supports_multimodal(&self) -> bool; /// Process a chat completion request - async fn chat_completion( - &self, - request: UnifiedRequest, - ) -> Result; + async fn chat_completion(&self, request: UnifiedRequest) -> Result; /// Process a streaming chat completion request async fn chat_completion_stream( @@ -40,7 +38,13 @@ pub trait Provider: Send + Sync { fn estimate_tokens(&self, request: &UnifiedRequest) -> Result; /// Calculate cost based on token usage and model using the registry - fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64; + fn calculate_cost( + &self, + model: &str, + prompt_tokens: u32, + completion_tokens: u32, + registry: &crate::models::registry::ModelRegistry, + ) -> f64; } pub struct ProviderResponse { @@ -64,11 +68,8 @@ use tokio::sync::RwLock; use crate::config::AppConfig; use crate::providers::{ + deepseek::DeepSeekProvider, gemini::GeminiProvider, grok::GrokProvider, ollama::OllamaProvider, openai::OpenAIProvider, - gemini::GeminiProvider, - deepseek::DeepSeekProvider, - grok::GrokProvider, - ollama::OllamaProvider, }; #[derive(Clone)] @@ -76,6 +77,12 @@ pub struct ProviderManager { providers: Arc>>>, } +impl Default for ProviderManager { + fn default() -> Self { + Self::new() + } +} + impl ProviderManager { pub fn new() -> Self { Self { @@ -84,7 +91,12 @@ impl ProviderManager { } /// Initialize a provider by name using config and database overrides - pub async fn initialize_provider(&self, name: &str, app_config: &AppConfig, db_pool: &crate::database::DbPool) -> Result<()> { + pub async fn initialize_provider( + &self, + name: &str, + app_config: &AppConfig, + db_pool: &crate::database::DbPool, + ) -> Result<()> { // Load override from database let db_config = sqlx::query("SELECT enabled, base_url, api_key FROM provider_configs WHERE id = ?") .bind(name) @@ -100,11 +112,31 @@ impl ProviderManager { } else { // No database override, use defaults from AppConfig match name { - "openai" => (app_config.providers.openai.enabled, Some(app_config.providers.openai.base_url.clone()), None), - "gemini" => (app_config.providers.gemini.enabled, Some(app_config.providers.gemini.base_url.clone()), None), - "deepseek" => (app_config.providers.deepseek.enabled, Some(app_config.providers.deepseek.base_url.clone()), None), - "grok" => (app_config.providers.grok.enabled, Some(app_config.providers.grok.base_url.clone()), None), - "ollama" => (app_config.providers.ollama.enabled, Some(app_config.providers.ollama.base_url.clone()), None), + "openai" => ( + app_config.providers.openai.enabled, + Some(app_config.providers.openai.base_url.clone()), + None, + ), + "gemini" => ( + app_config.providers.gemini.enabled, + Some(app_config.providers.gemini.base_url.clone()), + None, + ), + "deepseek" => ( + app_config.providers.deepseek.enabled, + Some(app_config.providers.deepseek.base_url.clone()), + None, + ), + "grok" => ( + app_config.providers.grok.enabled, + Some(app_config.providers.grok.base_url.clone()), + None, + ), + "ollama" => ( + app_config.providers.ollama.enabled, + Some(app_config.providers.ollama.base_url.clone()), + None, + ), _ => (false, None, None), } }; @@ -118,7 +150,9 @@ impl ProviderManager { let provider: Arc = match name { "openai" => { let mut cfg = app_config.providers.openai.clone(); - if let Some(url) = base_url { cfg.base_url = url; } + if let Some(url) = base_url { + cfg.base_url = url; + } // Handle API key override if present let p = if let Some(key) = api_key { // We need a way to create a provider with an explicit key @@ -128,42 +162,50 @@ impl ProviderManager { OpenAIProvider::new(&cfg, app_config)? }; Arc::new(p) - }, + } "ollama" => { let mut cfg = app_config.providers.ollama.clone(); - if let Some(url) = base_url { cfg.base_url = url; } + if let Some(url) = base_url { + cfg.base_url = url; + } Arc::new(OllamaProvider::new(&cfg, app_config)?) - }, + } "gemini" => { let mut cfg = app_config.providers.gemini.clone(); - if let Some(url) = base_url { cfg.base_url = url; } + if let Some(url) = base_url { + cfg.base_url = url; + } let p = if let Some(key) = api_key { GeminiProvider::new_with_key(&cfg, app_config, key)? } else { GeminiProvider::new(&cfg, app_config)? }; Arc::new(p) - }, + } "deepseek" => { let mut cfg = app_config.providers.deepseek.clone(); - if let Some(url) = base_url { cfg.base_url = url; } + if let Some(url) = base_url { + cfg.base_url = url; + } let p = if let Some(key) = api_key { DeepSeekProvider::new_with_key(&cfg, app_config, key)? } else { DeepSeekProvider::new(&cfg, app_config)? }; Arc::new(p) - }, + } "grok" => { let mut cfg = app_config.providers.grok.clone(); - if let Some(url) = base_url { cfg.base_url = url; } + if let Some(url) = base_url { + cfg.base_url = url; + } let p = if let Some(key) = api_key { GrokProvider::new_with_key(&cfg, app_config, key)? } else { GrokProvider::new(&cfg, app_config)? }; Arc::new(p) - }, + } _ => return Err(anyhow::anyhow!("Unknown provider: {}", name)), }; @@ -188,16 +230,12 @@ impl ProviderManager { pub async fn get_provider_for_model(&self, model: &str) -> Option> { let providers = self.providers.read().await; - providers.iter() - .find(|p| p.supports_model(model)) - .map(|p| Arc::clone(p)) + providers.iter().find(|p| p.supports_model(model)).map(Arc::clone) } pub async fn get_provider(&self, name: &str) -> Option> { let providers = self.providers.read().await; - providers.iter() - .find(|p| p.name() == name) - .map(|p| Arc::clone(p)) + providers.iter().find(|p| p.name() == name).map(Arc::clone) } pub async fn get_all_providers(&self) -> Vec> { @@ -238,22 +276,30 @@ pub mod placeholder { &self, _request: UnifiedRequest, ) -> Result>, AppError> { - Err(AppError::ProviderError("Streaming not supported for placeholder provider".to_string())) + Err(AppError::ProviderError( + "Streaming not supported for placeholder provider".to_string(), + )) } - async fn chat_completion( - &self, - _request: UnifiedRequest, - ) -> Result { - Err(AppError::ProviderError(format!("Provider {} not implemented", self.name))) + async fn chat_completion(&self, _request: UnifiedRequest) -> Result { + Err(AppError::ProviderError(format!( + "Provider {} not implemented", + self.name + ))) } fn estimate_tokens(&self, _request: &UnifiedRequest) -> Result { Ok(0) } - fn calculate_cost(&self, _model: &str, _prompt_tokens: u32, _completion_tokens: u32, _registry: &crate::models::registry::ModelRegistry) -> f64 { + fn calculate_cost( + &self, + _model: &str, + _prompt_tokens: u32, + _completion_tokens: u32, + _registry: &crate::models::registry::ModelRegistry, + ) -> f64 { 0.0 } } -} \ No newline at end of file +} diff --git a/src/providers/ollama.rs b/src/providers/ollama.rs index 41343ce3..850c6330 100644 --- a/src/providers/ollama.rs +++ b/src/providers/ollama.rs @@ -1,18 +1,14 @@ -use async_trait::async_trait; use anyhow::Result; -use futures::stream::{BoxStream, StreamExt}; -use serde_json::Value; +use async_trait::async_trait; +use futures::stream::BoxStream; -use crate::{ - models::UnifiedRequest, - errors::AppError, - config::AppConfig, -}; +use super::helpers; use super::{ProviderResponse, ProviderStreamChunk}; +use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest}; pub struct OllamaProvider { client: reqwest::Client, - _config: crate::config::OllamaConfig, + config: crate::config::OllamaConfig, pricing: Vec, } @@ -20,7 +16,7 @@ impl OllamaProvider { pub fn new(config: &crate::config::OllamaConfig, app_config: &AppConfig) -> Result { Ok(Self { client: reqwest::Client::new(), - _config: config.clone(), + config: config.clone(), pricing: app_config.pricing.ollama.clone(), }) } @@ -33,49 +29,29 @@ impl super::Provider for OllamaProvider { } fn supports_model(&self, model: &str) -> bool { - self._config.models.iter().any(|m| m == model) || model.starts_with("ollama/") + self.config.models.iter().any(|m| m == model) || model.starts_with("ollama/") } fn supports_multimodal(&self) -> bool { true } - async fn chat_completion( - &self, - request: UnifiedRequest, - ) -> Result { - let model = request.model.strip_prefix("ollama/").unwrap_or(&request.model).to_string(); + async fn chat_completion(&self, mut request: UnifiedRequest) -> Result { + // Strip "ollama/" prefix if present for the API call + let api_model = request + .model + .strip_prefix("ollama/") + .unwrap_or(&request.model) + .to_string(); + let original_model = request.model.clone(); + request.model = api_model; - let mut body = serde_json::json!({ - "model": model, - "messages": request.messages.iter().map(|m| { - serde_json::json!({ - "role": m.role, - "content": m.content.iter().map(|p| { - match p { - crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }), - crate::models::ContentPart::Image(image_input) => { - let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default(); - serde_json::json!({ - "type": "image_url", - "image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) } - }) - } - } - }).collect::>() - }) - }).collect::>(), - "stream": false, - }); + let messages_json = helpers::messages_to_openai_json(&request.messages).await?; + let body = helpers::build_openai_body(&request, messages_json, false); - if let Some(temp) = request.temperature { - body["temperature"] = serde_json::json!(temp); - } - if let Some(max_tokens) = request.max_tokens { - body["max_tokens"] = serde_json::json!(max_tokens); - } - - let response = self.client.post(format!("{}/chat/completions", self._config.base_url)) + let response = self + .client + .post(format!("{}/chat/completions", self.config.base_url)) .json(&body) .send() .await @@ -86,120 +62,67 @@ impl super::Provider for OllamaProvider { return Err(AppError::ProviderError(format!("Ollama API error: {}", error_text))); } - let resp_json: Value = response.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?; - - let choice = resp_json["choices"].get(0).ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?; - let message = &choice["message"]; - - let content = message["content"].as_str().unwrap_or_default().to_string(); - let reasoning_content = message["reasoning_content"].as_str().or_else(|| message["thought"].as_str()).map(|s| s.to_string()); - - let usage = &resp_json["usage"]; - let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32; - let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32; - let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32; + let resp_json: serde_json::Value = response + .json() + .await + .map_err(|e| AppError::ProviderError(e.to_string()))?; - Ok(ProviderResponse { - content, - reasoning_content, - prompt_tokens, - completion_tokens, - total_tokens, - model: request.model, - }) + // Ollama also supports "thought" as an alias for reasoning_content + let mut result = helpers::parse_openai_response(&resp_json, original_model)?; + if result.reasoning_content.is_none() { + result.reasoning_content = resp_json["choices"] + .get(0) + .and_then(|c| c["message"]["thought"].as_str()) + .map(|s| s.to_string()); + } + Ok(result) } fn estimate_tokens(&self, request: &UnifiedRequest) -> Result { Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request)) } - fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64 { - if let Some(metadata) = registry.find_model(model) { - if let Some(cost) = &metadata.cost { - return (prompt_tokens as f64 * cost.input / 1_000_000.0) + - (completion_tokens as f64 * cost.output / 1_000_000.0); - } - } - - let (prompt_rate, completion_rate) = self.pricing.iter() - .find(|p| model.contains(&p.model)) - .map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million)) - .unwrap_or((0.0, 0.0)); - - (prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0) + fn calculate_cost( + &self, + model: &str, + prompt_tokens: u32, + completion_tokens: u32, + registry: &crate::models::registry::ModelRegistry, + ) -> f64 { + helpers::calculate_cost_with_registry( + model, + prompt_tokens, + completion_tokens, + registry, + &self.pricing, + 0.0, + 0.0, + ) } async fn chat_completion_stream( &self, - request: UnifiedRequest, + mut request: UnifiedRequest, ) -> Result>, AppError> { - let model = request.model.strip_prefix("ollama/").unwrap_or(&request.model).to_string(); + let api_model = request + .model + .strip_prefix("ollama/") + .unwrap_or(&request.model) + .to_string(); + let original_model = request.model.clone(); + request.model = api_model; - let mut body = serde_json::json!({ - "model": model, - "messages": request.messages.iter().map(|m| { - serde_json::json!({ - "role": m.role, - "content": m.content.iter().map(|p| { - match p { - crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }), - crate::models::ContentPart::Image(_) => serde_json::json!({ "type": "text", "text": "[Image]" }), - } - }).collect::>() - }) - }).collect::>(), - "stream": true, - }); + let messages_json = helpers::messages_to_openai_json_text_only(&request.messages).await?; + let body = helpers::build_openai_body(&request, messages_json, true); - if let Some(temp) = request.temperature { - body["temperature"] = serde_json::json!(temp); - } - if let Some(max_tokens) = request.max_tokens { - body["max_tokens"] = serde_json::json!(max_tokens); - } + let es = reqwest_eventsource::EventSource::new( + self.client + .post(format!("{}/chat/completions", self.config.base_url)) + .json(&body), + ) + .map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?; - // Create eventsource stream - use reqwest_eventsource::{EventSource, Event}; - let es = EventSource::new(self.client.post(format!("{}/chat/completions", self._config.base_url)) - .json(&body)) - .map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?; - - let model_name = request.model.clone(); - - let stream = async_stream::try_stream! { - let mut es = es; - while let Some(event) = es.next().await { - match event { - Ok(Event::Message(msg)) => { - if msg.data == "[DONE]" { - break; - } - - let chunk: Value = serde_json::from_str(&msg.data) - .map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?; - - if let Some(choice) = chunk["choices"].get(0) { - let delta = &choice["delta"]; - let content = delta["content"].as_str().unwrap_or_default().to_string(); - let reasoning_content = delta["reasoning_content"].as_str().or_else(|| delta["thought"].as_str()).map(|s| s.to_string()); - let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string()); - - yield ProviderStreamChunk { - content, - reasoning_content, - finish_reason, - model: model_name.clone(), - }; - } - } - Ok(_) => continue, - Err(e) => { - Err(AppError::ProviderError(format!("Stream error: {}", e)))?; - } - } - } - }; - - Ok(Box::pin(stream)) + // Ollama uses "thought" as an alternative field for reasoning content + Ok(helpers::create_openai_stream(es, original_model, Some("thought"))) } } diff --git a/src/providers/openai.rs b/src/providers/openai.rs index 8da5c022..ca14e8f2 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -1,18 +1,14 @@ -use async_trait::async_trait; use anyhow::Result; -use futures::stream::{BoxStream, StreamExt}; -use serde_json::Value; +use async_trait::async_trait; +use futures::stream::BoxStream; -use crate::{ - models::UnifiedRequest, - errors::AppError, - config::AppConfig, -}; +use super::helpers; use super::{ProviderResponse, ProviderStreamChunk}; +use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest}; pub struct OpenAIProvider { client: reqwest::Client, - _config: crate::config::OpenAIConfig, + config: crate::config::OpenAIConfig, api_key: String, pricing: Vec, } @@ -26,7 +22,7 @@ impl OpenAIProvider { pub fn new_with_key(config: &crate::config::OpenAIConfig, app_config: &AppConfig, api_key: String) -> Result { Ok(Self { client: reqwest::Client::new(), - _config: config.clone(), + config: config.clone(), api_key, pricing: app_config.pricing.openai.clone(), }) @@ -47,40 +43,13 @@ impl super::Provider for OpenAIProvider { true } - async fn chat_completion( - &self, - request: UnifiedRequest, - ) -> Result { - let mut body = serde_json::json!({ - "model": request.model, - "messages": request.messages.iter().map(|m| { - serde_json::json!({ - "role": m.role, - "content": m.content.iter().map(|p| { - match p { - crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }), - crate::models::ContentPart::Image(image_input) => { - let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default(); - serde_json::json!({ - "type": "image_url", - "image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) } - }) - } - } - }).collect::>() - }) - }).collect::>(), - "stream": false, - }); + async fn chat_completion(&self, request: UnifiedRequest) -> Result { + let messages_json = helpers::messages_to_openai_json(&request.messages).await?; + let body = helpers::build_openai_body(&request, messages_json, false); - if let Some(temp) = request.temperature { - body["temperature"] = serde_json::json!(temp); - } - if let Some(max_tokens) = request.max_tokens { - body["max_tokens"] = serde_json::json!(max_tokens); - } - - let response = self.client.post(format!("{}/chat/completions", self._config.base_url)) + let response = self + .client + .post(format!("{}/chat/completions", self.config.base_url)) .header("Authorization", format!("Bearer {}", self.api_key)) .json(&body) .send() @@ -92,125 +61,51 @@ impl super::Provider for OpenAIProvider { return Err(AppError::ProviderError(format!("OpenAI API error: {}", error_text))); } - let resp_json: Value = response.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?; - - let choice = resp_json["choices"].get(0).ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?; - let message = &choice["message"]; - - let content = message["content"].as_str().unwrap_or_default().to_string(); - let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string()); - - let usage = &resp_json["usage"]; - let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32; - let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32; - let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32; + let resp_json: serde_json::Value = response + .json() + .await + .map_err(|e| AppError::ProviderError(e.to_string()))?; - Ok(ProviderResponse { - content, - reasoning_content, - prompt_tokens, - completion_tokens, - total_tokens, - model: request.model, - }) + helpers::parse_openai_response(&resp_json, request.model) } fn estimate_tokens(&self, request: &UnifiedRequest) -> Result { Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request)) } - fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64 { - if let Some(metadata) = registry.find_model(model) { - if let Some(cost) = &metadata.cost { - return (prompt_tokens as f64 * cost.input / 1_000_000.0) + - (completion_tokens as f64 * cost.output / 1_000_000.0); - } - } - - let (prompt_rate, completion_rate) = self.pricing.iter() - .find(|p| model.contains(&p.model)) - .map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million)) - .unwrap_or((0.15, 0.60)); - - (prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0) + fn calculate_cost( + &self, + model: &str, + prompt_tokens: u32, + completion_tokens: u32, + registry: &crate::models::registry::ModelRegistry, + ) -> f64 { + helpers::calculate_cost_with_registry( + model, + prompt_tokens, + completion_tokens, + registry, + &self.pricing, + 0.15, + 0.60, + ) } async fn chat_completion_stream( &self, request: UnifiedRequest, ) -> Result>, AppError> { - let mut body = serde_json::json!({ - "model": request.model, - "messages": request.messages.iter().map(|m| { - serde_json::json!({ - "role": m.role, - "content": m.content.iter().map(|p| { - match p { - crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }), - crate::models::ContentPart::Image(image_input) => { - let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default(); - serde_json::json!({ - "type": "image_url", - "image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) } - }) - } - } - }).collect::>() - }) - }).collect::>(), - "stream": true, - }); + let messages_json = helpers::messages_to_openai_json(&request.messages).await?; + let body = helpers::build_openai_body(&request, messages_json, true); - if let Some(temp) = request.temperature { - body["temperature"] = serde_json::json!(temp); - } - if let Some(max_tokens) = request.max_tokens { - body["max_tokens"] = serde_json::json!(max_tokens); - } + let es = reqwest_eventsource::EventSource::new( + self.client + .post(format!("{}/chat/completions", self.config.base_url)) + .header("Authorization", format!("Bearer {}", self.api_key)) + .json(&body), + ) + .map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?; - // Create eventsource stream - use reqwest_eventsource::{EventSource, Event}; - let es = EventSource::new(self.client.post(format!("{}/chat/completions", self._config.base_url)) - .header("Authorization", format!("Bearer {}", self.api_key)) - .json(&body)) - .map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?; - - let model = request.model.clone(); - - let stream = async_stream::try_stream! { - let mut es = es; - while let Some(event) = es.next().await { - match event { - Ok(Event::Message(msg)) => { - if msg.data == "[DONE]" { - break; - } - - let chunk: Value = serde_json::from_str(&msg.data) - .map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?; - - if let Some(choice) = chunk["choices"].get(0) { - let delta = &choice["delta"]; - let content = delta["content"].as_str().unwrap_or_default().to_string(); - let reasoning_content = delta["reasoning_content"].as_str().map(|s| s.to_string()); - let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string()); - - yield ProviderStreamChunk { - content, - reasoning_content, - finish_reason, - model: model.clone(), - }; - } - } - Ok(_) => continue, - Err(e) => { - Err(AppError::ProviderError(format!("Stream error: {}", e)))?; - } - } - } - }; - - Ok(Box::pin(stream)) + Ok(helpers::create_openai_stream(es, request.model, None)) } } diff --git a/src/rate_limiting/mod.rs b/src/rate_limiting/mod.rs index ef9929b4..2b40e672 100644 --- a/src/rate_limiting/mod.rs +++ b/src/rate_limiting/mod.rs @@ -5,12 +5,12 @@ //! 2. Provider circuit breaking to handle API failures //! 3. Global rate limiting for overall system protection -use std::sync::Arc; +use anyhow::Result; use std::collections::HashMap; +use std::sync::Arc; use std::time::Instant; use tokio::sync::RwLock; use tracing::{info, warn}; -use anyhow::Result; /// Rate limiter configuration #[derive(Debug, Clone)] @@ -26,8 +26,8 @@ pub struct RateLimiterConfig { impl Default for RateLimiterConfig { fn default() -> Self { Self { - requests_per_minute: 60, // 1 request per second per client - burst_size: 10, // Allow bursts of up to 10 requests + requests_per_minute: 60, // 1 request per second per client + burst_size: 10, // Allow bursts of up to 10 requests global_requests_per_minute: 600, // 10 requests per second globally } } @@ -36,9 +36,9 @@ impl Default for RateLimiterConfig { /// Circuit breaker state #[derive(Debug, Clone, Copy, PartialEq)] pub enum CircuitState { - Closed, // Normal operation - Open, // Circuit is open, requests fail fast - HalfOpen, // Testing if service has recovered + Closed, // Normal operation + Open, // Circuit is open, requests fail fast + HalfOpen, // Testing if service has recovered } /// Circuit breaker configuration @@ -57,10 +57,10 @@ pub struct CircuitBreakerConfig { impl Default for CircuitBreakerConfig { fn default() -> Self { Self { - failure_threshold: 5, // 5 failures - failure_window_secs: 60, // within 60 seconds - reset_timeout_secs: 30, // wait 30 seconds before half-open - success_threshold: 3, // 3 successes to close circuit + failure_threshold: 5, // 5 failures + failure_window_secs: 60, // within 60 seconds + reset_timeout_secs: 30, // wait 30 seconds before half-open + success_threshold: 3, // 3 successes to close circuit } } } @@ -88,14 +88,14 @@ impl TokenBucket { let now = Instant::now(); let elapsed = now.duration_since(self.last_refill).as_secs_f64(); let new_tokens = elapsed * self.refill_rate; - + self.tokens = (self.tokens + new_tokens).min(self.capacity); self.last_refill = now; } fn try_acquire(&mut self, tokens: f64) -> bool { self.refill(); - + if self.tokens >= tokens { self.tokens -= tokens; true @@ -175,18 +175,18 @@ impl ProviderCircuitBreaker { /// Record a failed request pub fn record_failure(&mut self) { let now = std::time::Instant::now(); - + // Check if failure window has expired - if let Some(last_failure) = self.last_failure_time { - if now.duration_since(last_failure).as_secs() > self.config.failure_window_secs { - // Reset failure count if window expired - self.failure_count = 0; - } + if let Some(last_failure) = self.last_failure_time + && now.duration_since(last_failure).as_secs() > self.config.failure_window_secs + { + // Reset failure count if window expired + self.failure_count = 0; } - + self.failure_count += 1; self.last_failure_time = Some(now); - + if self.failure_count >= self.config.failure_threshold && self.state == CircuitState::Closed { self.state = CircuitState::Open; self.last_state_change = now; @@ -220,7 +220,7 @@ impl RateLimitManager { pub fn new(config: RateLimiterConfig, circuit_config: CircuitBreakerConfig) -> Self { // Convert requests per minute to tokens per second let global_refill_rate = config.global_requests_per_minute as f64 / 60.0; - + Self { client_buckets: Arc::new(RwLock::new(HashMap::new())), global_bucket: Arc::new(RwLock::new(TokenBucket::new( @@ -243,18 +243,16 @@ impl RateLimitManager { return Ok(false); } } - + // Check client-specific rate limit let mut buckets = self.client_buckets.write().await; - let bucket = buckets - .entry(client_id.to_string()) - .or_insert_with(|| { - TokenBucket::new( - self.config.burst_size as f64, - self.config.requests_per_minute as f64 / 60.0, - ) - }); - + let bucket = buckets.entry(client_id.to_string()).or_insert_with(|| { + TokenBucket::new( + self.config.burst_size as f64, + self.config.requests_per_minute as f64 / 60.0, + ) + }); + Ok(bucket.try_acquire(1.0)) } @@ -264,7 +262,7 @@ impl RateLimitManager { let breaker = breakers .entry(provider_name.to_string()) .or_insert_with(|| ProviderCircuitBreaker::new(self.circuit_config.clone())); - + Ok(breaker.allow_request()) } @@ -282,7 +280,7 @@ impl RateLimitManager { let breaker = breakers .entry(provider_name.to_string()) .or_insert_with(|| ProviderCircuitBreaker::new(self.circuit_config.clone())); - + breaker.record_failure(); } @@ -299,14 +297,13 @@ impl RateLimitManager { /// Axum middleware for rate limiting pub mod middleware { use super::*; + use crate::errors::AppError; + use crate::state::AppState; use axum::{ extract::{Request, State}, middleware::Next, response::Response, }; - use crate::errors::AppError; - use crate::state::AppState; - /// Rate limiting middleware pub async fn rate_limit_middleware( @@ -319,41 +316,35 @@ pub mod middleware { // Check rate limits if !state.rate_limit_manager.check_client_request(&client_id).await? { - return Err(AppError::RateLimitError( - "Rate limit exceeded".to_string() - )); + return Err(AppError::RateLimitError("Rate limit exceeded".to_string())); } Ok(next.run(request).await) } - + /// Extract client ID from request (helper function) fn extract_client_id_from_request(request: &Request) -> String { // Try to extract from Authorization header - if let Some(auth_header) = request.headers().get("Authorization") { - if let Ok(auth_str) = auth_header.to_str() { - if auth_str.starts_with("Bearer ") { - let token = &auth_str[7..]; - // Use token hash as client ID (same logic as auth module) - return format!("client_{}", &token[..8.min(token.len())]); - } - } + if let Some(auth_header) = request.headers().get("Authorization") + && let Ok(auth_str) = auth_header.to_str() + && let Some(token) = auth_str.strip_prefix("Bearer ") + { + // Use token hash as client ID (same logic as auth module) + return format!("client_{}", &token[..8.min(token.len())]); } - + // Fallback to anonymous "anonymous".to_string() } /// Circuit breaker middleware for provider requests - pub async fn circuit_breaker_middleware( - provider_name: &str, - state: &AppState, - ) -> Result<(), AppError> { + pub async fn circuit_breaker_middleware(provider_name: &str, state: &AppState) -> Result<(), AppError> { if !state.rate_limit_manager.check_provider_request(provider_name).await? { - return Err(AppError::ProviderError( - format!("Provider {} is currently unavailable (circuit breaker open)", provider_name) - )); + return Err(AppError::ProviderError(format!( + "Provider {} is currently unavailable (circuit breaker open)", + provider_name + ))); } Ok(()) } -} \ No newline at end of file +} diff --git a/src/server/mod.rs b/src/server/mod.rs index 94a0fa33..fe3a3888 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -1,22 +1,25 @@ -use std::sync::Arc; -use sqlx::Row; -use uuid::Uuid; use axum::{ - extract::State, - routing::post, Json, Router, - response::sse::{Event, Sse}, + extract::State, response::IntoResponse, + response::sse::{Event, Sse}, + routing::post, }; use futures::stream::StreamExt; +use sqlx::Row; +use std::sync::Arc; use tracing::{info, warn}; +use uuid::Uuid; use crate::{ auth::AuthenticatedClient, errors::AppError, - models::{ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, ChatStreamChoice, ChatStreamDelta, ChatMessage, ChatChoice, Usage}, - state::AppState, + models::{ + ChatChoice, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, ChatMessage, + ChatStreamChoice, ChatStreamDelta, Usage, + }, rate_limiting, + state::AppState, }; pub fn router(state: AppState) -> Router { @@ -65,13 +68,13 @@ async fn chat_completions( if !state.auth_tokens.is_empty() && !state.auth_tokens.contains(&auth.token) { return Err(AppError::AuthError("Invalid authentication token".to_string())); } - + let start_time = std::time::Instant::now(); let client_id = auth.client_id.clone(); let model = request.model.clone(); - + info!("Chat completion request from client {} for model {}", client_id, model); - + // Check if model is enabled in database and get potential mapping let model_config = sqlx::query("SELECT enabled, mapping FROM model_configs WHERE id = ?") .bind(&model) @@ -85,7 +88,10 @@ async fn chat_completions( }; if !model_enabled { - return Err(AppError::ValidationError(format!("Model {} is currently disabled", model))); + return Err(AppError::ValidationError(format!( + "Model {} is currently disabled", + model + ))); } // Apply mapping if present @@ -95,53 +101,61 @@ async fn chat_completions( } // Find appropriate provider for the model - let provider = state.provider_manager.get_provider_for_model(&request.model).await + let provider = state + .provider_manager + .get_provider_for_model(&request.model) + .await .ok_or_else(|| AppError::ProviderError(format!("No provider found for model: {}", request.model)))?; - + let provider_name = provider.name().to_string(); - + // Check circuit breaker for this provider rate_limiting::middleware::circuit_breaker_middleware(&provider_name, &state).await?; // Convert to unified request format - let mut unified_request = crate::models::UnifiedRequest::try_from(request) - .map_err(|e| AppError::ValidationError(e.to_string()))?; - + let mut unified_request = + crate::models::UnifiedRequest::try_from(request).map_err(|e| AppError::ValidationError(e.to_string()))?; + // Set client_id from authentication unified_request.client_id = client_id.clone(); // Hydrate images if present if unified_request.has_images { - unified_request.hydrate_images().await + unified_request + .hydrate_images() + .await .map_err(|e| AppError::ValidationError(format!("Failed to process images: {}", e)))?; } + let has_images = unified_request.has_images; + // Check if streaming is requested if unified_request.stream { // Estimate prompt tokens for logging later let prompt_tokens = crate::utils::tokens::estimate_request_tokens(&model, &unified_request); - let has_images = unified_request.has_images; // Handle streaming response let stream_result = provider.chat_completion_stream(unified_request).await; - + match stream_result { Ok(stream) => { // Record provider success state.rate_limit_manager.record_provider_success(&provider_name).await; - + // Wrap with AggregatingStream for token counting and database logging let aggregating_stream = crate::utils::streaming::AggregatingStream::new( stream, - client_id.clone(), - provider.clone(), - model.clone(), - prompt_tokens, - has_images, - state.request_logger.clone(), - state.client_manager.clone(), - state.model_registry.clone(), - state.db_pool.clone(), + crate::utils::streaming::StreamConfig { + client_id: client_id.clone(), + provider: provider.clone(), + model: model.clone(), + prompt_tokens, + has_images, + logger: state.request_logger.clone(), + client_manager: state.client_manager.clone(), + model_registry: state.model_registry.clone(), + db_pool: state.db_pool.clone(), + }, ); // Create SSE stream from aggregating stream @@ -164,8 +178,14 @@ async fn chat_completions( finish_reason: chunk.finish_reason, }], }; - - Ok(Event::default().json_data(response).unwrap()) + + match Event::default().json_data(response) { + Ok(event) => Ok(event), + Err(e) => { + warn!("Failed to serialize SSE event: {}", e); + Err(AppError::InternalError("SSE serialization failed".to_string())) + } + } } Err(e) => { warn!("Error in streaming response: {}", e); @@ -173,17 +193,17 @@ async fn chat_completions( } } }); - + Ok(Sse::new(sse_stream).into_response()) } Err(e) => { // Record provider failure state.rate_limit_manager.record_provider_failure(&provider_name).await; - + // Log failed request let duration = start_time.elapsed(); warn!("Streaming request failed after {:?}: {}", duration, e); - + Err(e) } } @@ -193,12 +213,19 @@ async fn chat_completions( match result { Ok(response) => { - // Record provider success - state.rate_limit_manager.record_provider_success(&provider_name).await; - - let duration = start_time.elapsed(); - let cost = get_model_cost(&response.model, response.prompt_tokens, response.completion_tokens, &provider, &state).await; - // Log request to database + // Record provider success + state.rate_limit_manager.record_provider_success(&provider_name).await; + + let duration = start_time.elapsed(); + let cost = get_model_cost( + &response.model, + response.prompt_tokens, + response.completion_tokens, + &provider, + &state, + ) + .await; + // Log request to database state.request_logger.log_request(crate::logging::RequestLog { timestamp: chrono::Utc::now(), client_id: client_id.clone(), @@ -208,18 +235,17 @@ async fn chat_completions( completion_tokens: response.completion_tokens, total_tokens: response.total_tokens, cost, - has_images: false, // TODO: check images + has_images, status: "success".to_string(), error_message: None, duration_ms: duration.as_millis() as u64, }); // Update client usage - let _ = state.client_manager.update_client_usage( - &client_id, - response.total_tokens as i64, - cost, - ).await; + let _ = state + .client_manager + .update_client_usage(&client_id, response.total_tokens as i64, cost) + .await; // Convert ProviderResponse to ChatCompletionResponse let chat_response = ChatCompletionResponse { @@ -231,8 +257,8 @@ async fn chat_completions( index: 0, message: ChatMessage { role: "assistant".to_string(), - content: crate::models::MessageContent::Text { - content: response.content + content: crate::models::MessageContent::Text { + content: response.content, }, reasoning_content: response.reasoning_content, }, @@ -244,16 +270,16 @@ async fn chat_completions( total_tokens: response.total_tokens, }), }; - + // Log successful request info!("Request completed successfully in {:?}", duration); - + Ok(Json(chat_response).into_response()) } Err(e) => { // Record provider failure state.rate_limit_manager.record_provider_failure(&provider_name).await; - + // Log failed request to database let duration = start_time.elapsed(); state.request_logger.log_request(crate::logging::RequestLog { @@ -272,7 +298,7 @@ async fn chat_completions( }); warn!("Request failed after {:?}: {}", duration, e); - + Err(e) } } diff --git a/src/state/mod.rs b/src/state/mod.rs index a2c6a175..3412c6a4 100644 --- a/src/state/mod.rs +++ b/src/state/mod.rs @@ -2,9 +2,8 @@ use std::sync::Arc; use tokio::sync::broadcast; use crate::{ - client::ClientManager, database::DbPool, providers::ProviderManager, - rate_limiting::RateLimitManager, logging::RequestLogger, - models::registry::ModelRegistry, config::AppConfig, + client::ClientManager, config::AppConfig, database::DbPool, logging::RequestLogger, + models::registry::ModelRegistry, providers::ProviderManager, rate_limiting::RateLimitManager, }; /// Shared application state diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 7ad86e8d..0f0d1351 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,3 +1,3 @@ -pub mod tokens; pub mod registry; pub mod streaming; +pub mod tokens; diff --git a/src/utils/registry.rs b/src/utils/registry.rs index 80d855fe..ca7f8d52 100644 --- a/src/utils/registry.rs +++ b/src/utils/registry.rs @@ -1,24 +1,24 @@ +use crate::models::registry::ModelRegistry; use anyhow::Result; use tracing::info; -use crate::models::registry::ModelRegistry; const MODELS_DEV_URL: &str = "https://models.dev/api.json"; pub async fn fetch_registry() -> Result { info!("Fetching model registry from {}", MODELS_DEV_URL); - + let client = reqwest::Client::builder() .timeout(std::time::Duration::from_secs(10)) .build()?; - + let response = client.get(MODELS_DEV_URL).send().await?; - + if !response.status().is_success() { return Err(anyhow::anyhow!("Failed to fetch registry: HTTP {}", response.status())); } - + let registry: ModelRegistry = response.json().await?; info!("Successfully loaded model registry"); - + Ok(registry) } diff --git a/src/utils/streaming.rs b/src/utils/streaming.rs index b97e58ee..4f557a76 100644 --- a/src/utils/streaming.rs +++ b/src/utils/streaming.rs @@ -1,13 +1,26 @@ -use futures::stream::Stream; -use std::pin::Pin; -use std::task::{Context, Poll}; -use std::sync::Arc; -use sqlx::Row; -use crate::logging::{RequestLogger, RequestLog}; use crate::client::ClientManager; -use crate::providers::{Provider, ProviderStreamChunk}; use crate::errors::AppError; +use crate::logging::{RequestLog, RequestLogger}; +use crate::providers::{Provider, ProviderStreamChunk}; use crate::utils::tokens::estimate_completion_tokens; +use futures::stream::Stream; +use sqlx::Row; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +/// Configuration for creating an AggregatingStream. +pub struct StreamConfig { + pub client_id: String, + pub provider: Arc, + pub model: String, + pub prompt_tokens: u32, + pub has_images: bool, + pub logger: Arc, + pub client_manager: Arc, + pub model_registry: Arc, + pub db_pool: crate::database::DbPool, +} pub struct AggregatingStream { inner: S, @@ -26,35 +39,24 @@ pub struct AggregatingStream { has_logged: bool, } -impl AggregatingStream -where - S: Stream> + Unpin +impl AggregatingStream +where + S: Stream> + Unpin, { - pub fn new( - inner: S, - client_id: String, - provider: Arc, - model: String, - prompt_tokens: u32, - has_images: bool, - logger: Arc, - client_manager: Arc, - model_registry: Arc, - db_pool: crate::database::DbPool, - ) -> Self { + pub fn new(inner: S, config: StreamConfig) -> Self { Self { inner, - client_id, - provider, - model, - prompt_tokens, - has_images, + client_id: config.client_id, + provider: config.provider, + model: config.model, + prompt_tokens: config.prompt_tokens, + has_images: config.has_images, accumulated_content: String::new(), accumulated_reasoning: String::new(), - logger, - client_manager, - model_registry, - db_pool, + logger: config.logger, + client_manager: config.client_manager, + model_registry: config.model_registry, + db_pool: config.db_pool, start_time: std::time::Instant::now(), has_logged: false, } @@ -77,7 +79,7 @@ where let has_images = self.has_images; let registry = self.model_registry.clone(); let pool = self.db_pool.clone(); - + // Estimate completion tokens (including reasoning if present) let content_tokens = estimate_completion_tokens(&self.accumulated_content, &model); let reasoning_tokens = if !self.accumulated_reasoning.is_empty() { @@ -85,18 +87,19 @@ where } else { 0 }; - + let completion_tokens = content_tokens + reasoning_tokens; let total_tokens = prompt_tokens + completion_tokens; // Spawn a background task to log the completion tokio::spawn(async move { // Check database for cost overrides - let db_cost = sqlx::query("SELECT prompt_cost_per_m, completion_cost_per_m FROM model_configs WHERE id = ?") - .bind(&model) - .fetch_optional(&pool) - .await - .unwrap_or(None); + let db_cost = + sqlx::query("SELECT prompt_cost_per_m, completion_cost_per_m FROM model_configs WHERE id = ?") + .bind(&model) + .fetch_optional(&pool) + .await + .unwrap_or(None); let cost = if let Some(row) = db_cost { let prompt_rate = row.get::, _>("prompt_cost_per_m"); @@ -128,24 +131,22 @@ where }); // Update client usage - let _ = client_manager.update_client_usage( - &client_id, - total_tokens as i64, - cost, - ).await; + let _ = client_manager + .update_client_usage(&client_id, total_tokens as i64, cost) + .await; }); } } impl Stream for AggregatingStream where - S: Stream> + Unpin + S: Stream> + Unpin, { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let result = Pin::new(&mut self.inner).poll_next(cx); - + match &result { Poll::Ready(Some(Ok(chunk))) => { self.accumulated_content.push_str(&chunk.content); @@ -165,7 +166,7 @@ where } Poll::Pending => {} } - + result } } @@ -173,52 +174,87 @@ where #[cfg(test)] mod tests { use super::*; - use futures::stream::{self, StreamExt}; use anyhow::Result; - + use futures::stream::{self, StreamExt}; + // Simple mock provider for testing struct MockProvider; #[async_trait::async_trait] impl Provider for MockProvider { - fn name(&self) -> &str { "mock" } - fn supports_model(&self, _model: &str) -> bool { true } - fn supports_multimodal(&self) -> bool { false } - async fn chat_completion(&self, _req: crate::models::UnifiedRequest) -> Result { unimplemented!() } - async fn chat_completion_stream(&self, _req: crate::models::UnifiedRequest) -> Result>, AppError> { unimplemented!() } - fn estimate_tokens(&self, _req: &crate::models::UnifiedRequest) -> Result { Ok(10) } - fn calculate_cost(&self, _model: &str, _p: u32, _c: u32, _r: &crate::models::registry::ModelRegistry) -> f64 { 0.05 } + fn name(&self) -> &str { + "mock" + } + fn supports_model(&self, _model: &str) -> bool { + true + } + fn supports_multimodal(&self) -> bool { + false + } + async fn chat_completion( + &self, + _req: crate::models::UnifiedRequest, + ) -> Result { + unimplemented!() + } + async fn chat_completion_stream( + &self, + _req: crate::models::UnifiedRequest, + ) -> Result>, AppError> { + unimplemented!() + } + fn estimate_tokens(&self, _req: &crate::models::UnifiedRequest) -> Result { + Ok(10) + } + fn calculate_cost(&self, _model: &str, _p: u32, _c: u32, _r: &crate::models::registry::ModelRegistry) -> f64 { + 0.05 + } } #[tokio::test] async fn test_aggregating_stream() { let chunks = vec![ - Ok(ProviderStreamChunk { content: "Hello".to_string(), finish_reason: None, model: "test".to_string() }), - Ok(ProviderStreamChunk { content: " World".to_string(), finish_reason: Some("stop".to_string()), model: "test".to_string() }), + Ok(ProviderStreamChunk { + content: "Hello".to_string(), + reasoning_content: None, + finish_reason: None, + model: "test".to_string(), + }), + Ok(ProviderStreamChunk { + content: " World".to_string(), + reasoning_content: None, + finish_reason: Some("stop".to_string()), + model: "test".to_string(), + }), ]; let inner_stream = stream::iter(chunks); - + let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap(); - let logger = Arc::new(RequestLogger::new(pool.clone())); + let (dashboard_tx, _) = tokio::sync::broadcast::channel(16); + let logger = Arc::new(RequestLogger::new(pool.clone(), dashboard_tx)); let client_manager = Arc::new(ClientManager::new(pool.clone())); - let registry = Arc::new(crate::models::registry::ModelRegistry { providers: std::collections::HashMap::new() }); - + let registry = Arc::new(crate::models::registry::ModelRegistry { + providers: std::collections::HashMap::new(), + }); + let mut agg_stream = AggregatingStream::new( inner_stream, - "client_1".to_string(), - Arc::new(MockProvider), - "test".to_string(), - 10, - false, - logger, - client_manager, - registry, - pool.clone(), + StreamConfig { + client_id: "client_1".to_string(), + provider: Arc::new(MockProvider), + model: "test".to_string(), + prompt_tokens: 10, + has_images: false, + logger, + client_manager, + model_registry: registry, + db_pool: pool.clone(), + }, ); - + while let Some(item) = agg_stream.next().await { assert!(item.is_ok()); } - + assert_eq!(agg_stream.accumulated_content, "Hello World"); assert!(agg_stream.has_logged); } diff --git a/src/utils/tokens.rs b/src/utils/tokens.rs index ad4b5887..dbf020d2 100644 --- a/src/utils/tokens.rs +++ b/src/utils/tokens.rs @@ -1,27 +1,26 @@ -use tiktoken_rs::get_bpe_from_model; use crate::models::UnifiedRequest; +use tiktoken_rs::get_bpe_from_model; /// Count tokens for a given model and text pub fn count_tokens(model: &str, text: &str) -> u32 { // If we can't get the bpe for the model, fallback to a safe default (cl100k_base for GPT-4/o1) - let bpe = get_bpe_from_model(model).unwrap_or_else(|_| { - tiktoken_rs::cl100k_base().expect("Failed to get cl100k_base encoding") - }); - + let bpe = get_bpe_from_model(model) + .unwrap_or_else(|_| tiktoken_rs::cl100k_base().expect("Failed to get cl100k_base encoding")); + bpe.encode_with_special_tokens(text).len() as u32 } /// Estimate tokens for a unified request pub fn estimate_request_tokens(model: &str, request: &UnifiedRequest) -> u32 { let mut total_tokens = 0; - + // Base tokens per message for OpenAI (approximate) let tokens_per_message = 3; let _tokens_per_name = 1; - + for msg in &request.messages { total_tokens += tokens_per_message; - + for part in &msg.content { match part { crate::models::ContentPart::Text { text } => { @@ -34,14 +33,14 @@ pub fn estimate_request_tokens(model: &str, request: &UnifiedRequest) -> u32 { } } } - + // Add name tokens if we had names (we don't in UnifiedMessage yet) // total_tokens += tokens_per_name; } - + // Add 3 tokens for the assistant reply header total_tokens += 3; - + total_tokens } diff --git a/test_dashboard.sh b/test_dashboard.sh index 11f06ba5..5a6afa94 100755 --- a/test_dashboard.sh +++ b/test_dashboard.sh @@ -30,7 +30,7 @@ curl -s http://localhost:8080/api/auth/status | jq . 2>/dev/null || echo "JSON r echo "" echo "Dashboard should be available at: http://localhost:8080" -echo "Default login: admin / admin123" +echo "Default login: admin / admin" echo "" echo "Press Ctrl+C to stop the server" diff --git a/tests/integration_tests.rs.bak b/tests/integration_tests.rs.bak deleted file mode 100644 index b5c85745..00000000 --- a/tests/integration_tests.rs.bak +++ /dev/null @@ -1,188 +0,0 @@ -// Integration tests for LLM Proxy Gateway - -use llm_proxy::config::Config; -use llm_proxy::database::Database; -use llm_proxy::state::AppState; -use llm_proxy::rate_limiting::RateLimitManager; -use tempfile::TempDir; -use std::fs; - -#[tokio::test] -async fn test_config_loading() { - // Create a temporary config file - let temp_dir = TempDir::new().unwrap(); - let config_path = temp_dir.path().join("config.toml"); - - let config_content = r#" -[server] -port = 8080 -host = "0.0.0.0" - -[database] -path = "./data/test.db" -max_connections = 5 - -[providers.openai] -enabled = true -base_url = "https://api.openai.com/v1" - -[providers.gemini] -enabled = true -base_url = "https://generativelanguage.googleapis.com/v1" - -[providers.deepseek] -enabled = true -base_url = "https://api.deepseek.com" - -[providers.grok] -enabled = false -base_url = "https://api.x.ai/v1" - -[model_mapping] -"gpt-*" = "openai" -"gemini-*" = "gemini" -"deepseek-*" = "deepseek" -"grok-*" = "grok" - -[pricing] -openai = { input = 0.01, output = 0.03 } -gemini = { input = 0.0005, output = 0.0015 } -deepseek = { input = 0.00014, output = 0.00028 } -grok = { input = 0.001, output = 0.003 } -"#; - - fs::write(&config_path, config_content).unwrap(); - - // Test loading config - let config = Config::load_from_path(&config_path); - assert!(config.is_ok()); - - let config = config.unwrap(); - assert_eq!(config.server.port, 8080); - assert!(config.providers.openai.is_some()); - assert!(config.providers.grok.is_none()); -} - -#[tokio::test] -async fn test_database_initialization() { - // Create a temporary database file - let temp_dir = TempDir::new().unwrap(); - let db_path = temp_dir.path().join("test.db"); - - // Test database initialization - let database = Database::new(&db_path).await; - assert!(database.is_ok()); - - let database = database.unwrap(); - - // Test connection - let test_result = database.test_connection().await; - assert!(test_result.is_ok()); -} - -#[tokio::test] -async fn test_provider_manager() { - // Create a provider manager - use llm_proxy::providers::{ProviderManager, Provider}; - use llm_proxy::config::OpenAIConfig; - - let mut manager = ProviderManager::new(); - assert_eq!(manager.providers.len(), 0); - - // Test adding providers (we can't actually add real providers without API keys) - // This test just verifies the manager structure works - assert!(manager.get_provider_for_model("gpt-4").is_none()); - assert!(manager.get_provider("openai").is_none()); -} - -#[tokio::test] -async fn test_rate_limit_manager() { - let manager = RateLimitManager::new(60, 10); - - // Test client rate limiting - let allowed = manager.check_request("test-client").await; - assert!(allowed); // First request should be allowed - - // Test provider circuit breaker - let allowed = manager.check_provider("openai").await; - assert!(allowed); // Circuit should be closed initially - - // Record some failures - manager.record_provider_failure("openai").await; - manager.record_provider_failure("openai").await; - manager.record_provider_failure("openai").await; - manager.record_provider_failure("openai").await; - manager.record_provider_failure("openai").await; - - // After 5 failures, circuit should be open - let allowed = manager.check_provider("openai").await; - assert!(!allowed); // Circuit should be open - - // Record success to close circuit - manager.record_provider_success("openai").await; - manager.record_provider_success("openai").await; - manager.record_provider_success("openai").await; - - // After 3 successes in half-open state, circuit should be closed - let allowed = manager.check_provider("openai").await; - assert!(allowed); // Circuit should be closed again -} - -#[tokio::test] -async fn test_app_state_creation() { - // Create a temporary database - let temp_dir = TempDir::new().unwrap(); - let db_path = temp_dir.path().join("test.db"); - - let database = Database::new(&db_path).await.unwrap(); - - // Test AppState creation using test utilities - use llm_proxy::test_utils::create_test_state; - let state = create_test_state().await; - - // Verify state components are initialized - assert!(state.database.test_connection().await.is_ok()); -} - -#[tokio::test] -async fn test_multimodal_image_converter() { - use llm_proxy::multimodal::{ImageConverter, ImageInput}; - - // Test model detection - assert!(ImageConverter::model_supports_multimodal("gpt-4-vision-preview")); - assert!(ImageConverter::model_supports_multimodal("gemini-pro-vision")); - assert!(!ImageConverter::model_supports_multimodal("gpt-3.5-turbo")); - assert!(!ImageConverter::model_supports_multimodal("gemini-pro")); - - // Test data URL parsing (utility function) - let test_url = "data:image/jpeg;base64,SGVsbG8gV29ybGQ="; - let parts: Vec<&str> = test_url[5..].split(";base64,").collect(); - assert_eq!(parts.len(), 2); - assert_eq!(parts[0], "image/jpeg"); - assert_eq!(parts[1], "SGVsbG8gV29ybGQ="); -} - -#[tokio::test] -async fn test_error_conversions() { - use llm_proxy::errors::AppError; - use anyhow::anyhow; - - // Test anyhow error conversion - let anyhow_error = anyhow!("Test error"); - let app_error: AppError = anyhow_error.into(); - - match app_error { - AppError::InternalError(msg) => assert_eq!(msg, "Test error"), - _ => panic!("Expected InternalError"), - } - - // Test sqlx error conversion - use sqlx::Error as SqlxError; - let sqlx_error = SqlxError::PoolClosed; - let app_error: AppError = sqlx_error.into(); - - match app_error { - AppError::DatabaseError(msg) => assert!(msg.contains("pool closed")), - _ => panic!("Expected DatabaseError"), - } -} \ No newline at end of file diff --git a/tests/streaming_test.rs b/tests/streaming_test.rs deleted file mode 100644 index e69de29b..00000000