"""Persistent memory: SQLite storage + brute-force cosine recall over embeddings. Each exchange is stored with its OpenAI embedding as a float32 BLOB. Recall loads all embeddings (optionally scoped to a session) into a matrix and returns the top-k by cosine similarity. Brute force is fine up to tens of thousands of rows; swap in a vector index when that stops being true. """ from __future__ import annotations import sqlite3 from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path import numpy as np from lyra import llm from lyra.config import load SCHEMA = """ CREATE TABLE IF NOT EXISTS exchanges ( id INTEGER PRIMARY KEY AUTOINCREMENT, session_id TEXT NOT NULL, role TEXT NOT NULL, content TEXT NOT NULL, embedding BLOB NOT NULL, created_at TEXT NOT NULL ); CREATE INDEX IF NOT EXISTS idx_session_created ON exchanges(session_id, created_at); CREATE TABLE IF NOT EXISTS sessions ( id TEXT PRIMARY KEY, name TEXT, created_at TEXT NOT NULL ); """ _conn: sqlite3.Connection | None = None _conn_path: Path | None = None def _connection() -> sqlite3.Connection: """Lazily open the SQLite connection. Reopens if LYRA_DB_PATH changed (for tests).""" global _conn, _conn_path cfg = load() if _conn is None or _conn_path != cfg.db_path: if _conn is not None: _conn.close() cfg.db_path.parent.mkdir(parents=True, exist_ok=True) # check_same_thread=False: the web server runs blocking work in a thread # pool, so the singleton connection is touched from threads other than # the one that created it. Safe here under single-user, low-concurrency use. _conn = sqlite3.connect(cfg.db_path, check_same_thread=False) _conn.row_factory = sqlite3.Row _conn.executescript(SCHEMA) _conn_path = cfg.db_path return _conn @dataclass class Exchange: id: int session_id: str role: str content: str created_at: str score: float | None = None def _to_blob(vec: list[float]) -> bytes: return np.asarray(vec, dtype=np.float32).tobytes() def _from_blob(blob: bytes) -> np.ndarray: return np.frombuffer(blob, dtype=np.float32) def remember(session_id: str, role: str, content: str) -> int: """Embed and persist a single exchange. Returns the new row id.""" [embedding] = llm.embed([content]) now = datetime.now(timezone.utc).isoformat() conn = _connection() with conn: cur = conn.execute( "INSERT INTO exchanges (session_id, role, content, embedding, created_at) " "VALUES (?, ?, ?, ?, ?)", (session_id, role, content, _to_blob(embedding), now), ) return int(cur.lastrowid) def recent(session_id: str, n: int = 10) -> list[Exchange]: """Last `n` exchanges from a session, oldest first.""" conn = _connection() rows = conn.execute( "SELECT id, session_id, role, content, created_at FROM exchanges " "WHERE session_id = ? ORDER BY id DESC LIMIT ?", (session_id, n), ).fetchall() return [ Exchange( id=r["id"], session_id=r["session_id"], role=r["role"], content=r["content"], created_at=r["created_at"], ) for r in reversed(rows) ] def ensure_session(session_id: str, name: str | None = None) -> None: """Create the session row if absent; set its name if one is given.""" now = datetime.now(timezone.utc).isoformat() conn = _connection() with conn: conn.execute( "INSERT INTO sessions (id, name, created_at) VALUES (?, ?, ?) " "ON CONFLICT(id) DO NOTHING", (session_id, name, now), ) if name is not None: conn.execute("UPDATE sessions SET name = ? WHERE id = ?", (name, session_id)) def list_sessions() -> list[dict]: """All known sessions (named rows + any session that has exchanges), newest first.""" conn = _connection() rows = conn.execute( """ SELECT s.id AS id, s.name AS name, COALESCE(s.created_at, MIN(e.created_at)) AS created_at FROM sessions s LEFT JOIN exchanges e ON e.session_id = s.id GROUP BY s.id UNION SELECT e.session_id AS id, NULL AS name, MIN(e.created_at) AS created_at FROM exchanges e WHERE e.session_id NOT IN (SELECT id FROM sessions) GROUP BY e.session_id ORDER BY created_at DESC """ ).fetchall() return [{"id": r["id"], "name": r["name"]} for r in rows] def history(session_id: str) -> list[Exchange]: """Full conversation for a session, oldest first.""" conn = _connection() rows = conn.execute( "SELECT id, session_id, role, content, created_at FROM exchanges " "WHERE session_id = ? ORDER BY id ASC", (session_id,), ).fetchall() return [ Exchange( id=r["id"], session_id=r["session_id"], role=r["role"], content=r["content"], created_at=r["created_at"], ) for r in rows ] def delete_session(session_id: str) -> None: """Remove a session and all its exchanges.""" conn = _connection() with conn: conn.execute("DELETE FROM exchanges WHERE session_id = ?", (session_id,)) conn.execute("DELETE FROM sessions WHERE id = ?", (session_id,)) def recall(query: str, k: int = 5, session_id: str | None = None) -> list[Exchange]: """Top-k exchanges semantically similar to `query`, optionally scoped to a session.""" [q_vec] = llm.embed([query]) q = np.asarray(q_vec, dtype=np.float32) conn = _connection() sql = "SELECT id, session_id, role, content, embedding, created_at FROM exchanges" params: tuple = () if session_id is not None: sql += " WHERE session_id = ?" params = (session_id,) rows = conn.execute(sql, params).fetchall() if not rows: return [] matrix = np.stack([_from_blob(r["embedding"]) for r in rows]) norms = np.linalg.norm(matrix, axis=1) scores = (matrix @ q) / (norms * np.linalg.norm(q) + 1e-9) top_idx = np.argsort(scores)[::-1][:k] return [ Exchange( id=rows[i]["id"], session_id=rows[i]["session_id"], role=rows[i]["role"], content=rows[i]["content"], created_at=rows[i]["created_at"], score=float(scores[i]), ) for i in top_idx ]