diff --git a/lyra/summary.py b/lyra/summary.py index 46c3dfa..39506d8 100644 --- a/lyra/summary.py +++ b/lyra/summary.py @@ -11,6 +11,7 @@ from __future__ import annotations import sys import time +from concurrent.futures import ThreadPoolExecutor, as_completed from lyra import config, llm, logbus, memory from lyra.llm import Backend, Message @@ -64,25 +65,23 @@ def _summarize_text(text: str, backend: Backend) -> str: raise RuntimeError("unreachable") +def _summarize_transcript(transcript: str, backend: Backend) -> str: + """Transcript -> gist (LLM only, no DB). Chunks + merges if oversized.""" + if len(transcript) <= MAX_TRANSCRIPT_CHARS: + return _summarize_text(transcript, backend) + partials = [_summarize_text(c, backend) for c in _chunk(transcript, MAX_TRANSCRIPT_CHARS)] + return _summarize_text("Partial summaries to merge:\n\n" + "\n\n".join(partials), backend) + + def summarize_session(session_id: str, backend: Backend | None = None) -> str | None: """(Re)generate and store the gist for a session. Returns the summary text.""" exchanges = memory.history(session_id) if not exchanges: return None - backend = backend or config.load().summary_backend - transcript = _transcript(exchanges) - if len(transcript) <= MAX_TRANSCRIPT_CHARS: - gist = _summarize_text(transcript, backend) - else: - partials = [_summarize_text(c, backend) for c in _chunk(transcript, MAX_TRANSCRIPT_CHARS)] - gist = _summarize_text("Partial summaries to merge:\n\n" + "\n\n".join(partials), backend) - + gist = _summarize_transcript(_transcript(exchanges), backend) memory.store_summary(session_id, gist, exchanges[-1].id) - logbus.log( - "info", "summarized session", session=session_id, - exchanges=len(exchanges), backend=backend, - ) + logbus.log("info", "summarized session", session=session_id, exchanges=len(exchanges)) return gist @@ -92,31 +91,53 @@ def maybe_summarize(session_id: str, backend: Backend | None = None) -> None: summarize_session(session_id, backend=backend) -def summarize_all(backend: Backend | None = None, limit: int | None = None) -> dict: - """Summarize every session that needs it. Idempotent and resumable: sessions - with an up-to-date summary are skipped, so re-running continues where it left off. +def summarize_all( + backend: Backend | None = None, limit: int | None = None, workers: int = 8 +) -> dict: + """Summarize every session that needs it. Idempotent and resumable. + + LLM summarization runs concurrently across `workers` threads (great for a + cloud backend). DB reads (loading transcripts) and writes (store_summary, + which also embeds) happen on the main thread, so the single SQLite + connection is never touched from multiple threads. """ - sessions = memory.list_sessions() - done, skipped, failed = 0, 0, 0 - for s in sessions: + backend = backend or config.load().summary_backend + + # Main thread: collect the work (transcripts) for sessions needing a summary. + todo: list[tuple[str, str, int]] = [] + for s in memory.list_sessions(): sid = s["id"] if memory.get_summary(sid) and memory.unsummarized_count(sid) == 0: - skipped += 1 continue - try: - summarize_session(sid, backend=backend) - except Exception as exc: - # Don't let one bad session kill the batch; log and move on (it'll - # be retried on the next run, since it stays unsummarized). - failed += 1 - logbus.log("error", "summarize failed", session=sid, error=str(exc)[:120]) + exchanges = memory.history(sid) + if not exchanges: continue - done += 1 - if done % 25 == 0: - logbus.log("info", "summarize-all progress", summarized=done, skipped=skipped, failed=failed) - if limit is not None and done >= limit: + todo.append((sid, _transcript(exchanges), exchanges[-1].id)) + if limit is not None and len(todo) >= limit: break - report = {"summarized": done, "skipped": skipped, "failed": failed, "total": len(sessions)} + + done, failed = 0, 0 + logbus.log("info", "summarize-all starting", todo=len(todo), backend=backend, workers=workers) + + def work(item: tuple[str, str, int]) -> tuple[str, str, int]: + sid, transcript, last_id = item + return sid, _summarize_transcript(transcript, backend), last_id + + with ThreadPoolExecutor(max_workers=workers) as pool: + futures = {pool.submit(work, item): item for item in todo} + for fut in as_completed(futures): + sid = futures[fut][0] + try: + _, gist, last_id = fut.result() + memory.store_summary(sid, gist, last_id) # main thread: embed + write + done += 1 + except Exception as exc: + failed += 1 + logbus.log("error", "summarize failed", session=sid, error=str(exc)[:120]) + if (done + failed) % 25 == 0: + logbus.log("info", "summarize-all progress", done=done, failed=failed, total=len(todo)) + + report = {"summarized": done, "failed": failed, "total": len(todo)} logbus.log("info", "summarize-all complete", **report) return report