diff --git a/src/server/mod.rs b/src/server/mod.rs index 96646e74..f1a835af 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -5,7 +5,7 @@ use axum::{ response::sse::{Event, Sse}, routing::{get, post}, }; -use futures::stream::StreamExt; +use futures::stream::{StreamExt, self}; use std::time::Duration; use sqlx; use std::sync::Arc; @@ -19,6 +19,7 @@ use crate::{ ChatChoice, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, ChatMessage, ChatStreamChoice, ChatStreamDelta, Usage, }, + providers::ProviderStreamChunk, rate_limiting, state::AppState, }; @@ -241,17 +242,23 @@ async fn chat_completions( }, ); - // Create SSE stream from aggregating stream + // Create SSE stream with explicit [DONE] termination let stream_id = format!("chatcmpl-{}", Uuid::new_v4()); let stream_created = chrono::Utc::now().timestamp() as u64; + let stream_id_clone = stream_id.clone(); - let sse_stream = aggregating_stream - .map(move |chunk_result| { + // Convert aggregator to a Vec first, then stream with [DONE] + let chunks: Vec> = aggregating_stream.collect().await; + + // Create stream that yields SSE events then [DONE] + let final_stream = stream::iter(chunks) + .then(move |chunk_result| { + let sid = stream_id_clone.clone(); + async move { match chunk_result { Ok(chunk) => { - // Convert provider chunk to OpenAI-compatible SSE event let response = ChatCompletionStreamResponse { - id: stream_id.clone(), + id: sid, object: "chat.completion.chunk".to_string(), created: stream_created, model: chunk.model.clone(), @@ -266,28 +273,18 @@ async fn chat_completions( finish_reason: chunk.finish_reason, }], }; - - match Event::default().json_data(response) { - Ok(event) => Ok(event), - Err(e) => { - warn!("Failed to serialize SSE event: {}", e); - Err(AppError::InternalError("SSE serialization failed".to_string())) - } - } - } - Err(e) => { - warn!("Error in streaming response: {}", e); - Err(e) + Event::default().json_data(response) + .map_err(|e| AppError::InternalError(format!("SSE serialization failed: {}", e))) } + Err(e) => Err(e), } - }); + } + }) + .chain(stream::once(async { + Ok::(Event::default().data("[DONE]")) + })); - // Append [DONE] using iter (not once) - should ensure it gets polled - let done_vec = vec![Ok::(Event::default().data("[DONE]"))]; - let done_stream = futures::stream::iter(done_vec); - let out = sse_stream.chain(done_stream); - - Ok(Sse::new(out).into_response()) + Ok(Sse::new(final_stream).into_response()) } Err(e) => { // Record provider failure