feat: Realtime WebSocket STT via gpt-realtime-whisper

Replaces REST-based transcription (gpt-4o-transcribe) with WebSocket
streaming via gpt-realtime-whisper. Frontend captures PCM16 audio and
streams it through the backend to a Realtime transcription session.

- Server-side VAD detects utterance boundaries automatically
- Word-level transcript deltas stream to the client in real-time
- On utterance end, gpt-5.4-nano generates a response
- TTS streams back via with_streaming_response
- Total pipeline: PCM16 → Realtime WS → LLM → streaming TTS
This commit is contained in:
2026-06-04 14:26:19 -04:00
parent 25b12ee14f
commit 7502f201c7
3 changed files with 326 additions and 153 deletions
+124 -122
View File
@@ -1,19 +1,21 @@
"""Kira — AI body double backend
Cheapest pipeline: gpt-4o-mini-transcribe STT → gpt-5.4-nano LLM → OpenAI TTS
~$0.019/min total, simple 3-step chat completions.
Realtime WebSocket STT (gpt-realtime-whisper) → gpt-5.4-nano → streaming TTS
"""
import json
import base64
import uuid
import logging
import time
import asyncio
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from config import settings
from services.memory import kira_memory
from services.whisper_stream import WhisperStream
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("kira")
@@ -61,59 +63,6 @@ async def health():
return {"status": "ok", "name": "kira", "memory": mem_status}
async def run_conversation(text: str, memory_suffix: str = "") -> str:
"""LLM call with optional Honcho memory context injected into system prompt."""
system_prompt = BASE_SYSTEM_PROMPT
if memory_suffix:
system_prompt += memory_suffix
client = get_openai()
resp = await client.chat.completions.create(
model="gpt-5.4-nano",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": text},
],
max_completion_tokens=300,
temperature=0.7,
)
return resp.choices[0].message.content or "Mhm, I'm here!"
async def transcribe_audio(audio_bytes: bytes) -> str | None:
"""Transcribe Opus/webm audio using cheapest STT model."""
client = get_openai()
try:
transcript = await client.audio.transcriptions.create(
model="gpt-4o-transcribe",
file=("audio.webm", audio_bytes, "audio/webm"),
response_format="text",
)
return transcript.strip() if transcript and transcript.strip() else None
except Exception as e:
logger.warning(f"STT error: {e}")
return None
async def synthesize_speech(text: str, websocket, speaking_start_sent: bool = False) -> None:
"""Generate TTS audio from text, streaming chunks to the client."""
client = get_openai()
try:
async with client.audio.speech.with_streaming_response.create(
model="tts-1",
voice="sage",
input=text,
response_format="opus",
) as resp:
async for chunk in resp.iter_bytes():
if chunk:
audio_b64 = base64.b64encode(chunk).decode("utf-8")
await websocket.send_json({"type": "audio", "data": audio_b64, "text": text if speaking_start_sent else ""})
speaking_start_sent = True
except Exception as e:
logger.warning(f"TTS error: {e}")
@app.websocket("/api/ws")
async def conversation_ws(websocket: WebSocket):
await websocket.accept()
@@ -123,8 +72,85 @@ async def conversation_ws(websocket: WebSocket):
memory_suffix = ""
logger.info(f"[{session_id}] WebSocket connected")
audio_buffer = bytearray()
conversation_history: list[dict] = []
pending_transcript: str | None = None
transcript_lock = asyncio.Lock()
# ── Whisper stream callbacks ──
async def on_ready():
logger.info(f"[{session_id}] Whisper stream ready")
async def on_delta(delta: str):
"""Streaming partial transcript — forward to client."""
try:
await websocket.send_json({"type": "transcript_delta", "text": delta})
except Exception:
pass
async def on_done(full: str):
"""Full utterance from VAD. Kick off LLM + TTS."""
nonlocal pending_transcript
logger.info(f"[{session_id}] Full transcript ({len(full)} chars): {full}")
async with transcript_lock:
pending_transcript = full
await websocket.send_json({"type": "transcript", "role": "user", "text": full})
conversation_history.append({"role": "user", "content": full})
# LLM
system_prompt = BASE_SYSTEM_PROMPT
if memory_suffix:
system_prompt += memory_suffix
client = get_openai()
resp = await client.chat.completions.create(
model="gpt-5.4-nano",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": full},
],
max_completion_tokens=300,
temperature=0.7,
)
kira_text = resp.choices[0].message.content or "Mhm, I'm here!"
conversation_history.append({"role": "assistant", "content": kira_text})
logger.info(f"[{session_id}] Kira: {kira_text}")
# Store in Honcho
if kira_memory.enabled and identified:
try:
kira_memory.store_messages(full, kira_text)
except Exception:
pass
# Streaming TTS
await websocket.send_json({"type": "speaking_start", "text": kira_text})
async with client.audio.speech.with_streaming_response.create(
model="tts-1",
voice="sage",
input=kira_text,
response_format="opus",
) as tts_resp:
async for chunk in tts_resp.iter_bytes():
if chunk:
b64 = base64.b64encode(chunk).decode("utf-8")
await websocket.send_json({"type": "audio", "data": b64})
await websocket.send_json({"type": "speaking_end"})
async def on_error(msg: str):
logger.warning(f"Whisper error: {msg}")
# Start WhisperStream
stream = WhisperStream(
on_transcript_delta=on_delta,
on_transcript_done=on_done,
on_ready=on_ready,
on_error=on_error,
)
stream_task = asyncio.create_task(stream.connect())
await asyncio.sleep(2) # brief wait for connection
try:
while True:
@@ -145,7 +171,6 @@ async def conversation_ws(websocket: WebSocket):
if kira_memory.enabled:
kira_memory.ensure_peers(user_id)
kira_memory.ensure_session(session_id)
# Build memory context ONCE on identify (not per-turn — too slow)
try:
ctx = kira_memory.build_system_prompt_suffix()
if ctx:
@@ -165,92 +190,69 @@ async def conversation_ws(websocket: WebSocket):
value = msg.get("value", "").strip()
if key and user_id and user_id != "default-user":
kira_memory.set_user_preference(user_id, key, value)
await websocket.send_json({
"type": "preference_saved",
"key": key,
"success": True,
})
await websocket.send_json({"type": "preference_saved", "key": key, "success": True})
continue
# ── Conversation ──
if msg_type == "audio_chunk":
chunk = base64.b64decode(msg["data"])
audio_buffer.extend(chunk)
# ── PCM16 audio → WhisperStream ──
if msg_type == "audio":
pcm16 = base64.b64decode(msg["data"])
await stream.send_audio(pcm16)
continue
elif msg_type == "transcribe":
if not audio_buffer:
await websocket.send_json({"type": "error", "message": "No audio data"})
# ── Text input → direct LLM + TTS ──
if msg_type == "conversation_text":
text = msg.get("text", "").strip()
if not text:
continue
import time
t0 = time.time()
logger.info(f"[{session_id}] Transcribing {len(audio_buffer)} bytes...")
logger.info(f"[{session_id}] User (text): {text}")
conversation_history.append({"role": "user", "content": text})
# 1. STT
transcript = await transcribe_audio(bytes(audio_buffer))
t1 = time.time()
audio_buffer.clear()
system_prompt = BASE_SYSTEM_PROMPT
if memory_suffix:
system_prompt += memory_suffix
if not transcript:
await websocket.send_json({"type": "error", "message": "Could not transcribe"})
continue
logger.info(f"[{session_id}] STT took {t1-t0:.1f}s")
await websocket.send_json({"type": "transcript", "role": "user", "text": transcript})
conversation_history.append({"role": "user", "content": transcript})
# 2. LLM (uses cached memory_suffix from identify)
logger.info(f"[{session_id}] User: {transcript}")
kira_text = await run_conversation(transcript, memory_suffix)
t2 = time.time()
logger.info(f"[{session_id}] LLM took {t2-t1:.1f}s")
client = get_openai()
resp = await client.chat.completions.create(
model="gpt-5.4-nano",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": text},
],
max_completion_tokens=300,
temperature=0.7,
)
kira_text = resp.choices[0].message.content or "Mhm!"
conversation_history.append({"role": "assistant", "content": kira_text})
logger.info(f"[{session_id}] Kira: {kira_text}")
if kira_memory.enabled and identified:
try:
kira_memory.store_messages(transcript, kira_text)
except Exception:
pass
# 3. TTS
await websocket.send_json({"type": "speaking_start", "text": kira_text})
await synthesize_speech(kira_text, websocket)
t3 = time.time()
logger.info(f"[{session_id}] TTS took {t3-t2:.1f}s. Total: {t3-t0:.1f}s")
await websocket.send_json({"type": "speaking_end"})
elif msg_type == "conversation_text":
user_text = msg.get("text", "").strip()
if not user_text:
continue
conversation_history.append({"role": "user", "content": user_text})
logger.info(f"[{session_id}] User (text): {user_text}")
kira_text = await run_conversation(user_text, memory_suffix)
conversation_history.append({"role": "assistant", "content": kira_text})
logger.info(f"[{session_id}] Kira: {kira_text}")
if kira_memory.enabled and identified:
try:
kira_memory.store_messages(user_text, kira_text)
kira_memory.store_messages(text, kira_text)
except Exception:
pass
await websocket.send_json({"type": "speaking_start", "text": kira_text})
await synthesize_speech(kira_text, websocket)
async with client.audio.speech.with_streaming_response.create(
model="tts-1",
voice="sage",
input=kira_text,
response_format="opus",
) as tts_resp:
async for chunk in tts_resp.iter_bytes():
if chunk:
b64 = base64.b64encode(chunk).decode("utf-8")
await websocket.send_json({"type": "audio", "data": b64})
await websocket.send_json({"type": "speaking_end"})
continue
elif msg_type == "ping":
if msg_type == "ping":
await websocket.send_json({"type": "pong"})
except WebSocketDisconnect:
logger.info(f"[{session_id}] Disconnected")
except Exception as e:
logger.error(f"[{session_id}] Error: {e}")
try:
await websocket.send_json({"type": "error", "message": str(e)})
except Exception:
pass
finally:
await stream.disconnect()
stream_task.cancel()
+144
View File
@@ -0,0 +1,144 @@
"""Realtime streaming transcription service via gpt-realtime-whisper.
Connects to OpenAI Realtime API via WebSocket, configures the session
for pure transcription (no model responses), and streams word-level
transcript deltas back. Full utterances are then processed by the
cheap LLM + TTS pipeline.
"""
import json
import base64
import logging
import asyncio
from typing import Callable, Awaitable
from config import settings
logger = logging.getLogger("kira.whisper")
class WhisperStream:
"""Streaming transcription via gpt-realtime-whisper over WebSocket."""
def __init__(
self,
on_transcript_delta: Callable[[str], Awaitable[None]],
on_transcript_done: Callable[[str], Awaitable[None]],
on_ready: Callable[[], Awaitable[None]],
on_error: Callable[[str], Awaitable[None]],
):
self._on_delta = on_transcript_delta
self._on_done = on_transcript_done
self._on_ready = on_ready
self._on_error = on_error
self._conn = None
self._connected = False
self._transcript = ""
async def connect(self):
if self._connected:
return
try:
import websockets
url = "wss://api.openai.com/v1/realtime?model=gpt-4o-mini-realtime-preview"
ws = await websockets.connect(
url,
additional_headers={
"Authorization": f"Bearer {settings.openai_api_key}",
},
)
async with ws as conn:
self._conn = conn
self._connected = True
logger.info("Connected to Realtime transcription session")
# Configure: transcribe only with gpt-realtime-whisper, no model responses
await self._send({
"type": "session.update",
"session": {
"modalities": ["text"], # no audio output
"input_audio_format": "pcm16",
"input_audio_transcription": {
"model": "gpt-realtime-whisper",
"enabled": True,
},
"turn_detection": {
"type": "server_vad",
"threshold": 0.5,
"prefix_padding_ms": 300,
"silence_duration_ms": 600,
},
},
})
await self._on_ready()
while self._connected:
try:
raw = await conn.recv()
if isinstance(raw, (str, bytes)):
data = json.loads(raw if isinstance(raw, str) else raw.decode())
await self._handle(data)
except Exception as e:
if self._connected:
logger.warning(f"recv: {e}")
break
except Exception as e:
logger.error(f"Whisper stream error: {e}")
await self._on_error(str(e))
finally:
self._connected = False
self._conn = None
async def _handle(self, data: dict):
et = data.get("type", "")
if et == "input_audio_buffer.speech_started":
self._transcript = ""
elif et == "input_audio_buffer.speech_stopped":
if self._transcript.strip():
await self._on_done(self._transcript.strip())
self._transcript = ""
elif et == "conversation.item.created":
item = data.get("item", {})
content = item.get("content", [])
for part in (content or []):
pt = part.get("type", "")
txt = part.get("transcript", "") or part.get("text", "")
if pt == "transcript" and txt:
self._transcript = txt
await self._on_delta(txt)
elif et == "error":
err = data.get("error", {})
msg = err.get("message", str(data))
logger.warning(f"Whisper error: {msg}")
async def send_audio(self, pcm16_bytes: bytes):
if not self._connected:
return
try:
b64 = base64.b64encode(pcm16_bytes).decode("utf-8")
await self._send({"type": "input_audio_buffer.append", "audio": b64})
except Exception as e:
logger.warning(f"send audio: {e}")
async def _send(self, data: dict):
try:
await self._conn.send(json.dumps(data))
except Exception as e:
logger.warning(f"send: {e}")
async def disconnect(self):
self._connected = False
if self._conn:
try:
await self._conn.close()
except Exception:
pass
self._conn = None