Files
kira/backend/services/hybrid.py
T
hobokenchicken 66e799a655 fix: connect directly via websockets to bypass OpenAI-Beta header
The openai library's beta.realtime.connect() hardcodes the obsolete
'OpenAI-Beta: realtime=v1' header which the GA API rejects. Connecting
directly via the websockets library with only the Authorization header
resolves the 'beta_api_shape_disabled' error.
2026-06-04 13:49:56 -04:00

237 lines
8.5 KiB
Python

"""Hybrid pipeline: streaming STT (gpt-realtime-whisper) + cheap LLM + TTS.
Uses gpt-realtime-whisper for low-latency streaming transcription,
gpt-5.4-nano as the brain, and OpenAI TTS for voice output.
"""
import json
import base64
import logging
import asyncio
from typing import Callable, Awaitable
from openai import AsyncOpenAI
from config import settings
logger = logging.getLogger("kira.hybrid")
# ─── System instructions for Kira's personality ───
KIRA_INSTRUCTIONS = (
"You are Kira, a warm, kind, and encouraging AI body double. "
"You speak in a friendly, girly-pop tone. You are helping someone with ADHD "
"stay focused and on task. Keep responses short, supportive, and uplifting. "
"Check in on them. Remind them to take breaks. Celebrate small wins. "
"Use occasional emoji but don't overdo it. Never be judgmental."
)
class HybridPipeline:
"""Streaming STT via gpt-realtime-whisper → gpt-5.4-nano LLM → OpenAI TTS."""
def __init__(
self,
on_transcript_delta: Callable[[str], Awaitable[None]],
on_transcript_done: Callable[[str], Awaitable[None]],
on_audio_delta: Callable[[bytes], Awaitable[None]],
on_speech_start: Callable[[], Awaitable[None]],
on_speech_end: Callable[[], Awaitable[None]],
on_ready: Callable[[], Awaitable[None]],
on_error: Callable[[str], Awaitable[None]],
memory_suffix: str = "",
):
self._on_transcript_delta = on_transcript_delta
self._on_transcript_done = on_transcript_done
self._on_audio_delta = on_audio_delta
self._on_speech_start = on_speech_start
self._on_speech_end = on_speech_end
self._on_ready = on_ready
self._on_error = on_error
self._memory_suffix = memory_suffix
self._openai = None
self._conn = None
self._connected = False
self._transcript_buffer = ""
async def connect(self):
"""Connect to gpt-realtime-whisper via OpenAI Realtime API."""
if self._connected:
return
try:
import websockets
self._openai = AsyncOpenAI(api_key=settings.openai_api_key)
logger.info("Connecting to gpt-realtime-whisper...")
# Connect directly via websockets to avoid the OpenAI-Beta header
url = f"wss://api.openai.com/v1/realtime?model=gpt-realtime-whisper"
ws = await websockets.connect(
url,
additional_headers={
"Authorization": f"Bearer {settings.openai_api_key}",
},
)
async with ws as conn:
self._conn = conn
self._connected = True
logger.info("Connected to gpt-realtime-whisper")
# Configure session for transcription
await self._send({
"type": "session.update",
"session": {
"input_audio_format": "pcm16",
"input_audio_transcription": {
"enabled": True,
},
"turn_detection": {
"type": "server_vad",
"threshold": 0.5,
"prefix_padding_ms": 300,
"silence_duration_ms": 600,
},
},
})
await self._on_ready()
# Listen for transcription events
while self._connected:
try:
raw = await conn.recv()
if isinstance(raw, (str, bytes)):
data = json.loads(raw if isinstance(raw, str) else raw.decode())
await self._handle_event(data)
except Exception as e:
if self._connected:
logger.warning(f"recv error: {e}")
break
except Exception as e:
logger.error(f"Connection error: {e}")
await self._on_error(str(e))
finally:
self._connected = False
self._conn = None
async def _handle_event(self, event):
"""Process events from gpt-realtime-whisper."""
event_type = getattr(event, "type", None) or (event.get("type") if isinstance(event, dict) else "")
if event_type == "input_audio_buffer.speech_started":
self._transcript_buffer = ""
elif event_type == "input_audio_buffer.speech_stopped":
if self._transcript_buffer.strip():
await self._process_transcript(self._transcript_buffer.strip())
self._transcript_buffer = ""
elif event_type == "input_audio_buffer.transcription_delta":
delta_text = self._get_field(event, "delta", "")
if delta_text:
self._transcript_buffer += delta_text
elif event_type == "conversation.item.created":
item = self._get_field(event, "item", {})
content = self._get_field(item, "content", [])
for part in (content or []):
part_type = self._get_field(part, "type", "")
part_transcript = self._get_field(part, "transcript", "")
if part_type == "transcript" and part_transcript:
self._transcript_buffer = part_transcript
await self._on_transcript_delta(part_transcript)
elif event_type == "error":
err = self._get_field(event, "error", {})
msg = self._get_field(err, "message", str(event))
logger.warning(f"Whisper error: {msg}")
async def _process_transcript(self, transcript: str):
"""Full utterance received. Call LLM, then TTS."""
await self._on_transcript_done(transcript)
logger.info(f"User: {transcript}")
# Build system prompt with optional memory context
system_content = KIRA_INSTRUCTIONS
if self._memory_suffix:
system_content += self._memory_suffix
# Call gpt-5.4-nano
try:
resp = await self._openai.chat.completions.create(
model="gpt-5.4-nano",
messages=[
{"role": "system", "content": system_content},
{"role": "user", "content": transcript},
],
max_tokens=300,
temperature=0.7,
)
kira_text = resp.choices[0].message.content or "Mhm, I'm here!"
except Exception as e:
logger.error(f"LLM error: {e}")
kira_text = "Sorry, let me try that again!"
await self._on_error(str(e))
logger.info(f"Kira: {kira_text}")
# Call TTS
await self._on_speech_start()
try:
tts_resp = await self._openai.audio.speech.create(
model="tts-1",
voice="nova",
input=kira_text,
response_format="opus",
)
audio_bytes = tts_resp.content
if audio_bytes:
await self._on_audio_delta(audio_bytes)
except Exception as e:
logger.error(f"TTS error: {e}")
await self._on_speech_end()
async def send_audio(self, pcm16_bytes: bytes):
"""Send PCM16 audio chunk for transcription."""
if not self._connected or not self._conn:
return
try:
audio_b64 = base64.b64encode(pcm16_bytes).decode("utf-8")
await self._send({
"type": "input_audio_buffer.append",
"audio": audio_b64,
})
except Exception as e:
logger.warning(f"Send audio error: {e}")
async def send_text(self, text: str):
"""Process text input directly (no transcription needed)."""
await self._process_transcript(text)
async def _send(self, data: dict):
try:
await self._conn.send(json.dumps(data))
except Exception as e:
logger.warning(f"Send error: {e}")
async def disconnect(self):
self._connected = False
if self._conn:
try:
await self._conn.close()
except Exception:
pass
self._conn = None
@staticmethod
def _get_field(obj, field: str, default=None):
"""Get a field from either an object or dict."""
if hasattr(obj, field):
return getattr(obj, field, default)
if isinstance(obj, dict):
return obj.get(field, default)
return default