Files
project-lyra/lyra/memory.py
T
serversdown f3037b7879 feat: ChatGPT chat-log importer
Import the parser's {title, messages} JSON into Lyra's memory so past
conversations seed recall (and, later, the era-rollup tier).

- lyra/ingest.py: one conversation -> one session, text messages -> exchanges;
  skips non-text (image asset) messages and non user/assistant roles; embeddings
  batched; idempotent by filename-derived session id; `lyra-import <dir>` CLI
- memory.add_exchanges_bulk: batched insert of pre-embedded rows

Format has no timestamps yet, so imports are stamped at import time; a future
dated export will let era memory group by real calendar time.

Verified on the 68-file lyra dev set: 7519 exchanges, idempotent re-run, recall
returns relevant history.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-16 00:51:45 +00:00

323 lines
10 KiB
Python

"""Persistent memory: SQLite storage + brute-force cosine recall over embeddings.
Each exchange is stored with its OpenAI embedding as a float32 BLOB. Recall
loads all embeddings (optionally scoped to a session) into a matrix and
returns the top-k by cosine similarity. Brute force is fine up to tens of
thousands of rows; swap in a vector index when that stops being true.
"""
from __future__ import annotations
import sqlite3
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
import numpy as np
from lyra import llm
from lyra.config import load
SCHEMA = """
CREATE TABLE IF NOT EXISTS exchanges (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT NOT NULL,
role TEXT NOT NULL,
content TEXT NOT NULL,
embedding BLOB NOT NULL,
created_at TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_session_created ON exchanges(session_id, created_at);
CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY,
name TEXT,
created_at TEXT NOT NULL
);
-- One compacted "gist" per session. last_exchange_id marks how far the summary
-- covers, so we know when enough new turns have accumulated to re-summarize.
CREATE TABLE IF NOT EXISTS summaries (
session_id TEXT PRIMARY KEY,
content TEXT NOT NULL,
embedding BLOB NOT NULL,
last_exchange_id INTEGER NOT NULL,
created_at TEXT NOT NULL
);
"""
_conn: sqlite3.Connection | None = None
_conn_path: Path | None = None
def _connection() -> sqlite3.Connection:
"""Lazily open the SQLite connection. Reopens if LYRA_DB_PATH changed (for tests)."""
global _conn, _conn_path
cfg = load()
if _conn is None or _conn_path != cfg.db_path:
if _conn is not None:
_conn.close()
cfg.db_path.parent.mkdir(parents=True, exist_ok=True)
# check_same_thread=False: the web server runs blocking work in a thread
# pool, so the singleton connection is touched from threads other than
# the one that created it. Safe here under single-user, low-concurrency use.
_conn = sqlite3.connect(cfg.db_path, check_same_thread=False)
_conn.row_factory = sqlite3.Row
_conn.executescript(SCHEMA)
_conn_path = cfg.db_path
return _conn
@dataclass
class Exchange:
id: int
session_id: str
role: str
content: str
created_at: str
score: float | None = None
@dataclass
class Summary:
session_id: str
content: str
last_exchange_id: int
created_at: str
score: float | None = None
def _to_blob(vec: list[float]) -> bytes:
return np.asarray(vec, dtype=np.float32).tobytes()
def _from_blob(blob: bytes) -> np.ndarray:
return np.frombuffer(blob, dtype=np.float32)
def remember(session_id: str, role: str, content: str) -> int:
"""Embed and persist a single exchange. Returns the new row id."""
[embedding] = llm.embed([content])
now = datetime.now(timezone.utc).isoformat()
conn = _connection()
with conn:
cur = conn.execute(
"INSERT INTO exchanges (session_id, role, content, embedding, created_at) "
"VALUES (?, ?, ?, ?, ?)",
(session_id, role, content, _to_blob(embedding), now),
)
return int(cur.lastrowid)
def add_exchanges_bulk(session_id: str, rows: list[tuple[str, str, list[float], str]]) -> int:
"""Insert many pre-embedded exchanges at once.
Each row is (role, content, embedding, created_at). Used by the importer to
avoid one INSERT (and one embed round-trip) per message. Returns row count.
"""
conn = _connection()
with conn:
conn.executemany(
"INSERT INTO exchanges (session_id, role, content, embedding, created_at) "
"VALUES (?, ?, ?, ?, ?)",
[(session_id, role, content, _to_blob(emb), ca) for role, content, emb, ca in rows],
)
return len(rows)
def recent(session_id: str, n: int = 10) -> list[Exchange]:
"""Last `n` exchanges from a session, oldest first."""
conn = _connection()
rows = conn.execute(
"SELECT id, session_id, role, content, created_at FROM exchanges "
"WHERE session_id = ? ORDER BY id DESC LIMIT ?",
(session_id, n),
).fetchall()
return [
Exchange(
id=r["id"],
session_id=r["session_id"],
role=r["role"],
content=r["content"],
created_at=r["created_at"],
)
for r in reversed(rows)
]
def ensure_session(session_id: str, name: str | None = None) -> None:
"""Create the session row if absent; set its name if one is given."""
now = datetime.now(timezone.utc).isoformat()
conn = _connection()
with conn:
conn.execute(
"INSERT INTO sessions (id, name, created_at) VALUES (?, ?, ?) "
"ON CONFLICT(id) DO NOTHING",
(session_id, name, now),
)
if name is not None:
conn.execute("UPDATE sessions SET name = ? WHERE id = ?", (name, session_id))
def list_sessions() -> list[dict]:
"""All known sessions (named rows + any session that has exchanges), newest first."""
conn = _connection()
rows = conn.execute(
"""
SELECT s.id AS id,
s.name AS name,
COALESCE(s.created_at, MIN(e.created_at)) AS created_at
FROM sessions s
LEFT JOIN exchanges e ON e.session_id = s.id
GROUP BY s.id
UNION
SELECT e.session_id AS id, NULL AS name, MIN(e.created_at) AS created_at
FROM exchanges e
WHERE e.session_id NOT IN (SELECT id FROM sessions)
GROUP BY e.session_id
ORDER BY created_at DESC
"""
).fetchall()
return [{"id": r["id"], "name": r["name"]} for r in rows]
def history(session_id: str) -> list[Exchange]:
"""Full conversation for a session, oldest first."""
conn = _connection()
rows = conn.execute(
"SELECT id, session_id, role, content, created_at FROM exchanges "
"WHERE session_id = ? ORDER BY id ASC",
(session_id,),
).fetchall()
return [
Exchange(
id=r["id"],
session_id=r["session_id"],
role=r["role"],
content=r["content"],
created_at=r["created_at"],
)
for r in rows
]
def delete_session(session_id: str) -> None:
"""Remove a session and all its exchanges."""
conn = _connection()
with conn:
conn.execute("DELETE FROM exchanges WHERE session_id = ?", (session_id,))
conn.execute("DELETE FROM sessions WHERE id = ?", (session_id,))
conn.execute("DELETE FROM summaries WHERE session_id = ?", (session_id,))
def recall(query: str, k: int = 5, session_id: str | None = None) -> list[Exchange]:
"""Top-k exchanges semantically similar to `query`, optionally scoped to a session."""
[q_vec] = llm.embed([query])
q = np.asarray(q_vec, dtype=np.float32)
conn = _connection()
sql = "SELECT id, session_id, role, content, embedding, created_at FROM exchanges"
params: tuple = ()
if session_id is not None:
sql += " WHERE session_id = ?"
params = (session_id,)
rows = conn.execute(sql, params).fetchall()
if not rows:
return []
matrix = np.stack([_from_blob(r["embedding"]) for r in rows])
norms = np.linalg.norm(matrix, axis=1)
scores = (matrix @ q) / (norms * np.linalg.norm(q) + 1e-9)
top_idx = np.argsort(scores)[::-1][:k]
return [
Exchange(
id=rows[i]["id"],
session_id=rows[i]["session_id"],
role=rows[i]["role"],
content=rows[i]["content"],
created_at=rows[i]["created_at"],
score=float(scores[i]),
)
for i in top_idx
]
# --- Summary tier (compacted per-session gists) ---
def store_summary(session_id: str, content: str, last_exchange_id: int) -> None:
"""Embed and persist the gist of a session, replacing any prior summary."""
[embedding] = llm.embed([content])
now = datetime.now(timezone.utc).isoformat()
conn = _connection()
with conn:
conn.execute(
"INSERT INTO summaries (session_id, content, embedding, last_exchange_id, created_at) "
"VALUES (?, ?, ?, ?, ?) "
"ON CONFLICT(session_id) DO UPDATE SET "
"content=excluded.content, embedding=excluded.embedding, "
"last_exchange_id=excluded.last_exchange_id, created_at=excluded.created_at",
(session_id, content, _to_blob(embedding), last_exchange_id, now),
)
def get_summary(session_id: str) -> Summary | None:
conn = _connection()
r = conn.execute(
"SELECT session_id, content, last_exchange_id, created_at FROM summaries "
"WHERE session_id = ?",
(session_id,),
).fetchone()
if r is None:
return None
return Summary(
session_id=r["session_id"],
content=r["content"],
last_exchange_id=r["last_exchange_id"],
created_at=r["created_at"],
)
def unsummarized_count(session_id: str) -> int:
"""How many exchanges in this session are newer than its current summary."""
conn = _connection()
summary = get_summary(session_id)
cutoff = summary.last_exchange_id if summary else 0
r = conn.execute(
"SELECT COUNT(*) AS n FROM exchanges WHERE session_id = ? AND id > ?",
(session_id, cutoff),
).fetchone()
return int(r["n"])
def recall_summaries(query: str, k: int = 3, exclude_session: str | None = None) -> list[Summary]:
"""Top-k session summaries most similar to `query` (the long-term gist tier)."""
[q_vec] = llm.embed([query])
q = np.asarray(q_vec, dtype=np.float32)
conn = _connection()
sql = "SELECT session_id, content, embedding, last_exchange_id, created_at FROM summaries"
params: tuple = ()
if exclude_session is not None:
sql += " WHERE session_id != ?"
params = (exclude_session,)
rows = conn.execute(sql, params).fetchall()
if not rows:
return []
matrix = np.stack([_from_blob(r["embedding"]) for r in rows])
norms = np.linalg.norm(matrix, axis=1)
scores = (matrix @ q) / (norms * np.linalg.norm(q) + 1e-9)
top_idx = np.argsort(scores)[::-1][:k]
return [
Summary(
session_id=rows[i]["session_id"],
content=rows[i]["content"],
last_exchange_id=rows[i]["last_exchange_id"],
created_at=rows[i]["created_at"],
score=float(scores[i]),
)
for i in top_idx
]