Files
project-lyra/neomem/neomem/vector_stores/neptune_analytics.py
2025-11-16 03:17:32 -05:00

468 lines
15 KiB
Python

import logging
import time
import uuid
from typing import Dict, List, Optional
from pydantic import BaseModel
try:
from langchain_aws import NeptuneAnalyticsGraph
except ImportError:
raise ImportError("langchain_aws is not installed. Please install it using pip install langchain_aws")
from mem0.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__)
class OutputData(BaseModel):
id: Optional[str] # memory id
score: Optional[float] # distance
payload: Optional[Dict] # metadata
class NeptuneAnalyticsVector(VectorStoreBase):
"""
Neptune Analytics vector store implementation for Mem0.
Provides vector storage and similarity search capabilities using Amazon Neptune Analytics,
a serverless graph analytics service that supports vector operations.
"""
_COLLECTION_PREFIX = "MEM0_VECTOR_"
_FIELD_N = 'n'
_FIELD_ID = '~id'
_FIELD_PROP = '~properties'
_FIELD_SCORE = 'score'
_FIELD_LABEL = 'label'
_TIMEZONE = "UTC"
def __init__(
self,
endpoint: str,
collection_name: str,
):
"""
Initialize the Neptune Analytics vector store.
Args:
endpoint (str): Neptune Analytics endpoint in format 'neptune-graph://<graphid>'.
collection_name (str): Name of the collection to store vectors.
Raises:
ValueError: If endpoint format is invalid.
ImportError: If langchain_aws is not installed.
"""
if not endpoint.startswith("neptune-graph://"):
raise ValueError("Please provide 'endpoint' with the format as 'neptune-graph://<graphid>'.")
graph_id = endpoint.replace("neptune-graph://", "")
self.graph = NeptuneAnalyticsGraph(graph_id)
self.collection_name = self._COLLECTION_PREFIX + collection_name
def create_col(self, name, vector_size, distance):
"""
Create a collection (no-op for Neptune Analytics).
Neptune Analytics supports dynamic indices that are created implicitly
when vectors are inserted, so this method performs no operation.
Args:
name: Collection name (unused).
vector_size: Vector dimension (unused).
distance: Distance metric (unused).
"""
pass
def insert(self, vectors: List[list],
payloads: Optional[List[Dict]] = None,
ids: Optional[List[str]] = None):
"""
Insert vectors into the collection.
Creates or updates nodes in Neptune Analytics with vector embeddings and metadata.
Uses MERGE operation to handle both creation and updates.
Args:
vectors (List[list]): List of embedding vectors to insert.
payloads (Optional[List[Dict]]): Optional metadata for each vector.
ids (Optional[List[str]]): Optional IDs for vectors. Generated if not provided.
"""
para_list = []
for index, data_vector in enumerate(vectors):
if payloads:
payload = payloads[index]
payload[self._FIELD_LABEL] = self.collection_name
payload["updated_at"] = str(int(time.time()))
else:
payload = {}
para_list.append(dict(
node_id=ids[index] if ids else str(uuid.uuid4()),
properties=payload,
embedding=data_vector,
))
para_map_to_insert = {"rows": para_list}
query_string = (f"""
UNWIND $rows AS row
MERGE (n :{self.collection_name} {{`~id`: row.node_id}})
ON CREATE SET n = row.properties
ON MATCH SET n += row.properties
"""
)
self.execute_query(query_string, para_map_to_insert)
query_string_vector = (f"""
UNWIND $rows AS row
MATCH (n
:{self.collection_name}
{{`~id`: row.node_id}})
WITH n, row.embedding AS embedding
CALL neptune.algo.vectors.upsert(n, embedding)
YIELD success
RETURN success
"""
)
result = self.execute_query(query_string_vector, para_map_to_insert)
self._process_success_message(result, "Vector store - Insert")
def search(
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
) -> List[OutputData]:
"""
Search for similar vectors using embedding similarity.
Performs vector similarity search using Neptune Analytics' topKByEmbeddingWithFiltering
algorithm to find the most similar vectors.
Args:
query (str): Search query text (unused in vector search).
vectors (List[float]): Query embedding vector.
limit (int, optional): Maximum number of results to return. Defaults to 5.
filters (Optional[Dict]): Optional filters to apply to search results.
Returns:
List[OutputData]: List of similar vectors with scores and metadata.
"""
if not filters:
filters = {}
filters[self._FIELD_LABEL] = self.collection_name
filter_clause = self._get_node_filter_clause(filters)
query_string = f"""
CALL neptune.algo.vectors.topKByEmbeddingWithFiltering({{
topK: {limit},
embedding: {vectors}
{filter_clause}
}}
)
YIELD node, score
RETURN node as n, score
"""
query_response = self.execute_query(query_string)
if len(query_response) > 0:
return self._parse_query_responses(query_response, with_score=True)
else :
return []
def delete(self, vector_id: str):
"""
Delete a vector by its ID.
Removes the node and all its relationships from the Neptune Analytics graph.
Args:
vector_id (str): ID of the vector to delete.
"""
params = dict(node_id=vector_id)
query_string = f"""
MATCH (n :{self.collection_name})
WHERE id(n) = $node_id
DETACH DELETE n
"""
self.execute_query(query_string, params)
def update(
self,
vector_id: str,
vector: Optional[List[float]] = None,
payload: Optional[Dict] = None,
):
"""
Update a vector's embedding and/or metadata.
Updates the node properties and/or vector embedding for an existing vector.
Can update either the payload, the vector, or both.
Args:
vector_id (str): ID of the vector to update.
vector (Optional[List[float]]): New embedding vector.
payload (Optional[Dict]): New metadata to replace existing payload.
"""
if payload:
# Replace payload
payload[self._FIELD_LABEL] = self.collection_name
payload["updated_at"] = str(int(time.time()))
para_payload = {
"properties": payload,
"vector_id": vector_id
}
query_string_embedding = f"""
MATCH (n :{self.collection_name})
WHERE id(n) = $vector_id
SET n = $properties
"""
self.execute_query(query_string_embedding, para_payload)
if vector:
para_embedding = {
"embedding": vector,
"vector_id": vector_id
}
query_string_embedding = f"""
MATCH (n :{self.collection_name})
WHERE id(n) = $vector_id
WITH $embedding as embedding, n as n
CALL neptune.algo.vectors.upsert(n, embedding)
YIELD success
RETURN success
"""
self.execute_query(query_string_embedding, para_embedding)
def get(self, vector_id: str):
"""
Retrieve a vector by its ID.
Fetches the node data including metadata for the specified vector ID.
Args:
vector_id (str): ID of the vector to retrieve.
Returns:
OutputData: Vector data with metadata, or None if not found.
"""
params = dict(node_id=vector_id)
query_string = f"""
MATCH (n :{self.collection_name})
WHERE id(n) = $node_id
RETURN n
"""
# Composite the query
result = self.execute_query(query_string, params)
if len(result) != 0:
return self._parse_query_responses(result)[0]
def list_cols(self):
"""
List all collections with the Mem0 prefix.
Queries the Neptune Analytics schema to find all node labels that start
with the Mem0 collection prefix.
Returns:
List[str]: List of collection names.
"""
query_string = f"""
CALL neptune.graph.pg_schema()
YIELD schema
RETURN [ label IN schema.nodeLabels WHERE label STARTS WITH '{self.collection_name}'] AS result
"""
result = self.execute_query(query_string)
if len(result) == 1 and "result" in result[0]:
return result[0]["result"]
else:
return []
def delete_col(self):
"""
Delete the entire collection.
Removes all nodes with the collection label and their relationships
from the Neptune Analytics graph.
"""
self.execute_query(f"MATCH (n :{self.collection_name}) DETACH DELETE n")
def col_info(self):
"""
Get collection information (no-op for Neptune Analytics).
Collections are created dynamically in Neptune Analytics, so no
collection-specific metadata is available.
"""
pass
def list(self, filters: Optional[Dict] = None, limit: int = 100) -> List[OutputData]:
"""
List all vectors in the collection with optional filtering.
Retrieves vectors from the collection, optionally filtered by metadata properties.
Args:
filters (Optional[Dict]): Optional filters to apply based on metadata.
limit (int, optional): Maximum number of vectors to return. Defaults to 100.
Returns:
List[OutputData]: List of vectors with their metadata.
"""
where_clause = self._get_where_clause(filters) if filters else ""
para = {
"limit": limit,
}
query_string = f"""
MATCH (n :{self.collection_name})
{where_clause}
RETURN n
LIMIT $limit
"""
query_response = self.execute_query(query_string, para)
if len(query_response) > 0:
# Handle if there is no match.
return [self._parse_query_responses(query_response)]
return [[]]
def reset(self):
"""
Reset the collection by deleting all vectors.
Removes all vectors from the collection, effectively resetting it to empty state.
"""
self.delete_col()
def _parse_query_responses(self, response: dict, with_score: bool = False):
"""
Parse Neptune Analytics query responses into OutputData objects.
Args:
response (dict): Raw query response from Neptune Analytics.
with_score (bool, optional): Whether to include similarity scores. Defaults to False.
Returns:
List[OutputData]: Parsed response data.
"""
result = []
# Handle if there is no match.
for item in response:
id = item[self._FIELD_N][self._FIELD_ID]
properties = item[self._FIELD_N][self._FIELD_PROP]
properties.pop("label", None)
if with_score:
score = item[self._FIELD_SCORE]
else:
score = None
result.append(OutputData(
id=id,
score=score,
payload=properties,
))
return result
def execute_query(self, query_string: str, params=None):
"""
Execute an openCypher query on Neptune Analytics.
This is a wrapper method around the Neptune Analytics graph query execution
that provides debug logging for query monitoring and troubleshooting.
Args:
query_string (str): The openCypher query string to execute.
params (dict): Parameters to bind to the query.
Returns:
Query result from Neptune Analytics graph execution.
"""
if params is None:
params = {}
logger.debug(f"Executing openCypher query:[{query_string}], with parameters:[{params}].")
return self.graph.query(query_string, params)
@staticmethod
def _get_where_clause(filters: dict):
"""
Build WHERE clause for Cypher queries from filters.
Args:
filters (dict): Filter conditions as key-value pairs.
Returns:
str: Formatted WHERE clause for Cypher query.
"""
where_clause = ""
for i, (k, v) in enumerate(filters.items()):
if i == 0:
where_clause += f"WHERE n.{k} = '{v}' "
else:
where_clause += f"AND n.{k} = '{v}' "
return where_clause
@staticmethod
def _get_node_filter_clause(filters: dict):
"""
Build node filter clause for vector search operations.
Creates filter conditions for Neptune Analytics vector search operations
using the nodeFilter parameter format.
Args:
filters (dict): Filter conditions as key-value pairs.
Returns:
str: Formatted node filter clause for vector search.
"""
conditions = []
for k, v in filters.items():
conditions.append(f"{{equals:{{property: '{k}', value: '{v}'}}}}")
if len(conditions) == 1:
filter_clause = f", nodeFilter: {conditions[0]}"
else:
filter_clause = f"""
, nodeFilter: {{andAll: [ {", ".join(conditions)} ]}}
"""
return filter_clause
@staticmethod
def _process_success_message(response, context):
"""
Process and validate success messages from Neptune Analytics operations.
Checks the response from vector operations (insert/update) to ensure they
completed successfully. Logs errors if operations fail.
Args:
response: Response from Neptune Analytics vector operation.
context (str): Context description for logging (e.g., "Vector store - Insert").
"""
for success_message in response:
if "success" not in success_message:
logger.error(f"Query execution status is absent on action: [{context}]")
break
if success_message["success"] is not True:
logger.error(f"Abnormal response status on action: [{context}] with message: [{success_message['success']}] ")
break