diff --git a/migrations/001-add-billing-mode.sql b/migrations/001-add-billing-mode.sql new file mode 100644 index 00000000..4a323305 --- /dev/null +++ b/migrations/001-add-billing-mode.sql @@ -0,0 +1,13 @@ +-- Migration: add billing_mode to provider_configs +-- Adds a billing_mode TEXT column with default 'prepaid' +-- After applying, set Gemini to postpaid with: +-- UPDATE provider_configs SET billing_mode = 'postpaid' WHERE id = 'gemini'; + +BEGIN TRANSACTION; + +ALTER TABLE provider_configs ADD COLUMN billing_mode TEXT DEFAULT 'prepaid'; + +COMMIT; + +-- NOTE: If you use a production SQLite file, run the following to set Gemini to postpaid: +-- sqlite3 /path/to/db.sqlite "UPDATE provider_configs SET billing_mode='postpaid' WHERE id='gemini';" diff --git a/src/dashboard/providers.rs b/src/dashboard/providers.rs index 3f176e6d..c5c3ce24 100644 --- a/src/dashboard/providers.rs +++ b/src/dashboard/providers.rs @@ -17,6 +17,7 @@ pub(super) struct UpdateProviderRequest { pub(super) api_key: Option, pub(super) credit_balance: Option, pub(super) low_credit_threshold: Option, + pub(super) billing_mode: Option, } pub(super) async fn handle_get_providers(State(state): State) -> Json> { @@ -24,11 +25,12 @@ pub(super) async fn handle_get_providers(State(state): State) -> let config = &state.app_state.config; let pool = &state.app_state.db_pool; - // Load all overrides from database - let db_configs_result = - sqlx::query("SELECT id, enabled, base_url, credit_balance, low_credit_threshold FROM provider_configs") - .fetch_all(pool) - .await; + // Load all overrides from database (including billing_mode) + let db_configs_result = sqlx::query( + "SELECT id, enabled, base_url, credit_balance, low_credit_threshold, billing_mode FROM provider_configs", + ) + .fetch_all(pool) + .await; let mut db_configs = HashMap::new(); if let Ok(rows) = db_configs_result { @@ -38,7 +40,8 @@ pub(super) async fn handle_get_providers(State(state): State) -> let base_url: Option = row.get("base_url"); let balance: f64 = row.get("credit_balance"); let threshold: f64 = row.get("low_credit_threshold"); - db_configs.insert(id, (enabled, base_url, balance, threshold)); + let billing_mode: Option = row.get("billing_mode"); + db_configs.insert(id, (enabled, base_url, balance, threshold, billing_mode)); } } @@ -80,15 +83,17 @@ pub(super) async fn handle_get_providers(State(state): State) -> let mut balance = 0.0; let mut threshold = 5.0; + let mut billing_mode: Option = None; // Apply database overrides - if let Some((db_enabled, db_url, db_balance, db_threshold)) = db_configs.get(id) { + if let Some((db_enabled, db_url, db_balance, db_threshold, db_billing)) = db_configs.get(id) { enabled = *db_enabled; if let Some(url) = db_url { base_url = url.clone(); } balance = *db_balance; threshold = *db_threshold; + billing_mode = db_billing.clone(); } // Find models for this provider in registry @@ -138,6 +143,7 @@ pub(super) async fn handle_get_providers(State(state): State) -> "base_url": base_url, "credit_balance": balance, "low_credit_threshold": threshold, + "billing_mode": billing_mode, "last_used": None::, })); } @@ -185,10 +191,11 @@ pub(super) async fn handle_get_provider( let mut balance = 0.0; let mut threshold = 5.0; + let mut billing_mode: Option = None; // Apply database overrides let db_config = sqlx::query( - "SELECT enabled, base_url, credit_balance, low_credit_threshold FROM provider_configs WHERE id = ?", + "SELECT enabled, base_url, credit_balance, low_credit_threshold, billing_mode FROM provider_configs WHERE id = ?", ) .bind(&name) .fetch_optional(pool) @@ -201,6 +208,7 @@ pub(super) async fn handle_get_provider( } balance = row.get::("credit_balance"); threshold = row.get::("low_credit_threshold"); + billing_mode = row.get::, _>("billing_mode"); } // Find models for this provider @@ -246,6 +254,7 @@ pub(super) async fn handle_get_provider( "base_url": base_url, "credit_balance": balance, "low_credit_threshold": threshold, + "billing_mode": billing_mode, "last_used": None::, }))) } @@ -262,19 +271,20 @@ pub(super) async fn handle_update_provider( let pool = &state.app_state.db_pool; - // Update or insert into database + // Update or insert into database (include billing_mode) let result = sqlx::query( r#" - INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, credit_balance, low_credit_threshold) - VALUES (?, ?, ?, ?, ?, ?, ?) + INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, credit_balance, low_credit_threshold, billing_mode) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(id) DO UPDATE SET enabled = excluded.enabled, base_url = excluded.base_url, api_key = COALESCE(excluded.api_key, provider_configs.api_key), credit_balance = COALESCE(excluded.credit_balance, provider_configs.credit_balance), low_credit_threshold = COALESCE(excluded.low_credit_threshold, provider_configs.low_credit_threshold), + billing_mode = COALESCE(excluded.billing_mode, provider_configs.billing_mode), updated_at = CURRENT_TIMESTAMP - "# + "#, ) .bind(&name) .bind(name.to_uppercase()) @@ -283,6 +293,7 @@ pub(super) async fn handle_update_provider( .bind(&payload.api_key) .bind(payload.credit_balance) .bind(payload.low_credit_threshold) + .bind(payload.billing_mode) .execute(pool) .await; diff --git a/src/logging/mod.rs b/src/logging/mod.rs index feb74594..90c22458 100644 --- a/src/logging/mod.rs +++ b/src/logging/mod.rs @@ -100,13 +100,18 @@ impl RequestLogger { .execute(&mut *tx) .await?; - // Deduct from provider balance if successful (skip postpaid like Gemini) - if log.cost > 0.0 && log.provider != "gemini" { - sqlx::query("UPDATE provider_configs SET credit_balance = credit_balance - ? WHERE id = ?") - .bind(log.cost) - .bind(&log.provider) - .execute(&mut *tx) - .await?; + // Deduct from provider balance if successful. + // Providers configured with billing_mode = 'postpaid' will not have their + // credit_balance decremented. Use a conditional UPDATE so we don't need + // a prior SELECT and avoid extra round-trips. + if log.cost > 0.0 { + sqlx::query( + "UPDATE provider_configs SET credit_balance = credit_balance - ? WHERE id = ? AND (billing_mode IS NULL OR billing_mode != 'postpaid')", + ) + .bind(log.cost) + .bind(&log.provider) + .execute(&mut *tx) + .await?; } tx.commit().await?; diff --git a/static/js/pages/providers.js b/static/js/pages/providers.js index 19bac5ea..41bd7996 100644 --- a/static/js/pages/providers.js +++ b/static/js/pages/providers.js @@ -70,6 +70,10 @@ class ProvidersPage { Balance: ${provider.id === 'ollama' ? 'FREE' : window.api.formatCurrency(provider.credit_balance)} ${isLowBalance ? '' : ''} +
+ + Billing: ${provider.billing_mode ? provider.billing_mode.toUpperCase() : 'PREPAID'} +
Last used: ${provider.last_used ? window.api.formatTimeAgo(provider.last_used) : 'Never'} @@ -163,16 +167,23 @@ class ProvidersPage {
-
-
- - +
+
+ + +
+
+ + +
- - + +
-