fix(streaming): collect chunks then stream with explicit [DONE]
Some checks failed
CI / Check (push) Has been cancelled
CI / Clippy (push) Has been cancelled
CI / Formatting (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Release Build (push) Has been cancelled

This commit is contained in:
2026-03-03 13:02:10 -05:00
parent 2508a745c6
commit 2a7a380977

View File

@@ -5,7 +5,7 @@ use axum::{
response::sse::{Event, Sse}, response::sse::{Event, Sse},
routing::{get, post}, routing::{get, post},
}; };
use futures::stream::StreamExt; use futures::stream::{StreamExt, self};
use std::time::Duration; use std::time::Duration;
use sqlx; use sqlx;
use std::sync::Arc; use std::sync::Arc;
@@ -19,6 +19,7 @@ use crate::{
ChatChoice, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, ChatMessage, ChatChoice, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, ChatMessage,
ChatStreamChoice, ChatStreamDelta, Usage, ChatStreamChoice, ChatStreamDelta, Usage,
}, },
providers::ProviderStreamChunk,
rate_limiting, rate_limiting,
state::AppState, state::AppState,
}; };
@@ -241,17 +242,23 @@ async fn chat_completions(
}, },
); );
// Create SSE stream from aggregating stream // Create SSE stream with explicit [DONE] termination
let stream_id = format!("chatcmpl-{}", Uuid::new_v4()); let stream_id = format!("chatcmpl-{}", Uuid::new_v4());
let stream_created = chrono::Utc::now().timestamp() as u64; let stream_created = chrono::Utc::now().timestamp() as u64;
let stream_id_clone = stream_id.clone();
let sse_stream = aggregating_stream // Convert aggregator to a Vec first, then stream with [DONE]
.map(move |chunk_result| { let chunks: Vec<Result<ProviderStreamChunk, AppError>> = aggregating_stream.collect().await;
// Create stream that yields SSE events then [DONE]
let final_stream = stream::iter(chunks)
.then(move |chunk_result| {
let sid = stream_id_clone.clone();
async move {
match chunk_result { match chunk_result {
Ok(chunk) => { Ok(chunk) => {
// Convert provider chunk to OpenAI-compatible SSE event
let response = ChatCompletionStreamResponse { let response = ChatCompletionStreamResponse {
id: stream_id.clone(), id: sid,
object: "chat.completion.chunk".to_string(), object: "chat.completion.chunk".to_string(),
created: stream_created, created: stream_created,
model: chunk.model.clone(), model: chunk.model.clone(),
@@ -266,28 +273,18 @@ async fn chat_completions(
finish_reason: chunk.finish_reason, finish_reason: chunk.finish_reason,
}], }],
}; };
Event::default().json_data(response)
.map_err(|e| AppError::InternalError(format!("SSE serialization failed: {}", e)))
}
Err(e) => Err(e),
}
}
})
.chain(stream::once(async {
Ok::<Event, AppError>(Event::default().data("[DONE]"))
}));
match Event::default().json_data(response) { Ok(Sse::new(final_stream).into_response())
Ok(event) => Ok(event),
Err(e) => {
warn!("Failed to serialize SSE event: {}", e);
Err(AppError::InternalError("SSE serialization failed".to_string()))
}
}
}
Err(e) => {
warn!("Error in streaming response: {}", e);
Err(e)
}
}
});
// Append [DONE] using iter (not once) - should ensure it gets polled
let done_vec = vec![Ok::<Event, AppError>(Event::default().data("[DONE]"))];
let done_stream = futures::stream::iter(done_vec);
let out = sse_stream.chain(done_stream);
Ok(Sse::new(out).into_response())
} }
Err(e) => { Err(e) => {
// Record provider failure // Record provider failure