perf: concurrent summarize-all (parallel LLM, serial DB)
Refactor summarize_all to run LLM summarization across a thread pool (default 8 workers) while keeping all SQLite reads/writes on the main thread (the single connection is never shared across threads). Extract _summarize_transcript (transcript -> gist, no DB) for the worker. The MI50 proved far too slow for the large-transcript backfill (~29 summaries in 9h due to gfx906 prefill); on cloud gpt-4o-mini with concurrency this runs at ~30 summaries/minute (~17 min for the full backfill, ~$2). MI50 stays the chat backend where small prompts make it snappy. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
+50
-29
@@ -11,6 +11,7 @@ from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
from lyra import config, llm, logbus, memory
|
||||
from lyra.llm import Backend, Message
|
||||
@@ -64,25 +65,23 @@ def _summarize_text(text: str, backend: Backend) -> str:
|
||||
raise RuntimeError("unreachable")
|
||||
|
||||
|
||||
def _summarize_transcript(transcript: str, backend: Backend) -> str:
|
||||
"""Transcript -> gist (LLM only, no DB). Chunks + merges if oversized."""
|
||||
if len(transcript) <= MAX_TRANSCRIPT_CHARS:
|
||||
return _summarize_text(transcript, backend)
|
||||
partials = [_summarize_text(c, backend) for c in _chunk(transcript, MAX_TRANSCRIPT_CHARS)]
|
||||
return _summarize_text("Partial summaries to merge:\n\n" + "\n\n".join(partials), backend)
|
||||
|
||||
|
||||
def summarize_session(session_id: str, backend: Backend | None = None) -> str | None:
|
||||
"""(Re)generate and store the gist for a session. Returns the summary text."""
|
||||
exchanges = memory.history(session_id)
|
||||
if not exchanges:
|
||||
return None
|
||||
|
||||
backend = backend or config.load().summary_backend
|
||||
transcript = _transcript(exchanges)
|
||||
if len(transcript) <= MAX_TRANSCRIPT_CHARS:
|
||||
gist = _summarize_text(transcript, backend)
|
||||
else:
|
||||
partials = [_summarize_text(c, backend) for c in _chunk(transcript, MAX_TRANSCRIPT_CHARS)]
|
||||
gist = _summarize_text("Partial summaries to merge:\n\n" + "\n\n".join(partials), backend)
|
||||
|
||||
gist = _summarize_transcript(_transcript(exchanges), backend)
|
||||
memory.store_summary(session_id, gist, exchanges[-1].id)
|
||||
logbus.log(
|
||||
"info", "summarized session", session=session_id,
|
||||
exchanges=len(exchanges), backend=backend,
|
||||
)
|
||||
logbus.log("info", "summarized session", session=session_id, exchanges=len(exchanges))
|
||||
return gist
|
||||
|
||||
|
||||
@@ -92,31 +91,53 @@ def maybe_summarize(session_id: str, backend: Backend | None = None) -> None:
|
||||
summarize_session(session_id, backend=backend)
|
||||
|
||||
|
||||
def summarize_all(backend: Backend | None = None, limit: int | None = None) -> dict:
|
||||
"""Summarize every session that needs it. Idempotent and resumable: sessions
|
||||
with an up-to-date summary are skipped, so re-running continues where it left off.
|
||||
def summarize_all(
|
||||
backend: Backend | None = None, limit: int | None = None, workers: int = 8
|
||||
) -> dict:
|
||||
"""Summarize every session that needs it. Idempotent and resumable.
|
||||
|
||||
LLM summarization runs concurrently across `workers` threads (great for a
|
||||
cloud backend). DB reads (loading transcripts) and writes (store_summary,
|
||||
which also embeds) happen on the main thread, so the single SQLite
|
||||
connection is never touched from multiple threads.
|
||||
"""
|
||||
sessions = memory.list_sessions()
|
||||
done, skipped, failed = 0, 0, 0
|
||||
for s in sessions:
|
||||
backend = backend or config.load().summary_backend
|
||||
|
||||
# Main thread: collect the work (transcripts) for sessions needing a summary.
|
||||
todo: list[tuple[str, str, int]] = []
|
||||
for s in memory.list_sessions():
|
||||
sid = s["id"]
|
||||
if memory.get_summary(sid) and memory.unsummarized_count(sid) == 0:
|
||||
skipped += 1
|
||||
continue
|
||||
exchanges = memory.history(sid)
|
||||
if not exchanges:
|
||||
continue
|
||||
todo.append((sid, _transcript(exchanges), exchanges[-1].id))
|
||||
if limit is not None and len(todo) >= limit:
|
||||
break
|
||||
|
||||
done, failed = 0, 0
|
||||
logbus.log("info", "summarize-all starting", todo=len(todo), backend=backend, workers=workers)
|
||||
|
||||
def work(item: tuple[str, str, int]) -> tuple[str, str, int]:
|
||||
sid, transcript, last_id = item
|
||||
return sid, _summarize_transcript(transcript, backend), last_id
|
||||
|
||||
with ThreadPoolExecutor(max_workers=workers) as pool:
|
||||
futures = {pool.submit(work, item): item for item in todo}
|
||||
for fut in as_completed(futures):
|
||||
sid = futures[fut][0]
|
||||
try:
|
||||
summarize_session(sid, backend=backend)
|
||||
_, gist, last_id = fut.result()
|
||||
memory.store_summary(sid, gist, last_id) # main thread: embed + write
|
||||
done += 1
|
||||
except Exception as exc:
|
||||
# Don't let one bad session kill the batch; log and move on (it'll
|
||||
# be retried on the next run, since it stays unsummarized).
|
||||
failed += 1
|
||||
logbus.log("error", "summarize failed", session=sid, error=str(exc)[:120])
|
||||
continue
|
||||
done += 1
|
||||
if done % 25 == 0:
|
||||
logbus.log("info", "summarize-all progress", summarized=done, skipped=skipped, failed=failed)
|
||||
if limit is not None and done >= limit:
|
||||
break
|
||||
report = {"summarized": done, "skipped": skipped, "failed": failed, "total": len(sessions)}
|
||||
if (done + failed) % 25 == 0:
|
||||
logbus.log("info", "summarize-all progress", done=done, failed=failed, total=len(todo))
|
||||
|
||||
report = {"summarized": done, "failed": failed, "total": len(todo)}
|
||||
logbus.log("info", "summarize-all complete", **report)
|
||||
return report
|
||||
|
||||
|
||||
Reference in New Issue
Block a user