66e799a655
The openai library's beta.realtime.connect() hardcodes the obsolete 'OpenAI-Beta: realtime=v1' header which the GA API rejects. Connecting directly via the websockets library with only the Authorization header resolves the 'beta_api_shape_disabled' error.
237 lines
8.5 KiB
Python
237 lines
8.5 KiB
Python
"""Hybrid pipeline: streaming STT (gpt-realtime-whisper) + cheap LLM + TTS.
|
|
|
|
Uses gpt-realtime-whisper for low-latency streaming transcription,
|
|
gpt-5.4-nano as the brain, and OpenAI TTS for voice output.
|
|
"""
|
|
|
|
import json
|
|
import base64
|
|
import logging
|
|
import asyncio
|
|
from typing import Callable, Awaitable
|
|
from openai import AsyncOpenAI
|
|
from config import settings
|
|
|
|
logger = logging.getLogger("kira.hybrid")
|
|
|
|
# ─── System instructions for Kira's personality ───
|
|
|
|
KIRA_INSTRUCTIONS = (
|
|
"You are Kira, a warm, kind, and encouraging AI body double. "
|
|
"You speak in a friendly, girly-pop tone. You are helping someone with ADHD "
|
|
"stay focused and on task. Keep responses short, supportive, and uplifting. "
|
|
"Check in on them. Remind them to take breaks. Celebrate small wins. "
|
|
"Use occasional emoji but don't overdo it. Never be judgmental."
|
|
)
|
|
|
|
|
|
class HybridPipeline:
|
|
"""Streaming STT via gpt-realtime-whisper → gpt-5.4-nano LLM → OpenAI TTS."""
|
|
|
|
def __init__(
|
|
self,
|
|
on_transcript_delta: Callable[[str], Awaitable[None]],
|
|
on_transcript_done: Callable[[str], Awaitable[None]],
|
|
on_audio_delta: Callable[[bytes], Awaitable[None]],
|
|
on_speech_start: Callable[[], Awaitable[None]],
|
|
on_speech_end: Callable[[], Awaitable[None]],
|
|
on_ready: Callable[[], Awaitable[None]],
|
|
on_error: Callable[[str], Awaitable[None]],
|
|
memory_suffix: str = "",
|
|
):
|
|
self._on_transcript_delta = on_transcript_delta
|
|
self._on_transcript_done = on_transcript_done
|
|
self._on_audio_delta = on_audio_delta
|
|
self._on_speech_start = on_speech_start
|
|
self._on_speech_end = on_speech_end
|
|
self._on_ready = on_ready
|
|
self._on_error = on_error
|
|
self._memory_suffix = memory_suffix
|
|
self._openai = None
|
|
self._conn = None
|
|
self._connected = False
|
|
self._transcript_buffer = ""
|
|
|
|
async def connect(self):
|
|
"""Connect to gpt-realtime-whisper via OpenAI Realtime API."""
|
|
if self._connected:
|
|
return
|
|
|
|
try:
|
|
import websockets
|
|
|
|
self._openai = AsyncOpenAI(api_key=settings.openai_api_key)
|
|
|
|
logger.info("Connecting to gpt-realtime-whisper...")
|
|
|
|
# Connect directly via websockets to avoid the OpenAI-Beta header
|
|
url = f"wss://api.openai.com/v1/realtime?model=gpt-realtime-whisper"
|
|
ws = await websockets.connect(
|
|
url,
|
|
additional_headers={
|
|
"Authorization": f"Bearer {settings.openai_api_key}",
|
|
},
|
|
)
|
|
|
|
async with ws as conn:
|
|
self._conn = conn
|
|
self._connected = True
|
|
logger.info("Connected to gpt-realtime-whisper")
|
|
|
|
# Configure session for transcription
|
|
await self._send({
|
|
"type": "session.update",
|
|
"session": {
|
|
"input_audio_format": "pcm16",
|
|
"input_audio_transcription": {
|
|
"enabled": True,
|
|
},
|
|
"turn_detection": {
|
|
"type": "server_vad",
|
|
"threshold": 0.5,
|
|
"prefix_padding_ms": 300,
|
|
"silence_duration_ms": 600,
|
|
},
|
|
},
|
|
})
|
|
|
|
await self._on_ready()
|
|
|
|
# Listen for transcription events
|
|
while self._connected:
|
|
try:
|
|
raw = await conn.recv()
|
|
if isinstance(raw, (str, bytes)):
|
|
data = json.loads(raw if isinstance(raw, str) else raw.decode())
|
|
await self._handle_event(data)
|
|
except Exception as e:
|
|
if self._connected:
|
|
logger.warning(f"recv error: {e}")
|
|
break
|
|
|
|
except Exception as e:
|
|
logger.error(f"Connection error: {e}")
|
|
await self._on_error(str(e))
|
|
finally:
|
|
self._connected = False
|
|
self._conn = None
|
|
|
|
async def _handle_event(self, event):
|
|
"""Process events from gpt-realtime-whisper."""
|
|
event_type = getattr(event, "type", None) or (event.get("type") if isinstance(event, dict) else "")
|
|
|
|
if event_type == "input_audio_buffer.speech_started":
|
|
self._transcript_buffer = ""
|
|
|
|
elif event_type == "input_audio_buffer.speech_stopped":
|
|
if self._transcript_buffer.strip():
|
|
await self._process_transcript(self._transcript_buffer.strip())
|
|
self._transcript_buffer = ""
|
|
|
|
elif event_type == "input_audio_buffer.transcription_delta":
|
|
delta_text = self._get_field(event, "delta", "")
|
|
if delta_text:
|
|
self._transcript_buffer += delta_text
|
|
|
|
elif event_type == "conversation.item.created":
|
|
item = self._get_field(event, "item", {})
|
|
content = self._get_field(item, "content", [])
|
|
for part in (content or []):
|
|
part_type = self._get_field(part, "type", "")
|
|
part_transcript = self._get_field(part, "transcript", "")
|
|
if part_type == "transcript" and part_transcript:
|
|
self._transcript_buffer = part_transcript
|
|
await self._on_transcript_delta(part_transcript)
|
|
|
|
elif event_type == "error":
|
|
err = self._get_field(event, "error", {})
|
|
msg = self._get_field(err, "message", str(event))
|
|
logger.warning(f"Whisper error: {msg}")
|
|
|
|
async def _process_transcript(self, transcript: str):
|
|
"""Full utterance received. Call LLM, then TTS."""
|
|
await self._on_transcript_done(transcript)
|
|
logger.info(f"User: {transcript}")
|
|
|
|
# Build system prompt with optional memory context
|
|
system_content = KIRA_INSTRUCTIONS
|
|
if self._memory_suffix:
|
|
system_content += self._memory_suffix
|
|
|
|
# Call gpt-5.4-nano
|
|
try:
|
|
resp = await self._openai.chat.completions.create(
|
|
model="gpt-5.4-nano",
|
|
messages=[
|
|
{"role": "system", "content": system_content},
|
|
{"role": "user", "content": transcript},
|
|
],
|
|
max_tokens=300,
|
|
temperature=0.7,
|
|
)
|
|
kira_text = resp.choices[0].message.content or "Mhm, I'm here!"
|
|
except Exception as e:
|
|
logger.error(f"LLM error: {e}")
|
|
kira_text = "Sorry, let me try that again!"
|
|
await self._on_error(str(e))
|
|
|
|
logger.info(f"Kira: {kira_text}")
|
|
|
|
# Call TTS
|
|
await self._on_speech_start()
|
|
try:
|
|
tts_resp = await self._openai.audio.speech.create(
|
|
model="tts-1",
|
|
voice="nova",
|
|
input=kira_text,
|
|
response_format="opus",
|
|
)
|
|
audio_bytes = tts_resp.content
|
|
if audio_bytes:
|
|
await self._on_audio_delta(audio_bytes)
|
|
except Exception as e:
|
|
logger.error(f"TTS error: {e}")
|
|
|
|
await self._on_speech_end()
|
|
|
|
async def send_audio(self, pcm16_bytes: bytes):
|
|
"""Send PCM16 audio chunk for transcription."""
|
|
if not self._connected or not self._conn:
|
|
return
|
|
try:
|
|
audio_b64 = base64.b64encode(pcm16_bytes).decode("utf-8")
|
|
await self._send({
|
|
"type": "input_audio_buffer.append",
|
|
"audio": audio_b64,
|
|
})
|
|
except Exception as e:
|
|
logger.warning(f"Send audio error: {e}")
|
|
|
|
async def send_text(self, text: str):
|
|
"""Process text input directly (no transcription needed)."""
|
|
await self._process_transcript(text)
|
|
|
|
async def _send(self, data: dict):
|
|
try:
|
|
await self._conn.send(json.dumps(data))
|
|
except Exception as e:
|
|
logger.warning(f"Send error: {e}")
|
|
|
|
async def disconnect(self):
|
|
self._connected = False
|
|
if self._conn:
|
|
try:
|
|
await self._conn.close()
|
|
except Exception:
|
|
pass
|
|
self._conn = None
|
|
|
|
@staticmethod
|
|
def _get_field(obj, field: str, default=None):
|
|
"""Get a field from either an object or dict."""
|
|
if hasattr(obj, field):
|
|
return getattr(obj, field, default)
|
|
if isinstance(obj, dict):
|
|
return obj.get(field, default)
|
|
return default
|