fix(streaming): use async_stream to ensure [DONE] is always sent
Some checks failed
CI / Check (push) Has been cancelled
CI / Clippy (push) Has been cancelled
CI / Formatting (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Release Build (push) Has been cancelled

This commit is contained in:
2026-03-03 13:40:57 -05:00
parent e0948a3e7f
commit 656a6f31ce

View File

@@ -6,9 +6,9 @@ use axum::{
routing::{get, post}, routing::{get, post},
}; };
use futures::stream::StreamExt; use futures::stream::StreamExt;
use std::time::Duration;
use sqlx; use sqlx;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use tracing::{info, warn}; use tracing::{info, warn};
use uuid::Uuid; use uuid::Uuid;
@@ -246,40 +246,49 @@ async fn chat_completions(
let stream_id = format!("chatcmpl-{}", Uuid::new_v4()); let stream_id = format!("chatcmpl-{}", Uuid::new_v4());
let stream_created = chrono::Utc::now().timestamp() as u64; let stream_created = chrono::Utc::now().timestamp() as u64;
// Map chunks to SSE events // Map chunks to SSE events - clone stream_id for the async block
let sse_stream = aggregating_stream let stream_id_for_sse = stream_id.clone();
.map(move |chunk_result| {
match chunk_result { // Use async stream macro to ensure proper sequencing
Ok(chunk) => { let final_stream = async_stream::stream! {
let response = ChatCompletionStreamResponse { // First, process and yield all chunks from aggregator
id: stream_id.clone(), let mut stream = Box::pin(aggregating_stream
object: "chat.completion.chunk".to_string(), .map(move |chunk_result| {
created: stream_created, match chunk_result {
model: chunk.model.clone(), Ok(chunk) => {
choices: vec![ChatStreamChoice { let response = ChatCompletionStreamResponse {
index: 0, id: stream_id_for_sse.clone(),
delta: ChatStreamDelta { object: "chat.completion.chunk".to_string(),
role: None, created: stream_created,
content: Some(chunk.content), model: chunk.model.clone(),
reasoning_content: chunk.reasoning_content, choices: vec![ChatStreamChoice {
tool_calls: chunk.tool_calls, index: 0,
}, delta: ChatStreamDelta {
finish_reason: chunk.finish_reason, role: None,
}], content: Some(chunk.content),
}; reasoning_content: chunk.reasoning_content,
Event::default().json_data(response) tool_calls: chunk.tool_calls,
.map_err(|e| AppError::InternalError(format!("SSE error: {}", e))) },
finish_reason: chunk.finish_reason,
}],
};
Event::default().json_data(response)
.map_err(|e| AppError::InternalError(format!("SSE error: {}", e)))
}
Err(e) => Err(e),
} }
Err(e) => Err(e), }));
}
}); // Yield all chunks
while let Some(item) = stream.next().await {
yield item;
}
// Finally yield [DONE]
yield Ok::<Event, AppError>(Event::default().data("[DONE]"));
};
// Chain [DONE] - using repeat_with to ensure it gets polled Ok(Sse::new(final_stream).into_response())
let done_stream = futures::stream::repeat_with(|| Ok::<Event, AppError>(Event::default().data("[DONE]")))
.take(1);
let out = sse_stream.chain(done_stream);
Ok(Sse::new(out).into_response())
} }
Err(e) => { Err(e) => {
// Record provider failure // Record provider failure