feat: tiered, compacting memory (phase 1.5)
Older sessions fade to a general idea; details stay retrievable.
- memory: summaries table (one compacted gist per session, embedded), plus
store_summary/get_summary/recall_summaries and unsummarized_count (tracks
exchanges newer than the current summary)
- lyra/summary.py: summarize_session compacts a session's raw turns into a
third-person gist (default SUMMARY_BACKEND=local, so compaction is free);
maybe_summarize re-summarizes once SUMMARIZE_AFTER new turns accumulate
- chat.build_messages now layers context in tiers: persona -> gists of other
sessions -> a few sharp raw cross-session details -> current session raw
turns -> new message; respond() compacts the session after each turn
- web: POST /sessions/{id}/summarize to compact on demand
- summarization activity surfaces in the live log
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
+44
-22
@@ -1,43 +1,62 @@
|
||||
"""The chat turn loop: persona + recalled memory + recent context -> reply.
|
||||
"""The chat turn loop: persona + tiered memory + recent context -> reply.
|
||||
|
||||
Each turn assembles the persona system prompt, semantically-relevant memories
|
||||
recalled from across all past sessions, and the recent turns of the current
|
||||
session, then asks the model for a reply and persists both sides.
|
||||
Context is assembled in tiers (oldest/most-compacted first):
|
||||
1. persona
|
||||
2. long-term gist — relevant *summaries* of other sessions
|
||||
3. sharp details — a few raw cross-session exchanges (so specifics survive)
|
||||
4. recent raw turns of the current session (full fidelity)
|
||||
5. the new user message
|
||||
After replying, the session is compacted if enough new turns have accumulated.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from lyra import config, llm, logbus, memory, persona
|
||||
from lyra import config, llm, logbus, memory, persona, summary
|
||||
from lyra.llm import Backend, Message
|
||||
|
||||
RECALL_K = 5
|
||||
RECENT_N = 10
|
||||
RECALL_K = 3 # raw cross-session "sharp detail" hits
|
||||
RECENT_N = 10 # raw turns of the current session
|
||||
SUMMARY_K = 3 # other-session gists
|
||||
|
||||
|
||||
def _memory_note(exchanges: list[memory.Exchange]) -> Message:
|
||||
"""Format recalled memories as a system note Lyra can draw on."""
|
||||
lines = []
|
||||
for ex in exchanges:
|
||||
when = ex.created_at[:10] # YYYY-MM-DD
|
||||
lines.append(f"- ({when}, {ex.role}) {ex.content}")
|
||||
body = "Relevant things you remember from past conversations:\n" + "\n".join(lines)
|
||||
def _summary_note(summaries: list[memory.Summary]) -> Message:
|
||||
lines = [f"- ({s.created_at[:10]}) {s.content}" for s in summaries]
|
||||
body = "Gist of earlier sessions (compacted — ask if you need specifics):\n" + "\n".join(lines)
|
||||
return {"role": "system", "content": body}
|
||||
|
||||
|
||||
def _detail_note(exchanges: list[memory.Exchange]) -> Message:
|
||||
lines = [f"- ({ex.created_at[:10]}, {ex.role}) {ex.content}" for ex in exchanges]
|
||||
body = "Specific things you recall from past conversations:\n" + "\n".join(lines)
|
||||
return {"role": "system", "content": body}
|
||||
|
||||
|
||||
def build_messages(session_id: str, user_msg: str) -> list[Message]:
|
||||
"""Assemble the full message list for one turn."""
|
||||
"""Assemble the full, tiered message list for one turn."""
|
||||
messages: list[Message] = [{"role": "system", "content": persona.system_prompt()}]
|
||||
|
||||
recent = memory.recent(session_id, n=RECENT_N)
|
||||
recent_ids = {ex.id for ex in recent}
|
||||
|
||||
# Cross-session recall, minus anything already shown in the recent window.
|
||||
recalled = [
|
||||
ex for ex in memory.recall(user_msg, k=RECALL_K) if ex.id not in recent_ids
|
||||
]
|
||||
logbus.log("debug", "context built", recent=len(recent), recalled=len(recalled))
|
||||
if recalled:
|
||||
messages.append(_memory_note(recalled))
|
||||
# Tier 1: compacted gists of *other* sessions (long-term, general idea).
|
||||
summaries = memory.recall_summaries(user_msg, k=SUMMARY_K, exclude_session=session_id)
|
||||
if summaries:
|
||||
messages.append(_summary_note(summaries))
|
||||
|
||||
# Tier 2: a few sharp raw details from other sessions (so specifics survive
|
||||
# compaction). Skip the current session (its raw turns are in `recent`).
|
||||
recalled = [
|
||||
ex for ex in memory.recall(user_msg, k=RECALL_K)
|
||||
if ex.id not in recent_ids and ex.session_id != session_id
|
||||
]
|
||||
if recalled:
|
||||
messages.append(_detail_note(recalled))
|
||||
|
||||
logbus.log(
|
||||
"debug", "context built",
|
||||
recent=len(recent), summaries=len(summaries), details=len(recalled),
|
||||
)
|
||||
|
||||
# Tier 3: current session, full fidelity.
|
||||
for ex in recent:
|
||||
messages.append({"role": ex.role, "content": ex.content})
|
||||
|
||||
@@ -60,4 +79,7 @@ def respond(session_id: str, user_msg: str, backend: Backend = "cloud") -> str:
|
||||
|
||||
memory.remember(session_id, "user", user_msg)
|
||||
memory.remember(session_id, "assistant", reply)
|
||||
|
||||
# Compact this session once enough new turns have piled up.
|
||||
summary.maybe_summarize(session_id)
|
||||
return reply
|
||||
|
||||
@@ -19,6 +19,7 @@ class Config:
|
||||
embed_backend: str # "cloud" (OpenAI) or "local" (Ollama)
|
||||
embed_model: str # OpenAI embedding model
|
||||
local_embed_model: str # Ollama embedding model
|
||||
summary_backend: str # "local" or "cloud" — backend used to compact memory
|
||||
db_path: Path
|
||||
|
||||
|
||||
@@ -31,5 +32,6 @@ def load() -> Config:
|
||||
embed_backend=os.getenv("EMBED_BACKEND", "cloud").lower(),
|
||||
embed_model=os.getenv("EMBED_MODEL", "text-embedding-3-small"),
|
||||
local_embed_model=os.getenv("LOCAL_EMBED_MODEL", "nomic-embed-text"),
|
||||
summary_backend=os.getenv("SUMMARY_BACKEND", "local").lower(),
|
||||
db_path=Path(os.getenv("LYRA_DB_PATH", "data/lyra.db")),
|
||||
)
|
||||
|
||||
+100
@@ -33,6 +33,16 @@ CREATE TABLE IF NOT EXISTS sessions (
|
||||
name TEXT,
|
||||
created_at TEXT NOT NULL
|
||||
);
|
||||
|
||||
-- One compacted "gist" per session. last_exchange_id marks how far the summary
|
||||
-- covers, so we know when enough new turns have accumulated to re-summarize.
|
||||
CREATE TABLE IF NOT EXISTS summaries (
|
||||
session_id TEXT PRIMARY KEY,
|
||||
content TEXT NOT NULL,
|
||||
embedding BLOB NOT NULL,
|
||||
last_exchange_id INTEGER NOT NULL,
|
||||
created_at TEXT NOT NULL
|
||||
);
|
||||
"""
|
||||
|
||||
_conn: sqlite3.Connection | None = None
|
||||
@@ -67,6 +77,15 @@ class Exchange:
|
||||
score: float | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Summary:
|
||||
session_id: str
|
||||
content: str
|
||||
last_exchange_id: int
|
||||
created_at: str
|
||||
score: float | None = None
|
||||
|
||||
|
||||
def _to_blob(vec: list[float]) -> bytes:
|
||||
return np.asarray(vec, dtype=np.float32).tobytes()
|
||||
|
||||
@@ -171,6 +190,7 @@ def delete_session(session_id: str) -> None:
|
||||
with conn:
|
||||
conn.execute("DELETE FROM exchanges WHERE session_id = ?", (session_id,))
|
||||
conn.execute("DELETE FROM sessions WHERE id = ?", (session_id,))
|
||||
conn.execute("DELETE FROM summaries WHERE session_id = ?", (session_id,))
|
||||
|
||||
|
||||
def recall(query: str, k: int = 5, session_id: str | None = None) -> list[Exchange]:
|
||||
@@ -204,3 +224,83 @@ def recall(query: str, k: int = 5, session_id: str | None = None) -> list[Exchan
|
||||
)
|
||||
for i in top_idx
|
||||
]
|
||||
|
||||
|
||||
# --- Summary tier (compacted per-session gists) ---
|
||||
|
||||
|
||||
def store_summary(session_id: str, content: str, last_exchange_id: int) -> None:
|
||||
"""Embed and persist the gist of a session, replacing any prior summary."""
|
||||
[embedding] = llm.embed([content])
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
conn = _connection()
|
||||
with conn:
|
||||
conn.execute(
|
||||
"INSERT INTO summaries (session_id, content, embedding, last_exchange_id, created_at) "
|
||||
"VALUES (?, ?, ?, ?, ?) "
|
||||
"ON CONFLICT(session_id) DO UPDATE SET "
|
||||
"content=excluded.content, embedding=excluded.embedding, "
|
||||
"last_exchange_id=excluded.last_exchange_id, created_at=excluded.created_at",
|
||||
(session_id, content, _to_blob(embedding), last_exchange_id, now),
|
||||
)
|
||||
|
||||
|
||||
def get_summary(session_id: str) -> Summary | None:
|
||||
conn = _connection()
|
||||
r = conn.execute(
|
||||
"SELECT session_id, content, last_exchange_id, created_at FROM summaries "
|
||||
"WHERE session_id = ?",
|
||||
(session_id,),
|
||||
).fetchone()
|
||||
if r is None:
|
||||
return None
|
||||
return Summary(
|
||||
session_id=r["session_id"],
|
||||
content=r["content"],
|
||||
last_exchange_id=r["last_exchange_id"],
|
||||
created_at=r["created_at"],
|
||||
)
|
||||
|
||||
|
||||
def unsummarized_count(session_id: str) -> int:
|
||||
"""How many exchanges in this session are newer than its current summary."""
|
||||
conn = _connection()
|
||||
summary = get_summary(session_id)
|
||||
cutoff = summary.last_exchange_id if summary else 0
|
||||
r = conn.execute(
|
||||
"SELECT COUNT(*) AS n FROM exchanges WHERE session_id = ? AND id > ?",
|
||||
(session_id, cutoff),
|
||||
).fetchone()
|
||||
return int(r["n"])
|
||||
|
||||
|
||||
def recall_summaries(query: str, k: int = 3, exclude_session: str | None = None) -> list[Summary]:
|
||||
"""Top-k session summaries most similar to `query` (the long-term gist tier)."""
|
||||
[q_vec] = llm.embed([query])
|
||||
q = np.asarray(q_vec, dtype=np.float32)
|
||||
|
||||
conn = _connection()
|
||||
sql = "SELECT session_id, content, embedding, last_exchange_id, created_at FROM summaries"
|
||||
params: tuple = ()
|
||||
if exclude_session is not None:
|
||||
sql += " WHERE session_id != ?"
|
||||
params = (exclude_session,)
|
||||
rows = conn.execute(sql, params).fetchall()
|
||||
if not rows:
|
||||
return []
|
||||
|
||||
matrix = np.stack([_from_blob(r["embedding"]) for r in rows])
|
||||
norms = np.linalg.norm(matrix, axis=1)
|
||||
scores = (matrix @ q) / (norms * np.linalg.norm(q) + 1e-9)
|
||||
|
||||
top_idx = np.argsort(scores)[::-1][:k]
|
||||
return [
|
||||
Summary(
|
||||
session_id=rows[i]["session_id"],
|
||||
content=rows[i]["content"],
|
||||
last_exchange_id=rows[i]["last_exchange_id"],
|
||||
created_at=rows[i]["created_at"],
|
||||
score=float(scores[i]),
|
||||
)
|
||||
for i in top_idx
|
||||
]
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
"""Session summarization: compact a session's raw exchanges into a stored gist.
|
||||
|
||||
This is the compaction half of the tiered memory. Raw exchanges stay for detail
|
||||
recall; the summary is what surfaces when an *older* session is recalled later —
|
||||
"a month ago is a general idea," per the design.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from lyra import config, llm, logbus, memory
|
||||
from lyra.llm import Backend
|
||||
|
||||
# Re-summarize a session once it has accumulated this many new raw exchanges
|
||||
# beyond what its current summary covers.
|
||||
SUMMARIZE_AFTER = 20
|
||||
|
||||
_PROMPT = """You are compacting a conversation into a long-term memory record \
|
||||
(not replying to anyone). Write a concise gist of the session below: what was \
|
||||
discussed, key decisions or outcomes, concrete specifics worth keeping (names, \
|
||||
places, numbers, hands), and the user's apparent mood/state. Third person, \
|
||||
referring to the user as "Brian". 4-8 sentences. No preamble."""
|
||||
|
||||
|
||||
def _transcript(exchanges: list[memory.Exchange]) -> str:
|
||||
return "\n".join(f"{ex.role}: {ex.content}" for ex in exchanges)
|
||||
|
||||
|
||||
def summarize_session(session_id: str, backend: Backend | None = None) -> str | None:
|
||||
"""(Re)generate and store the gist for a session. Returns the summary text.
|
||||
|
||||
Returns None if the session has no exchanges. The summarizer defaults to the
|
||||
local backend so routine compaction stays free.
|
||||
"""
|
||||
exchanges = memory.history(session_id)
|
||||
if not exchanges:
|
||||
return None
|
||||
|
||||
backend = backend or config.load().summary_backend
|
||||
messages = [
|
||||
{"role": "system", "content": _PROMPT},
|
||||
{"role": "user", "content": _transcript(exchanges)},
|
||||
]
|
||||
gist = llm.complete(messages, backend=backend)
|
||||
|
||||
last_id = exchanges[-1].id
|
||||
memory.store_summary(session_id, gist, last_id)
|
||||
logbus.log(
|
||||
"info", "summarized session", session=session_id,
|
||||
exchanges=len(exchanges), backend=backend,
|
||||
)
|
||||
return gist
|
||||
|
||||
|
||||
def maybe_summarize(session_id: str, backend: Backend | None = None) -> None:
|
||||
"""Summarize the session if enough new turns have accumulated since last time."""
|
||||
if memory.unsummarized_count(session_id) >= SUMMARIZE_AFTER:
|
||||
summarize_session(session_id, backend=backend)
|
||||
+6
-1
@@ -18,7 +18,7 @@ from fastapi import FastAPI, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from lyra import chat, logbus, memory
|
||||
from lyra import chat, logbus, memory, summary
|
||||
from lyra.llm import Backend
|
||||
|
||||
|
||||
@@ -77,6 +77,11 @@ def create_app() -> FastAPI:
|
||||
memory.delete_session(session_id)
|
||||
return {"ok": True}
|
||||
|
||||
@app.post("/sessions/{session_id}/summarize")
|
||||
async def summarize(session_id: str) -> dict:
|
||||
gist = await asyncio.to_thread(summary.summarize_session, session_id)
|
||||
return {"ok": gist is not None, "summary": gist}
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(request: Request) -> dict:
|
||||
body = await request.json()
|
||||
|
||||
Reference in New Issue
Block a user