feat: SQLite-backed memory with brute-force cosine recall
- lyra.memory.remember(session_id, role, content) embeds and stores - lyra.memory.recent(session_id, n) returns the last N from a session - lyra.memory.recall(query, k, session_id=None) returns top-k by cosine similarity across the chosen scope (all sessions by default) - Embeddings live in the exchanges.embedding BLOB column as float32 bytes - Connection reopens automatically if LYRA_DB_PATH changes (test-friendly)
This commit is contained in:
+133
@@ -0,0 +1,133 @@
|
||||
"""Persistent memory: SQLite storage + brute-force cosine recall over embeddings.
|
||||
|
||||
Each exchange is stored with its OpenAI embedding as a float32 BLOB. Recall
|
||||
loads all embeddings (optionally scoped to a session) into a matrix and
|
||||
returns the top-k by cosine similarity. Brute force is fine up to tens of
|
||||
thousands of rows; swap in a vector index when that stops being true.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lyra import llm
|
||||
from lyra.config import load
|
||||
|
||||
SCHEMA = """
|
||||
CREATE TABLE IF NOT EXISTS exchanges (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
session_id TEXT NOT NULL,
|
||||
role TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
embedding BLOB NOT NULL,
|
||||
created_at TEXT NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_session_created ON exchanges(session_id, created_at);
|
||||
"""
|
||||
|
||||
_conn: sqlite3.Connection | None = None
|
||||
_conn_path: Path | None = None
|
||||
|
||||
|
||||
def _connection() -> sqlite3.Connection:
|
||||
"""Lazily open the SQLite connection. Reopens if LYRA_DB_PATH changed (for tests)."""
|
||||
global _conn, _conn_path
|
||||
cfg = load()
|
||||
if _conn is None or _conn_path != cfg.db_path:
|
||||
if _conn is not None:
|
||||
_conn.close()
|
||||
cfg.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
_conn = sqlite3.connect(cfg.db_path)
|
||||
_conn.row_factory = sqlite3.Row
|
||||
_conn.executescript(SCHEMA)
|
||||
_conn_path = cfg.db_path
|
||||
return _conn
|
||||
|
||||
|
||||
@dataclass
|
||||
class Exchange:
|
||||
id: int
|
||||
session_id: str
|
||||
role: str
|
||||
content: str
|
||||
created_at: str
|
||||
score: float | None = None
|
||||
|
||||
|
||||
def _to_blob(vec: list[float]) -> bytes:
|
||||
return np.asarray(vec, dtype=np.float32).tobytes()
|
||||
|
||||
|
||||
def _from_blob(blob: bytes) -> np.ndarray:
|
||||
return np.frombuffer(blob, dtype=np.float32)
|
||||
|
||||
|
||||
def remember(session_id: str, role: str, content: str) -> int:
|
||||
"""Embed and persist a single exchange. Returns the new row id."""
|
||||
[embedding] = llm.embed([content])
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
conn = _connection()
|
||||
with conn:
|
||||
cur = conn.execute(
|
||||
"INSERT INTO exchanges (session_id, role, content, embedding, created_at) "
|
||||
"VALUES (?, ?, ?, ?, ?)",
|
||||
(session_id, role, content, _to_blob(embedding), now),
|
||||
)
|
||||
return int(cur.lastrowid)
|
||||
|
||||
|
||||
def recent(session_id: str, n: int = 10) -> list[Exchange]:
|
||||
"""Last `n` exchanges from a session, oldest first."""
|
||||
conn = _connection()
|
||||
rows = conn.execute(
|
||||
"SELECT id, session_id, role, content, created_at FROM exchanges "
|
||||
"WHERE session_id = ? ORDER BY id DESC LIMIT ?",
|
||||
(session_id, n),
|
||||
).fetchall()
|
||||
return [
|
||||
Exchange(
|
||||
id=r["id"],
|
||||
session_id=r["session_id"],
|
||||
role=r["role"],
|
||||
content=r["content"],
|
||||
created_at=r["created_at"],
|
||||
)
|
||||
for r in reversed(rows)
|
||||
]
|
||||
|
||||
|
||||
def recall(query: str, k: int = 5, session_id: str | None = None) -> list[Exchange]:
|
||||
"""Top-k exchanges semantically similar to `query`, optionally scoped to a session."""
|
||||
[q_vec] = llm.embed([query])
|
||||
q = np.asarray(q_vec, dtype=np.float32)
|
||||
|
||||
conn = _connection()
|
||||
sql = "SELECT id, session_id, role, content, embedding, created_at FROM exchanges"
|
||||
params: tuple = ()
|
||||
if session_id is not None:
|
||||
sql += " WHERE session_id = ?"
|
||||
params = (session_id,)
|
||||
rows = conn.execute(sql, params).fetchall()
|
||||
if not rows:
|
||||
return []
|
||||
|
||||
matrix = np.stack([_from_blob(r["embedding"]) for r in rows])
|
||||
norms = np.linalg.norm(matrix, axis=1)
|
||||
scores = (matrix @ q) / (norms * np.linalg.norm(q) + 1e-9)
|
||||
|
||||
top_idx = np.argsort(scores)[::-1][:k]
|
||||
return [
|
||||
Exchange(
|
||||
id=rows[i]["id"],
|
||||
session_id=rows[i]["session_id"],
|
||||
role=rows[i]["role"],
|
||||
content=rows[i]["content"],
|
||||
created_at=rows[i]["created_at"],
|
||||
score=float(scores[i]),
|
||||
)
|
||||
for i in top_idx
|
||||
]
|
||||
Reference in New Issue
Block a user