Initial clean commit - unified Lyra stack
This commit is contained in:
0
neomem/neomem/vector_stores/__init__.py
Normal file
0
neomem/neomem/vector_stores/__init__.py
Normal file
396
neomem/neomem/vector_stores/azure_ai_search.py
Normal file
396
neomem/neomem/vector_stores/azure_ai_search.py
Normal file
@@ -0,0 +1,396 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from mem0.memory.utils import extract_json
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
try:
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
from azure.identity import DefaultAzureCredential
|
||||
from azure.search.documents import SearchClient
|
||||
from azure.search.documents.indexes import SearchIndexClient
|
||||
from azure.search.documents.indexes.models import (
|
||||
BinaryQuantizationCompression,
|
||||
HnswAlgorithmConfiguration,
|
||||
ScalarQuantizationCompression,
|
||||
SearchField,
|
||||
SearchFieldDataType,
|
||||
SearchIndex,
|
||||
SimpleField,
|
||||
VectorSearch,
|
||||
VectorSearchProfile,
|
||||
)
|
||||
from azure.search.documents.models import VectorizedQuery
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The 'azure-search-documents' library is required. Please install it using 'pip install azure-search-documents==11.5.2'."
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OutputData(BaseModel):
|
||||
id: Optional[str]
|
||||
score: Optional[float]
|
||||
payload: Optional[dict]
|
||||
|
||||
|
||||
class AzureAISearch(VectorStoreBase):
|
||||
def __init__(
|
||||
self,
|
||||
service_name,
|
||||
collection_name,
|
||||
api_key,
|
||||
embedding_model_dims,
|
||||
compression_type: Optional[str] = None,
|
||||
use_float16: bool = False,
|
||||
hybrid_search: bool = False,
|
||||
vector_filter_mode: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the Azure AI Search vector store.
|
||||
|
||||
Args:
|
||||
service_name (str): Azure AI Search service name.
|
||||
collection_name (str): Index name.
|
||||
api_key (str): API key for the Azure AI Search service.
|
||||
embedding_model_dims (int): Dimension of the embedding vector.
|
||||
compression_type (Optional[str]): Specifies the type of quantization to use.
|
||||
Allowed values are None (no quantization), "scalar", or "binary".
|
||||
use_float16 (bool): Whether to store vectors in half precision (Edm.Half) or full precision (Edm.Single).
|
||||
(Note: This flag is preserved from the initial implementation per feedback.)
|
||||
hybrid_search (bool): Whether to use hybrid search. Default is False.
|
||||
vector_filter_mode (Optional[str]): Mode for vector filtering. Default is "preFilter".
|
||||
"""
|
||||
self.service_name = service_name
|
||||
self.api_key = api_key
|
||||
self.index_name = collection_name
|
||||
self.collection_name = collection_name
|
||||
self.embedding_model_dims = embedding_model_dims
|
||||
# If compression_type is None, treat it as "none".
|
||||
self.compression_type = (compression_type or "none").lower()
|
||||
self.use_float16 = use_float16
|
||||
self.hybrid_search = hybrid_search
|
||||
self.vector_filter_mode = vector_filter_mode
|
||||
|
||||
# If the API key is not provided or is a placeholder, use DefaultAzureCredential.
|
||||
if self.api_key is None or self.api_key == "" or self.api_key == "your-api-key":
|
||||
credential = DefaultAzureCredential()
|
||||
self.api_key = None
|
||||
else:
|
||||
credential = AzureKeyCredential(self.api_key)
|
||||
|
||||
self.search_client = SearchClient(
|
||||
endpoint=f"https://{service_name}.search.windows.net",
|
||||
index_name=self.index_name,
|
||||
credential=credential,
|
||||
)
|
||||
self.index_client = SearchIndexClient(
|
||||
endpoint=f"https://{service_name}.search.windows.net",
|
||||
credential=credential,
|
||||
)
|
||||
|
||||
self.search_client._client._config.user_agent_policy.add_user_agent("mem0")
|
||||
self.index_client._client._config.user_agent_policy.add_user_agent("mem0")
|
||||
|
||||
collections = self.list_cols()
|
||||
if collection_name not in collections:
|
||||
self.create_col()
|
||||
|
||||
def create_col(self):
|
||||
"""Create a new index in Azure AI Search."""
|
||||
# Determine vector type based on use_float16 setting.
|
||||
if self.use_float16:
|
||||
vector_type = "Collection(Edm.Half)"
|
||||
else:
|
||||
vector_type = "Collection(Edm.Single)"
|
||||
|
||||
# Configure compression settings based on the specified compression_type.
|
||||
compression_configurations = []
|
||||
compression_name = None
|
||||
if self.compression_type == "scalar":
|
||||
compression_name = "myCompression"
|
||||
# For SQ, rescoring defaults to True and oversampling defaults to 4.
|
||||
compression_configurations = [
|
||||
ScalarQuantizationCompression(
|
||||
compression_name=compression_name
|
||||
# rescoring defaults to True and oversampling defaults to 4
|
||||
)
|
||||
]
|
||||
elif self.compression_type == "binary":
|
||||
compression_name = "myCompression"
|
||||
# For BQ, rescoring defaults to True and oversampling defaults to 10.
|
||||
compression_configurations = [
|
||||
BinaryQuantizationCompression(
|
||||
compression_name=compression_name
|
||||
# rescoring defaults to True and oversampling defaults to 10
|
||||
)
|
||||
]
|
||||
# If no compression is desired, compression_configurations remains empty.
|
||||
fields = [
|
||||
SimpleField(name="id", type=SearchFieldDataType.String, key=True),
|
||||
SimpleField(name="user_id", type=SearchFieldDataType.String, filterable=True),
|
||||
SimpleField(name="run_id", type=SearchFieldDataType.String, filterable=True),
|
||||
SimpleField(name="agent_id", type=SearchFieldDataType.String, filterable=True),
|
||||
SearchField(
|
||||
name="vector",
|
||||
type=vector_type,
|
||||
searchable=True,
|
||||
vector_search_dimensions=self.embedding_model_dims,
|
||||
vector_search_profile_name="my-vector-config",
|
||||
),
|
||||
SearchField(name="payload", type=SearchFieldDataType.String, searchable=True),
|
||||
]
|
||||
|
||||
vector_search = VectorSearch(
|
||||
profiles=[
|
||||
VectorSearchProfile(
|
||||
name="my-vector-config",
|
||||
algorithm_configuration_name="my-algorithms-config",
|
||||
compression_name=compression_name if self.compression_type != "none" else None,
|
||||
)
|
||||
],
|
||||
algorithms=[HnswAlgorithmConfiguration(name="my-algorithms-config")],
|
||||
compressions=compression_configurations,
|
||||
)
|
||||
index = SearchIndex(name=self.index_name, fields=fields, vector_search=vector_search)
|
||||
self.index_client.create_or_update_index(index)
|
||||
|
||||
def _generate_document(self, vector, payload, id):
|
||||
document = {"id": id, "vector": vector, "payload": json.dumps(payload)}
|
||||
# Extract additional fields if they exist.
|
||||
for field in ["user_id", "run_id", "agent_id"]:
|
||||
if field in payload:
|
||||
document[field] = payload[field]
|
||||
return document
|
||||
|
||||
# Note: Explicit "insert" calls may later be decoupled from memory management decisions.
|
||||
def insert(self, vectors, payloads=None, ids=None):
|
||||
"""
|
||||
Insert vectors into the index.
|
||||
|
||||
Args:
|
||||
vectors (List[List[float]]): List of vectors to insert.
|
||||
payloads (List[Dict], optional): List of payloads corresponding to vectors.
|
||||
ids (List[str], optional): List of IDs corresponding to vectors.
|
||||
"""
|
||||
logger.info(f"Inserting {len(vectors)} vectors into index {self.index_name}")
|
||||
documents = [
|
||||
self._generate_document(vector, payload, id) for id, vector, payload in zip(ids, vectors, payloads)
|
||||
]
|
||||
response = self.search_client.upload_documents(documents)
|
||||
for doc in response:
|
||||
if not hasattr(doc, "status_code") and doc.get("status_code") != 201:
|
||||
raise Exception(f"Insert failed for document {doc.get('id')}: {doc}")
|
||||
return response
|
||||
|
||||
def _sanitize_key(self, key: str) -> str:
|
||||
return re.sub(r"[^\w]", "", key)
|
||||
|
||||
def _build_filter_expression(self, filters):
|
||||
filter_conditions = []
|
||||
for key, value in filters.items():
|
||||
safe_key = self._sanitize_key(key)
|
||||
if isinstance(value, str):
|
||||
safe_value = value.replace("'", "''")
|
||||
condition = f"{safe_key} eq '{safe_value}'"
|
||||
else:
|
||||
condition = f"{safe_key} eq {value}"
|
||||
filter_conditions.append(condition)
|
||||
filter_expression = " and ".join(filter_conditions)
|
||||
return filter_expression
|
||||
|
||||
def search(self, query, vectors, limit=5, filters=None):
|
||||
"""
|
||||
Search for similar vectors.
|
||||
|
||||
Args:
|
||||
query (str): Query.
|
||||
vectors (List[float]): Query vector.
|
||||
limit (int, optional): Number of results to return. Defaults to 5.
|
||||
filters (Dict, optional): Filters to apply to the search. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[OutputData]: Search results.
|
||||
"""
|
||||
filter_expression = None
|
||||
if filters:
|
||||
filter_expression = self._build_filter_expression(filters)
|
||||
|
||||
vector_query = VectorizedQuery(vector=vectors, k_nearest_neighbors=limit, fields="vector")
|
||||
if self.hybrid_search:
|
||||
search_results = self.search_client.search(
|
||||
search_text=query,
|
||||
vector_queries=[vector_query],
|
||||
filter=filter_expression,
|
||||
top=limit,
|
||||
vector_filter_mode=self.vector_filter_mode,
|
||||
search_fields=["payload"],
|
||||
)
|
||||
else:
|
||||
search_results = self.search_client.search(
|
||||
vector_queries=[vector_query],
|
||||
filter=filter_expression,
|
||||
top=limit,
|
||||
vector_filter_mode=self.vector_filter_mode,
|
||||
)
|
||||
|
||||
results = []
|
||||
for result in search_results:
|
||||
payload = json.loads(extract_json(result["payload"]))
|
||||
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
|
||||
return results
|
||||
|
||||
def delete(self, vector_id):
|
||||
"""
|
||||
Delete a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to delete.
|
||||
"""
|
||||
response = self.search_client.delete_documents(documents=[{"id": vector_id}])
|
||||
for doc in response:
|
||||
if not hasattr(doc, "status_code") and doc.get("status_code") != 200:
|
||||
raise Exception(f"Delete failed for document {vector_id}: {doc}")
|
||||
logger.info(f"Deleted document with ID '{vector_id}' from index '{self.index_name}'.")
|
||||
return response
|
||||
|
||||
def update(self, vector_id, vector=None, payload=None):
|
||||
"""
|
||||
Update a vector and its payload.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to update.
|
||||
vector (List[float], optional): Updated vector.
|
||||
payload (Dict, optional): Updated payload.
|
||||
"""
|
||||
document = {"id": vector_id}
|
||||
if vector:
|
||||
document["vector"] = vector
|
||||
if payload:
|
||||
json_payload = json.dumps(payload)
|
||||
document["payload"] = json_payload
|
||||
for field in ["user_id", "run_id", "agent_id"]:
|
||||
document[field] = payload.get(field)
|
||||
response = self.search_client.merge_or_upload_documents(documents=[document])
|
||||
for doc in response:
|
||||
if not hasattr(doc, "status_code") and doc.get("status_code") != 200:
|
||||
raise Exception(f"Update failed for document {vector_id}: {doc}")
|
||||
return response
|
||||
|
||||
def get(self, vector_id) -> OutputData:
|
||||
"""
|
||||
Retrieve a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to retrieve.
|
||||
|
||||
Returns:
|
||||
OutputData: Retrieved vector.
|
||||
"""
|
||||
try:
|
||||
result = self.search_client.get_document(key=vector_id)
|
||||
except ResourceNotFoundError:
|
||||
return None
|
||||
payload = json.loads(extract_json(result["payload"]))
|
||||
return OutputData(id=result["id"], score=None, payload=payload)
|
||||
|
||||
def list_cols(self) -> List[str]:
|
||||
"""
|
||||
List all collections (indexes).
|
||||
|
||||
Returns:
|
||||
List[str]: List of index names.
|
||||
"""
|
||||
try:
|
||||
names = self.index_client.list_index_names()
|
||||
except AttributeError:
|
||||
names = [index.name for index in self.index_client.list_indexes()]
|
||||
return names
|
||||
|
||||
def delete_col(self):
|
||||
"""Delete the index."""
|
||||
self.index_client.delete_index(self.index_name)
|
||||
|
||||
def col_info(self):
|
||||
"""
|
||||
Get information about the index.
|
||||
|
||||
Returns:
|
||||
dict: Index information.
|
||||
"""
|
||||
index = self.index_client.get_index(self.index_name)
|
||||
return {"name": index.name, "fields": index.fields}
|
||||
|
||||
def list(self, filters=None, limit=100):
|
||||
"""
|
||||
List all vectors in the index.
|
||||
|
||||
Args:
|
||||
filters (dict, optional): Filters to apply to the list.
|
||||
limit (int, optional): Number of vectors to return. Defaults to 100.
|
||||
|
||||
Returns:
|
||||
List[OutputData]: List of vectors.
|
||||
"""
|
||||
filter_expression = None
|
||||
if filters:
|
||||
filter_expression = self._build_filter_expression(filters)
|
||||
|
||||
search_results = self.search_client.search(search_text="*", filter=filter_expression, top=limit)
|
||||
results = []
|
||||
for result in search_results:
|
||||
payload = json.loads(extract_json(result["payload"]))
|
||||
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
|
||||
return [results]
|
||||
|
||||
def __del__(self):
|
||||
"""Close the search client when the object is deleted."""
|
||||
self.search_client.close()
|
||||
self.index_client.close()
|
||||
|
||||
def reset(self):
|
||||
"""Reset the index by deleting and recreating it."""
|
||||
logger.warning(f"Resetting index {self.index_name}...")
|
||||
|
||||
try:
|
||||
# Close the existing clients
|
||||
self.search_client.close()
|
||||
self.index_client.close()
|
||||
|
||||
# Delete the collection
|
||||
self.delete_col()
|
||||
|
||||
# If the API key is not provided or is a placeholder, use DefaultAzureCredential.
|
||||
if self.api_key is None or self.api_key == "" or self.api_key == "your-api-key":
|
||||
credential = DefaultAzureCredential()
|
||||
self.api_key = None
|
||||
else:
|
||||
credential = AzureKeyCredential(self.api_key)
|
||||
|
||||
# Reinitialize the clients
|
||||
service_endpoint = f"https://{self.service_name}.search.windows.net"
|
||||
self.search_client = SearchClient(
|
||||
endpoint=service_endpoint,
|
||||
index_name=self.index_name,
|
||||
credential=credential,
|
||||
)
|
||||
self.index_client = SearchIndexClient(
|
||||
endpoint=service_endpoint,
|
||||
credential=credential,
|
||||
)
|
||||
|
||||
# Add user agent
|
||||
self.search_client._client._config.user_agent_policy.add_user_agent("mem0")
|
||||
self.index_client._client._config.user_agent_policy.add_user_agent("mem0")
|
||||
|
||||
# Create the collection
|
||||
self.create_col()
|
||||
except Exception as e:
|
||||
logger.error(f"Error resetting index {self.index_name}: {e}")
|
||||
raise
|
||||
463
neomem/neomem/vector_stores/azure_mysql.py
Normal file
463
neomem/neomem/vector_stores/azure_mysql.py
Normal file
@@ -0,0 +1,463 @@
|
||||
import json
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
try:
|
||||
import pymysql
|
||||
from pymysql.cursors import DictCursor
|
||||
from dbutils.pooled_db import PooledDB
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Azure MySQL vector store requires PyMySQL and DBUtils. "
|
||||
"Please install them using 'pip install pymysql dbutils'"
|
||||
)
|
||||
|
||||
try:
|
||||
from azure.identity import DefaultAzureCredential
|
||||
AZURE_IDENTITY_AVAILABLE = True
|
||||
except ImportError:
|
||||
AZURE_IDENTITY_AVAILABLE = False
|
||||
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OutputData(BaseModel):
|
||||
id: Optional[str]
|
||||
score: Optional[float]
|
||||
payload: Optional[dict]
|
||||
|
||||
|
||||
class AzureMySQL(VectorStoreBase):
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
user: str,
|
||||
password: Optional[str],
|
||||
database: str,
|
||||
collection_name: str,
|
||||
embedding_model_dims: int,
|
||||
use_azure_credential: bool = False,
|
||||
ssl_ca: Optional[str] = None,
|
||||
ssl_disabled: bool = False,
|
||||
minconn: int = 1,
|
||||
maxconn: int = 5,
|
||||
connection_pool: Optional[Any] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the Azure MySQL vector store.
|
||||
|
||||
Args:
|
||||
host (str): MySQL server host
|
||||
port (int): MySQL server port
|
||||
user (str): Database user
|
||||
password (str, optional): Database password (not required if using Azure credential)
|
||||
database (str): Database name
|
||||
collection_name (str): Collection/table name
|
||||
embedding_model_dims (int): Dimension of the embedding vector
|
||||
use_azure_credential (bool): Use Azure DefaultAzureCredential for authentication
|
||||
ssl_ca (str, optional): Path to SSL CA certificate
|
||||
ssl_disabled (bool): Disable SSL connection
|
||||
minconn (int): Minimum number of connections in the pool
|
||||
maxconn (int): Maximum number of connections in the pool
|
||||
connection_pool (Any, optional): Pre-configured connection pool
|
||||
"""
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.user = user
|
||||
self.password = password
|
||||
self.database = database
|
||||
self.collection_name = collection_name
|
||||
self.embedding_model_dims = embedding_model_dims
|
||||
self.use_azure_credential = use_azure_credential
|
||||
self.ssl_ca = ssl_ca
|
||||
self.ssl_disabled = ssl_disabled
|
||||
self.connection_pool = connection_pool
|
||||
|
||||
# Handle Azure authentication
|
||||
if use_azure_credential:
|
||||
if not AZURE_IDENTITY_AVAILABLE:
|
||||
raise ImportError(
|
||||
"Azure Identity is required for Azure credential authentication. "
|
||||
"Please install it using 'pip install azure-identity'"
|
||||
)
|
||||
self._setup_azure_auth()
|
||||
|
||||
# Setup connection pool
|
||||
if self.connection_pool is None:
|
||||
self._setup_connection_pool(minconn, maxconn)
|
||||
|
||||
# Create collection if it doesn't exist
|
||||
collections = self.list_cols()
|
||||
if collection_name not in collections:
|
||||
self.create_col(name=collection_name, vector_size=embedding_model_dims, distance="cosine")
|
||||
|
||||
def _setup_azure_auth(self):
|
||||
"""Setup Azure authentication using DefaultAzureCredential."""
|
||||
try:
|
||||
credential = DefaultAzureCredential()
|
||||
# Get access token for Azure Database for MySQL
|
||||
token = credential.get_token("https://ossrdbms-aad.database.windows.net/.default")
|
||||
# Use token as password
|
||||
self.password = token.token
|
||||
logger.info("Successfully authenticated using Azure DefaultAzureCredential")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to authenticate with Azure: {e}")
|
||||
raise
|
||||
|
||||
def _setup_connection_pool(self, minconn: int, maxconn: int):
|
||||
"""Setup MySQL connection pool."""
|
||||
connect_kwargs = {
|
||||
"host": self.host,
|
||||
"port": self.port,
|
||||
"user": self.user,
|
||||
"password": self.password,
|
||||
"database": self.database,
|
||||
"charset": "utf8mb4",
|
||||
"cursorclass": DictCursor,
|
||||
"autocommit": False,
|
||||
}
|
||||
|
||||
# SSL configuration
|
||||
if not self.ssl_disabled:
|
||||
ssl_config = {"ssl_verify_cert": True}
|
||||
if self.ssl_ca:
|
||||
ssl_config["ssl_ca"] = self.ssl_ca
|
||||
connect_kwargs["ssl"] = ssl_config
|
||||
|
||||
try:
|
||||
self.connection_pool = PooledDB(
|
||||
creator=pymysql,
|
||||
mincached=minconn,
|
||||
maxcached=maxconn,
|
||||
maxconnections=maxconn,
|
||||
blocking=True,
|
||||
**connect_kwargs
|
||||
)
|
||||
logger.info("Successfully created MySQL connection pool")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create connection pool: {e}")
|
||||
raise
|
||||
|
||||
@contextmanager
|
||||
def _get_cursor(self, commit: bool = False):
|
||||
"""
|
||||
Context manager to get a cursor from the connection pool.
|
||||
Auto-commits or rolls back based on exception.
|
||||
"""
|
||||
conn = self.connection_pool.connection()
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
yield cur
|
||||
if commit:
|
||||
conn.commit()
|
||||
except Exception as exc:
|
||||
conn.rollback()
|
||||
logger.error(f"Database error: {exc}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
cur.close()
|
||||
conn.close()
|
||||
|
||||
def create_col(self, name: str = None, vector_size: int = None, distance: str = "cosine"):
|
||||
"""
|
||||
Create a new collection (table in MySQL).
|
||||
Enables vector extension and creates appropriate indexes.
|
||||
|
||||
Args:
|
||||
name (str, optional): Collection name (uses self.collection_name if not provided)
|
||||
vector_size (int, optional): Vector dimension (uses self.embedding_model_dims if not provided)
|
||||
distance (str): Distance metric (cosine, euclidean, dot_product)
|
||||
"""
|
||||
table_name = name or self.collection_name
|
||||
dims = vector_size or self.embedding_model_dims
|
||||
|
||||
with self._get_cursor(commit=True) as cur:
|
||||
# Create table with vector column
|
||||
cur.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS `{table_name}` (
|
||||
id VARCHAR(255) PRIMARY KEY,
|
||||
vector JSON,
|
||||
payload JSON,
|
||||
INDEX idx_payload_keys ((CAST(payload AS CHAR(255)) ARRAY))
|
||||
)
|
||||
""")
|
||||
logger.info(f"Created collection '{table_name}' with vector dimension {dims}")
|
||||
|
||||
def insert(self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None):
|
||||
"""
|
||||
Insert vectors into the collection.
|
||||
|
||||
Args:
|
||||
vectors (List[List[float]]): List of vectors to insert
|
||||
payloads (List[Dict], optional): List of payloads corresponding to vectors
|
||||
ids (List[str], optional): List of IDs corresponding to vectors
|
||||
"""
|
||||
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
|
||||
|
||||
if payloads is None:
|
||||
payloads = [{}] * len(vectors)
|
||||
if ids is None:
|
||||
import uuid
|
||||
ids = [str(uuid.uuid4()) for _ in range(len(vectors))]
|
||||
|
||||
data = []
|
||||
for vector, payload, vec_id in zip(vectors, payloads, ids):
|
||||
data.append((vec_id, json.dumps(vector), json.dumps(payload)))
|
||||
|
||||
with self._get_cursor(commit=True) as cur:
|
||||
cur.executemany(
|
||||
f"INSERT INTO `{self.collection_name}` (id, vector, payload) VALUES (%s, %s, %s) "
|
||||
f"ON DUPLICATE KEY UPDATE vector = VALUES(vector), payload = VALUES(payload)",
|
||||
data
|
||||
)
|
||||
|
||||
def _cosine_distance(self, vec1_json: str, vec2: List[float]) -> str:
|
||||
"""Generate SQL for cosine distance calculation."""
|
||||
# For MySQL, we need to calculate cosine similarity manually
|
||||
# This is a simplified version - in production, you'd use stored procedures or UDFs
|
||||
return """
|
||||
1 - (
|
||||
(SELECT SUM(a.val * b.val) /
|
||||
(SQRT(SUM(a.val * a.val)) * SQRT(SUM(b.val * b.val))))
|
||||
FROM (
|
||||
SELECT JSON_EXTRACT(vector, CONCAT('$[', idx, ']')) as val
|
||||
FROM (SELECT @row := @row + 1 as idx FROM (SELECT 0 UNION ALL SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3) t1, (SELECT 0 UNION ALL SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3) t2) indices
|
||||
WHERE idx < JSON_LENGTH(vector)
|
||||
) a,
|
||||
(
|
||||
SELECT JSON_EXTRACT(%s, CONCAT('$[', idx, ']')) as val
|
||||
FROM (SELECT @row := @row + 1 as idx FROM (SELECT 0 UNION ALL SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3) t1, (SELECT 0 UNION ALL SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3) t2) indices
|
||||
WHERE idx < JSON_LENGTH(%s)
|
||||
) b
|
||||
WHERE a.idx = b.idx
|
||||
)
|
||||
"""
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
vectors: List[float],
|
||||
limit: int = 5,
|
||||
filters: Optional[Dict] = None,
|
||||
) -> List[OutputData]:
|
||||
"""
|
||||
Search for similar vectors using cosine similarity.
|
||||
|
||||
Args:
|
||||
query (str): Query string (not used in vector search)
|
||||
vectors (List[float]): Query vector
|
||||
limit (int): Number of results to return
|
||||
filters (Dict, optional): Filters to apply to the search
|
||||
|
||||
Returns:
|
||||
List[OutputData]: Search results
|
||||
"""
|
||||
filter_conditions = []
|
||||
filter_params = []
|
||||
|
||||
if filters:
|
||||
for k, v in filters.items():
|
||||
filter_conditions.append("JSON_EXTRACT(payload, %s) = %s")
|
||||
filter_params.extend([f"$.{k}", json.dumps(v)])
|
||||
|
||||
filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
|
||||
|
||||
# For simplicity, we'll compute cosine similarity in Python
|
||||
# In production, you'd want to use MySQL stored procedures or UDFs
|
||||
with self._get_cursor() as cur:
|
||||
query_sql = f"""
|
||||
SELECT id, vector, payload
|
||||
FROM `{self.collection_name}`
|
||||
{filter_clause}
|
||||
"""
|
||||
cur.execute(query_sql, filter_params)
|
||||
results = cur.fetchall()
|
||||
|
||||
# Calculate cosine similarity in Python
|
||||
import numpy as np
|
||||
query_vec = np.array(vectors)
|
||||
scored_results = []
|
||||
|
||||
for row in results:
|
||||
vec = np.array(json.loads(row['vector']))
|
||||
# Cosine similarity
|
||||
similarity = np.dot(query_vec, vec) / (np.linalg.norm(query_vec) * np.linalg.norm(vec))
|
||||
distance = 1 - similarity
|
||||
scored_results.append((row['id'], distance, row['payload']))
|
||||
|
||||
# Sort by distance and limit
|
||||
scored_results.sort(key=lambda x: x[1])
|
||||
scored_results = scored_results[:limit]
|
||||
|
||||
return [
|
||||
OutputData(id=r[0], score=float(r[1]), payload=json.loads(r[2]) if isinstance(r[2], str) else r[2])
|
||||
for r in scored_results
|
||||
]
|
||||
|
||||
def delete(self, vector_id: str):
|
||||
"""
|
||||
Delete a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to delete
|
||||
"""
|
||||
with self._get_cursor(commit=True) as cur:
|
||||
cur.execute(f"DELETE FROM `{self.collection_name}` WHERE id = %s", (vector_id,))
|
||||
|
||||
def update(
|
||||
self,
|
||||
vector_id: str,
|
||||
vector: Optional[List[float]] = None,
|
||||
payload: Optional[Dict] = None,
|
||||
):
|
||||
"""
|
||||
Update a vector and its payload.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to update
|
||||
vector (List[float], optional): Updated vector
|
||||
payload (Dict, optional): Updated payload
|
||||
"""
|
||||
with self._get_cursor(commit=True) as cur:
|
||||
if vector is not None:
|
||||
cur.execute(
|
||||
f"UPDATE `{self.collection_name}` SET vector = %s WHERE id = %s",
|
||||
(json.dumps(vector), vector_id),
|
||||
)
|
||||
if payload is not None:
|
||||
cur.execute(
|
||||
f"UPDATE `{self.collection_name}` SET payload = %s WHERE id = %s",
|
||||
(json.dumps(payload), vector_id),
|
||||
)
|
||||
|
||||
def get(self, vector_id: str) -> Optional[OutputData]:
|
||||
"""
|
||||
Retrieve a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to retrieve
|
||||
|
||||
Returns:
|
||||
OutputData: Retrieved vector or None if not found
|
||||
"""
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"SELECT id, vector, payload FROM `{self.collection_name}` WHERE id = %s",
|
||||
(vector_id,),
|
||||
)
|
||||
result = cur.fetchone()
|
||||
if not result:
|
||||
return None
|
||||
return OutputData(
|
||||
id=result['id'],
|
||||
score=None,
|
||||
payload=json.loads(result['payload']) if isinstance(result['payload'], str) else result['payload']
|
||||
)
|
||||
|
||||
def list_cols(self) -> List[str]:
|
||||
"""
|
||||
List all collections (tables).
|
||||
|
||||
Returns:
|
||||
List[str]: List of collection names
|
||||
"""
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute("SHOW TABLES")
|
||||
return [row[f"Tables_in_{self.database}"] for row in cur.fetchall()]
|
||||
|
||||
def delete_col(self):
|
||||
"""Delete the collection (table)."""
|
||||
with self._get_cursor(commit=True) as cur:
|
||||
cur.execute(f"DROP TABLE IF EXISTS `{self.collection_name}`")
|
||||
logger.info(f"Deleted collection '{self.collection_name}'")
|
||||
|
||||
def col_info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about the collection.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Collection information
|
||||
"""
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute("""
|
||||
SELECT
|
||||
TABLE_NAME as name,
|
||||
TABLE_ROWS as count,
|
||||
ROUND(((DATA_LENGTH + INDEX_LENGTH) / 1024 / 1024), 2) as size_mb
|
||||
FROM information_schema.TABLES
|
||||
WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s
|
||||
""", (self.database, self.collection_name))
|
||||
result = cur.fetchone()
|
||||
|
||||
if result:
|
||||
return {
|
||||
"name": result['name'],
|
||||
"count": result['count'],
|
||||
"size": f"{result['size_mb']} MB"
|
||||
}
|
||||
return {}
|
||||
|
||||
def list(
|
||||
self,
|
||||
filters: Optional[Dict] = None,
|
||||
limit: int = 100
|
||||
) -> List[List[OutputData]]:
|
||||
"""
|
||||
List all vectors in the collection.
|
||||
|
||||
Args:
|
||||
filters (Dict, optional): Filters to apply
|
||||
limit (int): Number of vectors to return
|
||||
|
||||
Returns:
|
||||
List[List[OutputData]]: List of vectors
|
||||
"""
|
||||
filter_conditions = []
|
||||
filter_params = []
|
||||
|
||||
if filters:
|
||||
for k, v in filters.items():
|
||||
filter_conditions.append("JSON_EXTRACT(payload, %s) = %s")
|
||||
filter_params.extend([f"$.{k}", json.dumps(v)])
|
||||
|
||||
filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
|
||||
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"""
|
||||
SELECT id, vector, payload
|
||||
FROM `{self.collection_name}`
|
||||
{filter_clause}
|
||||
LIMIT %s
|
||||
""",
|
||||
(*filter_params, limit)
|
||||
)
|
||||
results = cur.fetchall()
|
||||
|
||||
return [[
|
||||
OutputData(
|
||||
id=r['id'],
|
||||
score=None,
|
||||
payload=json.loads(r['payload']) if isinstance(r['payload'], str) else r['payload']
|
||||
) for r in results
|
||||
]]
|
||||
|
||||
def reset(self):
|
||||
"""Reset the collection by deleting and recreating it."""
|
||||
logger.warning(f"Resetting collection {self.collection_name}...")
|
||||
self.delete_col()
|
||||
self.create_col(name=self.collection_name, vector_size=self.embedding_model_dims)
|
||||
|
||||
def __del__(self):
|
||||
"""Close the connection pool when the object is deleted."""
|
||||
try:
|
||||
if hasattr(self, 'connection_pool') and self.connection_pool:
|
||||
self.connection_pool.close()
|
||||
except Exception:
|
||||
pass
|
||||
368
neomem/neomem/vector_stores/baidu.py
Normal file
368
neomem/neomem/vector_stores/baidu.py
Normal file
@@ -0,0 +1,368 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
try:
|
||||
import pymochow
|
||||
from pymochow.auth.bce_credentials import BceCredentials
|
||||
from pymochow.configuration import Configuration
|
||||
from pymochow.exception import ServerError
|
||||
from pymochow.model.enum import (
|
||||
FieldType,
|
||||
IndexType,
|
||||
MetricType,
|
||||
ServerErrCode,
|
||||
TableState,
|
||||
)
|
||||
from pymochow.model.schema import (
|
||||
AutoBuildRowCountIncrement,
|
||||
Field,
|
||||
FilteringIndex,
|
||||
HNSWParams,
|
||||
Schema,
|
||||
VectorIndex,
|
||||
)
|
||||
from pymochow.model.table import (
|
||||
FloatVector,
|
||||
Partition,
|
||||
Row,
|
||||
VectorSearchConfig,
|
||||
VectorTopkSearchRequest,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError("The 'pymochow' library is required. Please install it using 'pip install pymochow'.")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OutputData(BaseModel):
|
||||
id: Optional[str] # memory id
|
||||
score: Optional[float] # distance
|
||||
payload: Optional[Dict] # metadata
|
||||
|
||||
|
||||
class BaiduDB(VectorStoreBase):
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str,
|
||||
account: str,
|
||||
api_key: str,
|
||||
database_name: str,
|
||||
table_name: str,
|
||||
embedding_model_dims: int,
|
||||
metric_type: MetricType,
|
||||
) -> None:
|
||||
"""Initialize the BaiduDB database.
|
||||
|
||||
Args:
|
||||
endpoint (str): Endpoint URL for Baidu VectorDB.
|
||||
account (str): Account for Baidu VectorDB.
|
||||
api_key (str): API Key for Baidu VectorDB.
|
||||
database_name (str): Name of the database.
|
||||
table_name (str): Name of the table.
|
||||
embedding_model_dims (int): Dimensions of the embedding model.
|
||||
metric_type (MetricType): Metric type for similarity search.
|
||||
"""
|
||||
self.endpoint = endpoint
|
||||
self.account = account
|
||||
self.api_key = api_key
|
||||
self.database_name = database_name
|
||||
self.table_name = table_name
|
||||
self.embedding_model_dims = embedding_model_dims
|
||||
self.metric_type = metric_type
|
||||
|
||||
# Initialize Mochow client
|
||||
config = Configuration(credentials=BceCredentials(account, api_key), endpoint=endpoint)
|
||||
self.client = pymochow.MochowClient(config)
|
||||
|
||||
# Ensure database and table exist
|
||||
self._create_database_if_not_exists()
|
||||
self.create_col(
|
||||
name=self.table_name,
|
||||
vector_size=self.embedding_model_dims,
|
||||
distance=self.metric_type,
|
||||
)
|
||||
|
||||
def _create_database_if_not_exists(self):
|
||||
"""Create database if it doesn't exist."""
|
||||
try:
|
||||
# Check if database exists
|
||||
databases = self.client.list_databases()
|
||||
db_exists = any(db.database_name == self.database_name for db in databases)
|
||||
if not db_exists:
|
||||
self._database = self.client.create_database(self.database_name)
|
||||
logger.info(f"Created database: {self.database_name}")
|
||||
else:
|
||||
self._database = self.client.database(self.database_name)
|
||||
logger.info(f"Database {self.database_name} already exists")
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating database: {e}")
|
||||
raise
|
||||
|
||||
def create_col(self, name, vector_size, distance):
|
||||
"""Create a new table.
|
||||
|
||||
Args:
|
||||
name (str): Name of the table to create.
|
||||
vector_size (int): Dimension of the vector.
|
||||
distance (str): Metric type for similarity search.
|
||||
"""
|
||||
# Check if table already exists
|
||||
try:
|
||||
tables = self._database.list_table()
|
||||
table_exists = any(table.table_name == name for table in tables)
|
||||
if table_exists:
|
||||
logger.info(f"Table {name} already exists. Skipping creation.")
|
||||
self._table = self._database.describe_table(name)
|
||||
return
|
||||
|
||||
# Convert distance string to MetricType enum
|
||||
metric_type = None
|
||||
for k, v in MetricType.__members__.items():
|
||||
if k == distance:
|
||||
metric_type = v
|
||||
if metric_type is None:
|
||||
raise ValueError(f"Unsupported metric_type: {distance}")
|
||||
|
||||
# Define table schema
|
||||
fields = [
|
||||
Field(
|
||||
"id", FieldType.STRING, primary_key=True, partition_key=True, auto_increment=False, not_null=True
|
||||
),
|
||||
Field("vector", FieldType.FLOAT_VECTOR, dimension=vector_size),
|
||||
Field("metadata", FieldType.JSON),
|
||||
]
|
||||
|
||||
# Create vector index
|
||||
indexes = [
|
||||
VectorIndex(
|
||||
index_name="vector_idx",
|
||||
index_type=IndexType.HNSW,
|
||||
field="vector",
|
||||
metric_type=metric_type,
|
||||
params=HNSWParams(m=16, efconstruction=200),
|
||||
auto_build=True,
|
||||
auto_build_index_policy=AutoBuildRowCountIncrement(row_count_increment=10000),
|
||||
),
|
||||
FilteringIndex(index_name="metadata_filtering_idx", fields=["metadata"]),
|
||||
]
|
||||
|
||||
schema = Schema(fields=fields, indexes=indexes)
|
||||
|
||||
# Create table
|
||||
self._table = self._database.create_table(
|
||||
table_name=name, replication=3, partition=Partition(partition_num=1), schema=schema
|
||||
)
|
||||
logger.info(f"Created table: {name}")
|
||||
|
||||
# Wait for table to be ready
|
||||
while True:
|
||||
time.sleep(2)
|
||||
table = self._database.describe_table(name)
|
||||
if table.state == TableState.NORMAL:
|
||||
logger.info(f"Table {name} is ready.")
|
||||
break
|
||||
logger.info(f"Waiting for table {name} to be ready, current state: {table.state}")
|
||||
self._table = table
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating table: {e}")
|
||||
raise
|
||||
|
||||
def insert(self, vectors, payloads=None, ids=None):
|
||||
"""Insert vectors into the table.
|
||||
|
||||
Args:
|
||||
vectors (List[List[float]]): List of vectors to insert.
|
||||
payloads (List[Dict], optional): List of payloads corresponding to vectors.
|
||||
ids (List[str], optional): List of IDs corresponding to vectors.
|
||||
"""
|
||||
# Prepare data for insertion
|
||||
for idx, vector, metadata in zip(ids, vectors, payloads):
|
||||
row = Row(id=idx, vector=vector, metadata=metadata)
|
||||
self._table.upsert(rows=[row])
|
||||
|
||||
def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None) -> list:
|
||||
"""
|
||||
Search for similar vectors.
|
||||
|
||||
Args:
|
||||
query (str): Query string.
|
||||
vectors (List[float]): Query vector.
|
||||
limit (int, optional): Number of results to return. Defaults to 5.
|
||||
filters (Dict, optional): Filters to apply to the search. Defaults to None.
|
||||
|
||||
Returns:
|
||||
list: Search results.
|
||||
"""
|
||||
# Add filters if provided
|
||||
search_filter = None
|
||||
if filters:
|
||||
search_filter = self._create_filter(filters)
|
||||
|
||||
# Create AnnSearch for vector search
|
||||
request = VectorTopkSearchRequest(
|
||||
vector_field="vector",
|
||||
vector=FloatVector(vectors),
|
||||
limit=limit,
|
||||
filter=search_filter,
|
||||
config=VectorSearchConfig(ef=200),
|
||||
)
|
||||
|
||||
# Perform search
|
||||
projections = ["id", "metadata"]
|
||||
res = self._table.vector_search(request=request, projections=projections)
|
||||
|
||||
# Parse results
|
||||
output = []
|
||||
for row in res.rows:
|
||||
row_data = row.get("row", {})
|
||||
output_data = OutputData(
|
||||
id=row_data.get("id"), score=row.get("score", 0.0), payload=row_data.get("metadata", {})
|
||||
)
|
||||
output.append(output_data)
|
||||
|
||||
return output
|
||||
|
||||
def delete(self, vector_id):
|
||||
"""
|
||||
Delete a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to delete.
|
||||
"""
|
||||
self._table.delete(primary_key={"id": vector_id})
|
||||
|
||||
def update(self, vector_id=None, vector=None, payload=None):
|
||||
"""
|
||||
Update a vector and its payload.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to update.
|
||||
vector (List[float], optional): Updated vector.
|
||||
payload (Dict, optional): Updated payload.
|
||||
"""
|
||||
row = Row(id=vector_id, vector=vector, metadata=payload)
|
||||
self._table.upsert(rows=[row])
|
||||
|
||||
def get(self, vector_id):
|
||||
"""
|
||||
Retrieve a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to retrieve.
|
||||
|
||||
Returns:
|
||||
OutputData: Retrieved vector.
|
||||
"""
|
||||
projections = ["id", "metadata"]
|
||||
result = self._table.query(primary_key={"id": vector_id}, projections=projections)
|
||||
row = result.row
|
||||
return OutputData(id=row.get("id"), score=None, payload=row.get("metadata", {}))
|
||||
|
||||
def list_cols(self):
|
||||
"""
|
||||
List all tables (collections).
|
||||
|
||||
Returns:
|
||||
List[str]: List of table names.
|
||||
"""
|
||||
tables = self._database.list_table()
|
||||
return [table.table_name for table in tables]
|
||||
|
||||
def delete_col(self):
|
||||
"""Delete the table."""
|
||||
try:
|
||||
tables = self._database.list_table()
|
||||
|
||||
# skip drop table if table not exists
|
||||
table_exists = any(table.table_name == self.table_name for table in tables)
|
||||
if not table_exists:
|
||||
logger.info(f"Table {self.table_name} does not exist, skipping deletion")
|
||||
return
|
||||
|
||||
# Delete the table
|
||||
self._database.drop_table(self.table_name)
|
||||
logger.info(f"Initiated deletion of table {self.table_name}")
|
||||
|
||||
# Wait for table to be completely deleted
|
||||
while True:
|
||||
time.sleep(2)
|
||||
try:
|
||||
self._database.describe_table(self.table_name)
|
||||
logger.info(f"Waiting for table {self.table_name} to be deleted...")
|
||||
except ServerError as e:
|
||||
if e.code == ServerErrCode.TABLE_NOT_EXIST:
|
||||
logger.info(f"Table {self.table_name} has been completely deleted")
|
||||
break
|
||||
logger.error(f"Error checking table status: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting table: {e}")
|
||||
raise
|
||||
|
||||
def col_info(self):
|
||||
"""
|
||||
Get information about the table.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Table information.
|
||||
"""
|
||||
return self._table.stats()
|
||||
|
||||
def list(self, filters: dict = None, limit: int = 100) -> list:
|
||||
"""
|
||||
List all vectors in the table.
|
||||
|
||||
Args:
|
||||
filters (Dict, optional): Filters to apply to the list.
|
||||
limit (int, optional): Number of vectors to return. Defaults to 100.
|
||||
|
||||
Returns:
|
||||
List[OutputData]: List of vectors.
|
||||
"""
|
||||
projections = ["id", "metadata"]
|
||||
list_filter = self._create_filter(filters) if filters else None
|
||||
result = self._table.select(filter=list_filter, projections=projections, limit=limit)
|
||||
|
||||
memories = []
|
||||
for row in result.rows:
|
||||
obj = OutputData(id=row.get("id"), score=None, payload=row.get("metadata", {}))
|
||||
memories.append(obj)
|
||||
|
||||
return [memories]
|
||||
|
||||
def reset(self):
|
||||
"""Reset the table by deleting and recreating it."""
|
||||
logger.warning(f"Resetting table {self.table_name}...")
|
||||
try:
|
||||
self.delete_col()
|
||||
self.create_col(
|
||||
name=self.table_name,
|
||||
vector_size=self.embedding_model_dims,
|
||||
distance=self.metric_type,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error resetting table: {e}")
|
||||
raise
|
||||
|
||||
def _create_filter(self, filters: dict) -> str:
|
||||
"""
|
||||
Create filter expression for queries.
|
||||
|
||||
Args:
|
||||
filters (dict): Filter conditions.
|
||||
|
||||
Returns:
|
||||
str: Filter expression.
|
||||
"""
|
||||
conditions = []
|
||||
for key, value in filters.items():
|
||||
if isinstance(value, str):
|
||||
conditions.append(f'metadata["{key}"] = "{value}"')
|
||||
else:
|
||||
conditions.append(f'metadata["{key}"] = {value}')
|
||||
return " AND ".join(conditions)
|
||||
58
neomem/neomem/vector_stores/base.py
Normal file
58
neomem/neomem/vector_stores/base.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class VectorStoreBase(ABC):
|
||||
@abstractmethod
|
||||
def create_col(self, name, vector_size, distance):
|
||||
"""Create a new collection."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def insert(self, vectors, payloads=None, ids=None):
|
||||
"""Insert vectors into a collection."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(self, query, vectors, limit=5, filters=None):
|
||||
"""Search for similar vectors."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, vector_id):
|
||||
"""Delete a vector by ID."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(self, vector_id, vector=None, payload=None):
|
||||
"""Update a vector and its payload."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(self, vector_id):
|
||||
"""Retrieve a vector by ID."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_cols(self):
|
||||
"""List all collections."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_col(self):
|
||||
"""Delete a collection."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def col_info(self):
|
||||
"""Get information about a collection."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list(self, filters=None, limit=None):
|
||||
"""List all memories."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset(self):
|
||||
"""Reset by delete the collection and recreate it."""
|
||||
pass
|
||||
267
neomem/neomem/vector_stores/chroma.py
Normal file
267
neomem/neomem/vector_stores/chroma.py
Normal file
@@ -0,0 +1,267 @@
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
try:
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
except ImportError:
|
||||
raise ImportError("The 'chromadb' library is required. Please install it using 'pip install chromadb'.")
|
||||
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OutputData(BaseModel):
|
||||
id: Optional[str] # memory id
|
||||
score: Optional[float] # distance
|
||||
payload: Optional[Dict] # metadata
|
||||
|
||||
|
||||
class ChromaDB(VectorStoreBase):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
client: Optional[chromadb.Client] = None,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[int] = None,
|
||||
path: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
tenant: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the Chromadb vector store.
|
||||
|
||||
Args:
|
||||
collection_name (str): Name of the collection.
|
||||
client (chromadb.Client, optional): Existing chromadb client instance. Defaults to None.
|
||||
host (str, optional): Host address for chromadb server. Defaults to None.
|
||||
port (int, optional): Port for chromadb server. Defaults to None.
|
||||
path (str, optional): Path for local chromadb database. Defaults to None.
|
||||
api_key (str, optional): ChromaDB Cloud API key. Defaults to None.
|
||||
tenant (str, optional): ChromaDB Cloud tenant ID. Defaults to None.
|
||||
"""
|
||||
if client:
|
||||
self.client = client
|
||||
elif api_key and tenant:
|
||||
# Initialize ChromaDB Cloud client
|
||||
logger.info("Initializing ChromaDB Cloud client")
|
||||
self.client = chromadb.CloudClient(
|
||||
api_key=api_key,
|
||||
tenant=tenant,
|
||||
database="mem0" # Use fixed database name for cloud
|
||||
)
|
||||
else:
|
||||
# Initialize local or server client
|
||||
self.settings = Settings(anonymized_telemetry=False)
|
||||
|
||||
if host and port:
|
||||
self.settings.chroma_server_host = host
|
||||
self.settings.chroma_server_http_port = port
|
||||
self.settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI"
|
||||
else:
|
||||
if path is None:
|
||||
path = "db"
|
||||
|
||||
self.settings.persist_directory = path
|
||||
self.settings.is_persistent = True
|
||||
|
||||
self.client = chromadb.Client(self.settings)
|
||||
|
||||
self.collection_name = collection_name
|
||||
self.collection = self.create_col(collection_name)
|
||||
|
||||
def _parse_output(self, data: Dict) -> List[OutputData]:
|
||||
"""
|
||||
Parse the output data.
|
||||
|
||||
Args:
|
||||
data (Dict): Output data.
|
||||
|
||||
Returns:
|
||||
List[OutputData]: Parsed output data.
|
||||
"""
|
||||
keys = ["ids", "distances", "metadatas"]
|
||||
values = []
|
||||
|
||||
for key in keys:
|
||||
value = data.get(key, [])
|
||||
if isinstance(value, list) and value and isinstance(value[0], list):
|
||||
value = value[0]
|
||||
values.append(value)
|
||||
|
||||
ids, distances, metadatas = values
|
||||
max_length = max(len(v) for v in values if isinstance(v, list) and v is not None)
|
||||
|
||||
result = []
|
||||
for i in range(max_length):
|
||||
entry = OutputData(
|
||||
id=ids[i] if isinstance(ids, list) and ids and i < len(ids) else None,
|
||||
score=(distances[i] if isinstance(distances, list) and distances and i < len(distances) else None),
|
||||
payload=(metadatas[i] if isinstance(metadatas, list) and metadatas and i < len(metadatas) else None),
|
||||
)
|
||||
result.append(entry)
|
||||
|
||||
return result
|
||||
|
||||
def create_col(self, name: str, embedding_fn: Optional[callable] = None):
|
||||
"""
|
||||
Create a new collection.
|
||||
|
||||
Args:
|
||||
name (str): Name of the collection.
|
||||
embedding_fn (Optional[callable]): Embedding function to use. Defaults to None.
|
||||
|
||||
Returns:
|
||||
chromadb.Collection: The created or retrieved collection.
|
||||
"""
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=name,
|
||||
embedding_function=embedding_fn,
|
||||
)
|
||||
return collection
|
||||
|
||||
def insert(
|
||||
self,
|
||||
vectors: List[list],
|
||||
payloads: Optional[List[Dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
):
|
||||
"""
|
||||
Insert vectors into a collection.
|
||||
|
||||
Args:
|
||||
vectors (List[list]): List of vectors to insert.
|
||||
payloads (Optional[List[Dict]], optional): List of payloads corresponding to vectors. Defaults to None.
|
||||
ids (Optional[List[str]], optional): List of IDs corresponding to vectors. Defaults to None.
|
||||
"""
|
||||
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
|
||||
self.collection.add(ids=ids, embeddings=vectors, metadatas=payloads)
|
||||
|
||||
def search(
|
||||
self, query: str, vectors: List[list], limit: int = 5, filters: Optional[Dict] = None
|
||||
) -> List[OutputData]:
|
||||
"""
|
||||
Search for similar vectors.
|
||||
|
||||
Args:
|
||||
query (str): Query.
|
||||
vectors (List[list]): List of vectors to search.
|
||||
limit (int, optional): Number of results to return. Defaults to 5.
|
||||
filters (Optional[Dict], optional): Filters to apply to the search. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[OutputData]: Search results.
|
||||
"""
|
||||
where_clause = self._generate_where_clause(filters) if filters else None
|
||||
results = self.collection.query(query_embeddings=vectors, where=where_clause, n_results=limit)
|
||||
final_results = self._parse_output(results)
|
||||
return final_results
|
||||
|
||||
def delete(self, vector_id: str):
|
||||
"""
|
||||
Delete a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to delete.
|
||||
"""
|
||||
self.collection.delete(ids=vector_id)
|
||||
|
||||
def update(
|
||||
self,
|
||||
vector_id: str,
|
||||
vector: Optional[List[float]] = None,
|
||||
payload: Optional[Dict] = None,
|
||||
):
|
||||
"""
|
||||
Update a vector and its payload.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to update.
|
||||
vector (Optional[List[float]], optional): Updated vector. Defaults to None.
|
||||
payload (Optional[Dict], optional): Updated payload. Defaults to None.
|
||||
"""
|
||||
self.collection.update(ids=vector_id, embeddings=vector, metadatas=payload)
|
||||
|
||||
def get(self, vector_id: str) -> OutputData:
|
||||
"""
|
||||
Retrieve a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to retrieve.
|
||||
|
||||
Returns:
|
||||
OutputData: Retrieved vector.
|
||||
"""
|
||||
result = self.collection.get(ids=[vector_id])
|
||||
return self._parse_output(result)[0]
|
||||
|
||||
def list_cols(self) -> List[chromadb.Collection]:
|
||||
"""
|
||||
List all collections.
|
||||
|
||||
Returns:
|
||||
List[chromadb.Collection]: List of collections.
|
||||
"""
|
||||
return self.client.list_collections()
|
||||
|
||||
def delete_col(self):
|
||||
"""
|
||||
Delete a collection.
|
||||
"""
|
||||
self.client.delete_collection(name=self.collection_name)
|
||||
|
||||
def col_info(self) -> Dict:
|
||||
"""
|
||||
Get information about a collection.
|
||||
|
||||
Returns:
|
||||
Dict: Collection information.
|
||||
"""
|
||||
return self.client.get_collection(name=self.collection_name)
|
||||
|
||||
def list(self, filters: Optional[Dict] = None, limit: int = 100) -> List[OutputData]:
|
||||
"""
|
||||
List all vectors in a collection.
|
||||
|
||||
Args:
|
||||
filters (Optional[Dict], optional): Filters to apply to the list. Defaults to None.
|
||||
limit (int, optional): Number of vectors to return. Defaults to 100.
|
||||
|
||||
Returns:
|
||||
List[OutputData]: List of vectors.
|
||||
"""
|
||||
where_clause = self._generate_where_clause(filters) if filters else None
|
||||
results = self.collection.get(where=where_clause, limit=limit)
|
||||
return [self._parse_output(results)]
|
||||
|
||||
def reset(self):
|
||||
"""Reset the index by deleting and recreating it."""
|
||||
logger.warning(f"Resetting index {self.collection_name}...")
|
||||
self.delete_col()
|
||||
self.collection = self.create_col(self.collection_name)
|
||||
|
||||
@staticmethod
|
||||
def _generate_where_clause(where: dict[str, any]) -> dict[str, any]:
|
||||
"""
|
||||
Generate a properly formatted where clause for ChromaDB.
|
||||
|
||||
Args:
|
||||
where (dict[str, any]): The filter conditions.
|
||||
|
||||
Returns:
|
||||
dict[str, any]: Properly formatted where clause for ChromaDB.
|
||||
"""
|
||||
# If only one filter is supplied, return it as is
|
||||
# (no need to wrap in $and based on chroma docs)
|
||||
if where is None:
|
||||
return {}
|
||||
if len(where.keys()) <= 1:
|
||||
return where
|
||||
where_filters = []
|
||||
for k, v in where.items():
|
||||
if isinstance(v, str):
|
||||
where_filters.append({k: v})
|
||||
return {"$and": where_filters}
|
||||
65
neomem/neomem/vector_stores/configs.py
Normal file
65
neomem/neomem/vector_stores/configs.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class VectorStoreConfig(BaseModel):
|
||||
provider: str = Field(
|
||||
description="Provider of the vector store (e.g., 'qdrant', 'chroma', 'upstash_vector')",
|
||||
default="qdrant",
|
||||
)
|
||||
config: Optional[Dict] = Field(description="Configuration for the specific vector store", default=None)
|
||||
|
||||
_provider_configs: Dict[str, str] = {
|
||||
"qdrant": "QdrantConfig",
|
||||
"chroma": "ChromaDbConfig",
|
||||
"pgvector": "PGVectorConfig",
|
||||
"pinecone": "PineconeConfig",
|
||||
"mongodb": "MongoDBConfig",
|
||||
"milvus": "MilvusDBConfig",
|
||||
"baidu": "BaiduDBConfig",
|
||||
"neptune": "NeptuneAnalyticsConfig",
|
||||
"upstash_vector": "UpstashVectorConfig",
|
||||
"azure_ai_search": "AzureAISearchConfig",
|
||||
"azure_mysql": "AzureMySQLConfig",
|
||||
"redis": "RedisDBConfig",
|
||||
"valkey": "ValkeyConfig",
|
||||
"databricks": "DatabricksConfig",
|
||||
"elasticsearch": "ElasticsearchConfig",
|
||||
"vertex_ai_vector_search": "GoogleMatchingEngineConfig",
|
||||
"opensearch": "OpenSearchConfig",
|
||||
"supabase": "SupabaseConfig",
|
||||
"weaviate": "WeaviateConfig",
|
||||
"faiss": "FAISSConfig",
|
||||
"langchain": "LangchainConfig",
|
||||
"s3_vectors": "S3VectorsConfig",
|
||||
}
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_and_create_config(self) -> "VectorStoreConfig":
|
||||
provider = self.provider
|
||||
config = self.config
|
||||
|
||||
if provider not in self._provider_configs:
|
||||
raise ValueError(f"Unsupported vector store provider: {provider}")
|
||||
|
||||
module = __import__(
|
||||
f"mem0.configs.vector_stores.{provider}",
|
||||
fromlist=[self._provider_configs[provider]],
|
||||
)
|
||||
config_class = getattr(module, self._provider_configs[provider])
|
||||
|
||||
if config is None:
|
||||
config = {}
|
||||
|
||||
if not isinstance(config, dict):
|
||||
if not isinstance(config, config_class):
|
||||
raise ValueError(f"Invalid config type for provider {provider}")
|
||||
return self
|
||||
|
||||
# also check if path in allowed kays for pydantic model, and whether config extra fields are allowed
|
||||
if "path" not in config and "path" in config_class.__annotations__:
|
||||
config["path"] = f"/tmp/{provider}"
|
||||
|
||||
self.config = config_class(**config)
|
||||
return self
|
||||
759
neomem/neomem/vector_stores/databricks.py
Normal file
759
neomem/neomem/vector_stores/databricks.py
Normal file
@@ -0,0 +1,759 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Optional, List
|
||||
from datetime import datetime, date
|
||||
from databricks.sdk.service.catalog import ColumnInfo, ColumnTypeName, TableType, DataSourceFormat
|
||||
from databricks.sdk.service.catalog import TableConstraint, PrimaryKeyConstraint
|
||||
from databricks.sdk import WorkspaceClient
|
||||
from databricks.sdk.service.vectorsearch import (
|
||||
VectorIndexType,
|
||||
DeltaSyncVectorIndexSpecRequest,
|
||||
DirectAccessVectorIndexSpec,
|
||||
EmbeddingSourceColumn,
|
||||
EmbeddingVectorColumn,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from mem0.memory.utils import extract_json
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryResult(BaseModel):
|
||||
id: Optional[str] = None
|
||||
score: Optional[float] = None
|
||||
payload: Optional[dict] = None
|
||||
|
||||
|
||||
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
|
||||
|
||||
|
||||
class Databricks(VectorStoreBase):
|
||||
def __init__(
|
||||
self,
|
||||
workspace_url: str,
|
||||
access_token: Optional[str] = None,
|
||||
client_id: Optional[str] = None,
|
||||
client_secret: Optional[str] = None,
|
||||
azure_client_id: Optional[str] = None,
|
||||
azure_client_secret: Optional[str] = None,
|
||||
endpoint_name: str = None,
|
||||
catalog: str = None,
|
||||
schema: str = None,
|
||||
table_name: str = None,
|
||||
collection_name: str = "mem0",
|
||||
index_type: str = "DELTA_SYNC",
|
||||
embedding_model_endpoint_name: Optional[str] = None,
|
||||
embedding_dimension: int = 1536,
|
||||
endpoint_type: str = "STANDARD",
|
||||
pipeline_type: str = "TRIGGERED",
|
||||
warehouse_name: Optional[str] = None,
|
||||
query_type: str = "ANN",
|
||||
):
|
||||
"""
|
||||
Initialize the Databricks Vector Search vector store.
|
||||
|
||||
Args:
|
||||
workspace_url (str): Databricks workspace URL.
|
||||
access_token (str, optional): Personal access token for authentication.
|
||||
client_id (str, optional): Service principal client ID for authentication.
|
||||
client_secret (str, optional): Service principal client secret for authentication.
|
||||
azure_client_id (str, optional): Azure AD application client ID (for Azure Databricks).
|
||||
azure_client_secret (str, optional): Azure AD application client secret (for Azure Databricks).
|
||||
endpoint_name (str): Vector search endpoint name.
|
||||
catalog (str): Unity Catalog catalog name.
|
||||
schema (str): Unity Catalog schema name.
|
||||
table_name (str): Source Delta table name.
|
||||
index_name (str, optional): Vector search index name (default: "mem0").
|
||||
index_type (str, optional): Index type, either "DELTA_SYNC" or "DIRECT_ACCESS" (default: "DELTA_SYNC").
|
||||
embedding_model_endpoint_name (str, optional): Embedding model endpoint for Databricks-computed embeddings.
|
||||
embedding_dimension (int, optional): Vector embedding dimensions (default: 1536).
|
||||
endpoint_type (str, optional): Endpoint type, either "STANDARD" or "STORAGE_OPTIMIZED" (default: "STANDARD").
|
||||
pipeline_type (str, optional): Sync pipeline type, either "TRIGGERED" or "CONTINUOUS" (default: "TRIGGERED").
|
||||
warehouse_name (str, optional): Databricks SQL warehouse Name (if using SQL warehouse).
|
||||
query_type (str, optional): Query type, either "ANN" or "HYBRID" (default: "ANN").
|
||||
"""
|
||||
# Basic identifiers
|
||||
self.workspace_url = workspace_url
|
||||
self.endpoint_name = endpoint_name
|
||||
self.catalog = catalog
|
||||
self.schema = schema
|
||||
self.table_name = table_name
|
||||
self.fully_qualified_table_name = f"{self.catalog}.{self.schema}.{self.table_name}"
|
||||
self.index_name = collection_name
|
||||
self.fully_qualified_index_name = f"{self.catalog}.{self.schema}.{self.index_name}"
|
||||
|
||||
# Configuration
|
||||
self.index_type = index_type
|
||||
self.embedding_model_endpoint_name = embedding_model_endpoint_name
|
||||
self.embedding_dimension = embedding_dimension
|
||||
self.endpoint_type = endpoint_type
|
||||
self.pipeline_type = pipeline_type
|
||||
self.query_type = query_type
|
||||
|
||||
# Schema
|
||||
self.columns = [
|
||||
ColumnInfo(
|
||||
name="memory_id",
|
||||
type_name=ColumnTypeName.STRING,
|
||||
type_text="string",
|
||||
type_json='{"type":"string"}',
|
||||
nullable=False,
|
||||
comment="Primary key",
|
||||
position=0,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="hash",
|
||||
type_name=ColumnTypeName.STRING,
|
||||
type_text="string",
|
||||
type_json='{"type":"string"}',
|
||||
comment="Hash of the memory content",
|
||||
position=1,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="agent_id",
|
||||
type_name=ColumnTypeName.STRING,
|
||||
type_text="string",
|
||||
type_json='{"type":"string"}',
|
||||
comment="ID of the agent",
|
||||
position=2,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="run_id",
|
||||
type_name=ColumnTypeName.STRING,
|
||||
type_text="string",
|
||||
type_json='{"type":"string"}',
|
||||
comment="ID of the run",
|
||||
position=3,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="user_id",
|
||||
type_name=ColumnTypeName.STRING,
|
||||
type_text="string",
|
||||
type_json='{"type":"string"}',
|
||||
comment="ID of the user",
|
||||
position=4,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="memory",
|
||||
type_name=ColumnTypeName.STRING,
|
||||
type_text="string",
|
||||
type_json='{"type":"string"}',
|
||||
comment="Memory content",
|
||||
position=5,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="metadata",
|
||||
type_name=ColumnTypeName.STRING,
|
||||
type_text="string",
|
||||
type_json='{"type":"string"}',
|
||||
comment="Additional metadata",
|
||||
position=6,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="created_at",
|
||||
type_name=ColumnTypeName.TIMESTAMP,
|
||||
type_text="timestamp",
|
||||
type_json='{"type":"timestamp"}',
|
||||
comment="Creation timestamp",
|
||||
position=7,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="updated_at",
|
||||
type_name=ColumnTypeName.TIMESTAMP,
|
||||
type_text="timestamp",
|
||||
type_json='{"type":"timestamp"}',
|
||||
comment="Last update timestamp",
|
||||
position=8,
|
||||
),
|
||||
]
|
||||
if self.index_type == VectorIndexType.DIRECT_ACCESS:
|
||||
self.columns.append(
|
||||
ColumnInfo(
|
||||
name="embedding",
|
||||
type_name=ColumnTypeName.ARRAY,
|
||||
type_text="array<float>",
|
||||
type_json='{"type":"array","element":"float","element_nullable":false}',
|
||||
nullable=True,
|
||||
comment="Embedding vector",
|
||||
position=9,
|
||||
)
|
||||
)
|
||||
self.column_names = [col.name for col in self.columns]
|
||||
|
||||
# Initialize Databricks workspace client
|
||||
client_config = {}
|
||||
if client_id and client_secret:
|
||||
client_config.update(
|
||||
{
|
||||
"host": workspace_url,
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
}
|
||||
)
|
||||
elif azure_client_id and azure_client_secret:
|
||||
client_config.update(
|
||||
{
|
||||
"host": workspace_url,
|
||||
"azure_client_id": azure_client_id,
|
||||
"azure_client_secret": azure_client_secret,
|
||||
}
|
||||
)
|
||||
elif access_token:
|
||||
client_config.update({"host": workspace_url, "token": access_token})
|
||||
else:
|
||||
# Try automatic authentication
|
||||
client_config["host"] = workspace_url
|
||||
|
||||
try:
|
||||
self.client = WorkspaceClient(**client_config)
|
||||
logger.info("Initialized Databricks workspace client")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize Databricks workspace client: {e}")
|
||||
raise
|
||||
|
||||
# Get the warehouse ID by name
|
||||
self.warehouse_id = next((w.id for w in self.client.warehouses.list() if w.name == warehouse_name), None)
|
||||
|
||||
# Initialize endpoint (required in Databricks)
|
||||
self._ensure_endpoint_exists()
|
||||
|
||||
# Check if index exists and create if needed
|
||||
collections = self.list_cols()
|
||||
if self.fully_qualified_index_name not in collections:
|
||||
self.create_col()
|
||||
|
||||
def _ensure_endpoint_exists(self):
|
||||
"""Ensure the vector search endpoint exists, create if it doesn't."""
|
||||
try:
|
||||
self.client.vector_search_endpoints.get_endpoint(endpoint_name=self.endpoint_name)
|
||||
logger.info(f"Vector search endpoint '{self.endpoint_name}' already exists")
|
||||
except Exception:
|
||||
# Endpoint doesn't exist, create it
|
||||
try:
|
||||
logger.info(f"Creating vector search endpoint '{self.endpoint_name}' with type '{self.endpoint_type}'")
|
||||
self.client.vector_search_endpoints.create_endpoint_and_wait(
|
||||
name=self.endpoint_name, endpoint_type=self.endpoint_type
|
||||
)
|
||||
logger.info(f"Successfully created vector search endpoint '{self.endpoint_name}'")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create vector search endpoint '{self.endpoint_name}': {e}")
|
||||
raise
|
||||
|
||||
def _ensure_source_table_exists(self):
|
||||
"""Ensure the source Delta table exists with the proper schema."""
|
||||
check = self.client.tables.exists(self.fully_qualified_table_name)
|
||||
|
||||
if check.table_exists:
|
||||
logger.info(f"Source table '{self.fully_qualified_table_name}' already exists")
|
||||
else:
|
||||
logger.info(f"Source table '{self.fully_qualified_table_name}' does not exist, creating it...")
|
||||
self.client.tables.create(
|
||||
name=self.table_name,
|
||||
catalog_name=self.catalog,
|
||||
schema_name=self.schema,
|
||||
table_type=TableType.MANAGED,
|
||||
data_source_format=DataSourceFormat.DELTA,
|
||||
storage_location=None, # Use default storage location
|
||||
columns=self.columns,
|
||||
properties={"delta.enableChangeDataFeed": "true"},
|
||||
)
|
||||
logger.info(f"Successfully created source table '{self.fully_qualified_table_name}'")
|
||||
self.client.table_constraints.create(
|
||||
full_name_arg="logistics_dev.ai.dev_memory",
|
||||
constraint=TableConstraint(
|
||||
primary_key_constraint=PrimaryKeyConstraint(
|
||||
name="pk_dev_memory", # Name of the primary key constraint
|
||||
child_columns=["memory_id"], # Columns that make up the primary key
|
||||
)
|
||||
),
|
||||
)
|
||||
logger.info(
|
||||
f"Successfully created primary key constraint on 'memory_id' for table '{self.fully_qualified_table_name}'"
|
||||
)
|
||||
|
||||
def create_col(self, name=None, vector_size=None, distance=None):
|
||||
"""
|
||||
Create a new collection (index).
|
||||
|
||||
Args:
|
||||
name (str, optional): Index name. If provided, will create a new index using the provided source_table_name.
|
||||
vector_size (int, optional): Vector dimension size.
|
||||
distance (str, optional): Distance metric (not directly applicable for Databricks).
|
||||
|
||||
Returns:
|
||||
The index object.
|
||||
"""
|
||||
# Determine index configuration
|
||||
embedding_dims = vector_size or self.embedding_dimension
|
||||
embedding_source_columns = [
|
||||
EmbeddingSourceColumn(
|
||||
name="memory",
|
||||
embedding_model_endpoint_name=self.embedding_model_endpoint_name,
|
||||
)
|
||||
]
|
||||
|
||||
logger.info(f"Creating vector search index '{self.fully_qualified_index_name}'")
|
||||
|
||||
# First, ensure the source Delta table exists
|
||||
self._ensure_source_table_exists()
|
||||
|
||||
if self.index_type not in [VectorIndexType.DELTA_SYNC, VectorIndexType.DIRECT_ACCESS]:
|
||||
raise ValueError("index_type must be either 'DELTA_SYNC' or 'DIRECT_ACCESS'")
|
||||
|
||||
try:
|
||||
if self.index_type == VectorIndexType.DELTA_SYNC:
|
||||
index = self.client.vector_search_indexes.create_index(
|
||||
name=self.fully_qualified_index_name,
|
||||
endpoint_name=self.endpoint_name,
|
||||
primary_key="memory_id",
|
||||
index_type=self.index_type,
|
||||
delta_sync_index_spec=DeltaSyncVectorIndexSpecRequest(
|
||||
source_table=self.fully_qualified_table_name,
|
||||
pipeline_type=self.pipeline_type,
|
||||
columns_to_sync=self.column_names,
|
||||
embedding_source_columns=embedding_source_columns,
|
||||
),
|
||||
)
|
||||
logger.info(
|
||||
f"Successfully created vector search index '{self.fully_qualified_index_name}' with DELTA_SYNC type"
|
||||
)
|
||||
return index
|
||||
|
||||
elif self.index_type == VectorIndexType.DIRECT_ACCESS:
|
||||
index = self.client.vector_search_indexes.create_index(
|
||||
name=self.fully_qualified_index_name,
|
||||
endpoint_name=self.endpoint_name,
|
||||
primary_key="memory_id",
|
||||
index_type=self.index_type,
|
||||
direct_access_index_spec=DirectAccessVectorIndexSpec(
|
||||
embedding_source_columns=embedding_source_columns,
|
||||
embedding_vector_columns=[
|
||||
EmbeddingVectorColumn(name="embedding", embedding_dimension=embedding_dims)
|
||||
],
|
||||
),
|
||||
)
|
||||
logger.info(
|
||||
f"Successfully created vector search index '{self.fully_qualified_index_name}' with DIRECT_ACCESS type"
|
||||
)
|
||||
return index
|
||||
except Exception as e:
|
||||
logger.error(f"Error making index_type: {self.index_type} for index {self.fully_qualified_index_name}: {e}")
|
||||
|
||||
def _format_sql_value(self, v):
|
||||
"""
|
||||
Format a Python value into a safe SQL literal for Databricks.
|
||||
"""
|
||||
if v is None:
|
||||
return "NULL"
|
||||
if isinstance(v, bool):
|
||||
return "TRUE" if v else "FALSE"
|
||||
if isinstance(v, (int, float)):
|
||||
return str(v)
|
||||
if isinstance(v, (datetime, date)):
|
||||
return f"'{v.isoformat()}'"
|
||||
if isinstance(v, list):
|
||||
# Render arrays (assume numeric or string elements)
|
||||
elems = []
|
||||
for x in v:
|
||||
if x is None:
|
||||
elems.append("NULL")
|
||||
elif isinstance(x, (int, float)):
|
||||
elems.append(str(x))
|
||||
else:
|
||||
s = str(x).replace("'", "''")
|
||||
elems.append(f"'{s}'")
|
||||
return f"array({', '.join(elems)})"
|
||||
if isinstance(v, dict):
|
||||
try:
|
||||
s = json.dumps(v)
|
||||
except Exception:
|
||||
s = str(v)
|
||||
s = s.replace("'", "''")
|
||||
return f"'{s}'"
|
||||
# Fallback: treat as string
|
||||
s = str(v).replace("'", "''")
|
||||
return f"'{s}'"
|
||||
|
||||
def insert(self, vectors: list, payloads: list = None, ids: list = None):
|
||||
"""
|
||||
Insert vectors into the index.
|
||||
|
||||
Args:
|
||||
vectors (List[List[float]]): List of vectors to insert.
|
||||
payloads (List[Dict], optional): List of payloads corresponding to vectors.
|
||||
ids (List[str], optional): List of IDs corresponding to vectors.
|
||||
"""
|
||||
# Determine the number of items to process
|
||||
num_items = len(payloads) if payloads else len(vectors) if vectors else 0
|
||||
|
||||
value_tuples = []
|
||||
for i in range(num_items):
|
||||
values = []
|
||||
for col in self.columns:
|
||||
if col.name == "memory_id":
|
||||
val = ids[i] if ids and i < len(ids) else str(uuid.uuid4())
|
||||
elif col.name == "embedding":
|
||||
val = vectors[i] if vectors and i < len(vectors) else []
|
||||
elif col.name == "memory":
|
||||
val = payloads[i].get("data") if payloads and i < len(payloads) else None
|
||||
else:
|
||||
val = payloads[i].get(col.name) if payloads and i < len(payloads) else None
|
||||
values.append(val)
|
||||
formatted = [self._format_sql_value(v) for v in values]
|
||||
value_tuples.append(f"({', '.join(formatted)})")
|
||||
|
||||
insert_sql = f"INSERT INTO {self.fully_qualified_table_name} ({', '.join(self.column_names)}) VALUES {', '.join(value_tuples)}"
|
||||
|
||||
# Execute the insert
|
||||
try:
|
||||
response = self.client.statement_execution.execute_statement(
|
||||
statement=insert_sql, warehouse_id=self.warehouse_id, wait_timeout="30s"
|
||||
)
|
||||
if response.status.state.value == "SUCCEEDED":
|
||||
logger.info(
|
||||
f"Successfully inserted {num_items} items into Delta table {self.fully_qualified_table_name}"
|
||||
)
|
||||
return
|
||||
else:
|
||||
logger.error(f"Failed to insert items: {response.status.error}")
|
||||
raise Exception(f"Insert operation failed: {response.status.error}")
|
||||
except Exception as e:
|
||||
logger.error(f"Insert operation failed: {e}")
|
||||
raise
|
||||
|
||||
def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None) -> List[MemoryResult]:
|
||||
"""
|
||||
Search for similar vectors or text using the Databricks Vector Search index.
|
||||
|
||||
Args:
|
||||
query (str): Search query text (for text-based search).
|
||||
vectors (list): Query vector (for vector-based search).
|
||||
limit (int): Maximum number of results.
|
||||
filters (dict): Filters to apply.
|
||||
|
||||
Returns:
|
||||
List of MemoryResult objects.
|
||||
"""
|
||||
try:
|
||||
filters_json = json.dumps(filters) if filters else None
|
||||
|
||||
# Choose query type
|
||||
if self.index_type == VectorIndexType.DELTA_SYNC and query:
|
||||
# Text-based search
|
||||
sdk_results = self.client.vector_search_indexes.query_index(
|
||||
index_name=self.fully_qualified_index_name,
|
||||
columns=self.column_names,
|
||||
query_text=query,
|
||||
num_results=limit,
|
||||
query_type=self.query_type,
|
||||
filters_json=filters_json,
|
||||
)
|
||||
elif self.index_type == VectorIndexType.DIRECT_ACCESS and vectors:
|
||||
# Vector-based search
|
||||
sdk_results = self.client.vector_search_indexes.query_index(
|
||||
index_name=self.fully_qualified_index_name,
|
||||
columns=self.column_names,
|
||||
query_vector=vectors,
|
||||
num_results=limit,
|
||||
query_type=self.query_type,
|
||||
filters_json=filters_json,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Must provide query text for DELTA_SYNC or vectors for DIRECT_ACCESS.")
|
||||
|
||||
# Parse results
|
||||
result_data = sdk_results.result if hasattr(sdk_results, "result") else sdk_results
|
||||
data_array = result_data.data_array if getattr(result_data, "data_array", None) else []
|
||||
|
||||
memory_results = []
|
||||
for row in data_array:
|
||||
# Map columns to values
|
||||
row_dict = dict(zip(self.column_names, row)) if isinstance(row, (list, tuple)) else row
|
||||
score = row_dict.get("score") or (
|
||||
row[-1] if isinstance(row, (list, tuple)) and len(row) > len(self.column_names) else None
|
||||
)
|
||||
payload = {k: row_dict.get(k) for k in self.column_names}
|
||||
payload["data"] = payload.get("memory", "")
|
||||
memory_id = row_dict.get("memory_id") or row_dict.get("id")
|
||||
memory_results.append(MemoryResult(id=memory_id, score=score, payload=payload))
|
||||
return memory_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Search failed: {e}")
|
||||
raise
|
||||
|
||||
def delete(self, vector_id):
|
||||
"""
|
||||
Delete a vector by ID from the Delta table.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to delete.
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Deleting vector with ID {vector_id} from Delta table {self.fully_qualified_table_name}")
|
||||
|
||||
delete_sql = f"DELETE FROM {self.fully_qualified_table_name} WHERE memory_id = '{vector_id}'"
|
||||
|
||||
response = self.client.statement_execution.execute_statement(
|
||||
statement=delete_sql, warehouse_id=self.warehouse_id, wait_timeout="30s"
|
||||
)
|
||||
|
||||
if response.status.state.value == "SUCCEEDED":
|
||||
logger.info(f"Successfully deleted vector with ID {vector_id}")
|
||||
else:
|
||||
logger.error(f"Failed to delete vector with ID {vector_id}: {response.status.error}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Delete operation failed for vector ID {vector_id}: {e}")
|
||||
raise
|
||||
|
||||
def update(self, vector_id=None, vector=None, payload=None):
|
||||
"""
|
||||
Update a vector and its payload in the Delta table.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to update.
|
||||
vector (list, optional): New vector values.
|
||||
payload (dict, optional): New payload data.
|
||||
"""
|
||||
|
||||
update_sql = f"UPDATE {self.fully_qualified_table_name} SET "
|
||||
set_clauses = []
|
||||
if not vector_id:
|
||||
logger.error("vector_id is required for update operation")
|
||||
return
|
||||
if vector is not None:
|
||||
if not isinstance(vector, list):
|
||||
logger.error("vector must be a list of float values")
|
||||
return
|
||||
set_clauses.append(f"embedding = {vector}")
|
||||
if payload:
|
||||
if not isinstance(payload, dict):
|
||||
logger.error("payload must be a dictionary")
|
||||
return
|
||||
for key, value in payload.items():
|
||||
if key not in excluded_keys:
|
||||
set_clauses.append(f"{key} = '{value}'")
|
||||
|
||||
if not set_clauses:
|
||||
logger.error("No fields to update")
|
||||
return
|
||||
update_sql += ", ".join(set_clauses)
|
||||
update_sql += f" WHERE memory_id = '{vector_id}'"
|
||||
try:
|
||||
logger.info(f"Updating vector with ID {vector_id} in Delta table {self.fully_qualified_table_name}")
|
||||
|
||||
response = self.client.statement_execution.execute_statement(
|
||||
statement=update_sql, warehouse_id=self.warehouse_id, wait_timeout="30s"
|
||||
)
|
||||
|
||||
if response.status.state.value == "SUCCEEDED":
|
||||
logger.info(f"Successfully updated vector with ID {vector_id}")
|
||||
else:
|
||||
logger.error(f"Failed to update vector with ID {vector_id}: {response.status.error}")
|
||||
except Exception as e:
|
||||
logger.error(f"Update operation failed for vector ID {vector_id}: {e}")
|
||||
raise
|
||||
|
||||
def get(self, vector_id) -> MemoryResult:
|
||||
"""
|
||||
Retrieve a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to retrieve.
|
||||
|
||||
Returns:
|
||||
MemoryResult: The retrieved vector.
|
||||
"""
|
||||
try:
|
||||
# Use query with ID filter to retrieve the specific vector
|
||||
filters = {"memory_id": vector_id}
|
||||
filters_json = json.dumps(filters)
|
||||
|
||||
results = self.client.vector_search_indexes.query_index(
|
||||
index_name=self.fully_qualified_index_name,
|
||||
columns=self.column_names,
|
||||
query_text=" ", # Empty query, rely on filters
|
||||
num_results=1,
|
||||
query_type=self.query_type,
|
||||
filters_json=filters_json,
|
||||
)
|
||||
|
||||
# Process results
|
||||
result_data = results.result if hasattr(results, "result") else results
|
||||
data_array = result_data.data_array if hasattr(result_data, "data_array") else []
|
||||
|
||||
if not data_array:
|
||||
raise KeyError(f"Vector with ID {vector_id} not found")
|
||||
|
||||
result = data_array[0]
|
||||
row_data = result if isinstance(result, dict) else result.__dict__
|
||||
|
||||
# Build payload following the standard schema
|
||||
payload = {
|
||||
"hash": row_data.get("hash", "unknown"),
|
||||
"data": row_data.get("memory", row_data.get("data", "unknown")),
|
||||
"created_at": row_data.get("created_at"),
|
||||
}
|
||||
|
||||
# Add updated_at if available
|
||||
if "updated_at" in row_data:
|
||||
payload["updated_at"] = row_data.get("updated_at")
|
||||
|
||||
# Add optional fields
|
||||
for field in ["agent_id", "run_id", "user_id"]:
|
||||
if field in row_data:
|
||||
payload[field] = row_data[field]
|
||||
|
||||
# Add metadata
|
||||
if "metadata" in row_data:
|
||||
try:
|
||||
metadata = json.loads(extract_json(row_data["metadata"]))
|
||||
payload.update(metadata)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.warning(f"Failed to parse metadata: {row_data.get('metadata')}")
|
||||
|
||||
memory_id = row_data.get("memory_id", row_data.get("memory_id", vector_id))
|
||||
return MemoryResult(id=memory_id, payload=payload)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get vector with ID {vector_id}: {e}")
|
||||
raise
|
||||
|
||||
def list_cols(self) -> List[str]:
|
||||
"""
|
||||
List all collections (indexes).
|
||||
|
||||
Returns:
|
||||
List of index names.
|
||||
"""
|
||||
try:
|
||||
indexes = self.client.vector_search_indexes.list_indexes(endpoint_name=self.endpoint_name)
|
||||
return [idx.name for idx in indexes]
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list collections: {e}")
|
||||
raise
|
||||
|
||||
def delete_col(self):
|
||||
"""
|
||||
Delete the current collection (index).
|
||||
"""
|
||||
try:
|
||||
# Try fully qualified first
|
||||
try:
|
||||
self.client.vector_search_indexes.delete_index(index_name=self.fully_qualified_index_name)
|
||||
logger.info(f"Successfully deleted index '{self.fully_qualified_index_name}'")
|
||||
except Exception:
|
||||
self.client.vector_search_indexes.delete_index(index_name=self.index_name)
|
||||
logger.info(f"Successfully deleted index '{self.index_name}' (short name)")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete index '{self.index_name}': {e}")
|
||||
raise
|
||||
|
||||
def col_info(self, name=None):
|
||||
"""
|
||||
Get information about a collection (index).
|
||||
|
||||
Args:
|
||||
name (str, optional): Index name. Defaults to current index.
|
||||
|
||||
Returns:
|
||||
Dict: Index information.
|
||||
"""
|
||||
try:
|
||||
index_name = name or self.index_name
|
||||
index = self.client.vector_search_indexes.get_index(index_name=index_name)
|
||||
return {"name": index.name, "fields": self.columns}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get info for index '{name or self.index_name}': {e}")
|
||||
raise
|
||||
|
||||
def list(self, filters: dict = None, limit: int = None) -> list[MemoryResult]:
|
||||
"""
|
||||
List all recent created memories from the vector store.
|
||||
|
||||
Args:
|
||||
filters (dict, optional): Filters to apply.
|
||||
limit (int, optional): Maximum number of results.
|
||||
|
||||
Returns:
|
||||
List containing list of MemoryResult objects.
|
||||
"""
|
||||
try:
|
||||
filters_json = json.dumps(filters) if filters else None
|
||||
num_results = limit or 100
|
||||
columns = self.column_names
|
||||
sdk_results = self.client.vector_search_indexes.query_index(
|
||||
index_name=self.fully_qualified_index_name,
|
||||
columns=columns,
|
||||
query_text=" ",
|
||||
num_results=num_results,
|
||||
query_type=self.query_type,
|
||||
filters_json=filters_json,
|
||||
)
|
||||
result_data = sdk_results.result if hasattr(sdk_results, "result") else sdk_results
|
||||
data_array = result_data.data_array if hasattr(result_data, "data_array") else []
|
||||
|
||||
memory_results = []
|
||||
for row in data_array:
|
||||
row_dict = dict(zip(columns, row)) if isinstance(row, (list, tuple)) else row
|
||||
payload = {k: row_dict.get(k) for k in columns}
|
||||
# Parse metadata if present
|
||||
if "metadata" in payload and payload["metadata"]:
|
||||
try:
|
||||
payload.update(json.loads(payload["metadata"]))
|
||||
except Exception:
|
||||
pass
|
||||
memory_id = row_dict.get("memory_id") or row_dict.get("id")
|
||||
memory_results.append(MemoryResult(id=memory_id, payload=payload))
|
||||
return [memory_results]
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list memories: {e}")
|
||||
return []
|
||||
|
||||
def reset(self):
|
||||
"""Reset the vector search index and underlying source table.
|
||||
|
||||
This will attempt to delete the existing index (both fully qualified and short name forms
|
||||
for robustness), drop the backing Delta table, recreate the table with the expected schema,
|
||||
and finally recreate the index. Use with caution as all existing data will be removed.
|
||||
"""
|
||||
fq_index = self.fully_qualified_index_name
|
||||
logger.warning(f"Resetting Databricks vector search index '{fq_index}'...")
|
||||
try:
|
||||
# Try deleting via fully qualified name first
|
||||
try:
|
||||
self.client.vector_search_indexes.delete_index(index_name=fq_index)
|
||||
logger.info(f"Deleted index '{fq_index}'")
|
||||
except Exception as e_fq:
|
||||
logger.debug(f"Failed deleting fully qualified index name '{fq_index}': {e_fq}. Trying short name...")
|
||||
try:
|
||||
# Fallback to existing helper which may use short name
|
||||
self.delete_col()
|
||||
except Exception as e_short:
|
||||
logger.debug(f"Failed deleting short index name '{self.index_name}': {e_short}")
|
||||
|
||||
# Drop the backing table (if it exists)
|
||||
try:
|
||||
drop_sql = f"DROP TABLE IF EXISTS {self.fully_qualified_table_name}"
|
||||
resp = self.client.statement_execution.execute_statement(
|
||||
statement=drop_sql, warehouse_id=self.warehouse_id, wait_timeout="30s"
|
||||
)
|
||||
if getattr(resp.status, "state", None) == "SUCCEEDED":
|
||||
logger.info(f"Dropped table '{self.fully_qualified_table_name}'")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Attempted to drop table '{self.fully_qualified_table_name}' but state was {getattr(resp.status, 'state', 'UNKNOWN')}: {getattr(resp.status, 'error', None)}"
|
||||
)
|
||||
except Exception as e_drop:
|
||||
logger.warning(f"Failed to drop table '{self.fully_qualified_table_name}': {e_drop}")
|
||||
|
||||
# Recreate table & index
|
||||
self._ensure_source_table_exists()
|
||||
self.create_col()
|
||||
logger.info(f"Successfully reset index '{fq_index}'")
|
||||
except Exception as e:
|
||||
logger.error(f"Error resetting index '{fq_index}': {e}")
|
||||
raise
|
||||
237
neomem/neomem/vector_stores/elasticsearch.py
Normal file
237
neomem/neomem/vector_stores/elasticsearch.py
Normal file
@@ -0,0 +1,237 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
try:
|
||||
from elasticsearch import Elasticsearch
|
||||
from elasticsearch.helpers import bulk
|
||||
except ImportError:
|
||||
raise ImportError("Elasticsearch requires extra dependencies. Install with `pip install elasticsearch`") from None
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from mem0.configs.vector_stores.elasticsearch import ElasticsearchConfig
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OutputData(BaseModel):
|
||||
id: str
|
||||
score: float
|
||||
payload: Dict
|
||||
|
||||
|
||||
class ElasticsearchDB(VectorStoreBase):
|
||||
def __init__(self, **kwargs):
|
||||
config = ElasticsearchConfig(**kwargs)
|
||||
|
||||
# Initialize Elasticsearch client
|
||||
if config.cloud_id:
|
||||
self.client = Elasticsearch(
|
||||
cloud_id=config.cloud_id,
|
||||
api_key=config.api_key,
|
||||
verify_certs=config.verify_certs,
|
||||
headers= config.headers or {},
|
||||
)
|
||||
else:
|
||||
self.client = Elasticsearch(
|
||||
hosts=[f"{config.host}" if config.port is None else f"{config.host}:{config.port}"],
|
||||
basic_auth=(config.user, config.password) if (config.user and config.password) else None,
|
||||
verify_certs=config.verify_certs,
|
||||
headers= config.headers or {},
|
||||
)
|
||||
|
||||
self.collection_name = config.collection_name
|
||||
self.embedding_model_dims = config.embedding_model_dims
|
||||
|
||||
# Create index only if auto_create_index is True
|
||||
if config.auto_create_index:
|
||||
self.create_index()
|
||||
|
||||
if config.custom_search_query:
|
||||
self.custom_search_query = config.custom_search_query
|
||||
else:
|
||||
self.custom_search_query = None
|
||||
|
||||
def create_index(self) -> None:
|
||||
"""Create Elasticsearch index with proper mappings if it doesn't exist"""
|
||||
index_settings = {
|
||||
"settings": {"index": {"number_of_replicas": 1, "number_of_shards": 5, "refresh_interval": "1s"}},
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"text": {"type": "text"},
|
||||
"vector": {
|
||||
"type": "dense_vector",
|
||||
"dims": self.embedding_model_dims,
|
||||
"index": True,
|
||||
"similarity": "cosine",
|
||||
},
|
||||
"metadata": {"type": "object", "properties": {"user_id": {"type": "keyword"}}},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
if not self.client.indices.exists(index=self.collection_name):
|
||||
self.client.indices.create(index=self.collection_name, body=index_settings)
|
||||
logger.info(f"Created index {self.collection_name}")
|
||||
else:
|
||||
logger.info(f"Index {self.collection_name} already exists")
|
||||
|
||||
def create_col(self, name: str, vector_size: int, distance: str = "cosine") -> None:
|
||||
"""Create a new collection (index in Elasticsearch)."""
|
||||
index_settings = {
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"vector": {"type": "dense_vector", "dims": vector_size, "index": True, "similarity": "cosine"},
|
||||
"payload": {"type": "object"},
|
||||
"id": {"type": "keyword"},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if not self.client.indices.exists(index=name):
|
||||
self.client.indices.create(index=name, body=index_settings)
|
||||
logger.info(f"Created index {name}")
|
||||
|
||||
def insert(
|
||||
self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None
|
||||
) -> List[OutputData]:
|
||||
"""Insert vectors into the index."""
|
||||
if not ids:
|
||||
ids = [str(i) for i in range(len(vectors))]
|
||||
|
||||
if payloads is None:
|
||||
payloads = [{} for _ in range(len(vectors))]
|
||||
|
||||
actions = []
|
||||
for i, (vec, id_) in enumerate(zip(vectors, ids)):
|
||||
action = {
|
||||
"_index": self.collection_name,
|
||||
"_id": id_,
|
||||
"_source": {
|
||||
"vector": vec,
|
||||
"metadata": payloads[i], # Store all metadata in the metadata field
|
||||
},
|
||||
}
|
||||
actions.append(action)
|
||||
|
||||
bulk(self.client, actions)
|
||||
|
||||
results = []
|
||||
for i, id_ in enumerate(ids):
|
||||
results.append(
|
||||
OutputData(
|
||||
id=id_,
|
||||
score=1.0, # Default score for inserts
|
||||
payload=payloads[i],
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
def search(
|
||||
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
|
||||
) -> List[OutputData]:
|
||||
"""
|
||||
Search with two options:
|
||||
1. Use custom search query if provided
|
||||
2. Use KNN search on vectors with pre-filtering if no custom search query is provided
|
||||
"""
|
||||
if self.custom_search_query:
|
||||
search_query = self.custom_search_query(vectors, limit, filters)
|
||||
else:
|
||||
search_query = {
|
||||
"knn": {"field": "vector", "query_vector": vectors, "k": limit, "num_candidates": limit * 2}
|
||||
}
|
||||
if filters:
|
||||
filter_conditions = []
|
||||
for key, value in filters.items():
|
||||
filter_conditions.append({"term": {f"metadata.{key}": value}})
|
||||
search_query["knn"]["filter"] = {"bool": {"must": filter_conditions}}
|
||||
|
||||
response = self.client.search(index=self.collection_name, body=search_query)
|
||||
|
||||
results = []
|
||||
for hit in response["hits"]["hits"]:
|
||||
results.append(
|
||||
OutputData(id=hit["_id"], score=hit["_score"], payload=hit.get("_source", {}).get("metadata", {}))
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def delete(self, vector_id: str) -> None:
|
||||
"""Delete a vector by ID."""
|
||||
self.client.delete(index=self.collection_name, id=vector_id)
|
||||
|
||||
def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict] = None) -> None:
|
||||
"""Update a vector and its payload."""
|
||||
doc = {}
|
||||
if vector is not None:
|
||||
doc["vector"] = vector
|
||||
if payload is not None:
|
||||
doc["metadata"] = payload
|
||||
|
||||
self.client.update(index=self.collection_name, id=vector_id, body={"doc": doc})
|
||||
|
||||
def get(self, vector_id: str) -> Optional[OutputData]:
|
||||
"""Retrieve a vector by ID."""
|
||||
try:
|
||||
response = self.client.get(index=self.collection_name, id=vector_id)
|
||||
return OutputData(
|
||||
id=response["_id"],
|
||||
score=1.0, # Default score for direct get
|
||||
payload=response["_source"].get("metadata", {}),
|
||||
)
|
||||
except KeyError as e:
|
||||
logger.warning(f"Missing key in Elasticsearch response: {e}")
|
||||
return None
|
||||
except TypeError as e:
|
||||
logger.warning(f"Invalid response type from Elasticsearch: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while parsing Elasticsearch response: {e}")
|
||||
return None
|
||||
|
||||
def list_cols(self) -> List[str]:
|
||||
"""List all collections (indices)."""
|
||||
return list(self.client.indices.get_alias().keys())
|
||||
|
||||
def delete_col(self) -> None:
|
||||
"""Delete a collection (index)."""
|
||||
self.client.indices.delete(index=self.collection_name)
|
||||
|
||||
def col_info(self, name: str) -> Any:
|
||||
"""Get information about a collection (index)."""
|
||||
return self.client.indices.get(index=name)
|
||||
|
||||
def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None) -> List[List[OutputData]]:
|
||||
"""List all memories."""
|
||||
query: Dict[str, Any] = {"query": {"match_all": {}}}
|
||||
|
||||
if filters:
|
||||
filter_conditions = []
|
||||
for key, value in filters.items():
|
||||
filter_conditions.append({"term": {f"metadata.{key}": value}})
|
||||
query["query"] = {"bool": {"must": filter_conditions}}
|
||||
|
||||
if limit:
|
||||
query["size"] = limit
|
||||
|
||||
response = self.client.search(index=self.collection_name, body=query)
|
||||
|
||||
results = []
|
||||
for hit in response["hits"]["hits"]:
|
||||
results.append(
|
||||
OutputData(
|
||||
id=hit["_id"],
|
||||
score=1.0, # Default score for list operation
|
||||
payload=hit.get("_source", {}).get("metadata", {}),
|
||||
)
|
||||
)
|
||||
|
||||
return [results]
|
||||
|
||||
def reset(self):
|
||||
"""Reset the index by deleting and recreating it."""
|
||||
logger.warning(f"Resetting index {self.collection_name}...")
|
||||
self.delete_col()
|
||||
self.create_index()
|
||||
479
neomem/neomem/vector_stores/faiss.py
Normal file
479
neomem/neomem/vector_stores/faiss.py
Normal file
@@ -0,0 +1,479 @@
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel
|
||||
|
||||
import warnings
|
||||
|
||||
try:
|
||||
# Suppress SWIG deprecation warnings from FAISS
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning, message=".*SwigPy.*")
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning, message=".*swigvarlink.*")
|
||||
|
||||
logging.getLogger("faiss").setLevel(logging.WARNING)
|
||||
logging.getLogger("faiss.loader").setLevel(logging.WARNING)
|
||||
|
||||
import faiss
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import faiss python package. "
|
||||
"Please install it with `pip install faiss-gpu` (for CUDA supported GPU) "
|
||||
"or `pip install faiss-cpu` (depending on Python version)."
|
||||
)
|
||||
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OutputData(BaseModel):
|
||||
id: Optional[str] # memory id
|
||||
score: Optional[float] # distance
|
||||
payload: Optional[Dict] # metadata
|
||||
|
||||
|
||||
class FAISS(VectorStoreBase):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
path: Optional[str] = None,
|
||||
distance_strategy: str = "euclidean",
|
||||
normalize_L2: bool = False,
|
||||
embedding_model_dims: int = 1536,
|
||||
):
|
||||
"""
|
||||
Initialize the FAISS vector store.
|
||||
|
||||
Args:
|
||||
collection_name (str): Name of the collection.
|
||||
path (str, optional): Path for local FAISS database. Defaults to None.
|
||||
distance_strategy (str, optional): Distance strategy to use. Options: 'euclidean', 'inner_product', 'cosine'.
|
||||
Defaults to "euclidean".
|
||||
normalize_L2 (bool, optional): Whether to normalize L2 vectors. Only applicable for euclidean distance.
|
||||
Defaults to False.
|
||||
"""
|
||||
self.collection_name = collection_name
|
||||
self.path = path or f"/tmp/faiss/{collection_name}"
|
||||
self.distance_strategy = distance_strategy
|
||||
self.normalize_L2 = normalize_L2
|
||||
self.embedding_model_dims = embedding_model_dims
|
||||
|
||||
# Initialize storage structures
|
||||
self.index = None
|
||||
self.docstore = {}
|
||||
self.index_to_id = {}
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
if self.path:
|
||||
os.makedirs(os.path.dirname(self.path), exist_ok=True)
|
||||
|
||||
# Try to load existing index if available
|
||||
index_path = f"{self.path}/{collection_name}.faiss"
|
||||
docstore_path = f"{self.path}/{collection_name}.pkl"
|
||||
if os.path.exists(index_path) and os.path.exists(docstore_path):
|
||||
self._load(index_path, docstore_path)
|
||||
else:
|
||||
self.create_col(collection_name)
|
||||
|
||||
def _load(self, index_path: str, docstore_path: str):
|
||||
"""
|
||||
Load FAISS index and docstore from disk.
|
||||
|
||||
Args:
|
||||
index_path (str): Path to FAISS index file.
|
||||
docstore_path (str): Path to docstore pickle file.
|
||||
"""
|
||||
try:
|
||||
self.index = faiss.read_index(index_path)
|
||||
with open(docstore_path, "rb") as f:
|
||||
self.docstore, self.index_to_id = pickle.load(f)
|
||||
logger.info(f"Loaded FAISS index from {index_path} with {self.index.ntotal} vectors")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load FAISS index: {e}")
|
||||
|
||||
self.docstore = {}
|
||||
self.index_to_id = {}
|
||||
|
||||
def _save(self):
|
||||
"""Save FAISS index and docstore to disk."""
|
||||
if not self.path or not self.index:
|
||||
return
|
||||
|
||||
try:
|
||||
os.makedirs(self.path, exist_ok=True)
|
||||
index_path = f"{self.path}/{self.collection_name}.faiss"
|
||||
docstore_path = f"{self.path}/{self.collection_name}.pkl"
|
||||
|
||||
faiss.write_index(self.index, index_path)
|
||||
with open(docstore_path, "wb") as f:
|
||||
pickle.dump((self.docstore, self.index_to_id), f)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save FAISS index: {e}")
|
||||
|
||||
def _parse_output(self, scores, ids, limit=None) -> List[OutputData]:
|
||||
"""
|
||||
Parse the output data.
|
||||
|
||||
Args:
|
||||
scores: Similarity scores from FAISS.
|
||||
ids: Indices from FAISS.
|
||||
limit: Maximum number of results to return.
|
||||
|
||||
Returns:
|
||||
List[OutputData]: Parsed output data.
|
||||
"""
|
||||
if limit is None:
|
||||
limit = len(ids)
|
||||
|
||||
results = []
|
||||
for i in range(min(len(ids), limit)):
|
||||
if ids[i] == -1: # FAISS returns -1 for empty results
|
||||
continue
|
||||
|
||||
index_id = int(ids[i])
|
||||
vector_id = self.index_to_id.get(index_id)
|
||||
if vector_id is None:
|
||||
continue
|
||||
|
||||
payload = self.docstore.get(vector_id)
|
||||
if payload is None:
|
||||
continue
|
||||
|
||||
payload_copy = payload.copy()
|
||||
|
||||
score = float(scores[i])
|
||||
entry = OutputData(
|
||||
id=vector_id,
|
||||
score=score,
|
||||
payload=payload_copy,
|
||||
)
|
||||
results.append(entry)
|
||||
|
||||
return results
|
||||
|
||||
def create_col(self, name: str, distance: str = None):
|
||||
"""
|
||||
Create a new collection.
|
||||
|
||||
Args:
|
||||
name (str): Name of the collection.
|
||||
distance (str, optional): Distance metric to use. Overrides the distance_strategy
|
||||
passed during initialization. Defaults to None.
|
||||
|
||||
Returns:
|
||||
self: The FAISS instance.
|
||||
"""
|
||||
distance_strategy = distance or self.distance_strategy
|
||||
|
||||
# Create index based on distance strategy
|
||||
if distance_strategy.lower() == "inner_product" or distance_strategy.lower() == "cosine":
|
||||
self.index = faiss.IndexFlatIP(self.embedding_model_dims)
|
||||
else:
|
||||
self.index = faiss.IndexFlatL2(self.embedding_model_dims)
|
||||
|
||||
self.collection_name = name
|
||||
|
||||
self._save()
|
||||
|
||||
return self
|
||||
|
||||
def insert(
|
||||
self,
|
||||
vectors: List[list],
|
||||
payloads: Optional[List[Dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
):
|
||||
"""
|
||||
Insert vectors into a collection.
|
||||
|
||||
Args:
|
||||
vectors (List[list]): List of vectors to insert.
|
||||
payloads (Optional[List[Dict]], optional): List of payloads corresponding to vectors. Defaults to None.
|
||||
ids (Optional[List[str]], optional): List of IDs corresponding to vectors. Defaults to None.
|
||||
"""
|
||||
if self.index is None:
|
||||
raise ValueError("Collection not initialized. Call create_col first.")
|
||||
|
||||
if ids is None:
|
||||
ids = [str(uuid.uuid4()) for _ in range(len(vectors))]
|
||||
|
||||
if payloads is None:
|
||||
payloads = [{} for _ in range(len(vectors))]
|
||||
|
||||
if len(vectors) != len(ids) or len(vectors) != len(payloads):
|
||||
raise ValueError("Vectors, payloads, and IDs must have the same length")
|
||||
|
||||
vectors_np = np.array(vectors, dtype=np.float32)
|
||||
|
||||
if self.normalize_L2 and self.distance_strategy.lower() == "euclidean":
|
||||
faiss.normalize_L2(vectors_np)
|
||||
|
||||
self.index.add(vectors_np)
|
||||
|
||||
starting_idx = len(self.index_to_id)
|
||||
for i, (vector_id, payload) in enumerate(zip(ids, payloads)):
|
||||
self.docstore[vector_id] = payload.copy()
|
||||
self.index_to_id[starting_idx + i] = vector_id
|
||||
|
||||
self._save()
|
||||
|
||||
logger.info(f"Inserted {len(vectors)} vectors into collection {self.collection_name}")
|
||||
|
||||
def search(
|
||||
self, query: str, vectors: List[list], limit: int = 5, filters: Optional[Dict] = None
|
||||
) -> List[OutputData]:
|
||||
"""
|
||||
Search for similar vectors.
|
||||
|
||||
Args:
|
||||
query (str): Query (not used, kept for API compatibility).
|
||||
vectors (List[list]): List of vectors to search.
|
||||
limit (int, optional): Number of results to return. Defaults to 5.
|
||||
filters (Optional[Dict], optional): Filters to apply to the search. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[OutputData]: Search results.
|
||||
"""
|
||||
if self.index is None:
|
||||
raise ValueError("Collection not initialized. Call create_col first.")
|
||||
|
||||
query_vectors = np.array(vectors, dtype=np.float32)
|
||||
|
||||
if len(query_vectors.shape) == 1:
|
||||
query_vectors = query_vectors.reshape(1, -1)
|
||||
|
||||
if self.normalize_L2 and self.distance_strategy.lower() == "euclidean":
|
||||
faiss.normalize_L2(query_vectors)
|
||||
|
||||
fetch_k = limit * 2 if filters else limit
|
||||
scores, indices = self.index.search(query_vectors, fetch_k)
|
||||
|
||||
results = self._parse_output(scores[0], indices[0], limit)
|
||||
|
||||
if filters:
|
||||
filtered_results = []
|
||||
for result in results:
|
||||
if self._apply_filters(result.payload, filters):
|
||||
filtered_results.append(result)
|
||||
if len(filtered_results) >= limit:
|
||||
break
|
||||
results = filtered_results[:limit]
|
||||
|
||||
return results
|
||||
|
||||
def _apply_filters(self, payload: Dict, filters: Dict) -> bool:
|
||||
"""
|
||||
Apply filters to a payload.
|
||||
|
||||
Args:
|
||||
payload (Dict): Payload to filter.
|
||||
filters (Dict): Filters to apply.
|
||||
|
||||
Returns:
|
||||
bool: True if payload passes filters, False otherwise.
|
||||
"""
|
||||
if not filters or not payload:
|
||||
return True
|
||||
|
||||
for key, value in filters.items():
|
||||
if key not in payload:
|
||||
return False
|
||||
|
||||
if isinstance(value, list):
|
||||
if payload[key] not in value:
|
||||
return False
|
||||
elif payload[key] != value:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def delete(self, vector_id: str):
|
||||
"""
|
||||
Delete a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to delete.
|
||||
"""
|
||||
if self.index is None:
|
||||
raise ValueError("Collection not initialized. Call create_col first.")
|
||||
|
||||
index_to_delete = None
|
||||
for idx, vid in self.index_to_id.items():
|
||||
if vid == vector_id:
|
||||
index_to_delete = idx
|
||||
break
|
||||
|
||||
if index_to_delete is not None:
|
||||
self.docstore.pop(vector_id, None)
|
||||
self.index_to_id.pop(index_to_delete, None)
|
||||
|
||||
self._save()
|
||||
|
||||
logger.info(f"Deleted vector {vector_id} from collection {self.collection_name}")
|
||||
else:
|
||||
logger.warning(f"Vector {vector_id} not found in collection {self.collection_name}")
|
||||
|
||||
def update(
|
||||
self,
|
||||
vector_id: str,
|
||||
vector: Optional[List[float]] = None,
|
||||
payload: Optional[Dict] = None,
|
||||
):
|
||||
"""
|
||||
Update a vector and its payload.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to update.
|
||||
vector (Optional[List[float]], optional): Updated vector. Defaults to None.
|
||||
payload (Optional[Dict], optional): Updated payload. Defaults to None.
|
||||
"""
|
||||
if self.index is None:
|
||||
raise ValueError("Collection not initialized. Call create_col first.")
|
||||
|
||||
if vector_id not in self.docstore:
|
||||
raise ValueError(f"Vector {vector_id} not found")
|
||||
|
||||
current_payload = self.docstore[vector_id].copy()
|
||||
|
||||
if payload is not None:
|
||||
self.docstore[vector_id] = payload.copy()
|
||||
current_payload = self.docstore[vector_id].copy()
|
||||
|
||||
if vector is not None:
|
||||
self.delete(vector_id)
|
||||
self.insert([vector], [current_payload], [vector_id])
|
||||
else:
|
||||
self._save()
|
||||
|
||||
logger.info(f"Updated vector {vector_id} in collection {self.collection_name}")
|
||||
|
||||
def get(self, vector_id: str) -> OutputData:
|
||||
"""
|
||||
Retrieve a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to retrieve.
|
||||
|
||||
Returns:
|
||||
OutputData: Retrieved vector.
|
||||
"""
|
||||
if self.index is None:
|
||||
raise ValueError("Collection not initialized. Call create_col first.")
|
||||
|
||||
if vector_id not in self.docstore:
|
||||
return None
|
||||
|
||||
payload = self.docstore[vector_id].copy()
|
||||
|
||||
return OutputData(
|
||||
id=vector_id,
|
||||
score=None,
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
def list_cols(self) -> List[str]:
|
||||
"""
|
||||
List all collections.
|
||||
|
||||
Returns:
|
||||
List[str]: List of collection names.
|
||||
"""
|
||||
if not self.path:
|
||||
return [self.collection_name] if self.index else []
|
||||
|
||||
try:
|
||||
collections = []
|
||||
path = Path(self.path).parent
|
||||
for file in path.glob("*.faiss"):
|
||||
collections.append(file.stem)
|
||||
return collections
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to list collections: {e}")
|
||||
return [self.collection_name] if self.index else []
|
||||
|
||||
def delete_col(self):
|
||||
"""
|
||||
Delete a collection.
|
||||
"""
|
||||
if self.path:
|
||||
try:
|
||||
index_path = f"{self.path}/{self.collection_name}.faiss"
|
||||
docstore_path = f"{self.path}/{self.collection_name}.pkl"
|
||||
|
||||
if os.path.exists(index_path):
|
||||
os.remove(index_path)
|
||||
if os.path.exists(docstore_path):
|
||||
os.remove(docstore_path)
|
||||
|
||||
logger.info(f"Deleted collection {self.collection_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete collection: {e}")
|
||||
|
||||
self.index = None
|
||||
self.docstore = {}
|
||||
self.index_to_id = {}
|
||||
|
||||
def col_info(self) -> Dict:
|
||||
"""
|
||||
Get information about a collection.
|
||||
|
||||
Returns:
|
||||
Dict: Collection information.
|
||||
"""
|
||||
if self.index is None:
|
||||
return {"name": self.collection_name, "count": 0}
|
||||
|
||||
return {
|
||||
"name": self.collection_name,
|
||||
"count": self.index.ntotal,
|
||||
"dimension": self.index.d,
|
||||
"distance": self.distance_strategy,
|
||||
}
|
||||
|
||||
def list(self, filters: Optional[Dict] = None, limit: int = 100) -> List[OutputData]:
|
||||
"""
|
||||
List all vectors in a collection.
|
||||
|
||||
Args:
|
||||
filters (Optional[Dict], optional): Filters to apply to the list. Defaults to None.
|
||||
limit (int, optional): Number of vectors to return. Defaults to 100.
|
||||
|
||||
Returns:
|
||||
List[OutputData]: List of vectors.
|
||||
"""
|
||||
if self.index is None:
|
||||
return []
|
||||
|
||||
results = []
|
||||
count = 0
|
||||
|
||||
for vector_id, payload in self.docstore.items():
|
||||
if filters and not self._apply_filters(payload, filters):
|
||||
continue
|
||||
|
||||
payload_copy = payload.copy()
|
||||
|
||||
results.append(
|
||||
OutputData(
|
||||
id=vector_id,
|
||||
score=None,
|
||||
payload=payload_copy,
|
||||
)
|
||||
)
|
||||
|
||||
count += 1
|
||||
if count >= limit:
|
||||
break
|
||||
|
||||
return [results]
|
||||
|
||||
def reset(self):
|
||||
"""Reset the index by deleting and recreating it."""
|
||||
logger.warning(f"Resetting index {self.collection_name}...")
|
||||
self.delete_col()
|
||||
self.create_col(self.collection_name)
|
||||
180
neomem/neomem/vector_stores/langchain.py
Normal file
180
neomem/neomem/vector_stores/langchain.py
Normal file
@@ -0,0 +1,180 @@
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
try:
|
||||
from langchain_community.vectorstores import VectorStore
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The 'langchain_community' library is required. Please install it using 'pip install langchain_community'."
|
||||
)
|
||||
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OutputData(BaseModel):
|
||||
id: Optional[str] # memory id
|
||||
score: Optional[float] # distance
|
||||
payload: Optional[Dict] # metadata
|
||||
|
||||
|
||||
class Langchain(VectorStoreBase):
|
||||
def __init__(self, client: VectorStore, collection_name: str = "mem0"):
|
||||
self.client = client
|
||||
self.collection_name = collection_name
|
||||
|
||||
def _parse_output(self, data: Dict) -> List[OutputData]:
|
||||
"""
|
||||
Parse the output data.
|
||||
|
||||
Args:
|
||||
data (Dict): Output data or list of Document objects.
|
||||
|
||||
Returns:
|
||||
List[OutputData]: Parsed output data.
|
||||
"""
|
||||
# Check if input is a list of Document objects
|
||||
if isinstance(data, list) and all(hasattr(doc, "metadata") for doc in data if hasattr(doc, "__dict__")):
|
||||
result = []
|
||||
for doc in data:
|
||||
entry = OutputData(
|
||||
id=getattr(doc, "id", None),
|
||||
score=None, # Document objects typically don't include scores
|
||||
payload=getattr(doc, "metadata", {}),
|
||||
)
|
||||
result.append(entry)
|
||||
return result
|
||||
|
||||
# Original format handling
|
||||
keys = ["ids", "distances", "metadatas"]
|
||||
values = []
|
||||
|
||||
for key in keys:
|
||||
value = data.get(key, [])
|
||||
if isinstance(value, list) and value and isinstance(value[0], list):
|
||||
value = value[0]
|
||||
values.append(value)
|
||||
|
||||
ids, distances, metadatas = values
|
||||
max_length = max(len(v) for v in values if isinstance(v, list) and v is not None)
|
||||
|
||||
result = []
|
||||
for i in range(max_length):
|
||||
entry = OutputData(
|
||||
id=ids[i] if isinstance(ids, list) and ids and i < len(ids) else None,
|
||||
score=(distances[i] if isinstance(distances, list) and distances and i < len(distances) else None),
|
||||
payload=(metadatas[i] if isinstance(metadatas, list) and metadatas and i < len(metadatas) else None),
|
||||
)
|
||||
result.append(entry)
|
||||
|
||||
return result
|
||||
|
||||
def create_col(self, name, vector_size=None, distance=None):
|
||||
self.collection_name = name
|
||||
return self.client
|
||||
|
||||
def insert(
|
||||
self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None
|
||||
):
|
||||
"""
|
||||
Insert vectors into the LangChain vectorstore.
|
||||
"""
|
||||
# Check if client has add_embeddings method
|
||||
if hasattr(self.client, "add_embeddings"):
|
||||
# Some LangChain vectorstores have a direct add_embeddings method
|
||||
self.client.add_embeddings(embeddings=vectors, metadatas=payloads, ids=ids)
|
||||
else:
|
||||
# Fallback to add_texts method
|
||||
texts = [payload.get("data", "") for payload in payloads] if payloads else [""] * len(vectors)
|
||||
self.client.add_texts(texts=texts, metadatas=payloads, ids=ids)
|
||||
|
||||
def search(self, query: str, vectors: List[List[float]], limit: int = 5, filters: Optional[Dict] = None):
|
||||
"""
|
||||
Search for similar vectors in LangChain.
|
||||
"""
|
||||
# For each vector, perform a similarity search
|
||||
if filters:
|
||||
results = self.client.similarity_search_by_vector(embedding=vectors, k=limit, filter=filters)
|
||||
else:
|
||||
results = self.client.similarity_search_by_vector(embedding=vectors, k=limit)
|
||||
|
||||
final_results = self._parse_output(results)
|
||||
return final_results
|
||||
|
||||
def delete(self, vector_id):
|
||||
"""
|
||||
Delete a vector by ID.
|
||||
"""
|
||||
self.client.delete(ids=[vector_id])
|
||||
|
||||
def update(self, vector_id, vector=None, payload=None):
|
||||
"""
|
||||
Update a vector and its payload.
|
||||
"""
|
||||
self.delete(vector_id)
|
||||
self.insert(vector, payload, [vector_id])
|
||||
|
||||
def get(self, vector_id):
|
||||
"""
|
||||
Retrieve a vector by ID.
|
||||
"""
|
||||
docs = self.client.get_by_ids([vector_id])
|
||||
if docs and len(docs) > 0:
|
||||
doc = docs[0]
|
||||
return self._parse_output([doc])[0]
|
||||
return None
|
||||
|
||||
def list_cols(self):
|
||||
"""
|
||||
List all collections.
|
||||
"""
|
||||
# LangChain doesn't have collections
|
||||
return [self.collection_name]
|
||||
|
||||
def delete_col(self):
|
||||
"""
|
||||
Delete a collection.
|
||||
"""
|
||||
logger.warning("Deleting collection")
|
||||
if hasattr(self.client, "delete_collection"):
|
||||
self.client.delete_collection()
|
||||
elif hasattr(self.client, "reset_collection"):
|
||||
self.client.reset_collection()
|
||||
else:
|
||||
self.client.delete(ids=None)
|
||||
|
||||
def col_info(self):
|
||||
"""
|
||||
Get information about a collection.
|
||||
"""
|
||||
return {"name": self.collection_name}
|
||||
|
||||
def list(self, filters=None, limit=None):
|
||||
"""
|
||||
List all vectors in a collection.
|
||||
"""
|
||||
try:
|
||||
if hasattr(self.client, "_collection") and hasattr(self.client._collection, "get"):
|
||||
# Convert mem0 filters to Chroma where clause if needed
|
||||
where_clause = None
|
||||
if filters:
|
||||
# Handle all filters, not just user_id
|
||||
where_clause = filters
|
||||
|
||||
result = self.client._collection.get(where=where_clause, limit=limit)
|
||||
|
||||
# Convert the result to the expected format
|
||||
if result and isinstance(result, dict):
|
||||
return [self._parse_output(result)]
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing vectors from Chroma: {e}")
|
||||
return []
|
||||
|
||||
def reset(self):
|
||||
"""Reset the index by deleting and recreating it."""
|
||||
logger.warning(f"Resetting collection: {self.collection_name}")
|
||||
self.delete_col()
|
||||
247
neomem/neomem/vector_stores/milvus.py
Normal file
247
neomem/neomem/vector_stores/milvus.py
Normal file
@@ -0,0 +1,247 @@
|
||||
import logging
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from mem0.configs.vector_stores.milvus import MetricType
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
try:
|
||||
import pymilvus # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError("The 'pymilvus' library is required. Please install it using 'pip install pymilvus'.")
|
||||
|
||||
from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OutputData(BaseModel):
|
||||
id: Optional[str] # memory id
|
||||
score: Optional[float] # distance
|
||||
payload: Optional[Dict] # metadata
|
||||
|
||||
|
||||
class MilvusDB(VectorStoreBase):
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
token: str,
|
||||
collection_name: str,
|
||||
embedding_model_dims: int,
|
||||
metric_type: MetricType,
|
||||
db_name: str,
|
||||
) -> None:
|
||||
"""Initialize the MilvusDB database.
|
||||
|
||||
Args:
|
||||
url (str): Full URL for Milvus/Zilliz server.
|
||||
token (str): Token/api_key for Zilliz server / for local setup defaults to None.
|
||||
collection_name (str): Name of the collection (defaults to mem0).
|
||||
embedding_model_dims (int): Dimensions of the embedding model (defaults to 1536).
|
||||
metric_type (MetricType): Metric type for similarity search (defaults to L2).
|
||||
db_name (str): Name of the database (defaults to "").
|
||||
"""
|
||||
self.collection_name = collection_name
|
||||
self.embedding_model_dims = embedding_model_dims
|
||||
self.metric_type = metric_type
|
||||
self.client = MilvusClient(uri=url, token=token, db_name=db_name)
|
||||
self.create_col(
|
||||
collection_name=self.collection_name,
|
||||
vector_size=self.embedding_model_dims,
|
||||
metric_type=self.metric_type,
|
||||
)
|
||||
|
||||
def create_col(
|
||||
self,
|
||||
collection_name: str,
|
||||
vector_size: str,
|
||||
metric_type: MetricType = MetricType.COSINE,
|
||||
) -> None:
|
||||
"""Create a new collection with index_type AUTOINDEX.
|
||||
|
||||
Args:
|
||||
collection_name (str): Name of the collection (defaults to mem0).
|
||||
vector_size (str): Dimensions of the embedding model (defaults to 1536).
|
||||
metric_type (MetricType, optional): etric type for similarity search. Defaults to MetricType.COSINE.
|
||||
"""
|
||||
|
||||
if self.client.has_collection(collection_name):
|
||||
logger.info(f"Collection {collection_name} already exists. Skipping creation.")
|
||||
else:
|
||||
fields = [
|
||||
FieldSchema(name="id", dtype=DataType.VARCHAR, is_primary=True, max_length=512),
|
||||
FieldSchema(name="vectors", dtype=DataType.FLOAT_VECTOR, dim=vector_size),
|
||||
FieldSchema(name="metadata", dtype=DataType.JSON),
|
||||
]
|
||||
|
||||
schema = CollectionSchema(fields, enable_dynamic_field=True)
|
||||
|
||||
index = self.client.prepare_index_params(
|
||||
field_name="vectors", metric_type=metric_type, index_type="AUTOINDEX", index_name="vector_index"
|
||||
)
|
||||
self.client.create_collection(collection_name=collection_name, schema=schema, index_params=index)
|
||||
|
||||
def insert(self, ids, vectors, payloads, **kwargs: Optional[dict[str, any]]):
|
||||
"""Insert vectors into a collection.
|
||||
|
||||
Args:
|
||||
vectors (List[List[float]]): List of vectors to insert.
|
||||
payloads (List[Dict], optional): List of payloads corresponding to vectors.
|
||||
ids (List[str], optional): List of IDs corresponding to vectors.
|
||||
"""
|
||||
for idx, embedding, metadata in zip(ids, vectors, payloads):
|
||||
data = {"id": idx, "vectors": embedding, "metadata": metadata}
|
||||
self.client.insert(collection_name=self.collection_name, data=data, **kwargs)
|
||||
|
||||
def _create_filter(self, filters: dict):
|
||||
"""Prepare filters for efficient query.
|
||||
|
||||
Args:
|
||||
filters (dict): filters [user_id, agent_id, run_id]
|
||||
|
||||
Returns:
|
||||
str: formated filter.
|
||||
"""
|
||||
operands = []
|
||||
for key, value in filters.items():
|
||||
if isinstance(value, str):
|
||||
operands.append(f'(metadata["{key}"] == "{value}")')
|
||||
else:
|
||||
operands.append(f'(metadata["{key}"] == {value})')
|
||||
|
||||
return " and ".join(operands)
|
||||
|
||||
def _parse_output(self, data: list):
|
||||
"""
|
||||
Parse the output data.
|
||||
|
||||
Args:
|
||||
data (Dict): Output data.
|
||||
|
||||
Returns:
|
||||
List[OutputData]: Parsed output data.
|
||||
"""
|
||||
memory = []
|
||||
|
||||
for value in data:
|
||||
uid, score, metadata = (
|
||||
value.get("id"),
|
||||
value.get("distance"),
|
||||
value.get("entity", {}).get("metadata"),
|
||||
)
|
||||
|
||||
memory_obj = OutputData(id=uid, score=score, payload=metadata)
|
||||
memory.append(memory_obj)
|
||||
|
||||
return memory
|
||||
|
||||
def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None) -> list:
|
||||
"""
|
||||
Search for similar vectors.
|
||||
|
||||
Args:
|
||||
query (str): Query.
|
||||
vectors (List[float]): Query vector.
|
||||
limit (int, optional): Number of results to return. Defaults to 5.
|
||||
filters (Dict, optional): Filters to apply to the search. Defaults to None.
|
||||
|
||||
Returns:
|
||||
list: Search results.
|
||||
"""
|
||||
query_filter = self._create_filter(filters) if filters else None
|
||||
hits = self.client.search(
|
||||
collection_name=self.collection_name,
|
||||
data=[vectors],
|
||||
limit=limit,
|
||||
filter=query_filter,
|
||||
output_fields=["*"],
|
||||
)
|
||||
result = self._parse_output(data=hits[0])
|
||||
return result
|
||||
|
||||
def delete(self, vector_id):
|
||||
"""
|
||||
Delete a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to delete.
|
||||
"""
|
||||
self.client.delete(collection_name=self.collection_name, ids=vector_id)
|
||||
|
||||
def update(self, vector_id=None, vector=None, payload=None):
|
||||
"""
|
||||
Update a vector and its payload.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to update.
|
||||
vector (List[float], optional): Updated vector.
|
||||
payload (Dict, optional): Updated payload.
|
||||
"""
|
||||
schema = {"id": vector_id, "vectors": vector, "metadata": payload}
|
||||
self.client.upsert(collection_name=self.collection_name, data=schema)
|
||||
|
||||
def get(self, vector_id):
|
||||
"""
|
||||
Retrieve a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to retrieve.
|
||||
|
||||
Returns:
|
||||
OutputData: Retrieved vector.
|
||||
"""
|
||||
result = self.client.get(collection_name=self.collection_name, ids=vector_id)
|
||||
output = OutputData(
|
||||
id=result[0].get("id", None),
|
||||
score=None,
|
||||
payload=result[0].get("metadata", None),
|
||||
)
|
||||
return output
|
||||
|
||||
def list_cols(self):
|
||||
"""
|
||||
List all collections.
|
||||
|
||||
Returns:
|
||||
List[str]: List of collection names.
|
||||
"""
|
||||
return self.client.list_collections()
|
||||
|
||||
def delete_col(self):
|
||||
"""Delete a collection."""
|
||||
return self.client.drop_collection(collection_name=self.collection_name)
|
||||
|
||||
def col_info(self):
|
||||
"""
|
||||
Get information about a collection.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Collection information.
|
||||
"""
|
||||
return self.client.get_collection_stats(collection_name=self.collection_name)
|
||||
|
||||
def list(self, filters: dict = None, limit: int = 100) -> list:
|
||||
"""
|
||||
List all vectors in a collection.
|
||||
|
||||
Args:
|
||||
filters (Dict, optional): Filters to apply to the list.
|
||||
limit (int, optional): Number of vectors to return. Defaults to 100.
|
||||
|
||||
Returns:
|
||||
List[OutputData]: List of vectors.
|
||||
"""
|
||||
query_filter = self._create_filter(filters) if filters else None
|
||||
result = self.client.query(collection_name=self.collection_name, filter=query_filter, limit=limit)
|
||||
memories = []
|
||||
for data in result:
|
||||
obj = OutputData(id=data.get("id"), score=None, payload=data.get("metadata"))
|
||||
memories.append(obj)
|
||||
return [memories]
|
||||
|
||||
def reset(self):
|
||||
"""Reset the index by deleting and recreating it."""
|
||||
logger.warning(f"Resetting index {self.collection_name}...")
|
||||
self.delete_col()
|
||||
self.create_col(self.collection_name, self.embedding_model_dims, self.metric_type)
|
||||
310
neomem/neomem/vector_stores/mongodb.py
Normal file
310
neomem/neomem/vector_stores/mongodb.py
Normal file
@@ -0,0 +1,310 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
try:
|
||||
from pymongo import MongoClient
|
||||
from pymongo.errors import PyMongoError
|
||||
from pymongo.operations import SearchIndexModel
|
||||
except ImportError:
|
||||
raise ImportError("The 'pymongo' library is required. Please install it using 'pip install pymongo'.")
|
||||
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
class OutputData(BaseModel):
|
||||
id: Optional[str]
|
||||
score: Optional[float]
|
||||
payload: Optional[dict]
|
||||
|
||||
|
||||
class MongoDB(VectorStoreBase):
|
||||
VECTOR_TYPE = "knnVector"
|
||||
SIMILARITY_METRIC = "cosine"
|
||||
|
||||
def __init__(self, db_name: str, collection_name: str, embedding_model_dims: int, mongo_uri: str):
|
||||
"""
|
||||
Initialize the MongoDB vector store with vector search capabilities.
|
||||
|
||||
Args:
|
||||
db_name (str): Database name
|
||||
collection_name (str): Collection name
|
||||
embedding_model_dims (int): Dimension of the embedding vector
|
||||
mongo_uri (str): MongoDB connection URI
|
||||
"""
|
||||
self.collection_name = collection_name
|
||||
self.embedding_model_dims = embedding_model_dims
|
||||
self.db_name = db_name
|
||||
|
||||
self.client = MongoClient(mongo_uri)
|
||||
self.db = self.client[db_name]
|
||||
self.collection = self.create_col()
|
||||
|
||||
def create_col(self):
|
||||
"""Create new collection with vector search index."""
|
||||
try:
|
||||
database = self.client[self.db_name]
|
||||
collection_names = database.list_collection_names()
|
||||
if self.collection_name not in collection_names:
|
||||
logger.info(f"Collection '{self.collection_name}' does not exist. Creating it now.")
|
||||
collection = database[self.collection_name]
|
||||
# Insert and remove a placeholder document to create the collection
|
||||
collection.insert_one({"_id": 0, "placeholder": True})
|
||||
collection.delete_one({"_id": 0})
|
||||
logger.info(f"Collection '{self.collection_name}' created successfully.")
|
||||
else:
|
||||
collection = database[self.collection_name]
|
||||
|
||||
self.index_name = f"{self.collection_name}_vector_index"
|
||||
found_indexes = list(collection.list_search_indexes(name=self.index_name))
|
||||
if found_indexes:
|
||||
logger.info(f"Search index '{self.index_name}' already exists in collection '{self.collection_name}'.")
|
||||
else:
|
||||
search_index_model = SearchIndexModel(
|
||||
name=self.index_name,
|
||||
definition={
|
||||
"mappings": {
|
||||
"dynamic": False,
|
||||
"fields": {
|
||||
"embedding": {
|
||||
"type": self.VECTOR_TYPE,
|
||||
"dimensions": self.embedding_model_dims,
|
||||
"similarity": self.SIMILARITY_METRIC,
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
collection.create_search_index(search_index_model)
|
||||
logger.info(
|
||||
f"Search index '{self.index_name}' created successfully for collection '{self.collection_name}'."
|
||||
)
|
||||
return collection
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error creating collection and search index: {e}")
|
||||
return None
|
||||
|
||||
def insert(
|
||||
self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None
|
||||
) -> None:
|
||||
"""
|
||||
Insert vectors into the collection.
|
||||
|
||||
Args:
|
||||
vectors (List[List[float]]): List of vectors to insert.
|
||||
payloads (List[Dict], optional): List of payloads corresponding to vectors.
|
||||
ids (List[str], optional): List of IDs corresponding to vectors.
|
||||
"""
|
||||
logger.info(f"Inserting {len(vectors)} vectors into collection '{self.collection_name}'.")
|
||||
|
||||
data = []
|
||||
for vector, payload, _id in zip(vectors, payloads or [{}] * len(vectors), ids or [None] * len(vectors)):
|
||||
document = {"_id": _id, "embedding": vector, "payload": payload}
|
||||
data.append(document)
|
||||
try:
|
||||
self.collection.insert_many(data)
|
||||
logger.info(f"Inserted {len(data)} documents into '{self.collection_name}'.")
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error inserting data: {e}")
|
||||
|
||||
def search(self, query: str, vectors: List[float], limit=5, filters: Optional[Dict] = None) -> List[OutputData]:
|
||||
"""
|
||||
Search for similar vectors using the vector search index.
|
||||
|
||||
Args:
|
||||
query (str): Query string
|
||||
vectors (List[float]): Query vector.
|
||||
limit (int, optional): Number of results to return. Defaults to 5.
|
||||
filters (Dict, optional): Filters to apply to the search.
|
||||
|
||||
Returns:
|
||||
List[OutputData]: Search results.
|
||||
"""
|
||||
|
||||
found_indexes = list(self.collection.list_search_indexes(name=self.index_name))
|
||||
if not found_indexes:
|
||||
logger.error(f"Index '{self.index_name}' does not exist.")
|
||||
return []
|
||||
|
||||
results = []
|
||||
try:
|
||||
collection = self.client[self.db_name][self.collection_name]
|
||||
pipeline = [
|
||||
{
|
||||
"$vectorSearch": {
|
||||
"index": self.index_name,
|
||||
"limit": limit,
|
||||
"numCandidates": limit,
|
||||
"queryVector": vectors,
|
||||
"path": "embedding",
|
||||
}
|
||||
},
|
||||
{"$set": {"score": {"$meta": "vectorSearchScore"}}},
|
||||
{"$project": {"embedding": 0}},
|
||||
]
|
||||
|
||||
# Add filter stage if filters are provided
|
||||
if filters:
|
||||
filter_conditions = []
|
||||
for key, value in filters.items():
|
||||
filter_conditions.append({"payload." + key: value})
|
||||
|
||||
if filter_conditions:
|
||||
# Add a $match stage after vector search to apply filters
|
||||
pipeline.insert(1, {"$match": {"$and": filter_conditions}})
|
||||
|
||||
results = list(collection.aggregate(pipeline))
|
||||
logger.info(f"Vector search completed. Found {len(results)} documents.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during vector search for query {query}: {e}")
|
||||
return []
|
||||
|
||||
output = [OutputData(id=str(doc["_id"]), score=doc.get("score"), payload=doc.get("payload")) for doc in results]
|
||||
return output
|
||||
|
||||
def delete(self, vector_id: str) -> None:
|
||||
"""
|
||||
Delete a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to delete.
|
||||
"""
|
||||
try:
|
||||
result = self.collection.delete_one({"_id": vector_id})
|
||||
if result.deleted_count > 0:
|
||||
logger.info(f"Deleted document with ID '{vector_id}'.")
|
||||
else:
|
||||
logger.warning(f"No document found with ID '{vector_id}' to delete.")
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error deleting document: {e}")
|
||||
|
||||
def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict] = None) -> None:
|
||||
"""
|
||||
Update a vector and its payload.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to update.
|
||||
vector (List[float], optional): Updated vector.
|
||||
payload (Dict, optional): Updated payload.
|
||||
"""
|
||||
update_fields = {}
|
||||
if vector is not None:
|
||||
update_fields["embedding"] = vector
|
||||
if payload is not None:
|
||||
update_fields["payload"] = payload
|
||||
|
||||
if update_fields:
|
||||
try:
|
||||
result = self.collection.update_one({"_id": vector_id}, {"$set": update_fields})
|
||||
if result.matched_count > 0:
|
||||
logger.info(f"Updated document with ID '{vector_id}'.")
|
||||
else:
|
||||
logger.warning(f"No document found with ID '{vector_id}' to update.")
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error updating document: {e}")
|
||||
|
||||
def get(self, vector_id: str) -> Optional[OutputData]:
|
||||
"""
|
||||
Retrieve a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to retrieve.
|
||||
|
||||
Returns:
|
||||
Optional[OutputData]: Retrieved vector or None if not found.
|
||||
"""
|
||||
try:
|
||||
doc = self.collection.find_one({"_id": vector_id})
|
||||
if doc:
|
||||
logger.info(f"Retrieved document with ID '{vector_id}'.")
|
||||
return OutputData(id=str(doc["_id"]), score=None, payload=doc.get("payload"))
|
||||
else:
|
||||
logger.warning(f"Document with ID '{vector_id}' not found.")
|
||||
return None
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error retrieving document: {e}")
|
||||
return None
|
||||
|
||||
def list_cols(self) -> List[str]:
|
||||
"""
|
||||
List all collections in the database.
|
||||
|
||||
Returns:
|
||||
List[str]: List of collection names.
|
||||
"""
|
||||
try:
|
||||
collections = self.db.list_collection_names()
|
||||
logger.info(f"Listing collections in database '{self.db_name}': {collections}")
|
||||
return collections
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error listing collections: {e}")
|
||||
return []
|
||||
|
||||
def delete_col(self) -> None:
|
||||
"""Delete the collection."""
|
||||
try:
|
||||
self.collection.drop()
|
||||
logger.info(f"Deleted collection '{self.collection_name}'.")
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error deleting collection: {e}")
|
||||
|
||||
def col_info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about the collection.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Collection information.
|
||||
"""
|
||||
try:
|
||||
stats = self.db.command("collstats", self.collection_name)
|
||||
info = {"name": self.collection_name, "count": stats.get("count"), "size": stats.get("size")}
|
||||
logger.info(f"Collection info: {info}")
|
||||
return info
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error getting collection info: {e}")
|
||||
return {}
|
||||
|
||||
def list(self, filters: Optional[Dict] = None, limit: int = 100) -> List[OutputData]:
|
||||
"""
|
||||
List vectors in the collection.
|
||||
|
||||
Args:
|
||||
filters (Dict, optional): Filters to apply to the list.
|
||||
limit (int, optional): Number of vectors to return.
|
||||
|
||||
Returns:
|
||||
List[OutputData]: List of vectors.
|
||||
"""
|
||||
try:
|
||||
query = {}
|
||||
if filters:
|
||||
# Apply filters to the payload field
|
||||
filter_conditions = []
|
||||
for key, value in filters.items():
|
||||
filter_conditions.append({"payload." + key: value})
|
||||
if filter_conditions:
|
||||
query = {"$and": filter_conditions}
|
||||
|
||||
cursor = self.collection.find(query).limit(limit)
|
||||
results = [OutputData(id=str(doc["_id"]), score=None, payload=doc.get("payload")) for doc in cursor]
|
||||
logger.info(f"Retrieved {len(results)} documents from collection '{self.collection_name}'.")
|
||||
return results
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error listing documents: {e}")
|
||||
return []
|
||||
|
||||
def reset(self):
|
||||
"""Reset the index by deleting and recreating it."""
|
||||
logger.warning(f"Resetting index {self.collection_name}...")
|
||||
self.delete_col()
|
||||
self.collection = self.create_col(self.collection_name)
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""Close the database connection when the object is deleted."""
|
||||
if hasattr(self, "client"):
|
||||
self.client.close()
|
||||
logger.info("MongoClient connection closed.")
|
||||
467
neomem/neomem/vector_stores/neptune_analytics.py
Normal file
467
neomem/neomem/vector_stores/neptune_analytics.py
Normal file
@@ -0,0 +1,467 @@
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
try:
|
||||
from langchain_aws import NeptuneAnalyticsGraph
|
||||
except ImportError:
|
||||
raise ImportError("langchain_aws is not installed. Please install it using pip install langchain_aws")
|
||||
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class OutputData(BaseModel):
|
||||
id: Optional[str] # memory id
|
||||
score: Optional[float] # distance
|
||||
payload: Optional[Dict] # metadata
|
||||
|
||||
|
||||
class NeptuneAnalyticsVector(VectorStoreBase):
|
||||
"""
|
||||
Neptune Analytics vector store implementation for Mem0.
|
||||
|
||||
Provides vector storage and similarity search capabilities using Amazon Neptune Analytics,
|
||||
a serverless graph analytics service that supports vector operations.
|
||||
"""
|
||||
|
||||
_COLLECTION_PREFIX = "MEM0_VECTOR_"
|
||||
_FIELD_N = 'n'
|
||||
_FIELD_ID = '~id'
|
||||
_FIELD_PROP = '~properties'
|
||||
_FIELD_SCORE = 'score'
|
||||
_FIELD_LABEL = 'label'
|
||||
_TIMEZONE = "UTC"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str,
|
||||
collection_name: str,
|
||||
):
|
||||
"""
|
||||
Initialize the Neptune Analytics vector store.
|
||||
|
||||
Args:
|
||||
endpoint (str): Neptune Analytics endpoint in format 'neptune-graph://<graphid>'.
|
||||
collection_name (str): Name of the collection to store vectors.
|
||||
|
||||
Raises:
|
||||
ValueError: If endpoint format is invalid.
|
||||
ImportError: If langchain_aws is not installed.
|
||||
"""
|
||||
|
||||
if not endpoint.startswith("neptune-graph://"):
|
||||
raise ValueError("Please provide 'endpoint' with the format as 'neptune-graph://<graphid>'.")
|
||||
|
||||
graph_id = endpoint.replace("neptune-graph://", "")
|
||||
self.graph = NeptuneAnalyticsGraph(graph_id)
|
||||
self.collection_name = self._COLLECTION_PREFIX + collection_name
|
||||
|
||||
|
||||
def create_col(self, name, vector_size, distance):
|
||||
"""
|
||||
Create a collection (no-op for Neptune Analytics).
|
||||
|
||||
Neptune Analytics supports dynamic indices that are created implicitly
|
||||
when vectors are inserted, so this method performs no operation.
|
||||
|
||||
Args:
|
||||
name: Collection name (unused).
|
||||
vector_size: Vector dimension (unused).
|
||||
distance: Distance metric (unused).
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def insert(self, vectors: List[list],
|
||||
payloads: Optional[List[Dict]] = None,
|
||||
ids: Optional[List[str]] = None):
|
||||
"""
|
||||
Insert vectors into the collection.
|
||||
|
||||
Creates or updates nodes in Neptune Analytics with vector embeddings and metadata.
|
||||
Uses MERGE operation to handle both creation and updates.
|
||||
|
||||
Args:
|
||||
vectors (List[list]): List of embedding vectors to insert.
|
||||
payloads (Optional[List[Dict]]): Optional metadata for each vector.
|
||||
ids (Optional[List[str]]): Optional IDs for vectors. Generated if not provided.
|
||||
"""
|
||||
|
||||
para_list = []
|
||||
for index, data_vector in enumerate(vectors):
|
||||
if payloads:
|
||||
payload = payloads[index]
|
||||
payload[self._FIELD_LABEL] = self.collection_name
|
||||
payload["updated_at"] = str(int(time.time()))
|
||||
else:
|
||||
payload = {}
|
||||
para_list.append(dict(
|
||||
node_id=ids[index] if ids else str(uuid.uuid4()),
|
||||
properties=payload,
|
||||
embedding=data_vector,
|
||||
))
|
||||
|
||||
para_map_to_insert = {"rows": para_list}
|
||||
|
||||
query_string = (f"""
|
||||
UNWIND $rows AS row
|
||||
MERGE (n :{self.collection_name} {{`~id`: row.node_id}})
|
||||
ON CREATE SET n = row.properties
|
||||
ON MATCH SET n += row.properties
|
||||
"""
|
||||
)
|
||||
self.execute_query(query_string, para_map_to_insert)
|
||||
|
||||
|
||||
query_string_vector = (f"""
|
||||
UNWIND $rows AS row
|
||||
MATCH (n
|
||||
:{self.collection_name}
|
||||
{{`~id`: row.node_id}})
|
||||
WITH n, row.embedding AS embedding
|
||||
CALL neptune.algo.vectors.upsert(n, embedding)
|
||||
YIELD success
|
||||
RETURN success
|
||||
"""
|
||||
)
|
||||
result = self.execute_query(query_string_vector, para_map_to_insert)
|
||||
self._process_success_message(result, "Vector store - Insert")
|
||||
|
||||
|
||||
def search(
|
||||
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
|
||||
) -> List[OutputData]:
|
||||
"""
|
||||
Search for similar vectors using embedding similarity.
|
||||
|
||||
Performs vector similarity search using Neptune Analytics' topKByEmbeddingWithFiltering
|
||||
algorithm to find the most similar vectors.
|
||||
|
||||
Args:
|
||||
query (str): Search query text (unused in vector search).
|
||||
vectors (List[float]): Query embedding vector.
|
||||
limit (int, optional): Maximum number of results to return. Defaults to 5.
|
||||
filters (Optional[Dict]): Optional filters to apply to search results.
|
||||
|
||||
Returns:
|
||||
List[OutputData]: List of similar vectors with scores and metadata.
|
||||
"""
|
||||
|
||||
if not filters:
|
||||
filters = {}
|
||||
filters[self._FIELD_LABEL] = self.collection_name
|
||||
|
||||
filter_clause = self._get_node_filter_clause(filters)
|
||||
|
||||
query_string = f"""
|
||||
CALL neptune.algo.vectors.topKByEmbeddingWithFiltering({{
|
||||
topK: {limit},
|
||||
embedding: {vectors}
|
||||
{filter_clause}
|
||||
}}
|
||||
)
|
||||
YIELD node, score
|
||||
RETURN node as n, score
|
||||
"""
|
||||
query_response = self.execute_query(query_string)
|
||||
if len(query_response) > 0:
|
||||
return self._parse_query_responses(query_response, with_score=True)
|
||||
else :
|
||||
return []
|
||||
|
||||
|
||||
def delete(self, vector_id: str):
|
||||
"""
|
||||
Delete a vector by its ID.
|
||||
|
||||
Removes the node and all its relationships from the Neptune Analytics graph.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to delete.
|
||||
"""
|
||||
params = dict(node_id=vector_id)
|
||||
query_string = f"""
|
||||
MATCH (n :{self.collection_name})
|
||||
WHERE id(n) = $node_id
|
||||
DETACH DELETE n
|
||||
"""
|
||||
self.execute_query(query_string, params)
|
||||
|
||||
def update(
|
||||
self,
|
||||
vector_id: str,
|
||||
vector: Optional[List[float]] = None,
|
||||
payload: Optional[Dict] = None,
|
||||
):
|
||||
"""
|
||||
Update a vector's embedding and/or metadata.
|
||||
|
||||
Updates the node properties and/or vector embedding for an existing vector.
|
||||
Can update either the payload, the vector, or both.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to update.
|
||||
vector (Optional[List[float]]): New embedding vector.
|
||||
payload (Optional[Dict]): New metadata to replace existing payload.
|
||||
"""
|
||||
|
||||
if payload:
|
||||
# Replace payload
|
||||
payload[self._FIELD_LABEL] = self.collection_name
|
||||
payload["updated_at"] = str(int(time.time()))
|
||||
para_payload = {
|
||||
"properties": payload,
|
||||
"vector_id": vector_id
|
||||
}
|
||||
query_string_embedding = f"""
|
||||
MATCH (n :{self.collection_name})
|
||||
WHERE id(n) = $vector_id
|
||||
SET n = $properties
|
||||
"""
|
||||
self.execute_query(query_string_embedding, para_payload)
|
||||
|
||||
if vector:
|
||||
para_embedding = {
|
||||
"embedding": vector,
|
||||
"vector_id": vector_id
|
||||
}
|
||||
query_string_embedding = f"""
|
||||
MATCH (n :{self.collection_name})
|
||||
WHERE id(n) = $vector_id
|
||||
WITH $embedding as embedding, n as n
|
||||
CALL neptune.algo.vectors.upsert(n, embedding)
|
||||
YIELD success
|
||||
RETURN success
|
||||
"""
|
||||
self.execute_query(query_string_embedding, para_embedding)
|
||||
|
||||
|
||||
|
||||
def get(self, vector_id: str):
|
||||
"""
|
||||
Retrieve a vector by its ID.
|
||||
|
||||
Fetches the node data including metadata for the specified vector ID.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to retrieve.
|
||||
|
||||
Returns:
|
||||
OutputData: Vector data with metadata, or None if not found.
|
||||
"""
|
||||
params = dict(node_id=vector_id)
|
||||
query_string = f"""
|
||||
MATCH (n :{self.collection_name})
|
||||
WHERE id(n) = $node_id
|
||||
RETURN n
|
||||
"""
|
||||
|
||||
# Composite the query
|
||||
result = self.execute_query(query_string, params)
|
||||
|
||||
if len(result) != 0:
|
||||
return self._parse_query_responses(result)[0]
|
||||
|
||||
|
||||
def list_cols(self):
|
||||
"""
|
||||
List all collections with the Mem0 prefix.
|
||||
|
||||
Queries the Neptune Analytics schema to find all node labels that start
|
||||
with the Mem0 collection prefix.
|
||||
|
||||
Returns:
|
||||
List[str]: List of collection names.
|
||||
"""
|
||||
query_string = f"""
|
||||
CALL neptune.graph.pg_schema()
|
||||
YIELD schema
|
||||
RETURN [ label IN schema.nodeLabels WHERE label STARTS WITH '{self.collection_name}'] AS result
|
||||
"""
|
||||
result = self.execute_query(query_string)
|
||||
if len(result) == 1 and "result" in result[0]:
|
||||
return result[0]["result"]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
def delete_col(self):
|
||||
"""
|
||||
Delete the entire collection.
|
||||
|
||||
Removes all nodes with the collection label and their relationships
|
||||
from the Neptune Analytics graph.
|
||||
"""
|
||||
self.execute_query(f"MATCH (n :{self.collection_name}) DETACH DELETE n")
|
||||
|
||||
|
||||
def col_info(self):
|
||||
"""
|
||||
Get collection information (no-op for Neptune Analytics).
|
||||
|
||||
Collections are created dynamically in Neptune Analytics, so no
|
||||
collection-specific metadata is available.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def list(self, filters: Optional[Dict] = None, limit: int = 100) -> List[OutputData]:
|
||||
"""
|
||||
List all vectors in the collection with optional filtering.
|
||||
|
||||
Retrieves vectors from the collection, optionally filtered by metadata properties.
|
||||
|
||||
Args:
|
||||
filters (Optional[Dict]): Optional filters to apply based on metadata.
|
||||
limit (int, optional): Maximum number of vectors to return. Defaults to 100.
|
||||
|
||||
Returns:
|
||||
List[OutputData]: List of vectors with their metadata.
|
||||
"""
|
||||
where_clause = self._get_where_clause(filters) if filters else ""
|
||||
|
||||
para = {
|
||||
"limit": limit,
|
||||
}
|
||||
query_string = f"""
|
||||
MATCH (n :{self.collection_name})
|
||||
{where_clause}
|
||||
RETURN n
|
||||
LIMIT $limit
|
||||
"""
|
||||
query_response = self.execute_query(query_string, para)
|
||||
|
||||
if len(query_response) > 0:
|
||||
# Handle if there is no match.
|
||||
return [self._parse_query_responses(query_response)]
|
||||
return [[]]
|
||||
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the collection by deleting all vectors.
|
||||
|
||||
Removes all vectors from the collection, effectively resetting it to empty state.
|
||||
"""
|
||||
self.delete_col()
|
||||
|
||||
|
||||
def _parse_query_responses(self, response: dict, with_score: bool = False):
|
||||
"""
|
||||
Parse Neptune Analytics query responses into OutputData objects.
|
||||
|
||||
Args:
|
||||
response (dict): Raw query response from Neptune Analytics.
|
||||
with_score (bool, optional): Whether to include similarity scores. Defaults to False.
|
||||
|
||||
Returns:
|
||||
List[OutputData]: Parsed response data.
|
||||
"""
|
||||
result = []
|
||||
# Handle if there is no match.
|
||||
for item in response:
|
||||
id = item[self._FIELD_N][self._FIELD_ID]
|
||||
properties = item[self._FIELD_N][self._FIELD_PROP]
|
||||
properties.pop("label", None)
|
||||
if with_score:
|
||||
score = item[self._FIELD_SCORE]
|
||||
else:
|
||||
score = None
|
||||
result.append(OutputData(
|
||||
id=id,
|
||||
score=score,
|
||||
payload=properties,
|
||||
))
|
||||
return result
|
||||
|
||||
|
||||
def execute_query(self, query_string: str, params=None):
|
||||
"""
|
||||
Execute an openCypher query on Neptune Analytics.
|
||||
|
||||
This is a wrapper method around the Neptune Analytics graph query execution
|
||||
that provides debug logging for query monitoring and troubleshooting.
|
||||
|
||||
Args:
|
||||
query_string (str): The openCypher query string to execute.
|
||||
params (dict): Parameters to bind to the query.
|
||||
|
||||
Returns:
|
||||
Query result from Neptune Analytics graph execution.
|
||||
"""
|
||||
if params is None:
|
||||
params = {}
|
||||
logger.debug(f"Executing openCypher query:[{query_string}], with parameters:[{params}].")
|
||||
return self.graph.query(query_string, params)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _get_where_clause(filters: dict):
|
||||
"""
|
||||
Build WHERE clause for Cypher queries from filters.
|
||||
|
||||
Args:
|
||||
filters (dict): Filter conditions as key-value pairs.
|
||||
|
||||
Returns:
|
||||
str: Formatted WHERE clause for Cypher query.
|
||||
"""
|
||||
where_clause = ""
|
||||
for i, (k, v) in enumerate(filters.items()):
|
||||
if i == 0:
|
||||
where_clause += f"WHERE n.{k} = '{v}' "
|
||||
else:
|
||||
where_clause += f"AND n.{k} = '{v}' "
|
||||
return where_clause
|
||||
|
||||
@staticmethod
|
||||
def _get_node_filter_clause(filters: dict):
|
||||
"""
|
||||
Build node filter clause for vector search operations.
|
||||
|
||||
Creates filter conditions for Neptune Analytics vector search operations
|
||||
using the nodeFilter parameter format.
|
||||
|
||||
Args:
|
||||
filters (dict): Filter conditions as key-value pairs.
|
||||
|
||||
Returns:
|
||||
str: Formatted node filter clause for vector search.
|
||||
"""
|
||||
conditions = []
|
||||
for k, v in filters.items():
|
||||
conditions.append(f"{{equals:{{property: '{k}', value: '{v}'}}}}")
|
||||
|
||||
if len(conditions) == 1:
|
||||
filter_clause = f", nodeFilter: {conditions[0]}"
|
||||
else:
|
||||
filter_clause = f"""
|
||||
, nodeFilter: {{andAll: [ {", ".join(conditions)} ]}}
|
||||
"""
|
||||
return filter_clause
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _process_success_message(response, context):
|
||||
"""
|
||||
Process and validate success messages from Neptune Analytics operations.
|
||||
|
||||
Checks the response from vector operations (insert/update) to ensure they
|
||||
completed successfully. Logs errors if operations fail.
|
||||
|
||||
Args:
|
||||
response: Response from Neptune Analytics vector operation.
|
||||
context (str): Context description for logging (e.g., "Vector store - Insert").
|
||||
"""
|
||||
for success_message in response:
|
||||
if "success" not in success_message:
|
||||
logger.error(f"Query execution status is absent on action: [{context}]")
|
||||
break
|
||||
|
||||
if success_message["success"] is not True:
|
||||
logger.error(f"Abnormal response status on action: [{context}] with message: [{success_message['success']}] ")
|
||||
break
|
||||
281
neomem/neomem/vector_stores/opensearch.py
Normal file
281
neomem/neomem/vector_stores/opensearch.py
Normal file
@@ -0,0 +1,281 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
try:
|
||||
from opensearchpy import OpenSearch, RequestsHttpConnection
|
||||
except ImportError:
|
||||
raise ImportError("OpenSearch requires extra dependencies. Install with `pip install opensearch-py`") from None
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from mem0.configs.vector_stores.opensearch import OpenSearchConfig
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OutputData(BaseModel):
|
||||
id: str
|
||||
score: float
|
||||
payload: Dict
|
||||
|
||||
|
||||
class OpenSearchDB(VectorStoreBase):
|
||||
def __init__(self, **kwargs):
|
||||
config = OpenSearchConfig(**kwargs)
|
||||
|
||||
# Initialize OpenSearch client
|
||||
self.client = OpenSearch(
|
||||
hosts=[{"host": config.host, "port": config.port or 9200}],
|
||||
http_auth=config.http_auth
|
||||
if config.http_auth
|
||||
else ((config.user, config.password) if (config.user and config.password) else None),
|
||||
use_ssl=config.use_ssl,
|
||||
verify_certs=config.verify_certs,
|
||||
connection_class=RequestsHttpConnection,
|
||||
pool_maxsize=20,
|
||||
)
|
||||
|
||||
self.collection_name = config.collection_name
|
||||
self.embedding_model_dims = config.embedding_model_dims
|
||||
self.create_col(self.collection_name, self.embedding_model_dims)
|
||||
|
||||
def create_index(self) -> None:
|
||||
"""Create OpenSearch index with proper mappings if it doesn't exist."""
|
||||
index_settings = {
|
||||
"settings": {
|
||||
"index": {"number_of_replicas": 1, "number_of_shards": 5, "refresh_interval": "10s", "knn": True}
|
||||
},
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"text": {"type": "text"},
|
||||
"vector_field": {
|
||||
"type": "knn_vector",
|
||||
"dimension": self.embedding_model_dims,
|
||||
"method": {"engine": "nmslib", "name": "hnsw", "space_type": "cosinesimil"},
|
||||
},
|
||||
"metadata": {"type": "object", "properties": {"user_id": {"type": "keyword"}}},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
if not self.client.indices.exists(index=self.collection_name):
|
||||
self.client.indices.create(index=self.collection_name, body=index_settings)
|
||||
logger.info(f"Created index {self.collection_name}")
|
||||
else:
|
||||
logger.info(f"Index {self.collection_name} already exists")
|
||||
|
||||
def create_col(self, name: str, vector_size: int) -> None:
|
||||
"""Create a new collection (index in OpenSearch)."""
|
||||
index_settings = {
|
||||
"settings": {"index.knn": True},
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"vector_field": {
|
||||
"type": "knn_vector",
|
||||
"dimension": vector_size,
|
||||
"method": {"engine": "nmslib", "name": "hnsw", "space_type": "cosinesimil"},
|
||||
},
|
||||
"payload": {"type": "object"},
|
||||
"id": {"type": "keyword"},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
if not self.client.indices.exists(index=name):
|
||||
logger.warning(f"Creating index {name}, it might take 1-2 minutes...")
|
||||
self.client.indices.create(index=name, body=index_settings)
|
||||
|
||||
# Wait for index to be ready
|
||||
max_retries = 180 # 3 minutes timeout
|
||||
retry_count = 0
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
# Check if index is ready by attempting a simple search
|
||||
self.client.search(index=name, body={"query": {"match_all": {}}})
|
||||
time.sleep(1)
|
||||
logger.info(f"Index {name} is ready")
|
||||
return
|
||||
except Exception:
|
||||
retry_count += 1
|
||||
if retry_count == max_retries:
|
||||
raise TimeoutError(f"Index {name} creation timed out after {max_retries} seconds")
|
||||
time.sleep(0.5)
|
||||
|
||||
def insert(
|
||||
self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None
|
||||
) -> List[OutputData]:
|
||||
"""Insert vectors into the index."""
|
||||
if not ids:
|
||||
ids = [str(i) for i in range(len(vectors))]
|
||||
|
||||
if payloads is None:
|
||||
payloads = [{} for _ in range(len(vectors))]
|
||||
|
||||
for i, (vec, id_) in enumerate(zip(vectors, ids)):
|
||||
body = {
|
||||
"vector_field": vec,
|
||||
"payload": payloads[i],
|
||||
"id": id_,
|
||||
}
|
||||
self.client.index(index=self.collection_name, body=body)
|
||||
|
||||
results = []
|
||||
|
||||
return results
|
||||
|
||||
def search(
|
||||
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
|
||||
) -> List[OutputData]:
|
||||
"""Search for similar vectors using OpenSearch k-NN search with optional filters."""
|
||||
|
||||
# Base KNN query
|
||||
knn_query = {
|
||||
"knn": {
|
||||
"vector_field": {
|
||||
"vector": vectors,
|
||||
"k": limit * 2,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Start building the full query
|
||||
query_body = {"size": limit * 2, "query": None}
|
||||
|
||||
# Prepare filter conditions if applicable
|
||||
filter_clauses = []
|
||||
if filters:
|
||||
for key in ["user_id", "run_id", "agent_id"]:
|
||||
value = filters.get(key)
|
||||
if value:
|
||||
filter_clauses.append({"term": {f"payload.{key}.keyword": value}})
|
||||
|
||||
# Combine knn with filters if needed
|
||||
if filter_clauses:
|
||||
query_body["query"] = {"bool": {"must": knn_query, "filter": filter_clauses}}
|
||||
else:
|
||||
query_body["query"] = knn_query
|
||||
|
||||
# Execute search
|
||||
response = self.client.search(index=self.collection_name, body=query_body)
|
||||
|
||||
hits = response["hits"]["hits"]
|
||||
results = [
|
||||
OutputData(id=hit["_source"].get("id"), score=hit["_score"], payload=hit["_source"].get("payload", {}))
|
||||
for hit in hits
|
||||
]
|
||||
return results
|
||||
|
||||
def delete(self, vector_id: str) -> None:
|
||||
"""Delete a vector by custom ID."""
|
||||
# First, find the document by custom ID
|
||||
search_query = {"query": {"term": {"id": vector_id}}}
|
||||
|
||||
response = self.client.search(index=self.collection_name, body=search_query)
|
||||
hits = response.get("hits", {}).get("hits", [])
|
||||
|
||||
if not hits:
|
||||
return
|
||||
|
||||
opensearch_id = hits[0]["_id"]
|
||||
|
||||
# Delete using the actual document ID
|
||||
self.client.delete(index=self.collection_name, id=opensearch_id)
|
||||
|
||||
def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict] = None) -> None:
|
||||
"""Update a vector and its payload using the custom 'id' field."""
|
||||
|
||||
# First, find the document by custom ID
|
||||
search_query = {"query": {"term": {"id": vector_id}}}
|
||||
|
||||
response = self.client.search(index=self.collection_name, body=search_query)
|
||||
hits = response.get("hits", {}).get("hits", [])
|
||||
|
||||
if not hits:
|
||||
return
|
||||
|
||||
opensearch_id = hits[0]["_id"] # The actual document ID in OpenSearch
|
||||
|
||||
# Prepare updated fields
|
||||
doc = {}
|
||||
if vector is not None:
|
||||
doc["vector_field"] = vector
|
||||
if payload is not None:
|
||||
doc["payload"] = payload
|
||||
|
||||
if doc:
|
||||
try:
|
||||
response = self.client.update(index=self.collection_name, id=opensearch_id, body={"doc": doc})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def get(self, vector_id: str) -> Optional[OutputData]:
|
||||
"""Retrieve a vector by ID."""
|
||||
try:
|
||||
# First check if index exists
|
||||
if not self.client.indices.exists(index=self.collection_name):
|
||||
logger.info(f"Index {self.collection_name} does not exist, creating it...")
|
||||
self.create_col(self.collection_name, self.embedding_model_dims)
|
||||
return None
|
||||
|
||||
search_query = {"query": {"term": {"id": vector_id}}}
|
||||
response = self.client.search(index=self.collection_name, body=search_query)
|
||||
|
||||
hits = response["hits"]["hits"]
|
||||
|
||||
if not hits:
|
||||
return None
|
||||
|
||||
return OutputData(id=hits[0]["_source"].get("id"), score=1.0, payload=hits[0]["_source"].get("payload", {}))
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector {vector_id}: {str(e)}")
|
||||
return None
|
||||
|
||||
def list_cols(self) -> List[str]:
|
||||
"""List all collections (indices)."""
|
||||
return list(self.client.indices.get_alias().keys())
|
||||
|
||||
def delete_col(self) -> None:
|
||||
"""Delete a collection (index)."""
|
||||
self.client.indices.delete(index=self.collection_name)
|
||||
|
||||
def col_info(self, name: str) -> Any:
|
||||
"""Get information about a collection (index)."""
|
||||
return self.client.indices.get(index=name)
|
||||
|
||||
def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None) -> List[OutputData]:
|
||||
try:
|
||||
"""List all memories with optional filters."""
|
||||
query: Dict = {"query": {"match_all": {}}}
|
||||
|
||||
filter_clauses = []
|
||||
if filters:
|
||||
for key in ["user_id", "run_id", "agent_id"]:
|
||||
value = filters.get(key)
|
||||
if value:
|
||||
filter_clauses.append({"term": {f"payload.{key}.keyword": value}})
|
||||
|
||||
if filter_clauses:
|
||||
query["query"] = {"bool": {"filter": filter_clauses}}
|
||||
|
||||
if limit:
|
||||
query["size"] = limit
|
||||
|
||||
response = self.client.search(index=self.collection_name, body=query)
|
||||
hits = response["hits"]["hits"]
|
||||
|
||||
return [
|
||||
[
|
||||
OutputData(id=hit["_source"].get("id"), score=1.0, payload=hit["_source"].get("payload", {}))
|
||||
for hit in hits
|
||||
]
|
||||
]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def reset(self):
|
||||
"""Reset the index by deleting and recreating it."""
|
||||
logger.warning(f"Resetting index {self.collection_name}...")
|
||||
self.delete_col()
|
||||
self.create_col(self.collection_name, self.embedding_model_dims)
|
||||
404
neomem/neomem/vector_stores/pgvector.py
Normal file
404
neomem/neomem/vector_stores/pgvector.py
Normal file
@@ -0,0 +1,404 @@
|
||||
import json
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Try to import psycopg (psycopg3) first, then fall back to psycopg2
|
||||
try:
|
||||
from psycopg.types.json import Json
|
||||
from psycopg_pool import ConnectionPool
|
||||
PSYCOPG_VERSION = 3
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info("Using psycopg (psycopg3) with ConnectionPool for PostgreSQL connections")
|
||||
except ImportError:
|
||||
try:
|
||||
from psycopg2.extras import Json, execute_values
|
||||
from psycopg2.pool import ThreadedConnectionPool as ConnectionPool
|
||||
PSYCOPG_VERSION = 2
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info("Using psycopg2 with ThreadedConnectionPool for PostgreSQL connections")
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Neither 'psycopg' nor 'psycopg2' library is available. "
|
||||
"Please install one of them using 'pip install psycopg[pool]' or 'pip install psycopg2'"
|
||||
)
|
||||
|
||||
from neomem.vector_stores.base import VectorStoreBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OutputData(BaseModel):
|
||||
id: Optional[str]
|
||||
score: Optional[float]
|
||||
payload: Optional[dict]
|
||||
|
||||
|
||||
class PGVector(VectorStoreBase):
|
||||
def __init__(
|
||||
self,
|
||||
dbname,
|
||||
collection_name,
|
||||
embedding_model_dims,
|
||||
user,
|
||||
password,
|
||||
host,
|
||||
port,
|
||||
diskann,
|
||||
hnsw,
|
||||
minconn=1,
|
||||
maxconn=5,
|
||||
sslmode=None,
|
||||
connection_string=None,
|
||||
connection_pool=None,
|
||||
):
|
||||
"""
|
||||
Initialize the PGVector database.
|
||||
|
||||
Args:
|
||||
dbname (str): Database name
|
||||
collection_name (str): Collection name
|
||||
embedding_model_dims (int): Dimension of the embedding vector
|
||||
user (str): Database user
|
||||
password (str): Database password
|
||||
host (str, optional): Database host
|
||||
port (int, optional): Database port
|
||||
diskann (bool, optional): Use DiskANN for faster search
|
||||
hnsw (bool, optional): Use HNSW for faster search
|
||||
minconn (int): Minimum number of connections to keep in the connection pool
|
||||
maxconn (int): Maximum number of connections allowed in the connection pool
|
||||
sslmode (str, optional): SSL mode for PostgreSQL connection (e.g., 'require', 'prefer', 'disable')
|
||||
connection_string (str, optional): PostgreSQL connection string (overrides individual connection parameters)
|
||||
connection_pool (Any, optional): psycopg2 connection pool object (overrides connection string and individual parameters)
|
||||
"""
|
||||
self.collection_name = collection_name
|
||||
self.use_diskann = diskann
|
||||
self.use_hnsw = hnsw
|
||||
self.embedding_model_dims = embedding_model_dims
|
||||
self.connection_pool = None
|
||||
|
||||
# Connection setup with priority: connection_pool > connection_string > individual parameters
|
||||
if connection_pool is not None:
|
||||
# Use provided connection pool
|
||||
self.connection_pool = connection_pool
|
||||
elif connection_string:
|
||||
if sslmode:
|
||||
# Append sslmode to connection string if provided
|
||||
if 'sslmode=' in connection_string:
|
||||
# Replace existing sslmode
|
||||
import re
|
||||
connection_string = re.sub(r'sslmode=[^ ]*', f'sslmode={sslmode}', connection_string)
|
||||
else:
|
||||
# Add sslmode to connection string
|
||||
connection_string = f"{connection_string} sslmode={sslmode}"
|
||||
else:
|
||||
connection_string = f"postgresql://{user}:{password}@{host}:{port}/{dbname}"
|
||||
if sslmode:
|
||||
connection_string = f"{connection_string} sslmode={sslmode}"
|
||||
|
||||
if self.connection_pool is None:
|
||||
if PSYCOPG_VERSION == 3:
|
||||
# psycopg3 ConnectionPool
|
||||
self.connection_pool = ConnectionPool(conninfo=connection_string, min_size=minconn, max_size=maxconn, open=True)
|
||||
else:
|
||||
# psycopg2 ThreadedConnectionPool
|
||||
self.connection_pool = ConnectionPool(minconn=minconn, maxconn=maxconn, dsn=connection_string)
|
||||
|
||||
collections = self.list_cols()
|
||||
if collection_name not in collections:
|
||||
self.create_col()
|
||||
|
||||
@contextmanager
|
||||
def _get_cursor(self, commit: bool = False):
|
||||
"""
|
||||
Unified context manager to get a cursor from the appropriate pool.
|
||||
Auto-commits or rolls back based on exception, and returns the connection to the pool.
|
||||
"""
|
||||
if PSYCOPG_VERSION == 3:
|
||||
# psycopg3 auto-manages commit/rollback and pool return
|
||||
with self.connection_pool.connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
try:
|
||||
yield cur
|
||||
if commit:
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
logger.error("Error in cursor context (psycopg3)", exc_info=True)
|
||||
raise
|
||||
else:
|
||||
# psycopg2 manual getconn/putconn
|
||||
conn = self.connection_pool.getconn()
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
yield cur
|
||||
if commit:
|
||||
conn.commit()
|
||||
except Exception as exc:
|
||||
conn.rollback()
|
||||
logger.error(f"Error occurred: {exc}")
|
||||
raise exc
|
||||
finally:
|
||||
cur.close()
|
||||
self.connection_pool.putconn(conn)
|
||||
|
||||
def create_col(self) -> None:
|
||||
"""
|
||||
Create a new collection (table in PostgreSQL).
|
||||
Will also initialize vector search index if specified.
|
||||
"""
|
||||
with self._get_cursor(commit=True) as cur:
|
||||
cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
||||
cur.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.collection_name} (
|
||||
id UUID PRIMARY KEY,
|
||||
vector vector({self.embedding_model_dims}),
|
||||
payload JSONB
|
||||
);
|
||||
"""
|
||||
)
|
||||
if self.use_diskann and self.embedding_model_dims < 2000:
|
||||
cur.execute("SELECT * FROM pg_extension WHERE extname = 'vectorscale'")
|
||||
if cur.fetchone():
|
||||
# Create DiskANN index if extension is installed for faster search
|
||||
cur.execute(
|
||||
f"""
|
||||
CREATE INDEX IF NOT EXISTS {self.collection_name}_diskann_idx
|
||||
ON {self.collection_name}
|
||||
USING diskann (vector);
|
||||
"""
|
||||
)
|
||||
elif self.use_hnsw:
|
||||
cur.execute(
|
||||
f"""
|
||||
CREATE INDEX IF NOT EXISTS {self.collection_name}_hnsw_idx
|
||||
ON {self.collection_name}
|
||||
USING hnsw (vector vector_cosine_ops)
|
||||
"""
|
||||
)
|
||||
|
||||
def insert(self, vectors: list[list[float]], payloads=None, ids=None) -> None:
|
||||
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
|
||||
json_payloads = [json.dumps(payload) for payload in payloads]
|
||||
|
||||
data = [(id, vector, payload) for id, vector, payload in zip(ids, vectors, json_payloads)]
|
||||
if PSYCOPG_VERSION == 3:
|
||||
with self._get_cursor(commit=True) as cur:
|
||||
cur.executemany(
|
||||
f"INSERT INTO {self.collection_name} (id, vector, payload) VALUES (%s, %s, %s)",
|
||||
data,
|
||||
)
|
||||
else:
|
||||
with self._get_cursor(commit=True) as cur:
|
||||
execute_values(
|
||||
cur,
|
||||
f"INSERT INTO {self.collection_name} (id, vector, payload) VALUES %s",
|
||||
data,
|
||||
)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
vectors: list[float],
|
||||
limit: Optional[int] = 5,
|
||||
filters: Optional[dict] = None,
|
||||
) -> List[OutputData]:
|
||||
"""
|
||||
Search for similar vectors.
|
||||
|
||||
Args:
|
||||
query (str): Query.
|
||||
vectors (List[float]): Query vector.
|
||||
limit (int, optional): Number of results to return. Defaults to 5.
|
||||
filters (Dict, optional): Filters to apply to the search. Defaults to None.
|
||||
|
||||
Returns:
|
||||
list: Search results.
|
||||
"""
|
||||
filter_conditions = []
|
||||
filter_params = []
|
||||
|
||||
if filters:
|
||||
for k, v in filters.items():
|
||||
filter_conditions.append("payload->>%s = %s")
|
||||
filter_params.extend([k, str(v)])
|
||||
|
||||
filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
|
||||
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"""
|
||||
SELECT id, vector <=> %s::vector AS distance, payload
|
||||
FROM {self.collection_name}
|
||||
{filter_clause}
|
||||
ORDER BY distance
|
||||
LIMIT %s
|
||||
""",
|
||||
(vectors, *filter_params, limit),
|
||||
)
|
||||
|
||||
results = cur.fetchall()
|
||||
return [OutputData(id=str(r[0]), score=float(r[1]), payload=r[2]) for r in results]
|
||||
|
||||
def delete(self, vector_id: str) -> None:
|
||||
"""
|
||||
Delete a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to delete.
|
||||
"""
|
||||
with self._get_cursor(commit=True) as cur:
|
||||
cur.execute(f"DELETE FROM {self.collection_name} WHERE id = %s", (vector_id,))
|
||||
|
||||
def update(
|
||||
self,
|
||||
vector_id: str,
|
||||
vector: Optional[list[float]] = None,
|
||||
payload: Optional[dict] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Update a vector and its payload.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to update.
|
||||
vector (List[float], optional): Updated vector.
|
||||
payload (Dict, optional): Updated payload.
|
||||
"""
|
||||
with self._get_cursor(commit=True) as cur:
|
||||
if vector:
|
||||
cur.execute(
|
||||
f"UPDATE {self.collection_name} SET vector = %s WHERE id = %s",
|
||||
(vector, vector_id),
|
||||
)
|
||||
if payload:
|
||||
# Handle JSON serialization based on psycopg version
|
||||
if PSYCOPG_VERSION == 3:
|
||||
# psycopg3 uses psycopg.types.json.Json
|
||||
cur.execute(
|
||||
f"UPDATE {self.collection_name} SET payload = %s WHERE id = %s",
|
||||
(Json(payload), vector_id),
|
||||
)
|
||||
else:
|
||||
# psycopg2 uses psycopg2.extras.Json
|
||||
cur.execute(
|
||||
f"UPDATE {self.collection_name} SET payload = %s WHERE id = %s",
|
||||
(Json(payload), vector_id),
|
||||
)
|
||||
|
||||
|
||||
def get(self, vector_id: str) -> OutputData:
|
||||
"""
|
||||
Retrieve a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to retrieve.
|
||||
|
||||
Returns:
|
||||
OutputData: Retrieved vector.
|
||||
"""
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"SELECT id, vector, payload FROM {self.collection_name} WHERE id = %s",
|
||||
(vector_id,),
|
||||
)
|
||||
result = cur.fetchone()
|
||||
if not result:
|
||||
return None
|
||||
return OutputData(id=str(result[0]), score=None, payload=result[2])
|
||||
|
||||
def list_cols(self) -> List[str]:
|
||||
"""
|
||||
List all collections.
|
||||
|
||||
Returns:
|
||||
List[str]: List of collection names.
|
||||
"""
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'")
|
||||
return [row[0] for row in cur.fetchall()]
|
||||
|
||||
def delete_col(self) -> None:
|
||||
"""Delete a collection."""
|
||||
with self._get_cursor(commit=True) as cur:
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.collection_name}")
|
||||
|
||||
def col_info(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get information about a collection.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Collection information.
|
||||
"""
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"""
|
||||
SELECT
|
||||
table_name,
|
||||
(SELECT COUNT(*) FROM {self.collection_name}) as row_count,
|
||||
(SELECT pg_size_pretty(pg_total_relation_size('{self.collection_name}'))) as total_size
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = 'public' AND table_name = %s
|
||||
""",
|
||||
(self.collection_name,),
|
||||
)
|
||||
result = cur.fetchone()
|
||||
return {"name": result[0], "count": result[1], "size": result[2]}
|
||||
|
||||
def list(
|
||||
self,
|
||||
filters: Optional[dict] = None,
|
||||
limit: Optional[int] = 100
|
||||
) -> List[OutputData]:
|
||||
"""
|
||||
List all vectors in a collection.
|
||||
|
||||
Args:
|
||||
filters (Dict, optional): Filters to apply to the list.
|
||||
limit (int, optional): Number of vectors to return. Defaults to 100.
|
||||
|
||||
Returns:
|
||||
List[OutputData]: List of vectors.
|
||||
"""
|
||||
filter_conditions = []
|
||||
filter_params = []
|
||||
|
||||
if filters:
|
||||
for k, v in filters.items():
|
||||
filter_conditions.append("payload->>%s = %s")
|
||||
filter_params.extend([k, str(v)])
|
||||
|
||||
filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
|
||||
|
||||
query = f"""
|
||||
SELECT id, vector, payload
|
||||
FROM {self.collection_name}
|
||||
{filter_clause}
|
||||
LIMIT %s
|
||||
"""
|
||||
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(query, (*filter_params, limit))
|
||||
results = cur.fetchall()
|
||||
return [[OutputData(id=str(r[0]), score=None, payload=r[2]) for r in results]]
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""
|
||||
Close the database connection pool when the object is deleted.
|
||||
"""
|
||||
try:
|
||||
# Close pool appropriately
|
||||
if PSYCOPG_VERSION == 3:
|
||||
self.connection_pool.close()
|
||||
else:
|
||||
self.connection_pool.closeall()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the index by deleting and recreating it."""
|
||||
logger.warning(f"Resetting index {self.collection_name}...")
|
||||
self.delete_col()
|
||||
self.create_col()
|
||||
382
neomem/neomem/vector_stores/pinecone.py
Normal file
382
neomem/neomem/vector_stores/pinecone.py
Normal file
@@ -0,0 +1,382 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
try:
|
||||
from pinecone import Pinecone, PodSpec, ServerlessSpec, Vector
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Pinecone requires extra dependencies. Install with `pip install pinecone pinecone-text`"
|
||||
) from None
|
||||
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OutputData(BaseModel):
|
||||
id: Optional[str] # memory id
|
||||
score: Optional[float] # distance
|
||||
payload: Optional[Dict] # metadata
|
||||
|
||||
|
||||
class PineconeDB(VectorStoreBase):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
embedding_model_dims: int,
|
||||
client: Optional["Pinecone"],
|
||||
api_key: Optional[str],
|
||||
environment: Optional[str],
|
||||
serverless_config: Optional[Dict[str, Any]],
|
||||
pod_config: Optional[Dict[str, Any]],
|
||||
hybrid_search: bool,
|
||||
metric: str,
|
||||
batch_size: int,
|
||||
extra_params: Optional[Dict[str, Any]],
|
||||
namespace: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the Pinecone vector store.
|
||||
|
||||
Args:
|
||||
collection_name (str): Name of the index/collection.
|
||||
embedding_model_dims (int): Dimensions of the embedding model.
|
||||
client (Pinecone, optional): Existing Pinecone client instance. Defaults to None.
|
||||
api_key (str, optional): API key for Pinecone. Defaults to None.
|
||||
environment (str, optional): Pinecone environment. Defaults to None.
|
||||
serverless_config (Dict, optional): Configuration for serverless deployment. Defaults to None.
|
||||
pod_config (Dict, optional): Configuration for pod-based deployment. Defaults to None.
|
||||
hybrid_search (bool, optional): Whether to enable hybrid search. Defaults to False.
|
||||
metric (str, optional): Distance metric for vector similarity. Defaults to "cosine".
|
||||
batch_size (int, optional): Batch size for operations. Defaults to 100.
|
||||
extra_params (Dict, optional): Additional parameters for Pinecone client. Defaults to None.
|
||||
namespace (str, optional): Namespace for the collection. Defaults to None.
|
||||
"""
|
||||
if client:
|
||||
self.client = client
|
||||
else:
|
||||
api_key = api_key or os.environ.get("PINECONE_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"Pinecone API key must be provided either as a parameter or as an environment variable"
|
||||
)
|
||||
|
||||
params = extra_params or {}
|
||||
self.client = Pinecone(api_key=api_key, **params)
|
||||
|
||||
self.collection_name = collection_name
|
||||
self.embedding_model_dims = embedding_model_dims
|
||||
self.environment = environment
|
||||
self.serverless_config = serverless_config
|
||||
self.pod_config = pod_config
|
||||
self.hybrid_search = hybrid_search
|
||||
self.metric = metric
|
||||
self.batch_size = batch_size
|
||||
self.namespace = namespace
|
||||
|
||||
self.sparse_encoder = None
|
||||
if self.hybrid_search:
|
||||
try:
|
||||
from pinecone_text.sparse import BM25Encoder
|
||||
|
||||
logger.info("Initializing BM25Encoder for sparse vectors...")
|
||||
self.sparse_encoder = BM25Encoder.default()
|
||||
except ImportError:
|
||||
logger.warning("pinecone-text not installed. Hybrid search will be disabled.")
|
||||
self.hybrid_search = False
|
||||
|
||||
self.create_col(embedding_model_dims, metric)
|
||||
|
||||
def create_col(self, vector_size: int, metric: str = "cosine"):
|
||||
"""
|
||||
Create a new index/collection.
|
||||
|
||||
Args:
|
||||
vector_size (int): Size of the vectors to be stored.
|
||||
metric (str, optional): Distance metric for vector similarity. Defaults to "cosine".
|
||||
"""
|
||||
existing_indexes = self.list_cols().names()
|
||||
|
||||
if self.collection_name in existing_indexes:
|
||||
logger.debug(f"Index {self.collection_name} already exists. Skipping creation.")
|
||||
self.index = self.client.Index(self.collection_name)
|
||||
return
|
||||
|
||||
if self.serverless_config:
|
||||
spec = ServerlessSpec(**self.serverless_config)
|
||||
elif self.pod_config:
|
||||
spec = PodSpec(**self.pod_config)
|
||||
else:
|
||||
spec = ServerlessSpec(cloud="aws", region="us-west-2")
|
||||
|
||||
self.client.create_index(
|
||||
name=self.collection_name,
|
||||
dimension=vector_size,
|
||||
metric=metric,
|
||||
spec=spec,
|
||||
)
|
||||
|
||||
self.index = self.client.Index(self.collection_name)
|
||||
|
||||
def insert(
|
||||
self,
|
||||
vectors: List[List[float]],
|
||||
payloads: Optional[List[Dict]] = None,
|
||||
ids: Optional[List[Union[str, int]]] = None,
|
||||
):
|
||||
"""
|
||||
Insert vectors into an index.
|
||||
|
||||
Args:
|
||||
vectors (list): List of vectors to insert.
|
||||
payloads (list, optional): List of payloads corresponding to vectors. Defaults to None.
|
||||
ids (list, optional): List of IDs corresponding to vectors. Defaults to None.
|
||||
"""
|
||||
logger.info(f"Inserting {len(vectors)} vectors into index {self.collection_name}")
|
||||
items = []
|
||||
|
||||
for idx, vector in enumerate(vectors):
|
||||
item_id = str(ids[idx]) if ids is not None else str(idx)
|
||||
payload = payloads[idx] if payloads else {}
|
||||
|
||||
vector_record = {"id": item_id, "values": vector, "metadata": payload}
|
||||
|
||||
if self.hybrid_search and self.sparse_encoder and "text" in payload:
|
||||
sparse_vector = self.sparse_encoder.encode_documents(payload["text"])
|
||||
vector_record["sparse_values"] = sparse_vector
|
||||
|
||||
items.append(vector_record)
|
||||
|
||||
if len(items) >= self.batch_size:
|
||||
self.index.upsert(vectors=items, namespace=self.namespace)
|
||||
items = []
|
||||
|
||||
if items:
|
||||
self.index.upsert(vectors=items, namespace=self.namespace)
|
||||
|
||||
def _parse_output(self, data: Dict) -> List[OutputData]:
|
||||
"""
|
||||
Parse the output data from Pinecone search results.
|
||||
|
||||
Args:
|
||||
data (Dict): Output data from Pinecone query.
|
||||
|
||||
Returns:
|
||||
List[OutputData]: Parsed output data.
|
||||
"""
|
||||
if isinstance(data, Vector):
|
||||
result = OutputData(
|
||||
id=data.id,
|
||||
score=0.0,
|
||||
payload=data.metadata,
|
||||
)
|
||||
return result
|
||||
else:
|
||||
result = []
|
||||
for match in data:
|
||||
entry = OutputData(
|
||||
id=match.get("id"),
|
||||
score=match.get("score"),
|
||||
payload=match.get("metadata"),
|
||||
)
|
||||
result.append(entry)
|
||||
|
||||
return result
|
||||
|
||||
def _create_filter(self, filters: Optional[Dict]) -> Dict:
|
||||
"""
|
||||
Create a filter dictionary from the provided filters.
|
||||
"""
|
||||
if not filters:
|
||||
return {}
|
||||
|
||||
pinecone_filter = {}
|
||||
|
||||
for key, value in filters.items():
|
||||
if isinstance(value, dict) and "gte" in value and "lte" in value:
|
||||
pinecone_filter[key] = {"$gte": value["gte"], "$lte": value["lte"]}
|
||||
else:
|
||||
pinecone_filter[key] = {"$eq": value}
|
||||
|
||||
return pinecone_filter
|
||||
|
||||
def search(
|
||||
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
|
||||
) -> List[OutputData]:
|
||||
"""
|
||||
Search for similar vectors.
|
||||
|
||||
Args:
|
||||
query (str): Query.
|
||||
vectors (list): List of vectors to search.
|
||||
limit (int, optional): Number of results to return. Defaults to 5.
|
||||
filters (dict, optional): Filters to apply to the search. Defaults to None.
|
||||
|
||||
Returns:
|
||||
list: Search results.
|
||||
"""
|
||||
filter_dict = self._create_filter(filters) if filters else None
|
||||
|
||||
query_params = {
|
||||
"vector": vectors,
|
||||
"top_k": limit,
|
||||
"include_metadata": True,
|
||||
"include_values": False,
|
||||
}
|
||||
|
||||
if filter_dict:
|
||||
query_params["filter"] = filter_dict
|
||||
|
||||
if self.hybrid_search and self.sparse_encoder and "text" in filters:
|
||||
query_text = filters.get("text")
|
||||
if query_text:
|
||||
sparse_vector = self.sparse_encoder.encode_queries(query_text)
|
||||
query_params["sparse_vector"] = sparse_vector
|
||||
|
||||
response = self.index.query(**query_params, namespace=self.namespace)
|
||||
|
||||
results = self._parse_output(response.matches)
|
||||
return results
|
||||
|
||||
def delete(self, vector_id: Union[str, int]):
|
||||
"""
|
||||
Delete a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (Union[str, int]): ID of the vector to delete.
|
||||
"""
|
||||
self.index.delete(ids=[str(vector_id)], namespace=self.namespace)
|
||||
|
||||
def update(self, vector_id: Union[str, int], vector: Optional[List[float]] = None, payload: Optional[Dict] = None):
|
||||
"""
|
||||
Update a vector and its payload.
|
||||
|
||||
Args:
|
||||
vector_id (Union[str, int]): ID of the vector to update.
|
||||
vector (list, optional): Updated vector. Defaults to None.
|
||||
payload (dict, optional): Updated payload. Defaults to None.
|
||||
"""
|
||||
item = {
|
||||
"id": str(vector_id),
|
||||
}
|
||||
|
||||
if vector is not None:
|
||||
item["values"] = vector
|
||||
|
||||
if payload is not None:
|
||||
item["metadata"] = payload
|
||||
|
||||
if self.hybrid_search and self.sparse_encoder and "text" in payload:
|
||||
sparse_vector = self.sparse_encoder.encode_documents(payload["text"])
|
||||
item["sparse_values"] = sparse_vector
|
||||
|
||||
self.index.upsert(vectors=[item], namespace=self.namespace)
|
||||
|
||||
def get(self, vector_id: Union[str, int]) -> OutputData:
|
||||
"""
|
||||
Retrieve a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (Union[str, int]): ID of the vector to retrieve.
|
||||
|
||||
Returns:
|
||||
dict: Retrieved vector or None if not found.
|
||||
"""
|
||||
try:
|
||||
response = self.index.fetch(ids=[str(vector_id)], namespace=self.namespace)
|
||||
if str(vector_id) in response.vectors:
|
||||
return self._parse_output(response.vectors[str(vector_id)])
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector {vector_id}: {e}")
|
||||
return None
|
||||
|
||||
def list_cols(self):
|
||||
"""
|
||||
List all indexes/collections.
|
||||
|
||||
Returns:
|
||||
list: List of index information.
|
||||
"""
|
||||
return self.client.list_indexes()
|
||||
|
||||
def delete_col(self):
|
||||
"""Delete an index/collection."""
|
||||
try:
|
||||
self.client.delete_index(self.collection_name)
|
||||
logger.info(f"Index {self.collection_name} deleted successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting index {self.collection_name}: {e}")
|
||||
|
||||
def col_info(self) -> Dict:
|
||||
"""
|
||||
Get information about an index/collection.
|
||||
|
||||
Returns:
|
||||
dict: Index information.
|
||||
"""
|
||||
return self.client.describe_index(self.collection_name)
|
||||
|
||||
def list(self, filters: Optional[Dict] = None, limit: int = 100) -> List[OutputData]:
|
||||
"""
|
||||
List vectors in an index with optional filtering.
|
||||
|
||||
Args:
|
||||
filters (dict, optional): Filters to apply to the list. Defaults to None.
|
||||
limit (int, optional): Number of vectors to return. Defaults to 100.
|
||||
|
||||
Returns:
|
||||
dict: List of vectors with their metadata.
|
||||
"""
|
||||
filter_dict = self._create_filter(filters) if filters else None
|
||||
|
||||
stats = self.index.describe_index_stats()
|
||||
dimension = stats.dimension
|
||||
|
||||
zero_vector = [0.0] * dimension
|
||||
|
||||
query_params = {
|
||||
"vector": zero_vector,
|
||||
"top_k": limit,
|
||||
"include_metadata": True,
|
||||
"include_values": True,
|
||||
}
|
||||
|
||||
if filter_dict:
|
||||
query_params["filter"] = filter_dict
|
||||
|
||||
try:
|
||||
response = self.index.query(**query_params, namespace=self.namespace)
|
||||
response = response.to_dict()
|
||||
results = self._parse_output(response["matches"])
|
||||
return [results]
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing vectors: {e}")
|
||||
return {"points": [], "next_page_token": None}
|
||||
|
||||
def count(self) -> int:
|
||||
"""
|
||||
Count number of vectors in the index.
|
||||
|
||||
Returns:
|
||||
int: Total number of vectors.
|
||||
"""
|
||||
stats = self.index.describe_index_stats()
|
||||
if self.namespace:
|
||||
# Safely get the namespace stats and return vector_count, defaulting to 0 if not found
|
||||
namespace_summary = (stats.namespaces or {}).get(self.namespace)
|
||||
if namespace_summary:
|
||||
return namespace_summary.vector_count or 0
|
||||
return 0
|
||||
return stats.total_vector_count or 0
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the index by deleting and recreating it.
|
||||
"""
|
||||
logger.warning(f"Resetting index {self.collection_name}...")
|
||||
self.delete_col()
|
||||
self.create_col(self.embedding_model_dims, self.metric)
|
||||
270
neomem/neomem/vector_stores/qdrant.py
Normal file
270
neomem/neomem/vector_stores/qdrant.py
Normal file
@@ -0,0 +1,270 @@
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import (
|
||||
Distance,
|
||||
FieldCondition,
|
||||
Filter,
|
||||
MatchValue,
|
||||
PointIdsList,
|
||||
PointStruct,
|
||||
Range,
|
||||
VectorParams,
|
||||
)
|
||||
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Qdrant(VectorStoreBase):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
embedding_model_dims: int,
|
||||
client: QdrantClient = None,
|
||||
host: str = None,
|
||||
port: int = None,
|
||||
path: str = None,
|
||||
url: str = None,
|
||||
api_key: str = None,
|
||||
on_disk: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the Qdrant vector store.
|
||||
|
||||
Args:
|
||||
collection_name (str): Name of the collection.
|
||||
embedding_model_dims (int): Dimensions of the embedding model.
|
||||
client (QdrantClient, optional): Existing Qdrant client instance. Defaults to None.
|
||||
host (str, optional): Host address for Qdrant server. Defaults to None.
|
||||
port (int, optional): Port for Qdrant server. Defaults to None.
|
||||
path (str, optional): Path for local Qdrant database. Defaults to None.
|
||||
url (str, optional): Full URL for Qdrant server. Defaults to None.
|
||||
api_key (str, optional): API key for Qdrant server. Defaults to None.
|
||||
on_disk (bool, optional): Enables persistent storage. Defaults to False.
|
||||
"""
|
||||
if client:
|
||||
self.client = client
|
||||
self.is_local = False
|
||||
else:
|
||||
params = {}
|
||||
if api_key:
|
||||
params["api_key"] = api_key
|
||||
if url:
|
||||
params["url"] = url
|
||||
if host and port:
|
||||
params["host"] = host
|
||||
params["port"] = port
|
||||
|
||||
if not params:
|
||||
params["path"] = path
|
||||
self.is_local = True
|
||||
if not on_disk:
|
||||
if os.path.exists(path) and os.path.isdir(path):
|
||||
shutil.rmtree(path)
|
||||
else:
|
||||
self.is_local = False
|
||||
|
||||
self.client = QdrantClient(**params)
|
||||
|
||||
self.collection_name = collection_name
|
||||
self.embedding_model_dims = embedding_model_dims
|
||||
self.on_disk = on_disk
|
||||
self.create_col(embedding_model_dims, on_disk)
|
||||
|
||||
def create_col(self, vector_size: int, on_disk: bool, distance: Distance = Distance.COSINE):
|
||||
"""
|
||||
Create a new collection.
|
||||
|
||||
Args:
|
||||
vector_size (int): Size of the vectors to be stored.
|
||||
on_disk (bool): Enables persistent storage.
|
||||
distance (Distance, optional): Distance metric for vector similarity. Defaults to Distance.COSINE.
|
||||
"""
|
||||
# Skip creating collection if already exists
|
||||
response = self.list_cols()
|
||||
for collection in response.collections:
|
||||
if collection.name == self.collection_name:
|
||||
logger.debug(f"Collection {self.collection_name} already exists. Skipping creation.")
|
||||
self._create_filter_indexes()
|
||||
return
|
||||
|
||||
self.client.create_collection(
|
||||
collection_name=self.collection_name,
|
||||
vectors_config=VectorParams(size=vector_size, distance=distance, on_disk=on_disk),
|
||||
)
|
||||
self._create_filter_indexes()
|
||||
|
||||
def _create_filter_indexes(self):
|
||||
"""Create indexes for commonly used filter fields to enable filtering."""
|
||||
# Only create payload indexes for remote Qdrant servers
|
||||
if self.is_local:
|
||||
logger.debug("Skipping payload index creation for local Qdrant (not supported)")
|
||||
return
|
||||
|
||||
common_fields = ["user_id", "agent_id", "run_id", "actor_id"]
|
||||
|
||||
for field in common_fields:
|
||||
try:
|
||||
self.client.create_payload_index(
|
||||
collection_name=self.collection_name,
|
||||
field_name=field,
|
||||
field_schema="keyword"
|
||||
)
|
||||
logger.info(f"Created index for {field} in collection {self.collection_name}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Index for {field} might already exist: {e}")
|
||||
|
||||
def insert(self, vectors: list, payloads: list = None, ids: list = None):
|
||||
"""
|
||||
Insert vectors into a collection.
|
||||
|
||||
Args:
|
||||
vectors (list): List of vectors to insert.
|
||||
payloads (list, optional): List of payloads corresponding to vectors. Defaults to None.
|
||||
ids (list, optional): List of IDs corresponding to vectors. Defaults to None.
|
||||
"""
|
||||
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
|
||||
points = [
|
||||
PointStruct(
|
||||
id=idx if ids is None else ids[idx],
|
||||
vector=vector,
|
||||
payload=payloads[idx] if payloads else {},
|
||||
)
|
||||
for idx, vector in enumerate(vectors)
|
||||
]
|
||||
self.client.upsert(collection_name=self.collection_name, points=points)
|
||||
|
||||
def _create_filter(self, filters: dict) -> Filter:
|
||||
"""
|
||||
Create a Filter object from the provided filters.
|
||||
|
||||
Args:
|
||||
filters (dict): Filters to apply.
|
||||
|
||||
Returns:
|
||||
Filter: The created Filter object.
|
||||
"""
|
||||
if not filters:
|
||||
return None
|
||||
|
||||
conditions = []
|
||||
for key, value in filters.items():
|
||||
if isinstance(value, dict) and "gte" in value and "lte" in value:
|
||||
conditions.append(FieldCondition(key=key, range=Range(gte=value["gte"], lte=value["lte"])))
|
||||
else:
|
||||
conditions.append(FieldCondition(key=key, match=MatchValue(value=value)))
|
||||
return Filter(must=conditions) if conditions else None
|
||||
|
||||
def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None) -> list:
|
||||
"""
|
||||
Search for similar vectors.
|
||||
|
||||
Args:
|
||||
query (str): Query.
|
||||
vectors (list): Query vector.
|
||||
limit (int, optional): Number of results to return. Defaults to 5.
|
||||
filters (dict, optional): Filters to apply to the search. Defaults to None.
|
||||
|
||||
Returns:
|
||||
list: Search results.
|
||||
"""
|
||||
query_filter = self._create_filter(filters) if filters else None
|
||||
hits = self.client.query_points(
|
||||
collection_name=self.collection_name,
|
||||
query=vectors,
|
||||
query_filter=query_filter,
|
||||
limit=limit,
|
||||
)
|
||||
return hits.points
|
||||
|
||||
def delete(self, vector_id: int):
|
||||
"""
|
||||
Delete a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (int): ID of the vector to delete.
|
||||
"""
|
||||
self.client.delete(
|
||||
collection_name=self.collection_name,
|
||||
points_selector=PointIdsList(
|
||||
points=[vector_id],
|
||||
),
|
||||
)
|
||||
|
||||
def update(self, vector_id: int, vector: list = None, payload: dict = None):
|
||||
"""
|
||||
Update a vector and its payload.
|
||||
|
||||
Args:
|
||||
vector_id (int): ID of the vector to update.
|
||||
vector (list, optional): Updated vector. Defaults to None.
|
||||
payload (dict, optional): Updated payload. Defaults to None.
|
||||
"""
|
||||
point = PointStruct(id=vector_id, vector=vector, payload=payload)
|
||||
self.client.upsert(collection_name=self.collection_name, points=[point])
|
||||
|
||||
def get(self, vector_id: int) -> dict:
|
||||
"""
|
||||
Retrieve a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (int): ID of the vector to retrieve.
|
||||
|
||||
Returns:
|
||||
dict: Retrieved vector.
|
||||
"""
|
||||
result = self.client.retrieve(collection_name=self.collection_name, ids=[vector_id], with_payload=True)
|
||||
return result[0] if result else None
|
||||
|
||||
def list_cols(self) -> list:
|
||||
"""
|
||||
List all collections.
|
||||
|
||||
Returns:
|
||||
list: List of collection names.
|
||||
"""
|
||||
return self.client.get_collections()
|
||||
|
||||
def delete_col(self):
|
||||
"""Delete a collection."""
|
||||
self.client.delete_collection(collection_name=self.collection_name)
|
||||
|
||||
def col_info(self) -> dict:
|
||||
"""
|
||||
Get information about a collection.
|
||||
|
||||
Returns:
|
||||
dict: Collection information.
|
||||
"""
|
||||
return self.client.get_collection(collection_name=self.collection_name)
|
||||
|
||||
def list(self, filters: dict = None, limit: int = 100) -> list:
|
||||
"""
|
||||
List all vectors in a collection.
|
||||
|
||||
Args:
|
||||
filters (dict, optional): Filters to apply to the list. Defaults to None.
|
||||
limit (int, optional): Number of vectors to return. Defaults to 100.
|
||||
|
||||
Returns:
|
||||
list: List of vectors.
|
||||
"""
|
||||
query_filter = self._create_filter(filters) if filters else None
|
||||
result = self.client.scroll(
|
||||
collection_name=self.collection_name,
|
||||
scroll_filter=query_filter,
|
||||
limit=limit,
|
||||
with_payload=True,
|
||||
with_vectors=False,
|
||||
)
|
||||
return result
|
||||
|
||||
def reset(self):
|
||||
"""Reset the index by deleting and recreating it."""
|
||||
logger.warning(f"Resetting index {self.collection_name}...")
|
||||
self.delete_col()
|
||||
self.create_col(self.embedding_model_dims, self.on_disk)
|
||||
295
neomem/neomem/vector_stores/redis.py
Normal file
295
neomem/neomem/vector_stores/redis.py
Normal file
@@ -0,0 +1,295 @@
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from functools import reduce
|
||||
|
||||
import numpy as np
|
||||
import pytz
|
||||
import redis
|
||||
from redis.commands.search.query import Query
|
||||
from redisvl.index import SearchIndex
|
||||
from redisvl.query import VectorQuery
|
||||
from redisvl.query.filter import Tag
|
||||
|
||||
from mem0.memory.utils import extract_json
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# TODO: Improve as these are not the best fields for the Redis's perspective. Might do away with them.
|
||||
DEFAULT_FIELDS = [
|
||||
{"name": "memory_id", "type": "tag"},
|
||||
{"name": "hash", "type": "tag"},
|
||||
{"name": "agent_id", "type": "tag"},
|
||||
{"name": "run_id", "type": "tag"},
|
||||
{"name": "user_id", "type": "tag"},
|
||||
{"name": "memory", "type": "text"},
|
||||
{"name": "metadata", "type": "text"},
|
||||
# TODO: Although it is numeric but also accepts string
|
||||
{"name": "created_at", "type": "numeric"},
|
||||
{"name": "updated_at", "type": "numeric"},
|
||||
{
|
||||
"name": "embedding",
|
||||
"type": "vector",
|
||||
"attrs": {"distance_metric": "cosine", "algorithm": "flat", "datatype": "float32"},
|
||||
},
|
||||
]
|
||||
|
||||
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
|
||||
|
||||
|
||||
class MemoryResult:
|
||||
def __init__(self, id: str, payload: dict, score: float = None):
|
||||
self.id = id
|
||||
self.payload = payload
|
||||
self.score = score
|
||||
|
||||
|
||||
class RedisDB(VectorStoreBase):
|
||||
def __init__(
|
||||
self,
|
||||
redis_url: str,
|
||||
collection_name: str,
|
||||
embedding_model_dims: int,
|
||||
):
|
||||
"""
|
||||
Initialize the Redis vector store.
|
||||
|
||||
Args:
|
||||
redis_url (str): Redis URL.
|
||||
collection_name (str): Collection name.
|
||||
embedding_model_dims (int): Embedding model dimensions.
|
||||
"""
|
||||
self.embedding_model_dims = embedding_model_dims
|
||||
index_schema = {
|
||||
"name": collection_name,
|
||||
"prefix": f"mem0:{collection_name}",
|
||||
}
|
||||
|
||||
fields = DEFAULT_FIELDS.copy()
|
||||
fields[-1]["attrs"]["dims"] = embedding_model_dims
|
||||
|
||||
self.schema = {"index": index_schema, "fields": fields}
|
||||
|
||||
self.client = redis.Redis.from_url(redis_url)
|
||||
self.index = SearchIndex.from_dict(self.schema)
|
||||
self.index.set_client(self.client)
|
||||
self.index.create(overwrite=True)
|
||||
|
||||
def create_col(self, name=None, vector_size=None, distance=None):
|
||||
"""
|
||||
Create a new collection (index) in Redis.
|
||||
|
||||
Args:
|
||||
name (str, optional): Name for the collection. Defaults to None, which uses the current collection_name.
|
||||
vector_size (int, optional): Size of the vector embeddings. Defaults to None, which uses the current embedding_model_dims.
|
||||
distance (str, optional): Distance metric to use. Defaults to None, which uses 'cosine'.
|
||||
|
||||
Returns:
|
||||
The created index object.
|
||||
"""
|
||||
# Use provided parameters or fall back to instance attributes
|
||||
collection_name = name or self.schema["index"]["name"]
|
||||
embedding_dims = vector_size or self.embedding_model_dims
|
||||
distance_metric = distance or "cosine"
|
||||
|
||||
# Create a new schema with the specified parameters
|
||||
index_schema = {
|
||||
"name": collection_name,
|
||||
"prefix": f"mem0:{collection_name}",
|
||||
}
|
||||
|
||||
# Copy the default fields and update the vector field with the specified dimensions
|
||||
fields = DEFAULT_FIELDS.copy()
|
||||
fields[-1]["attrs"]["dims"] = embedding_dims
|
||||
fields[-1]["attrs"]["distance_metric"] = distance_metric
|
||||
|
||||
# Create the schema
|
||||
schema = {"index": index_schema, "fields": fields}
|
||||
|
||||
# Create the index
|
||||
index = SearchIndex.from_dict(schema)
|
||||
index.set_client(self.client)
|
||||
index.create(overwrite=True)
|
||||
|
||||
# Update instance attributes if creating a new collection
|
||||
if name:
|
||||
self.schema = schema
|
||||
self.index = index
|
||||
|
||||
return index
|
||||
|
||||
def insert(self, vectors: list, payloads: list = None, ids: list = None):
|
||||
data = []
|
||||
for vector, payload, id in zip(vectors, payloads, ids):
|
||||
# Start with required fields
|
||||
entry = {
|
||||
"memory_id": id,
|
||||
"hash": payload["hash"],
|
||||
"memory": payload["data"],
|
||||
"created_at": int(datetime.fromisoformat(payload["created_at"]).timestamp()),
|
||||
"embedding": np.array(vector, dtype=np.float32).tobytes(),
|
||||
}
|
||||
|
||||
# Conditionally add optional fields
|
||||
for field in ["agent_id", "run_id", "user_id"]:
|
||||
if field in payload:
|
||||
entry[field] = payload[field]
|
||||
|
||||
# Add metadata excluding specific keys
|
||||
entry["metadata"] = json.dumps({k: v for k, v in payload.items() if k not in excluded_keys})
|
||||
|
||||
data.append(entry)
|
||||
self.index.load(data, id_field="memory_id")
|
||||
|
||||
def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None):
|
||||
conditions = [Tag(key) == value for key, value in filters.items() if value is not None]
|
||||
filter = reduce(lambda x, y: x & y, conditions)
|
||||
|
||||
v = VectorQuery(
|
||||
vector=np.array(vectors, dtype=np.float32).tobytes(),
|
||||
vector_field_name="embedding",
|
||||
return_fields=["memory_id", "hash", "agent_id", "run_id", "user_id", "memory", "metadata", "created_at"],
|
||||
filter_expression=filter,
|
||||
num_results=limit,
|
||||
)
|
||||
|
||||
results = self.index.query(v)
|
||||
|
||||
return [
|
||||
MemoryResult(
|
||||
id=result["memory_id"],
|
||||
score=result["vector_distance"],
|
||||
payload={
|
||||
"hash": result["hash"],
|
||||
"data": result["memory"],
|
||||
"created_at": datetime.fromtimestamp(
|
||||
int(result["created_at"]), tz=pytz.timezone("US/Pacific")
|
||||
).isoformat(timespec="microseconds"),
|
||||
**(
|
||||
{
|
||||
"updated_at": datetime.fromtimestamp(
|
||||
int(result["updated_at"]), tz=pytz.timezone("US/Pacific")
|
||||
).isoformat(timespec="microseconds")
|
||||
}
|
||||
if "updated_at" in result
|
||||
else {}
|
||||
),
|
||||
**{field: result[field] for field in ["agent_id", "run_id", "user_id"] if field in result},
|
||||
**{k: v for k, v in json.loads(extract_json(result["metadata"])).items()},
|
||||
},
|
||||
)
|
||||
for result in results
|
||||
]
|
||||
|
||||
def delete(self, vector_id):
|
||||
self.index.drop_keys(f"{self.schema['index']['prefix']}:{vector_id}")
|
||||
|
||||
def update(self, vector_id=None, vector=None, payload=None):
|
||||
data = {
|
||||
"memory_id": vector_id,
|
||||
"hash": payload["hash"],
|
||||
"memory": payload["data"],
|
||||
"created_at": int(datetime.fromisoformat(payload["created_at"]).timestamp()),
|
||||
"updated_at": int(datetime.fromisoformat(payload["updated_at"]).timestamp()),
|
||||
"embedding": np.array(vector, dtype=np.float32).tobytes(),
|
||||
}
|
||||
|
||||
for field in ["agent_id", "run_id", "user_id"]:
|
||||
if field in payload:
|
||||
data[field] = payload[field]
|
||||
|
||||
data["metadata"] = json.dumps({k: v for k, v in payload.items() if k not in excluded_keys})
|
||||
self.index.load(data=[data], keys=[f"{self.schema['index']['prefix']}:{vector_id}"], id_field="memory_id")
|
||||
|
||||
def get(self, vector_id):
|
||||
result = self.index.fetch(vector_id)
|
||||
payload = {
|
||||
"hash": result["hash"],
|
||||
"data": result["memory"],
|
||||
"created_at": datetime.fromtimestamp(int(result["created_at"]), tz=pytz.timezone("US/Pacific")).isoformat(
|
||||
timespec="microseconds"
|
||||
),
|
||||
**(
|
||||
{
|
||||
"updated_at": datetime.fromtimestamp(
|
||||
int(result["updated_at"]), tz=pytz.timezone("US/Pacific")
|
||||
).isoformat(timespec="microseconds")
|
||||
}
|
||||
if "updated_at" in result
|
||||
else {}
|
||||
),
|
||||
**{field: result[field] for field in ["agent_id", "run_id", "user_id"] if field in result},
|
||||
**{k: v for k, v in json.loads(extract_json(result["metadata"])).items()},
|
||||
}
|
||||
|
||||
return MemoryResult(id=result["memory_id"], payload=payload)
|
||||
|
||||
def list_cols(self):
|
||||
return self.index.listall()
|
||||
|
||||
def delete_col(self):
|
||||
self.index.delete()
|
||||
|
||||
def col_info(self, name):
|
||||
return self.index.info()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the index by deleting and recreating it.
|
||||
"""
|
||||
collection_name = self.schema["index"]["name"]
|
||||
logger.warning(f"Resetting index {collection_name}...")
|
||||
self.delete_col()
|
||||
|
||||
self.index = SearchIndex.from_dict(self.schema)
|
||||
self.index.set_client(self.client)
|
||||
self.index.create(overwrite=True)
|
||||
|
||||
# or use
|
||||
# self.create_col(collection_name, self.embedding_model_dims)
|
||||
|
||||
# Recreate the index with the same parameters
|
||||
self.create_col(collection_name, self.embedding_model_dims)
|
||||
|
||||
def list(self, filters: dict = None, limit: int = None) -> list:
|
||||
"""
|
||||
List all recent created memories from the vector store.
|
||||
"""
|
||||
conditions = [Tag(key) == value for key, value in filters.items() if value is not None]
|
||||
filter = reduce(lambda x, y: x & y, conditions)
|
||||
query = Query(str(filter)).sort_by("created_at", asc=False)
|
||||
if limit is not None:
|
||||
query = Query(str(filter)).sort_by("created_at", asc=False).paging(0, limit)
|
||||
|
||||
results = self.index.search(query)
|
||||
return [
|
||||
[
|
||||
MemoryResult(
|
||||
id=result["memory_id"],
|
||||
payload={
|
||||
"hash": result["hash"],
|
||||
"data": result["memory"],
|
||||
"created_at": datetime.fromtimestamp(
|
||||
int(result["created_at"]), tz=pytz.timezone("US/Pacific")
|
||||
).isoformat(timespec="microseconds"),
|
||||
**(
|
||||
{
|
||||
"updated_at": datetime.fromtimestamp(
|
||||
int(result["updated_at"]), tz=pytz.timezone("US/Pacific")
|
||||
).isoformat(timespec="microseconds")
|
||||
}
|
||||
if result.__dict__.get("updated_at")
|
||||
else {}
|
||||
),
|
||||
**{
|
||||
field: result[field]
|
||||
for field in ["agent_id", "run_id", "user_id"]
|
||||
if field in result.__dict__
|
||||
},
|
||||
**{k: v for k, v in json.loads(extract_json(result["metadata"])).items()},
|
||||
},
|
||||
)
|
||||
for result in results.docs
|
||||
]
|
||||
]
|
||||
176
neomem/neomem/vector_stores/s3_vectors.py
Normal file
176
neomem/neomem/vector_stores/s3_vectors.py
Normal file
@@ -0,0 +1,176 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
try:
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
except ImportError:
|
||||
raise ImportError("The 'boto3' library is required. Please install it using 'pip install boto3'.")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OutputData(BaseModel):
|
||||
id: Optional[str]
|
||||
score: Optional[float]
|
||||
payload: Optional[Dict]
|
||||
|
||||
|
||||
class S3Vectors(VectorStoreBase):
|
||||
def __init__(
|
||||
self,
|
||||
vector_bucket_name: str,
|
||||
collection_name: str,
|
||||
embedding_model_dims: int,
|
||||
distance_metric: str = "cosine",
|
||||
region_name: Optional[str] = None,
|
||||
):
|
||||
self.client = boto3.client("s3vectors", region_name=region_name)
|
||||
self.vector_bucket_name = vector_bucket_name
|
||||
self.collection_name = collection_name
|
||||
self.embedding_model_dims = embedding_model_dims
|
||||
self.distance_metric = distance_metric
|
||||
|
||||
self._ensure_bucket_exists()
|
||||
self.create_col(self.collection_name, self.embedding_model_dims, self.distance_metric)
|
||||
|
||||
def _ensure_bucket_exists(self):
|
||||
try:
|
||||
self.client.get_vector_bucket(vectorBucketName=self.vector_bucket_name)
|
||||
logger.info(f"Vector bucket '{self.vector_bucket_name}' already exists.")
|
||||
except ClientError as e:
|
||||
if e.response["Error"]["Code"] == "NotFoundException":
|
||||
logger.info(f"Vector bucket '{self.vector_bucket_name}' not found. Creating it.")
|
||||
self.client.create_vector_bucket(vectorBucketName=self.vector_bucket_name)
|
||||
logger.info(f"Vector bucket '{self.vector_bucket_name}' created.")
|
||||
else:
|
||||
raise
|
||||
|
||||
def create_col(self, name, vector_size, distance="cosine"):
|
||||
try:
|
||||
self.client.get_index(vectorBucketName=self.vector_bucket_name, indexName=name)
|
||||
logger.info(f"Index '{name}' already exists in bucket '{self.vector_bucket_name}'.")
|
||||
except ClientError as e:
|
||||
if e.response["Error"]["Code"] == "NotFoundException":
|
||||
logger.info(f"Index '{name}' not found in bucket '{self.vector_bucket_name}'. Creating it.")
|
||||
self.client.create_index(
|
||||
vectorBucketName=self.vector_bucket_name,
|
||||
indexName=name,
|
||||
dataType="float32",
|
||||
dimension=vector_size,
|
||||
distanceMetric=distance,
|
||||
)
|
||||
logger.info(f"Index '{name}' created.")
|
||||
else:
|
||||
raise
|
||||
|
||||
def _parse_output(self, vectors: List[Dict]) -> List[OutputData]:
|
||||
results = []
|
||||
for v in vectors:
|
||||
payload = v.get("metadata", {})
|
||||
# Boto3 might return metadata as a JSON string
|
||||
if isinstance(payload, str):
|
||||
try:
|
||||
payload = json.loads(payload)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse metadata for key {v.get('key')}")
|
||||
payload = {}
|
||||
results.append(OutputData(id=v.get("key"), score=v.get("distance"), payload=payload))
|
||||
return results
|
||||
|
||||
def insert(self, vectors, payloads=None, ids=None):
|
||||
vectors_to_put = []
|
||||
for i, vec in enumerate(vectors):
|
||||
vectors_to_put.append(
|
||||
{
|
||||
"key": ids[i],
|
||||
"data": {"float32": vec},
|
||||
"metadata": payloads[i] if payloads else {},
|
||||
}
|
||||
)
|
||||
self.client.put_vectors(
|
||||
vectorBucketName=self.vector_bucket_name,
|
||||
indexName=self.collection_name,
|
||||
vectors=vectors_to_put,
|
||||
)
|
||||
|
||||
def search(self, query, vectors, limit=5, filters=None):
|
||||
params = {
|
||||
"vectorBucketName": self.vector_bucket_name,
|
||||
"indexName": self.collection_name,
|
||||
"queryVector": {"float32": vectors},
|
||||
"topK": limit,
|
||||
"returnMetadata": True,
|
||||
"returnDistance": True,
|
||||
}
|
||||
if filters:
|
||||
params["filter"] = filters
|
||||
|
||||
response = self.client.query_vectors(**params)
|
||||
return self._parse_output(response.get("vectors", []))
|
||||
|
||||
def delete(self, vector_id):
|
||||
self.client.delete_vectors(
|
||||
vectorBucketName=self.vector_bucket_name,
|
||||
indexName=self.collection_name,
|
||||
keys=[vector_id],
|
||||
)
|
||||
|
||||
def update(self, vector_id, vector=None, payload=None):
|
||||
# S3 Vectors uses put_vectors for updates (overwrite)
|
||||
self.insert(vectors=[vector], payloads=[payload], ids=[vector_id])
|
||||
|
||||
def get(self, vector_id) -> Optional[OutputData]:
|
||||
response = self.client.get_vectors(
|
||||
vectorBucketName=self.vector_bucket_name,
|
||||
indexName=self.collection_name,
|
||||
keys=[vector_id],
|
||||
returnData=False,
|
||||
returnMetadata=True,
|
||||
)
|
||||
vectors = response.get("vectors", [])
|
||||
if not vectors:
|
||||
return None
|
||||
return self._parse_output(vectors)[0]
|
||||
|
||||
def list_cols(self):
|
||||
response = self.client.list_indexes(vectorBucketName=self.vector_bucket_name)
|
||||
return [idx["indexName"] for idx in response.get("indexes", [])]
|
||||
|
||||
def delete_col(self):
|
||||
self.client.delete_index(vectorBucketName=self.vector_bucket_name, indexName=self.collection_name)
|
||||
|
||||
def col_info(self):
|
||||
response = self.client.get_index(vectorBucketName=self.vector_bucket_name, indexName=self.collection_name)
|
||||
return response.get("index", {})
|
||||
|
||||
def list(self, filters=None, limit=None):
|
||||
# Note: list_vectors does not support metadata filtering.
|
||||
if filters:
|
||||
logger.warning("S3 Vectors `list` does not support metadata filtering. Ignoring filters.")
|
||||
|
||||
params = {
|
||||
"vectorBucketName": self.vector_bucket_name,
|
||||
"indexName": self.collection_name,
|
||||
"returnData": False,
|
||||
"returnMetadata": True,
|
||||
}
|
||||
if limit:
|
||||
params["maxResults"] = limit
|
||||
|
||||
paginator = self.client.get_paginator("list_vectors")
|
||||
pages = paginator.paginate(**params)
|
||||
all_vectors = []
|
||||
for page in pages:
|
||||
all_vectors.extend(page.get("vectors", []))
|
||||
return [self._parse_output(all_vectors)]
|
||||
|
||||
def reset(self):
|
||||
logger.warning(f"Resetting index {self.collection_name}...")
|
||||
self.delete_col()
|
||||
self.create_col(self.collection_name, self.embedding_model_dims, self.distance_metric)
|
||||
237
neomem/neomem/vector_stores/supabase.py
Normal file
237
neomem/neomem/vector_stores/supabase.py
Normal file
@@ -0,0 +1,237 @@
|
||||
import logging
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
try:
|
||||
import vecs
|
||||
except ImportError:
|
||||
raise ImportError("The 'vecs' library is required. Please install it using 'pip install vecs'.")
|
||||
|
||||
from mem0.configs.vector_stores.supabase import IndexMeasure, IndexMethod
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OutputData(BaseModel):
|
||||
id: Optional[str]
|
||||
score: Optional[float]
|
||||
payload: Optional[dict]
|
||||
|
||||
|
||||
class Supabase(VectorStoreBase):
|
||||
def __init__(
|
||||
self,
|
||||
connection_string: str,
|
||||
collection_name: str,
|
||||
embedding_model_dims: int,
|
||||
index_method: IndexMethod = IndexMethod.AUTO,
|
||||
index_measure: IndexMeasure = IndexMeasure.COSINE,
|
||||
):
|
||||
"""
|
||||
Initialize the Supabase vector store using vecs.
|
||||
|
||||
Args:
|
||||
connection_string (str): PostgreSQL connection string
|
||||
collection_name (str): Collection name
|
||||
embedding_model_dims (int): Dimension of the embedding vector
|
||||
index_method (IndexMethod): Index method to use. Defaults to AUTO.
|
||||
index_measure (IndexMeasure): Distance measure to use. Defaults to COSINE.
|
||||
"""
|
||||
self.db = vecs.create_client(connection_string)
|
||||
self.collection_name = collection_name
|
||||
self.embedding_model_dims = embedding_model_dims
|
||||
self.index_method = index_method
|
||||
self.index_measure = index_measure
|
||||
|
||||
collections = self.list_cols()
|
||||
if collection_name not in collections:
|
||||
self.create_col(embedding_model_dims)
|
||||
|
||||
def _preprocess_filters(self, filters: Optional[dict] = None) -> Optional[dict]:
|
||||
"""
|
||||
Preprocess filters to be compatible with vecs.
|
||||
|
||||
Args:
|
||||
filters (Dict, optional): Filters to preprocess. Multiple filters will be
|
||||
combined with AND logic.
|
||||
"""
|
||||
if filters is None:
|
||||
return None
|
||||
|
||||
if len(filters) == 1:
|
||||
# For single filter, keep the simple format
|
||||
key, value = next(iter(filters.items()))
|
||||
return {key: {"$eq": value}}
|
||||
|
||||
# For multiple filters, use $and clause
|
||||
return {"$and": [{key: {"$eq": value}} for key, value in filters.items()]}
|
||||
|
||||
def create_col(self, embedding_model_dims: Optional[int] = None) -> None:
|
||||
"""
|
||||
Create a new collection with vector support.
|
||||
Will also initialize vector search index.
|
||||
|
||||
Args:
|
||||
embedding_model_dims (int, optional): Dimension of the embedding vector.
|
||||
If not provided, uses the dimension specified in initialization.
|
||||
"""
|
||||
dims = embedding_model_dims or self.embedding_model_dims
|
||||
if not dims:
|
||||
raise ValueError(
|
||||
"embedding_model_dims must be provided either during initialization or when creating collection"
|
||||
)
|
||||
|
||||
logger.info(f"Creating new collection: {self.collection_name}")
|
||||
try:
|
||||
self.collection = self.db.get_or_create_collection(name=self.collection_name, dimension=dims)
|
||||
self.collection.create_index(method=self.index_method.value, measure=self.index_measure.value)
|
||||
logger.info(f"Successfully created collection {self.collection_name} with dimension {dims}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create collection: {str(e)}")
|
||||
raise
|
||||
|
||||
def insert(
|
||||
self, vectors: List[List[float]], payloads: Optional[List[dict]] = None, ids: Optional[List[str]] = None
|
||||
):
|
||||
"""
|
||||
Insert vectors into the collection.
|
||||
|
||||
Args:
|
||||
vectors (List[List[float]]): List of vectors to insert
|
||||
payloads (List[Dict], optional): List of payloads corresponding to vectors
|
||||
ids (List[str], optional): List of IDs corresponding to vectors
|
||||
"""
|
||||
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
|
||||
|
||||
if not ids:
|
||||
ids = [str(uuid.uuid4()) for _ in vectors]
|
||||
if not payloads:
|
||||
payloads = [{} for _ in vectors]
|
||||
|
||||
records = [(id, vector, payload) for id, vector, payload in zip(ids, vectors, payloads)]
|
||||
|
||||
self.collection.upsert(records)
|
||||
|
||||
def search(
|
||||
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[dict] = None
|
||||
) -> List[OutputData]:
|
||||
"""
|
||||
Search for similar vectors.
|
||||
|
||||
Args:
|
||||
query (str): Query.
|
||||
vectors (List[float]): Query vector.
|
||||
limit (int, optional): Number of results to return. Defaults to 5.
|
||||
filters (Dict, optional): Filters to apply to the search. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[OutputData]: Search results
|
||||
"""
|
||||
filters = self._preprocess_filters(filters)
|
||||
results = self.collection.query(
|
||||
data=vectors, limit=limit, filters=filters, include_metadata=True, include_value=True
|
||||
)
|
||||
|
||||
return [OutputData(id=str(result[0]), score=float(result[1]), payload=result[2]) for result in results]
|
||||
|
||||
def delete(self, vector_id: str):
|
||||
"""
|
||||
Delete a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to delete
|
||||
"""
|
||||
self.collection.delete([(vector_id,)])
|
||||
|
||||
def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[dict] = None):
|
||||
"""
|
||||
Update a vector and/or its payload.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to update
|
||||
vector (List[float], optional): Updated vector
|
||||
payload (Dict, optional): Updated payload
|
||||
"""
|
||||
if vector is None:
|
||||
# If only updating metadata, we need to get the existing vector
|
||||
existing = self.get(vector_id)
|
||||
if existing and existing.payload:
|
||||
vector = existing.payload.get("vector", [])
|
||||
|
||||
if vector:
|
||||
self.collection.upsert([(vector_id, vector, payload or {})])
|
||||
|
||||
def get(self, vector_id: str) -> Optional[OutputData]:
|
||||
"""
|
||||
Retrieve a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to retrieve
|
||||
|
||||
Returns:
|
||||
Optional[OutputData]: Retrieved vector data or None if not found
|
||||
"""
|
||||
result = self.collection.fetch([(vector_id,)])
|
||||
if not result:
|
||||
return []
|
||||
|
||||
record = result[0]
|
||||
return OutputData(id=str(record.id), score=None, payload=record.metadata)
|
||||
|
||||
def list_cols(self) -> List[str]:
|
||||
"""
|
||||
List all collections.
|
||||
|
||||
Returns:
|
||||
List[str]: List of collection names
|
||||
"""
|
||||
return self.db.list_collections()
|
||||
|
||||
def delete_col(self):
|
||||
"""Delete the collection."""
|
||||
self.db.delete_collection(self.collection_name)
|
||||
|
||||
def col_info(self) -> dict:
|
||||
"""
|
||||
Get information about the collection.
|
||||
|
||||
Returns:
|
||||
Dict: Collection information including name and configuration
|
||||
"""
|
||||
info = self.collection.describe()
|
||||
return {
|
||||
"name": info.name,
|
||||
"count": info.vectors,
|
||||
"dimension": info.dimension,
|
||||
"index": {"method": info.index_method, "metric": info.distance_metric},
|
||||
}
|
||||
|
||||
def list(self, filters: Optional[dict] = None, limit: int = 100) -> List[OutputData]:
|
||||
"""
|
||||
List vectors in the collection.
|
||||
|
||||
Args:
|
||||
filters (Dict, optional): Filters to apply
|
||||
limit (int, optional): Maximum number of results to return. Defaults to 100.
|
||||
|
||||
Returns:
|
||||
List[OutputData]: List of vectors
|
||||
"""
|
||||
filters = self._preprocess_filters(filters)
|
||||
query = [0] * self.embedding_model_dims
|
||||
ids = self.collection.query(
|
||||
data=query, limit=limit, filters=filters, include_metadata=True, include_value=False
|
||||
)
|
||||
ids = [id[0] for id in ids]
|
||||
records = self.collection.fetch(ids=ids)
|
||||
|
||||
return [[OutputData(id=str(record[0]), score=None, payload=record[2]) for record in records]]
|
||||
|
||||
def reset(self):
|
||||
"""Reset the index by deleting and recreating it."""
|
||||
logger.warning(f"Resetting index {self.collection_name}...")
|
||||
self.delete_col()
|
||||
self.create_col(self.embedding_model_dims)
|
||||
293
neomem/neomem/vector_stores/upstash_vector.py
Normal file
293
neomem/neomem/vector_stores/upstash_vector.py
Normal file
@@ -0,0 +1,293 @@
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
try:
|
||||
from upstash_vector import Index
|
||||
except ImportError:
|
||||
raise ImportError("The 'upstash_vector' library is required. Please install it using 'pip install upstash_vector'.")
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OutputData(BaseModel):
|
||||
id: Optional[str] # memory id
|
||||
score: Optional[float] # is None for `get` method
|
||||
payload: Optional[Dict] # metadata
|
||||
|
||||
|
||||
class UpstashVector(VectorStoreBase):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
url: Optional[str] = None,
|
||||
token: Optional[str] = None,
|
||||
client: Optional[Index] = None,
|
||||
enable_embeddings: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the UpstashVector vector store.
|
||||
|
||||
Args:
|
||||
url (str, optional): URL for Upstash Vector index. Defaults to None.
|
||||
token (int, optional): Token for Upstash Vector index. Defaults to None.
|
||||
client (Index, optional): Existing `upstash_vector.Index` client instance. Defaults to None.
|
||||
namespace (str, optional): Default namespace for the index. Defaults to None.
|
||||
"""
|
||||
if client:
|
||||
self.client = client
|
||||
elif url and token:
|
||||
self.client = Index(url, token)
|
||||
else:
|
||||
raise ValueError("Either a client or URL and token must be provided.")
|
||||
|
||||
self.collection_name = collection_name
|
||||
|
||||
self.enable_embeddings = enable_embeddings
|
||||
|
||||
def insert(
|
||||
self,
|
||||
vectors: List[list],
|
||||
payloads: Optional[List[Dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
):
|
||||
"""
|
||||
Insert vectors
|
||||
|
||||
Args:
|
||||
vectors (list): List of vectors to insert.
|
||||
payloads (list, optional): List of payloads corresponding to vectors. These will be passed as metadatas to the Upstash Vector client. Defaults to None.
|
||||
ids (list, optional): List of IDs corresponding to vectors. Defaults to None.
|
||||
"""
|
||||
logger.info(f"Inserting {len(vectors)} vectors into namespace {self.collection_name}")
|
||||
|
||||
if self.enable_embeddings:
|
||||
if not payloads or any("data" not in m or m["data"] is None for m in payloads):
|
||||
raise ValueError("When embeddings are enabled, all payloads must contain a 'data' field.")
|
||||
processed_vectors = [
|
||||
{
|
||||
"id": ids[i] if ids else None,
|
||||
"data": payloads[i]["data"],
|
||||
"metadata": payloads[i],
|
||||
}
|
||||
for i, v in enumerate(vectors)
|
||||
]
|
||||
else:
|
||||
processed_vectors = [
|
||||
{
|
||||
"id": ids[i] if ids else None,
|
||||
"vector": vectors[i],
|
||||
"metadata": payloads[i] if payloads else None,
|
||||
}
|
||||
for i, v in enumerate(vectors)
|
||||
]
|
||||
|
||||
self.client.upsert(
|
||||
vectors=processed_vectors,
|
||||
namespace=self.collection_name,
|
||||
)
|
||||
|
||||
def _stringify(self, x):
|
||||
return f'"{x}"' if isinstance(x, str) else x
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
vectors: List[list],
|
||||
limit: int = 5,
|
||||
filters: Optional[Dict] = None,
|
||||
) -> List[OutputData]:
|
||||
"""
|
||||
Search for similar vectors.
|
||||
|
||||
Args:
|
||||
query (list): Query vector.
|
||||
limit (int, optional): Number of results to return. Defaults to 5.
|
||||
filters (Dict, optional): Filters to apply to the search.
|
||||
|
||||
Returns:
|
||||
List[OutputData]: Search results.
|
||||
"""
|
||||
|
||||
filters_str = " AND ".join([f"{k} = {self._stringify(v)}" for k, v in filters.items()]) if filters else None
|
||||
|
||||
response = []
|
||||
|
||||
if self.enable_embeddings:
|
||||
response = self.client.query(
|
||||
data=query,
|
||||
top_k=limit,
|
||||
filter=filters_str or "",
|
||||
include_metadata=True,
|
||||
namespace=self.collection_name,
|
||||
)
|
||||
else:
|
||||
queries = [
|
||||
{
|
||||
"vector": v,
|
||||
"top_k": limit,
|
||||
"filter": filters_str or "",
|
||||
"include_metadata": True,
|
||||
"namespace": self.collection_name,
|
||||
}
|
||||
for v in vectors
|
||||
]
|
||||
responses = self.client.query_many(queries=queries)
|
||||
# flatten
|
||||
response = [res for res_list in responses for res in res_list]
|
||||
|
||||
return [
|
||||
OutputData(
|
||||
id=res.id,
|
||||
score=res.score,
|
||||
payload=res.metadata,
|
||||
)
|
||||
for res in response
|
||||
]
|
||||
|
||||
def delete(self, vector_id: int):
|
||||
"""
|
||||
Delete a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (int): ID of the vector to delete.
|
||||
"""
|
||||
self.client.delete(
|
||||
ids=[str(vector_id)],
|
||||
namespace=self.collection_name,
|
||||
)
|
||||
|
||||
def update(
|
||||
self,
|
||||
vector_id: int,
|
||||
vector: Optional[list] = None,
|
||||
payload: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Update a vector and its payload.
|
||||
|
||||
Args:
|
||||
vector_id (int): ID of the vector to update.
|
||||
vector (list, optional): Updated vector. Defaults to None.
|
||||
payload (dict, optional): Updated payload. Defaults to None.
|
||||
"""
|
||||
self.client.update(
|
||||
id=str(vector_id),
|
||||
vector=vector,
|
||||
data=payload.get("data") if payload else None,
|
||||
metadata=payload,
|
||||
namespace=self.collection_name,
|
||||
)
|
||||
|
||||
def get(self, vector_id: int) -> Optional[OutputData]:
|
||||
"""
|
||||
Retrieve a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (int): ID of the vector to retrieve.
|
||||
|
||||
Returns:
|
||||
dict: Retrieved vector.
|
||||
"""
|
||||
response = self.client.fetch(
|
||||
ids=[str(vector_id)],
|
||||
namespace=self.collection_name,
|
||||
include_metadata=True,
|
||||
)
|
||||
if len(response) == 0:
|
||||
return None
|
||||
vector = response[0]
|
||||
if not vector:
|
||||
return None
|
||||
return OutputData(id=vector.id, score=None, payload=vector.metadata)
|
||||
|
||||
def list(self, filters: Optional[Dict] = None, limit: int = 100) -> List[List[OutputData]]:
|
||||
"""
|
||||
List all memories.
|
||||
Args:
|
||||
filters (Dict, optional): Filters to apply to the search. Defaults to None.
|
||||
limit (int, optional): Number of results to return. Defaults to 100.
|
||||
Returns:
|
||||
List[OutputData]: Search results.
|
||||
"""
|
||||
filters_str = " AND ".join([f"{k} = {self._stringify(v)}" for k, v in filters.items()]) if filters else None
|
||||
|
||||
info = self.client.info()
|
||||
ns_info = info.namespaces.get(self.collection_name)
|
||||
|
||||
if not ns_info or ns_info.vector_count == 0:
|
||||
return [[]]
|
||||
|
||||
random_vector = [1.0] * self.client.info().dimension
|
||||
|
||||
results, query = self.client.resumable_query(
|
||||
vector=random_vector,
|
||||
filter=filters_str or "",
|
||||
include_metadata=True,
|
||||
namespace=self.collection_name,
|
||||
top_k=100,
|
||||
)
|
||||
with query:
|
||||
while True:
|
||||
if len(results) >= limit:
|
||||
break
|
||||
res = query.fetch_next(100)
|
||||
if not res:
|
||||
break
|
||||
results.extend(res)
|
||||
|
||||
parsed_result = [
|
||||
OutputData(
|
||||
id=res.id,
|
||||
score=res.score,
|
||||
payload=res.metadata,
|
||||
)
|
||||
for res in results
|
||||
]
|
||||
return [parsed_result]
|
||||
|
||||
def create_col(self, name, vector_size, distance):
|
||||
"""
|
||||
Upstash Vector has namespaces instead of collections. A namespace is created when the first vector is inserted.
|
||||
|
||||
This method is a placeholder to maintain the interface.
|
||||
"""
|
||||
pass
|
||||
|
||||
def list_cols(self) -> List[str]:
|
||||
"""
|
||||
Lists all namespaces in the Upstash Vector index.
|
||||
Returns:
|
||||
List[str]: List of namespaces.
|
||||
"""
|
||||
return self.client.list_namespaces()
|
||||
|
||||
def delete_col(self):
|
||||
"""
|
||||
Delete the namespace and all vectors in it.
|
||||
"""
|
||||
self.client.reset(namespace=self.collection_name)
|
||||
pass
|
||||
|
||||
def col_info(self):
|
||||
"""
|
||||
Return general information about the Upstash Vector index.
|
||||
|
||||
- Total number of vectors across all namespaces
|
||||
- Total number of vectors waiting to be indexed across all namespaces
|
||||
- Total size of the index on disk in bytes
|
||||
- Vector dimension
|
||||
- Similarity function used
|
||||
- Per-namespace vector and pending vector counts
|
||||
"""
|
||||
return self.client.info()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the Upstash Vector index.
|
||||
"""
|
||||
self.delete_col()
|
||||
824
neomem/neomem/vector_stores/valkey.py
Normal file
824
neomem/neomem/vector_stores/valkey.py
Normal file
@@ -0,0 +1,824 @@
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import pytz
|
||||
import valkey
|
||||
from pydantic import BaseModel
|
||||
from valkey.exceptions import ResponseError
|
||||
|
||||
from mem0.memory.utils import extract_json
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default fields for the Valkey index
|
||||
DEFAULT_FIELDS = [
|
||||
{"name": "memory_id", "type": "tag"},
|
||||
{"name": "hash", "type": "tag"},
|
||||
{"name": "agent_id", "type": "tag"},
|
||||
{"name": "run_id", "type": "tag"},
|
||||
{"name": "user_id", "type": "tag"},
|
||||
{"name": "memory", "type": "tag"}, # Using TAG instead of TEXT for Valkey compatibility
|
||||
{"name": "metadata", "type": "tag"}, # Using TAG instead of TEXT for Valkey compatibility
|
||||
{"name": "created_at", "type": "numeric"},
|
||||
{"name": "updated_at", "type": "numeric"},
|
||||
{
|
||||
"name": "embedding",
|
||||
"type": "vector",
|
||||
"attrs": {"distance_metric": "cosine", "algorithm": "flat", "datatype": "float32"},
|
||||
},
|
||||
]
|
||||
|
||||
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
|
||||
|
||||
|
||||
class OutputData(BaseModel):
|
||||
id: str
|
||||
score: float
|
||||
payload: Dict
|
||||
|
||||
|
||||
class ValkeyDB(VectorStoreBase):
|
||||
def __init__(
|
||||
self,
|
||||
valkey_url: str,
|
||||
collection_name: str,
|
||||
embedding_model_dims: int,
|
||||
timezone: str = "UTC",
|
||||
index_type: str = "hnsw",
|
||||
hnsw_m: int = 16,
|
||||
hnsw_ef_construction: int = 200,
|
||||
hnsw_ef_runtime: int = 10,
|
||||
):
|
||||
"""
|
||||
Initialize the Valkey vector store.
|
||||
|
||||
Args:
|
||||
valkey_url (str): Valkey URL.
|
||||
collection_name (str): Collection name.
|
||||
embedding_model_dims (int): Embedding model dimensions.
|
||||
timezone (str, optional): Timezone for timestamps. Defaults to "UTC".
|
||||
index_type (str, optional): Index type ('hnsw' or 'flat'). Defaults to "hnsw".
|
||||
hnsw_m (int, optional): HNSW M parameter (connections per node). Defaults to 16.
|
||||
hnsw_ef_construction (int, optional): HNSW ef_construction parameter. Defaults to 200.
|
||||
hnsw_ef_runtime (int, optional): HNSW ef_runtime parameter. Defaults to 10.
|
||||
"""
|
||||
self.embedding_model_dims = embedding_model_dims
|
||||
self.collection_name = collection_name
|
||||
self.prefix = f"mem0:{collection_name}"
|
||||
self.timezone = timezone
|
||||
self.index_type = index_type.lower()
|
||||
self.hnsw_m = hnsw_m
|
||||
self.hnsw_ef_construction = hnsw_ef_construction
|
||||
self.hnsw_ef_runtime = hnsw_ef_runtime
|
||||
|
||||
# Validate index type
|
||||
if self.index_type not in ["hnsw", "flat"]:
|
||||
raise ValueError(f"Invalid index_type: {index_type}. Must be 'hnsw' or 'flat'")
|
||||
|
||||
# Connect to Valkey
|
||||
try:
|
||||
self.client = valkey.from_url(valkey_url)
|
||||
logger.debug(f"Successfully connected to Valkey at {valkey_url}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to connect to Valkey at {valkey_url}: {e}")
|
||||
raise
|
||||
|
||||
# Create the index schema
|
||||
self._create_index(embedding_model_dims)
|
||||
|
||||
def _build_index_schema(self, collection_name, embedding_dims, distance_metric, prefix):
|
||||
"""
|
||||
Build the FT.CREATE command for index creation.
|
||||
|
||||
Args:
|
||||
collection_name (str): Name of the collection/index
|
||||
embedding_dims (int): Vector embedding dimensions
|
||||
distance_metric (str): Distance metric (e.g., "COSINE", "L2", "IP")
|
||||
prefix (str): Key prefix for the index
|
||||
|
||||
Returns:
|
||||
list: Complete FT.CREATE command as list of arguments
|
||||
"""
|
||||
# Build the vector field configuration based on index type
|
||||
if self.index_type == "hnsw":
|
||||
vector_config = [
|
||||
"embedding",
|
||||
"VECTOR",
|
||||
"HNSW",
|
||||
"12", # Attribute count: TYPE, FLOAT32, DIM, dims, DISTANCE_METRIC, metric, M, m, EF_CONSTRUCTION, ef_construction, EF_RUNTIME, ef_runtime
|
||||
"TYPE",
|
||||
"FLOAT32",
|
||||
"DIM",
|
||||
str(embedding_dims),
|
||||
"DISTANCE_METRIC",
|
||||
distance_metric,
|
||||
"M",
|
||||
str(self.hnsw_m),
|
||||
"EF_CONSTRUCTION",
|
||||
str(self.hnsw_ef_construction),
|
||||
"EF_RUNTIME",
|
||||
str(self.hnsw_ef_runtime),
|
||||
]
|
||||
elif self.index_type == "flat":
|
||||
vector_config = [
|
||||
"embedding",
|
||||
"VECTOR",
|
||||
"FLAT",
|
||||
"6", # Attribute count: TYPE, FLOAT32, DIM, dims, DISTANCE_METRIC, metric
|
||||
"TYPE",
|
||||
"FLOAT32",
|
||||
"DIM",
|
||||
str(embedding_dims),
|
||||
"DISTANCE_METRIC",
|
||||
distance_metric,
|
||||
]
|
||||
else:
|
||||
# This should never happen due to constructor validation, but be defensive
|
||||
raise ValueError(f"Unsupported index_type: {self.index_type}. Must be 'hnsw' or 'flat'")
|
||||
|
||||
# Build the complete command (comma is default separator for TAG fields)
|
||||
cmd = [
|
||||
"FT.CREATE",
|
||||
collection_name,
|
||||
"ON",
|
||||
"HASH",
|
||||
"PREFIX",
|
||||
"1",
|
||||
prefix,
|
||||
"SCHEMA",
|
||||
"memory_id",
|
||||
"TAG",
|
||||
"hash",
|
||||
"TAG",
|
||||
"agent_id",
|
||||
"TAG",
|
||||
"run_id",
|
||||
"TAG",
|
||||
"user_id",
|
||||
"TAG",
|
||||
"memory",
|
||||
"TAG",
|
||||
"metadata",
|
||||
"TAG",
|
||||
"created_at",
|
||||
"NUMERIC",
|
||||
"updated_at",
|
||||
"NUMERIC",
|
||||
] + vector_config
|
||||
|
||||
return cmd
|
||||
|
||||
def _create_index(self, embedding_model_dims):
|
||||
"""
|
||||
Create the search index with the specified schema.
|
||||
|
||||
Args:
|
||||
embedding_model_dims (int): Dimensions for the vector embeddings.
|
||||
|
||||
Raises:
|
||||
ValueError: If the search module is not available.
|
||||
Exception: For other errors during index creation.
|
||||
"""
|
||||
# Check if the search module is available
|
||||
try:
|
||||
# Try to execute a search command
|
||||
self.client.execute_command("FT._LIST")
|
||||
except ResponseError as e:
|
||||
if "unknown command" in str(e).lower():
|
||||
raise ValueError(
|
||||
"Valkey search module is not available. Please ensure Valkey is running with the search module enabled. "
|
||||
"The search module can be loaded using the --loadmodule option with the valkey-search library. "
|
||||
"For installation and setup instructions, refer to the Valkey Search documentation."
|
||||
)
|
||||
else:
|
||||
logger.exception(f"Error checking search module: {e}")
|
||||
raise
|
||||
|
||||
# Check if the index already exists
|
||||
try:
|
||||
self.client.ft(self.collection_name).info()
|
||||
return
|
||||
except ResponseError as e:
|
||||
if "not found" not in str(e).lower():
|
||||
logger.exception(f"Error checking index existence: {e}")
|
||||
raise
|
||||
|
||||
# Build and execute the index creation command
|
||||
cmd = self._build_index_schema(
|
||||
self.collection_name,
|
||||
embedding_model_dims,
|
||||
"COSINE", # Fixed distance metric for initialization
|
||||
self.prefix,
|
||||
)
|
||||
|
||||
try:
|
||||
self.client.execute_command(*cmd)
|
||||
logger.info(f"Successfully created {self.index_type.upper()} index {self.collection_name}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error creating index {self.collection_name}: {e}")
|
||||
raise
|
||||
|
||||
def create_col(self, name=None, vector_size=None, distance=None):
|
||||
"""
|
||||
Create a new collection (index) in Valkey.
|
||||
|
||||
Args:
|
||||
name (str, optional): Name for the collection. Defaults to None, which uses the current collection_name.
|
||||
vector_size (int, optional): Size of the vector embeddings. Defaults to None, which uses the current embedding_model_dims.
|
||||
distance (str, optional): Distance metric to use. Defaults to None, which uses 'cosine'.
|
||||
|
||||
Returns:
|
||||
The created index object.
|
||||
"""
|
||||
# Use provided parameters or fall back to instance attributes
|
||||
collection_name = name or self.collection_name
|
||||
embedding_dims = vector_size or self.embedding_model_dims
|
||||
distance_metric = distance or "COSINE"
|
||||
prefix = f"mem0:{collection_name}"
|
||||
|
||||
# Try to drop the index if it exists (cleanup before creation)
|
||||
self._drop_index(collection_name, log_level="silent")
|
||||
|
||||
# Build and execute the index creation command
|
||||
cmd = self._build_index_schema(
|
||||
collection_name,
|
||||
embedding_dims,
|
||||
distance_metric, # Configurable distance metric
|
||||
prefix,
|
||||
)
|
||||
|
||||
try:
|
||||
self.client.execute_command(*cmd)
|
||||
logger.info(f"Successfully created {self.index_type.upper()} index {collection_name}")
|
||||
|
||||
# Update instance attributes if creating a new collection
|
||||
if name:
|
||||
self.collection_name = collection_name
|
||||
self.prefix = prefix
|
||||
|
||||
return self.client.ft(collection_name)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error creating collection {collection_name}: {e}")
|
||||
raise
|
||||
|
||||
def insert(self, vectors: list, payloads: list = None, ids: list = None):
|
||||
"""
|
||||
Insert vectors and their payloads into the index.
|
||||
|
||||
Args:
|
||||
vectors (list): List of vectors to insert.
|
||||
payloads (list, optional): List of payloads corresponding to the vectors.
|
||||
ids (list, optional): List of IDs for the vectors.
|
||||
"""
|
||||
for vector, payload, id in zip(vectors, payloads, ids):
|
||||
try:
|
||||
# Create the key for the hash
|
||||
key = f"{self.prefix}:{id}"
|
||||
|
||||
# Check for required fields and provide defaults if missing
|
||||
if "data" not in payload:
|
||||
# Silently use default value for missing 'data' field
|
||||
pass
|
||||
|
||||
# Ensure created_at is present
|
||||
if "created_at" not in payload:
|
||||
payload["created_at"] = datetime.now(pytz.timezone(self.timezone)).isoformat()
|
||||
|
||||
# Prepare the hash data
|
||||
hash_data = {
|
||||
"memory_id": id,
|
||||
"hash": payload.get("hash", f"hash_{id}"), # Use a default hash if not provided
|
||||
"memory": payload.get("data", f"data_{id}"), # Use a default data if not provided
|
||||
"created_at": int(datetime.fromisoformat(payload["created_at"]).timestamp()),
|
||||
"embedding": np.array(vector, dtype=np.float32).tobytes(),
|
||||
}
|
||||
|
||||
# Add optional fields
|
||||
for field in ["agent_id", "run_id", "user_id"]:
|
||||
if field in payload:
|
||||
hash_data[field] = payload[field]
|
||||
|
||||
# Add metadata
|
||||
hash_data["metadata"] = json.dumps({k: v for k, v in payload.items() if k not in excluded_keys})
|
||||
|
||||
# Store in Valkey
|
||||
self.client.hset(key, mapping=hash_data)
|
||||
logger.debug(f"Successfully inserted vector with ID {id}")
|
||||
except KeyError as e:
|
||||
logger.error(f"Error inserting vector with ID {id}: Missing required field {e}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error inserting vector with ID {id}: {e}")
|
||||
raise
|
||||
|
||||
def _build_search_query(self, knn_part, filters=None):
|
||||
"""
|
||||
Build a search query string with filters.
|
||||
|
||||
Args:
|
||||
knn_part (str): The KNN part of the query.
|
||||
filters (dict, optional): Filters to apply to the search. Each key-value pair
|
||||
becomes a tag filter (@key:{value}). None values are ignored.
|
||||
Values are used as-is (no validation) - wildcards, lists, etc. are
|
||||
passed through literally to Valkey search. Multiple filters are
|
||||
combined with AND logic (space-separated).
|
||||
|
||||
Returns:
|
||||
str: The complete search query string in format "filter_expr =>[KNN...]"
|
||||
or "*=>[KNN...]" if no valid filters.
|
||||
"""
|
||||
# No filters, just use the KNN search
|
||||
if not filters or not any(value is not None for key, value in filters.items()):
|
||||
return f"*=>{knn_part}"
|
||||
|
||||
# Build filter expression
|
||||
filter_parts = []
|
||||
for key, value in filters.items():
|
||||
if value is not None:
|
||||
# Use the correct filter syntax for Valkey
|
||||
filter_parts.append(f"@{key}:{{{value}}}")
|
||||
|
||||
# No valid filter parts
|
||||
if not filter_parts:
|
||||
return f"*=>{knn_part}"
|
||||
|
||||
# Combine filter parts with proper syntax
|
||||
filter_expr = " ".join(filter_parts)
|
||||
return f"{filter_expr} =>{knn_part}"
|
||||
|
||||
def _execute_search(self, query, params):
|
||||
"""
|
||||
Execute a search query.
|
||||
|
||||
Args:
|
||||
query (str): The search query to execute.
|
||||
params (dict): The query parameters.
|
||||
|
||||
Returns:
|
||||
The search results.
|
||||
"""
|
||||
try:
|
||||
return self.client.ft(self.collection_name).search(query, query_params=params)
|
||||
except ResponseError as e:
|
||||
logger.error(f"Search failed with query '{query}': {e}")
|
||||
raise
|
||||
|
||||
def _process_search_results(self, results):
|
||||
"""
|
||||
Process search results into OutputData objects.
|
||||
|
||||
Args:
|
||||
results: The search results from Valkey.
|
||||
|
||||
Returns:
|
||||
list: List of OutputData objects.
|
||||
"""
|
||||
memory_results = []
|
||||
for doc in results.docs:
|
||||
# Extract the score
|
||||
score = float(doc.vector_score) if hasattr(doc, "vector_score") else None
|
||||
|
||||
# Create the payload
|
||||
payload = {
|
||||
"hash": doc.hash,
|
||||
"data": doc.memory,
|
||||
"created_at": self._format_timestamp(int(doc.created_at), self.timezone),
|
||||
}
|
||||
|
||||
# Add updated_at if available
|
||||
if hasattr(doc, "updated_at"):
|
||||
payload["updated_at"] = self._format_timestamp(int(doc.updated_at), self.timezone)
|
||||
|
||||
# Add optional fields
|
||||
for field in ["agent_id", "run_id", "user_id"]:
|
||||
if hasattr(doc, field):
|
||||
payload[field] = getattr(doc, field)
|
||||
|
||||
# Add metadata
|
||||
if hasattr(doc, "metadata"):
|
||||
try:
|
||||
metadata = json.loads(extract_json(doc.metadata))
|
||||
payload.update(metadata)
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
logger.warning(f"Failed to parse metadata: {e}")
|
||||
|
||||
# Create the result
|
||||
memory_results.append(OutputData(id=doc.memory_id, score=score, payload=payload))
|
||||
|
||||
return memory_results
|
||||
|
||||
def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None, ef_runtime: int = None):
|
||||
"""
|
||||
Search for similar vectors in the index.
|
||||
|
||||
Args:
|
||||
query (str): The search query.
|
||||
vectors (list): The vector to search for.
|
||||
limit (int, optional): Maximum number of results to return. Defaults to 5.
|
||||
filters (dict, optional): Filters to apply to the search. Defaults to None.
|
||||
ef_runtime (int, optional): HNSW ef_runtime parameter for this query. Only used with HNSW index. Defaults to None.
|
||||
|
||||
Returns:
|
||||
list: List of OutputData objects.
|
||||
"""
|
||||
# Convert the vector to bytes
|
||||
vector_bytes = np.array(vectors, dtype=np.float32).tobytes()
|
||||
|
||||
# Build the KNN part with optional EF_RUNTIME for HNSW
|
||||
if self.index_type == "hnsw" and ef_runtime is not None:
|
||||
knn_part = f"[KNN {limit} @embedding $vec_param EF_RUNTIME {ef_runtime} AS vector_score]"
|
||||
else:
|
||||
# For FLAT indexes or when ef_runtime is None, use basic KNN
|
||||
knn_part = f"[KNN {limit} @embedding $vec_param AS vector_score]"
|
||||
|
||||
# Build the complete query
|
||||
q = self._build_search_query(knn_part, filters)
|
||||
|
||||
# Log the query for debugging (only in debug mode)
|
||||
logger.debug(f"Valkey search query: {q}")
|
||||
|
||||
# Set up the query parameters
|
||||
params = {"vec_param": vector_bytes}
|
||||
|
||||
# Execute the search
|
||||
results = self._execute_search(q, params)
|
||||
|
||||
# Process the results
|
||||
return self._process_search_results(results)
|
||||
|
||||
def delete(self, vector_id):
|
||||
"""
|
||||
Delete a vector from the index.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to delete.
|
||||
"""
|
||||
try:
|
||||
key = f"{self.prefix}:{vector_id}"
|
||||
self.client.delete(key)
|
||||
logger.debug(f"Successfully deleted vector with ID {vector_id}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error deleting vector with ID {vector_id}: {e}")
|
||||
raise
|
||||
|
||||
def update(self, vector_id=None, vector=None, payload=None):
|
||||
"""
|
||||
Update a vector in the index.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to update.
|
||||
vector (list, optional): New vector data.
|
||||
payload (dict, optional): New payload data.
|
||||
"""
|
||||
try:
|
||||
key = f"{self.prefix}:{vector_id}"
|
||||
|
||||
# Check for required fields and provide defaults if missing
|
||||
if "data" not in payload:
|
||||
# Silently use default value for missing 'data' field
|
||||
pass
|
||||
|
||||
# Ensure created_at is present
|
||||
if "created_at" not in payload:
|
||||
payload["created_at"] = datetime.now(pytz.timezone(self.timezone)).isoformat()
|
||||
|
||||
# Prepare the hash data
|
||||
hash_data = {
|
||||
"memory_id": vector_id,
|
||||
"hash": payload.get("hash", f"hash_{vector_id}"), # Use a default hash if not provided
|
||||
"memory": payload.get("data", f"data_{vector_id}"), # Use a default data if not provided
|
||||
"created_at": int(datetime.fromisoformat(payload["created_at"]).timestamp()),
|
||||
"embedding": np.array(vector, dtype=np.float32).tobytes(),
|
||||
}
|
||||
|
||||
# Add updated_at if available
|
||||
if "updated_at" in payload:
|
||||
hash_data["updated_at"] = int(datetime.fromisoformat(payload["updated_at"]).timestamp())
|
||||
|
||||
# Add optional fields
|
||||
for field in ["agent_id", "run_id", "user_id"]:
|
||||
if field in payload:
|
||||
hash_data[field] = payload[field]
|
||||
|
||||
# Add metadata
|
||||
hash_data["metadata"] = json.dumps({k: v for k, v in payload.items() if k not in excluded_keys})
|
||||
|
||||
# Update in Valkey
|
||||
self.client.hset(key, mapping=hash_data)
|
||||
logger.debug(f"Successfully updated vector with ID {vector_id}")
|
||||
except KeyError as e:
|
||||
logger.error(f"Error updating vector with ID {vector_id}: Missing required field {e}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error updating vector with ID {vector_id}: {e}")
|
||||
raise
|
||||
|
||||
def _format_timestamp(self, timestamp, timezone=None):
|
||||
"""
|
||||
Format a timestamp with the specified timezone.
|
||||
|
||||
Args:
|
||||
timestamp (int): The timestamp to format.
|
||||
timezone (str, optional): The timezone to use. Defaults to UTC.
|
||||
|
||||
Returns:
|
||||
str: The formatted timestamp.
|
||||
"""
|
||||
# Use UTC as default timezone if not specified
|
||||
tz = pytz.timezone(timezone or "UTC")
|
||||
return datetime.fromtimestamp(timestamp, tz=tz).isoformat(timespec="microseconds")
|
||||
|
||||
def _process_document_fields(self, result, vector_id):
|
||||
"""
|
||||
Process document fields from a Valkey hash result.
|
||||
|
||||
Args:
|
||||
result (dict): The hash result from Valkey.
|
||||
vector_id (str): The vector ID.
|
||||
|
||||
Returns:
|
||||
dict: The processed payload.
|
||||
str: The memory ID.
|
||||
"""
|
||||
# Create the payload with error handling
|
||||
payload = {}
|
||||
|
||||
# Convert bytes to string for text fields
|
||||
for k in result:
|
||||
if k not in ["embedding"]:
|
||||
if isinstance(result[k], bytes):
|
||||
try:
|
||||
result[k] = result[k].decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
# If decoding fails, keep the bytes
|
||||
pass
|
||||
|
||||
# Add required fields with error handling
|
||||
for field in ["hash", "memory", "created_at"]:
|
||||
if field in result:
|
||||
if field == "created_at":
|
||||
try:
|
||||
payload[field] = self._format_timestamp(int(result[field]), self.timezone)
|
||||
except (ValueError, TypeError):
|
||||
payload[field] = result[field]
|
||||
else:
|
||||
payload[field] = result[field]
|
||||
else:
|
||||
# Use default values for missing fields
|
||||
if field == "hash":
|
||||
payload[field] = "unknown"
|
||||
elif field == "memory":
|
||||
payload[field] = "unknown"
|
||||
elif field == "created_at":
|
||||
payload[field] = self._format_timestamp(
|
||||
int(datetime.now(tz=pytz.timezone(self.timezone)).timestamp()), self.timezone
|
||||
)
|
||||
|
||||
# Rename memory to data for consistency
|
||||
if "memory" in payload:
|
||||
payload["data"] = payload.pop("memory")
|
||||
|
||||
# Add updated_at if available
|
||||
if "updated_at" in result:
|
||||
try:
|
||||
payload["updated_at"] = self._format_timestamp(int(result["updated_at"]), self.timezone)
|
||||
except (ValueError, TypeError):
|
||||
payload["updated_at"] = result["updated_at"]
|
||||
|
||||
# Add optional fields
|
||||
for field in ["agent_id", "run_id", "user_id"]:
|
||||
if field in result:
|
||||
payload[field] = result[field]
|
||||
|
||||
# Add metadata
|
||||
if "metadata" in result:
|
||||
try:
|
||||
metadata = json.loads(extract_json(result["metadata"]))
|
||||
payload.update(metadata)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.warning(f"Failed to parse metadata: {result.get('metadata')}")
|
||||
|
||||
# Use memory_id from result if available, otherwise use vector_id
|
||||
memory_id = result.get("memory_id", vector_id)
|
||||
|
||||
return payload, memory_id
|
||||
|
||||
def _convert_bytes(self, data):
|
||||
"""Convert bytes data back to string"""
|
||||
if isinstance(data, bytes):
|
||||
try:
|
||||
return data.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
return data
|
||||
if isinstance(data, dict):
|
||||
return {self._convert_bytes(key): self._convert_bytes(value) for key, value in data.items()}
|
||||
if isinstance(data, list):
|
||||
return [self._convert_bytes(item) for item in data]
|
||||
if isinstance(data, tuple):
|
||||
return tuple(self._convert_bytes(item) for item in data)
|
||||
return data
|
||||
|
||||
def get(self, vector_id):
|
||||
"""
|
||||
Get a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to get.
|
||||
|
||||
Returns:
|
||||
OutputData: The retrieved vector.
|
||||
"""
|
||||
try:
|
||||
key = f"{self.prefix}:{vector_id}"
|
||||
result = self.client.hgetall(key)
|
||||
|
||||
if not result:
|
||||
raise KeyError(f"Vector with ID {vector_id} not found")
|
||||
|
||||
# Convert bytes keys/values to strings
|
||||
result = self._convert_bytes(result)
|
||||
|
||||
logger.debug(f"Retrieved result keys: {result.keys()}")
|
||||
|
||||
# Process the document fields
|
||||
payload, memory_id = self._process_document_fields(result, vector_id)
|
||||
|
||||
return OutputData(id=memory_id, payload=payload, score=0.0)
|
||||
except KeyError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"Error getting vector with ID {vector_id}: {e}")
|
||||
raise
|
||||
|
||||
def list_cols(self):
|
||||
"""
|
||||
List all collections (indices) in Valkey.
|
||||
|
||||
Returns:
|
||||
list: List of collection names.
|
||||
"""
|
||||
try:
|
||||
# Use the FT._LIST command to list all indices
|
||||
return self.client.execute_command("FT._LIST")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error listing collections: {e}")
|
||||
raise
|
||||
|
||||
def _drop_index(self, collection_name, log_level="error"):
|
||||
"""
|
||||
Drop an index by name using the documented FT.DROPINDEX command.
|
||||
|
||||
Args:
|
||||
collection_name (str): Name of the index to drop.
|
||||
log_level (str): Logging level for missing index ("silent", "info", "error").
|
||||
"""
|
||||
try:
|
||||
self.client.execute_command("FT.DROPINDEX", collection_name)
|
||||
logger.info(f"Successfully deleted index {collection_name}")
|
||||
return True
|
||||
except ResponseError as e:
|
||||
if "Unknown index name" in str(e):
|
||||
# Index doesn't exist - handle based on context
|
||||
if log_level == "silent":
|
||||
pass # No logging in situations where this is expected such as initial index creation
|
||||
elif log_level == "info":
|
||||
logger.info(f"Index {collection_name} doesn't exist, skipping deletion")
|
||||
return False
|
||||
else:
|
||||
# Real error - always log and raise
|
||||
logger.error(f"Error deleting index {collection_name}: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
# Non-ResponseError exceptions - always log and raise
|
||||
logger.error(f"Error deleting index {collection_name}: {e}")
|
||||
raise
|
||||
|
||||
def delete_col(self):
|
||||
"""
|
||||
Delete the current collection (index).
|
||||
"""
|
||||
return self._drop_index(self.collection_name, log_level="info")
|
||||
|
||||
def col_info(self, name=None):
|
||||
"""
|
||||
Get information about a collection (index).
|
||||
|
||||
Args:
|
||||
name (str, optional): Name of the collection. Defaults to None, which uses the current collection_name.
|
||||
|
||||
Returns:
|
||||
dict: Information about the collection.
|
||||
"""
|
||||
try:
|
||||
collection_name = name or self.collection_name
|
||||
return self.client.ft(collection_name).info()
|
||||
except Exception as e:
|
||||
logger.exception(f"Error getting collection info for {collection_name}: {e}")
|
||||
raise
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the index by deleting and recreating it.
|
||||
"""
|
||||
try:
|
||||
collection_name = self.collection_name
|
||||
logger.warning(f"Resetting index {collection_name}...")
|
||||
|
||||
# Delete the index
|
||||
self.delete_col()
|
||||
|
||||
# Recreate the index
|
||||
self._create_index(self.embedding_model_dims)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.exception(f"Error resetting index {self.collection_name}: {e}")
|
||||
raise
|
||||
|
||||
def _build_list_query(self, filters=None):
|
||||
"""
|
||||
Build a query for listing vectors.
|
||||
|
||||
Args:
|
||||
filters (dict, optional): Filters to apply to the list. Each key-value pair
|
||||
becomes a tag filter (@key:{value}). None values are ignored.
|
||||
Values are used as-is (no validation) - wildcards, lists, etc. are
|
||||
passed through literally to Valkey search.
|
||||
|
||||
Returns:
|
||||
str: The query string. Returns "*" if no valid filters provided.
|
||||
"""
|
||||
# Default query
|
||||
q = "*"
|
||||
|
||||
# Add filters if provided
|
||||
if filters and any(value is not None for key, value in filters.items()):
|
||||
filter_conditions = []
|
||||
for key, value in filters.items():
|
||||
if value is not None:
|
||||
filter_conditions.append(f"@{key}:{{{value}}}")
|
||||
|
||||
if filter_conditions:
|
||||
q = " ".join(filter_conditions)
|
||||
|
||||
return q
|
||||
|
||||
def list(self, filters: dict = None, limit: int = None) -> list:
|
||||
"""
|
||||
List all recent created memories from the vector store.
|
||||
|
||||
Args:
|
||||
filters (dict, optional): Filters to apply to the list. Each key-value pair
|
||||
becomes a tag filter (@key:{value}). None values are ignored.
|
||||
Values are used as-is without validation - wildcards, special characters,
|
||||
lists, etc. are passed through literally to Valkey search.
|
||||
Multiple filters are combined with AND logic.
|
||||
limit (int, optional): Maximum number of results to return. Defaults to 1000
|
||||
if not specified.
|
||||
|
||||
Returns:
|
||||
list: Nested list format [[MemoryResult(), ...]] matching Redis implementation.
|
||||
Each MemoryResult contains id and payload with hash, data, timestamps, etc.
|
||||
"""
|
||||
try:
|
||||
# Since Valkey search requires vector format, use a dummy vector search
|
||||
# that returns all documents by using a zero vector and large K
|
||||
dummy_vector = [0.0] * self.embedding_model_dims
|
||||
search_limit = limit if limit is not None else 1000 # Large default
|
||||
|
||||
# Use the existing search method which handles filters properly
|
||||
search_results = self.search("", dummy_vector, limit=search_limit, filters=filters)
|
||||
|
||||
# Convert search results to list format (match Redis format)
|
||||
class MemoryResult:
|
||||
def __init__(self, id: str, payload: dict, score: float = None):
|
||||
self.id = id
|
||||
self.payload = payload
|
||||
self.score = score
|
||||
|
||||
memory_results = []
|
||||
for result in search_results:
|
||||
# Create payload in the expected format
|
||||
payload = {
|
||||
"hash": result.payload.get("hash", ""),
|
||||
"data": result.payload.get("data", ""),
|
||||
"created_at": result.payload.get("created_at"),
|
||||
"updated_at": result.payload.get("updated_at"),
|
||||
}
|
||||
|
||||
# Add metadata (exclude system fields)
|
||||
for key, value in result.payload.items():
|
||||
if key not in ["data", "hash", "created_at", "updated_at"]:
|
||||
payload[key] = value
|
||||
|
||||
# Create MemoryResult object (matching Redis format)
|
||||
memory_results.append(MemoryResult(id=result.id, payload=payload))
|
||||
|
||||
# Return nested list format like Redis
|
||||
return [memory_results]
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in list method: {e}")
|
||||
return [[]] # Return empty result on error
|
||||
629
neomem/neomem/vector_stores/vertex_ai_vector_search.py
Normal file
629
neomem/neomem/vector_stores/vertex_ai_vector_search.py
Normal file
@@ -0,0 +1,629 @@
|
||||
import logging
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import google.api_core.exceptions
|
||||
from google.cloud import aiplatform, aiplatform_v1
|
||||
from google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint import (
|
||||
Namespace,
|
||||
)
|
||||
from google.oauth2 import service_account
|
||||
from langchain.schema import Document
|
||||
from pydantic import BaseModel
|
||||
|
||||
from mem0.configs.vector_stores.vertex_ai_vector_search import (
|
||||
GoogleMatchingEngineConfig,
|
||||
)
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OutputData(BaseModel):
|
||||
id: Optional[str] # memory id
|
||||
score: Optional[float] # distance
|
||||
payload: Optional[Dict] # metadata
|
||||
|
||||
|
||||
class GoogleMatchingEngine(VectorStoreBase):
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize Google Matching Engine client."""
|
||||
logger.debug("Initializing Google Matching Engine with kwargs: %s", kwargs)
|
||||
|
||||
# If collection_name is passed, use it as deployment_index_id if deployment_index_id is not provided
|
||||
if "collection_name" in kwargs and "deployment_index_id" not in kwargs:
|
||||
kwargs["deployment_index_id"] = kwargs["collection_name"]
|
||||
logger.debug("Using collection_name as deployment_index_id: %s", kwargs["deployment_index_id"])
|
||||
elif "deployment_index_id" in kwargs and "collection_name" not in kwargs:
|
||||
kwargs["collection_name"] = kwargs["deployment_index_id"]
|
||||
logger.debug("Using deployment_index_id as collection_name: %s", kwargs["collection_name"])
|
||||
|
||||
try:
|
||||
config = GoogleMatchingEngineConfig(**kwargs)
|
||||
logger.debug("Config created: %s", config.model_dump())
|
||||
logger.debug("Config collection_name: %s", getattr(config, "collection_name", None))
|
||||
except Exception as e:
|
||||
logger.error("Failed to validate config: %s", str(e))
|
||||
raise
|
||||
|
||||
self.project_id = config.project_id
|
||||
self.project_number = config.project_number
|
||||
self.region = config.region
|
||||
self.endpoint_id = config.endpoint_id
|
||||
self.index_id = config.index_id # The actual index ID
|
||||
self.deployment_index_id = config.deployment_index_id # The deployment-specific ID
|
||||
self.collection_name = config.collection_name
|
||||
self.vector_search_api_endpoint = config.vector_search_api_endpoint
|
||||
|
||||
logger.debug("Using project=%s, location=%s", self.project_id, self.region)
|
||||
|
||||
# Initialize Vertex AI with credentials if provided
|
||||
init_args = {
|
||||
"project": self.project_id,
|
||||
"location": self.region,
|
||||
}
|
||||
if hasattr(config, "credentials_path") and config.credentials_path:
|
||||
logger.debug("Using credentials from: %s", config.credentials_path)
|
||||
credentials = service_account.Credentials.from_service_account_file(config.credentials_path)
|
||||
init_args["credentials"] = credentials
|
||||
|
||||
try:
|
||||
aiplatform.init(**init_args)
|
||||
logger.debug("Vertex AI initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize Vertex AI: %s", str(e))
|
||||
raise
|
||||
|
||||
try:
|
||||
# Format the index path properly using the configured index_id
|
||||
index_path = f"projects/{self.project_number}/locations/{self.region}/indexes/{self.index_id}"
|
||||
logger.debug("Initializing index with path: %s", index_path)
|
||||
self.index = aiplatform.MatchingEngineIndex(index_name=index_path)
|
||||
logger.debug("Index initialized successfully")
|
||||
|
||||
# Format the endpoint name properly
|
||||
endpoint_name = self.endpoint_id
|
||||
logger.debug("Initializing endpoint with name: %s", endpoint_name)
|
||||
self.index_endpoint = aiplatform.MatchingEngineIndexEndpoint(index_endpoint_name=endpoint_name)
|
||||
logger.debug("Endpoint initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize Matching Engine components: %s", str(e))
|
||||
raise ValueError(f"Invalid configuration: {str(e)}")
|
||||
|
||||
def _parse_output(self, data: Dict) -> List[OutputData]:
|
||||
"""
|
||||
Parse the output data.
|
||||
Args:
|
||||
data (Dict): Output data.
|
||||
Returns:
|
||||
List[OutputData]: Parsed output data.
|
||||
"""
|
||||
results = data.get("nearestNeighbors", {}).get("neighbors", [])
|
||||
output_data = []
|
||||
for result in results:
|
||||
output_data.append(
|
||||
OutputData(
|
||||
id=result.get("datapoint").get("datapointId"),
|
||||
score=result.get("distance"),
|
||||
payload=result.get("datapoint").get("metadata"),
|
||||
)
|
||||
)
|
||||
return output_data
|
||||
|
||||
def _create_restriction(self, key: str, value: Any) -> aiplatform_v1.types.index.IndexDatapoint.Restriction:
|
||||
"""Create a restriction object for the Matching Engine index.
|
||||
|
||||
Args:
|
||||
key: The namespace/key for the restriction
|
||||
value: The value to restrict on
|
||||
|
||||
Returns:
|
||||
Restriction object for the index
|
||||
"""
|
||||
str_value = str(value) if value is not None else ""
|
||||
return aiplatform_v1.types.index.IndexDatapoint.Restriction(namespace=key, allow_list=[str_value])
|
||||
|
||||
def _create_datapoint(
|
||||
self, vector_id: str, vector: List[float], payload: Optional[Dict] = None
|
||||
) -> aiplatform_v1.types.index.IndexDatapoint:
|
||||
"""Create a datapoint object for the Matching Engine index.
|
||||
|
||||
Args:
|
||||
vector_id: The ID for the datapoint
|
||||
vector: The vector to store
|
||||
payload: Optional metadata to store with the vector
|
||||
|
||||
Returns:
|
||||
IndexDatapoint object
|
||||
"""
|
||||
restrictions = []
|
||||
if payload:
|
||||
restrictions = [self._create_restriction(key, value) for key, value in payload.items()]
|
||||
|
||||
return aiplatform_v1.types.index.IndexDatapoint(
|
||||
datapoint_id=vector_id, feature_vector=vector, restricts=restrictions
|
||||
)
|
||||
|
||||
def insert(
|
||||
self,
|
||||
vectors: List[list],
|
||||
payloads: Optional[List[Dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
"""Insert vectors into the Matching Engine index.
|
||||
|
||||
Args:
|
||||
vectors: List of vectors to insert
|
||||
payloads: Optional list of metadata dictionaries
|
||||
ids: Optional list of IDs for the vectors
|
||||
|
||||
Raises:
|
||||
ValueError: If vectors is empty or lengths don't match
|
||||
GoogleAPIError: If the API call fails
|
||||
"""
|
||||
if not vectors:
|
||||
raise ValueError("No vectors provided for insertion")
|
||||
|
||||
if payloads and len(payloads) != len(vectors):
|
||||
raise ValueError(f"Number of payloads ({len(payloads)}) does not match number of vectors ({len(vectors)})")
|
||||
|
||||
if ids and len(ids) != len(vectors):
|
||||
raise ValueError(f"Number of ids ({len(ids)}) does not match number of vectors ({len(vectors)})")
|
||||
|
||||
logger.debug("Starting insert of %d vectors", len(vectors))
|
||||
|
||||
try:
|
||||
datapoints = [
|
||||
self._create_datapoint(
|
||||
vector_id=ids[i] if ids else str(uuid.uuid4()),
|
||||
vector=vector,
|
||||
payload=payloads[i] if payloads and i < len(payloads) else None,
|
||||
)
|
||||
for i, vector in enumerate(vectors)
|
||||
]
|
||||
|
||||
logger.debug("Created %d datapoints", len(datapoints))
|
||||
self.index.upsert_datapoints(datapoints=datapoints)
|
||||
logger.debug("Successfully inserted datapoints")
|
||||
|
||||
except google.api_core.exceptions.GoogleAPIError as e:
|
||||
logger.error("Failed to insert vectors: %s", str(e))
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Unexpected error during insert: %s", str(e))
|
||||
logger.error("Stack trace: %s", traceback.format_exc())
|
||||
raise
|
||||
|
||||
def search(
|
||||
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
|
||||
) -> List[OutputData]:
|
||||
"""
|
||||
Search for similar vectors.
|
||||
Args:
|
||||
query (str): Query.
|
||||
vectors (List[float]): Query vector.
|
||||
limit (int, optional): Number of results to return. Defaults to 5.
|
||||
filters (Optional[Dict], optional): Filters to apply to the search. Defaults to None.
|
||||
Returns:
|
||||
List[OutputData]: Search results (unwrapped)
|
||||
"""
|
||||
logger.debug("Starting search")
|
||||
logger.debug("Limit: %d, Filters: %s", limit, filters)
|
||||
|
||||
try:
|
||||
filter_namespaces = []
|
||||
if filters:
|
||||
logger.debug("Processing filters")
|
||||
for key, value in filters.items():
|
||||
logger.debug("Processing filter %s=%s (type=%s)", key, value, type(value))
|
||||
if isinstance(value, (str, int, float)):
|
||||
logger.debug("Adding simple filter for %s", key)
|
||||
filter_namespaces.append(Namespace(key, [str(value)], []))
|
||||
elif isinstance(value, dict):
|
||||
logger.debug("Adding complex filter for %s", key)
|
||||
includes = value.get("include", [])
|
||||
excludes = value.get("exclude", [])
|
||||
filter_namespaces.append(Namespace(key, includes, excludes))
|
||||
|
||||
logger.debug("Final filter_namespaces: %s", filter_namespaces)
|
||||
|
||||
response = self.index_endpoint.find_neighbors(
|
||||
deployed_index_id=self.deployment_index_id,
|
||||
queries=[vectors],
|
||||
num_neighbors=limit,
|
||||
filter=filter_namespaces if filter_namespaces else None,
|
||||
return_full_datapoint=True,
|
||||
)
|
||||
|
||||
if not response or len(response) == 0 or len(response[0]) == 0:
|
||||
logger.debug("No results found")
|
||||
return []
|
||||
|
||||
results = []
|
||||
for neighbor in response[0]:
|
||||
logger.debug("Processing neighbor - id: %s, distance: %s", neighbor.id, neighbor.distance)
|
||||
|
||||
payload = {}
|
||||
if hasattr(neighbor, "restricts"):
|
||||
logger.debug("Processing restricts")
|
||||
for restrict in neighbor.restricts:
|
||||
if hasattr(restrict, "name") and hasattr(restrict, "allow_tokens") and restrict.allow_tokens:
|
||||
logger.debug("Adding %s: %s", restrict.name, restrict.allow_tokens[0])
|
||||
payload[restrict.name] = restrict.allow_tokens[0]
|
||||
|
||||
output_data = OutputData(id=neighbor.id, score=neighbor.distance, payload=payload)
|
||||
results.append(output_data)
|
||||
|
||||
logger.debug("Returning %d results", len(results))
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error occurred: %s", str(e))
|
||||
logger.error("Error type: %s", type(e))
|
||||
logger.error("Stack trace: %s", traceback.format_exc())
|
||||
raise
|
||||
|
||||
def delete(self, vector_id: Optional[str] = None, ids: Optional[List[str]] = None) -> bool:
|
||||
"""
|
||||
Delete vectors from the Matching Engine index.
|
||||
Args:
|
||||
vector_id (Optional[str]): Single ID to delete (for backward compatibility)
|
||||
ids (Optional[List[str]]): List of IDs of vectors to delete
|
||||
Returns:
|
||||
bool: True if vectors were deleted successfully or already deleted, False if error
|
||||
"""
|
||||
logger.debug("Starting delete, vector_id: %s, ids: %s", vector_id, ids)
|
||||
try:
|
||||
# Handle both single vector_id and list of ids
|
||||
if vector_id:
|
||||
datapoint_ids = [vector_id]
|
||||
elif ids:
|
||||
datapoint_ids = ids
|
||||
else:
|
||||
raise ValueError("Either vector_id or ids must be provided")
|
||||
|
||||
logger.debug("Deleting ids: %s", datapoint_ids)
|
||||
try:
|
||||
self.index.remove_datapoints(datapoint_ids=datapoint_ids)
|
||||
logger.debug("Delete completed successfully")
|
||||
return True
|
||||
except google.api_core.exceptions.NotFound:
|
||||
# If the datapoint is already deleted, consider it a success
|
||||
logger.debug("Datapoint already deleted")
|
||||
return True
|
||||
except google.api_core.exceptions.PermissionDenied as e:
|
||||
logger.error("Permission denied: %s", str(e))
|
||||
return False
|
||||
except google.api_core.exceptions.InvalidArgument as e:
|
||||
logger.error("Invalid argument: %s", str(e))
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error occurred: %s", str(e))
|
||||
logger.error("Error type: %s", type(e))
|
||||
logger.error("Stack trace: %s", traceback.format_exc())
|
||||
return False
|
||||
|
||||
def update(
|
||||
self,
|
||||
vector_id: str,
|
||||
vector: Optional[List[float]] = None,
|
||||
payload: Optional[Dict] = None,
|
||||
) -> bool:
|
||||
"""Update a vector and its payload.
|
||||
|
||||
Args:
|
||||
vector_id: ID of the vector to update
|
||||
vector: Optional new vector values
|
||||
payload: Optional new metadata payload
|
||||
|
||||
Returns:
|
||||
bool: True if update was successful
|
||||
|
||||
Raises:
|
||||
ValueError: If neither vector nor payload is provided
|
||||
GoogleAPIError: If the API call fails
|
||||
"""
|
||||
logger.debug("Starting update for vector_id: %s", vector_id)
|
||||
|
||||
if vector is None and payload is None:
|
||||
raise ValueError("Either vector or payload must be provided for update")
|
||||
|
||||
# First check if the vector exists
|
||||
try:
|
||||
existing = self.get(vector_id)
|
||||
if existing is None:
|
||||
logger.error("Vector ID not found: %s", vector_id)
|
||||
return False
|
||||
|
||||
datapoint = self._create_datapoint(
|
||||
vector_id=vector_id, vector=vector if vector is not None else [], payload=payload
|
||||
)
|
||||
|
||||
logger.debug("Upserting datapoint: %s", datapoint)
|
||||
self.index.upsert_datapoints(datapoints=[datapoint])
|
||||
logger.debug("Update completed successfully")
|
||||
return True
|
||||
|
||||
except google.api_core.exceptions.GoogleAPIError as e:
|
||||
logger.error("API error during update: %s", str(e))
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error("Unexpected error during update: %s", str(e))
|
||||
logger.error("Stack trace: %s", traceback.format_exc())
|
||||
raise
|
||||
|
||||
def get(self, vector_id: str) -> Optional[OutputData]:
|
||||
"""
|
||||
Retrieve a vector by ID.
|
||||
Args:
|
||||
vector_id (str): ID of the vector to retrieve.
|
||||
Returns:
|
||||
Optional[OutputData]: Retrieved vector or None if not found.
|
||||
"""
|
||||
logger.debug("Starting get for vector_id: %s", vector_id)
|
||||
|
||||
try:
|
||||
if not self.vector_search_api_endpoint:
|
||||
raise ValueError("vector_search_api_endpoint is required for get operation")
|
||||
|
||||
vector_search_client = aiplatform_v1.MatchServiceClient(
|
||||
client_options={"api_endpoint": self.vector_search_api_endpoint},
|
||||
)
|
||||
datapoint = aiplatform_v1.IndexDatapoint(datapoint_id=vector_id)
|
||||
|
||||
query = aiplatform_v1.FindNeighborsRequest.Query(datapoint=datapoint, neighbor_count=1)
|
||||
request = aiplatform_v1.FindNeighborsRequest(
|
||||
index_endpoint=f"projects/{self.project_number}/locations/{self.region}/indexEndpoints/{self.endpoint_id}",
|
||||
deployed_index_id=self.deployment_index_id,
|
||||
queries=[query],
|
||||
return_full_datapoint=True,
|
||||
)
|
||||
|
||||
try:
|
||||
response = vector_search_client.find_neighbors(request)
|
||||
logger.debug("Got response")
|
||||
|
||||
if response and response.nearest_neighbors:
|
||||
nearest = response.nearest_neighbors[0]
|
||||
if nearest.neighbors:
|
||||
neighbor = nearest.neighbors[0]
|
||||
|
||||
payload = {}
|
||||
if hasattr(neighbor.datapoint, "restricts"):
|
||||
for restrict in neighbor.datapoint.restricts:
|
||||
if restrict.allow_list:
|
||||
payload[restrict.namespace] = restrict.allow_list[0]
|
||||
|
||||
return OutputData(id=neighbor.datapoint.datapoint_id, score=neighbor.distance, payload=payload)
|
||||
|
||||
logger.debug("No results found")
|
||||
return None
|
||||
|
||||
except google.api_core.exceptions.NotFound:
|
||||
logger.debug("Datapoint not found")
|
||||
return None
|
||||
except google.api_core.exceptions.PermissionDenied as e:
|
||||
logger.error("Permission denied: %s", str(e))
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error occurred: %s", str(e))
|
||||
logger.error("Error type: %s", type(e))
|
||||
logger.error("Stack trace: %s", traceback.format_exc())
|
||||
raise
|
||||
|
||||
def list_cols(self) -> List[str]:
|
||||
"""
|
||||
List all collections (indexes).
|
||||
Returns:
|
||||
List[str]: List of collection names.
|
||||
"""
|
||||
return [self.deployment_index_id]
|
||||
|
||||
def delete_col(self):
|
||||
"""
|
||||
Delete a collection (index).
|
||||
Note: This operation is not supported through the API.
|
||||
"""
|
||||
logger.warning("Delete collection operation is not supported for Google Matching Engine")
|
||||
pass
|
||||
|
||||
def col_info(self) -> Dict:
|
||||
"""
|
||||
Get information about a collection (index).
|
||||
Returns:
|
||||
Dict: Collection information.
|
||||
"""
|
||||
return {
|
||||
"index_id": self.index_id,
|
||||
"endpoint_id": self.endpoint_id,
|
||||
"project_id": self.project_id,
|
||||
"region": self.region,
|
||||
}
|
||||
|
||||
def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None) -> List[List[OutputData]]:
|
||||
"""List vectors matching the given filters.
|
||||
|
||||
Args:
|
||||
filters: Optional filters to apply
|
||||
limit: Optional maximum number of results to return
|
||||
|
||||
Returns:
|
||||
List[List[OutputData]]: List of matching vectors wrapped in an extra array
|
||||
to match the interface
|
||||
"""
|
||||
logger.debug("Starting list operation")
|
||||
logger.debug("Filters: %s", filters)
|
||||
logger.debug("Limit: %s", limit)
|
||||
|
||||
try:
|
||||
# Use a zero vector for the search
|
||||
dimension = 768 # This should be configurable based on the model
|
||||
zero_vector = [0.0] * dimension
|
||||
|
||||
# Use a large limit if none specified
|
||||
search_limit = limit if limit is not None else 10000
|
||||
|
||||
results = self.search(query=zero_vector, limit=search_limit, filters=filters)
|
||||
|
||||
logger.debug("Found %d results", len(results))
|
||||
return [results] # Wrap in extra array to match interface
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in list operation: %s", str(e))
|
||||
logger.error("Stack trace: %s", traceback.format_exc())
|
||||
raise
|
||||
|
||||
def create_col(self, name=None, vector_size=None, distance=None):
|
||||
"""
|
||||
Create a new collection. For Google Matching Engine, collections (indexes)
|
||||
are created through the Google Cloud Console or API separately.
|
||||
This method is a no-op since indexes are pre-created.
|
||||
|
||||
Args:
|
||||
name: Ignored for Google Matching Engine
|
||||
vector_size: Ignored for Google Matching Engine
|
||||
distance: Ignored for Google Matching Engine
|
||||
"""
|
||||
# Google Matching Engine indexes are created through Google Cloud Console
|
||||
# This method is included only to satisfy the abstract base class
|
||||
pass
|
||||
|
||||
def add(self, text: str, metadata: Optional[Dict] = None, user_id: Optional[str] = None) -> str:
|
||||
logger.debug("Starting add operation")
|
||||
logger.debug("Text: %s", text)
|
||||
logger.debug("Metadata: %s", metadata)
|
||||
logger.debug("User ID: %s", user_id)
|
||||
|
||||
try:
|
||||
# Generate a unique ID for this entry
|
||||
vector_id = str(uuid.uuid4())
|
||||
|
||||
# Create the payload with all necessary fields
|
||||
payload = {
|
||||
"data": text, # Store the text in the data field
|
||||
"user_id": user_id,
|
||||
**(metadata or {}),
|
||||
}
|
||||
|
||||
# Get the embedding
|
||||
vector = self.embedder.embed_query(text)
|
||||
|
||||
# Insert using the insert method
|
||||
self.insert(vectors=[vector], payloads=[payload], ids=[vector_id])
|
||||
|
||||
return vector_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error occurred: %s", str(e))
|
||||
raise
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: List[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
) -> List[str]:
|
||||
"""Add texts to the vector store.
|
||||
|
||||
Args:
|
||||
texts: List of texts to add
|
||||
metadatas: Optional list of metadata dicts
|
||||
ids: Optional list of IDs to use
|
||||
|
||||
Returns:
|
||||
List[str]: List of IDs of the added texts
|
||||
|
||||
Raises:
|
||||
ValueError: If texts is empty or lengths don't match
|
||||
"""
|
||||
if not texts:
|
||||
raise ValueError("No texts provided")
|
||||
|
||||
if metadatas and len(metadatas) != len(texts):
|
||||
raise ValueError(
|
||||
f"Number of metadata items ({len(metadatas)}) does not match number of texts ({len(texts)})"
|
||||
)
|
||||
|
||||
if ids and len(ids) != len(texts):
|
||||
raise ValueError(f"Number of ids ({len(ids)}) does not match number of texts ({len(texts)})")
|
||||
|
||||
logger.debug("Starting add_texts operation")
|
||||
logger.debug("Number of texts: %d", len(texts))
|
||||
logger.debug("Has metadatas: %s", metadatas is not None)
|
||||
logger.debug("Has ids: %s", ids is not None)
|
||||
|
||||
if ids is None:
|
||||
ids = [str(uuid.uuid4()) for _ in texts]
|
||||
|
||||
try:
|
||||
# Get embeddings
|
||||
embeddings = self.embedder.embed_documents(texts)
|
||||
|
||||
# Add to store
|
||||
self.insert(vectors=embeddings, payloads=metadatas if metadatas else [{}] * len(texts), ids=ids)
|
||||
return ids
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in add_texts: %s", str(e))
|
||||
logger.error("Stack trace: %s", traceback.format_exc())
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
texts: List[str],
|
||||
embedding: Any,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> "GoogleMatchingEngine":
|
||||
"""Create an instance from texts."""
|
||||
logger.debug("Creating instance from texts")
|
||||
store = cls(**kwargs)
|
||||
store.add_texts(texts=texts, metadatas=metadatas, ids=ids)
|
||||
return store
|
||||
|
||||
def similarity_search_with_score(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 5,
|
||||
filter: Optional[Dict] = None,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return documents most similar to query with scores."""
|
||||
logger.debug("Starting similarity search with score")
|
||||
logger.debug("Query: %s", query)
|
||||
logger.debug("k: %d", k)
|
||||
logger.debug("Filter: %s", filter)
|
||||
|
||||
embedding = self.embedder.embed_query(query)
|
||||
results = self.search(query=embedding, limit=k, filters=filter)
|
||||
|
||||
docs_and_scores = [
|
||||
(Document(page_content=result.payload.get("text", ""), metadata=result.payload), result.score)
|
||||
for result in results
|
||||
]
|
||||
logger.debug("Found %d results", len(docs_and_scores))
|
||||
return docs_and_scores
|
||||
|
||||
def similarity_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 5,
|
||||
filter: Optional[Dict] = None,
|
||||
) -> List[Document]:
|
||||
"""Return documents most similar to query."""
|
||||
logger.debug("Starting similarity search")
|
||||
docs_and_scores = self.similarity_search_with_score(query, k, filter)
|
||||
return [doc for doc, _ in docs_and_scores]
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the Google Matching Engine index.
|
||||
"""
|
||||
logger.warning("Reset operation is not supported for Google Matching Engine")
|
||||
pass
|
||||
343
neomem/neomem/vector_stores/weaviate.py
Normal file
343
neomem/neomem/vector_stores/weaviate.py
Normal file
@@ -0,0 +1,343 @@
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Dict, List, Mapping, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
try:
|
||||
import weaviate
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The 'weaviate' library is required. Please install it using 'pip install weaviate-client weaviate'."
|
||||
)
|
||||
|
||||
import weaviate.classes.config as wvcc
|
||||
from weaviate.classes.init import AdditionalConfig, Auth, Timeout
|
||||
from weaviate.classes.query import Filter, MetadataQuery
|
||||
from weaviate.util import get_valid_uuid
|
||||
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OutputData(BaseModel):
|
||||
id: str
|
||||
score: float
|
||||
payload: Dict
|
||||
|
||||
|
||||
class Weaviate(VectorStoreBase):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
embedding_model_dims: int,
|
||||
cluster_url: str = None,
|
||||
auth_client_secret: str = None,
|
||||
additional_headers: dict = None,
|
||||
):
|
||||
"""
|
||||
Initialize the Weaviate vector store.
|
||||
|
||||
Args:
|
||||
collection_name (str): Name of the collection/class in Weaviate.
|
||||
embedding_model_dims (int): Dimensions of the embedding model.
|
||||
client (WeaviateClient, optional): Existing Weaviate client instance. Defaults to None.
|
||||
cluster_url (str, optional): URL for Weaviate server. Defaults to None.
|
||||
auth_config (dict, optional): Authentication configuration for Weaviate. Defaults to None.
|
||||
additional_headers (dict, optional): Additional headers for requests. Defaults to None.
|
||||
"""
|
||||
if "localhost" in cluster_url:
|
||||
self.client = weaviate.connect_to_local(headers=additional_headers)
|
||||
elif auth_client_secret:
|
||||
self.client = weaviate.connect_to_weaviate_cloud(
|
||||
cluster_url=cluster_url,
|
||||
auth_credentials=Auth.api_key(auth_client_secret),
|
||||
headers=additional_headers,
|
||||
)
|
||||
else:
|
||||
parsed = urlparse(cluster_url) # e.g., http://mem0_store:8080
|
||||
http_host = parsed.hostname or "localhost"
|
||||
http_port = parsed.port or (443 if parsed.scheme == "https" else 8080)
|
||||
http_secure = parsed.scheme == "https"
|
||||
|
||||
# Weaviate gRPC defaults (inside Docker network)
|
||||
grpc_host = http_host
|
||||
grpc_port = 50051
|
||||
grpc_secure = False
|
||||
|
||||
self.client = weaviate.connect_to_custom(
|
||||
http_host,
|
||||
http_port,
|
||||
http_secure,
|
||||
grpc_host,
|
||||
grpc_port,
|
||||
grpc_secure,
|
||||
headers=additional_headers,
|
||||
skip_init_checks=True,
|
||||
additional_config=AdditionalConfig(timeout=Timeout(init=2.0)),
|
||||
)
|
||||
|
||||
self.collection_name = collection_name
|
||||
self.embedding_model_dims = embedding_model_dims
|
||||
self.create_col(embedding_model_dims)
|
||||
|
||||
def _parse_output(self, data: Dict) -> List[OutputData]:
|
||||
"""
|
||||
Parse the output data.
|
||||
|
||||
Args:
|
||||
data (Dict): Output data.
|
||||
|
||||
Returns:
|
||||
List[OutputData]: Parsed output data.
|
||||
"""
|
||||
keys = ["ids", "distances", "metadatas"]
|
||||
values = []
|
||||
|
||||
for key in keys:
|
||||
value = data.get(key, [])
|
||||
if isinstance(value, list) and value and isinstance(value[0], list):
|
||||
value = value[0]
|
||||
values.append(value)
|
||||
|
||||
ids, distances, metadatas = values
|
||||
max_length = max(len(v) for v in values if isinstance(v, list) and v is not None)
|
||||
|
||||
result = []
|
||||
for i in range(max_length):
|
||||
entry = OutputData(
|
||||
id=ids[i] if isinstance(ids, list) and ids and i < len(ids) else None,
|
||||
score=(distances[i] if isinstance(distances, list) and distances and i < len(distances) else None),
|
||||
payload=(metadatas[i] if isinstance(metadatas, list) and metadatas and i < len(metadatas) else None),
|
||||
)
|
||||
result.append(entry)
|
||||
|
||||
return result
|
||||
|
||||
def create_col(self, vector_size, distance="cosine"):
|
||||
"""
|
||||
Create a new collection with the specified schema.
|
||||
|
||||
Args:
|
||||
vector_size (int): Size of the vectors to be stored.
|
||||
distance (str, optional): Distance metric for vector similarity. Defaults to "cosine".
|
||||
"""
|
||||
if self.client.collections.exists(self.collection_name):
|
||||
logger.debug(f"Collection {self.collection_name} already exists. Skipping creation.")
|
||||
return
|
||||
|
||||
properties = [
|
||||
wvcc.Property(name="ids", data_type=wvcc.DataType.TEXT),
|
||||
wvcc.Property(name="hash", data_type=wvcc.DataType.TEXT),
|
||||
wvcc.Property(
|
||||
name="metadata",
|
||||
data_type=wvcc.DataType.TEXT,
|
||||
description="Additional metadata",
|
||||
),
|
||||
wvcc.Property(name="data", data_type=wvcc.DataType.TEXT),
|
||||
wvcc.Property(name="created_at", data_type=wvcc.DataType.TEXT),
|
||||
wvcc.Property(name="category", data_type=wvcc.DataType.TEXT),
|
||||
wvcc.Property(name="updated_at", data_type=wvcc.DataType.TEXT),
|
||||
wvcc.Property(name="user_id", data_type=wvcc.DataType.TEXT),
|
||||
wvcc.Property(name="agent_id", data_type=wvcc.DataType.TEXT),
|
||||
wvcc.Property(name="run_id", data_type=wvcc.DataType.TEXT),
|
||||
]
|
||||
|
||||
vectorizer_config = wvcc.Configure.Vectorizer.none()
|
||||
vector_index_config = wvcc.Configure.VectorIndex.hnsw()
|
||||
|
||||
self.client.collections.create(
|
||||
self.collection_name,
|
||||
vectorizer_config=vectorizer_config,
|
||||
vector_index_config=vector_index_config,
|
||||
properties=properties,
|
||||
)
|
||||
|
||||
def insert(self, vectors, payloads=None, ids=None):
|
||||
"""
|
||||
Insert vectors into a collection.
|
||||
|
||||
Args:
|
||||
vectors (list): List of vectors to insert.
|
||||
payloads (list, optional): List of payloads corresponding to vectors. Defaults to None.
|
||||
ids (list, optional): List of IDs corresponding to vectors. Defaults to None.
|
||||
"""
|
||||
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
|
||||
with self.client.batch.fixed_size(batch_size=100) as batch:
|
||||
for idx, vector in enumerate(vectors):
|
||||
object_id = ids[idx] if ids and idx < len(ids) else str(uuid.uuid4())
|
||||
object_id = get_valid_uuid(object_id)
|
||||
|
||||
data_object = payloads[idx] if payloads and idx < len(payloads) else {}
|
||||
|
||||
# Ensure 'id' is not included in properties (it's used as the Weaviate object ID)
|
||||
if "ids" in data_object:
|
||||
del data_object["ids"]
|
||||
|
||||
batch.add_object(collection=self.collection_name, properties=data_object, uuid=object_id, vector=vector)
|
||||
|
||||
def search(
|
||||
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
|
||||
) -> List[OutputData]:
|
||||
"""
|
||||
Search for similar vectors.
|
||||
"""
|
||||
collection = self.client.collections.get(str(self.collection_name))
|
||||
filter_conditions = []
|
||||
if filters:
|
||||
for key, value in filters.items():
|
||||
if value and key in ["user_id", "agent_id", "run_id"]:
|
||||
filter_conditions.append(Filter.by_property(key).equal(value))
|
||||
combined_filter = Filter.all_of(filter_conditions) if filter_conditions else None
|
||||
response = collection.query.hybrid(
|
||||
query="",
|
||||
vector=vectors,
|
||||
limit=limit,
|
||||
filters=combined_filter,
|
||||
return_properties=["hash", "created_at", "updated_at", "user_id", "agent_id", "run_id", "data", "category"],
|
||||
return_metadata=MetadataQuery(score=True),
|
||||
)
|
||||
results = []
|
||||
for obj in response.objects:
|
||||
payload = obj.properties.copy()
|
||||
|
||||
for id_field in ["run_id", "agent_id", "user_id"]:
|
||||
if id_field in payload and payload[id_field] is None:
|
||||
del payload[id_field]
|
||||
|
||||
payload["id"] = str(obj.uuid).split("'")[0] # Include the id in the payload
|
||||
if obj.metadata.distance is not None:
|
||||
score = 1 - obj.metadata.distance # Convert distance to similarity score
|
||||
elif obj.metadata.score is not None:
|
||||
score = obj.metadata.score
|
||||
else:
|
||||
score = 1.0 # Default score if none provided
|
||||
results.append(
|
||||
OutputData(
|
||||
id=str(obj.uuid),
|
||||
score=score,
|
||||
payload=payload,
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
def delete(self, vector_id):
|
||||
"""
|
||||
Delete a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id: ID of the vector to delete.
|
||||
"""
|
||||
collection = self.client.collections.get(str(self.collection_name))
|
||||
collection.data.delete_by_id(vector_id)
|
||||
|
||||
def update(self, vector_id, vector=None, payload=None):
|
||||
"""
|
||||
Update a vector and its payload.
|
||||
|
||||
Args:
|
||||
vector_id: ID of the vector to update.
|
||||
vector (list, optional): Updated vector. Defaults to None.
|
||||
payload (dict, optional): Updated payload. Defaults to None.
|
||||
"""
|
||||
collection = self.client.collections.get(str(self.collection_name))
|
||||
|
||||
if payload:
|
||||
collection.data.update(uuid=vector_id, properties=payload)
|
||||
|
||||
if vector:
|
||||
existing_data = self.get(vector_id)
|
||||
if existing_data:
|
||||
existing_data = dict(existing_data)
|
||||
if "id" in existing_data:
|
||||
del existing_data["id"]
|
||||
existing_payload: Mapping[str, str] = existing_data
|
||||
collection.data.update(uuid=vector_id, properties=existing_payload, vector=vector)
|
||||
|
||||
def get(self, vector_id):
|
||||
"""
|
||||
Retrieve a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id: ID of the vector to retrieve.
|
||||
|
||||
Returns:
|
||||
dict: Retrieved vector and metadata.
|
||||
"""
|
||||
vector_id = get_valid_uuid(vector_id)
|
||||
collection = self.client.collections.get(str(self.collection_name))
|
||||
|
||||
response = collection.query.fetch_object_by_id(
|
||||
uuid=vector_id,
|
||||
return_properties=["hash", "created_at", "updated_at", "user_id", "agent_id", "run_id", "data", "category"],
|
||||
)
|
||||
# results = {}
|
||||
# print("reponse",response)
|
||||
# for obj in response.objects:
|
||||
payload = response.properties.copy()
|
||||
payload["id"] = str(response.uuid).split("'")[0]
|
||||
results = OutputData(
|
||||
id=str(response.uuid).split("'")[0],
|
||||
score=1.0,
|
||||
payload=payload,
|
||||
)
|
||||
return results
|
||||
|
||||
def list_cols(self):
|
||||
"""
|
||||
List all collections.
|
||||
|
||||
Returns:
|
||||
list: List of collection names.
|
||||
"""
|
||||
collections = self.client.collections.list_all()
|
||||
logger.debug(f"collections: {collections}")
|
||||
print(f"collections: {collections}")
|
||||
return {"collections": [{"name": col.name} for col in collections]}
|
||||
|
||||
def delete_col(self):
|
||||
"""Delete a collection."""
|
||||
self.client.collections.delete(self.collection_name)
|
||||
|
||||
def col_info(self):
|
||||
"""
|
||||
Get information about a collection.
|
||||
|
||||
Returns:
|
||||
dict: Collection information.
|
||||
"""
|
||||
schema = self.client.collections.get(self.collection_name)
|
||||
if schema:
|
||||
return schema
|
||||
return None
|
||||
|
||||
def list(self, filters=None, limit=100) -> List[OutputData]:
|
||||
"""
|
||||
List all vectors in a collection.
|
||||
"""
|
||||
collection = self.client.collections.get(self.collection_name)
|
||||
filter_conditions = []
|
||||
if filters:
|
||||
for key, value in filters.items():
|
||||
if value and key in ["user_id", "agent_id", "run_id"]:
|
||||
filter_conditions.append(Filter.by_property(key).equal(value))
|
||||
combined_filter = Filter.all_of(filter_conditions) if filter_conditions else None
|
||||
response = collection.query.fetch_objects(
|
||||
limit=limit,
|
||||
filters=combined_filter,
|
||||
return_properties=["hash", "created_at", "updated_at", "user_id", "agent_id", "run_id", "data", "category"],
|
||||
)
|
||||
results = []
|
||||
for obj in response.objects:
|
||||
payload = obj.properties.copy()
|
||||
payload["id"] = str(obj.uuid).split("'")[0]
|
||||
results.append(OutputData(id=str(obj.uuid).split("'")[0], score=1.0, payload=payload))
|
||||
return [results]
|
||||
|
||||
def reset(self):
|
||||
"""Reset the index by deleting and recreating it."""
|
||||
logger.warning(f"Resetting index {self.collection_name}...")
|
||||
self.delete_col()
|
||||
self.create_col()
|
||||
Reference in New Issue
Block a user