44 lines
2.1 KiB
Python
44 lines
2.1 KiB
Python
# cortex/neomem_client.py
|
|
import os, httpx, logging
|
|
from typing import List, Dict, Any, Optional
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class NeoMemClient:
|
|
"""Simple REST client for the NeoMem API (search/add/health)."""
|
|
|
|
def __init__(self):
|
|
self.base_url = os.getenv("NEOMEM_API", "http://neomem-api:7077")
|
|
self.api_key = os.getenv("NEOMEM_API_KEY", None)
|
|
self.headers = {"Content-Type": "application/json"}
|
|
if self.api_key:
|
|
self.headers["Authorization"] = f"Bearer {self.api_key}"
|
|
|
|
async def health(self) -> Dict[str, Any]:
|
|
async with httpx.AsyncClient(timeout=10) as client:
|
|
r = await client.get(f"{self.base_url}/health")
|
|
r.raise_for_status()
|
|
return r.json()
|
|
|
|
async def search(self, query: str, user_id: str, limit: int = 25, threshold: float = 0.82) -> List[Dict[str, Any]]:
|
|
payload = {"query": query, "user_id": user_id, "limit": limit}
|
|
async with httpx.AsyncClient(timeout=30) as client:
|
|
r = await client.post(f"{self.base_url}/search", headers=self.headers, json=payload)
|
|
if r.status_code != 200:
|
|
logger.warning(f"NeoMem search failed ({r.status_code}): {r.text}")
|
|
return []
|
|
results = r.json()
|
|
# Filter by score threshold if field exists
|
|
if isinstance(results, dict) and "results" in results:
|
|
results = results["results"]
|
|
filtered = [m for m in results if float(m.get("score", 0)) >= threshold]
|
|
logger.info(f"NeoMem search returned {len(filtered)} results above {threshold}")
|
|
return filtered
|
|
|
|
async def add(self, messages: List[Dict[str, Any]], user_id: str, metadata: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
|
payload = {"messages": messages, "user_id": user_id, "metadata": metadata or {}}
|
|
async with httpx.AsyncClient(timeout=30) as client:
|
|
r = await client.post(f"{self.base_url}/memories", headers=self.headers, json=payload)
|
|
r.raise_for_status()
|
|
return r.json()
|