468 lines
15 KiB
Python
468 lines
15 KiB
Python
import logging
|
|
import time
|
|
import uuid
|
|
from typing import Dict, List, Optional
|
|
|
|
from pydantic import BaseModel
|
|
|
|
try:
|
|
from langchain_aws import NeptuneAnalyticsGraph
|
|
except ImportError:
|
|
raise ImportError("langchain_aws is not installed. Please install it using pip install langchain_aws")
|
|
|
|
from mem0.vector_stores.base import VectorStoreBase
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class OutputData(BaseModel):
|
|
id: Optional[str] # memory id
|
|
score: Optional[float] # distance
|
|
payload: Optional[Dict] # metadata
|
|
|
|
|
|
class NeptuneAnalyticsVector(VectorStoreBase):
|
|
"""
|
|
Neptune Analytics vector store implementation for Mem0.
|
|
|
|
Provides vector storage and similarity search capabilities using Amazon Neptune Analytics,
|
|
a serverless graph analytics service that supports vector operations.
|
|
"""
|
|
|
|
_COLLECTION_PREFIX = "MEM0_VECTOR_"
|
|
_FIELD_N = 'n'
|
|
_FIELD_ID = '~id'
|
|
_FIELD_PROP = '~properties'
|
|
_FIELD_SCORE = 'score'
|
|
_FIELD_LABEL = 'label'
|
|
_TIMEZONE = "UTC"
|
|
|
|
def __init__(
|
|
self,
|
|
endpoint: str,
|
|
collection_name: str,
|
|
):
|
|
"""
|
|
Initialize the Neptune Analytics vector store.
|
|
|
|
Args:
|
|
endpoint (str): Neptune Analytics endpoint in format 'neptune-graph://<graphid>'.
|
|
collection_name (str): Name of the collection to store vectors.
|
|
|
|
Raises:
|
|
ValueError: If endpoint format is invalid.
|
|
ImportError: If langchain_aws is not installed.
|
|
"""
|
|
|
|
if not endpoint.startswith("neptune-graph://"):
|
|
raise ValueError("Please provide 'endpoint' with the format as 'neptune-graph://<graphid>'.")
|
|
|
|
graph_id = endpoint.replace("neptune-graph://", "")
|
|
self.graph = NeptuneAnalyticsGraph(graph_id)
|
|
self.collection_name = self._COLLECTION_PREFIX + collection_name
|
|
|
|
|
|
def create_col(self, name, vector_size, distance):
|
|
"""
|
|
Create a collection (no-op for Neptune Analytics).
|
|
|
|
Neptune Analytics supports dynamic indices that are created implicitly
|
|
when vectors are inserted, so this method performs no operation.
|
|
|
|
Args:
|
|
name: Collection name (unused).
|
|
vector_size: Vector dimension (unused).
|
|
distance: Distance metric (unused).
|
|
"""
|
|
pass
|
|
|
|
|
|
def insert(self, vectors: List[list],
|
|
payloads: Optional[List[Dict]] = None,
|
|
ids: Optional[List[str]] = None):
|
|
"""
|
|
Insert vectors into the collection.
|
|
|
|
Creates or updates nodes in Neptune Analytics with vector embeddings and metadata.
|
|
Uses MERGE operation to handle both creation and updates.
|
|
|
|
Args:
|
|
vectors (List[list]): List of embedding vectors to insert.
|
|
payloads (Optional[List[Dict]]): Optional metadata for each vector.
|
|
ids (Optional[List[str]]): Optional IDs for vectors. Generated if not provided.
|
|
"""
|
|
|
|
para_list = []
|
|
for index, data_vector in enumerate(vectors):
|
|
if payloads:
|
|
payload = payloads[index]
|
|
payload[self._FIELD_LABEL] = self.collection_name
|
|
payload["updated_at"] = str(int(time.time()))
|
|
else:
|
|
payload = {}
|
|
para_list.append(dict(
|
|
node_id=ids[index] if ids else str(uuid.uuid4()),
|
|
properties=payload,
|
|
embedding=data_vector,
|
|
))
|
|
|
|
para_map_to_insert = {"rows": para_list}
|
|
|
|
query_string = (f"""
|
|
UNWIND $rows AS row
|
|
MERGE (n :{self.collection_name} {{`~id`: row.node_id}})
|
|
ON CREATE SET n = row.properties
|
|
ON MATCH SET n += row.properties
|
|
"""
|
|
)
|
|
self.execute_query(query_string, para_map_to_insert)
|
|
|
|
|
|
query_string_vector = (f"""
|
|
UNWIND $rows AS row
|
|
MATCH (n
|
|
:{self.collection_name}
|
|
{{`~id`: row.node_id}})
|
|
WITH n, row.embedding AS embedding
|
|
CALL neptune.algo.vectors.upsert(n, embedding)
|
|
YIELD success
|
|
RETURN success
|
|
"""
|
|
)
|
|
result = self.execute_query(query_string_vector, para_map_to_insert)
|
|
self._process_success_message(result, "Vector store - Insert")
|
|
|
|
|
|
def search(
|
|
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
|
|
) -> List[OutputData]:
|
|
"""
|
|
Search for similar vectors using embedding similarity.
|
|
|
|
Performs vector similarity search using Neptune Analytics' topKByEmbeddingWithFiltering
|
|
algorithm to find the most similar vectors.
|
|
|
|
Args:
|
|
query (str): Search query text (unused in vector search).
|
|
vectors (List[float]): Query embedding vector.
|
|
limit (int, optional): Maximum number of results to return. Defaults to 5.
|
|
filters (Optional[Dict]): Optional filters to apply to search results.
|
|
|
|
Returns:
|
|
List[OutputData]: List of similar vectors with scores and metadata.
|
|
"""
|
|
|
|
if not filters:
|
|
filters = {}
|
|
filters[self._FIELD_LABEL] = self.collection_name
|
|
|
|
filter_clause = self._get_node_filter_clause(filters)
|
|
|
|
query_string = f"""
|
|
CALL neptune.algo.vectors.topKByEmbeddingWithFiltering({{
|
|
topK: {limit},
|
|
embedding: {vectors}
|
|
{filter_clause}
|
|
}}
|
|
)
|
|
YIELD node, score
|
|
RETURN node as n, score
|
|
"""
|
|
query_response = self.execute_query(query_string)
|
|
if len(query_response) > 0:
|
|
return self._parse_query_responses(query_response, with_score=True)
|
|
else :
|
|
return []
|
|
|
|
|
|
def delete(self, vector_id: str):
|
|
"""
|
|
Delete a vector by its ID.
|
|
|
|
Removes the node and all its relationships from the Neptune Analytics graph.
|
|
|
|
Args:
|
|
vector_id (str): ID of the vector to delete.
|
|
"""
|
|
params = dict(node_id=vector_id)
|
|
query_string = f"""
|
|
MATCH (n :{self.collection_name})
|
|
WHERE id(n) = $node_id
|
|
DETACH DELETE n
|
|
"""
|
|
self.execute_query(query_string, params)
|
|
|
|
def update(
|
|
self,
|
|
vector_id: str,
|
|
vector: Optional[List[float]] = None,
|
|
payload: Optional[Dict] = None,
|
|
):
|
|
"""
|
|
Update a vector's embedding and/or metadata.
|
|
|
|
Updates the node properties and/or vector embedding for an existing vector.
|
|
Can update either the payload, the vector, or both.
|
|
|
|
Args:
|
|
vector_id (str): ID of the vector to update.
|
|
vector (Optional[List[float]]): New embedding vector.
|
|
payload (Optional[Dict]): New metadata to replace existing payload.
|
|
"""
|
|
|
|
if payload:
|
|
# Replace payload
|
|
payload[self._FIELD_LABEL] = self.collection_name
|
|
payload["updated_at"] = str(int(time.time()))
|
|
para_payload = {
|
|
"properties": payload,
|
|
"vector_id": vector_id
|
|
}
|
|
query_string_embedding = f"""
|
|
MATCH (n :{self.collection_name})
|
|
WHERE id(n) = $vector_id
|
|
SET n = $properties
|
|
"""
|
|
self.execute_query(query_string_embedding, para_payload)
|
|
|
|
if vector:
|
|
para_embedding = {
|
|
"embedding": vector,
|
|
"vector_id": vector_id
|
|
}
|
|
query_string_embedding = f"""
|
|
MATCH (n :{self.collection_name})
|
|
WHERE id(n) = $vector_id
|
|
WITH $embedding as embedding, n as n
|
|
CALL neptune.algo.vectors.upsert(n, embedding)
|
|
YIELD success
|
|
RETURN success
|
|
"""
|
|
self.execute_query(query_string_embedding, para_embedding)
|
|
|
|
|
|
|
|
def get(self, vector_id: str):
|
|
"""
|
|
Retrieve a vector by its ID.
|
|
|
|
Fetches the node data including metadata for the specified vector ID.
|
|
|
|
Args:
|
|
vector_id (str): ID of the vector to retrieve.
|
|
|
|
Returns:
|
|
OutputData: Vector data with metadata, or None if not found.
|
|
"""
|
|
params = dict(node_id=vector_id)
|
|
query_string = f"""
|
|
MATCH (n :{self.collection_name})
|
|
WHERE id(n) = $node_id
|
|
RETURN n
|
|
"""
|
|
|
|
# Composite the query
|
|
result = self.execute_query(query_string, params)
|
|
|
|
if len(result) != 0:
|
|
return self._parse_query_responses(result)[0]
|
|
|
|
|
|
def list_cols(self):
|
|
"""
|
|
List all collections with the Mem0 prefix.
|
|
|
|
Queries the Neptune Analytics schema to find all node labels that start
|
|
with the Mem0 collection prefix.
|
|
|
|
Returns:
|
|
List[str]: List of collection names.
|
|
"""
|
|
query_string = f"""
|
|
CALL neptune.graph.pg_schema()
|
|
YIELD schema
|
|
RETURN [ label IN schema.nodeLabels WHERE label STARTS WITH '{self.collection_name}'] AS result
|
|
"""
|
|
result = self.execute_query(query_string)
|
|
if len(result) == 1 and "result" in result[0]:
|
|
return result[0]["result"]
|
|
else:
|
|
return []
|
|
|
|
|
|
def delete_col(self):
|
|
"""
|
|
Delete the entire collection.
|
|
|
|
Removes all nodes with the collection label and their relationships
|
|
from the Neptune Analytics graph.
|
|
"""
|
|
self.execute_query(f"MATCH (n :{self.collection_name}) DETACH DELETE n")
|
|
|
|
|
|
def col_info(self):
|
|
"""
|
|
Get collection information (no-op for Neptune Analytics).
|
|
|
|
Collections are created dynamically in Neptune Analytics, so no
|
|
collection-specific metadata is available.
|
|
"""
|
|
pass
|
|
|
|
|
|
def list(self, filters: Optional[Dict] = None, limit: int = 100) -> List[OutputData]:
|
|
"""
|
|
List all vectors in the collection with optional filtering.
|
|
|
|
Retrieves vectors from the collection, optionally filtered by metadata properties.
|
|
|
|
Args:
|
|
filters (Optional[Dict]): Optional filters to apply based on metadata.
|
|
limit (int, optional): Maximum number of vectors to return. Defaults to 100.
|
|
|
|
Returns:
|
|
List[OutputData]: List of vectors with their metadata.
|
|
"""
|
|
where_clause = self._get_where_clause(filters) if filters else ""
|
|
|
|
para = {
|
|
"limit": limit,
|
|
}
|
|
query_string = f"""
|
|
MATCH (n :{self.collection_name})
|
|
{where_clause}
|
|
RETURN n
|
|
LIMIT $limit
|
|
"""
|
|
query_response = self.execute_query(query_string, para)
|
|
|
|
if len(query_response) > 0:
|
|
# Handle if there is no match.
|
|
return [self._parse_query_responses(query_response)]
|
|
return [[]]
|
|
|
|
|
|
def reset(self):
|
|
"""
|
|
Reset the collection by deleting all vectors.
|
|
|
|
Removes all vectors from the collection, effectively resetting it to empty state.
|
|
"""
|
|
self.delete_col()
|
|
|
|
|
|
def _parse_query_responses(self, response: dict, with_score: bool = False):
|
|
"""
|
|
Parse Neptune Analytics query responses into OutputData objects.
|
|
|
|
Args:
|
|
response (dict): Raw query response from Neptune Analytics.
|
|
with_score (bool, optional): Whether to include similarity scores. Defaults to False.
|
|
|
|
Returns:
|
|
List[OutputData]: Parsed response data.
|
|
"""
|
|
result = []
|
|
# Handle if there is no match.
|
|
for item in response:
|
|
id = item[self._FIELD_N][self._FIELD_ID]
|
|
properties = item[self._FIELD_N][self._FIELD_PROP]
|
|
properties.pop("label", None)
|
|
if with_score:
|
|
score = item[self._FIELD_SCORE]
|
|
else:
|
|
score = None
|
|
result.append(OutputData(
|
|
id=id,
|
|
score=score,
|
|
payload=properties,
|
|
))
|
|
return result
|
|
|
|
|
|
def execute_query(self, query_string: str, params=None):
|
|
"""
|
|
Execute an openCypher query on Neptune Analytics.
|
|
|
|
This is a wrapper method around the Neptune Analytics graph query execution
|
|
that provides debug logging for query monitoring and troubleshooting.
|
|
|
|
Args:
|
|
query_string (str): The openCypher query string to execute.
|
|
params (dict): Parameters to bind to the query.
|
|
|
|
Returns:
|
|
Query result from Neptune Analytics graph execution.
|
|
"""
|
|
if params is None:
|
|
params = {}
|
|
logger.debug(f"Executing openCypher query:[{query_string}], with parameters:[{params}].")
|
|
return self.graph.query(query_string, params)
|
|
|
|
|
|
@staticmethod
|
|
def _get_where_clause(filters: dict):
|
|
"""
|
|
Build WHERE clause for Cypher queries from filters.
|
|
|
|
Args:
|
|
filters (dict): Filter conditions as key-value pairs.
|
|
|
|
Returns:
|
|
str: Formatted WHERE clause for Cypher query.
|
|
"""
|
|
where_clause = ""
|
|
for i, (k, v) in enumerate(filters.items()):
|
|
if i == 0:
|
|
where_clause += f"WHERE n.{k} = '{v}' "
|
|
else:
|
|
where_clause += f"AND n.{k} = '{v}' "
|
|
return where_clause
|
|
|
|
@staticmethod
|
|
def _get_node_filter_clause(filters: dict):
|
|
"""
|
|
Build node filter clause for vector search operations.
|
|
|
|
Creates filter conditions for Neptune Analytics vector search operations
|
|
using the nodeFilter parameter format.
|
|
|
|
Args:
|
|
filters (dict): Filter conditions as key-value pairs.
|
|
|
|
Returns:
|
|
str: Formatted node filter clause for vector search.
|
|
"""
|
|
conditions = []
|
|
for k, v in filters.items():
|
|
conditions.append(f"{{equals:{{property: '{k}', value: '{v}'}}}}")
|
|
|
|
if len(conditions) == 1:
|
|
filter_clause = f", nodeFilter: {conditions[0]}"
|
|
else:
|
|
filter_clause = f"""
|
|
, nodeFilter: {{andAll: [ {", ".join(conditions)} ]}}
|
|
"""
|
|
return filter_clause
|
|
|
|
|
|
@staticmethod
|
|
def _process_success_message(response, context):
|
|
"""
|
|
Process and validate success messages from Neptune Analytics operations.
|
|
|
|
Checks the response from vector operations (insert/update) to ensure they
|
|
completed successfully. Logs errors if operations fail.
|
|
|
|
Args:
|
|
response: Response from Neptune Analytics vector operation.
|
|
context (str): Context description for logging (e.g., "Vector store - Insert").
|
|
"""
|
|
for success_message in response:
|
|
if "success" not in success_message:
|
|
logger.error(f"Query execution status is absent on action: [{context}]")
|
|
break
|
|
|
|
if success_message["success"] is not True:
|
|
logger.error(f"Abnormal response status on action: [{context}] with message: [{success_message['success']}] ")
|
|
break
|