- Added `trillium.py` for searching and creating notes with Trillium's ETAPI. - Implemented `search_notes` and `create_note` functions with appropriate error handling and validation. feat: Add web search functionality using DuckDuckGo - Introduced `web_search.py` for performing web searches without API keys. - Implemented `search_web` function with result handling and validation. feat: Create provider-agnostic function caller for iterative tool calling - Developed `function_caller.py` to manage LLM interactions with tools. - Implemented iterative calling logic with error handling and tool execution. feat: Establish a tool registry for managing available tools - Created `registry.py` to define and manage tool availability and execution. - Integrated feature flags for enabling/disabling tools based on environment variables. feat: Implement event streaming for tool calling processes - Added `stream_events.py` to manage Server-Sent Events (SSE) for tool calling. - Enabled real-time updates during tool execution for enhanced user experience. test: Add tests for tool calling system components - Created `test_tools.py` to validate functionality of code execution, web search, and tool registry. - Implemented asynchronous tests to ensure proper execution and result handling. chore: Add Dockerfile for sandbox environment setup - Created `Dockerfile` to set up a Python environment with necessary dependencies for code execution. chore: Add debug regex script for testing XML parsing - Introduced `debug_regex.py` to validate regex patterns against XML tool calls. chore: Add HTML template for displaying thinking stream events - Created `test_thinking_stream.html` for visualizing tool calling events in a user-friendly format. test: Add tests for OllamaAdapter XML parsing - Developed `test_ollama_parser.py` to validate XML parsing with various test cases, including malformed XML.
92 lines
2.9 KiB
Python
92 lines
2.9 KiB
Python
"""
|
|
Event streaming for tool calling "show your work" feature.
|
|
|
|
This module manages Server-Sent Events (SSE) for broadcasting the internal
|
|
thinking process during tool calling operations.
|
|
"""
|
|
|
|
import asyncio
|
|
from typing import Dict, Optional
|
|
from collections import defaultdict
|
|
import json
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ToolStreamManager:
|
|
"""Manages SSE streams for tool calling events."""
|
|
|
|
def __init__(self):
|
|
# session_id -> list of queues (one per connected client)
|
|
self._subscribers: Dict[str, list] = defaultdict(list)
|
|
|
|
def subscribe(self, session_id: str) -> asyncio.Queue:
|
|
"""Subscribe to events for a session.
|
|
|
|
Returns:
|
|
Queue that will receive events for this session
|
|
"""
|
|
queue = asyncio.Queue()
|
|
self._subscribers[session_id].append(queue)
|
|
logger.info(f"New subscriber for session {session_id}, total: {len(self._subscribers[session_id])}")
|
|
return queue
|
|
|
|
def unsubscribe(self, session_id: str, queue: asyncio.Queue):
|
|
"""Unsubscribe from events for a session."""
|
|
if session_id in self._subscribers:
|
|
try:
|
|
self._subscribers[session_id].remove(queue)
|
|
logger.info(f"Removed subscriber for session {session_id}, remaining: {len(self._subscribers[session_id])}")
|
|
|
|
# Clean up empty lists
|
|
if not self._subscribers[session_id]:
|
|
del self._subscribers[session_id]
|
|
except ValueError:
|
|
pass
|
|
|
|
async def emit(self, session_id: str, event_type: str, data: dict):
|
|
"""Emit an event to all subscribers of a session.
|
|
|
|
Args:
|
|
session_id: Session to emit to
|
|
event_type: Type of event (thinking, tool_call, tool_result, done)
|
|
data: Event data
|
|
"""
|
|
if session_id not in self._subscribers:
|
|
return
|
|
|
|
event = {
|
|
"type": event_type,
|
|
"data": data
|
|
}
|
|
|
|
# Send to all subscribers
|
|
dead_queues = []
|
|
for queue in self._subscribers[session_id]:
|
|
try:
|
|
await queue.put(event)
|
|
except Exception as e:
|
|
logger.error(f"Failed to emit event to queue: {e}")
|
|
dead_queues.append(queue)
|
|
|
|
# Clean up dead queues
|
|
for queue in dead_queues:
|
|
self.unsubscribe(session_id, queue)
|
|
|
|
def has_subscribers(self, session_id: str) -> bool:
|
|
"""Check if a session has any active subscribers."""
|
|
return session_id in self._subscribers and len(self._subscribers[session_id]) > 0
|
|
|
|
|
|
# Global stream manager instance
|
|
_stream_manager: Optional[ToolStreamManager] = None
|
|
|
|
|
|
def get_stream_manager() -> ToolStreamManager:
|
|
"""Get the global stream manager instance."""
|
|
global _stream_manager
|
|
if _stream_manager is None:
|
|
_stream_manager = ToolStreamManager()
|
|
return _stream_manager
|