From 656a6f31ce189032887aa25f24d6bc31846cf08f Mon Sep 17 00:00:00 2001 From: hobokenchicken Date: Tue, 3 Mar 2026 13:40:57 -0500 Subject: [PATCH] fix(streaming): use async_stream to ensure [DONE] is always sent --- src/server/mod.rs | 75 ++++++++++++++++++++++++++--------------------- 1 file changed, 42 insertions(+), 33 deletions(-) diff --git a/src/server/mod.rs b/src/server/mod.rs index bcff3be7..b2d4904f 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -6,9 +6,9 @@ use axum::{ routing::{get, post}, }; use futures::stream::StreamExt; -use std::time::Duration; use sqlx; use std::sync::Arc; +use std::time::Duration; use tracing::{info, warn}; use uuid::Uuid; @@ -246,40 +246,49 @@ async fn chat_completions( let stream_id = format!("chatcmpl-{}", Uuid::new_v4()); let stream_created = chrono::Utc::now().timestamp() as u64; - // Map chunks to SSE events - let sse_stream = aggregating_stream - .map(move |chunk_result| { - match chunk_result { - Ok(chunk) => { - let response = ChatCompletionStreamResponse { - id: stream_id.clone(), - object: "chat.completion.chunk".to_string(), - created: stream_created, - model: chunk.model.clone(), - choices: vec![ChatStreamChoice { - index: 0, - delta: ChatStreamDelta { - role: None, - content: Some(chunk.content), - reasoning_content: chunk.reasoning_content, - tool_calls: chunk.tool_calls, - }, - finish_reason: chunk.finish_reason, - }], - }; - Event::default().json_data(response) - .map_err(|e| AppError::InternalError(format!("SSE error: {}", e))) + // Map chunks to SSE events - clone stream_id for the async block + let stream_id_for_sse = stream_id.clone(); + + // Use async stream macro to ensure proper sequencing + let final_stream = async_stream::stream! { + // First, process and yield all chunks from aggregator + let mut stream = Box::pin(aggregating_stream + .map(move |chunk_result| { + match chunk_result { + Ok(chunk) => { + let response = ChatCompletionStreamResponse { + id: stream_id_for_sse.clone(), + object: "chat.completion.chunk".to_string(), + created: stream_created, + model: chunk.model.clone(), + choices: vec![ChatStreamChoice { + index: 0, + delta: ChatStreamDelta { + role: None, + content: Some(chunk.content), + reasoning_content: chunk.reasoning_content, + tool_calls: chunk.tool_calls, + }, + finish_reason: chunk.finish_reason, + }], + }; + Event::default().json_data(response) + .map_err(|e| AppError::InternalError(format!("SSE error: {}", e))) + } + Err(e) => Err(e), } - Err(e) => Err(e), - } - }); + })); + + // Yield all chunks + while let Some(item) = stream.next().await { + yield item; + } + + // Finally yield [DONE] + yield Ok::(Event::default().data("[DONE]")); + }; - // Chain [DONE] - using repeat_with to ensure it gets polled - let done_stream = futures::stream::repeat_with(|| Ok::(Event::default().data("[DONE]"))) - .take(1); - let out = sse_stream.chain(done_stream); - - Ok(Sse::new(out).into_response()) + Ok(Sse::new(final_stream).into_response()) } Err(e) => { // Record provider failure