Files
kira/backend/main.py
T

281 lines
11 KiB
Python

"""Kira — AI body double backend
Gemini Live API (gemini-3.1-flash-live-preview) for real-time voice.
Text chat still goes through Gemini generateContent REST endpoint.
"""
import json
import base64
import uuid
import logging
import asyncio
import struct
import websockets
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from config import settings
from services.memory import kira_memory
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("kira")
app = FastAPI(title="Kira Backend")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
BASE_SYSTEM_PROMPT = (
"You are Kira, a warm, kind, and encouraging AI body double. "
"You speak in a friendly, girly-pop tone. You are helping someone with ADHD "
"stay focused and on task. Keep responses short, supportive, and uplifting. "
"Check in on them. Remind them to take breaks. Celebrate small wins. "
"Use occasional emoji but don't overdo it. Never be judgmental. "
"You are speaking out loud via voice, so keep natural conversational flow."
)
GEMINI_WS_URL = "wss://generativelanguage.googleapis.com/ws/google.ai.generativelanguage.v1beta.GenerativeService.BidiGenerateContent"
GEMINI_MODEL = "models/gemini-3.1-flash-live-preview"
@app.on_event("startup")
async def startup():
if kira_memory.init():
logger.info("Honcho memory initialized")
else:
logger.info("Honcho memory not configured")
@app.get("/api/health")
async def health():
mem_status = "active" if kira_memory.enabled else "disabled"
return {"status": "ok", "name": "kira", "memory": mem_status}
@app.websocket("/api/ws")
async def gemini_voice_ws(websocket: WebSocket):
"""WebSocket proxy between frontend and Gemini Live API.
Protocol (frontend ↔ this proxy):
{"type": "audio", "data": "<base64 PCM16 16kHz>"}
{"type": "conversation_text", "text": "..."}
{"type": "identify", "user_id": "...", "name": "..."}
{"type": "ping"}
{"type": "audio", "data": "<base64 PCM16 24kHz>"}
{"type": "transcript", "role": "user"|"kira", "text": "..."}
{"type": "turn_complete"}
{"type": "interrupted"}
{"type": "error", "message": "..."}
"""
await websocket.accept()
session_id = str(uuid.uuid4())[:8]
user_id = "default-user"
memory_suffix = ""
logger.info(f"[{session_id}] WebSocket connected")
gemini_ws = None
gemini_task = None
frontend_task = None
try:
# ── Connect to Gemini Live API ──
gemini_url = f"{GEMINI_WS_URL}?key={settings.gemini_api_key}"
gemini_ws = await websockets.connect(gemini_url, max_size=2**24)
# ── Send setup ──
system_prompt = BASE_SYSTEM_PROMPT
setup_msg = {
"setup": {
"model": GEMINI_MODEL,
"generationConfig": {
"responseModalities": ["AUDIO", "TEXT"],
"speechConfig": {
"voiceConfig": {
"prebuiltVoiceConfig": {
"voiceName": "Aoede"
}
}
},
},
"systemInstruction": {
"parts": [{"text": system_prompt}]
},
}
}
await gemini_ws.send(json.dumps(setup_msg))
logger.info(f"[{session_id}] Connected to Gemini Live API")
# Wait for setup complete
raw = await asyncio.wait_for(gemini_ws.recv(), timeout=10)
setup_resp = json.loads(raw)
if "setupComplete" in setup_resp:
logger.info(f"[{session_id}] Gemini setup complete")
else:
logger.warning(f"[{session_id}] Unexpected setup response: {list(setup_resp.keys())}")
# ── Gemini → Frontend relay ──
async def relay_gemini():
try:
async for raw in gemini_ws:
msg = json.loads(raw)
if "serverContent" in msg:
sc = msg["serverContent"]
model_turn = sc.get("modelTurn", {})
parts = model_turn.get("parts", [])
for part in parts:
# Text response
if "text" in part:
await websocket.send_json({
"type": "transcript",
"role": "kira",
"text": part["text"],
})
# Audio response (PCM16 24kHz)
if "inlineData" in part:
audio_data = part["inlineData"].get("data", "")
if audio_data:
await websocket.send_json({
"type": "audio",
"data": audio_data,
})
# Turn complete
if sc.get("turnComplete"):
await websocket.send_json({"type": "turn_complete"})
# Interrupted
if sc.get("interrupted"):
await websocket.send_json({"type": "interrupted"})
elif "toolCall" in msg:
pass # future: tool use
elif "toolCallCancellation" in msg:
pass
elif "error" in msg:
err = msg["error"]
logger.error(f"[{session_id}] Gemini error: {err}")
await websocket.send_json({
"type": "error",
"message": str(err.get("message", err)),
})
except websockets.exceptions.ConnectionClosed:
logger.info(f"[{session_id}] Gemini WS closed")
except Exception as e:
logger.error(f"[{session_id}] Gemini relay error: {e}")
# ── Frontend → Gemini relay ──
async def relay_frontend():
nonlocal user_id, memory_suffix
try:
while True:
raw = await websocket.receive_text()
msg = json.loads(raw)
msg_type = msg.get("type", "")
if msg_type == "identify":
user_id = msg.get("user_id", "default-user").strip()
user_name = msg.get("name", "").strip()
if user_name and user_id:
kira_memory.set_user_preference(user_id, "name", user_name)
prefs = kira_memory.get_user_preferences(user_id)
if kira_memory.enabled:
kira_memory.ensure_peers(user_id)
kira_memory.ensure_session(session_id)
try:
ctx = kira_memory.build_system_prompt_suffix()
if ctx:
memory_suffix = ctx
except Exception:
pass
await websocket.send_json({
"type": "identified",
"user_id": user_id,
"preferences": prefs,
})
continue
if msg_type == "set_preference":
key = msg.get("key", "").strip()
value = msg.get("value", "").strip()
if key and user_id and user_id != "default-user":
kira_memory.set_user_preference(user_id, key, value)
await websocket.send_json({"type": "preference_saved", "key": key, "success": True})
continue
if msg_type == "audio":
# Forward PCM16 audio to Gemini as realtimeInput
audio_b64 = msg.get("data", "")
if audio_b64 and gemini_ws and gemini_ws.state.name == "OPEN":
gemini_msg = {
"realtimeInput": {
"audio": {
"mimeType": "audio/pcm;rate=16000",
"data": audio_b64,
}
}
}
await gemini_ws.send(json.dumps(gemini_msg))
continue
if msg_type == "conversation_text":
text = msg.get("text", "").strip()
if not text:
continue
logger.info(f"[{session_id}] User (text): {text}")
# Send as a text turn to Gemini
if gemini_ws and gemini_ws.state.name == "OPEN":
user_part = {"text": text}
if memory_suffix:
user_part = {"text": f"[Context: {memory_suffix}]\n{text}"}
gemini_msg = {
"clientContent": {
"turns": [{"role": "user", "parts": [user_part]}],
"turnComplete": True,
}
}
await gemini_ws.send(json.dumps(gemini_msg))
await websocket.send_json({
"type": "transcript",
"role": "user",
"text": text,
})
continue
if msg_type == "ping":
await websocket.send_json({"type": "pong"})
except WebSocketDisconnect:
pass
except Exception as e:
logger.error(f"[{session_id}] Frontend relay error: {e}")
gemini_task = asyncio.create_task(relay_gemini())
frontend_task = asyncio.create_task(relay_frontend())
# Wait for either to finish
done, pending = await asyncio.wait(
[gemini_task, frontend_task],
return_when=asyncio.FIRST_COMPLETED,
)
for t in pending:
t.cancel()
except Exception as e:
logger.error(f"[{session_id}] Connection error: {e}")
finally:
if gemini_ws:
await gemini_ws.close()
logger.info(f"[{session_id}] Disconnected")