281 lines
11 KiB
Python
281 lines
11 KiB
Python
"""Kira — AI body double backend
|
|
|
|
Gemini Live API (gemini-3.1-flash-live-preview) for real-time voice.
|
|
Text chat still goes through Gemini generateContent REST endpoint.
|
|
"""
|
|
|
|
import json
|
|
import base64
|
|
import uuid
|
|
import logging
|
|
import asyncio
|
|
import struct
|
|
|
|
import websockets
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
from config import settings
|
|
from services.memory import kira_memory
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger("kira")
|
|
|
|
app = FastAPI(title="Kira Backend")
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
BASE_SYSTEM_PROMPT = (
|
|
"You are Kira, a warm, kind, and encouraging AI body double. "
|
|
"You speak in a friendly, girly-pop tone. You are helping someone with ADHD "
|
|
"stay focused and on task. Keep responses short, supportive, and uplifting. "
|
|
"Check in on them. Remind them to take breaks. Celebrate small wins. "
|
|
"Use occasional emoji but don't overdo it. Never be judgmental. "
|
|
"You are speaking out loud via voice, so keep natural conversational flow."
|
|
)
|
|
|
|
GEMINI_WS_URL = "wss://generativelanguage.googleapis.com/ws/google.ai.generativelanguage.v1beta.GenerativeService.BidiGenerateContent"
|
|
GEMINI_MODEL = "models/gemini-3.1-flash-live-preview"
|
|
|
|
|
|
@app.on_event("startup")
|
|
async def startup():
|
|
if kira_memory.init():
|
|
logger.info("Honcho memory initialized")
|
|
else:
|
|
logger.info("Honcho memory not configured")
|
|
|
|
|
|
@app.get("/api/health")
|
|
async def health():
|
|
mem_status = "active" if kira_memory.enabled else "disabled"
|
|
return {"status": "ok", "name": "kira", "memory": mem_status}
|
|
|
|
|
|
@app.websocket("/api/ws")
|
|
async def gemini_voice_ws(websocket: WebSocket):
|
|
"""WebSocket proxy between frontend and Gemini Live API.
|
|
|
|
Protocol (frontend ↔ this proxy):
|
|
→ {"type": "audio", "data": "<base64 PCM16 16kHz>"}
|
|
→ {"type": "conversation_text", "text": "..."}
|
|
→ {"type": "identify", "user_id": "...", "name": "..."}
|
|
→ {"type": "ping"}
|
|
← {"type": "audio", "data": "<base64 PCM16 24kHz>"}
|
|
← {"type": "transcript", "role": "user"|"kira", "text": "..."}
|
|
← {"type": "turn_complete"}
|
|
← {"type": "interrupted"}
|
|
← {"type": "error", "message": "..."}
|
|
"""
|
|
await websocket.accept()
|
|
session_id = str(uuid.uuid4())[:8]
|
|
user_id = "default-user"
|
|
memory_suffix = ""
|
|
logger.info(f"[{session_id}] WebSocket connected")
|
|
|
|
gemini_ws = None
|
|
gemini_task = None
|
|
frontend_task = None
|
|
|
|
try:
|
|
# ── Connect to Gemini Live API ──
|
|
gemini_url = f"{GEMINI_WS_URL}?key={settings.gemini_api_key}"
|
|
gemini_ws = await websockets.connect(gemini_url, max_size=2**24)
|
|
|
|
# ── Send setup ──
|
|
system_prompt = BASE_SYSTEM_PROMPT
|
|
setup_msg = {
|
|
"setup": {
|
|
"model": GEMINI_MODEL,
|
|
"generationConfig": {
|
|
"responseModalities": ["AUDIO"],
|
|
"speechConfig": {
|
|
"voiceConfig": {
|
|
"prebuiltVoiceConfig": {
|
|
"voiceName": "Aoede"
|
|
}
|
|
}
|
|
},
|
|
},
|
|
"systemInstruction": {
|
|
"parts": [{"text": system_prompt}]
|
|
},
|
|
}
|
|
}
|
|
await gemini_ws.send(json.dumps(setup_msg))
|
|
logger.info(f"[{session_id}] Connected to Gemini Live API")
|
|
|
|
# Wait for setup complete
|
|
raw = await asyncio.wait_for(gemini_ws.recv(), timeout=10)
|
|
setup_resp = json.loads(raw)
|
|
if "setupComplete" in setup_resp:
|
|
logger.info(f"[{session_id}] Gemini setup complete")
|
|
else:
|
|
logger.warning(f"[{session_id}] Unexpected setup response: {list(setup_resp.keys())}")
|
|
|
|
# ── Gemini → Frontend relay ──
|
|
async def relay_gemini():
|
|
try:
|
|
async for raw in gemini_ws:
|
|
msg = json.loads(raw)
|
|
|
|
if "serverContent" in msg:
|
|
sc = msg["serverContent"]
|
|
model_turn = sc.get("modelTurn", {})
|
|
parts = model_turn.get("parts", [])
|
|
|
|
for part in parts:
|
|
# Text response
|
|
if "text" in part:
|
|
await websocket.send_json({
|
|
"type": "transcript",
|
|
"role": "kira",
|
|
"text": part["text"],
|
|
})
|
|
|
|
# Audio response (PCM16 24kHz)
|
|
if "inlineData" in part:
|
|
audio_data = part["inlineData"].get("data", "")
|
|
if audio_data:
|
|
await websocket.send_json({
|
|
"type": "audio",
|
|
"data": audio_data,
|
|
})
|
|
|
|
# Turn complete
|
|
if sc.get("turnComplete"):
|
|
await websocket.send_json({"type": "turn_complete"})
|
|
|
|
# Interrupted
|
|
if sc.get("interrupted"):
|
|
await websocket.send_json({"type": "interrupted"})
|
|
|
|
elif "toolCall" in msg:
|
|
pass # future: tool use
|
|
|
|
elif "toolCallCancellation" in msg:
|
|
pass
|
|
|
|
elif "error" in msg:
|
|
err = msg["error"]
|
|
logger.error(f"[{session_id}] Gemini error: {err}")
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"message": str(err.get("message", err)),
|
|
})
|
|
|
|
except websockets.exceptions.ConnectionClosed:
|
|
logger.info(f"[{session_id}] Gemini WS closed")
|
|
except Exception as e:
|
|
logger.error(f"[{session_id}] Gemini relay error: {e}")
|
|
|
|
# ── Frontend → Gemini relay ──
|
|
async def relay_frontend():
|
|
nonlocal user_id, memory_suffix
|
|
try:
|
|
while True:
|
|
raw = await websocket.receive_text()
|
|
msg = json.loads(raw)
|
|
msg_type = msg.get("type", "")
|
|
|
|
if msg_type == "identify":
|
|
user_id = msg.get("user_id", "default-user").strip()
|
|
user_name = msg.get("name", "").strip()
|
|
if user_name and user_id:
|
|
kira_memory.set_user_preference(user_id, "name", user_name)
|
|
prefs = kira_memory.get_user_preferences(user_id)
|
|
if kira_memory.enabled:
|
|
kira_memory.ensure_peers(user_id)
|
|
kira_memory.ensure_session(session_id)
|
|
try:
|
|
ctx = kira_memory.build_system_prompt_suffix()
|
|
if ctx:
|
|
memory_suffix = ctx
|
|
except Exception:
|
|
pass
|
|
await websocket.send_json({
|
|
"type": "identified",
|
|
"user_id": user_id,
|
|
"preferences": prefs,
|
|
})
|
|
continue
|
|
|
|
if msg_type == "set_preference":
|
|
key = msg.get("key", "").strip()
|
|
value = msg.get("value", "").strip()
|
|
if key and user_id and user_id != "default-user":
|
|
kira_memory.set_user_preference(user_id, key, value)
|
|
await websocket.send_json({"type": "preference_saved", "key": key, "success": True})
|
|
continue
|
|
|
|
if msg_type == "audio":
|
|
# Forward PCM16 audio to Gemini as realtimeInput
|
|
audio_b64 = msg.get("data", "")
|
|
if audio_b64 and gemini_ws and gemini_ws.state.name == "OPEN":
|
|
gemini_msg = {
|
|
"realtimeInput": {
|
|
"audio": {
|
|
"mimeType": "audio/pcm;rate=16000",
|
|
"data": audio_b64,
|
|
}
|
|
}
|
|
}
|
|
await gemini_ws.send(json.dumps(gemini_msg))
|
|
continue
|
|
|
|
if msg_type == "conversation_text":
|
|
text = msg.get("text", "").strip()
|
|
if not text:
|
|
continue
|
|
logger.info(f"[{session_id}] User (text): {text}")
|
|
# Send as a text turn to Gemini
|
|
if gemini_ws and gemini_ws.state.name == "OPEN":
|
|
user_part = {"text": text}
|
|
if memory_suffix:
|
|
user_part = {"text": f"[Context: {memory_suffix}]\n{text}"}
|
|
gemini_msg = {
|
|
"clientContent": {
|
|
"turns": [{"role": "user", "parts": [user_part]}],
|
|
"turnComplete": True,
|
|
}
|
|
}
|
|
await gemini_ws.send(json.dumps(gemini_msg))
|
|
await websocket.send_json({
|
|
"type": "transcript",
|
|
"role": "user",
|
|
"text": text,
|
|
})
|
|
continue
|
|
|
|
if msg_type == "ping":
|
|
await websocket.send_json({"type": "pong"})
|
|
|
|
except WebSocketDisconnect:
|
|
pass
|
|
except Exception as e:
|
|
logger.error(f"[{session_id}] Frontend relay error: {e}")
|
|
|
|
gemini_task = asyncio.create_task(relay_gemini())
|
|
frontend_task = asyncio.create_task(relay_frontend())
|
|
|
|
# Wait for either to finish
|
|
done, pending = await asyncio.wait(
|
|
[gemini_task, frontend_task],
|
|
return_when=asyncio.FIRST_COMPLETED,
|
|
)
|
|
for t in pending:
|
|
t.cancel()
|
|
|
|
except Exception as e:
|
|
logger.error(f"[{session_id}] Connection error: {e}")
|
|
finally:
|
|
if gemini_ws:
|
|
await gemini_ws.close()
|
|
logger.info(f"[{session_id}] Disconnected")
|