feat(tasks): Gemini tool calling for task list management
5 tools: add_task, remove_task, complete_task, get_tasks, clear_completed_tasks Backend stores tasks in-memory per session. Frontend TaskList component syncs via WS. Kira can manage tasks via voice or text conversation.
This commit is contained in:
+158
-27
@@ -1,15 +1,13 @@
|
||||
"""Kira — AI body double backend
|
||||
|
||||
Gemini Live API (gemini-3.1-flash-live-preview) for real-time voice.
|
||||
Text chat still goes through Gemini generateContent REST endpoint.
|
||||
Task list management via Gemini function calling.
|
||||
"""
|
||||
|
||||
import json
|
||||
import base64
|
||||
import uuid
|
||||
import logging
|
||||
import asyncio
|
||||
import struct
|
||||
|
||||
import websockets
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
@@ -37,12 +35,119 @@ BASE_SYSTEM_PROMPT = (
|
||||
"stay focused and on task. Keep responses short, supportive, and uplifting. "
|
||||
"Check in on them. Remind them to take breaks. Celebrate small wins. "
|
||||
"Use occasional emoji but don't overdo it. Never be judgmental. "
|
||||
"You are speaking out loud via voice, so keep natural conversational flow."
|
||||
"You are speaking out loud via voice, so keep natural conversational flow. "
|
||||
"You have tools to manage a task list for the user. When they mention tasks, "
|
||||
"todos, things to do, or things to remember, use the task tools to help. "
|
||||
"Always confirm when you add or remove something. When asked what's on the "
|
||||
"list, read it back to them."
|
||||
)
|
||||
|
||||
GEMINI_WS_URL = "wss://generativelanguage.googleapis.com/ws/google.ai.generativelanguage.v1beta.GenerativeService.BidiGenerateContent"
|
||||
GEMINI_MODEL = "models/gemini-3.1-flash-live-preview"
|
||||
|
||||
# ── Gemini tool declarations ──
|
||||
TOOLS = [
|
||||
{
|
||||
"functionDeclarations": [
|
||||
{
|
||||
"name": "add_task",
|
||||
"description": "Add a new task to the user's task list.",
|
||||
"parameters": {
|
||||
"type": "OBJECT",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "STRING",
|
||||
"description": "The task description.",
|
||||
}
|
||||
},
|
||||
"required": ["text"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "remove_task",
|
||||
"description": "Remove a task from the list by its ID.",
|
||||
"parameters": {
|
||||
"type": "OBJECT",
|
||||
"properties": {
|
||||
"task_id": {
|
||||
"type": "STRING",
|
||||
"description": "The ID of the task to remove.",
|
||||
}
|
||||
},
|
||||
"required": ["task_id"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "complete_task",
|
||||
"description": "Mark a task as completed.",
|
||||
"parameters": {
|
||||
"type": "OBJECT",
|
||||
"properties": {
|
||||
"task_id": {
|
||||
"type": "STRING",
|
||||
"description": "The ID of the task to mark complete.",
|
||||
}
|
||||
},
|
||||
"required": ["task_id"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "get_tasks",
|
||||
"description": "Get the current task list.",
|
||||
"parameters": {
|
||||
"type": "OBJECT",
|
||||
"properties": {},
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "clear_completed_tasks",
|
||||
"description": "Remove all completed tasks from the list.",
|
||||
"parameters": {
|
||||
"type": "OBJECT",
|
||||
"properties": {},
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def execute_tool(name: str, args: dict, tasks: list[dict]) -> dict:
|
||||
"""Execute a tool call and return the result."""
|
||||
if name == "add_task":
|
||||
text = args.get("text", "").strip()
|
||||
if not text:
|
||||
return {"error": "Task text cannot be empty."}
|
||||
task = {"id": str(uuid.uuid4())[:8], "text": text, "completed": False}
|
||||
tasks.append(task)
|
||||
return {"status": "added", "task": task, "total": len(tasks)}
|
||||
|
||||
elif name == "remove_task":
|
||||
task_id = args.get("task_id", "")
|
||||
for i, t in enumerate(tasks):
|
||||
if t["id"] == task_id:
|
||||
removed = tasks.pop(i)
|
||||
return {"status": "removed", "task": removed, "total": len(tasks)}
|
||||
return {"error": f"Task {task_id} not found."}
|
||||
|
||||
elif name == "complete_task":
|
||||
task_id = args.get("task_id", "")
|
||||
for t in tasks:
|
||||
if t["id"] == task_id:
|
||||
t["completed"] = True
|
||||
return {"status": "completed", "task": t}
|
||||
return {"error": f"Task {task_id} not found."}
|
||||
|
||||
elif name == "get_tasks":
|
||||
return {"tasks": tasks, "total": len(tasks)}
|
||||
|
||||
elif name == "clear_completed_tasks":
|
||||
before = len(tasks)
|
||||
tasks[:] = [t for t in tasks if not t["completed"]]
|
||||
return {"status": "cleared", "removed": before - len(tasks), "total": len(tasks)}
|
||||
|
||||
return {"error": f"Unknown tool: {name}"}
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup():
|
||||
@@ -62,33 +167,37 @@ async def health():
|
||||
async def gemini_voice_ws(websocket: WebSocket):
|
||||
"""WebSocket proxy between frontend and Gemini Live API.
|
||||
|
||||
Protocol (frontend ↔ this proxy):
|
||||
→ {"type": "audio", "data": "<base64 PCM16 16kHz>"}
|
||||
→ {"type": "conversation_text", "text": "..."}
|
||||
→ {"type": "identify", "user_id": "...", "name": "..."}
|
||||
→ {"type": "ping"}
|
||||
← {"type": "audio", "data": "<base64 PCM16 24kHz>"}
|
||||
← {"type": "transcript", "role": "user"|"kira", "text": "..."}
|
||||
← {"type": "turn_complete"}
|
||||
← {"type": "interrupted"}
|
||||
← {"type": "error", "message": "..."}
|
||||
Protocol (frontend <-> this proxy):
|
||||
-> {"type": "audio", "data": "<base64 PCM16 16kHz>"}
|
||||
-> {"type": "conversation_text", "text": "..."}
|
||||
-> {"type": "identify", "user_id": "...", "name": "..."}
|
||||
-> {"type": "ping"}
|
||||
<- {"type": "audio", "data": "<base64 PCM16 24kHz>"}
|
||||
<- {"type": "transcript", "role": "user"|"kira", "text": "..."}
|
||||
<- {"type": "turn_complete"}
|
||||
<- {"type": "interrupted"}
|
||||
<- {"type": "tasks", "tasks": [...]}
|
||||
<- {"type": "error", "message": "..."}
|
||||
"""
|
||||
await websocket.accept()
|
||||
session_id = str(uuid.uuid4())[:8]
|
||||
user_id = "default-user"
|
||||
memory_suffix = ""
|
||||
tasks: list[dict] = []
|
||||
logger.info(f"[{session_id}] WebSocket connected")
|
||||
|
||||
gemini_ws = None
|
||||
gemini_task = None
|
||||
frontend_task = None
|
||||
|
||||
async def send_tasks_to_frontend():
|
||||
"""Push current task list to frontend."""
|
||||
await websocket.send_json({"type": "tasks", "tasks": tasks})
|
||||
|
||||
try:
|
||||
# ── Connect to Gemini Live API ──
|
||||
gemini_url = f"{GEMINI_WS_URL}?key={settings.gemini_api_key}"
|
||||
gemini_ws = await websockets.connect(gemini_url, max_size=2**24)
|
||||
|
||||
# ── Send setup ──
|
||||
# ── Send setup with tools ──
|
||||
system_prompt = BASE_SYSTEM_PROMPT
|
||||
setup_msg = {
|
||||
"setup": {
|
||||
@@ -106,6 +215,7 @@ async def gemini_voice_ws(websocket: WebSocket):
|
||||
"systemInstruction": {
|
||||
"parts": [{"text": system_prompt}]
|
||||
},
|
||||
"tools": TOOLS,
|
||||
}
|
||||
}
|
||||
await gemini_ws.send(json.dumps(setup_msg))
|
||||
@@ -119,7 +229,7 @@ async def gemini_voice_ws(websocket: WebSocket):
|
||||
else:
|
||||
logger.warning(f"[{session_id}] Unexpected setup response: {list(setup_resp.keys())}")
|
||||
|
||||
# ── Gemini → Frontend relay ──
|
||||
# ── Gemini -> Frontend relay ──
|
||||
async def relay_gemini():
|
||||
try:
|
||||
async for raw in gemini_ws:
|
||||
@@ -131,7 +241,6 @@ async def gemini_voice_ws(websocket: WebSocket):
|
||||
parts = model_turn.get("parts", [])
|
||||
|
||||
for part in parts:
|
||||
# Text response
|
||||
if "text" in part:
|
||||
await websocket.send_json({
|
||||
"type": "transcript",
|
||||
@@ -139,7 +248,6 @@ async def gemini_voice_ws(websocket: WebSocket):
|
||||
"text": part["text"],
|
||||
})
|
||||
|
||||
# Audio response (PCM16 24kHz)
|
||||
if "inlineData" in part:
|
||||
audio_data = part["inlineData"].get("data", "")
|
||||
if audio_data:
|
||||
@@ -148,16 +256,40 @@ async def gemini_voice_ws(websocket: WebSocket):
|
||||
"data": audio_data,
|
||||
})
|
||||
|
||||
# Turn complete
|
||||
if sc.get("turnComplete"):
|
||||
await websocket.send_json({"type": "turn_complete"})
|
||||
|
||||
# Interrupted
|
||||
if sc.get("interrupted"):
|
||||
await websocket.send_json({"type": "interrupted"})
|
||||
|
||||
elif "toolCall" in msg:
|
||||
pass # future: tool use
|
||||
# Execute tool calls and send responses back
|
||||
tool_call = msg["toolCall"]
|
||||
function_calls = tool_call.get("functionCalls", [])
|
||||
tool_results = []
|
||||
|
||||
for fc in function_calls:
|
||||
call_id = fc.get("id", "")
|
||||
fn_name = fc.get("name", "")
|
||||
fn_args = fc.get("args", {})
|
||||
logger.info(f"[{session_id}] Tool call: {fn_name}({fn_args})")
|
||||
|
||||
result = execute_tool(fn_name, fn_args, tasks)
|
||||
tool_results.append({
|
||||
"id": call_id,
|
||||
"name": fn_name,
|
||||
"response": result,
|
||||
})
|
||||
|
||||
# Push updated task list to frontend after any mutation
|
||||
if fn_name in ("add_task", "remove_task", "complete_task", "clear_completed_tasks"):
|
||||
await send_tasks_to_frontend()
|
||||
|
||||
# Send tool response back to Gemini
|
||||
if tool_results:
|
||||
resp = {"toolResponse": {"functionResponses": tool_results}}
|
||||
await gemini_ws.send(json.dumps(resp))
|
||||
logger.info(f"[{session_id}] Sent {len(tool_results)} tool responses")
|
||||
|
||||
elif "toolCallCancellation" in msg:
|
||||
pass
|
||||
@@ -175,7 +307,7 @@ async def gemini_voice_ws(websocket: WebSocket):
|
||||
except Exception as e:
|
||||
logger.error(f"[{session_id}] Gemini relay error: {e}")
|
||||
|
||||
# ── Frontend → Gemini relay ──
|
||||
# ── Frontend -> Gemini relay ──
|
||||
async def relay_frontend():
|
||||
nonlocal user_id, memory_suffix
|
||||
try:
|
||||
@@ -205,6 +337,8 @@ async def gemini_voice_ws(websocket: WebSocket):
|
||||
"user_id": user_id,
|
||||
"preferences": prefs,
|
||||
})
|
||||
# Send current tasks on identify
|
||||
await send_tasks_to_frontend()
|
||||
continue
|
||||
|
||||
if msg_type == "set_preference":
|
||||
@@ -219,7 +353,6 @@ async def gemini_voice_ws(websocket: WebSocket):
|
||||
continue
|
||||
|
||||
if msg_type == "audio":
|
||||
# Forward PCM16 audio to Gemini as realtimeInput
|
||||
audio_b64 = msg.get("data", "")
|
||||
if audio_b64 and gemini_ws and gemini_ws.state.name == "OPEN":
|
||||
gemini_msg = {
|
||||
@@ -238,7 +371,6 @@ async def gemini_voice_ws(websocket: WebSocket):
|
||||
if not text:
|
||||
continue
|
||||
logger.info(f"[{session_id}] User (text): {text}")
|
||||
# Send as a text turn to Gemini
|
||||
if gemini_ws and gemini_ws.state.name == "OPEN":
|
||||
user_part = {"text": text}
|
||||
if memory_suffix:
|
||||
@@ -268,7 +400,6 @@ async def gemini_voice_ws(websocket: WebSocket):
|
||||
gemini_task = asyncio.create_task(relay_gemini())
|
||||
frontend_task = asyncio.create_task(relay_frontend())
|
||||
|
||||
# Wait for either to finish
|
||||
done, pending = await asyncio.wait(
|
||||
[gemini_task, frontend_task],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
|
||||
Reference in New Issue
Block a user