feat: Realtime WebSocket STT via gpt-realtime-whisper
Replaces REST-based transcription (gpt-4o-transcribe) with WebSocket streaming via gpt-realtime-whisper. Frontend captures PCM16 audio and streams it through the backend to a Realtime transcription session. - Server-side VAD detects utterance boundaries automatically - Word-level transcript deltas stream to the client in real-time - On utterance end, gpt-5.4-nano generates a response - TTS streams back via with_streaming_response - Total pipeline: PCM16 → Realtime WS → LLM → streaming TTS
This commit is contained in:
+124
-122
@@ -1,19 +1,21 @@
|
||||
"""Kira — AI body double backend
|
||||
|
||||
Cheapest pipeline: gpt-4o-mini-transcribe STT → gpt-5.4-nano LLM → OpenAI TTS
|
||||
~$0.019/min total, simple 3-step chat completions.
|
||||
Realtime WebSocket STT (gpt-realtime-whisper) → gpt-5.4-nano → streaming TTS
|
||||
"""
|
||||
|
||||
import json
|
||||
import base64
|
||||
import uuid
|
||||
import logging
|
||||
import time
|
||||
import asyncio
|
||||
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from config import settings
|
||||
from services.memory import kira_memory
|
||||
from services.whisper_stream import WhisperStream
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("kira")
|
||||
@@ -61,59 +63,6 @@ async def health():
|
||||
return {"status": "ok", "name": "kira", "memory": mem_status}
|
||||
|
||||
|
||||
async def run_conversation(text: str, memory_suffix: str = "") -> str:
|
||||
"""LLM call with optional Honcho memory context injected into system prompt."""
|
||||
system_prompt = BASE_SYSTEM_PROMPT
|
||||
if memory_suffix:
|
||||
system_prompt += memory_suffix
|
||||
|
||||
client = get_openai()
|
||||
resp = await client.chat.completions.create(
|
||||
model="gpt-5.4-nano",
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": text},
|
||||
],
|
||||
max_completion_tokens=300,
|
||||
temperature=0.7,
|
||||
)
|
||||
return resp.choices[0].message.content or "Mhm, I'm here!"
|
||||
|
||||
|
||||
async def transcribe_audio(audio_bytes: bytes) -> str | None:
|
||||
"""Transcribe Opus/webm audio using cheapest STT model."""
|
||||
client = get_openai()
|
||||
try:
|
||||
transcript = await client.audio.transcriptions.create(
|
||||
model="gpt-4o-transcribe",
|
||||
file=("audio.webm", audio_bytes, "audio/webm"),
|
||||
response_format="text",
|
||||
)
|
||||
return transcript.strip() if transcript and transcript.strip() else None
|
||||
except Exception as e:
|
||||
logger.warning(f"STT error: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def synthesize_speech(text: str, websocket, speaking_start_sent: bool = False) -> None:
|
||||
"""Generate TTS audio from text, streaming chunks to the client."""
|
||||
client = get_openai()
|
||||
try:
|
||||
async with client.audio.speech.with_streaming_response.create(
|
||||
model="tts-1",
|
||||
voice="sage",
|
||||
input=text,
|
||||
response_format="opus",
|
||||
) as resp:
|
||||
async for chunk in resp.iter_bytes():
|
||||
if chunk:
|
||||
audio_b64 = base64.b64encode(chunk).decode("utf-8")
|
||||
await websocket.send_json({"type": "audio", "data": audio_b64, "text": text if speaking_start_sent else ""})
|
||||
speaking_start_sent = True
|
||||
except Exception as e:
|
||||
logger.warning(f"TTS error: {e}")
|
||||
|
||||
|
||||
@app.websocket("/api/ws")
|
||||
async def conversation_ws(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
@@ -123,8 +72,85 @@ async def conversation_ws(websocket: WebSocket):
|
||||
memory_suffix = ""
|
||||
logger.info(f"[{session_id}] WebSocket connected")
|
||||
|
||||
audio_buffer = bytearray()
|
||||
conversation_history: list[dict] = []
|
||||
pending_transcript: str | None = None
|
||||
transcript_lock = asyncio.Lock()
|
||||
|
||||
# ── Whisper stream callbacks ──
|
||||
|
||||
async def on_ready():
|
||||
logger.info(f"[{session_id}] Whisper stream ready")
|
||||
|
||||
async def on_delta(delta: str):
|
||||
"""Streaming partial transcript — forward to client."""
|
||||
try:
|
||||
await websocket.send_json({"type": "transcript_delta", "text": delta})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def on_done(full: str):
|
||||
"""Full utterance from VAD. Kick off LLM + TTS."""
|
||||
nonlocal pending_transcript
|
||||
logger.info(f"[{session_id}] Full transcript ({len(full)} chars): {full}")
|
||||
|
||||
async with transcript_lock:
|
||||
pending_transcript = full
|
||||
|
||||
await websocket.send_json({"type": "transcript", "role": "user", "text": full})
|
||||
conversation_history.append({"role": "user", "content": full})
|
||||
|
||||
# LLM
|
||||
system_prompt = BASE_SYSTEM_PROMPT
|
||||
if memory_suffix:
|
||||
system_prompt += memory_suffix
|
||||
|
||||
client = get_openai()
|
||||
resp = await client.chat.completions.create(
|
||||
model="gpt-5.4-nano",
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": full},
|
||||
],
|
||||
max_completion_tokens=300,
|
||||
temperature=0.7,
|
||||
)
|
||||
kira_text = resp.choices[0].message.content or "Mhm, I'm here!"
|
||||
conversation_history.append({"role": "assistant", "content": kira_text})
|
||||
logger.info(f"[{session_id}] Kira: {kira_text}")
|
||||
|
||||
# Store in Honcho
|
||||
if kira_memory.enabled and identified:
|
||||
try:
|
||||
kira_memory.store_messages(full, kira_text)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Streaming TTS
|
||||
await websocket.send_json({"type": "speaking_start", "text": kira_text})
|
||||
async with client.audio.speech.with_streaming_response.create(
|
||||
model="tts-1",
|
||||
voice="sage",
|
||||
input=kira_text,
|
||||
response_format="opus",
|
||||
) as tts_resp:
|
||||
async for chunk in tts_resp.iter_bytes():
|
||||
if chunk:
|
||||
b64 = base64.b64encode(chunk).decode("utf-8")
|
||||
await websocket.send_json({"type": "audio", "data": b64})
|
||||
await websocket.send_json({"type": "speaking_end"})
|
||||
|
||||
async def on_error(msg: str):
|
||||
logger.warning(f"Whisper error: {msg}")
|
||||
|
||||
# Start WhisperStream
|
||||
stream = WhisperStream(
|
||||
on_transcript_delta=on_delta,
|
||||
on_transcript_done=on_done,
|
||||
on_ready=on_ready,
|
||||
on_error=on_error,
|
||||
)
|
||||
stream_task = asyncio.create_task(stream.connect())
|
||||
await asyncio.sleep(2) # brief wait for connection
|
||||
|
||||
try:
|
||||
while True:
|
||||
@@ -145,7 +171,6 @@ async def conversation_ws(websocket: WebSocket):
|
||||
if kira_memory.enabled:
|
||||
kira_memory.ensure_peers(user_id)
|
||||
kira_memory.ensure_session(session_id)
|
||||
# Build memory context ONCE on identify (not per-turn — too slow)
|
||||
try:
|
||||
ctx = kira_memory.build_system_prompt_suffix()
|
||||
if ctx:
|
||||
@@ -165,92 +190,69 @@ async def conversation_ws(websocket: WebSocket):
|
||||
value = msg.get("value", "").strip()
|
||||
if key and user_id and user_id != "default-user":
|
||||
kira_memory.set_user_preference(user_id, key, value)
|
||||
await websocket.send_json({
|
||||
"type": "preference_saved",
|
||||
"key": key,
|
||||
"success": True,
|
||||
})
|
||||
await websocket.send_json({"type": "preference_saved", "key": key, "success": True})
|
||||
continue
|
||||
|
||||
# ── Conversation ──
|
||||
if msg_type == "audio_chunk":
|
||||
chunk = base64.b64decode(msg["data"])
|
||||
audio_buffer.extend(chunk)
|
||||
# ── PCM16 audio → WhisperStream ──
|
||||
if msg_type == "audio":
|
||||
pcm16 = base64.b64decode(msg["data"])
|
||||
await stream.send_audio(pcm16)
|
||||
continue
|
||||
|
||||
elif msg_type == "transcribe":
|
||||
if not audio_buffer:
|
||||
await websocket.send_json({"type": "error", "message": "No audio data"})
|
||||
# ── Text input → direct LLM + TTS ──
|
||||
if msg_type == "conversation_text":
|
||||
text = msg.get("text", "").strip()
|
||||
if not text:
|
||||
continue
|
||||
|
||||
import time
|
||||
t0 = time.time()
|
||||
logger.info(f"[{session_id}] Transcribing {len(audio_buffer)} bytes...")
|
||||
logger.info(f"[{session_id}] User (text): {text}")
|
||||
conversation_history.append({"role": "user", "content": text})
|
||||
|
||||
# 1. STT
|
||||
transcript = await transcribe_audio(bytes(audio_buffer))
|
||||
t1 = time.time()
|
||||
audio_buffer.clear()
|
||||
system_prompt = BASE_SYSTEM_PROMPT
|
||||
if memory_suffix:
|
||||
system_prompt += memory_suffix
|
||||
|
||||
if not transcript:
|
||||
await websocket.send_json({"type": "error", "message": "Could not transcribe"})
|
||||
continue
|
||||
|
||||
logger.info(f"[{session_id}] STT took {t1-t0:.1f}s")
|
||||
|
||||
await websocket.send_json({"type": "transcript", "role": "user", "text": transcript})
|
||||
conversation_history.append({"role": "user", "content": transcript})
|
||||
|
||||
# 2. LLM (uses cached memory_suffix from identify)
|
||||
logger.info(f"[{session_id}] User: {transcript}")
|
||||
kira_text = await run_conversation(transcript, memory_suffix)
|
||||
t2 = time.time()
|
||||
logger.info(f"[{session_id}] LLM took {t2-t1:.1f}s")
|
||||
client = get_openai()
|
||||
resp = await client.chat.completions.create(
|
||||
model="gpt-5.4-nano",
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": text},
|
||||
],
|
||||
max_completion_tokens=300,
|
||||
temperature=0.7,
|
||||
)
|
||||
kira_text = resp.choices[0].message.content or "Mhm!"
|
||||
conversation_history.append({"role": "assistant", "content": kira_text})
|
||||
logger.info(f"[{session_id}] Kira: {kira_text}")
|
||||
|
||||
if kira_memory.enabled and identified:
|
||||
try:
|
||||
kira_memory.store_messages(transcript, kira_text)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 3. TTS
|
||||
await websocket.send_json({"type": "speaking_start", "text": kira_text})
|
||||
await synthesize_speech(kira_text, websocket)
|
||||
t3 = time.time()
|
||||
logger.info(f"[{session_id}] TTS took {t3-t2:.1f}s. Total: {t3-t0:.1f}s")
|
||||
await websocket.send_json({"type": "speaking_end"})
|
||||
|
||||
elif msg_type == "conversation_text":
|
||||
user_text = msg.get("text", "").strip()
|
||||
if not user_text:
|
||||
continue
|
||||
|
||||
conversation_history.append({"role": "user", "content": user_text})
|
||||
logger.info(f"[{session_id}] User (text): {user_text}")
|
||||
|
||||
kira_text = await run_conversation(user_text, memory_suffix)
|
||||
conversation_history.append({"role": "assistant", "content": kira_text})
|
||||
logger.info(f"[{session_id}] Kira: {kira_text}")
|
||||
|
||||
if kira_memory.enabled and identified:
|
||||
try:
|
||||
kira_memory.store_messages(user_text, kira_text)
|
||||
kira_memory.store_messages(text, kira_text)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await websocket.send_json({"type": "speaking_start", "text": kira_text})
|
||||
await synthesize_speech(kira_text, websocket)
|
||||
async with client.audio.speech.with_streaming_response.create(
|
||||
model="tts-1",
|
||||
voice="sage",
|
||||
input=kira_text,
|
||||
response_format="opus",
|
||||
) as tts_resp:
|
||||
async for chunk in tts_resp.iter_bytes():
|
||||
if chunk:
|
||||
b64 = base64.b64encode(chunk).decode("utf-8")
|
||||
await websocket.send_json({"type": "audio", "data": b64})
|
||||
await websocket.send_json({"type": "speaking_end"})
|
||||
continue
|
||||
|
||||
elif msg_type == "ping":
|
||||
if msg_type == "ping":
|
||||
await websocket.send_json({"type": "pong"})
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"[{session_id}] Disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"[{session_id}] Error: {e}")
|
||||
try:
|
||||
await websocket.send_json({"type": "error", "message": str(e)})
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
await stream.disconnect()
|
||||
stream_task.cancel()
|
||||
|
||||
Reference in New Issue
Block a user