224 lines
7.1 KiB
Rust
224 lines
7.1 KiB
Rust
use chrono::{DateTime, Utc};
|
|
use serde::Serialize;
|
|
use sqlx::SqlitePool;
|
|
use tokio::sync::broadcast;
|
|
use tracing::{info, warn};
|
|
|
|
use crate::errors::AppError;
|
|
|
|
/// Request log entry for database storage
|
|
#[derive(Debug, Clone, Serialize)]
|
|
pub struct RequestLog {
|
|
pub timestamp: DateTime<Utc>,
|
|
pub client_id: String,
|
|
pub provider: String,
|
|
pub model: String,
|
|
pub prompt_tokens: u32,
|
|
pub completion_tokens: u32,
|
|
pub total_tokens: u32,
|
|
pub cache_read_tokens: u32,
|
|
pub cache_write_tokens: u32,
|
|
pub cost: f64,
|
|
pub has_images: bool,
|
|
pub status: String, // "success", "error"
|
|
pub error_message: Option<String>,
|
|
pub duration_ms: u64,
|
|
}
|
|
|
|
/// Database operations for request logging
|
|
pub struct RequestLogger {
|
|
db_pool: SqlitePool,
|
|
dashboard_tx: broadcast::Sender<serde_json::Value>,
|
|
}
|
|
|
|
impl RequestLogger {
|
|
pub fn new(db_pool: SqlitePool, dashboard_tx: broadcast::Sender<serde_json::Value>) -> Self {
|
|
Self { db_pool, dashboard_tx }
|
|
}
|
|
|
|
/// Log a request to the database (async, spawns a task)
|
|
pub fn log_request(&self, log: RequestLog) {
|
|
let pool = self.db_pool.clone();
|
|
let tx = self.dashboard_tx.clone();
|
|
|
|
// Spawn async task to log without blocking response
|
|
tokio::spawn(async move {
|
|
// Broadcast to dashboard
|
|
let broadcast_result = tx.send(serde_json::json!({
|
|
"type": "request",
|
|
"channel": "requests",
|
|
"payload": log
|
|
}));
|
|
match broadcast_result {
|
|
Ok(receivers) => info!("Broadcast request log to {} dashboard listeners", receivers),
|
|
Err(_) => {} // No active WebSocket clients — expected when dashboard isn't open
|
|
}
|
|
|
|
match Self::insert_log(&pool, log).await {
|
|
Ok(()) => info!("Request logged to database successfully"),
|
|
Err(e) => warn!("Failed to log request to database: {}", e),
|
|
}
|
|
});
|
|
}
|
|
|
|
/// Insert a log entry into the database
|
|
async fn insert_log(pool: &SqlitePool, log: RequestLog) -> Result<(), sqlx::Error> {
|
|
let mut tx = pool.begin().await?;
|
|
|
|
// Ensure the client row exists (FK constraint requires it)
|
|
sqlx::query(
|
|
"INSERT OR IGNORE INTO clients (client_id, name, description) VALUES (?, ?, 'Auto-created from request')",
|
|
)
|
|
.bind(&log.client_id)
|
|
.bind(&log.client_id)
|
|
.execute(&mut *tx)
|
|
.await?;
|
|
|
|
sqlx::query(
|
|
r#"
|
|
INSERT INTO llm_requests
|
|
(timestamp, client_id, provider, model, prompt_tokens, completion_tokens, total_tokens, cache_read_tokens, cache_write_tokens, cost, has_images, status, error_message, duration_ms, request_body, response_body)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
"#,
|
|
)
|
|
.bind(log.timestamp)
|
|
.bind(log.client_id)
|
|
.bind(&log.provider)
|
|
.bind(log.model)
|
|
.bind(log.prompt_tokens as i64)
|
|
.bind(log.completion_tokens as i64)
|
|
.bind(log.total_tokens as i64)
|
|
.bind(log.cache_read_tokens as i64)
|
|
.bind(log.cache_write_tokens as i64)
|
|
.bind(log.cost)
|
|
.bind(log.has_images)
|
|
.bind(log.status)
|
|
.bind(log.error_message)
|
|
.bind(log.duration_ms as i64)
|
|
.bind(None::<String>) // request_body - optional, not stored to save disk space
|
|
.bind(None::<String>) // response_body - optional, not stored to save disk space
|
|
.execute(&mut *tx)
|
|
.await?;
|
|
|
|
// Deduct from provider balance if successful.
|
|
// Providers configured with billing_mode = 'postpaid' will not have their
|
|
// credit_balance decremented. Use a conditional UPDATE so we don't need
|
|
// a prior SELECT and avoid extra round-trips.
|
|
if log.cost > 0.0 {
|
|
sqlx::query(
|
|
"UPDATE provider_configs SET credit_balance = credit_balance - ? WHERE id = ? AND (billing_mode IS NULL OR billing_mode != 'postpaid')",
|
|
)
|
|
.bind(log.cost)
|
|
.bind(&log.provider)
|
|
.execute(&mut *tx)
|
|
.await?;
|
|
}
|
|
|
|
tx.commit().await?;
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
// /// Middleware to log LLM API requests
|
|
// /// TODO: Implement proper middleware that can extract response body details
|
|
// pub async fn request_logging_middleware(
|
|
// // Extract the authenticated client (if available)
|
|
// auth_result: Result<AuthenticatedClient, AppError>,
|
|
// request: Request,
|
|
// next: Next,
|
|
// ) -> Response {
|
|
// let start_time = std::time::Instant::now();
|
|
//
|
|
// // Extract client_id from auth or use "unknown"
|
|
// let client_id = match auth_result {
|
|
// Ok(auth) => auth.client_id,
|
|
// Err(_) => "unknown".to_string(),
|
|
// };
|
|
//
|
|
// // Try to extract request details
|
|
// let (request_parts, request_body) = request.into_parts();
|
|
//
|
|
// // Clone request parts for logging
|
|
// let path = request_parts.uri.path().to_string();
|
|
//
|
|
// // Check if this is a chat completion request
|
|
// let is_chat_completion = path == "/v1/chat/completions";
|
|
//
|
|
// // Reconstruct request for downstream handlers
|
|
// let request = Request::from_parts(request_parts, request_body);
|
|
//
|
|
// // Process request and get response
|
|
// let response = next.run(request).await;
|
|
//
|
|
// // Calculate duration
|
|
// let duration = start_time.elapsed();
|
|
// let duration_ms = duration.as_millis() as u64;
|
|
//
|
|
// // Log basic request info
|
|
// info!(
|
|
// "Request from {} to {} - Status: {} - Duration: {}ms",
|
|
// client_id,
|
|
// path,
|
|
// response.status().as_u16(),
|
|
// duration_ms
|
|
// );
|
|
//
|
|
// // TODO: Extract more details from request/response for logging
|
|
// // For now, we'll need to modify the server handler to pass additional context
|
|
//
|
|
// response
|
|
// }
|
|
|
|
/// Context for request logging that can be passed through extensions
|
|
#[derive(Clone)]
|
|
pub struct LoggingContext {
|
|
pub client_id: String,
|
|
pub provider_name: String,
|
|
pub model: String,
|
|
pub prompt_tokens: u32,
|
|
pub completion_tokens: u32,
|
|
pub total_tokens: u32,
|
|
pub cost: f64,
|
|
pub has_images: bool,
|
|
pub error: Option<AppError>,
|
|
}
|
|
|
|
impl LoggingContext {
|
|
pub fn new(client_id: String, provider_name: String, model: String) -> Self {
|
|
Self {
|
|
client_id,
|
|
provider_name,
|
|
model,
|
|
prompt_tokens: 0,
|
|
completion_tokens: 0,
|
|
total_tokens: 0,
|
|
cost: 0.0,
|
|
has_images: false,
|
|
error: None,
|
|
}
|
|
}
|
|
|
|
pub fn with_token_counts(mut self, prompt_tokens: u32, completion_tokens: u32) -> Self {
|
|
self.prompt_tokens = prompt_tokens;
|
|
self.completion_tokens = completion_tokens;
|
|
self.total_tokens = prompt_tokens + completion_tokens;
|
|
self
|
|
}
|
|
|
|
pub fn with_cost(mut self, cost: f64) -> Self {
|
|
self.cost = cost;
|
|
self
|
|
}
|
|
|
|
pub fn with_images(mut self, has_images: bool) -> Self {
|
|
self.has_images = has_images;
|
|
self
|
|
}
|
|
|
|
pub fn with_error(mut self, error: AppError) -> Self {
|
|
self.error = Some(error);
|
|
self
|
|
}
|
|
}
|