fix(streaming): use async_stream to ensure [DONE] is always sent
This commit is contained in:
@@ -6,9 +6,9 @@ use axum::{
|
|||||||
routing::{get, post},
|
routing::{get, post},
|
||||||
};
|
};
|
||||||
use futures::stream::StreamExt;
|
use futures::stream::StreamExt;
|
||||||
use std::time::Duration;
|
|
||||||
use sqlx;
|
use sqlx;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
use tracing::{info, warn};
|
use tracing::{info, warn};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
@@ -246,40 +246,49 @@ async fn chat_completions(
|
|||||||
let stream_id = format!("chatcmpl-{}", Uuid::new_v4());
|
let stream_id = format!("chatcmpl-{}", Uuid::new_v4());
|
||||||
let stream_created = chrono::Utc::now().timestamp() as u64;
|
let stream_created = chrono::Utc::now().timestamp() as u64;
|
||||||
|
|
||||||
// Map chunks to SSE events
|
// Map chunks to SSE events - clone stream_id for the async block
|
||||||
let sse_stream = aggregating_stream
|
let stream_id_for_sse = stream_id.clone();
|
||||||
.map(move |chunk_result| {
|
|
||||||
match chunk_result {
|
// Use async stream macro to ensure proper sequencing
|
||||||
Ok(chunk) => {
|
let final_stream = async_stream::stream! {
|
||||||
let response = ChatCompletionStreamResponse {
|
// First, process and yield all chunks from aggregator
|
||||||
id: stream_id.clone(),
|
let mut stream = Box::pin(aggregating_stream
|
||||||
object: "chat.completion.chunk".to_string(),
|
.map(move |chunk_result| {
|
||||||
created: stream_created,
|
match chunk_result {
|
||||||
model: chunk.model.clone(),
|
Ok(chunk) => {
|
||||||
choices: vec![ChatStreamChoice {
|
let response = ChatCompletionStreamResponse {
|
||||||
index: 0,
|
id: stream_id_for_sse.clone(),
|
||||||
delta: ChatStreamDelta {
|
object: "chat.completion.chunk".to_string(),
|
||||||
role: None,
|
created: stream_created,
|
||||||
content: Some(chunk.content),
|
model: chunk.model.clone(),
|
||||||
reasoning_content: chunk.reasoning_content,
|
choices: vec![ChatStreamChoice {
|
||||||
tool_calls: chunk.tool_calls,
|
index: 0,
|
||||||
},
|
delta: ChatStreamDelta {
|
||||||
finish_reason: chunk.finish_reason,
|
role: None,
|
||||||
}],
|
content: Some(chunk.content),
|
||||||
};
|
reasoning_content: chunk.reasoning_content,
|
||||||
Event::default().json_data(response)
|
tool_calls: chunk.tool_calls,
|
||||||
.map_err(|e| AppError::InternalError(format!("SSE error: {}", e)))
|
},
|
||||||
|
finish_reason: chunk.finish_reason,
|
||||||
|
}],
|
||||||
|
};
|
||||||
|
Event::default().json_data(response)
|
||||||
|
.map_err(|e| AppError::InternalError(format!("SSE error: {}", e)))
|
||||||
|
}
|
||||||
|
Err(e) => Err(e),
|
||||||
}
|
}
|
||||||
Err(e) => Err(e),
|
}));
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// Chain [DONE] - using repeat_with to ensure it gets polled
|
// Yield all chunks
|
||||||
let done_stream = futures::stream::repeat_with(|| Ok::<Event, AppError>(Event::default().data("[DONE]")))
|
while let Some(item) = stream.next().await {
|
||||||
.take(1);
|
yield item;
|
||||||
let out = sse_stream.chain(done_stream);
|
}
|
||||||
|
|
||||||
Ok(Sse::new(out).into_response())
|
// Finally yield [DONE]
|
||||||
|
yield Ok::<Event, AppError>(Event::default().data("[DONE]"));
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Sse::new(final_stream).into_response())
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
// Record provider failure
|
// Record provider failure
|
||||||
|
|||||||
Reference in New Issue
Block a user