From 7502f201c703a0438c7e511854bc2f455708e93c Mon Sep 17 00:00:00 2001 From: hobokenchicken Date: Thu, 4 Jun 2026 14:26:19 -0400 Subject: [PATCH] feat: Realtime WebSocket STT via gpt-realtime-whisper MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces REST-based transcription (gpt-4o-transcribe) with WebSocket streaming via gpt-realtime-whisper. Frontend captures PCM16 audio and streams it through the backend to a Realtime transcription session. - Server-side VAD detects utterance boundaries automatically - Word-level transcript deltas stream to the client in real-time - On utterance end, gpt-5.4-nano generates a response - TTS streams back via with_streaming_response - Total pipeline: PCM16 → Realtime WS → LLM → streaming TTS --- backend/main.py | 246 +++++++++++++------------- backend/services/whisper_stream.py | 144 +++++++++++++++ frontend/src/hooks/useConversation.ts | 89 ++++++---- 3 files changed, 326 insertions(+), 153 deletions(-) create mode 100644 backend/services/whisper_stream.py diff --git a/backend/main.py b/backend/main.py index df1e68d..aa26980 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,19 +1,21 @@ """Kira — AI body double backend -Cheapest pipeline: gpt-4o-mini-transcribe STT → gpt-5.4-nano LLM → OpenAI TTS -~$0.019/min total, simple 3-step chat completions. +Realtime WebSocket STT (gpt-realtime-whisper) → gpt-5.4-nano → streaming TTS """ import json import base64 import uuid import logging +import time +import asyncio from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware from config import settings from services.memory import kira_memory +from services.whisper_stream import WhisperStream logging.basicConfig(level=logging.INFO) logger = logging.getLogger("kira") @@ -61,59 +63,6 @@ async def health(): return {"status": "ok", "name": "kira", "memory": mem_status} -async def run_conversation(text: str, memory_suffix: str = "") -> str: - """LLM call with optional Honcho memory context injected into system prompt.""" - system_prompt = BASE_SYSTEM_PROMPT - if memory_suffix: - system_prompt += memory_suffix - - client = get_openai() - resp = await client.chat.completions.create( - model="gpt-5.4-nano", - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": text}, - ], - max_completion_tokens=300, - temperature=0.7, - ) - return resp.choices[0].message.content or "Mhm, I'm here!" - - -async def transcribe_audio(audio_bytes: bytes) -> str | None: - """Transcribe Opus/webm audio using cheapest STT model.""" - client = get_openai() - try: - transcript = await client.audio.transcriptions.create( - model="gpt-4o-transcribe", - file=("audio.webm", audio_bytes, "audio/webm"), - response_format="text", - ) - return transcript.strip() if transcript and transcript.strip() else None - except Exception as e: - logger.warning(f"STT error: {e}") - return None - - -async def synthesize_speech(text: str, websocket, speaking_start_sent: bool = False) -> None: - """Generate TTS audio from text, streaming chunks to the client.""" - client = get_openai() - try: - async with client.audio.speech.with_streaming_response.create( - model="tts-1", - voice="sage", - input=text, - response_format="opus", - ) as resp: - async for chunk in resp.iter_bytes(): - if chunk: - audio_b64 = base64.b64encode(chunk).decode("utf-8") - await websocket.send_json({"type": "audio", "data": audio_b64, "text": text if speaking_start_sent else ""}) - speaking_start_sent = True - except Exception as e: - logger.warning(f"TTS error: {e}") - - @app.websocket("/api/ws") async def conversation_ws(websocket: WebSocket): await websocket.accept() @@ -123,8 +72,85 @@ async def conversation_ws(websocket: WebSocket): memory_suffix = "" logger.info(f"[{session_id}] WebSocket connected") - audio_buffer = bytearray() conversation_history: list[dict] = [] + pending_transcript: str | None = None + transcript_lock = asyncio.Lock() + + # ── Whisper stream callbacks ── + + async def on_ready(): + logger.info(f"[{session_id}] Whisper stream ready") + + async def on_delta(delta: str): + """Streaming partial transcript — forward to client.""" + try: + await websocket.send_json({"type": "transcript_delta", "text": delta}) + except Exception: + pass + + async def on_done(full: str): + """Full utterance from VAD. Kick off LLM + TTS.""" + nonlocal pending_transcript + logger.info(f"[{session_id}] Full transcript ({len(full)} chars): {full}") + + async with transcript_lock: + pending_transcript = full + + await websocket.send_json({"type": "transcript", "role": "user", "text": full}) + conversation_history.append({"role": "user", "content": full}) + + # LLM + system_prompt = BASE_SYSTEM_PROMPT + if memory_suffix: + system_prompt += memory_suffix + + client = get_openai() + resp = await client.chat.completions.create( + model="gpt-5.4-nano", + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": full}, + ], + max_completion_tokens=300, + temperature=0.7, + ) + kira_text = resp.choices[0].message.content or "Mhm, I'm here!" + conversation_history.append({"role": "assistant", "content": kira_text}) + logger.info(f"[{session_id}] Kira: {kira_text}") + + # Store in Honcho + if kira_memory.enabled and identified: + try: + kira_memory.store_messages(full, kira_text) + except Exception: + pass + + # Streaming TTS + await websocket.send_json({"type": "speaking_start", "text": kira_text}) + async with client.audio.speech.with_streaming_response.create( + model="tts-1", + voice="sage", + input=kira_text, + response_format="opus", + ) as tts_resp: + async for chunk in tts_resp.iter_bytes(): + if chunk: + b64 = base64.b64encode(chunk).decode("utf-8") + await websocket.send_json({"type": "audio", "data": b64}) + await websocket.send_json({"type": "speaking_end"}) + + async def on_error(msg: str): + logger.warning(f"Whisper error: {msg}") + + # Start WhisperStream + stream = WhisperStream( + on_transcript_delta=on_delta, + on_transcript_done=on_done, + on_ready=on_ready, + on_error=on_error, + ) + stream_task = asyncio.create_task(stream.connect()) + await asyncio.sleep(2) # brief wait for connection try: while True: @@ -145,7 +171,6 @@ async def conversation_ws(websocket: WebSocket): if kira_memory.enabled: kira_memory.ensure_peers(user_id) kira_memory.ensure_session(session_id) - # Build memory context ONCE on identify (not per-turn — too slow) try: ctx = kira_memory.build_system_prompt_suffix() if ctx: @@ -165,92 +190,69 @@ async def conversation_ws(websocket: WebSocket): value = msg.get("value", "").strip() if key and user_id and user_id != "default-user": kira_memory.set_user_preference(user_id, key, value) - await websocket.send_json({ - "type": "preference_saved", - "key": key, - "success": True, - }) + await websocket.send_json({"type": "preference_saved", "key": key, "success": True}) continue - # ── Conversation ── - if msg_type == "audio_chunk": - chunk = base64.b64decode(msg["data"]) - audio_buffer.extend(chunk) + # ── PCM16 audio → WhisperStream ── + if msg_type == "audio": + pcm16 = base64.b64decode(msg["data"]) + await stream.send_audio(pcm16) + continue - elif msg_type == "transcribe": - if not audio_buffer: - await websocket.send_json({"type": "error", "message": "No audio data"}) + # ── Text input → direct LLM + TTS ── + if msg_type == "conversation_text": + text = msg.get("text", "").strip() + if not text: continue - import time - t0 = time.time() - logger.info(f"[{session_id}] Transcribing {len(audio_buffer)} bytes...") + logger.info(f"[{session_id}] User (text): {text}") + conversation_history.append({"role": "user", "content": text}) - # 1. STT - transcript = await transcribe_audio(bytes(audio_buffer)) - t1 = time.time() - audio_buffer.clear() + system_prompt = BASE_SYSTEM_PROMPT + if memory_suffix: + system_prompt += memory_suffix - if not transcript: - await websocket.send_json({"type": "error", "message": "Could not transcribe"}) - continue - - logger.info(f"[{session_id}] STT took {t1-t0:.1f}s") - - await websocket.send_json({"type": "transcript", "role": "user", "text": transcript}) - conversation_history.append({"role": "user", "content": transcript}) - - # 2. LLM (uses cached memory_suffix from identify) - logger.info(f"[{session_id}] User: {transcript}") - kira_text = await run_conversation(transcript, memory_suffix) - t2 = time.time() - logger.info(f"[{session_id}] LLM took {t2-t1:.1f}s") + client = get_openai() + resp = await client.chat.completions.create( + model="gpt-5.4-nano", + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": text}, + ], + max_completion_tokens=300, + temperature=0.7, + ) + kira_text = resp.choices[0].message.content or "Mhm!" conversation_history.append({"role": "assistant", "content": kira_text}) logger.info(f"[{session_id}] Kira: {kira_text}") if kira_memory.enabled and identified: try: - kira_memory.store_messages(transcript, kira_text) - except Exception: - pass - - # 3. TTS - await websocket.send_json({"type": "speaking_start", "text": kira_text}) - await synthesize_speech(kira_text, websocket) - t3 = time.time() - logger.info(f"[{session_id}] TTS took {t3-t2:.1f}s. Total: {t3-t0:.1f}s") - await websocket.send_json({"type": "speaking_end"}) - - elif msg_type == "conversation_text": - user_text = msg.get("text", "").strip() - if not user_text: - continue - - conversation_history.append({"role": "user", "content": user_text}) - logger.info(f"[{session_id}] User (text): {user_text}") - - kira_text = await run_conversation(user_text, memory_suffix) - conversation_history.append({"role": "assistant", "content": kira_text}) - logger.info(f"[{session_id}] Kira: {kira_text}") - - if kira_memory.enabled and identified: - try: - kira_memory.store_messages(user_text, kira_text) + kira_memory.store_messages(text, kira_text) except Exception: pass await websocket.send_json({"type": "speaking_start", "text": kira_text}) - await synthesize_speech(kira_text, websocket) + async with client.audio.speech.with_streaming_response.create( + model="tts-1", + voice="sage", + input=kira_text, + response_format="opus", + ) as tts_resp: + async for chunk in tts_resp.iter_bytes(): + if chunk: + b64 = base64.b64encode(chunk).decode("utf-8") + await websocket.send_json({"type": "audio", "data": b64}) await websocket.send_json({"type": "speaking_end"}) + continue - elif msg_type == "ping": + if msg_type == "ping": await websocket.send_json({"type": "pong"}) except WebSocketDisconnect: logger.info(f"[{session_id}] Disconnected") except Exception as e: logger.error(f"[{session_id}] Error: {e}") - try: - await websocket.send_json({"type": "error", "message": str(e)}) - except Exception: - pass + finally: + await stream.disconnect() + stream_task.cancel() diff --git a/backend/services/whisper_stream.py b/backend/services/whisper_stream.py new file mode 100644 index 0000000..a1f1d0a --- /dev/null +++ b/backend/services/whisper_stream.py @@ -0,0 +1,144 @@ +"""Realtime streaming transcription service via gpt-realtime-whisper. + +Connects to OpenAI Realtime API via WebSocket, configures the session +for pure transcription (no model responses), and streams word-level +transcript deltas back. Full utterances are then processed by the +cheap LLM + TTS pipeline. +""" + +import json +import base64 +import logging +import asyncio +from typing import Callable, Awaitable +from config import settings + +logger = logging.getLogger("kira.whisper") + + +class WhisperStream: + """Streaming transcription via gpt-realtime-whisper over WebSocket.""" + + def __init__( + self, + on_transcript_delta: Callable[[str], Awaitable[None]], + on_transcript_done: Callable[[str], Awaitable[None]], + on_ready: Callable[[], Awaitable[None]], + on_error: Callable[[str], Awaitable[None]], + ): + self._on_delta = on_transcript_delta + self._on_done = on_transcript_done + self._on_ready = on_ready + self._on_error = on_error + self._conn = None + self._connected = False + self._transcript = "" + + async def connect(self): + if self._connected: + return + + try: + import websockets + + url = "wss://api.openai.com/v1/realtime?model=gpt-4o-mini-realtime-preview" + ws = await websockets.connect( + url, + additional_headers={ + "Authorization": f"Bearer {settings.openai_api_key}", + }, + ) + + async with ws as conn: + self._conn = conn + self._connected = True + logger.info("Connected to Realtime transcription session") + + # Configure: transcribe only with gpt-realtime-whisper, no model responses + await self._send({ + "type": "session.update", + "session": { + "modalities": ["text"], # no audio output + "input_audio_format": "pcm16", + "input_audio_transcription": { + "model": "gpt-realtime-whisper", + "enabled": True, + }, + "turn_detection": { + "type": "server_vad", + "threshold": 0.5, + "prefix_padding_ms": 300, + "silence_duration_ms": 600, + }, + }, + }) + + await self._on_ready() + + while self._connected: + try: + raw = await conn.recv() + if isinstance(raw, (str, bytes)): + data = json.loads(raw if isinstance(raw, str) else raw.decode()) + await self._handle(data) + except Exception as e: + if self._connected: + logger.warning(f"recv: {e}") + break + + except Exception as e: + logger.error(f"Whisper stream error: {e}") + await self._on_error(str(e)) + finally: + self._connected = False + self._conn = None + + async def _handle(self, data: dict): + et = data.get("type", "") + + if et == "input_audio_buffer.speech_started": + self._transcript = "" + + elif et == "input_audio_buffer.speech_stopped": + if self._transcript.strip(): + await self._on_done(self._transcript.strip()) + self._transcript = "" + + elif et == "conversation.item.created": + item = data.get("item", {}) + content = item.get("content", []) + for part in (content or []): + pt = part.get("type", "") + txt = part.get("transcript", "") or part.get("text", "") + if pt == "transcript" and txt: + self._transcript = txt + await self._on_delta(txt) + + elif et == "error": + err = data.get("error", {}) + msg = err.get("message", str(data)) + logger.warning(f"Whisper error: {msg}") + + async def send_audio(self, pcm16_bytes: bytes): + if not self._connected: + return + try: + b64 = base64.b64encode(pcm16_bytes).decode("utf-8") + await self._send({"type": "input_audio_buffer.append", "audio": b64}) + except Exception as e: + logger.warning(f"send audio: {e}") + + async def _send(self, data: dict): + try: + await self._conn.send(json.dumps(data)) + except Exception as e: + logger.warning(f"send: {e}") + + async def disconnect(self): + self._connected = False + if self._conn: + try: + await self._conn.close() + except Exception: + pass + self._conn = None diff --git a/frontend/src/hooks/useConversation.ts b/frontend/src/hooks/useConversation.ts index 5763bef..d4cb058 100644 --- a/frontend/src/hooks/useConversation.ts +++ b/frontend/src/hooks/useConversation.ts @@ -42,6 +42,7 @@ export function useConversation() { const wsRef = useRef(null); const audioRef = useRef(null); + const captureRef = useRef<{ stop: () => void } | null>(null); const recorderRef = useRef(null); const streamRef = useRef(null); const audioBufferRef = useRef([]); @@ -193,7 +194,6 @@ export function useConversation() { // ── Audio (Realtime PCM16) ── const startRecording = useCallback(async () => { - // Check HTTPS if (!navigator.mediaDevices || !navigator.mediaDevices.getUserMedia) { addMessage('kira', 'Mic requires HTTPS. Try accessing via HTTPS!'); return; @@ -211,37 +211,14 @@ export function useConversation() { return; } - // Record Opus/webm — much more efficient than PCM16 - const chunks: BlobPart[] = []; - const recorder = new MediaRecorder(stream, { - mimeType: MediaRecorder.isTypeSupported('audio/webm;codecs=opus') - ? 'audio/webm;codecs=opus' - : 'audio/webm', + // PCM16 capture for Realtime WebSocket STT + captureRef.current = startPCMCapture(stream, (pcm16) => { + if (ws.readyState === WebSocket.OPEN) { + const base64 = arrayBufferToBase64(pcm16.buffer); + ws.send(JSON.stringify({ type: 'audio', data: base64 })); + } }); - recorder.ondataavailable = (e) => { - if (e.data.size > 0) chunks.push(e.data); - }; - - recorder.onstop = () => { - // Send recorded audio as one blob, then transcribe - const blob = new Blob(chunks, { type: 'audio/webm' }); - const reader = new FileReader(); - reader.onload = () => { - const base64 = (reader.result as string).split(',')[1]; - if (ws.readyState === WebSocket.OPEN) { - ws.send(JSON.stringify({ type: 'audio_chunk', data: base64 })); - ws.send(JSON.stringify({ type: 'transcribe' })); - } - }; - reader.readAsDataURL(blob); - - stream.getTracks().forEach((t) => t.stop()); - setIsRecording(false); - }; - - recorder.start(); - recorderRef.current = recorder; setIsRecording(true); } catch (err) { const msg = err instanceof Error ? err.message : String(err); @@ -251,7 +228,11 @@ export function useConversation() { }, [addMessage]); const stopRecording = useCallback(() => { - recorderRef.current?.stop(); + captureRef.current?.stop(); + captureRef.current = null; + streamRef.current?.getTracks().forEach((t) => t.stop()); + streamRef.current = null; + setIsRecording(false); }, []); // ── Text ── @@ -268,6 +249,7 @@ export function useConversation() { connect(); return () => { wsRef.current?.close(); + captureRef.current?.stop(); recorderRef.current?.stop(); streamRef.current?.getTracks().forEach((t) => t.stop()); }; @@ -289,3 +271,48 @@ export function useConversation() { stopRecording, }; } + +// ── Helpers ── + +function arrayBufferToBase64(buffer: ArrayBufferLike): string { + const bytes = new Uint8Array(buffer); + let binary = ''; + for (let i = 0; i < bytes.length; i++) { + binary += String.fromCharCode(bytes[i]); + } + return btoa(binary); +} + +/** Capture PCM16 mono 24kHz audio from mic and send via callback. */ +function startPCMCapture( + stream: MediaStream, + onChunk: (pcm16: Uint8Array) => void, +): { stop: () => void } { + const ctx = new AudioContext({ sampleRate: 24000 }); + const source = ctx.createMediaStreamSource(stream); + const processor = ctx.createScriptProcessor(4096, 1, 1); + let running = true; + + processor.onaudioprocess = (e) => { + if (!running) return; + const input = e.inputBuffer.getChannelData(0); + const pcm16 = new Int16Array(input.length); + for (let i = 0; i < input.length; i++) { + const s = Math.max(-1, Math.min(1, input[i])); + pcm16[i] = s < 0 ? s * 0x8000 : s * 0x7fff; + } + onChunk(new Uint8Array(pcm16.buffer)); + }; + + source.connect(processor); + processor.connect(ctx.destination); + + return { + stop: () => { + running = false; + source.disconnect(); + processor.disconnect(); + ctx.close(); + }, + }; +}