d7e2fce694
Refactor summarize_all to run LLM summarization across a thread pool (default 8 workers) while keeping all SQLite reads/writes on the main thread (the single connection is never shared across threads). Extract _summarize_transcript (transcript -> gist, no DB) for the worker. The MI50 proved far too slow for the large-transcript backfill (~29 summaries in 9h due to gfx906 prefill); on cloud gpt-4o-mini with concurrency this runs at ~30 summaries/minute (~17 min for the full backfill, ~$2). MI50 stays the chat backend where small prompts make it snappy. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
153 lines
6.0 KiB
Python
153 lines
6.0 KiB
Python
"""Session summarization: compact a session's raw exchanges into a stored gist.
|
|
|
|
This is the first consolidation stage. Raw exchanges stay for detail recall; the
|
|
summary is what surfaces when an *older* session is recalled, and it's the input
|
|
to the profile (semantic memory) and era-rollup tiers.
|
|
|
|
Long sessions are summarized in chunks, then the partial gists are merged, so a
|
|
big imported conversation doesn't blow the local model's context window.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import sys
|
|
import time
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
|
|
from lyra import config, llm, logbus, memory
|
|
from lyra.llm import Backend, Message
|
|
|
|
_RETRIES = 4
|
|
|
|
# Re-summarize a session once it has accumulated this many new raw exchanges.
|
|
SUMMARIZE_AFTER = 20
|
|
# Transcript budget per LLM call; longer sessions are chunked + merged.
|
|
MAX_TRANSCRIPT_CHARS = 24000
|
|
|
|
_PROMPT = """You are compacting a conversation into a long-term memory record \
|
|
(not replying to anyone). Write a concise gist of the session below: what was \
|
|
discussed, key decisions or outcomes, concrete specifics worth keeping (names, \
|
|
places, numbers, hands), and the user's apparent mood/state. Third person, \
|
|
referring to the user as "Brian". 4-8 sentences. No preamble."""
|
|
|
|
|
|
def _transcript(exchanges: list[memory.Exchange]) -> str:
|
|
return "\n".join(f"{ex.role}: {ex.content}" for ex in exchanges)
|
|
|
|
|
|
def _chunk(text: str, budget: int) -> list[str]:
|
|
"""Split on line boundaries into pieces under `budget` chars."""
|
|
chunks, buf, size = [], [], 0
|
|
for line in text.splitlines(keepends=True):
|
|
if size + len(line) > budget and buf:
|
|
chunks.append("".join(buf))
|
|
buf, size = [], 0
|
|
buf.append(line)
|
|
size += len(line)
|
|
if buf:
|
|
chunks.append("".join(buf))
|
|
return chunks
|
|
|
|
|
|
def _summarize_text(text: str, backend: Backend) -> str:
|
|
messages: list[Message] = [
|
|
{"role": "system", "content": _PROMPT},
|
|
{"role": "user", "content": text},
|
|
]
|
|
# Retry transient backend errors (e.g. the GPU server restarting) with backoff.
|
|
for attempt in range(_RETRIES):
|
|
try:
|
|
return llm.complete(messages, backend=backend)
|
|
except Exception as exc:
|
|
if attempt == _RETRIES - 1:
|
|
raise
|
|
logbus.log("debug", "summary retry", attempt=attempt + 1, error=str(exc)[:80])
|
|
time.sleep(5 * (attempt + 1))
|
|
raise RuntimeError("unreachable")
|
|
|
|
|
|
def _summarize_transcript(transcript: str, backend: Backend) -> str:
|
|
"""Transcript -> gist (LLM only, no DB). Chunks + merges if oversized."""
|
|
if len(transcript) <= MAX_TRANSCRIPT_CHARS:
|
|
return _summarize_text(transcript, backend)
|
|
partials = [_summarize_text(c, backend) for c in _chunk(transcript, MAX_TRANSCRIPT_CHARS)]
|
|
return _summarize_text("Partial summaries to merge:\n\n" + "\n\n".join(partials), backend)
|
|
|
|
|
|
def summarize_session(session_id: str, backend: Backend | None = None) -> str | None:
|
|
"""(Re)generate and store the gist for a session. Returns the summary text."""
|
|
exchanges = memory.history(session_id)
|
|
if not exchanges:
|
|
return None
|
|
backend = backend or config.load().summary_backend
|
|
gist = _summarize_transcript(_transcript(exchanges), backend)
|
|
memory.store_summary(session_id, gist, exchanges[-1].id)
|
|
logbus.log("info", "summarized session", session=session_id, exchanges=len(exchanges))
|
|
return gist
|
|
|
|
|
|
def maybe_summarize(session_id: str, backend: Backend | None = None) -> None:
|
|
"""Summarize the session if enough new turns have accumulated since last time."""
|
|
if memory.unsummarized_count(session_id) >= SUMMARIZE_AFTER:
|
|
summarize_session(session_id, backend=backend)
|
|
|
|
|
|
def summarize_all(
|
|
backend: Backend | None = None, limit: int | None = None, workers: int = 8
|
|
) -> dict:
|
|
"""Summarize every session that needs it. Idempotent and resumable.
|
|
|
|
LLM summarization runs concurrently across `workers` threads (great for a
|
|
cloud backend). DB reads (loading transcripts) and writes (store_summary,
|
|
which also embeds) happen on the main thread, so the single SQLite
|
|
connection is never touched from multiple threads.
|
|
"""
|
|
backend = backend or config.load().summary_backend
|
|
|
|
# Main thread: collect the work (transcripts) for sessions needing a summary.
|
|
todo: list[tuple[str, str, int]] = []
|
|
for s in memory.list_sessions():
|
|
sid = s["id"]
|
|
if memory.get_summary(sid) and memory.unsummarized_count(sid) == 0:
|
|
continue
|
|
exchanges = memory.history(sid)
|
|
if not exchanges:
|
|
continue
|
|
todo.append((sid, _transcript(exchanges), exchanges[-1].id))
|
|
if limit is not None and len(todo) >= limit:
|
|
break
|
|
|
|
done, failed = 0, 0
|
|
logbus.log("info", "summarize-all starting", todo=len(todo), backend=backend, workers=workers)
|
|
|
|
def work(item: tuple[str, str, int]) -> tuple[str, str, int]:
|
|
sid, transcript, last_id = item
|
|
return sid, _summarize_transcript(transcript, backend), last_id
|
|
|
|
with ThreadPoolExecutor(max_workers=workers) as pool:
|
|
futures = {pool.submit(work, item): item for item in todo}
|
|
for fut in as_completed(futures):
|
|
sid = futures[fut][0]
|
|
try:
|
|
_, gist, last_id = fut.result()
|
|
memory.store_summary(sid, gist, last_id) # main thread: embed + write
|
|
done += 1
|
|
except Exception as exc:
|
|
failed += 1
|
|
logbus.log("error", "summarize failed", session=sid, error=str(exc)[:120])
|
|
if (done + failed) % 25 == 0:
|
|
logbus.log("info", "summarize-all progress", done=done, failed=failed, total=len(todo))
|
|
|
|
report = {"summarized": done, "failed": failed, "total": len(todo)}
|
|
logbus.log("info", "summarize-all complete", **report)
|
|
return report
|
|
|
|
|
|
def main() -> int:
|
|
limit = int(sys.argv[1]) if len(sys.argv) > 1 else None
|
|
print(summarize_all(limit=limit))
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|