# cortex/neomem_client.py import os, httpx, logging from typing import List, Dict, Any, Optional logger = logging.getLogger(__name__) class NeoMemClient: """Simple REST client for the NeoMem API (search/add/health).""" def __init__(self): self.base_url = os.getenv("NEOMEM_API", "http://neomem-api:7077") self.api_key = os.getenv("NEOMEM_API_KEY", None) self.headers = {"Content-Type": "application/json"} if self.api_key: self.headers["Authorization"] = f"Bearer {self.api_key}" async def health(self) -> Dict[str, Any]: async with httpx.AsyncClient(timeout=10) as client: r = await client.get(f"{self.base_url}/health") r.raise_for_status() return r.json() async def search(self, query: str, user_id: str, limit: int = 25, threshold: float = 0.82) -> List[Dict[str, Any]]: payload = {"query": query, "user_id": user_id, "limit": limit} async with httpx.AsyncClient(timeout=30) as client: r = await client.post(f"{self.base_url}/search", headers=self.headers, json=payload) if r.status_code != 200: logger.warning(f"NeoMem search failed ({r.status_code}): {r.text}") return [] results = r.json() # Filter by score threshold if field exists if isinstance(results, dict) and "results" in results: results = results["results"] filtered = [m for m in results if float(m.get("score", 0)) >= threshold] logger.info(f"NeoMem search returned {len(filtered)} results above {threshold}") return filtered async def add(self, messages: List[Dict[str, Any]], user_id: str, metadata: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: payload = {"messages": messages, "user_id": user_id, "metadata": metadata or {}} async with httpx.AsyncClient(timeout=30) as client: r = await client.post(f"{self.base_url}/memories", headers=self.headers, json=payload) r.raise_for_status() return r.json()