tool improvment

This commit is contained in:
serversdwn
2025-12-31 22:36:24 -05:00
parent 6716245a99
commit b700ac3808
10 changed files with 598 additions and 80 deletions

View File

@@ -10,6 +10,14 @@ import os
import tempfile
import re
from typing import Dict
import docker
from docker.errors import (
DockerException,
APIError,
ContainerError,
ImageNotFound,
NotFound
)
# Forbidden patterns that pose security risks
@@ -67,6 +75,27 @@ async def execute_code(args: Dict) -> Dict:
container = os.getenv("CODE_SANDBOX_CONTAINER", "lyra-code-sandbox")
# Validate container exists and is running
try:
docker_client = docker.from_env()
container_obj = docker_client.containers.get(container)
if container_obj.status != "running":
return {
"error": f"Sandbox container '{container}' is not running (status: {container_obj.status})",
"hint": "Start the container with: docker start " + container
}
except NotFound:
return {
"error": f"Sandbox container '{container}' not found",
"hint": "Ensure the container exists and is running"
}
except DockerException as e:
return {
"error": f"Docker daemon error: {str(e)}",
"hint": "Check Docker connectivity and permissions"
}
# Write code to temporary file
suffix = ".py" if language == "python" else ".sh"
try:
@@ -125,15 +154,15 @@ async def execute_code(args: Dict) -> Dict:
execution_time = asyncio.get_event_loop().time() - start_time
# Truncate output to prevent memory issues
max_output = 10 * 1024 # 10KB
# Truncate output to prevent memory issues (configurable)
max_output = int(os.getenv("CODE_SANDBOX_MAX_OUTPUT", "10240")) # 10KB default
stdout_str = stdout[:max_output].decode('utf-8', errors='replace')
stderr_str = stderr[:max_output].decode('utf-8', errors='replace')
if len(stdout) > max_output:
stdout_str += "\n... (output truncated)"
stdout_str += f"\n... (output truncated, {len(stdout)} bytes total)"
if len(stderr) > max_output:
stderr_str += "\n... (output truncated)"
stderr_str += f"\n... (output truncated, {len(stderr)} bytes total)"
return {
"stdout": stdout_str,
@@ -151,12 +180,39 @@ async def execute_code(args: Dict) -> Dict:
pass
return {"error": f"Execution timeout after {timeout}s"}
except APIError as e:
return {
"error": f"Docker API error: {e.explanation}",
"status_code": e.status_code
}
except ContainerError as e:
return {
"error": f"Container execution error: {str(e)}",
"exit_code": e.exit_status
}
except DockerException as e:
return {
"error": f"Docker error: {str(e)}",
"hint": "Check Docker daemon connectivity and permissions"
}
except Exception as e:
return {"error": f"Execution failed: {str(e)}"}
finally:
# Cleanup temporary file
try:
os.unlink(temp_file)
except:
if 'temp_file' in locals():
os.unlink(temp_file)
except Exception as cleanup_error:
# Log but don't fail on cleanup errors
pass
# Optional: Clean up file from container (best effort)
try:
if 'exec_path' in locals() and 'container_obj' in locals():
container_obj.exec_run(
f"rm -f {exec_path}",
user="sandbox"
)
except:
pass # Best effort cleanup

View File

@@ -0,0 +1,13 @@
"""Web search provider implementations."""
from .base import SearchProvider, SearchResult, SearchResponse
from .brave import BraveSearchProvider
from .duckduckgo import DuckDuckGoProvider
__all__ = [
"SearchProvider",
"SearchResult",
"SearchResponse",
"BraveSearchProvider",
"DuckDuckGoProvider",
]

View File

@@ -0,0 +1,49 @@
"""Base interface for web search providers."""
from abc import ABC, abstractmethod
from typing import List, Optional
from dataclasses import dataclass
@dataclass
class SearchResult:
"""Standardized search result format."""
title: str
url: str
snippet: str
score: Optional[float] = None
@dataclass
class SearchResponse:
"""Standardized search response."""
results: List[SearchResult]
count: int
provider: str
query: str
error: Optional[str] = None
class SearchProvider(ABC):
"""Abstract base class for search providers."""
@abstractmethod
async def search(
self,
query: str,
max_results: int = 5,
**kwargs
) -> SearchResponse:
"""Execute search and return standardized results."""
pass
@abstractmethod
async def health_check(self) -> bool:
"""Check if provider is healthy and reachable."""
pass
@property
@abstractmethod
def name(self) -> str:
"""Provider name."""
pass

View File

@@ -0,0 +1,123 @@
"""Brave Search API provider implementation."""
import os
import asyncio
import aiohttp
from .base import SearchProvider, SearchResponse, SearchResult
from ..utils.resilience import async_retry
class BraveSearchProvider(SearchProvider):
"""Brave Search API implementation."""
def __init__(self):
self.api_key = os.getenv("BRAVE_SEARCH_API_KEY", "")
self.base_url = os.getenv(
"BRAVE_SEARCH_URL",
"https://api.search.brave.com/res/v1"
)
self.timeout = float(os.getenv("BRAVE_SEARCH_TIMEOUT", "10.0"))
@property
def name(self) -> str:
return "brave"
@async_retry(
max_attempts=3,
exceptions=(aiohttp.ClientError, asyncio.TimeoutError)
)
async def search(
self,
query: str,
max_results: int = 5,
**kwargs
) -> SearchResponse:
"""Execute Brave search with retry logic."""
if not self.api_key:
return SearchResponse(
results=[],
count=0,
provider=self.name,
query=query,
error="BRAVE_SEARCH_API_KEY not configured"
)
headers = {
"Accept": "application/json",
"X-Subscription-Token": self.api_key
}
params = {
"q": query,
"count": min(max_results, 20) # Brave max is 20
}
try:
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.base_url}/web/search",
headers=headers,
params=params,
timeout=aiohttp.ClientTimeout(total=self.timeout)
) as resp:
if resp.status == 200:
data = await resp.json()
results = []
for item in data.get("web", {}).get("results", []):
results.append(SearchResult(
title=item.get("title", ""),
url=item.get("url", ""),
snippet=item.get("description", ""),
score=item.get("score")
))
return SearchResponse(
results=results,
count=len(results),
provider=self.name,
query=query
)
elif resp.status == 401:
error = "Authentication failed. Check BRAVE_SEARCH_API_KEY"
elif resp.status == 429:
error = f"Rate limit exceeded. Status: {resp.status}"
else:
error_text = await resp.text()
error = f"HTTP {resp.status}: {error_text}"
return SearchResponse(
results=[],
count=0,
provider=self.name,
query=query,
error=error
)
except aiohttp.ClientConnectorError as e:
return SearchResponse(
results=[],
count=0,
provider=self.name,
query=query,
error=f"Cannot connect to Brave Search API: {str(e)}"
)
except asyncio.TimeoutError:
return SearchResponse(
results=[],
count=0,
provider=self.name,
query=query,
error=f"Search timeout after {self.timeout}s"
)
async def health_check(self) -> bool:
"""Check if Brave API is reachable."""
if not self.api_key:
return False
try:
response = await self.search("test", max_results=1)
return response.error is None
except:
return False

