feat: Implement Trillium notes executor for searching and creating notes via ETAPI
- Added `trillium.py` for searching and creating notes with Trillium's ETAPI. - Implemented `search_notes` and `create_note` functions with appropriate error handling and validation. feat: Add web search functionality using DuckDuckGo - Introduced `web_search.py` for performing web searches without API keys. - Implemented `search_web` function with result handling and validation. feat: Create provider-agnostic function caller for iterative tool calling - Developed `function_caller.py` to manage LLM interactions with tools. - Implemented iterative calling logic with error handling and tool execution. feat: Establish a tool registry for managing available tools - Created `registry.py` to define and manage tool availability and execution. - Integrated feature flags for enabling/disabling tools based on environment variables. feat: Implement event streaming for tool calling processes - Added `stream_events.py` to manage Server-Sent Events (SSE) for tool calling. - Enabled real-time updates during tool execution for enhanced user experience. test: Add tests for tool calling system components - Created `test_tools.py` to validate functionality of code execution, web search, and tool registry. - Implemented asynchronous tests to ensure proper execution and result handling. chore: Add Dockerfile for sandbox environment setup - Created `Dockerfile` to set up a Python environment with necessary dependencies for code execution. chore: Add debug regex script for testing XML parsing - Introduced `debug_regex.py` to validate regex patterns against XML tool calls. chore: Add HTML template for displaying thinking stream events - Created `test_thinking_stream.html` for visualizing tool calling events in a user-friendly format. test: Add tests for OllamaAdapter XML parsing - Developed `test_ollama_parser.py` to validate XML parsing with various test cases, including malformed XML.
This commit is contained in:
13
cortex/autonomy/tools/adapters/__init__.py
Normal file
13
cortex/autonomy/tools/adapters/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""Provider adapters for tool calling."""
|
||||
|
||||
from .base import ToolAdapter
|
||||
from .openai_adapter import OpenAIAdapter
|
||||
from .ollama_adapter import OllamaAdapter
|
||||
from .llamacpp_adapter import LlamaCppAdapter
|
||||
|
||||
__all__ = [
|
||||
"ToolAdapter",
|
||||
"OpenAIAdapter",
|
||||
"OllamaAdapter",
|
||||
"LlamaCppAdapter",
|
||||
]
|
||||
79
cortex/autonomy/tools/adapters/base.py
Normal file
79
cortex/autonomy/tools/adapters/base.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""
|
||||
Base adapter interface for provider-agnostic tool calling.
|
||||
|
||||
This module defines the abstract base class that all LLM provider adapters
|
||||
must implement to support tool calling in Lyra.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
|
||||
class ToolAdapter(ABC):
|
||||
"""Base class for provider-specific tool adapters.
|
||||
|
||||
Each LLM provider (OpenAI, Ollama, llama.cpp, etc.) has its own
|
||||
way of handling tool calls. This adapter pattern allows Lyra to
|
||||
support tools across all providers with a unified interface.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def prepare_request(
|
||||
self,
|
||||
messages: List[Dict],
|
||||
tools: List[Dict],
|
||||
tool_choice: Optional[str] = None
|
||||
) -> Dict:
|
||||
"""Convert Lyra tool definitions to provider-specific format.
|
||||
|
||||
Args:
|
||||
messages: Conversation history in OpenAI format
|
||||
tools: List of Lyra tool definitions (provider-agnostic)
|
||||
tool_choice: Optional tool forcing ("auto", "required", "none")
|
||||
|
||||
Returns:
|
||||
dict: Provider-specific request payload ready to send to LLM
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def parse_response(self, response) -> Dict:
|
||||
"""Extract tool calls from provider response.
|
||||
|
||||
Args:
|
||||
response: Raw provider response (format varies by provider)
|
||||
|
||||
Returns:
|
||||
dict: Standardized response in Lyra format:
|
||||
{
|
||||
"content": str, # Assistant's text response
|
||||
"tool_calls": [ # List of tool calls or None
|
||||
{
|
||||
"id": str, # Unique call ID
|
||||
"name": str, # Tool name
|
||||
"arguments": dict # Tool arguments
|
||||
}
|
||||
] or None
|
||||
}
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def format_tool_result(
|
||||
self,
|
||||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
result: Dict
|
||||
) -> Dict:
|
||||
"""Format tool execution result for next LLM call.
|
||||
|
||||
Args:
|
||||
tool_call_id: ID from the original tool call
|
||||
tool_name: Name of the executed tool
|
||||
result: Tool execution result dictionary
|
||||
|
||||
Returns:
|
||||
dict: Message object to append to conversation
|
||||
(format varies by provider)
|
||||
"""
|
||||
pass
|
||||
17
cortex/autonomy/tools/adapters/llamacpp_adapter.py
Normal file
17
cortex/autonomy/tools/adapters/llamacpp_adapter.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
llama.cpp adapter for tool calling.
|
||||
|
||||
Since llama.cpp has similar constraints to Ollama (no native function calling),
|
||||
this adapter reuses the XML-based approach from OllamaAdapter.
|
||||
"""
|
||||
|
||||
from .ollama_adapter import OllamaAdapter
|
||||
|
||||
|
||||
class LlamaCppAdapter(OllamaAdapter):
|
||||
"""llama.cpp adapter - uses same XML approach as Ollama.
|
||||
|
||||
llama.cpp doesn't have native function calling support, so we use
|
||||
the same XML-based prompt engineering approach as Ollama.
|
||||
"""
|
||||
pass
|
||||
191
cortex/autonomy/tools/adapters/ollama_adapter.py
Normal file
191
cortex/autonomy/tools/adapters/ollama_adapter.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""
|
||||
Ollama adapter for tool calling using XML-structured prompts.
|
||||
|
||||
Since Ollama doesn't have native function calling, this adapter uses
|
||||
XML-based prompts to instruct the model how to call tools.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Dict, List, Optional
|
||||
from .base import ToolAdapter
|
||||
|
||||
|
||||
class OllamaAdapter(ToolAdapter):
|
||||
"""Ollama adapter using XML-structured prompts for tool calling.
|
||||
|
||||
This adapter injects tool descriptions into the system prompt and
|
||||
teaches the model to respond with XML when it wants to use a tool.
|
||||
"""
|
||||
|
||||
SYSTEM_PROMPT = """You have access to the following tools:
|
||||
|
||||
{tool_descriptions}
|
||||
|
||||
To use a tool, respond with XML in this exact format:
|
||||
<tool_call>
|
||||
<name>tool_name</name>
|
||||
<arguments>
|
||||
<arg_name>value</arg_name>
|
||||
</arguments>
|
||||
<reason>why you're using this tool</reason>
|
||||
</tool_call>
|
||||
|
||||
You can call multiple tools by including multiple <tool_call> blocks.
|
||||
If you don't need to use any tools, respond normally without XML.
|
||||
After tools are executed, you'll receive results and can continue the conversation."""
|
||||
|
||||
async def prepare_request(
|
||||
self,
|
||||
messages: List[Dict],
|
||||
tools: List[Dict],
|
||||
tool_choice: Optional[str] = None
|
||||
) -> Dict:
|
||||
"""Inject tool descriptions into system prompt.
|
||||
|
||||
Args:
|
||||
messages: Conversation history
|
||||
tools: Lyra tool definitions
|
||||
tool_choice: Ignored for Ollama (no native support)
|
||||
|
||||
Returns:
|
||||
dict: Request payload with modified messages
|
||||
"""
|
||||
# Format tool descriptions
|
||||
tool_desc = "\n".join([
|
||||
f"- {t['name']}: {t['description']}\n Parameters: {self._format_parameters(t['parameters'], t.get('required', []))}"
|
||||
for t in tools
|
||||
])
|
||||
|
||||
system_msg = self.SYSTEM_PROMPT.format(tool_descriptions=tool_desc)
|
||||
|
||||
# Check if first message is already a system message
|
||||
modified_messages = messages.copy()
|
||||
if modified_messages and modified_messages[0].get("role") == "system":
|
||||
# Prepend tool instructions to existing system message
|
||||
modified_messages[0]["content"] = system_msg + "\n\n" + modified_messages[0]["content"]
|
||||
else:
|
||||
# Add new system message at the beginning
|
||||
modified_messages.insert(0, {"role": "system", "content": system_msg})
|
||||
|
||||
return {"messages": modified_messages}
|
||||
|
||||
def _format_parameters(self, parameters: Dict, required: List[str]) -> str:
|
||||
"""Format parameters for tool description.
|
||||
|
||||
Args:
|
||||
parameters: Parameter definitions
|
||||
required: List of required parameter names
|
||||
|
||||
Returns:
|
||||
str: Human-readable parameter description
|
||||
"""
|
||||
param_strs = []
|
||||
for name, spec in parameters.items():
|
||||
req_marker = "(required)" if name in required else "(optional)"
|
||||
param_strs.append(f"{name} {req_marker}: {spec.get('description', '')}")
|
||||
return ", ".join(param_strs)
|
||||
|
||||
async def parse_response(self, response) -> Dict:
|
||||
"""Extract tool calls from XML in response.
|
||||
|
||||
Args:
|
||||
response: String response from Ollama
|
||||
|
||||
Returns:
|
||||
dict: Standardized Lyra format with content and tool_calls
|
||||
"""
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Ollama returns a string
|
||||
if isinstance(response, dict):
|
||||
content = response.get("message", {}).get("content", "")
|
||||
else:
|
||||
content = str(response)
|
||||
|
||||
logger.info(f"🔍 OllamaAdapter.parse_response: content length={len(content)}, has <tool_call>={('<tool_call>' in content)}")
|
||||
logger.debug(f"🔍 Content preview: {content[:500]}")
|
||||
|
||||
# Parse XML tool calls
|
||||
tool_calls = []
|
||||
if "<tool_call>" in content:
|
||||
# Split content by <tool_call> to get each block
|
||||
blocks = content.split('<tool_call>')
|
||||
logger.info(f"🔍 Split into {len(blocks)} blocks")
|
||||
|
||||
# First block is content before any tool calls
|
||||
clean_parts = [blocks[0]]
|
||||
|
||||
for idx, block in enumerate(blocks[1:]): # Skip first block (pre-tool content)
|
||||
# Extract tool name
|
||||
name_match = re.search(r'<name>(.*?)</name>', block)
|
||||
if not name_match:
|
||||
logger.warning(f"Block {idx} has no <name> tag, skipping")
|
||||
continue
|
||||
|
||||
name = name_match.group(1).strip()
|
||||
arguments = {}
|
||||
|
||||
# Extract arguments
|
||||
args_match = re.search(r'<arguments>(.*?)</arguments>', block, re.DOTALL)
|
||||
if args_match:
|
||||
args_xml = args_match.group(1)
|
||||
# Parse <key>value</key> pairs
|
||||
arg_pairs = re.findall(r'<(\w+)>(.*?)</\1>', args_xml, re.DOTALL)
|
||||
arguments = {k: v.strip() for k, v in arg_pairs}
|
||||
|
||||
tool_calls.append({
|
||||
"id": f"call_{idx}",
|
||||
"name": name,
|
||||
"arguments": arguments
|
||||
})
|
||||
|
||||
# For clean content, find what comes AFTER the tool call block
|
||||
# Look for the last closing tag (</tool_call> or malformed </xxx>) and keep what's after
|
||||
# Split by any closing tag at the END of the tool block
|
||||
remaining = block
|
||||
# Remove everything up to and including a standalone closing tag
|
||||
# Pattern: find </something> that's not followed by more XML
|
||||
end_match = re.search(r'</[a-z_]+>\s*(.*)$', remaining, re.DOTALL)
|
||||
if end_match:
|
||||
after_content = end_match.group(1).strip()
|
||||
if after_content and not after_content.startswith('<'):
|
||||
# Only keep if it's actual text content, not more XML
|
||||
clean_parts.append(after_content)
|
||||
|
||||
clean_content = ''.join(clean_parts).strip()
|
||||
else:
|
||||
clean_content = content
|
||||
|
||||
return {
|
||||
"content": clean_content,
|
||||
"tool_calls": tool_calls if tool_calls else None
|
||||
}
|
||||
|
||||
def format_tool_result(
|
||||
self,
|
||||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
result: Dict
|
||||
) -> Dict:
|
||||
"""Format tool result as XML for next prompt.
|
||||
|
||||
Args:
|
||||
tool_call_id: ID from the original tool call
|
||||
tool_name: Name of the executed tool
|
||||
result: Tool execution result
|
||||
|
||||
Returns:
|
||||
dict: Message in user role with XML-formatted result
|
||||
"""
|
||||
# Format result as XML
|
||||
result_xml = f"""<tool_result>
|
||||
<tool>{tool_name}</tool>
|
||||
<result>{json.dumps(result, ensure_ascii=False)}</result>
|
||||
</tool_result>"""
|
||||
|
||||
return {
|
||||
"role": "user",
|
||||
"content": result_xml
|
||||
}
|
||||
130
cortex/autonomy/tools/adapters/openai_adapter.py
Normal file
130
cortex/autonomy/tools/adapters/openai_adapter.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""
|
||||
OpenAI adapter for tool calling using native function calling API.
|
||||
|
||||
This adapter converts Lyra tool definitions to OpenAI's function calling
|
||||
format and parses OpenAI responses back to Lyra's standardized format.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
from .base import ToolAdapter
|
||||
|
||||
|
||||
class OpenAIAdapter(ToolAdapter):
|
||||
"""OpenAI-specific adapter using native function calling.
|
||||
|
||||
OpenAI supports function calling natively through the 'tools' parameter
|
||||
in chat completions. This adapter leverages that capability.
|
||||
"""
|
||||
|
||||
async def prepare_request(
|
||||
self,
|
||||
messages: List[Dict],
|
||||
tools: List[Dict],
|
||||
tool_choice: Optional[str] = None
|
||||
) -> Dict:
|
||||
"""Convert Lyra tools to OpenAI function calling format.
|
||||
|
||||
Args:
|
||||
messages: Conversation history
|
||||
tools: Lyra tool definitions
|
||||
tool_choice: "auto", "required", "none", or None
|
||||
|
||||
Returns:
|
||||
dict: Request payload with OpenAI-formatted tools
|
||||
"""
|
||||
# Convert Lyra tools → OpenAI function calling format
|
||||
openai_tools = []
|
||||
for tool in tools:
|
||||
openai_tools.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool["name"],
|
||||
"description": tool["description"],
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": tool["parameters"],
|
||||
"required": tool.get("required", [])
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
payload = {
|
||||
"messages": messages,
|
||||
"tools": openai_tools
|
||||
}
|
||||
|
||||
# Add tool_choice if specified
|
||||
if tool_choice:
|
||||
if tool_choice == "required":
|
||||
payload["tool_choice"] = "required"
|
||||
elif tool_choice == "none":
|
||||
payload["tool_choice"] = "none"
|
||||
else: # "auto" or default
|
||||
payload["tool_choice"] = "auto"
|
||||
|
||||
return payload
|
||||
|
||||
async def parse_response(self, response) -> Dict:
|
||||
"""Extract tool calls from OpenAI response.
|
||||
|
||||
Args:
|
||||
response: OpenAI ChatCompletion response object
|
||||
|
||||
Returns:
|
||||
dict: Standardized Lyra format with content and tool_calls
|
||||
"""
|
||||
message = response.choices[0].message
|
||||
content = message.content if message.content else ""
|
||||
tool_calls = []
|
||||
|
||||
# Check if response contains tool calls
|
||||
if hasattr(message, 'tool_calls') and message.tool_calls:
|
||||
for tc in message.tool_calls:
|
||||
try:
|
||||
# Parse arguments (may be JSON string)
|
||||
args = tc.function.arguments
|
||||
if isinstance(args, str):
|
||||
args = json.loads(args)
|
||||
|
||||
tool_calls.append({
|
||||
"id": tc.id,
|
||||
"name": tc.function.name,
|
||||
"arguments": args
|
||||
})
|
||||
except json.JSONDecodeError as e:
|
||||
# If arguments can't be parsed, include error
|
||||
tool_calls.append({
|
||||
"id": tc.id,
|
||||
"name": tc.function.name,
|
||||
"arguments": {},
|
||||
"error": f"Failed to parse arguments: {str(e)}"
|
||||
})
|
||||
|
||||
return {
|
||||
"content": content,
|
||||
"tool_calls": tool_calls if tool_calls else None
|
||||
}
|
||||
|
||||
def format_tool_result(
|
||||
self,
|
||||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
result: Dict
|
||||
) -> Dict:
|
||||
"""Format tool result as OpenAI tool message.
|
||||
|
||||
Args:
|
||||
tool_call_id: ID from the original tool call
|
||||
tool_name: Name of the executed tool
|
||||
result: Tool execution result
|
||||
|
||||
Returns:
|
||||
dict: Message in OpenAI tool message format
|
||||
"""
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"name": tool_name,
|
||||
"content": json.dumps(result, ensure_ascii=False)
|
||||
}
|
||||
12
cortex/autonomy/tools/executors/__init__.py
Normal file
12
cortex/autonomy/tools/executors/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""Tool executors for Lyra."""
|
||||
|
||||
from .code_executor import execute_code
|
||||
from .web_search import search_web
|
||||
from .trillium import search_notes, create_note
|
||||
|
||||
__all__ = [
|
||||
"execute_code",
|
||||
"search_web",
|
||||
"search_notes",
|
||||
"create_note",
|
||||
]
|
||||
162
cortex/autonomy/tools/executors/code_executor.py
Normal file
162
cortex/autonomy/tools/executors/code_executor.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
Code executor for running Python and bash code in a sandbox container.
|
||||
|
||||
This module provides secure code execution with timeout protection,
|
||||
output limits, and forbidden pattern detection.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
import re
|
||||
from typing import Dict
|
||||
|
||||
|
||||
# Forbidden patterns that pose security risks
|
||||
FORBIDDEN_PATTERNS = [
|
||||
r'rm\s+-rf', # Destructive file removal
|
||||
r':\(\)\{\s*:\|:&\s*\};:', # Fork bomb
|
||||
r'mkfs', # Filesystem formatting
|
||||
r'/dev/sd[a-z]', # Direct device access
|
||||
r'dd\s+if=', # Low-level disk operations
|
||||
r'>\s*/dev/sd', # Writing to devices
|
||||
r'curl.*\|.*sh', # Pipe to shell (common attack vector)
|
||||
r'wget.*\|.*sh', # Pipe to shell
|
||||
]
|
||||
|
||||
|
||||
async def execute_code(args: Dict) -> Dict:
|
||||
"""Execute code in sandbox container.
|
||||
|
||||
Args:
|
||||
args: Dictionary containing:
|
||||
- language (str): "python" or "bash"
|
||||
- code (str): The code to execute
|
||||
- reason (str): Why this code is being executed
|
||||
- timeout (int, optional): Execution timeout in seconds
|
||||
|
||||
Returns:
|
||||
dict: Execution result containing:
|
||||
- stdout (str): Standard output
|
||||
- stderr (str): Standard error
|
||||
- exit_code (int): Process exit code
|
||||
- execution_time (float): Time taken in seconds
|
||||
OR
|
||||
- error (str): Error message if execution failed
|
||||
"""
|
||||
language = args.get("language")
|
||||
code = args.get("code")
|
||||
reason = args.get("reason", "No reason provided")
|
||||
timeout = args.get("timeout", 30)
|
||||
|
||||
# Validation
|
||||
if not language or language not in ["python", "bash"]:
|
||||
return {"error": "Invalid language. Must be 'python' or 'bash'"}
|
||||
|
||||
if not code:
|
||||
return {"error": "No code provided"}
|
||||
|
||||
# Security: Check for forbidden patterns
|
||||
for pattern in FORBIDDEN_PATTERNS:
|
||||
if re.search(pattern, code, re.IGNORECASE):
|
||||
return {"error": f"Forbidden pattern detected for security reasons"}
|
||||
|
||||
# Validate and cap timeout
|
||||
max_timeout = int(os.getenv("CODE_SANDBOX_MAX_TIMEOUT", "120"))
|
||||
timeout = min(max(timeout, 1), max_timeout)
|
||||
|
||||
container = os.getenv("CODE_SANDBOX_CONTAINER", "lyra-code-sandbox")
|
||||
|
||||
# Write code to temporary file
|
||||
suffix = ".py" if language == "python" else ".sh"
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode='w',
|
||||
suffix=suffix,
|
||||
delete=False,
|
||||
encoding='utf-8'
|
||||
) as f:
|
||||
f.write(code)
|
||||
temp_file = f.name
|
||||
except Exception as e:
|
||||
return {"error": f"Failed to create temp file: {str(e)}"}
|
||||
|
||||
try:
|
||||
# Copy file to container
|
||||
exec_path = f"/executions/{os.path.basename(temp_file)}"
|
||||
|
||||
cp_proc = await asyncio.create_subprocess_exec(
|
||||
"docker", "cp", temp_file, f"{container}:{exec_path}",
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
await cp_proc.communicate()
|
||||
|
||||
if cp_proc.returncode != 0:
|
||||
return {"error": "Failed to copy code to sandbox container"}
|
||||
|
||||
# Fix permissions so sandbox user can read the file (run as root)
|
||||
chown_proc = await asyncio.create_subprocess_exec(
|
||||
"docker", "exec", "-u", "root", container, "chown", "sandbox:sandbox", exec_path,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
await chown_proc.communicate()
|
||||
|
||||
# Execute in container as sandbox user
|
||||
if language == "python":
|
||||
cmd = ["docker", "exec", "-u", "sandbox", container, "python3", exec_path]
|
||||
else: # bash
|
||||
cmd = ["docker", "exec", "-u", "sandbox", container, "bash", exec_path]
|
||||
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(
|
||||
proc.communicate(),
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
execution_time = asyncio.get_event_loop().time() - start_time
|
||||
|
||||
# Truncate output to prevent memory issues
|
||||
max_output = 10 * 1024 # 10KB
|
||||
stdout_str = stdout[:max_output].decode('utf-8', errors='replace')
|
||||
stderr_str = stderr[:max_output].decode('utf-8', errors='replace')
|
||||
|
||||
if len(stdout) > max_output:
|
||||
stdout_str += "\n... (output truncated)"
|
||||
if len(stderr) > max_output:
|
||||
stderr_str += "\n... (output truncated)"
|
||||
|
||||
return {
|
||||
"stdout": stdout_str,
|
||||
"stderr": stderr_str,
|
||||
"exit_code": proc.returncode,
|
||||
"execution_time": round(execution_time, 2)
|
||||
}
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# Kill the process
|
||||
try:
|
||||
proc.kill()
|
||||
await proc.wait()
|
||||
except:
|
||||
pass
|
||||
return {"error": f"Execution timeout after {timeout}s"}
|
||||
|
||||
except Exception as e:
|
||||
return {"error": f"Execution failed: {str(e)}"}
|
||||
|
||||
finally:
|
||||
# Cleanup temporary file
|
||||
try:
|
||||
os.unlink(temp_file)
|
||||
except:
|
||||
pass
|
||||
134
cortex/autonomy/tools/executors/trillium.py
Normal file
134
cortex/autonomy/tools/executors/trillium.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""
|
||||
Trillium notes executor for searching and creating notes via ETAPI.
|
||||
|
||||
This module provides integration with Trillium notes through the ETAPI HTTP API.
|
||||
"""
|
||||
|
||||
import os
|
||||
import aiohttp
|
||||
from typing import Dict
|
||||
|
||||
|
||||
TRILLIUM_URL = os.getenv("TRILLIUM_URL", "http://localhost:8080")
|
||||
TRILLIUM_TOKEN = os.getenv("TRILLIUM_ETAPI_TOKEN", "")
|
||||
|
||||
|
||||
async def search_notes(args: Dict) -> Dict:
|
||||
"""Search Trillium notes via ETAPI.
|
||||
|
||||
Args:
|
||||
args: Dictionary containing:
|
||||
- query (str): Search query
|
||||
- limit (int, optional): Maximum notes to return (default: 5, max: 20)
|
||||
|
||||
Returns:
|
||||
dict: Search results containing:
|
||||
- notes (list): List of notes with noteId, title, content, type
|
||||
- count (int): Number of notes returned
|
||||
OR
|
||||
- error (str): Error message if search failed
|
||||
"""
|
||||
query = args.get("query")
|
||||
limit = args.get("limit", 5)
|
||||
|
||||
# Validation
|
||||
if not query:
|
||||
return {"error": "No query provided"}
|
||||
|
||||
if not TRILLIUM_TOKEN:
|
||||
return {"error": "TRILLIUM_ETAPI_TOKEN not configured in environment"}
|
||||
|
||||
# Cap limit
|
||||
limit = min(max(limit, 1), 20)
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{TRILLIUM_URL}/etapi/search-notes",
|
||||
params={"search": query, "limit": limit},
|
||||
headers={"Authorization": TRILLIUM_TOKEN}
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
return {
|
||||
"notes": data,
|
||||
"count": len(data)
|
||||
}
|
||||
elif resp.status == 401:
|
||||
return {"error": "Authentication failed. Check TRILLIUM_ETAPI_TOKEN"}
|
||||
else:
|
||||
error_text = await resp.text()
|
||||
return {"error": f"HTTP {resp.status}: {error_text}"}
|
||||
|
||||
except aiohttp.ClientConnectorError:
|
||||
return {"error": f"Cannot connect to Trillium at {TRILLIUM_URL}"}
|
||||
except Exception as e:
|
||||
return {"error": f"Search failed: {str(e)}"}
|
||||
|
||||
|
||||
async def create_note(args: Dict) -> Dict:
|
||||
"""Create a note in Trillium via ETAPI.
|
||||
|
||||
Args:
|
||||
args: Dictionary containing:
|
||||
- title (str): Note title
|
||||
- content (str): Note content in markdown or HTML
|
||||
- parent_note_id (str, optional): Parent note ID to nest under
|
||||
|
||||
Returns:
|
||||
dict: Creation result containing:
|
||||
- noteId (str): ID of created note
|
||||
- title (str): Title of created note
|
||||
- success (bool): True if created successfully
|
||||
OR
|
||||
- error (str): Error message if creation failed
|
||||
"""
|
||||
title = args.get("title")
|
||||
content = args.get("content")
|
||||
parent_note_id = args.get("parent_note_id")
|
||||
|
||||
# Validation
|
||||
if not title:
|
||||
return {"error": "No title provided"}
|
||||
|
||||
if not content:
|
||||
return {"error": "No content provided"}
|
||||
|
||||
if not TRILLIUM_TOKEN:
|
||||
return {"error": "TRILLIUM_ETAPI_TOKEN not configured in environment"}
|
||||
|
||||
# Prepare payload
|
||||
payload = {
|
||||
"title": title,
|
||||
"content": content,
|
||||
"type": "text",
|
||||
"mime": "text/html"
|
||||
}
|
||||
|
||||
if parent_note_id:
|
||||
payload["parentNoteId"] = parent_note_id
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{TRILLIUM_URL}/etapi/create-note",
|
||||
json=payload,
|
||||
headers={"Authorization": TRILLIUM_TOKEN}
|
||||
) as resp:
|
||||
if resp.status in [200, 201]:
|
||||
data = await resp.json()
|
||||
return {
|
||||
"noteId": data.get("noteId"),
|
||||
"title": title,
|
||||
"success": True
|
||||
}
|
||||
elif resp.status == 401:
|
||||
return {"error": "Authentication failed. Check TRILLIUM_ETAPI_TOKEN"}
|
||||
else:
|
||||
error_text = await resp.text()
|
||||
return {"error": f"HTTP {resp.status}: {error_text}"}
|
||||
|
||||
except aiohttp.ClientConnectorError:
|
||||
return {"error": f"Cannot connect to Trillium at {TRILLIUM_URL}"}
|
||||
except Exception as e:
|
||||
return {"error": f"Note creation failed: {str(e)}"}
|
||||
55
cortex/autonomy/tools/executors/web_search.py
Normal file
55
cortex/autonomy/tools/executors/web_search.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""
|
||||
Web search executor using DuckDuckGo.
|
||||
|
||||
This module provides web search capabilities without requiring API keys.
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
from duckduckgo_search import DDGS
|
||||
|
||||
|
||||
async def search_web(args: Dict) -> Dict:
|
||||
"""Search the web using DuckDuckGo.
|
||||
|
||||
Args:
|
||||
args: Dictionary containing:
|
||||
- query (str): The search query
|
||||
- max_results (int, optional): Maximum results to return (default: 5, max: 10)
|
||||
|
||||
Returns:
|
||||
dict: Search results containing:
|
||||
- results (list): List of search results with title, url, snippet
|
||||
- count (int): Number of results returned
|
||||
OR
|
||||
- error (str): Error message if search failed
|
||||
"""
|
||||
query = args.get("query")
|
||||
max_results = args.get("max_results", 5)
|
||||
|
||||
# Validation
|
||||
if not query:
|
||||
return {"error": "No query provided"}
|
||||
|
||||
# Cap max_results
|
||||
max_results = min(max(max_results, 1), 10)
|
||||
|
||||
try:
|
||||
# DuckDuckGo search is synchronous, but we wrap it for consistency
|
||||
with DDGS() as ddgs:
|
||||
results = []
|
||||
|
||||
# Perform text search
|
||||
for result in ddgs.text(query, max_results=max_results):
|
||||
results.append({
|
||||
"title": result.get("title", ""),
|
||||
"url": result.get("href", ""),
|
||||
"snippet": result.get("body", "")
|
||||
})
|
||||
|
||||
return {
|
||||
"results": results,
|
||||
"count": len(results)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {"error": f"Search failed: {str(e)}"}
|
||||
235
cortex/autonomy/tools/function_caller.py
Normal file
235
cortex/autonomy/tools/function_caller.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""
|
||||
Provider-agnostic function caller with iterative tool calling loop.
|
||||
|
||||
This module implements the iterative loop that allows LLMs to call tools
|
||||
multiple times until they have the information they need to answer the user.
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
from llm.llm_router import call_llm, TOOL_ADAPTERS, BACKENDS
|
||||
from .registry import get_registry
|
||||
from .stream_events import get_stream_manager
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FunctionCaller:
|
||||
"""Provider-agnostic iterative tool calling loop.
|
||||
|
||||
This class orchestrates the back-and-forth between the LLM and tools:
|
||||
1. Call LLM with tools available
|
||||
2. If LLM requests tool calls, execute them
|
||||
3. Add results to conversation
|
||||
4. Repeat until LLM is done or max iterations reached
|
||||
"""
|
||||
|
||||
def __init__(self, backend: str, temperature: float = 0.7):
|
||||
"""Initialize function caller.
|
||||
|
||||
Args:
|
||||
backend: LLM backend to use ("OPENAI", "OLLAMA", etc.)
|
||||
temperature: Temperature for LLM calls
|
||||
"""
|
||||
self.backend = backend
|
||||
self.temperature = temperature
|
||||
self.registry = get_registry()
|
||||
self.max_iterations = int(os.getenv("MAX_TOOL_ITERATIONS", "5"))
|
||||
|
||||
# Resolve adapter for this backend
|
||||
self.adapter = self._get_adapter()
|
||||
|
||||
def _get_adapter(self):
|
||||
"""Get the appropriate adapter for this backend."""
|
||||
adapter = TOOL_ADAPTERS.get(self.backend)
|
||||
|
||||
# For PRIMARY/SECONDARY/FALLBACK, determine adapter based on provider
|
||||
if adapter is None and self.backend in ["PRIMARY", "SECONDARY", "FALLBACK"]:
|
||||
cfg = BACKENDS.get(self.backend, {})
|
||||
provider = cfg.get("provider", "").lower()
|
||||
|
||||
if provider == "openai":
|
||||
adapter = TOOL_ADAPTERS["OPENAI"]
|
||||
elif provider == "ollama":
|
||||
adapter = TOOL_ADAPTERS["OLLAMA"]
|
||||
elif provider == "mi50":
|
||||
adapter = TOOL_ADAPTERS["MI50"]
|
||||
|
||||
return adapter
|
||||
|
||||
async def call_with_tools(
|
||||
self,
|
||||
messages: List[Dict],
|
||||
max_tokens: int = 2048,
|
||||
session_id: Optional[str] = None
|
||||
) -> Dict:
|
||||
"""Execute LLM with iterative tool calling.
|
||||
|
||||
Args:
|
||||
messages: Conversation history
|
||||
max_tokens: Maximum tokens for LLM response
|
||||
session_id: Optional session ID for streaming events
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
"content": str, # Final response
|
||||
"iterations": int, # Number of iterations
|
||||
"tool_calls": list, # All tool calls made
|
||||
"messages": list, # Full conversation history
|
||||
"truncated": bool (optional) # True if max iterations reached
|
||||
}
|
||||
"""
|
||||
logger.info(f"🔍 FunctionCaller.call_with_tools() invoked with {len(messages)} messages")
|
||||
tools = self.registry.get_tool_definitions()
|
||||
logger.info(f"🔍 Got {len(tools or [])} tool definitions from registry")
|
||||
|
||||
# Get stream manager for emitting events
|
||||
stream_manager = get_stream_manager()
|
||||
should_stream = session_id and stream_manager.has_subscribers(session_id)
|
||||
|
||||
# If no tools are enabled, just call LLM directly
|
||||
if not tools:
|
||||
logger.warning("FunctionCaller invoked but no tools are enabled")
|
||||
response = await call_llm(
|
||||
messages=messages,
|
||||
backend=self.backend,
|
||||
temperature=self.temperature,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
return {
|
||||
"content": response,
|
||||
"iterations": 1,
|
||||
"tool_calls": [],
|
||||
"messages": messages + [{"role": "assistant", "content": response}]
|
||||
}
|
||||
|
||||
conversation = messages.copy()
|
||||
all_tool_calls = []
|
||||
|
||||
for iteration in range(self.max_iterations):
|
||||
logger.info(f"Tool calling iteration {iteration + 1}/{self.max_iterations}")
|
||||
|
||||
# Emit thinking event
|
||||
if should_stream:
|
||||
await stream_manager.emit(session_id, "thinking", {
|
||||
"message": f"🤔 Thinking... (iteration {iteration + 1}/{self.max_iterations})"
|
||||
})
|
||||
|
||||
# Call LLM with tools
|
||||
try:
|
||||
response = await call_llm(
|
||||
messages=conversation,
|
||||
backend=self.backend,
|
||||
temperature=self.temperature,
|
||||
max_tokens=max_tokens,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
return_adapter_response=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"LLM call failed: {str(e)}")
|
||||
if should_stream:
|
||||
await stream_manager.emit(session_id, "error", {
|
||||
"message": f"❌ Error: {str(e)}"
|
||||
})
|
||||
return {
|
||||
"content": f"Error calling LLM: {str(e)}",
|
||||
"iterations": iteration + 1,
|
||||
"tool_calls": all_tool_calls,
|
||||
"messages": conversation,
|
||||
"error": True
|
||||
}
|
||||
|
||||
# Add assistant message to conversation
|
||||
if response.get("content"):
|
||||
conversation.append({
|
||||
"role": "assistant",
|
||||
"content": response["content"]
|
||||
})
|
||||
|
||||
# Check for tool calls
|
||||
tool_calls = response.get("tool_calls")
|
||||
logger.debug(f"Response from LLM: content_length={len(response.get('content', ''))}, tool_calls={tool_calls}")
|
||||
if not tool_calls:
|
||||
# No more tool calls - LLM is done
|
||||
logger.info(f"Tool calling complete after {iteration + 1} iterations")
|
||||
if should_stream:
|
||||
await stream_manager.emit(session_id, "done", {
|
||||
"message": "✅ Complete!",
|
||||
"final_answer": response["content"]
|
||||
})
|
||||
return {
|
||||
"content": response["content"],
|
||||
"iterations": iteration + 1,
|
||||
"tool_calls": all_tool_calls,
|
||||
"messages": conversation
|
||||
}
|
||||
|
||||
# Execute each tool call
|
||||
logger.info(f"Executing {len(tool_calls)} tool call(s)")
|
||||
for tool_call in tool_calls:
|
||||
all_tool_calls.append(tool_call)
|
||||
|
||||
tool_name = tool_call.get("name")
|
||||
tool_args = tool_call.get("arguments", {})
|
||||
tool_id = tool_call.get("id", "unknown")
|
||||
|
||||
logger.info(f"Calling tool: {tool_name} with args: {tool_args}")
|
||||
|
||||
# Emit tool call event
|
||||
if should_stream:
|
||||
await stream_manager.emit(session_id, "tool_call", {
|
||||
"tool": tool_name,
|
||||
"args": tool_args,
|
||||
"message": f"🔧 Using tool: {tool_name}"
|
||||
})
|
||||
|
||||
try:
|
||||
# Execute tool
|
||||
result = await self.registry.execute_tool(tool_name, tool_args)
|
||||
logger.info(f"Tool {tool_name} executed successfully")
|
||||
|
||||
# Emit tool result event
|
||||
if should_stream:
|
||||
# Format result preview
|
||||
result_preview = str(result)
|
||||
if len(result_preview) > 200:
|
||||
result_preview = result_preview[:200] + "..."
|
||||
|
||||
await stream_manager.emit(session_id, "tool_result", {
|
||||
"tool": tool_name,
|
||||
"result": result,
|
||||
"message": f"📊 Result: {result_preview}"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Tool {tool_name} execution failed: {str(e)}")
|
||||
result = {"error": f"Tool execution failed: {str(e)}"}
|
||||
|
||||
# Format result using adapter
|
||||
if not self.adapter:
|
||||
logger.warning(f"No adapter available for backend {self.backend}, using fallback format")
|
||||
result_msg = {
|
||||
"role": "user",
|
||||
"content": f"Tool {tool_name} result: {result}"
|
||||
}
|
||||
else:
|
||||
result_msg = self.adapter.format_tool_result(
|
||||
tool_id,
|
||||
tool_name,
|
||||
result
|
||||
)
|
||||
|
||||
conversation.append(result_msg)
|
||||
|
||||
# Max iterations reached without completion
|
||||
logger.warning(f"Tool calling truncated after {self.max_iterations} iterations")
|
||||
return {
|
||||
"content": response.get("content", ""),
|
||||
"iterations": self.max_iterations,
|
||||
"tool_calls": all_tool_calls,
|
||||
"messages": conversation,
|
||||
"truncated": True
|
||||
}
|
||||
196
cortex/autonomy/tools/registry.py
Normal file
196
cortex/autonomy/tools/registry.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""
|
||||
Provider-agnostic Tool Registry for Lyra.
|
||||
|
||||
This module provides a central registry for all available tools with
|
||||
Lyra-native definitions (not provider-specific).
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
from .executors import execute_code, search_web, search_notes, create_note
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""Registry for managing available tools and their definitions.
|
||||
|
||||
Tools are defined in Lyra's own format (provider-agnostic), and
|
||||
adapters convert them to provider-specific formats (OpenAI function
|
||||
calling, Ollama XML prompts, etc.).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the tool registry with feature flags from environment."""
|
||||
self.tools = {}
|
||||
self.executors = {}
|
||||
|
||||
# Feature flags from environment
|
||||
self.code_execution_enabled = os.getenv("ENABLE_CODE_EXECUTION", "true").lower() == "true"
|
||||
self.web_search_enabled = os.getenv("ENABLE_WEB_SEARCH", "true").lower() == "true"
|
||||
self.trillium_enabled = os.getenv("ENABLE_TRILLIUM", "false").lower() == "true"
|
||||
|
||||
self._register_tools()
|
||||
self._register_executors()
|
||||
|
||||
def _register_executors(self):
|
||||
"""Register executor functions for each tool."""
|
||||
if self.code_execution_enabled:
|
||||
self.executors["execute_code"] = execute_code
|
||||
|
||||
if self.web_search_enabled:
|
||||
self.executors["search_web"] = search_web
|
||||
|
||||
if self.trillium_enabled:
|
||||
self.executors["search_notes"] = search_notes
|
||||
self.executors["create_note"] = create_note
|
||||
|
||||
def _register_tools(self):
|
||||
"""Register all available tools based on feature flags."""
|
||||
|
||||
if self.code_execution_enabled:
|
||||
self.tools["execute_code"] = {
|
||||
"name": "execute_code",
|
||||
"description": "Execute Python or bash code in a secure sandbox environment. Use this to perform calculations, data processing, file operations, or any programmatic tasks. The sandbox is persistent across calls within a session and has common Python packages (numpy, pandas, requests, matplotlib, scipy) pre-installed.",
|
||||
"parameters": {
|
||||
"language": {
|
||||
"type": "string",
|
||||
"enum": ["python", "bash"],
|
||||
"description": "The programming language to execute (python or bash)"
|
||||
},
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "The code to execute. For multi-line code, use proper indentation. For Python, use standard Python 3.11 syntax."
|
||||
},
|
||||
"reason": {
|
||||
"type": "string",
|
||||
"description": "Brief explanation of why you're executing this code and what you expect to achieve"
|
||||
}
|
||||
},
|
||||
"required": ["language", "code", "reason"]
|
||||
}
|
||||
|
||||
if self.web_search_enabled:
|
||||
self.tools["search_web"] = {
|
||||
"name": "search_web",
|
||||
"description": "Search the internet using DuckDuckGo to find current information, facts, news, or answers to questions. Returns a list of search results with titles, snippets, and URLs. Use this when you need up-to-date information or facts not in your training data.",
|
||||
"parameters": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query to look up on the internet"
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of results to return (default: 5, max: 10)"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
|
||||
if self.trillium_enabled:
|
||||
self.tools["search_notes"] = {
|
||||
"name": "search_notes",
|
||||
"description": "Search through Trillium notes to find relevant information. Use this to retrieve knowledge, context, or information previously stored in the user's notes.",
|
||||
"parameters": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query to find matching notes"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of notes to return (default: 5, max: 20)"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
|
||||
self.tools["create_note"] = {
|
||||
"name": "create_note",
|
||||
"description": "Create a new note in Trillium. Use this to store important information, insights, or knowledge for future reference. Notes are stored in the user's Trillium knowledge base.",
|
||||
"parameters": {
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": "The title of the note"
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The content of the note in markdown or HTML format"
|
||||
},
|
||||
"parent_note_id": {
|
||||
"type": "string",
|
||||
"description": "Optional ID of the parent note to nest this note under"
|
||||
}
|
||||
},
|
||||
"required": ["title", "content"]
|
||||
}
|
||||
|
||||
def get_tool_definitions(self) -> Optional[List[Dict]]:
|
||||
"""Get list of all enabled tool definitions in Lyra format.
|
||||
|
||||
Returns:
|
||||
list: List of tool definition dicts, or None if no tools enabled
|
||||
"""
|
||||
if not self.tools:
|
||||
return None
|
||||
return list(self.tools.values())
|
||||
|
||||
def get_tool_names(self) -> List[str]:
|
||||
"""Get list of all enabled tool names.
|
||||
|
||||
Returns:
|
||||
list: List of tool name strings
|
||||
"""
|
||||
return list(self.tools.keys())
|
||||
|
||||
def is_tool_enabled(self, tool_name: str) -> bool:
|
||||
"""Check if a specific tool is enabled.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool to check
|
||||
|
||||
Returns:
|
||||
bool: True if tool is enabled, False otherwise
|
||||
"""
|
||||
return tool_name in self.tools
|
||||
|
||||
def register_executor(self, tool_name: str, executor_func):
|
||||
"""Register an executor function for a tool.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool
|
||||
executor_func: Async function that executes the tool
|
||||
"""
|
||||
self.executors[tool_name] = executor_func
|
||||
|
||||
async def execute_tool(self, name: str, arguments: dict) -> dict:
|
||||
"""Execute a tool by name.
|
||||
|
||||
Args:
|
||||
name: Tool name
|
||||
arguments: Tool arguments dict
|
||||
|
||||
Returns:
|
||||
dict: Tool execution result
|
||||
"""
|
||||
if name not in self.executors:
|
||||
return {"error": f"Unknown tool: {name}"}
|
||||
|
||||
executor = self.executors[name]
|
||||
try:
|
||||
return await executor(arguments)
|
||||
except Exception as e:
|
||||
return {"error": f"Tool execution failed: {str(e)}"}
|
||||
|
||||
|
||||
# Global registry instance (singleton pattern)
|
||||
_registry = None
|
||||
|
||||
|
||||
def get_registry() -> ToolRegistry:
|
||||
"""Get the global ToolRegistry instance.
|
||||
|
||||
Returns:
|
||||
ToolRegistry: The global registry instance
|
||||
"""
|
||||
global _registry
|
||||
if _registry is None:
|
||||
_registry = ToolRegistry()
|
||||
return _registry
|
||||
91
cortex/autonomy/tools/stream_events.py
Normal file
91
cortex/autonomy/tools/stream_events.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
Event streaming for tool calling "show your work" feature.
|
||||
|
||||
This module manages Server-Sent Events (SSE) for broadcasting the internal
|
||||
thinking process during tool calling operations.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, Optional
|
||||
from collections import defaultdict
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolStreamManager:
|
||||
"""Manages SSE streams for tool calling events."""
|
||||
|
||||
def __init__(self):
|
||||
# session_id -> list of queues (one per connected client)
|
||||
self._subscribers: Dict[str, list] = defaultdict(list)
|
||||
|
||||
def subscribe(self, session_id: str) -> asyncio.Queue:
|
||||
"""Subscribe to events for a session.
|
||||
|
||||
Returns:
|
||||
Queue that will receive events for this session
|
||||
"""
|
||||
queue = asyncio.Queue()
|
||||
self._subscribers[session_id].append(queue)
|
||||
logger.info(f"New subscriber for session {session_id}, total: {len(self._subscribers[session_id])}")
|
||||
return queue
|
||||
|
||||
def unsubscribe(self, session_id: str, queue: asyncio.Queue):
|
||||
"""Unsubscribe from events for a session."""
|
||||
if session_id in self._subscribers:
|
||||
try:
|
||||
self._subscribers[session_id].remove(queue)
|
||||
logger.info(f"Removed subscriber for session {session_id}, remaining: {len(self._subscribers[session_id])}")
|
||||
|
||||
# Clean up empty lists
|
||||
if not self._subscribers[session_id]:
|
||||
del self._subscribers[session_id]
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
async def emit(self, session_id: str, event_type: str, data: dict):
|
||||
"""Emit an event to all subscribers of a session.
|
||||
|
||||
Args:
|
||||
session_id: Session to emit to
|
||||
event_type: Type of event (thinking, tool_call, tool_result, done)
|
||||
data: Event data
|
||||
"""
|
||||
if session_id not in self._subscribers:
|
||||
return
|
||||
|
||||
event = {
|
||||
"type": event_type,
|
||||
"data": data
|
||||
}
|
||||
|
||||
# Send to all subscribers
|
||||
dead_queues = []
|
||||
for queue in self._subscribers[session_id]:
|
||||
try:
|
||||
await queue.put(event)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to emit event to queue: {e}")
|
||||
dead_queues.append(queue)
|
||||
|
||||
# Clean up dead queues
|
||||
for queue in dead_queues:
|
||||
self.unsubscribe(session_id, queue)
|
||||
|
||||
def has_subscribers(self, session_id: str) -> bool:
|
||||
"""Check if a session has any active subscribers."""
|
||||
return session_id in self._subscribers and len(self._subscribers[session_id]) > 0
|
||||
|
||||
|
||||
# Global stream manager instance
|
||||
_stream_manager: Optional[ToolStreamManager] = None
|
||||
|
||||
|
||||
def get_stream_manager() -> ToolStreamManager:
|
||||
"""Get the global stream manager instance."""
|
||||
global _stream_manager
|
||||
if _stream_manager is None:
|
||||
_stream_manager = ToolStreamManager()
|
||||
return _stream_manager
|
||||
Reference in New Issue
Block a user