""" Event streaming for tool calling "show your work" feature. This module manages Server-Sent Events (SSE) for broadcasting the internal thinking process during tool calling operations. """ import asyncio from typing import Dict, Optional from collections import defaultdict import json import logging logger = logging.getLogger(__name__) class ToolStreamManager: """Manages SSE streams for tool calling events.""" def __init__(self): # session_id -> list of queues (one per connected client) self._subscribers: Dict[str, list] = defaultdict(list) def subscribe(self, session_id: str) -> asyncio.Queue: """Subscribe to events for a session. Returns: Queue that will receive events for this session """ queue = asyncio.Queue() self._subscribers[session_id].append(queue) logger.info(f"New subscriber for session {session_id}, total: {len(self._subscribers[session_id])}") return queue def unsubscribe(self, session_id: str, queue: asyncio.Queue): """Unsubscribe from events for a session.""" if session_id in self._subscribers: try: self._subscribers[session_id].remove(queue) logger.info(f"Removed subscriber for session {session_id}, remaining: {len(self._subscribers[session_id])}") # Clean up empty lists if not self._subscribers[session_id]: del self._subscribers[session_id] except ValueError: pass async def emit(self, session_id: str, event_type: str, data: dict): """Emit an event to all subscribers of a session. Args: session_id: Session to emit to event_type: Type of event (thinking, tool_call, tool_result, done) data: Event data """ if session_id not in self._subscribers: return event = { "type": event_type, "data": data } # Send to all subscribers dead_queues = [] for queue in self._subscribers[session_id]: try: await queue.put(event) except Exception as e: logger.error(f"Failed to emit event to queue: {e}") dead_queues.append(queue) # Clean up dead queues for queue in dead_queues: self.unsubscribe(session_id, queue) def has_subscribers(self, session_id: str) -> bool: """Check if a session has any active subscribers.""" return session_id in self._subscribers and len(self._subscribers[session_id]) > 0 # Global stream manager instance _stream_manager: Optional[ToolStreamManager] = None def get_stream_manager() -> ToolStreamManager: """Get the global stream manager instance.""" global _stream_manager if _stream_manager is None: _stream_manager = ToolStreamManager() return _stream_manager