View File

@@ -0,0 +1,60 @@
"""DuckDuckGo search provider with retry logic (legacy fallback)."""
from duckduckgo_search import DDGS
from .base import SearchProvider, SearchResponse, SearchResult
from ..utils.resilience import async_retry
class DuckDuckGoProvider(SearchProvider):
"""DuckDuckGo search implementation with retry logic."""
@property
def name(self) -> str:
return "duckduckgo"
@async_retry(
max_attempts=3,
exceptions=(Exception,) # DDG throws generic exceptions
)
async def search(
self,
query: str,
max_results: int = 5,
**kwargs
) -> SearchResponse:
"""Execute DuckDuckGo search with retry logic."""
try:
with DDGS() as ddgs:
results = []
for result in ddgs.text(query, max_results=max_results):
results.append(SearchResult(
title=result.get("title", ""),
url=result.get("href", ""),
snippet=result.get("body", "")
))
return SearchResponse(
results=results,
count=len(results),
provider=self.name,
query=query
)
except Exception as e:
return SearchResponse(
results=[],
count=0,
provider=self.name,
query=query,
error=f"Search failed: {str(e)}"
)
async def health_check(self) -> bool:
"""Basic health check for DDG."""
try:
response = await self.search("test", max_results=1)
return response.error is None
except:
return False

View File

@@ -1,20 +1,42 @@
"""
Trilium notes executor for searching and creating notes via ETAPI.
This module provides integration with Trilium notes through the ETAPI HTTP API.
This module provides integration with Trilium notes through the ETAPI HTTP API
with improved resilience: timeout configuration, retry logic, and connection pooling.
"""
import os
import asyncio
import aiohttp
from typing import Dict
from typing import Dict, Optional
from ..utils.resilience import async_retry
TRILIUM_URL = os.getenv("TRILIUM_URL", "http://localhost:8080")
TRILIUM_TOKEN = os.getenv("TRILIUM_ETAPI_TOKEN", "")
# Module-level session for connection pooling
_session: Optional[aiohttp.ClientSession] = None
def get_session() -> aiohttp.ClientSession:
"""Get or create shared aiohttp session for connection pooling."""
global _session
if _session is None or _session.closed:
timeout = aiohttp.ClientTimeout(
total=float(os.getenv("TRILIUM_TIMEOUT", "30.0")),
connect=float(os.getenv("TRILIUM_CONNECT_TIMEOUT", "10.0"))
)
_session = aiohttp.ClientSession(timeout=timeout)
return _session
@async_retry(
max_attempts=3,
exceptions=(aiohttp.ClientError, asyncio.TimeoutError)
)
async def search_notes(args: Dict) -> Dict:
"""Search Trilium notes via ETAPI.
"""Search Trilium notes via ETAPI with retry logic.
Args:
args: Dictionary containing:
@@ -36,40 +58,72 @@ async def search_notes(args: Dict) -> Dict:
return {"error": "No query provided"}
if not TRILIUM_TOKEN:
return {"error": "TRILIUM_ETAPI_TOKEN not configured in environment"}
return {
"error": "TRILIUM_ETAPI_TOKEN not configured in environment",
"hint": "Set TRILIUM_ETAPI_TOKEN in .env file"
}
# Cap limit
limit = min(max(limit, 1), 20)
try:
async with aiohttp.ClientSession() as session:
async with session.get(
f"{TRILIUM_URL}/etapi/notes",
params={"search": query, "limit": limit},
headers={"Authorization": TRILIUM_TOKEN}
) as resp:
if resp.status == 200:
data = await resp.json()
# ETAPI returns {"results": [...]} format
results = data.get("results", [])
return {
"notes": results,
"count": len(results)
}
elif resp.status == 401:
return {"error": "Authentication failed. Check TRILIUM_ETAPI_TOKEN"}
else:
error_text = await resp.text()
return {"error": f"HTTP {resp.status}: {error_text}"}
session = get_session()
async with session.get(
f"{TRILIUM_URL}/etapi/notes",
params={"search": query, "limit": limit},
headers={"Authorization": TRILIUM_TOKEN}
) as resp:
if resp.status == 200:
data = await resp.json()
# ETAPI returns {"results": [...]} format
results = data.get("results", [])
return {
"notes": results,
"count": len(results)
}
elif resp.status == 401:
return {
"error": "Authentication failed. Check TRILIUM_ETAPI_TOKEN",
"status": 401
}
elif resp.status == 404:
return {
"error": "Trilium API endpoint not found. Check TRILIUM_URL",
"status": 404,
"url": TRILIUM_URL
}
else:
error_text = await resp.text()
return {
"error": f"HTTP {resp.status}: {error_text}",
"status": resp.status
}
except aiohttp.ClientConnectorError:
return {"error": f"Cannot connect to Trilium at {TRILIUM_URL}"}
except aiohttp.ClientConnectorError as e:
return {
"error": f"Cannot connect to Trilium at {TRILIUM_URL}",
"hint": "Check if Trilium is running and URL is correct",
"details": str(e)
}
except asyncio.TimeoutError:
timeout = os.getenv("TRILIUM_TIMEOUT", "30.0")
return {
"error": f"Trilium request timeout after {timeout}s",
"hint": "Trilium may be slow or unresponsive"
}
except Exception as e:
return {"error": f"Search failed: {str(e)}"}
return {
"error": f"Search failed: {str(e)}",
"type": type(e).__name__
}
@async_retry(
max_attempts=3,
exceptions=(aiohttp.ClientError, asyncio.TimeoutError)
)
async def create_note(args: Dict) -> Dict:
"""Create a note in Trilium via ETAPI.
"""Create a note in Trilium via ETAPI with retry logic.
Args:
args: Dictionary containing:
@@ -97,7 +151,10 @@ async def create_note(args: Dict) -> Dict:
return {"error": "No content provided"}
if not TRILIUM_TOKEN:
return {"error": "TRILIUM_ETAPI_TOKEN not configured in environment"}
return {
"error": "TRILIUM_ETAPI_TOKEN not configured in environment",
"hint": "Set TRILIUM_ETAPI_TOKEN in .env file"
}
# Prepare payload
payload = {
@@ -109,26 +166,51 @@ async def create_note(args: Dict) -> Dict:
}
try:
async with aiohttp.ClientSession() as session:
async with session.post(
f"{TRILIUM_URL}/etapi/create-note",
json=payload,
headers={"Authorization": TRILIUM_TOKEN}
) as resp:
if resp.status in [200, 201]:
data = await resp.json()
return {
"noteId": data.get("noteId"),
"title": title,
"success": True
}
elif resp.status == 401:
return {"error": "Authentication failed. Check TRILIUM_ETAPI_TOKEN"}
else:
error_text = await resp.text()
return {"error": f"HTTP {resp.status}: {error_text}"}
session = get_session()
async with session.post(
f"{TRILIUM_URL}/etapi/create-note",
json=payload,
headers={"Authorization": TRILIUM_TOKEN}
) as resp:
if resp.status in [200, 201]:
data = await resp.json()
return {
"noteId": data.get("noteId"),
"title": title,
"success": True
}
elif resp.status == 401:
return {
"error": "Authentication failed. Check TRILIUM_ETAPI_TOKEN",
"status": 401
}
elif resp.status == 404:
return {
"error": "Trilium API endpoint not found. Check TRILIUM_URL",
"status": 404,
"url": TRILIUM_URL
}
else:
error_text = await resp.text()
return {
"error": f"HTTP {resp.status}: {error_text}",
"status": resp.status
}
except aiohttp.ClientConnectorError:
return {"error": f"Cannot connect to Trilium at {TRILIUM_URL}"}
except aiohttp.ClientConnectorError as e:
return {
"error": f"Cannot connect to Trilium at {TRILIUM_URL}",
"hint": "Check if Trilium is running and URL is correct",
"details": str(e)
}
except asyncio.TimeoutError:
timeout = os.getenv("TRILIUM_TIMEOUT", "30.0")
return {
"error": f"Trilium request timeout after {timeout}s",
"hint": "Trilium may be slow or unresponsive"
}
except Exception as e:
return {"error": f"Note creation failed: {str(e)}"}
return {
"error": f"Note creation failed: {str(e)}",
"type": type(e).__name__
}

View File

@@ -1,55 +1,113 @@
"""
Web search executor using DuckDuckGo.
Web search executor with pluggable provider support.
This module provides web search capabilities without requiring API keys.
Supports multiple providers with automatic fallback:
- Brave Search API (recommended, configurable)
- DuckDuckGo (legacy fallback)
"""
from typing import Dict
from duckduckgo_search import DDGS
import os
from typing import Dict, Optional
from .search_providers.base import SearchProvider
from .search_providers.brave import BraveSearchProvider
from .search_providers.duckduckgo import DuckDuckGoProvider
# Provider registry
PROVIDERS = {
"brave": BraveSearchProvider,
"duckduckgo": DuckDuckGoProvider,
}
# Singleton provider instances
_provider_instances: Dict[str, SearchProvider] = {}
def get_provider(name: str) -> Optional[SearchProvider]:
"""Get or create provider instance."""
if name not in _provider_instances:
provider_class = PROVIDERS.get(name)
if provider_class:
_provider_instances[name] = provider_class()
return _provider_instances.get(name)
async def search_web(args: Dict) -> Dict:
"""Search the web using DuckDuckGo.
"""Search the web using configured provider with automatic fallback.
Args:
args: Dictionary containing:
- query (str): The search query
- max_results (int, optional): Maximum results to return (default: 5, max: 10)
- max_results (int, optional): Maximum results to return (default: 5, max: 20)
- provider (str, optional): Force specific provider
Returns:
dict: Search results containing:
- results (list): List of search results with title, url, snippet
- count (int): Number of results returned
- provider (str): Provider that returned results
OR
- error (str): Error message if search failed
- error (str): Error message if all providers failed
"""
query = args.get("query")
max_results = args.get("max_results", 5)
forced_provider = args.get("provider")
# Validation
if not query:
return {"error": "No query provided"}
# Cap max_results
max_results = min(max(max_results, 1), 10)
max_results = min(max(max_results, 1), 20)
try:
# DuckDuckGo search is synchronous, but we wrap it for consistency
with DDGS() as ddgs:
results = []
# Get provider preference from environment
primary_provider = os.getenv("WEB_SEARCH_PROVIDER", "duckduckgo")
fallback_providers = os.getenv(
"WEB_SEARCH_FALLBACK",
"duckduckgo"
).split(",")
# Perform text search
for result in ddgs.text(query, max_results=max_results):
results.append({
"title": result.get("title", ""),
"url": result.get("href", ""),
"snippet": result.get("body", "")
})
# Build provider list
if forced_provider:
providers_to_try = [forced_provider]
else:
providers_to_try = [primary_provider] + [
p.strip() for p in fallback_providers if p.strip() != primary_provider
]
return {
"results": results,
"count": len(results)
}
# Try providers in order
last_error = None
for provider_name in providers_to_try:
provider = get_provider(provider_name)
if not provider:
last_error = f"Unknown provider: {provider_name}"
continue
except Exception as e:
return {"error": f"Search failed: {str(e)}"}
try:
response = await provider.search(query, max_results)
# If successful, return results
if response.error is None and response.count > 0:
return {
"results": [
{
"title": r.title,
"url": r.url,
"snippet": r.snippet,
}
for r in response.results
],
"count": response.count,
"provider": provider_name
}
last_error = response.error or "No results returned"
except Exception as e:
last_error = f"{provider_name} failed: {str(e)}"
continue
# All providers failed
return {
"error": f"All search providers failed. Last error: {last_error}",
"providers_tried": providers_to_try
}

View File

@@ -0,0 +1,5 @@
"""Utility modules for tool executors."""
from .resilience import async_retry, async_timeout_wrapper
__all__ = ["async_retry", "async_timeout_wrapper"]

View File

@@ -0,0 +1,70 @@
"""Common resilience utilities for tool executors."""
import asyncio
import functools
import logging
from typing import Optional, Callable, Any, TypeVar
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
before_sleep_log
)
logger = logging.getLogger(__name__)
# Type variable for generic decorators
T = TypeVar('T')
def async_retry(
max_attempts: int = 3,
exceptions: tuple = (Exception,),
**kwargs
):
"""Async retry decorator with exponential backoff.
Args:
max_attempts: Maximum retry attempts
exceptions: Exception types to retry on
**kwargs: Additional tenacity configuration
Example:
@async_retry(max_attempts=3, exceptions=(aiohttp.ClientError,))
async def fetch_data():
...
"""
return retry(
stop=stop_after_attempt(max_attempts),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type(exceptions),
reraise=True,
before_sleep=before_sleep_log(logger, logging.WARNING),
**kwargs
)
async def async_timeout_wrapper(
coro: Callable[..., T],
timeout: float,
*args,
**kwargs
) -> T:
"""Wrap async function with timeout.
Args:
coro: Async function to wrap
timeout: Timeout in seconds
*args, **kwargs: Arguments for the function
Returns:
Result from the function
Raises:
asyncio.TimeoutError: If timeout exceeded
Example:
result = await async_timeout_wrapper(some_async_func, 5.0, arg1, arg2)
"""
return await asyncio.wait_for(coro(*args, **kwargs), timeout=timeout)