Initial clean commit - unified Lyra stack

This commit is contained in:
serversdwn
2025-11-16 03:17:32 -05:00
commit 94fb091e59
270 changed files with 74200 additions and 0 deletions

View File

View File

@@ -0,0 +1,396 @@
import json
import logging
import re
from typing import List, Optional
from pydantic import BaseModel
from mem0.memory.utils import extract_json
from mem0.vector_stores.base import VectorStoreBase
try:
from azure.core.credentials import AzureKeyCredential
from azure.core.exceptions import ResourceNotFoundError
from azure.identity import DefaultAzureCredential
from azure.search.documents import SearchClient
from azure.search.documents.indexes import SearchIndexClient
from azure.search.documents.indexes.models import (
BinaryQuantizationCompression,
HnswAlgorithmConfiguration,
ScalarQuantizationCompression,
SearchField,
SearchFieldDataType,
SearchIndex,
SimpleField,
VectorSearch,
VectorSearchProfile,
)
from azure.search.documents.models import VectorizedQuery
except ImportError:
raise ImportError(
"The 'azure-search-documents' library is required. Please install it using 'pip install azure-search-documents==11.5.2'."
)
logger = logging.getLogger(__name__)
class OutputData(BaseModel):
id: Optional[str]
score: Optional[float]
payload: Optional[dict]
class AzureAISearch(VectorStoreBase):
def __init__(
self,
service_name,
collection_name,
api_key,
embedding_model_dims,
compression_type: Optional[str] = None,
use_float16: bool = False,
hybrid_search: bool = False,
vector_filter_mode: Optional[str] = None,
):
"""
Initialize the Azure AI Search vector store.
Args:
service_name (str): Azure AI Search service name.
collection_name (str): Index name.
api_key (str): API key for the Azure AI Search service.
embedding_model_dims (int): Dimension of the embedding vector.
compression_type (Optional[str]): Specifies the type of quantization to use.
Allowed values are None (no quantization), "scalar", or "binary".
use_float16 (bool): Whether to store vectors in half precision (Edm.Half) or full precision (Edm.Single).
(Note: This flag is preserved from the initial implementation per feedback.)
hybrid_search (bool): Whether to use hybrid search. Default is False.
vector_filter_mode (Optional[str]): Mode for vector filtering. Default is "preFilter".
"""
self.service_name = service_name
self.api_key = api_key
self.index_name = collection_name
self.collection_name = collection_name
self.embedding_model_dims = embedding_model_dims
# If compression_type is None, treat it as "none".
self.compression_type = (compression_type or "none").lower()
self.use_float16 = use_float16
self.hybrid_search = hybrid_search
self.vector_filter_mode = vector_filter_mode
# If the API key is not provided or is a placeholder, use DefaultAzureCredential.
if self.api_key is None or self.api_key == "" or self.api_key == "your-api-key":
credential = DefaultAzureCredential()
self.api_key = None
else:
credential = AzureKeyCredential(self.api_key)
self.search_client = SearchClient(
endpoint=f"https://{service_name}.search.windows.net",
index_name=self.index_name,
credential=credential,
)
self.index_client = SearchIndexClient(
endpoint=f"https://{service_name}.search.windows.net",
credential=credential,
)
self.search_client._client._config.user_agent_policy.add_user_agent("mem0")
self.index_client._client._config.user_agent_policy.add_user_agent("mem0")
collections = self.list_cols()
if collection_name not in collections:
self.create_col()
def create_col(self):
"""Create a new index in Azure AI Search."""
# Determine vector type based on use_float16 setting.
if self.use_float16:
vector_type = "Collection(Edm.Half)"
else:
vector_type = "Collection(Edm.Single)"
# Configure compression settings based on the specified compression_type.
compression_configurations = []
compression_name = None
if self.compression_type == "scalar":
compression_name = "myCompression"
# For SQ, rescoring defaults to True and oversampling defaults to 4.
compression_configurations = [
ScalarQuantizationCompression(
compression_name=compression_name
# rescoring defaults to True and oversampling defaults to 4
)
]
elif self.compression_type == "binary":
compression_name = "myCompression"
# For BQ, rescoring defaults to True and oversampling defaults to 10.
compression_configurations = [
BinaryQuantizationCompression(
compression_name=compression_name
# rescoring defaults to True and oversampling defaults to 10
)
]
# If no compression is desired, compression_configurations remains empty.
fields = [
SimpleField(name="id", type=SearchFieldDataType.String, key=True),
SimpleField(name="user_id", type=SearchFieldDataType.String, filterable=True),
SimpleField(name="run_id", type=SearchFieldDataType.String, filterable=True),
SimpleField(name="agent_id", type=SearchFieldDataType.String, filterable=True),
SearchField(
name="vector",
type=vector_type,
searchable=True,
vector_search_dimensions=self.embedding_model_dims,
vector_search_profile_name="my-vector-config",
),
SearchField(name="payload", type=SearchFieldDataType.String, searchable=True),
]
vector_search = VectorSearch(
profiles=[
VectorSearchProfile(
name="my-vector-config",
algorithm_configuration_name="my-algorithms-config",
compression_name=compression_name if self.compression_type != "none" else None,
)
],
algorithms=[HnswAlgorithmConfiguration(name="my-algorithms-config")],
compressions=compression_configurations,
)
index = SearchIndex(name=self.index_name, fields=fields, vector_search=vector_search)
self.index_client.create_or_update_index(index)
def _generate_document(self, vector, payload, id):
document = {"id": id, "vector": vector, "payload": json.dumps(payload)}
# Extract additional fields if they exist.
for field in ["user_id", "run_id", "agent_id"]:
if field in payload:
document[field] = payload[field]
return document
# Note: Explicit "insert" calls may later be decoupled from memory management decisions.
def insert(self, vectors, payloads=None, ids=None):
"""
Insert vectors into the index.
Args:
vectors (List[List[float]]): List of vectors to insert.
payloads (List[Dict], optional): List of payloads corresponding to vectors.
ids (List[str], optional): List of IDs corresponding to vectors.
"""
logger.info(f"Inserting {len(vectors)} vectors into index {self.index_name}")
documents = [
self._generate_document(vector, payload, id) for id, vector, payload in zip(ids, vectors, payloads)
]
response = self.search_client.upload_documents(documents)
for doc in response:
if not hasattr(doc, "status_code") and doc.get("status_code") != 201:
raise Exception(f"Insert failed for document {doc.get('id')}: {doc}")
return response
def _sanitize_key(self, key: str) -> str:
return re.sub(r"[^\w]", "", key)
def _build_filter_expression(self, filters):
filter_conditions = []
for key, value in filters.items():
safe_key = self._sanitize_key(key)
if isinstance(value, str):
safe_value = value.replace("'", "''")
condition = f"{safe_key} eq '{safe_value}'"
else:
condition = f"{safe_key} eq {value}"
filter_conditions.append(condition)
filter_expression = " and ".join(filter_conditions)
return filter_expression
def search(self, query, vectors, limit=5, filters=None):
"""
Search for similar vectors.
Args:
query (str): Query.
vectors (List[float]): Query vector.
limit (int, optional): Number of results to return. Defaults to 5.
filters (Dict, optional): Filters to apply to the search. Defaults to None.
Returns:
List[OutputData]: Search results.
"""
filter_expression = None
if filters:
filter_expression = self._build_filter_expression(filters)
vector_query = VectorizedQuery(vector=vectors, k_nearest_neighbors=limit, fields="vector")
if self.hybrid_search:
search_results = self.search_client.search(
search_text=query,
vector_queries=[vector_query],
filter=filter_expression,
top=limit,
vector_filter_mode=self.vector_filter_mode,
search_fields=["payload"],
)
else:
search_results = self.search_client.search(
vector_queries=[vector_query],
filter=filter_expression,
top=limit,
vector_filter_mode=self.vector_filter_mode,
)
results = []
for result in search_results:
payload = json.loads(extract_json(result["payload"]))
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
return results
def delete(self, vector_id):
"""
Delete a vector by ID.
Args:
vector_id (str): ID of the vector to delete.
"""
response = self.search_client.delete_documents(documents=[{"id": vector_id}])
for doc in response:
if not hasattr(doc, "status_code") and doc.get("status_code") != 200:
raise Exception(f"Delete failed for document {vector_id}: {doc}")
logger.info(f"Deleted document with ID '{vector_id}' from index '{self.index_name}'.")
return response
def update(self, vector_id, vector=None, payload=None):
"""
Update a vector and its payload.
Args:
vector_id (str): ID of the vector to update.
vector (List[float], optional): Updated vector.
payload (Dict, optional): Updated payload.
"""
document = {"id": vector_id}
if vector:
document["vector"] = vector
if payload:
json_payload = json.dumps(payload)
document["payload"] = json_payload
for field in ["user_id", "run_id", "agent_id"]:
document[field] = payload.get(field)
response = self.search_client.merge_or_upload_documents(documents=[document])
for doc in response:
if not hasattr(doc, "status_code") and doc.get("status_code") != 200:
raise Exception(f"Update failed for document {vector_id}: {doc}")
return response
def get(self, vector_id) -> OutputData:
"""
Retrieve a vector by ID.
Args:
vector_id (str): ID of the vector to retrieve.
Returns:
OutputData: Retrieved vector.
"""
try:
result = self.search_client.get_document(key=vector_id)
except ResourceNotFoundError:
return None
payload = json.loads(extract_json(result["payload"]))
return OutputData(id=result["id"], score=None, payload=payload)
def list_cols(self) -> List[str]:
"""
List all collections (indexes).
Returns:
List[str]: List of index names.
"""
try:
names = self.index_client.list_index_names()
except AttributeError:
names = [index.name for index in self.index_client.list_indexes()]
return names
def delete_col(self):
"""Delete the index."""
self.index_client.delete_index(self.index_name)
def col_info(self):
"""
Get information about the index.
Returns:
dict: Index information.
"""
index = self.index_client.get_index(self.index_name)
return {"name": index.name, "fields": index.fields}
def list(self, filters=None, limit=100):
"""
List all vectors in the index.
Args:
filters (dict, optional): Filters to apply to the list.
limit (int, optional): Number of vectors to return. Defaults to 100.
Returns:
List[OutputData]: List of vectors.
"""
filter_expression = None
if filters:
filter_expression = self._build_filter_expression(filters)
search_results = self.search_client.search(search_text="*", filter=filter_expression, top=limit)
results = []
for result in search_results:
payload = json.loads(extract_json(result["payload"]))
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
return [results]
def __del__(self):
"""Close the search client when the object is deleted."""
self.search_client.close()
self.index_client.close()
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.index_name}...")
try:
# Close the existing clients
self.search_client.close()
self.index_client.close()
# Delete the collection
self.delete_col()
# If the API key is not provided or is a placeholder, use DefaultAzureCredential.
if self.api_key is None or self.api_key == "" or self.api_key == "your-api-key":
credential = DefaultAzureCredential()
self.api_key = None
else:
credential = AzureKeyCredential(self.api_key)
# Reinitialize the clients
service_endpoint = f"https://{self.service_name}.search.windows.net"
self.search_client = SearchClient(
endpoint=service_endpoint,
index_name=self.index_name,
credential=credential,
)
self.index_client = SearchIndexClient(
endpoint=service_endpoint,
credential=credential,
)
# Add user agent
self.search_client._client._config.user_agent_policy.add_user_agent("mem0")
self.index_client._client._config.user_agent_policy.add_user_agent("mem0")
# Create the collection
self.create_col()
except Exception as e:
logger.error(f"Error resetting index {self.index_name}: {e}")
raise

View File

@@ -0,0 +1,463 @@
import json
import logging
from contextlib import contextmanager
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
try:
import pymysql
from pymysql.cursors import DictCursor
from dbutils.pooled_db import PooledDB
except ImportError:
raise ImportError(
"Azure MySQL vector store requires PyMySQL and DBUtils. "
"Please install them using 'pip install pymysql dbutils'"
)
try:
from azure.identity import DefaultAzureCredential
AZURE_IDENTITY_AVAILABLE = True
except ImportError:
AZURE_IDENTITY_AVAILABLE = False
from mem0.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__)
class OutputData(BaseModel):
id: Optional[str]
score: Optional[float]
payload: Optional[dict]
class AzureMySQL(VectorStoreBase):
def __init__(
self,
host: str,
port: int,
user: str,
password: Optional[str],
database: str,
collection_name: str,
embedding_model_dims: int,
use_azure_credential: bool = False,
ssl_ca: Optional[str] = None,
ssl_disabled: bool = False,
minconn: int = 1,
maxconn: int = 5,
connection_pool: Optional[Any] = None,
):
"""
Initialize the Azure MySQL vector store.
Args:
host (str): MySQL server host
port (int): MySQL server port
user (str): Database user
password (str, optional): Database password (not required if using Azure credential)
database (str): Database name
collection_name (str): Collection/table name
embedding_model_dims (int): Dimension of the embedding vector
use_azure_credential (bool): Use Azure DefaultAzureCredential for authentication
ssl_ca (str, optional): Path to SSL CA certificate
ssl_disabled (bool): Disable SSL connection
minconn (int): Minimum number of connections in the pool
maxconn (int): Maximum number of connections in the pool
connection_pool (Any, optional): Pre-configured connection pool
"""
self.host = host
self.port = port
self.user = user
self.password = password
self.database = database
self.collection_name = collection_name
self.embedding_model_dims = embedding_model_dims
self.use_azure_credential = use_azure_credential
self.ssl_ca = ssl_ca
self.ssl_disabled = ssl_disabled
self.connection_pool = connection_pool
# Handle Azure authentication
if use_azure_credential:
if not AZURE_IDENTITY_AVAILABLE:
raise ImportError(
"Azure Identity is required for Azure credential authentication. "
"Please install it using 'pip install azure-identity'"
)
self._setup_azure_auth()
# Setup connection pool
if self.connection_pool is None:
self._setup_connection_pool(minconn, maxconn)
# Create collection if it doesn't exist
collections = self.list_cols()
if collection_name not in collections:
self.create_col(name=collection_name, vector_size=embedding_model_dims, distance="cosine")
def _setup_azure_auth(self):
"""Setup Azure authentication using DefaultAzureCredential."""
try:
credential = DefaultAzureCredential()
# Get access token for Azure Database for MySQL
token = credential.get_token("https://ossrdbms-aad.database.windows.net/.default")
# Use token as password
self.password = token.token
logger.info("Successfully authenticated using Azure DefaultAzureCredential")
except Exception as e:
logger.error(f"Failed to authenticate with Azure: {e}")
raise
def _setup_connection_pool(self, minconn: int, maxconn: int):
"""Setup MySQL connection pool."""
connect_kwargs = {
"host": self.host,
"port": self.port,
"user": self.user,
"password": self.password,
"database": self.database,
"charset": "utf8mb4",
"cursorclass": DictCursor,
"autocommit": False,
}
# SSL configuration
if not self.ssl_disabled:
ssl_config = {"ssl_verify_cert": True}
if self.ssl_ca:
ssl_config["ssl_ca"] = self.ssl_ca
connect_kwargs["ssl"] = ssl_config
try:
self.connection_pool = PooledDB(
creator=pymysql,
mincached=minconn,
maxcached=maxconn,
maxconnections=maxconn,
blocking=True,
**connect_kwargs
)
logger.info("Successfully created MySQL connection pool")
except Exception as e:
logger.error(f"Failed to create connection pool: {e}")
raise
@contextmanager
def _get_cursor(self, commit: bool = False):
"""
Context manager to get a cursor from the connection pool.
Auto-commits or rolls back based on exception.
"""
conn = self.connection_pool.connection()
cur = conn.cursor()
try:
yield cur
if commit:
conn.commit()
except Exception as exc:
conn.rollback()
logger.error(f"Database error: {exc}", exc_info=True)
raise
finally:
cur.close()
conn.close()
def create_col(self, name: str = None, vector_size: int = None, distance: str = "cosine"):
"""
Create a new collection (table in MySQL).
Enables vector extension and creates appropriate indexes.
Args:
name (str, optional): Collection name (uses self.collection_name if not provided)
vector_size (int, optional): Vector dimension (uses self.embedding_model_dims if not provided)
distance (str): Distance metric (cosine, euclidean, dot_product)
"""
table_name = name or self.collection_name
dims = vector_size or self.embedding_model_dims
with self._get_cursor(commit=True) as cur:
# Create table with vector column
cur.execute(f"""
CREATE TABLE IF NOT EXISTS `{table_name}` (
id VARCHAR(255) PRIMARY KEY,
vector JSON,
payload JSON,
INDEX idx_payload_keys ((CAST(payload AS CHAR(255)) ARRAY))
)
""")
logger.info(f"Created collection '{table_name}' with vector dimension {dims}")
def insert(self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None):
"""
Insert vectors into the collection.
Args:
vectors (List[List[float]]): List of vectors to insert
payloads (List[Dict], optional): List of payloads corresponding to vectors
ids (List[str], optional): List of IDs corresponding to vectors
"""
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
if payloads is None:
payloads = [{}] * len(vectors)
if ids is None:
import uuid
ids = [str(uuid.uuid4()) for _ in range(len(vectors))]
data = []
for vector, payload, vec_id in zip(vectors, payloads, ids):
data.append((vec_id, json.dumps(vector), json.dumps(payload)))
with self._get_cursor(commit=True) as cur:
cur.executemany(
f"INSERT INTO `{self.collection_name}` (id, vector, payload) VALUES (%s, %s, %s) "
f"ON DUPLICATE KEY UPDATE vector = VALUES(vector), payload = VALUES(payload)",
data
)
def _cosine_distance(self, vec1_json: str, vec2: List[float]) -> str:
"""Generate SQL for cosine distance calculation."""
# For MySQL, we need to calculate cosine similarity manually
# This is a simplified version - in production, you'd use stored procedures or UDFs
return """
1 - (
(SELECT SUM(a.val * b.val) /
(SQRT(SUM(a.val * a.val)) * SQRT(SUM(b.val * b.val))))
FROM (
SELECT JSON_EXTRACT(vector, CONCAT('$[', idx, ']')) as val
FROM (SELECT @row := @row + 1 as idx FROM (SELECT 0 UNION ALL SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3) t1, (SELECT 0 UNION ALL SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3) t2) indices
WHERE idx < JSON_LENGTH(vector)
) a,
(
SELECT JSON_EXTRACT(%s, CONCAT('$[', idx, ']')) as val
FROM (SELECT @row := @row + 1 as idx FROM (SELECT 0 UNION ALL SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3) t1, (SELECT 0 UNION ALL SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3) t2) indices
WHERE idx < JSON_LENGTH(%s)
) b
WHERE a.idx = b.idx
)
"""
def search(
self,
query: str,
vectors: List[float],
limit: int = 5,
filters: Optional[Dict] = None,
) -> List[OutputData]:
"""
Search for similar vectors using cosine similarity.
Args:
query (str): Query string (not used in vector search)
vectors (List[float]): Query vector
limit (int): Number of results to return
filters (Dict, optional): Filters to apply to the search
Returns:
List[OutputData]: Search results
"""
filter_conditions = []
filter_params = []
if filters:
for k, v in filters.items():
filter_conditions.append("JSON_EXTRACT(payload, %s) = %s")
filter_params.extend([f"$.{k}", json.dumps(v)])
filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
# For simplicity, we'll compute cosine similarity in Python
# In production, you'd want to use MySQL stored procedures or UDFs
with self._get_cursor() as cur:
query_sql = f"""
SELECT id, vector, payload
FROM `{self.collection_name}`
{filter_clause}
"""
cur.execute(query_sql, filter_params)
results = cur.fetchall()
# Calculate cosine similarity in Python
import numpy as np
query_vec = np.array(vectors)
scored_results = []
for row in results:
vec = np.array(json.loads(row['vector']))
# Cosine similarity
similarity = np.dot(query_vec, vec) / (np.linalg.norm(query_vec) * np.linalg.norm(vec))
distance = 1 - similarity
scored_results.append((row['id'], distance, row['payload']))
# Sort by distance and limit
scored_results.sort(key=lambda x: x[1])
scored_results = scored_results[:limit]
return [
OutputData(id=r[0], score=float(r[1]), payload=json.loads(r[2]) if isinstance(r[2], str) else r[2])
for r in scored_results
]
def delete(self, vector_id: str):
"""
Delete a vector by ID.
Args:
vector_id (str): ID of the vector to delete
"""
with self._get_cursor(commit=True) as cur:
cur.execute(f"DELETE FROM `{self.collection_name}` WHERE id = %s", (vector_id,))
def update(
self,
vector_id: str,
vector: Optional[List[float]] = None,
payload: Optional[Dict] = None,
):
"""
Update a vector and its payload.
Args:
vector_id (str): ID of the vector to update
vector (List[float], optional): Updated vector
payload (Dict, optional): Updated payload
"""
with self._get_cursor(commit=True) as cur:
if vector is not None:
cur.execute(
f"UPDATE `{self.collection_name}` SET vector = %s WHERE id = %s",
(json.dumps(vector), vector_id),
)
if payload is not None:
cur.execute(
f"UPDATE `{self.collection_name}` SET payload = %s WHERE id = %s",
(json.dumps(payload), vector_id),
)
def get(self, vector_id: str) -> Optional[OutputData]:
"""
Retrieve a vector by ID.
Args:
vector_id (str): ID of the vector to retrieve
Returns:
OutputData: Retrieved vector or None if not found
"""
with self._get_cursor() as cur:
cur.execute(
f"SELECT id, vector, payload FROM `{self.collection_name}` WHERE id = %s",
(vector_id,),
)
result = cur.fetchone()
if not result:
return None
return OutputData(
id=result['id'],
score=None,
payload=json.loads(result['payload']) if isinstance(result['payload'], str) else result['payload']
)
def list_cols(self) -> List[str]:
"""
List all collections (tables).
Returns:
List[str]: List of collection names
"""
with self._get_cursor() as cur:
cur.execute("SHOW TABLES")
return [row[f"Tables_in_{self.database}"] for row in cur.fetchall()]
def delete_col(self):
"""Delete the collection (table)."""
with self._get_cursor(commit=True) as cur:
cur.execute(f"DROP TABLE IF EXISTS `{self.collection_name}`")
logger.info(f"Deleted collection '{self.collection_name}'")
def col_info(self) -> Dict[str, Any]:
"""
Get information about the collection.
Returns:
Dict[str, Any]: Collection information
"""
with self._get_cursor() as cur:
cur.execute("""
SELECT
TABLE_NAME as name,
TABLE_ROWS as count,
ROUND(((DATA_LENGTH + INDEX_LENGTH) / 1024 / 1024), 2) as size_mb
FROM information_schema.TABLES
WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s
""", (self.database, self.collection_name))
result = cur.fetchone()
if result:
return {
"name": result['name'],
"count": result['count'],
"size": f"{result['size_mb']} MB"
}
return {}
def list(
self,
filters: Optional[Dict] = None,
limit: int = 100
) -> List[List[OutputData]]:
"""
List all vectors in the collection.
Args:
filters (Dict, optional): Filters to apply
limit (int): Number of vectors to return
Returns:
List[List[OutputData]]: List of vectors
"""
filter_conditions = []
filter_params = []
if filters:
for k, v in filters.items():
filter_conditions.append("JSON_EXTRACT(payload, %s) = %s")
filter_params.extend([f"$.{k}", json.dumps(v)])
filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
with self._get_cursor() as cur:
cur.execute(
f"""
SELECT id, vector, payload
FROM `{self.collection_name}`
{filter_clause}
LIMIT %s
""",
(*filter_params, limit)
)
results = cur.fetchall()
return [[
OutputData(
id=r['id'],
score=None,
payload=json.loads(r['payload']) if isinstance(r['payload'], str) else r['payload']
) for r in results
]]
def reset(self):
"""Reset the collection by deleting and recreating it."""
logger.warning(f"Resetting collection {self.collection_name}...")
self.delete_col()
self.create_col(name=self.collection_name, vector_size=self.embedding_model_dims)
def __del__(self):
"""Close the connection pool when the object is deleted."""
try:
if hasattr(self, 'connection_pool') and self.connection_pool:
self.connection_pool.close()
except Exception:
pass

View File

@@ -0,0 +1,368 @@
import logging
import time
from typing import Dict, Optional
from pydantic import BaseModel
from mem0.vector_stores.base import VectorStoreBase
try:
import pymochow
from pymochow.auth.bce_credentials import BceCredentials
from pymochow.configuration import Configuration
from pymochow.exception import ServerError
from pymochow.model.enum import (
FieldType,
IndexType,
MetricType,
ServerErrCode,
TableState,
)
from pymochow.model.schema import (
AutoBuildRowCountIncrement,
Field,
FilteringIndex,
HNSWParams,
Schema,
VectorIndex,
)
from pymochow.model.table import (
FloatVector,
Partition,
Row,
VectorSearchConfig,
VectorTopkSearchRequest,
)
except ImportError:
raise ImportError("The 'pymochow' library is required. Please install it using 'pip install pymochow'.")
logger = logging.getLogger(__name__)
class OutputData(BaseModel):
id: Optional[str] # memory id
score: Optional[float] # distance
payload: Optional[Dict] # metadata
class BaiduDB(VectorStoreBase):
def __init__(
self,
endpoint: str,
account: str,
api_key: str,
database_name: str,
table_name: str,
embedding_model_dims: int,
metric_type: MetricType,
) -> None:
"""Initialize the BaiduDB database.
Args:
endpoint (str): Endpoint URL for Baidu VectorDB.
account (str): Account for Baidu VectorDB.
api_key (str): API Key for Baidu VectorDB.
database_name (str): Name of the database.
table_name (str): Name of the table.
embedding_model_dims (int): Dimensions of the embedding model.
metric_type (MetricType): Metric type for similarity search.
"""
self.endpoint = endpoint
self.account = account
self.api_key = api_key
self.database_name = database_name
self.table_name = table_name
self.embedding_model_dims = embedding_model_dims
self.metric_type = metric_type
# Initialize Mochow client
config = Configuration(credentials=BceCredentials(account, api_key), endpoint=endpoint)
self.client = pymochow.MochowClient(config)
# Ensure database and table exist
self._create_database_if_not_exists()
self.create_col(
name=self.table_name,
vector_size=self.embedding_model_dims,
distance=self.metric_type,
)
def _create_database_if_not_exists(self):
"""Create database if it doesn't exist."""
try:
# Check if database exists
databases = self.client.list_databases()
db_exists = any(db.database_name == self.database_name for db in databases)
if not db_exists:
self._database = self.client.create_database(self.database_name)
logger.info(f"Created database: {self.database_name}")
else:
self._database = self.client.database(self.database_name)
logger.info(f"Database {self.database_name} already exists")
except Exception as e:
logger.error(f"Error creating database: {e}")
raise
def create_col(self, name, vector_size, distance):
"""Create a new table.
Args:
name (str): Name of the table to create.
vector_size (int): Dimension of the vector.
distance (str): Metric type for similarity search.
"""
# Check if table already exists
try:
tables = self._database.list_table()
table_exists = any(table.table_name == name for table in tables)
if table_exists:
logger.info(f"Table {name} already exists. Skipping creation.")
self._table = self._database.describe_table(name)
return
# Convert distance string to MetricType enum
metric_type = None
for k, v in MetricType.__members__.items():
if k == distance:
metric_type = v
if metric_type is None:
raise ValueError(f"Unsupported metric_type: {distance}")
# Define table schema
fields = [
Field(
"id", FieldType.STRING, primary_key=True, partition_key=True, auto_increment=False, not_null=True
),
Field("vector", FieldType.FLOAT_VECTOR, dimension=vector_size),
Field("metadata", FieldType.JSON),
]
# Create vector index
indexes = [
VectorIndex(
index_name="vector_idx",
index_type=IndexType.HNSW,
field="vector",
metric_type=metric_type,
params=HNSWParams(m=16, efconstruction=200),
auto_build=True,
auto_build_index_policy=AutoBuildRowCountIncrement(row_count_increment=10000),
),
FilteringIndex(index_name="metadata_filtering_idx", fields=["metadata"]),
]
schema = Schema(fields=fields, indexes=indexes)
# Create table
self._table = self._database.create_table(
table_name=name, replication=3, partition=Partition(partition_num=1), schema=schema
)
logger.info(f"Created table: {name}")
# Wait for table to be ready
while True:
time.sleep(2)
table = self._database.describe_table(name)
if table.state == TableState.NORMAL:
logger.info(f"Table {name} is ready.")
break
logger.info(f"Waiting for table {name} to be ready, current state: {table.state}")
self._table = table
except Exception as e:
logger.error(f"Error creating table: {e}")
raise
def insert(self, vectors, payloads=None, ids=None):
"""Insert vectors into the table.
Args:
vectors (List[List[float]]): List of vectors to insert.
payloads (List[Dict], optional): List of payloads corresponding to vectors.
ids (List[str], optional): List of IDs corresponding to vectors.
"""
# Prepare data for insertion
for idx, vector, metadata in zip(ids, vectors, payloads):
row = Row(id=idx, vector=vector, metadata=metadata)
self._table.upsert(rows=[row])
def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None) -> list:
"""
Search for similar vectors.
Args:
query (str): Query string.
vectors (List[float]): Query vector.
limit (int, optional): Number of results to return. Defaults to 5.
filters (Dict, optional): Filters to apply to the search. Defaults to None.
Returns:
list: Search results.
"""
# Add filters if provided
search_filter = None
if filters:
search_filter = self._create_filter(filters)
# Create AnnSearch for vector search
request = VectorTopkSearchRequest(
vector_field="vector",
vector=FloatVector(vectors),
limit=limit,
filter=search_filter,
config=VectorSearchConfig(ef=200),
)
# Perform search
projections = ["id", "metadata"]
res = self._table.vector_search(request=request, projections=projections)
# Parse results
output = []
for row in res.rows:
row_data = row.get("row", {})
output_data = OutputData(
id=row_data.get("id"), score=row.get("score", 0.0), payload=row_data.get("metadata", {})
)
output.append(output_data)
return output
def delete(self, vector_id):
"""
Delete a vector by ID.
Args:
vector_id (str): ID of the vector to delete.
"""
self._table.delete(primary_key={"id": vector_id})
def update(self, vector_id=None, vector=None, payload=None):
"""
Update a vector and its payload.
Args:
vector_id (str): ID of the vector to update.
vector (List[float], optional): Updated vector.
payload (Dict, optional): Updated payload.
"""
row = Row(id=vector_id, vector=vector, metadata=payload)
self._table.upsert(rows=[row])
def get(self, vector_id):
"""
Retrieve a vector by ID.
Args:
vector_id (str): ID of the vector to retrieve.
Returns:
OutputData: Retrieved vector.
"""
projections = ["id", "metadata"]
result = self._table.query(primary_key={"id": vector_id}, projections=projections)
row = result.row
return OutputData(id=row.get("id"), score=None, payload=row.get("metadata", {}))
def list_cols(self):
"""
List all tables (collections).
Returns:
List[str]: List of table names.
"""
tables = self._database.list_table()
return [table.table_name for table in tables]
def delete_col(self):
"""Delete the table."""
try:
tables = self._database.list_table()
# skip drop table if table not exists
table_exists = any(table.table_name == self.table_name for table in tables)
if not table_exists:
logger.info(f"Table {self.table_name} does not exist, skipping deletion")
return
# Delete the table
self._database.drop_table(self.table_name)
logger.info(f"Initiated deletion of table {self.table_name}")
# Wait for table to be completely deleted
while True:
time.sleep(2)
try:
self._database.describe_table(self.table_name)
logger.info(f"Waiting for table {self.table_name} to be deleted...")
except ServerError as e:
if e.code == ServerErrCode.TABLE_NOT_EXIST:
logger.info(f"Table {self.table_name} has been completely deleted")
break
logger.error(f"Error checking table status: {e}")
raise
except Exception as e:
logger.error(f"Error deleting table: {e}")
raise
def col_info(self):
"""
Get information about the table.
Returns:
Dict[str, Any]: Table information.
"""
return self._table.stats()
def list(self, filters: dict = None, limit: int = 100) -> list:
"""
List all vectors in the table.
Args:
filters (Dict, optional): Filters to apply to the list.
limit (int, optional): Number of vectors to return. Defaults to 100.
Returns:
List[OutputData]: List of vectors.
"""
projections = ["id", "metadata"]
list_filter = self._create_filter(filters) if filters else None
result = self._table.select(filter=list_filter, projections=projections, limit=limit)
memories = []
for row in result.rows:
obj = OutputData(id=row.get("id"), score=None, payload=row.get("metadata", {}))
memories.append(obj)
return [memories]
def reset(self):
"""Reset the table by deleting and recreating it."""
logger.warning(f"Resetting table {self.table_name}...")
try:
self.delete_col()
self.create_col(
name=self.table_name,
vector_size=self.embedding_model_dims,
distance=self.metric_type,
)
except Exception as e:
logger.warning(f"Error resetting table: {e}")
raise
def _create_filter(self, filters: dict) -> str:
"""
Create filter expression for queries.
Args:
filters (dict): Filter conditions.
Returns:
str: Filter expression.
"""
conditions = []
for key, value in filters.items():
if isinstance(value, str):
conditions.append(f'metadata["{key}"] = "{value}"')
else:
conditions.append(f'metadata["{key}"] = {value}')
return " AND ".join(conditions)

View File

@@ -0,0 +1,58 @@
from abc import ABC, abstractmethod
class VectorStoreBase(ABC):
@abstractmethod
def create_col(self, name, vector_size, distance):
"""Create a new collection."""
pass
@abstractmethod
def insert(self, vectors, payloads=None, ids=None):
"""Insert vectors into a collection."""
pass
@abstractmethod
def search(self, query, vectors, limit=5, filters=None):
"""Search for similar vectors."""
pass
@abstractmethod
def delete(self, vector_id):
"""Delete a vector by ID."""
pass
@abstractmethod
def update(self, vector_id, vector=None, payload=None):
"""Update a vector and its payload."""
pass
@abstractmethod
def get(self, vector_id):
"""Retrieve a vector by ID."""
pass
@abstractmethod
def list_cols(self):
"""List all collections."""
pass
@abstractmethod
def delete_col(self):
"""Delete a collection."""
pass
@abstractmethod
def col_info(self):
"""Get information about a collection."""
pass
@abstractmethod
def list(self, filters=None, limit=None):
"""List all memories."""
pass
@abstractmethod
def reset(self):
"""Reset by delete the collection and recreate it."""
pass

View File

@@ -0,0 +1,267 @@
import logging
from typing import Dict, List, Optional
from pydantic import BaseModel
try:
import chromadb
from chromadb.config import Settings
except ImportError:
raise ImportError("The 'chromadb' library is required. Please install it using 'pip install chromadb'.")
from mem0.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__)
class OutputData(BaseModel):
id: Optional[str] # memory id
score: Optional[float] # distance
payload: Optional[Dict] # metadata
class ChromaDB(VectorStoreBase):
def __init__(
self,
collection_name: str,
client: Optional[chromadb.Client] = None,
host: Optional[str] = None,
port: Optional[int] = None,
path: Optional[str] = None,
api_key: Optional[str] = None,
tenant: Optional[str] = None,
):
"""
Initialize the Chromadb vector store.
Args:
collection_name (str): Name of the collection.
client (chromadb.Client, optional): Existing chromadb client instance. Defaults to None.
host (str, optional): Host address for chromadb server. Defaults to None.
port (int, optional): Port for chromadb server. Defaults to None.
path (str, optional): Path for local chromadb database. Defaults to None.
api_key (str, optional): ChromaDB Cloud API key. Defaults to None.
tenant (str, optional): ChromaDB Cloud tenant ID. Defaults to None.
"""
if client:
self.client = client
elif api_key and tenant:
# Initialize ChromaDB Cloud client
logger.info("Initializing ChromaDB Cloud client")
self.client = chromadb.CloudClient(
api_key=api_key,
tenant=tenant,
database="mem0" # Use fixed database name for cloud
)
else:
# Initialize local or server client
self.settings = Settings(anonymized_telemetry=False)
if host and port:
self.settings.chroma_server_host = host
self.settings.chroma_server_http_port = port
self.settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI"
else:
if path is None:
path = "db"
self.settings.persist_directory = path
self.settings.is_persistent = True
self.client = chromadb.Client(self.settings)
self.collection_name = collection_name
self.collection = self.create_col(collection_name)
def _parse_output(self, data: Dict) -> List[OutputData]:
"""
Parse the output data.
Args:
data (Dict): Output data.
Returns:
List[OutputData]: Parsed output data.
"""
keys = ["ids", "distances", "metadatas"]
values = []
for key in keys:
value = data.get(key, [])
if isinstance(value, list) and value and isinstance(value[0], list):
value = value[0]
values.append(value)
ids, distances, metadatas = values
max_length = max(len(v) for v in values if isinstance(v, list) and v is not None)
result = []
for i in range(max_length):
entry = OutputData(
id=ids[i] if isinstance(ids, list) and ids and i < len(ids) else None,
score=(distances[i] if isinstance(distances, list) and distances and i < len(distances) else None),
payload=(metadatas[i] if isinstance(metadatas, list) and metadatas and i < len(metadatas) else None),
)
result.append(entry)
return result
def create_col(self, name: str, embedding_fn: Optional[callable] = None):
"""
Create a new collection.
Args:
name (str): Name of the collection.
embedding_fn (Optional[callable]): Embedding function to use. Defaults to None.
Returns:
chromadb.Collection: The created or retrieved collection.
"""
collection = self.client.get_or_create_collection(
name=name,
embedding_function=embedding_fn,
)
return collection
def insert(
self,
vectors: List[list],
payloads: Optional[List[Dict]] = None,
ids: Optional[List[str]] = None,
):
"""
Insert vectors into a collection.
Args:
vectors (List[list]): List of vectors to insert.
payloads (Optional[List[Dict]], optional): List of payloads corresponding to vectors. Defaults to None.
ids (Optional[List[str]], optional): List of IDs corresponding to vectors. Defaults to None.
"""
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
self.collection.add(ids=ids, embeddings=vectors, metadatas=payloads)
def search(
self, query: str, vectors: List[list], limit: int = 5, filters: Optional[Dict] = None
) -> List[OutputData]:
"""
Search for similar vectors.
Args:
query (str): Query.
vectors (List[list]): List of vectors to search.
limit (int, optional): Number of results to return. Defaults to 5.
filters (Optional[Dict], optional): Filters to apply to the search. Defaults to None.
Returns:
List[OutputData]: Search results.
"""
where_clause = self._generate_where_clause(filters) if filters else None
results = self.collection.query(query_embeddings=vectors, where=where_clause, n_results=limit)
final_results = self._parse_output(results)
return final_results
def delete(self, vector_id: str):
"""
Delete a vector by ID.
Args:
vector_id (str): ID of the vector to delete.
"""
self.collection.delete(ids=vector_id)
def update(
self,
vector_id: str,
vector: Optional[List[float]] = None,
payload: Optional[Dict] = None,
):
"""
Update a vector and its payload.
Args:
vector_id (str): ID of the vector to update.
vector (Optional[List[float]], optional): Updated vector. Defaults to None.
payload (Optional[Dict], optional): Updated payload. Defaults to None.
"""
self.collection.update(ids=vector_id, embeddings=vector, metadatas=payload)
def get(self, vector_id: str) -> OutputData:
"""
Retrieve a vector by ID.
Args:
vector_id (str): ID of the vector to retrieve.
Returns:
OutputData: Retrieved vector.
"""
result = self.collection.get(ids=[vector_id])
return self._parse_output(result)[0]
def list_cols(self) -> List[chromadb.Collection]:
"""
List all collections.
Returns:
List[chromadb.Collection]: List of collections.
"""
return self.client.list_collections()
def delete_col(self):
"""
Delete a collection.
"""
self.client.delete_collection(name=self.collection_name)
def col_info(self) -> Dict:
"""
Get information about a collection.
Returns:
Dict: Collection information.
"""
return self.client.get_collection(name=self.collection_name)
def list(self, filters: Optional[Dict] = None, limit: int = 100) -> List[OutputData]:
"""
List all vectors in a collection.
Args:
filters (Optional[Dict], optional): Filters to apply to the list. Defaults to None.
limit (int, optional): Number of vectors to return. Defaults to 100.
Returns:
List[OutputData]: List of vectors.
"""
where_clause = self._generate_where_clause(filters) if filters else None
results = self.collection.get(where=where_clause, limit=limit)
return [self._parse_output(results)]
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")
self.delete_col()
self.collection = self.create_col(self.collection_name)
@staticmethod
def _generate_where_clause(where: dict[str, any]) -> dict[str, any]:
"""
Generate a properly formatted where clause for ChromaDB.
Args:
where (dict[str, any]): The filter conditions.
Returns:
dict[str, any]: Properly formatted where clause for ChromaDB.
"""
# If only one filter is supplied, return it as is
# (no need to wrap in $and based on chroma docs)
if where is None:
return {}
if len(where.keys()) <= 1:
return where
where_filters = []
for k, v in where.items():
if isinstance(v, str):
where_filters.append({k: v})
return {"$and": where_filters}

View File

@@ -0,0 +1,65 @@
from typing import Dict, Optional
from pydantic import BaseModel, Field, model_validator
class VectorStoreConfig(BaseModel):
provider: str = Field(
description="Provider of the vector store (e.g., 'qdrant', 'chroma', 'upstash_vector')",
default="qdrant",
)
config: Optional[Dict] = Field(description="Configuration for the specific vector store", default=None)
_provider_configs: Dict[str, str] = {
"qdrant": "QdrantConfig",
"chroma": "ChromaDbConfig",
"pgvector": "PGVectorConfig",
"pinecone": "PineconeConfig",
"mongodb": "MongoDBConfig",
"milvus": "MilvusDBConfig",
"baidu": "BaiduDBConfig",
"neptune": "NeptuneAnalyticsConfig",
"upstash_vector": "UpstashVectorConfig",
"azure_ai_search": "AzureAISearchConfig",
"azure_mysql": "AzureMySQLConfig",
"redis": "RedisDBConfig",
"valkey": "ValkeyConfig",
"databricks": "DatabricksConfig",
"elasticsearch": "ElasticsearchConfig",
"vertex_ai_vector_search": "GoogleMatchingEngineConfig",
"opensearch": "OpenSearchConfig",
"supabase": "SupabaseConfig",
"weaviate": "WeaviateConfig",
"faiss": "FAISSConfig",
"langchain": "LangchainConfig",
"s3_vectors": "S3VectorsConfig",
}
@model_validator(mode="after")
def validate_and_create_config(self) -> "VectorStoreConfig":
provider = self.provider
config = self.config
if provider not in self._provider_configs:
raise ValueError(f"Unsupported vector store provider: {provider}")
module = __import__(
f"mem0.configs.vector_stores.{provider}",
fromlist=[self._provider_configs[provider]],
)
config_class = getattr(module, self._provider_configs[provider])
if config is None:
config = {}
if not isinstance(config, dict):
if not isinstance(config, config_class):
raise ValueError(f"Invalid config type for provider {provider}")
return self
# also check if path in allowed kays for pydantic model, and whether config extra fields are allowed
if "path" not in config and "path" in config_class.__annotations__:
config["path"] = f"/tmp/{provider}"
self.config = config_class(**config)
return self

View File

@@ -0,0 +1,759 @@
import json
import logging
import uuid
from typing import Optional, List
from datetime import datetime, date
from databricks.sdk.service.catalog import ColumnInfo, ColumnTypeName, TableType, DataSourceFormat
from databricks.sdk.service.catalog import TableConstraint, PrimaryKeyConstraint
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.vectorsearch import (
VectorIndexType,
DeltaSyncVectorIndexSpecRequest,
DirectAccessVectorIndexSpec,
EmbeddingSourceColumn,
EmbeddingVectorColumn,
)
from pydantic import BaseModel
from mem0.memory.utils import extract_json
from mem0.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__)
class MemoryResult(BaseModel):
id: Optional[str] = None
score: Optional[float] = None
payload: Optional[dict] = None
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
class Databricks(VectorStoreBase):
def __init__(
self,
workspace_url: str,
access_token: Optional[str] = None,
client_id: Optional[str] = None,
client_secret: Optional[str] = None,
azure_client_id: Optional[str] = None,
azure_client_secret: Optional[str] = None,
endpoint_name: str = None,
catalog: str = None,
schema: str = None,
table_name: str = None,
collection_name: str = "mem0",
index_type: str = "DELTA_SYNC",
embedding_model_endpoint_name: Optional[str] = None,
embedding_dimension: int = 1536,
endpoint_type: str = "STANDARD",
pipeline_type: str = "TRIGGERED",
warehouse_name: Optional[str] = None,
query_type: str = "ANN",
):
"""
Initialize the Databricks Vector Search vector store.
Args:
workspace_url (str): Databricks workspace URL.
access_token (str, optional): Personal access token for authentication.
client_id (str, optional): Service principal client ID for authentication.
client_secret (str, optional): Service principal client secret for authentication.
azure_client_id (str, optional): Azure AD application client ID (for Azure Databricks).
azure_client_secret (str, optional): Azure AD application client secret (for Azure Databricks).
endpoint_name (str): Vector search endpoint name.
catalog (str): Unity Catalog catalog name.
schema (str): Unity Catalog schema name.
table_name (str): Source Delta table name.
index_name (str, optional): Vector search index name (default: "mem0").
index_type (str, optional): Index type, either "DELTA_SYNC" or "DIRECT_ACCESS" (default: "DELTA_SYNC").
embedding_model_endpoint_name (str, optional): Embedding model endpoint for Databricks-computed embeddings.
embedding_dimension (int, optional): Vector embedding dimensions (default: 1536).
endpoint_type (str, optional): Endpoint type, either "STANDARD" or "STORAGE_OPTIMIZED" (default: "STANDARD").
pipeline_type (str, optional): Sync pipeline type, either "TRIGGERED" or "CONTINUOUS" (default: "TRIGGERED").
warehouse_name (str, optional): Databricks SQL warehouse Name (if using SQL warehouse).
query_type (str, optional): Query type, either "ANN" or "HYBRID" (default: "ANN").
"""
# Basic identifiers
self.workspace_url = workspace_url
self.endpoint_name = endpoint_name
self.catalog = catalog
self.schema = schema
self.table_name = table_name
self.fully_qualified_table_name = f"{self.catalog}.{self.schema}.{self.table_name}"
self.index_name = collection_name
self.fully_qualified_index_name = f"{self.catalog}.{self.schema}.{self.index_name}"
# Configuration
self.index_type = index_type
self.embedding_model_endpoint_name = embedding_model_endpoint_name
self.embedding_dimension = embedding_dimension
self.endpoint_type = endpoint_type
self.pipeline_type = pipeline_type
self.query_type = query_type
# Schema
self.columns = [
ColumnInfo(
name="memory_id",
type_name=ColumnTypeName.STRING,
type_text="string",
type_json='{"type":"string"}',
nullable=False,
comment="Primary key",
position=0,
),
ColumnInfo(
name="hash",
type_name=ColumnTypeName.STRING,
type_text="string",
type_json='{"type":"string"}',
comment="Hash of the memory content",
position=1,
),
ColumnInfo(
name="agent_id",
type_name=ColumnTypeName.STRING,
type_text="string",
type_json='{"type":"string"}',
comment="ID of the agent",
position=2,
),
ColumnInfo(
name="run_id",
type_name=ColumnTypeName.STRING,
type_text="string",
type_json='{"type":"string"}',
comment="ID of the run",
position=3,
),
ColumnInfo(
name="user_id",
type_name=ColumnTypeName.STRING,
type_text="string",
type_json='{"type":"string"}',
comment="ID of the user",
position=4,
),
ColumnInfo(
name="memory",
type_name=ColumnTypeName.STRING,
type_text="string",
type_json='{"type":"string"}',
comment="Memory content",
position=5,
),
ColumnInfo(
name="metadata",
type_name=ColumnTypeName.STRING,
type_text="string",
type_json='{"type":"string"}',
comment="Additional metadata",
position=6,
),
ColumnInfo(
name="created_at",
type_name=ColumnTypeName.TIMESTAMP,
type_text="timestamp",
type_json='{"type":"timestamp"}',
comment="Creation timestamp",
position=7,
),
ColumnInfo(
name="updated_at",
type_name=ColumnTypeName.TIMESTAMP,
type_text="timestamp",
type_json='{"type":"timestamp"}',
comment="Last update timestamp",
position=8,
),
]
if self.index_type == VectorIndexType.DIRECT_ACCESS:
self.columns.append(
ColumnInfo(
name="embedding",
type_name=ColumnTypeName.ARRAY,
type_text="array<float>",
type_json='{"type":"array","element":"float","element_nullable":false}',
nullable=True,
comment="Embedding vector",
position=9,
)
)
self.column_names = [col.name for col in self.columns]
# Initialize Databricks workspace client
client_config = {}
if client_id and client_secret:
client_config.update(
{
"host": workspace_url,
"client_id": client_id,
"client_secret": client_secret,
}
)
elif azure_client_id and azure_client_secret:
client_config.update(
{
"host": workspace_url,
"azure_client_id": azure_client_id,
"azure_client_secret": azure_client_secret,
}
)
elif access_token:
client_config.update({"host": workspace_url, "token": access_token})
else:
# Try automatic authentication
client_config["host"] = workspace_url
try:
self.client = WorkspaceClient(**client_config)
logger.info("Initialized Databricks workspace client")
except Exception as e:
logger.error(f"Failed to initialize Databricks workspace client: {e}")
raise
# Get the warehouse ID by name
self.warehouse_id = next((w.id for w in self.client.warehouses.list() if w.name == warehouse_name), None)
# Initialize endpoint (required in Databricks)
self._ensure_endpoint_exists()
# Check if index exists and create if needed
collections = self.list_cols()
if self.fully_qualified_index_name not in collections:
self.create_col()
def _ensure_endpoint_exists(self):
"""Ensure the vector search endpoint exists, create if it doesn't."""
try:
self.client.vector_search_endpoints.get_endpoint(endpoint_name=self.endpoint_name)
logger.info(f"Vector search endpoint '{self.endpoint_name}' already exists")
except Exception:
# Endpoint doesn't exist, create it
try:
logger.info(f"Creating vector search endpoint '{self.endpoint_name}' with type '{self.endpoint_type}'")
self.client.vector_search_endpoints.create_endpoint_and_wait(
name=self.endpoint_name, endpoint_type=self.endpoint_type
)
logger.info(f"Successfully created vector search endpoint '{self.endpoint_name}'")
except Exception as e:
logger.error(f"Failed to create vector search endpoint '{self.endpoint_name}': {e}")
raise
def _ensure_source_table_exists(self):
"""Ensure the source Delta table exists with the proper schema."""
check = self.client.tables.exists(self.fully_qualified_table_name)
if check.table_exists:
logger.info(f"Source table '{self.fully_qualified_table_name}' already exists")
else:
logger.info(f"Source table '{self.fully_qualified_table_name}' does not exist, creating it...")
self.client.tables.create(
name=self.table_name,
catalog_name=self.catalog,
schema_name=self.schema,
table_type=TableType.MANAGED,
data_source_format=DataSourceFormat.DELTA,
storage_location=None, # Use default storage location
columns=self.columns,
properties={"delta.enableChangeDataFeed": "true"},
)
logger.info(f"Successfully created source table '{self.fully_qualified_table_name}'")
self.client.table_constraints.create(
full_name_arg="logistics_dev.ai.dev_memory",
constraint=TableConstraint(
primary_key_constraint=PrimaryKeyConstraint(
name="pk_dev_memory", # Name of the primary key constraint
child_columns=["memory_id"], # Columns that make up the primary key
)
),
)
logger.info(
f"Successfully created primary key constraint on 'memory_id' for table '{self.fully_qualified_table_name}'"
)
def create_col(self, name=None, vector_size=None, distance=None):
"""
Create a new collection (index).
Args:
name (str, optional): Index name. If provided, will create a new index using the provided source_table_name.
vector_size (int, optional): Vector dimension size.
distance (str, optional): Distance metric (not directly applicable for Databricks).
Returns:
The index object.
"""
# Determine index configuration
embedding_dims = vector_size or self.embedding_dimension
embedding_source_columns = [
EmbeddingSourceColumn(
name="memory",
embedding_model_endpoint_name=self.embedding_model_endpoint_name,
)
]
logger.info(f"Creating vector search index '{self.fully_qualified_index_name}'")
# First, ensure the source Delta table exists
self._ensure_source_table_exists()
if self.index_type not in [VectorIndexType.DELTA_SYNC, VectorIndexType.DIRECT_ACCESS]:
raise ValueError("index_type must be either 'DELTA_SYNC' or 'DIRECT_ACCESS'")
try:
if self.index_type == VectorIndexType.DELTA_SYNC:
index = self.client.vector_search_indexes.create_index(
name=self.fully_qualified_index_name,
endpoint_name=self.endpoint_name,
primary_key="memory_id",
index_type=self.index_type,
delta_sync_index_spec=DeltaSyncVectorIndexSpecRequest(
source_table=self.fully_qualified_table_name,
pipeline_type=self.pipeline_type,
columns_to_sync=self.column_names,
embedding_source_columns=embedding_source_columns,
),
)
logger.info(
f"Successfully created vector search index '{self.fully_qualified_index_name}' with DELTA_SYNC type"
)
return index
elif self.index_type == VectorIndexType.DIRECT_ACCESS:
index = self.client.vector_search_indexes.create_index(
name=self.fully_qualified_index_name,
endpoint_name=self.endpoint_name,
primary_key="memory_id",
index_type=self.index_type,
direct_access_index_spec=DirectAccessVectorIndexSpec(
embedding_source_columns=embedding_source_columns,
embedding_vector_columns=[
EmbeddingVectorColumn(name="embedding", embedding_dimension=embedding_dims)
],
),
)
logger.info(
f"Successfully created vector search index '{self.fully_qualified_index_name}' with DIRECT_ACCESS type"
)
return index
except Exception as e:
logger.error(f"Error making index_type: {self.index_type} for index {self.fully_qualified_index_name}: {e}")
def _format_sql_value(self, v):
"""
Format a Python value into a safe SQL literal for Databricks.
"""
if v is None:
return "NULL"
if isinstance(v, bool):
return "TRUE" if v else "FALSE"
if isinstance(v, (int, float)):
return str(v)
if isinstance(v, (datetime, date)):
return f"'{v.isoformat()}'"
if isinstance(v, list):
# Render arrays (assume numeric or string elements)
elems = []
for x in v:
if x is None:
elems.append("NULL")
elif isinstance(x, (int, float)):
elems.append(str(x))
else:
s = str(x).replace("'", "''")
elems.append(f"'{s}'")
return f"array({', '.join(elems)})"
if isinstance(v, dict):
try:
s = json.dumps(v)
except Exception:
s = str(v)
s = s.replace("'", "''")
return f"'{s}'"
# Fallback: treat as string
s = str(v).replace("'", "''")
return f"'{s}'"
def insert(self, vectors: list, payloads: list = None, ids: list = None):
"""
Insert vectors into the index.
Args:
vectors (List[List[float]]): List of vectors to insert.
payloads (List[Dict], optional): List of payloads corresponding to vectors.
ids (List[str], optional): List of IDs corresponding to vectors.
"""
# Determine the number of items to process
num_items = len(payloads) if payloads else len(vectors) if vectors else 0
value_tuples = []
for i in range(num_items):
values = []
for col in self.columns:
if col.name == "memory_id":
val = ids[i] if ids and i < len(ids) else str(uuid.uuid4())
elif col.name == "embedding":
val = vectors[i] if vectors and i < len(vectors) else []
elif col.name == "memory":
val = payloads[i].get("data") if payloads and i < len(payloads) else None
else:
val = payloads[i].get(col.name) if payloads and i < len(payloads) else None
values.append(val)
formatted = [self._format_sql_value(v) for v in values]
value_tuples.append(f"({', '.join(formatted)})")
insert_sql = f"INSERT INTO {self.fully_qualified_table_name} ({', '.join(self.column_names)}) VALUES {', '.join(value_tuples)}"
# Execute the insert
try:
response = self.client.statement_execution.execute_statement(
statement=insert_sql, warehouse_id=self.warehouse_id, wait_timeout="30s"
)
if response.status.state.value == "SUCCEEDED":
logger.info(
f"Successfully inserted {num_items} items into Delta table {self.fully_qualified_table_name}"
)
return
else:
logger.error(f"Failed to insert items: {response.status.error}")
raise Exception(f"Insert operation failed: {response.status.error}")
except Exception as e:
logger.error(f"Insert operation failed: {e}")
raise
def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None) -> List[MemoryResult]:
"""
Search for similar vectors or text using the Databricks Vector Search index.
Args:
query (str): Search query text (for text-based search).
vectors (list): Query vector (for vector-based search).
limit (int): Maximum number of results.
filters (dict): Filters to apply.
Returns:
List of MemoryResult objects.
"""
try:
filters_json = json.dumps(filters) if filters else None
# Choose query type
if self.index_type == VectorIndexType.DELTA_SYNC and query:
# Text-based search
sdk_results = self.client.vector_search_indexes.query_index(
index_name=self.fully_qualified_index_name,
columns=self.column_names,
query_text=query,
num_results=limit,
query_type=self.query_type,
filters_json=filters_json,
)
elif self.index_type == VectorIndexType.DIRECT_ACCESS and vectors:
# Vector-based search
sdk_results = self.client.vector_search_indexes.query_index(
index_name=self.fully_qualified_index_name,
columns=self.column_names,
query_vector=vectors,
num_results=limit,
query_type=self.query_type,
filters_json=filters_json,
)
else:
raise ValueError("Must provide query text for DELTA_SYNC or vectors for DIRECT_ACCESS.")
# Parse results
result_data = sdk_results.result if hasattr(sdk_results, "result") else sdk_results
data_array = result_data.data_array if getattr(result_data, "data_array", None) else []
memory_results = []
for row in data_array:
# Map columns to values
row_dict = dict(zip(self.column_names, row)) if isinstance(row, (list, tuple)) else row
score = row_dict.get("score") or (
row[-1] if isinstance(row, (list, tuple)) and len(row) > len(self.column_names) else None
)
payload = {k: row_dict.get(k) for k in self.column_names}
payload["data"] = payload.get("memory", "")
memory_id = row_dict.get("memory_id") or row_dict.get("id")
memory_results.append(MemoryResult(id=memory_id, score=score, payload=payload))
return memory_results
except Exception as e:
logger.error(f"Search failed: {e}")
raise
def delete(self, vector_id):
"""
Delete a vector by ID from the Delta table.
Args:
vector_id (str): ID of the vector to delete.
"""
try:
logger.info(f"Deleting vector with ID {vector_id} from Delta table {self.fully_qualified_table_name}")
delete_sql = f"DELETE FROM {self.fully_qualified_table_name} WHERE memory_id = '{vector_id}'"
response = self.client.statement_execution.execute_statement(
statement=delete_sql, warehouse_id=self.warehouse_id, wait_timeout="30s"
)
if response.status.state.value == "SUCCEEDED":
logger.info(f"Successfully deleted vector with ID {vector_id}")
else:
logger.error(f"Failed to delete vector with ID {vector_id}: {response.status.error}")
except Exception as e:
logger.error(f"Delete operation failed for vector ID {vector_id}: {e}")
raise
def update(self, vector_id=None, vector=None, payload=None):
"""
Update a vector and its payload in the Delta table.
Args:
vector_id (str): ID of the vector to update.
vector (list, optional): New vector values.
payload (dict, optional): New payload data.
"""
update_sql = f"UPDATE {self.fully_qualified_table_name} SET "
set_clauses = []
if not vector_id:
logger.error("vector_id is required for update operation")
return
if vector is not None:
if not isinstance(vector, list):
logger.error("vector must be a list of float values")
return
set_clauses.append(f"embedding = {vector}")
if payload:
if not isinstance(payload, dict):
logger.error("payload must be a dictionary")
return
for key, value in payload.items():
if key not in excluded_keys:
set_clauses.append(f"{key} = '{value}'")
if not set_clauses:
logger.error("No fields to update")
return
update_sql += ", ".join(set_clauses)
update_sql += f" WHERE memory_id = '{vector_id}'"
try:
logger.info(f"Updating vector with ID {vector_id} in Delta table {self.fully_qualified_table_name}")
response = self.client.statement_execution.execute_statement(
statement=update_sql, warehouse_id=self.warehouse_id, wait_timeout="30s"
)
if response.status.state.value == "SUCCEEDED":
logger.info(f"Successfully updated vector with ID {vector_id}")
else:
logger.error(f"Failed to update vector with ID {vector_id}: {response.status.error}")
except Exception as e:
logger.error(f"Update operation failed for vector ID {vector_id}: {e}")
raise
def get(self, vector_id) -> MemoryResult:
"""
Retrieve a vector by ID.
Args:
vector_id (str): ID of the vector to retrieve.
Returns:
MemoryResult: The retrieved vector.
"""
try:
# Use query with ID filter to retrieve the specific vector
filters = {"memory_id": vector_id}
filters_json = json.dumps(filters)
results = self.client.vector_search_indexes.query_index(
index_name=self.fully_qualified_index_name,
columns=self.column_names,
query_text=" ", # Empty query, rely on filters
num_results=1,
query_type=self.query_type,
filters_json=filters_json,
)
# Process results
result_data = results.result if hasattr(results, "result") else results
data_array = result_data.data_array if hasattr(result_data, "data_array") else []
if not data_array:
raise KeyError(f"Vector with ID {vector_id} not found")
result = data_array[0]
row_data = result if isinstance(result, dict) else result.__dict__
# Build payload following the standard schema
payload = {
"hash": row_data.get("hash", "unknown"),
"data": row_data.get("memory", row_data.get("data", "unknown")),
"created_at": row_data.get("created_at"),
}
# Add updated_at if available
if "updated_at" in row_data:
payload["updated_at"] = row_data.get("updated_at")
# Add optional fields
for field in ["agent_id", "run_id", "user_id"]:
if field in row_data:
payload[field] = row_data[field]
# Add metadata
if "metadata" in row_data:
try:
metadata = json.loads(extract_json(row_data["metadata"]))
payload.update(metadata)
except (json.JSONDecodeError, TypeError):
logger.warning(f"Failed to parse metadata: {row_data.get('metadata')}")
memory_id = row_data.get("memory_id", row_data.get("memory_id", vector_id))
return MemoryResult(id=memory_id, payload=payload)
except Exception as e:
logger.error(f"Failed to get vector with ID {vector_id}: {e}")
raise
def list_cols(self) -> List[str]:
"""
List all collections (indexes).
Returns:
List of index names.
"""
try:
indexes = self.client.vector_search_indexes.list_indexes(endpoint_name=self.endpoint_name)
return [idx.name for idx in indexes]
except Exception as e:
logger.error(f"Failed to list collections: {e}")
raise
def delete_col(self):
"""
Delete the current collection (index).
"""
try:
# Try fully qualified first
try:
self.client.vector_search_indexes.delete_index(index_name=self.fully_qualified_index_name)
logger.info(f"Successfully deleted index '{self.fully_qualified_index_name}'")
except Exception:
self.client.vector_search_indexes.delete_index(index_name=self.index_name)
logger.info(f"Successfully deleted index '{self.index_name}' (short name)")
except Exception as e:
logger.error(f"Failed to delete index '{self.index_name}': {e}")
raise
def col_info(self, name=None):
"""
Get information about a collection (index).
Args:
name (str, optional): Index name. Defaults to current index.
Returns:
Dict: Index information.
"""
try:
index_name = name or self.index_name
index = self.client.vector_search_indexes.get_index(index_name=index_name)
return {"name": index.name, "fields": self.columns}
except Exception as e:
logger.error(f"Failed to get info for index '{name or self.index_name}': {e}")
raise
def list(self, filters: dict = None, limit: int = None) -> list[MemoryResult]:
"""
List all recent created memories from the vector store.
Args:
filters (dict, optional): Filters to apply.
limit (int, optional): Maximum number of results.
Returns:
List containing list of MemoryResult objects.
"""
try:
filters_json = json.dumps(filters) if filters else None
num_results = limit or 100
columns = self.column_names
sdk_results = self.client.vector_search_indexes.query_index(
index_name=self.fully_qualified_index_name,
columns=columns,
query_text=" ",
num_results=num_results,
query_type=self.query_type,
filters_json=filters_json,
)
result_data = sdk_results.result if hasattr(sdk_results, "result") else sdk_results
data_array = result_data.data_array if hasattr(result_data, "data_array") else []
memory_results = []
for row in data_array:
row_dict = dict(zip(columns, row)) if isinstance(row, (list, tuple)) else row
payload = {k: row_dict.get(k) for k in columns}
# Parse metadata if present
if "metadata" in payload and payload["metadata"]:
try:
payload.update(json.loads(payload["metadata"]))
except Exception:
pass
memory_id = row_dict.get("memory_id") or row_dict.get("id")
memory_results.append(MemoryResult(id=memory_id, payload=payload))
return [memory_results]
except Exception as e:
logger.error(f"Failed to list memories: {e}")
return []
def reset(self):
"""Reset the vector search index and underlying source table.
This will attempt to delete the existing index (both fully qualified and short name forms
for robustness), drop the backing Delta table, recreate the table with the expected schema,
and finally recreate the index. Use with caution as all existing data will be removed.
"""
fq_index = self.fully_qualified_index_name
logger.warning(f"Resetting Databricks vector search index '{fq_index}'...")
try:
# Try deleting via fully qualified name first
try:
self.client.vector_search_indexes.delete_index(index_name=fq_index)
logger.info(f"Deleted index '{fq_index}'")
except Exception as e_fq:
logger.debug(f"Failed deleting fully qualified index name '{fq_index}': {e_fq}. Trying short name...")
try:
# Fallback to existing helper which may use short name
self.delete_col()
except Exception as e_short:
logger.debug(f"Failed deleting short index name '{self.index_name}': {e_short}")
# Drop the backing table (if it exists)
try:
drop_sql = f"DROP TABLE IF EXISTS {self.fully_qualified_table_name}"
resp = self.client.statement_execution.execute_statement(
statement=drop_sql, warehouse_id=self.warehouse_id, wait_timeout="30s"
)
if getattr(resp.status, "state", None) == "SUCCEEDED":
logger.info(f"Dropped table '{self.fully_qualified_table_name}'")
else:
logger.warning(
f"Attempted to drop table '{self.fully_qualified_table_name}' but state was {getattr(resp.status, 'state', 'UNKNOWN')}: {getattr(resp.status, 'error', None)}"
)
except Exception as e_drop:
logger.warning(f"Failed to drop table '{self.fully_qualified_table_name}': {e_drop}")
# Recreate table & index
self._ensure_source_table_exists()
self.create_col()
logger.info(f"Successfully reset index '{fq_index}'")
except Exception as e:
logger.error(f"Error resetting index '{fq_index}': {e}")
raise

View File

@@ -0,0 +1,237 @@
import logging
from typing import Any, Dict, List, Optional
try:
from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk
except ImportError:
raise ImportError("Elasticsearch requires extra dependencies. Install with `pip install elasticsearch`") from None
from pydantic import BaseModel
from mem0.configs.vector_stores.elasticsearch import ElasticsearchConfig
from mem0.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__)
class OutputData(BaseModel):
id: str
score: float
payload: Dict
class ElasticsearchDB(VectorStoreBase):
def __init__(self, **kwargs):
config = ElasticsearchConfig(**kwargs)
# Initialize Elasticsearch client
if config.cloud_id:
self.client = Elasticsearch(
cloud_id=config.cloud_id,
api_key=config.api_key,
verify_certs=config.verify_certs,
headers= config.headers or {},
)
else:
self.client = Elasticsearch(
hosts=[f"{config.host}" if config.port is None else f"{config.host}:{config.port}"],
basic_auth=(config.user, config.password) if (config.user and config.password) else None,
verify_certs=config.verify_certs,
headers= config.headers or {},
)
self.collection_name = config.collection_name
self.embedding_model_dims = config.embedding_model_dims
# Create index only if auto_create_index is True
if config.auto_create_index:
self.create_index()
if config.custom_search_query:
self.custom_search_query = config.custom_search_query
else:
self.custom_search_query = None
def create_index(self) -> None:
"""Create Elasticsearch index with proper mappings if it doesn't exist"""
index_settings = {
"settings": {"index": {"number_of_replicas": 1, "number_of_shards": 5, "refresh_interval": "1s"}},
"mappings": {
"properties": {
"text": {"type": "text"},
"vector": {
"type": "dense_vector",
"dims": self.embedding_model_dims,
"index": True,
"similarity": "cosine",
},
"metadata": {"type": "object", "properties": {"user_id": {"type": "keyword"}}},
}
},
}
if not self.client.indices.exists(index=self.collection_name):
self.client.indices.create(index=self.collection_name, body=index_settings)
logger.info(f"Created index {self.collection_name}")
else:
logger.info(f"Index {self.collection_name} already exists")
def create_col(self, name: str, vector_size: int, distance: str = "cosine") -> None:
"""Create a new collection (index in Elasticsearch)."""
index_settings = {
"mappings": {
"properties": {
"vector": {"type": "dense_vector", "dims": vector_size, "index": True, "similarity": "cosine"},
"payload": {"type": "object"},
"id": {"type": "keyword"},
}
}
}
if not self.client.indices.exists(index=name):
self.client.indices.create(index=name, body=index_settings)
logger.info(f"Created index {name}")
def insert(
self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None
) -> List[OutputData]:
"""Insert vectors into the index."""
if not ids:
ids = [str(i) for i in range(len(vectors))]
if payloads is None:
payloads = [{} for _ in range(len(vectors))]
actions = []
for i, (vec, id_) in enumerate(zip(vectors, ids)):
action = {
"_index": self.collection_name,
"_id": id_,
"_source": {
"vector": vec,
"metadata": payloads[i], # Store all metadata in the metadata field
},
}
actions.append(action)
bulk(self.client, actions)
results = []
for i, id_ in enumerate(ids):
results.append(
OutputData(
id=id_,
score=1.0, # Default score for inserts
payload=payloads[i],
)
)
return results
def search(
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
) -> List[OutputData]:
"""
Search with two options:
1. Use custom search query if provided
2. Use KNN search on vectors with pre-filtering if no custom search query is provided
"""
if self.custom_search_query:
search_query = self.custom_search_query(vectors, limit, filters)
else:
search_query = {
"knn": {"field": "vector", "query_vector": vectors, "k": limit, "num_candidates": limit * 2}
}
if filters:
filter_conditions = []
for key, value in filters.items():
filter_conditions.append({"term": {f"metadata.{key}": value}})
search_query["knn"]["filter"] = {"bool": {"must": filter_conditions}}
response = self.client.search(index=self.collection_name, body=search_query)
results = []
for hit in response["hits"]["hits"]:
results.append(
OutputData(id=hit["_id"], score=hit["_score"], payload=hit.get("_source", {}).get("metadata", {}))
)
return results
def delete(self, vector_id: str) -> None:
"""Delete a vector by ID."""
self.client.delete(index=self.collection_name, id=vector_id)
def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict] = None) -> None:
"""Update a vector and its payload."""
doc = {}
if vector is not None:
doc["vector"] = vector
if payload is not None:
doc["metadata"] = payload
self.client.update(index=self.collection_name, id=vector_id, body={"doc": doc})
def get(self, vector_id: str) -> Optional[OutputData]:
"""Retrieve a vector by ID."""
try:
response = self.client.get(index=self.collection_name, id=vector_id)
return OutputData(
id=response["_id"],
score=1.0, # Default score for direct get
payload=response["_source"].get("metadata", {}),
)
except KeyError as e:
logger.warning(f"Missing key in Elasticsearch response: {e}")
return None
except TypeError as e:
logger.warning(f"Invalid response type from Elasticsearch: {e}")
return None
except Exception as e:
logger.error(f"Unexpected error while parsing Elasticsearch response: {e}")
return None
def list_cols(self) -> List[str]:
"""List all collections (indices)."""
return list(self.client.indices.get_alias().keys())
def delete_col(self) -> None:
"""Delete a collection (index)."""
self.client.indices.delete(index=self.collection_name)
def col_info(self, name: str) -> Any:
"""Get information about a collection (index)."""
return self.client.indices.get(index=name)
def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None) -> List[List[OutputData]]:
"""List all memories."""
query: Dict[str, Any] = {"query": {"match_all": {}}}
if filters:
filter_conditions = []
for key, value in filters.items():
filter_conditions.append({"term": {f"metadata.{key}": value}})
query["query"] = {"bool": {"must": filter_conditions}}
if limit:
query["size"] = limit
response = self.client.search(index=self.collection_name, body=query)
results = []
for hit in response["hits"]["hits"]:
results.append(
OutputData(
id=hit["_id"],
score=1.0, # Default score for list operation
payload=hit.get("_source", {}).get("metadata", {}),
)
)
return [results]
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")
self.delete_col()
self.create_index()

View File

@@ -0,0 +1,479 @@
import logging
import os
import pickle
import uuid
from pathlib import Path
from typing import Dict, List, Optional
import numpy as np
from pydantic import BaseModel
import warnings
try:
# Suppress SWIG deprecation warnings from FAISS
warnings.filterwarnings("ignore", category=DeprecationWarning, message=".*SwigPy.*")
warnings.filterwarnings("ignore", category=DeprecationWarning, message=".*swigvarlink.*")
logging.getLogger("faiss").setLevel(logging.WARNING)
logging.getLogger("faiss.loader").setLevel(logging.WARNING)
import faiss
except ImportError:
raise ImportError(
"Could not import faiss python package. "
"Please install it with `pip install faiss-gpu` (for CUDA supported GPU) "
"or `pip install faiss-cpu` (depending on Python version)."
)
from mem0.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__)
class OutputData(BaseModel):
id: Optional[str] # memory id
score: Optional[float] # distance
payload: Optional[Dict] # metadata
class FAISS(VectorStoreBase):
def __init__(
self,
collection_name: str,
path: Optional[str] = None,
distance_strategy: str = "euclidean",
normalize_L2: bool = False,
embedding_model_dims: int = 1536,
):
"""
Initialize the FAISS vector store.
Args:
collection_name (str): Name of the collection.
path (str, optional): Path for local FAISS database. Defaults to None.
distance_strategy (str, optional): Distance strategy to use. Options: 'euclidean', 'inner_product', 'cosine'.
Defaults to "euclidean".
normalize_L2 (bool, optional): Whether to normalize L2 vectors. Only applicable for euclidean distance.
Defaults to False.
"""
self.collection_name = collection_name
self.path = path or f"/tmp/faiss/{collection_name}"
self.distance_strategy = distance_strategy
self.normalize_L2 = normalize_L2
self.embedding_model_dims = embedding_model_dims
# Initialize storage structures
self.index = None
self.docstore = {}
self.index_to_id = {}
# Create directory if it doesn't exist
if self.path:
os.makedirs(os.path.dirname(self.path), exist_ok=True)
# Try to load existing index if available
index_path = f"{self.path}/{collection_name}.faiss"
docstore_path = f"{self.path}/{collection_name}.pkl"
if os.path.exists(index_path) and os.path.exists(docstore_path):
self._load(index_path, docstore_path)
else:
self.create_col(collection_name)
def _load(self, index_path: str, docstore_path: str):
"""
Load FAISS index and docstore from disk.
Args:
index_path (str): Path to FAISS index file.
docstore_path (str): Path to docstore pickle file.
"""
try:
self.index = faiss.read_index(index_path)
with open(docstore_path, "rb") as f:
self.docstore, self.index_to_id = pickle.load(f)
logger.info(f"Loaded FAISS index from {index_path} with {self.index.ntotal} vectors")
except Exception as e:
logger.warning(f"Failed to load FAISS index: {e}")
self.docstore = {}
self.index_to_id = {}
def _save(self):
"""Save FAISS index and docstore to disk."""
if not self.path or not self.index:
return
try:
os.makedirs(self.path, exist_ok=True)
index_path = f"{self.path}/{self.collection_name}.faiss"
docstore_path = f"{self.path}/{self.collection_name}.pkl"
faiss.write_index(self.index, index_path)
with open(docstore_path, "wb") as f:
pickle.dump((self.docstore, self.index_to_id), f)
except Exception as e:
logger.warning(f"Failed to save FAISS index: {e}")
def _parse_output(self, scores, ids, limit=None) -> List[OutputData]:
"""
Parse the output data.
Args:
scores: Similarity scores from FAISS.
ids: Indices from FAISS.
limit: Maximum number of results to return.
Returns:
List[OutputData]: Parsed output data.
"""
if limit is None:
limit = len(ids)
results = []
for i in range(min(len(ids), limit)):
if ids[i] == -1: # FAISS returns -1 for empty results
continue
index_id = int(ids[i])
vector_id = self.index_to_id.get(index_id)
if vector_id is None:
continue
payload = self.docstore.get(vector_id)
if payload is None:
continue
payload_copy = payload.copy()
score = float(scores[i])
entry = OutputData(
id=vector_id,
score=score,
payload=payload_copy,
)
results.append(entry)
return results
def create_col(self, name: str, distance: str = None):
"""
Create a new collection.
Args:
name (str): Name of the collection.
distance (str, optional): Distance metric to use. Overrides the distance_strategy
passed during initialization. Defaults to None.
Returns:
self: The FAISS instance.
"""
distance_strategy = distance or self.distance_strategy
# Create index based on distance strategy
if distance_strategy.lower() == "inner_product" or distance_strategy.lower() == "cosine":
self.index = faiss.IndexFlatIP(self.embedding_model_dims)
else:
self.index = faiss.IndexFlatL2(self.embedding_model_dims)
self.collection_name = name
self._save()
return self
def insert(
self,
vectors: List[list],
payloads: Optional[List[Dict]] = None,
ids: Optional[List[str]] = None,
):
"""
Insert vectors into a collection.
Args:
vectors (List[list]): List of vectors to insert.
payloads (Optional[List[Dict]], optional): List of payloads corresponding to vectors. Defaults to None.
ids (Optional[List[str]], optional): List of IDs corresponding to vectors. Defaults to None.
"""
if self.index is None:
raise ValueError("Collection not initialized. Call create_col first.")
if ids is None:
ids = [str(uuid.uuid4()) for _ in range(len(vectors))]
if payloads is None:
payloads = [{} for _ in range(len(vectors))]
if len(vectors) != len(ids) or len(vectors) != len(payloads):
raise ValueError("Vectors, payloads, and IDs must have the same length")
vectors_np = np.array(vectors, dtype=np.float32)
if self.normalize_L2 and self.distance_strategy.lower() == "euclidean":
faiss.normalize_L2(vectors_np)
self.index.add(vectors_np)
starting_idx = len(self.index_to_id)
for i, (vector_id, payload) in enumerate(zip(ids, payloads)):
self.docstore[vector_id] = payload.copy()
self.index_to_id[starting_idx + i] = vector_id
self._save()
logger.info(f"Inserted {len(vectors)} vectors into collection {self.collection_name}")
def search(
self, query: str, vectors: List[list], limit: int = 5, filters: Optional[Dict] = None
) -> List[OutputData]:
"""
Search for similar vectors.
Args:
query (str): Query (not used, kept for API compatibility).
vectors (List[list]): List of vectors to search.
limit (int, optional): Number of results to return. Defaults to 5.
filters (Optional[Dict], optional): Filters to apply to the search. Defaults to None.
Returns:
List[OutputData]: Search results.
"""
if self.index is None:
raise ValueError("Collection not initialized. Call create_col first.")
query_vectors = np.array(vectors, dtype=np.float32)
if len(query_vectors.shape) == 1:
query_vectors = query_vectors.reshape(1, -1)
if self.normalize_L2 and self.distance_strategy.lower() == "euclidean":
faiss.normalize_L2(query_vectors)
fetch_k = limit * 2 if filters else limit
scores, indices = self.index.search(query_vectors, fetch_k)
results = self._parse_output(scores[0], indices[0], limit)
if filters:
filtered_results = []
for result in results:
if self._apply_filters(result.payload, filters):
filtered_results.append(result)
if len(filtered_results) >= limit:
break
results = filtered_results[:limit]
return results
def _apply_filters(self, payload: Dict, filters: Dict) -> bool:
"""
Apply filters to a payload.
Args:
payload (Dict): Payload to filter.
filters (Dict): Filters to apply.
Returns:
bool: True if payload passes filters, False otherwise.
"""
if not filters or not payload:
return True
for key, value in filters.items():
if key not in payload:
return False
if isinstance(value, list):
if payload[key] not in value:
return False
elif payload[key] != value:
return False
return True
def delete(self, vector_id: str):
"""
Delete a vector by ID.
Args:
vector_id (str): ID of the vector to delete.
"""
if self.index is None:
raise ValueError("Collection not initialized. Call create_col first.")
index_to_delete = None
for idx, vid in self.index_to_id.items():
if vid == vector_id:
index_to_delete = idx
break
if index_to_delete is not None:
self.docstore.pop(vector_id, None)
self.index_to_id.pop(index_to_delete, None)
self._save()
logger.info(f"Deleted vector {vector_id} from collection {self.collection_name}")
else:
logger.warning(f"Vector {vector_id} not found in collection {self.collection_name}")
def update(
self,
vector_id: str,
vector: Optional[List[float]] = None,
payload: Optional[Dict] = None,
):
"""
Update a vector and its payload.
Args:
vector_id (str): ID of the vector to update.
vector (Optional[List[float]], optional): Updated vector. Defaults to None.
payload (Optional[Dict], optional): Updated payload. Defaults to None.
"""
if self.index is None:
raise ValueError("Collection not initialized. Call create_col first.")
if vector_id not in self.docstore:
raise ValueError(f"Vector {vector_id} not found")
current_payload = self.docstore[vector_id].copy()
if payload is not None:
self.docstore[vector_id] = payload.copy()
current_payload = self.docstore[vector_id].copy()
if vector is not None:
self.delete(vector_id)
self.insert([vector], [current_payload], [vector_id])
else:
self._save()
logger.info(f"Updated vector {vector_id} in collection {self.collection_name}")
def get(self, vector_id: str) -> OutputData:
"""
Retrieve a vector by ID.
Args:
vector_id (str): ID of the vector to retrieve.
Returns:
OutputData: Retrieved vector.
"""
if self.index is None:
raise ValueError("Collection not initialized. Call create_col first.")
if vector_id not in self.docstore:
return None
payload = self.docstore[vector_id].copy()
return OutputData(
id=vector_id,
score=None,
payload=payload,
)
def list_cols(self) -> List[str]:
"""
List all collections.
Returns:
List[str]: List of collection names.
"""
if not self.path:
return [self.collection_name] if self.index else []
try:
collections = []
path = Path(self.path).parent
for file in path.glob("*.faiss"):
collections.append(file.stem)
return collections
except Exception as e:
logger.warning(f"Failed to list collections: {e}")
return [self.collection_name] if self.index else []
def delete_col(self):
"""
Delete a collection.
"""
if self.path:
try:
index_path = f"{self.path}/{self.collection_name}.faiss"
docstore_path = f"{self.path}/{self.collection_name}.pkl"
if os.path.exists(index_path):
os.remove(index_path)
if os.path.exists(docstore_path):
os.remove(docstore_path)
logger.info(f"Deleted collection {self.collection_name}")
except Exception as e:
logger.warning(f"Failed to delete collection: {e}")
self.index = None
self.docstore = {}
self.index_to_id = {}
def col_info(self) -> Dict:
"""
Get information about a collection.
Returns:
Dict: Collection information.
"""
if self.index is None:
return {"name": self.collection_name, "count": 0}
return {
"name": self.collection_name,
"count": self.index.ntotal,
"dimension": self.index.d,
"distance": self.distance_strategy,
}
def list(self, filters: Optional[Dict] = None, limit: int = 100) -> List[OutputData]:
"""
List all vectors in a collection.
Args:
filters (Optional[Dict], optional): Filters to apply to the list. Defaults to None.
limit (int, optional): Number of vectors to return. Defaults to 100.
Returns:
List[OutputData]: List of vectors.
"""
if self.index is None:
return []
results = []
count = 0
for vector_id, payload in self.docstore.items():
if filters and not self._apply_filters(payload, filters):
continue
payload_copy = payload.copy()
results.append(
OutputData(
id=vector_id,
score=None,
payload=payload_copy,
)
)
count += 1
if count >= limit:
break
return [results]
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")
self.delete_col()
self.create_col(self.collection_name)

View File

@@ -0,0 +1,180 @@
import logging
from typing import Dict, List, Optional
from pydantic import BaseModel
try:
from langchain_community.vectorstores import VectorStore
except ImportError:
raise ImportError(
"The 'langchain_community' library is required. Please install it using 'pip install langchain_community'."
)
from mem0.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__)
class OutputData(BaseModel):
id: Optional[str] # memory id
score: Optional[float] # distance
payload: Optional[Dict] # metadata
class Langchain(VectorStoreBase):
def __init__(self, client: VectorStore, collection_name: str = "mem0"):
self.client = client
self.collection_name = collection_name
def _parse_output(self, data: Dict) -> List[OutputData]:
"""
Parse the output data.
Args:
data (Dict): Output data or list of Document objects.
Returns:
List[OutputData]: Parsed output data.
"""
# Check if input is a list of Document objects
if isinstance(data, list) and all(hasattr(doc, "metadata") for doc in data if hasattr(doc, "__dict__")):
result = []
for doc in data:
entry = OutputData(
id=getattr(doc, "id", None),
score=None, # Document objects typically don't include scores
payload=getattr(doc, "metadata", {}),
)
result.append(entry)
return result
# Original format handling
keys = ["ids", "distances", "metadatas"]
values = []
for key in keys:
value = data.get(key, [])
if isinstance(value, list) and value and isinstance(value[0], list):
value = value[0]
values.append(value)
ids, distances, metadatas = values
max_length = max(len(v) for v in values if isinstance(v, list) and v is not None)
result = []
for i in range(max_length):
entry = OutputData(
id=ids[i] if isinstance(ids, list) and ids and i < len(ids) else None,
score=(distances[i] if isinstance(distances, list) and distances and i < len(distances) else None),
payload=(metadatas[i] if isinstance(metadatas, list) and metadatas and i < len(metadatas) else None),
)
result.append(entry)
return result
def create_col(self, name, vector_size=None, distance=None):
self.collection_name = name
return self.client
def insert(
self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None
):
"""
Insert vectors into the LangChain vectorstore.
"""
# Check if client has add_embeddings method
if hasattr(self.client, "add_embeddings"):
# Some LangChain vectorstores have a direct add_embeddings method
self.client.add_embeddings(embeddings=vectors, metadatas=payloads, ids=ids)
else:
# Fallback to add_texts method
texts = [payload.get("data", "") for payload in payloads] if payloads else [""] * len(vectors)
self.client.add_texts(texts=texts, metadatas=payloads, ids=ids)
def search(self, query: str, vectors: List[List[float]], limit: int = 5, filters: Optional[Dict] = None):
"""
Search for similar vectors in LangChain.
"""
# For each vector, perform a similarity search
if filters:
results = self.client.similarity_search_by_vector(embedding=vectors, k=limit, filter=filters)
else:
results = self.client.similarity_search_by_vector(embedding=vectors, k=limit)
final_results = self._parse_output(results)
return final_results
def delete(self, vector_id):
"""
Delete a vector by ID.
"""
self.client.delete(ids=[vector_id])
def update(self, vector_id, vector=None, payload=None):
"""
Update a vector and its payload.
"""
self.delete(vector_id)
self.insert(vector, payload, [vector_id])
def get(self, vector_id):
"""
Retrieve a vector by ID.
"""
docs = self.client.get_by_ids([vector_id])
if docs and len(docs) > 0:
doc = docs[0]
return self._parse_output([doc])[0]
return None
def list_cols(self):
"""
List all collections.
"""
# LangChain doesn't have collections
return [self.collection_name]
def delete_col(self):
"""
Delete a collection.
"""
logger.warning("Deleting collection")
if hasattr(self.client, "delete_collection"):
self.client.delete_collection()
elif hasattr(self.client, "reset_collection"):
self.client.reset_collection()
else:
self.client.delete(ids=None)
def col_info(self):
"""
Get information about a collection.
"""
return {"name": self.collection_name}
def list(self, filters=None, limit=None):
"""
List all vectors in a collection.
"""
try:
if hasattr(self.client, "_collection") and hasattr(self.client._collection, "get"):
# Convert mem0 filters to Chroma where clause if needed
where_clause = None
if filters:
# Handle all filters, not just user_id
where_clause = filters
result = self.client._collection.get(where=where_clause, limit=limit)
# Convert the result to the expected format
if result and isinstance(result, dict):
return [self._parse_output(result)]
return []
except Exception as e:
logger.error(f"Error listing vectors from Chroma: {e}")
return []
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting collection: {self.collection_name}")
self.delete_col()

View File

@@ -0,0 +1,247 @@
import logging
from typing import Dict, Optional
from pydantic import BaseModel
from mem0.configs.vector_stores.milvus import MetricType
from mem0.vector_stores.base import VectorStoreBase
try:
import pymilvus # noqa: F401
except ImportError:
raise ImportError("The 'pymilvus' library is required. Please install it using 'pip install pymilvus'.")
from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusClient
logger = logging.getLogger(__name__)
class OutputData(BaseModel):
id: Optional[str] # memory id
score: Optional[float] # distance
payload: Optional[Dict] # metadata
class MilvusDB(VectorStoreBase):
def __init__(
self,
url: str,
token: str,
collection_name: str,
embedding_model_dims: int,
metric_type: MetricType,
db_name: str,
) -> None:
"""Initialize the MilvusDB database.
Args:
url (str): Full URL for Milvus/Zilliz server.
token (str): Token/api_key for Zilliz server / for local setup defaults to None.
collection_name (str): Name of the collection (defaults to mem0).
embedding_model_dims (int): Dimensions of the embedding model (defaults to 1536).
metric_type (MetricType): Metric type for similarity search (defaults to L2).
db_name (str): Name of the database (defaults to "").
"""
self.collection_name = collection_name
self.embedding_model_dims = embedding_model_dims
self.metric_type = metric_type
self.client = MilvusClient(uri=url, token=token, db_name=db_name)
self.create_col(
collection_name=self.collection_name,
vector_size=self.embedding_model_dims,
metric_type=self.metric_type,
)
def create_col(
self,
collection_name: str,
vector_size: str,
metric_type: MetricType = MetricType.COSINE,
) -> None:
"""Create a new collection with index_type AUTOINDEX.
Args:
collection_name (str): Name of the collection (defaults to mem0).
vector_size (str): Dimensions of the embedding model (defaults to 1536).
metric_type (MetricType, optional): etric type for similarity search. Defaults to MetricType.COSINE.
"""
if self.client.has_collection(collection_name):
logger.info(f"Collection {collection_name} already exists. Skipping creation.")
else:
fields = [
FieldSchema(name="id", dtype=DataType.VARCHAR, is_primary=True, max_length=512),
FieldSchema(name="vectors", dtype=DataType.FLOAT_VECTOR, dim=vector_size),
FieldSchema(name="metadata", dtype=DataType.JSON),
]
schema = CollectionSchema(fields, enable_dynamic_field=True)
index = self.client.prepare_index_params(
field_name="vectors", metric_type=metric_type, index_type="AUTOINDEX", index_name="vector_index"
)
self.client.create_collection(collection_name=collection_name, schema=schema, index_params=index)
def insert(self, ids, vectors, payloads, **kwargs: Optional[dict[str, any]]):
"""Insert vectors into a collection.
Args:
vectors (List[List[float]]): List of vectors to insert.
payloads (List[Dict], optional): List of payloads corresponding to vectors.
ids (List[str], optional): List of IDs corresponding to vectors.
"""
for idx, embedding, metadata in zip(ids, vectors, payloads):
data = {"id": idx, "vectors": embedding, "metadata": metadata}
self.client.insert(collection_name=self.collection_name, data=data, **kwargs)
def _create_filter(self, filters: dict):
"""Prepare filters for efficient query.
Args:
filters (dict): filters [user_id, agent_id, run_id]
Returns:
str: formated filter.
"""
operands = []
for key, value in filters.items():
if isinstance(value, str):
operands.append(f'(metadata["{key}"] == "{value}")')
else:
operands.append(f'(metadata["{key}"] == {value})')
return " and ".join(operands)
def _parse_output(self, data: list):
"""
Parse the output data.
Args:
data (Dict): Output data.
Returns:
List[OutputData]: Parsed output data.
"""
memory = []
for value in data:
uid, score, metadata = (
value.get("id"),
value.get("distance"),
value.get("entity", {}).get("metadata"),
)
memory_obj = OutputData(id=uid, score=score, payload=metadata)
memory.append(memory_obj)
return memory
def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None) -> list:
"""
Search for similar vectors.
Args:
query (str): Query.
vectors (List[float]): Query vector.
limit (int, optional): Number of results to return. Defaults to 5.
filters (Dict, optional): Filters to apply to the search. Defaults to None.
Returns:
list: Search results.
"""
query_filter = self._create_filter(filters) if filters else None
hits = self.client.search(
collection_name=self.collection_name,
data=[vectors],
limit=limit,
filter=query_filter,
output_fields=["*"],
)
result = self._parse_output(data=hits[0])
return result
def delete(self, vector_id):
"""
Delete a vector by ID.
Args:
vector_id (str): ID of the vector to delete.
"""
self.client.delete(collection_name=self.collection_name, ids=vector_id)
def update(self, vector_id=None, vector=None, payload=None):
"""
Update a vector and its payload.
Args:
vector_id (str): ID of the vector to update.
vector (List[float], optional): Updated vector.
payload (Dict, optional): Updated payload.
"""
schema = {"id": vector_id, "vectors": vector, "metadata": payload}
self.client.upsert(collection_name=self.collection_name, data=schema)
def get(self, vector_id):
"""
Retrieve a vector by ID.
Args:
vector_id (str): ID of the vector to retrieve.
Returns:
OutputData: Retrieved vector.
"""
result = self.client.get(collection_name=self.collection_name, ids=vector_id)
output = OutputData(
id=result[0].get("id", None),
score=None,
payload=result[0].get("metadata", None),
)
return output
def list_cols(self):
"""
List all collections.
Returns:
List[str]: List of collection names.
"""
return self.client.list_collections()
def delete_col(self):
"""Delete a collection."""
return self.client.drop_collection(collection_name=self.collection_name)
def col_info(self):
"""
Get information about a collection.
Returns:
Dict[str, Any]: Collection information.
"""
return self.client.get_collection_stats(collection_name=self.collection_name)
def list(self, filters: dict = None, limit: int = 100) -> list:
"""
List all vectors in a collection.
Args:
filters (Dict, optional): Filters to apply to the list.
limit (int, optional): Number of vectors to return. Defaults to 100.
Returns:
List[OutputData]: List of vectors.
"""
query_filter = self._create_filter(filters) if filters else None
result = self.client.query(collection_name=self.collection_name, filter=query_filter, limit=limit)
memories = []
for data in result:
obj = OutputData(id=data.get("id"), score=None, payload=data.get("metadata"))
memories.append(obj)
return [memories]
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")
self.delete_col()
self.create_col(self.collection_name, self.embedding_model_dims, self.metric_type)

View File

@@ -0,0 +1,310 @@
import logging
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
try:
from pymongo import MongoClient
from pymongo.errors import PyMongoError
from pymongo.operations import SearchIndexModel
except ImportError:
raise ImportError("The 'pymongo' library is required. Please install it using 'pip install pymongo'.")
from mem0.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
class OutputData(BaseModel):
id: Optional[str]
score: Optional[float]
payload: Optional[dict]
class MongoDB(VectorStoreBase):
VECTOR_TYPE = "knnVector"
SIMILARITY_METRIC = "cosine"
def __init__(self, db_name: str, collection_name: str, embedding_model_dims: int, mongo_uri: str):
"""
Initialize the MongoDB vector store with vector search capabilities.
Args:
db_name (str): Database name
collection_name (str): Collection name
embedding_model_dims (int): Dimension of the embedding vector
mongo_uri (str): MongoDB connection URI
"""
self.collection_name = collection_name
self.embedding_model_dims = embedding_model_dims
self.db_name = db_name
self.client = MongoClient(mongo_uri)
self.db = self.client[db_name]
self.collection = self.create_col()
def create_col(self):
"""Create new collection with vector search index."""
try:
database = self.client[self.db_name]
collection_names = database.list_collection_names()
if self.collection_name not in collection_names:
logger.info(f"Collection '{self.collection_name}' does not exist. Creating it now.")
collection = database[self.collection_name]
# Insert and remove a placeholder document to create the collection
collection.insert_one({"_id": 0, "placeholder": True})
collection.delete_one({"_id": 0})
logger.info(f"Collection '{self.collection_name}' created successfully.")
else:
collection = database[self.collection_name]
self.index_name = f"{self.collection_name}_vector_index"
found_indexes = list(collection.list_search_indexes(name=self.index_name))
if found_indexes:
logger.info(f"Search index '{self.index_name}' already exists in collection '{self.collection_name}'.")
else:
search_index_model = SearchIndexModel(
name=self.index_name,
definition={
"mappings": {
"dynamic": False,
"fields": {
"embedding": {
"type": self.VECTOR_TYPE,
"dimensions": self.embedding_model_dims,
"similarity": self.SIMILARITY_METRIC,
}
},
}
},
)
collection.create_search_index(search_index_model)
logger.info(
f"Search index '{self.index_name}' created successfully for collection '{self.collection_name}'."
)
return collection
except PyMongoError as e:
logger.error(f"Error creating collection and search index: {e}")
return None
def insert(
self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None
) -> None:
"""
Insert vectors into the collection.
Args:
vectors (List[List[float]]): List of vectors to insert.
payloads (List[Dict], optional): List of payloads corresponding to vectors.
ids (List[str], optional): List of IDs corresponding to vectors.
"""
logger.info(f"Inserting {len(vectors)} vectors into collection '{self.collection_name}'.")
data = []
for vector, payload, _id in zip(vectors, payloads or [{}] * len(vectors), ids or [None] * len(vectors)):
document = {"_id": _id, "embedding": vector, "payload": payload}
data.append(document)
try:
self.collection.insert_many(data)
logger.info(f"Inserted {len(data)} documents into '{self.collection_name}'.")
except PyMongoError as e:
logger.error(f"Error inserting data: {e}")
def search(self, query: str, vectors: List[float], limit=5, filters: Optional[Dict] = None) -> List[OutputData]:
"""
Search for similar vectors using the vector search index.
Args:
query (str): Query string
vectors (List[float]): Query vector.
limit (int, optional): Number of results to return. Defaults to 5.
filters (Dict, optional): Filters to apply to the search.
Returns:
List[OutputData]: Search results.
"""
found_indexes = list(self.collection.list_search_indexes(name=self.index_name))
if not found_indexes:
logger.error(f"Index '{self.index_name}' does not exist.")
return []
results = []
try:
collection = self.client[self.db_name][self.collection_name]
pipeline = [
{
"$vectorSearch": {
"index": self.index_name,
"limit": limit,
"numCandidates": limit,
"queryVector": vectors,
"path": "embedding",
}
},
{"$set": {"score": {"$meta": "vectorSearchScore"}}},
{"$project": {"embedding": 0}},
]
# Add filter stage if filters are provided
if filters:
filter_conditions = []
for key, value in filters.items():
filter_conditions.append({"payload." + key: value})
if filter_conditions:
# Add a $match stage after vector search to apply filters
pipeline.insert(1, {"$match": {"$and": filter_conditions}})
results = list(collection.aggregate(pipeline))
logger.info(f"Vector search completed. Found {len(results)} documents.")
except Exception as e:
logger.error(f"Error during vector search for query {query}: {e}")
return []
output = [OutputData(id=str(doc["_id"]), score=doc.get("score"), payload=doc.get("payload")) for doc in results]
return output
def delete(self, vector_id: str) -> None:
"""
Delete a vector by ID.
Args:
vector_id (str): ID of the vector to delete.
"""
try:
result = self.collection.delete_one({"_id": vector_id})
if result.deleted_count > 0:
logger.info(f"Deleted document with ID '{vector_id}'.")
else:
logger.warning(f"No document found with ID '{vector_id}' to delete.")
except PyMongoError as e:
logger.error(f"Error deleting document: {e}")
def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict] = None) -> None:
"""
Update a vector and its payload.
Args:
vector_id (str): ID of the vector to update.
vector (List[float], optional): Updated vector.
payload (Dict, optional): Updated payload.
"""
update_fields = {}
if vector is not None:
update_fields["embedding"] = vector
if payload is not None:
update_fields["payload"] = payload
if update_fields:
try:
result = self.collection.update_one({"_id": vector_id}, {"$set": update_fields})
if result.matched_count > 0:
logger.info(f"Updated document with ID '{vector_id}'.")
else:
logger.warning(f"No document found with ID '{vector_id}' to update.")
except PyMongoError as e:
logger.error(f"Error updating document: {e}")
def get(self, vector_id: str) -> Optional[OutputData]:
"""
Retrieve a vector by ID.
Args:
vector_id (str): ID of the vector to retrieve.
Returns:
Optional[OutputData]: Retrieved vector or None if not found.
"""
try:
doc = self.collection.find_one({"_id": vector_id})
if doc:
logger.info(f"Retrieved document with ID '{vector_id}'.")
return OutputData(id=str(doc["_id"]), score=None, payload=doc.get("payload"))
else:
logger.warning(f"Document with ID '{vector_id}' not found.")
return None
except PyMongoError as e:
logger.error(f"Error retrieving document: {e}")
return None
def list_cols(self) -> List[str]:
"""
List all collections in the database.
Returns:
List[str]: List of collection names.
"""
try:
collections = self.db.list_collection_names()
logger.info(f"Listing collections in database '{self.db_name}': {collections}")
return collections
except PyMongoError as e:
logger.error(f"Error listing collections: {e}")
return []
def delete_col(self) -> None:
"""Delete the collection."""
try:
self.collection.drop()
logger.info(f"Deleted collection '{self.collection_name}'.")
except PyMongoError as e:
logger.error(f"Error deleting collection: {e}")
def col_info(self) -> Dict[str, Any]:
"""
Get information about the collection.
Returns:
Dict[str, Any]: Collection information.
"""
try:
stats = self.db.command("collstats", self.collection_name)
info = {"name": self.collection_name, "count": stats.get("count"), "size": stats.get("size")}
logger.info(f"Collection info: {info}")
return info
except PyMongoError as e:
logger.error(f"Error getting collection info: {e}")
return {}
def list(self, filters: Optional[Dict] = None, limit: int = 100) -> List[OutputData]:
"""
List vectors in the collection.
Args:
filters (Dict, optional): Filters to apply to the list.
limit (int, optional): Number of vectors to return.
Returns:
List[OutputData]: List of vectors.
"""
try:
query = {}
if filters:
# Apply filters to the payload field
filter_conditions = []
for key, value in filters.items():
filter_conditions.append({"payload." + key: value})
if filter_conditions:
query = {"$and": filter_conditions}
cursor = self.collection.find(query).limit(limit)
results = [OutputData(id=str(doc["_id"]), score=None, payload=doc.get("payload")) for doc in cursor]
logger.info(f"Retrieved {len(results)} documents from collection '{self.collection_name}'.")
return results
except PyMongoError as e:
logger.error(f"Error listing documents: {e}")
return []
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")
self.delete_col()
self.collection = self.create_col(self.collection_name)
def __del__(self) -> None:
"""Close the database connection when the object is deleted."""
if hasattr(self, "client"):
self.client.close()
logger.info("MongoClient connection closed.")

View File

@@ -0,0 +1,467 @@
import logging
import time
import uuid
from typing import Dict, List, Optional
from pydantic import BaseModel
try:
from langchain_aws import NeptuneAnalyticsGraph
except ImportError:
raise ImportError("langchain_aws is not installed. Please install it using pip install langchain_aws")
from mem0.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__)
class OutputData(BaseModel):
id: Optional[str] # memory id
score: Optional[float] # distance
payload: Optional[Dict] # metadata
class NeptuneAnalyticsVector(VectorStoreBase):
"""
Neptune Analytics vector store implementation for Mem0.
Provides vector storage and similarity search capabilities using Amazon Neptune Analytics,
a serverless graph analytics service that supports vector operations.
"""
_COLLECTION_PREFIX = "MEM0_VECTOR_"
_FIELD_N = 'n'
_FIELD_ID = '~id'
_FIELD_PROP = '~properties'
_FIELD_SCORE = 'score'
_FIELD_LABEL = 'label'
_TIMEZONE = "UTC"
def __init__(
self,
endpoint: str,
collection_name: str,
):
"""
Initialize the Neptune Analytics vector store.
Args:
endpoint (str): Neptune Analytics endpoint in format 'neptune-graph://<graphid>'.
collection_name (str): Name of the collection to store vectors.
Raises:
ValueError: If endpoint format is invalid.
ImportError: If langchain_aws is not installed.
"""
if not endpoint.startswith("neptune-graph://"):
raise ValueError("Please provide 'endpoint' with the format as 'neptune-graph://<graphid>'.")
graph_id = endpoint.replace("neptune-graph://", "")
self.graph = NeptuneAnalyticsGraph(graph_id)
self.collection_name = self._COLLECTION_PREFIX + collection_name
def create_col(self, name, vector_size, distance):
"""
Create a collection (no-op for Neptune Analytics).
Neptune Analytics supports dynamic indices that are created implicitly
when vectors are inserted, so this method performs no operation.
Args:
name: Collection name (unused).
vector_size: Vector dimension (unused).
distance: Distance metric (unused).
"""
pass
def insert(self, vectors: List[list],
payloads: Optional[List[Dict]] = None,
ids: Optional[List[str]] = None):
"""
Insert vectors into the collection.
Creates or updates nodes in Neptune Analytics with vector embeddings and metadata.
Uses MERGE operation to handle both creation and updates.
Args:
vectors (List[list]): List of embedding vectors to insert.
payloads (Optional[List[Dict]]): Optional metadata for each vector.
ids (Optional[List[str]]): Optional IDs for vectors. Generated if not provided.
"""
para_list = []
for index, data_vector in enumerate(vectors):
if payloads:
payload = payloads[index]
payload[self._FIELD_LABEL] = self.collection_name
payload["updated_at"] = str(int(time.time()))
else:
payload = {}
para_list.append(dict(
node_id=ids[index] if ids else str(uuid.uuid4()),
properties=payload,
embedding=data_vector,
))
para_map_to_insert = {"rows": para_list}
query_string = (f"""
UNWIND $rows AS row
MERGE (n :{self.collection_name} {{`~id`: row.node_id}})
ON CREATE SET n = row.properties
ON MATCH SET n += row.properties
"""
)
self.execute_query(query_string, para_map_to_insert)
query_string_vector = (f"""
UNWIND $rows AS row
MATCH (n
:{self.collection_name}
{{`~id`: row.node_id}})
WITH n, row.embedding AS embedding
CALL neptune.algo.vectors.upsert(n, embedding)
YIELD success
RETURN success
"""
)
result = self.execute_query(query_string_vector, para_map_to_insert)
self._process_success_message(result, "Vector store - Insert")
def search(
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
) -> List[OutputData]:
"""
Search for similar vectors using embedding similarity.
Performs vector similarity search using Neptune Analytics' topKByEmbeddingWithFiltering
algorithm to find the most similar vectors.
Args:
query (str): Search query text (unused in vector search).
vectors (List[float]): Query embedding vector.
limit (int, optional): Maximum number of results to return. Defaults to 5.
filters (Optional[Dict]): Optional filters to apply to search results.
Returns:
List[OutputData]: List of similar vectors with scores and metadata.
"""
if not filters:
filters = {}
filters[self._FIELD_LABEL] = self.collection_name
filter_clause = self._get_node_filter_clause(filters)
query_string = f"""
CALL neptune.algo.vectors.topKByEmbeddingWithFiltering({{
topK: {limit},
embedding: {vectors}
{filter_clause}
}}
)
YIELD node, score
RETURN node as n, score
"""
query_response = self.execute_query(query_string)
if len(query_response) > 0:
return self._parse_query_responses(query_response, with_score=True)
else :
return []
def delete(self, vector_id: str):
"""
Delete a vector by its ID.
Removes the node and all its relationships from the Neptune Analytics graph.
Args:
vector_id (str): ID of the vector to delete.
"""
params = dict(node_id=vector_id)
query_string = f"""
MATCH (n :{self.collection_name})
WHERE id(n) = $node_id
DETACH DELETE n
"""
self.execute_query(query_string, params)
def update(
self,
vector_id: str,
vector: Optional[List[float]] = None,
payload: Optional[Dict] = None,
):
"""
Update a vector's embedding and/or metadata.
Updates the node properties and/or vector embedding for an existing vector.
Can update either the payload, the vector, or both.
Args:
vector_id (str): ID of the vector to update.
vector (Optional[List[float]]): New embedding vector.
payload (Optional[Dict]): New metadata to replace existing payload.
"""
if payload:
# Replace payload
payload[self._FIELD_LABEL] = self.collection_name
payload["updated_at"] = str(int(time.time()))
para_payload = {
"properties": payload,
"vector_id": vector_id
}
query_string_embedding = f"""
MATCH (n :{self.collection_name})
WHERE id(n) = $vector_id
SET n = $properties
"""
self.execute_query(query_string_embedding, para_payload)
if vector:
para_embedding = {
"embedding": vector,
"vector_id": vector_id
}
query_string_embedding = f"""
MATCH (n :{self.collection_name})
WHERE id(n) = $vector_id
WITH $embedding as embedding, n as n
CALL neptune.algo.vectors.upsert(n, embedding)
YIELD success
RETURN success
"""
self.execute_query(query_string_embedding, para_embedding)
def get(self, vector_id: str):
"""
Retrieve a vector by its ID.
Fetches the node data including metadata for the specified vector ID.
Args:
vector_id (str): ID of the vector to retrieve.
Returns:
OutputData: Vector data with metadata, or None if not found.
"""
params = dict(node_id=vector_id)
query_string = f"""
MATCH (n :{self.collection_name})
WHERE id(n) = $node_id
RETURN n
"""
# Composite the query
result = self.execute_query(query_string, params)
if len(result) != 0:
return self._parse_query_responses(result)[0]
def list_cols(self):
"""
List all collections with the Mem0 prefix.
Queries the Neptune Analytics schema to find all node labels that start
with the Mem0 collection prefix.
Returns:
List[str]: List of collection names.
"""
query_string = f"""
CALL neptune.graph.pg_schema()
YIELD schema
RETURN [ label IN schema.nodeLabels WHERE label STARTS WITH '{self.collection_name}'] AS result
"""
result = self.execute_query(query_string)
if len(result) == 1 and "result" in result[0]:
return result[0]["result"]
else:
return []
def delete_col(self):
"""
Delete the entire collection.
Removes all nodes with the collection label and their relationships
from the Neptune Analytics graph.
"""
self.execute_query(f"MATCH (n :{self.collection_name}) DETACH DELETE n")
def col_info(self):
"""
Get collection information (no-op for Neptune Analytics).
Collections are created dynamically in Neptune Analytics, so no
collection-specific metadata is available.
"""
pass
def list(self, filters: Optional[Dict] = None, limit: int = 100) -> List[OutputData]:
"""
List all vectors in the collection with optional filtering.
Retrieves vectors from the collection, optionally filtered by metadata properties.
Args:
filters (Optional[Dict]): Optional filters to apply based on metadata.
limit (int, optional): Maximum number of vectors to return. Defaults to 100.
Returns:
List[OutputData]: List of vectors with their metadata.
"""
where_clause = self._get_where_clause(filters) if filters else ""
para = {
"limit": limit,
}
query_string = f"""
MATCH (n :{self.collection_name})
{where_clause}
RETURN n
LIMIT $limit
"""
query_response = self.execute_query(query_string, para)
if len(query_response) > 0:
# Handle if there is no match.
return [self._parse_query_responses(query_response)]
return [[]]
def reset(self):
"""
Reset the collection by deleting all vectors.
Removes all vectors from the collection, effectively resetting it to empty state.
"""
self.delete_col()
def _parse_query_responses(self, response: dict, with_score: bool = False):
"""
Parse Neptune Analytics query responses into OutputData objects.
Args:
response (dict): Raw query response from Neptune Analytics.
with_score (bool, optional): Whether to include similarity scores. Defaults to False.
Returns:
List[OutputData]: Parsed response data.
"""
result = []
# Handle if there is no match.
for item in response:
id = item[self._FIELD_N][self._FIELD_ID]
properties = item[self._FIELD_N][self._FIELD_PROP]
properties.pop("label", None)
if with_score:
score = item[self._FIELD_SCORE]
else:
score = None
result.append(OutputData(
id=id,
score=score,
payload=properties,
))
return result
def execute_query(self, query_string: str, params=None):
"""
Execute an openCypher query on Neptune Analytics.
This is a wrapper method around the Neptune Analytics graph query execution
that provides debug logging for query monitoring and troubleshooting.
Args:
query_string (str): The openCypher query string to execute.
params (dict): Parameters to bind to the query.
Returns:
Query result from Neptune Analytics graph execution.
"""
if params is None:
params = {}
logger.debug(f"Executing openCypher query:[{query_string}], with parameters:[{params}].")
return self.graph.query(query_string, params)
@staticmethod
def _get_where_clause(filters: dict):
"""
Build WHERE clause for Cypher queries from filters.
Args:
filters (dict): Filter conditions as key-value pairs.
Returns:
str: Formatted WHERE clause for Cypher query.
"""
where_clause = ""
for i, (k, v) in enumerate(filters.items()):
if i == 0:
where_clause += f"WHERE n.{k} = '{v}' "
else:
where_clause += f"AND n.{k} = '{v}' "
return where_clause
@staticmethod
def _get_node_filter_clause(filters: dict):
"""
Build node filter clause for vector search operations.
Creates filter conditions for Neptune Analytics vector search operations
using the nodeFilter parameter format.
Args:
filters (dict): Filter conditions as key-value pairs.
Returns:
str: Formatted node filter clause for vector search.
"""
conditions = []
for k, v in filters.items():
conditions.append(f"{{equals:{{property: '{k}', value: '{v}'}}}}")
if len(conditions) == 1:
filter_clause = f", nodeFilter: {conditions[0]}"
else:
filter_clause = f"""
, nodeFilter: {{andAll: [ {", ".join(conditions)} ]}}
"""
return filter_clause
@staticmethod
def _process_success_message(response, context):
"""
Process and validate success messages from Neptune Analytics operations.
Checks the response from vector operations (insert/update) to ensure they
completed successfully. Logs errors if operations fail.
Args:
response: Response from Neptune Analytics vector operation.
context (str): Context description for logging (e.g., "Vector store - Insert").
"""
for success_message in response:
if "success" not in success_message:
logger.error(f"Query execution status is absent on action: [{context}]")
break
if success_message["success"] is not True:
logger.error(f"Abnormal response status on action: [{context}] with message: [{success_message['success']}] ")
break

View File

@@ -0,0 +1,281 @@
import logging
import time
from typing import Any, Dict, List, Optional
try:
from opensearchpy import OpenSearch, RequestsHttpConnection
except ImportError:
raise ImportError("OpenSearch requires extra dependencies. Install with `pip install opensearch-py`") from None
from pydantic import BaseModel
from mem0.configs.vector_stores.opensearch import OpenSearchConfig
from mem0.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__)
class OutputData(BaseModel):
id: str
score: float
payload: Dict
class OpenSearchDB(VectorStoreBase):
def __init__(self, **kwargs):
config = OpenSearchConfig(**kwargs)
# Initialize OpenSearch client
self.client = OpenSearch(
hosts=[{"host": config.host, "port": config.port or 9200}],
http_auth=config.http_auth
if config.http_auth
else ((config.user, config.password) if (config.user and config.password) else None),
use_ssl=config.use_ssl,
verify_certs=config.verify_certs,
connection_class=RequestsHttpConnection,
pool_maxsize=20,
)
self.collection_name = config.collection_name
self.embedding_model_dims = config.embedding_model_dims
self.create_col(self.collection_name, self.embedding_model_dims)
def create_index(self) -> None:
"""Create OpenSearch index with proper mappings if it doesn't exist."""
index_settings = {
"settings": {
"index": {"number_of_replicas": 1, "number_of_shards": 5, "refresh_interval": "10s", "knn": True}
},
"mappings": {
"properties": {
"text": {"type": "text"},
"vector_field": {
"type": "knn_vector",
"dimension": self.embedding_model_dims,
"method": {"engine": "nmslib", "name": "hnsw", "space_type": "cosinesimil"},
},
"metadata": {"type": "object", "properties": {"user_id": {"type": "keyword"}}},
}
},
}
if not self.client.indices.exists(index=self.collection_name):
self.client.indices.create(index=self.collection_name, body=index_settings)
logger.info(f"Created index {self.collection_name}")
else:
logger.info(f"Index {self.collection_name} already exists")
def create_col(self, name: str, vector_size: int) -> None:
"""Create a new collection (index in OpenSearch)."""
index_settings = {
"settings": {"index.knn": True},
"mappings": {
"properties": {
"vector_field": {
"type": "knn_vector",
"dimension": vector_size,
"method": {"engine": "nmslib", "name": "hnsw", "space_type": "cosinesimil"},
},
"payload": {"type": "object"},
"id": {"type": "keyword"},
}
},
}
if not self.client.indices.exists(index=name):
logger.warning(f"Creating index {name}, it might take 1-2 minutes...")
self.client.indices.create(index=name, body=index_settings)
# Wait for index to be ready
max_retries = 180 # 3 minutes timeout
retry_count = 0
while retry_count < max_retries:
try:
# Check if index is ready by attempting a simple search
self.client.search(index=name, body={"query": {"match_all": {}}})
time.sleep(1)
logger.info(f"Index {name} is ready")
return
except Exception:
retry_count += 1
if retry_count == max_retries:
raise TimeoutError(f"Index {name} creation timed out after {max_retries} seconds")
time.sleep(0.5)
def insert(
self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None
) -> List[OutputData]:
"""Insert vectors into the index."""
if not ids:
ids = [str(i) for i in range(len(vectors))]
if payloads is None:
payloads = [{} for _ in range(len(vectors))]
for i, (vec, id_) in enumerate(zip(vectors, ids)):
body = {
"vector_field": vec,
"payload": payloads[i],
"id": id_,
}
self.client.index(index=self.collection_name, body=body)
results = []
return results
def search(
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
) -> List[OutputData]:
"""Search for similar vectors using OpenSearch k-NN search with optional filters."""
# Base KNN query
knn_query = {
"knn": {
"vector_field": {
"vector": vectors,
"k": limit * 2,
}
}
}
# Start building the full query
query_body = {"size": limit * 2, "query": None}
# Prepare filter conditions if applicable
filter_clauses = []
if filters:
for key in ["user_id", "run_id", "agent_id"]:
value = filters.get(key)
if value:
filter_clauses.append({"term": {f"payload.{key}.keyword": value}})
# Combine knn with filters if needed
if filter_clauses:
query_body["query"] = {"bool": {"must": knn_query, "filter": filter_clauses}}
else:
query_body["query"] = knn_query
# Execute search
response = self.client.search(index=self.collection_name, body=query_body)
hits = response["hits"]["hits"]
results = [
OutputData(id=hit["_source"].get("id"), score=hit["_score"], payload=hit["_source"].get("payload", {}))
for hit in hits
]
return results
def delete(self, vector_id: str) -> None:
"""Delete a vector by custom ID."""
# First, find the document by custom ID
search_query = {"query": {"term": {"id": vector_id}}}
response = self.client.search(index=self.collection_name, body=search_query)
hits = response.get("hits", {}).get("hits", [])
if not hits:
return
opensearch_id = hits[0]["_id"]
# Delete using the actual document ID
self.client.delete(index=self.collection_name, id=opensearch_id)
def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict] = None) -> None:
"""Update a vector and its payload using the custom 'id' field."""
# First, find the document by custom ID
search_query = {"query": {"term": {"id": vector_id}}}
response = self.client.search(index=self.collection_name, body=search_query)
hits = response.get("hits", {}).get("hits", [])
if not hits:
return
opensearch_id = hits[0]["_id"] # The actual document ID in OpenSearch
# Prepare updated fields
doc = {}
if vector is not None:
doc["vector_field"] = vector
if payload is not None:
doc["payload"] = payload
if doc:
try:
response = self.client.update(index=self.collection_name, id=opensearch_id, body={"doc": doc})
except Exception:
pass
def get(self, vector_id: str) -> Optional[OutputData]:
"""Retrieve a vector by ID."""
try:
# First check if index exists
if not self.client.indices.exists(index=self.collection_name):
logger.info(f"Index {self.collection_name} does not exist, creating it...")
self.create_col(self.collection_name, self.embedding_model_dims)
return None
search_query = {"query": {"term": {"id": vector_id}}}
response = self.client.search(index=self.collection_name, body=search_query)
hits = response["hits"]["hits"]
if not hits:
return None
return OutputData(id=hits[0]["_source"].get("id"), score=1.0, payload=hits[0]["_source"].get("payload", {}))
except Exception as e:
logger.error(f"Error retrieving vector {vector_id}: {str(e)}")
return None
def list_cols(self) -> List[str]:
"""List all collections (indices)."""
return list(self.client.indices.get_alias().keys())
def delete_col(self) -> None:
"""Delete a collection (index)."""
self.client.indices.delete(index=self.collection_name)
def col_info(self, name: str) -> Any:
"""Get information about a collection (index)."""
return self.client.indices.get(index=name)
def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None) -> List[OutputData]:
try:
"""List all memories with optional filters."""
query: Dict = {"query": {"match_all": {}}}
filter_clauses = []
if filters:
for key in ["user_id", "run_id", "agent_id"]:
value = filters.get(key)
if value:
filter_clauses.append({"term": {f"payload.{key}.keyword": value}})
if filter_clauses:
query["query"] = {"bool": {"filter": filter_clauses}}
if limit:
query["size"] = limit
response = self.client.search(index=self.collection_name, body=query)
hits = response["hits"]["hits"]
return [
[
OutputData(id=hit["_source"].get("id"), score=1.0, payload=hit["_source"].get("payload", {}))
for hit in hits
]
]
except Exception:
return []
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")
self.delete_col()
self.create_col(self.collection_name, self.embedding_model_dims)

View File

@@ -0,0 +1,404 @@
import json
import logging
from contextlib import contextmanager
from typing import Any, List, Optional
from pydantic import BaseModel
# Try to import psycopg (psycopg3) first, then fall back to psycopg2
try:
from psycopg.types.json import Json
from psycopg_pool import ConnectionPool
PSYCOPG_VERSION = 3
logger = logging.getLogger(__name__)
logger.info("Using psycopg (psycopg3) with ConnectionPool for PostgreSQL connections")
except ImportError:
try:
from psycopg2.extras import Json, execute_values
from psycopg2.pool import ThreadedConnectionPool as ConnectionPool
PSYCOPG_VERSION = 2
logger = logging.getLogger(__name__)
logger.info("Using psycopg2 with ThreadedConnectionPool for PostgreSQL connections")
except ImportError:
raise ImportError(
"Neither 'psycopg' nor 'psycopg2' library is available. "
"Please install one of them using 'pip install psycopg[pool]' or 'pip install psycopg2'"
)
from neomem.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__)
class OutputData(BaseModel):
id: Optional[str]
score: Optional[float]
payload: Optional[dict]
class PGVector(VectorStoreBase):
def __init__(
self,
dbname,
collection_name,
embedding_model_dims,
user,
password,
host,
port,
diskann,
hnsw,
minconn=1,
maxconn=5,
sslmode=None,
connection_string=None,
connection_pool=None,
):
"""
Initialize the PGVector database.
Args:
dbname (str): Database name
collection_name (str): Collection name
embedding_model_dims (int): Dimension of the embedding vector
user (str): Database user
password (str): Database password
host (str, optional): Database host
port (int, optional): Database port
diskann (bool, optional): Use DiskANN for faster search
hnsw (bool, optional): Use HNSW for faster search
minconn (int): Minimum number of connections to keep in the connection pool
maxconn (int): Maximum number of connections allowed in the connection pool
sslmode (str, optional): SSL mode for PostgreSQL connection (e.g., 'require', 'prefer', 'disable')
connection_string (str, optional): PostgreSQL connection string (overrides individual connection parameters)
connection_pool (Any, optional): psycopg2 connection pool object (overrides connection string and individual parameters)
"""
self.collection_name = collection_name
self.use_diskann = diskann
self.use_hnsw = hnsw
self.embedding_model_dims = embedding_model_dims
self.connection_pool = None
# Connection setup with priority: connection_pool > connection_string > individual parameters
if connection_pool is not None:
# Use provided connection pool
self.connection_pool = connection_pool
elif connection_string:
if sslmode:
# Append sslmode to connection string if provided
if 'sslmode=' in connection_string:
# Replace existing sslmode
import re
connection_string = re.sub(r'sslmode=[^ ]*', f'sslmode={sslmode}', connection_string)
else:
# Add sslmode to connection string
connection_string = f"{connection_string} sslmode={sslmode}"
else:
connection_string = f"postgresql://{user}:{password}@{host}:{port}/{dbname}"
if sslmode:
connection_string = f"{connection_string} sslmode={sslmode}"
if self.connection_pool is None:
if PSYCOPG_VERSION == 3:
# psycopg3 ConnectionPool
self.connection_pool = ConnectionPool(conninfo=connection_string, min_size=minconn, max_size=maxconn, open=True)
else:
# psycopg2 ThreadedConnectionPool
self.connection_pool = ConnectionPool(minconn=minconn, maxconn=maxconn, dsn=connection_string)
collections = self.list_cols()
if collection_name not in collections:
self.create_col()
@contextmanager
def _get_cursor(self, commit: bool = False):
"""
Unified context manager to get a cursor from the appropriate pool.
Auto-commits or rolls back based on exception, and returns the connection to the pool.
"""
if PSYCOPG_VERSION == 3:
# psycopg3 auto-manages commit/rollback and pool return
with self.connection_pool.connection() as conn:
with conn.cursor() as cur:
try:
yield cur
if commit:
conn.commit()
except Exception:
conn.rollback()
logger.error("Error in cursor context (psycopg3)", exc_info=True)
raise
else:
# psycopg2 manual getconn/putconn
conn = self.connection_pool.getconn()
cur = conn.cursor()
try:
yield cur
if commit:
conn.commit()
except Exception as exc:
conn.rollback()
logger.error(f"Error occurred: {exc}")
raise exc
finally:
cur.close()
self.connection_pool.putconn(conn)
def create_col(self) -> None:
"""
Create a new collection (table in PostgreSQL).
Will also initialize vector search index if specified.
"""
with self._get_cursor(commit=True) as cur:
cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
cur.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.collection_name} (
id UUID PRIMARY KEY,
vector vector({self.embedding_model_dims}),
payload JSONB
);
"""
)
if self.use_diskann and self.embedding_model_dims < 2000:
cur.execute("SELECT * FROM pg_extension WHERE extname = 'vectorscale'")
if cur.fetchone():
# Create DiskANN index if extension is installed for faster search
cur.execute(
f"""
CREATE INDEX IF NOT EXISTS {self.collection_name}_diskann_idx
ON {self.collection_name}
USING diskann (vector);
"""
)
elif self.use_hnsw:
cur.execute(
f"""
CREATE INDEX IF NOT EXISTS {self.collection_name}_hnsw_idx
ON {self.collection_name}
USING hnsw (vector vector_cosine_ops)
"""
)
def insert(self, vectors: list[list[float]], payloads=None, ids=None) -> None:
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
json_payloads = [json.dumps(payload) for payload in payloads]
data = [(id, vector, payload) for id, vector, payload in zip(ids, vectors, json_payloads)]
if PSYCOPG_VERSION == 3:
with self._get_cursor(commit=True) as cur:
cur.executemany(
f"INSERT INTO {self.collection_name} (id, vector, payload) VALUES (%s, %s, %s)",
data,
)
else:
with self._get_cursor(commit=True) as cur:
execute_values(
cur,
f"INSERT INTO {self.collection_name} (id, vector, payload) VALUES %s",
data,
)
def search(
self,
query: str,
vectors: list[float],
limit: Optional[int] = 5,
filters: Optional[dict] = None,
) -> List[OutputData]:
"""
Search for similar vectors.
Args:
query (str): Query.
vectors (List[float]): Query vector.
limit (int, optional): Number of results to return. Defaults to 5.
filters (Dict, optional): Filters to apply to the search. Defaults to None.
Returns:
list: Search results.
"""
filter_conditions = []
filter_params = []
if filters:
for k, v in filters.items():
filter_conditions.append("payload->>%s = %s")
filter_params.extend([k, str(v)])
filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
with self._get_cursor() as cur:
cur.execute(
f"""
SELECT id, vector <=> %s::vector AS distance, payload
FROM {self.collection_name}
{filter_clause}
ORDER BY distance
LIMIT %s
""",
(vectors, *filter_params, limit),
)
results = cur.fetchall()
return [OutputData(id=str(r[0]), score=float(r[1]), payload=r[2]) for r in results]
def delete(self, vector_id: str) -> None:
"""
Delete a vector by ID.
Args:
vector_id (str): ID of the vector to delete.
"""
with self._get_cursor(commit=True) as cur:
cur.execute(f"DELETE FROM {self.collection_name} WHERE id = %s", (vector_id,))
def update(
self,
vector_id: str,
vector: Optional[list[float]] = None,
payload: Optional[dict] = None,
) -> None:
"""
Update a vector and its payload.
Args:
vector_id (str): ID of the vector to update.
vector (List[float], optional): Updated vector.
payload (Dict, optional): Updated payload.
"""
with self._get_cursor(commit=True) as cur:
if vector:
cur.execute(
f"UPDATE {self.collection_name} SET vector = %s WHERE id = %s",
(vector, vector_id),
)
if payload:
# Handle JSON serialization based on psycopg version
if PSYCOPG_VERSION == 3:
# psycopg3 uses psycopg.types.json.Json
cur.execute(
f"UPDATE {self.collection_name} SET payload = %s WHERE id = %s",
(Json(payload), vector_id),
)
else:
# psycopg2 uses psycopg2.extras.Json
cur.execute(
f"UPDATE {self.collection_name} SET payload = %s WHERE id = %s",
(Json(payload), vector_id),
)
def get(self, vector_id: str) -> OutputData:
"""
Retrieve a vector by ID.
Args:
vector_id (str): ID of the vector to retrieve.
Returns:
OutputData: Retrieved vector.
"""
with self._get_cursor() as cur:
cur.execute(
f"SELECT id, vector, payload FROM {self.collection_name} WHERE id = %s",
(vector_id,),
)
result = cur.fetchone()
if not result:
return None
return OutputData(id=str(result[0]), score=None, payload=result[2])
def list_cols(self) -> List[str]:
"""
List all collections.
Returns:
List[str]: List of collection names.
"""
with self._get_cursor() as cur:
cur.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'")
return [row[0] for row in cur.fetchall()]
def delete_col(self) -> None:
"""Delete a collection."""
with self._get_cursor(commit=True) as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.collection_name}")
def col_info(self) -> dict[str, Any]:
"""
Get information about a collection.
Returns:
Dict[str, Any]: Collection information.
"""
with self._get_cursor() as cur:
cur.execute(
f"""
SELECT
table_name,
(SELECT COUNT(*) FROM {self.collection_name}) as row_count,
(SELECT pg_size_pretty(pg_total_relation_size('{self.collection_name}'))) as total_size
FROM information_schema.tables
WHERE table_schema = 'public' AND table_name = %s
""",
(self.collection_name,),
)
result = cur.fetchone()
return {"name": result[0], "count": result[1], "size": result[2]}
def list(
self,
filters: Optional[dict] = None,
limit: Optional[int] = 100
) -> List[OutputData]:
"""
List all vectors in a collection.
Args:
filters (Dict, optional): Filters to apply to the list.
limit (int, optional): Number of vectors to return. Defaults to 100.
Returns:
List[OutputData]: List of vectors.
"""
filter_conditions = []
filter_params = []
if filters:
for k, v in filters.items():
filter_conditions.append("payload->>%s = %s")
filter_params.extend([k, str(v)])
filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
query = f"""
SELECT id, vector, payload
FROM {self.collection_name}
{filter_clause}
LIMIT %s
"""
with self._get_cursor() as cur:
cur.execute(query, (*filter_params, limit))
results = cur.fetchall()
return [[OutputData(id=str(r[0]), score=None, payload=r[2]) for r in results]]
def __del__(self) -> None:
"""
Close the database connection pool when the object is deleted.
"""
try:
# Close pool appropriately
if PSYCOPG_VERSION == 3:
self.connection_pool.close()
else:
self.connection_pool.closeall()
except Exception:
pass
def reset(self) -> None:
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")
self.delete_col()
self.create_col()

View File

@@ -0,0 +1,382 @@
import logging
import os
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel
try:
from pinecone import Pinecone, PodSpec, ServerlessSpec, Vector
except ImportError:
raise ImportError(
"Pinecone requires extra dependencies. Install with `pip install pinecone pinecone-text`"
) from None
from mem0.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__)
class OutputData(BaseModel):
id: Optional[str] # memory id
score: Optional[float] # distance
payload: Optional[Dict] # metadata
class PineconeDB(VectorStoreBase):
def __init__(
self,
collection_name: str,
embedding_model_dims: int,
client: Optional["Pinecone"],
api_key: Optional[str],
environment: Optional[str],
serverless_config: Optional[Dict[str, Any]],
pod_config: Optional[Dict[str, Any]],
hybrid_search: bool,
metric: str,
batch_size: int,
extra_params: Optional[Dict[str, Any]],
namespace: Optional[str] = None,
):
"""
Initialize the Pinecone vector store.
Args:
collection_name (str): Name of the index/collection.
embedding_model_dims (int): Dimensions of the embedding model.
client (Pinecone, optional): Existing Pinecone client instance. Defaults to None.
api_key (str, optional): API key for Pinecone. Defaults to None.
environment (str, optional): Pinecone environment. Defaults to None.
serverless_config (Dict, optional): Configuration for serverless deployment. Defaults to None.
pod_config (Dict, optional): Configuration for pod-based deployment. Defaults to None.
hybrid_search (bool, optional): Whether to enable hybrid search. Defaults to False.
metric (str, optional): Distance metric for vector similarity. Defaults to "cosine".
batch_size (int, optional): Batch size for operations. Defaults to 100.
extra_params (Dict, optional): Additional parameters for Pinecone client. Defaults to None.
namespace (str, optional): Namespace for the collection. Defaults to None.
"""
if client:
self.client = client
else:
api_key = api_key or os.environ.get("PINECONE_API_KEY")
if not api_key:
raise ValueError(
"Pinecone API key must be provided either as a parameter or as an environment variable"
)
params = extra_params or {}
self.client = Pinecone(api_key=api_key, **params)
self.collection_name = collection_name
self.embedding_model_dims = embedding_model_dims
self.environment = environment
self.serverless_config = serverless_config
self.pod_config = pod_config
self.hybrid_search = hybrid_search
self.metric = metric
self.batch_size = batch_size
self.namespace = namespace
self.sparse_encoder = None
if self.hybrid_search:
try:
from pinecone_text.sparse import BM25Encoder
logger.info("Initializing BM25Encoder for sparse vectors...")
self.sparse_encoder = BM25Encoder.default()
except ImportError:
logger.warning("pinecone-text not installed. Hybrid search will be disabled.")
self.hybrid_search = False
self.create_col(embedding_model_dims, metric)
def create_col(self, vector_size: int, metric: str = "cosine"):
"""
Create a new index/collection.
Args:
vector_size (int): Size of the vectors to be stored.
metric (str, optional): Distance metric for vector similarity. Defaults to "cosine".
"""
existing_indexes = self.list_cols().names()
if self.collection_name in existing_indexes:
logger.debug(f"Index {self.collection_name} already exists. Skipping creation.")
self.index = self.client.Index(self.collection_name)
return
if self.serverless_config:
spec = ServerlessSpec(**self.serverless_config)
elif self.pod_config:
spec = PodSpec(**self.pod_config)
else:
spec = ServerlessSpec(cloud="aws", region="us-west-2")
self.client.create_index(
name=self.collection_name,
dimension=vector_size,
metric=metric,
spec=spec,
)
self.index = self.client.Index(self.collection_name)
def insert(
self,
vectors: List[List[float]],
payloads: Optional[List[Dict]] = None,
ids: Optional[List[Union[str, int]]] = None,
):
"""
Insert vectors into an index.
Args:
vectors (list): List of vectors to insert.
payloads (list, optional): List of payloads corresponding to vectors. Defaults to None.
ids (list, optional): List of IDs corresponding to vectors. Defaults to None.
"""
logger.info(f"Inserting {len(vectors)} vectors into index {self.collection_name}")
items = []
for idx, vector in enumerate(vectors):
item_id = str(ids[idx]) if ids is not None else str(idx)
payload = payloads[idx] if payloads else {}
vector_record = {"id": item_id, "values": vector, "metadata": payload}
if self.hybrid_search and self.sparse_encoder and "text" in payload:
sparse_vector = self.sparse_encoder.encode_documents(payload["text"])
vector_record["sparse_values"] = sparse_vector
items.append(vector_record)
if len(items) >= self.batch_size:
self.index.upsert(vectors=items, namespace=self.namespace)
items = []
if items:
self.index.upsert(vectors=items, namespace=self.namespace)
def _parse_output(self, data: Dict) -> List[OutputData]:
"""
Parse the output data from Pinecone search results.
Args:
data (Dict): Output data from Pinecone query.
Returns:
List[OutputData]: Parsed output data.
"""
if isinstance(data, Vector):
result = OutputData(
id=data.id,
score=0.0,
payload=data.metadata,
)
return result
else:
result = []
for match in data:
entry = OutputData(
id=match.get("id"),
score=match.get("score"),
payload=match.get("metadata"),
)
result.append(entry)
return result
def _create_filter(self, filters: Optional[Dict]) -> Dict:
"""
Create a filter dictionary from the provided filters.
"""
if not filters:
return {}
pinecone_filter = {}
for key, value in filters.items():
if isinstance(value, dict) and "gte" in value and "lte" in value:
pinecone_filter[key] = {"$gte": value["gte"], "$lte": value["lte"]}
else:
pinecone_filter[key] = {"$eq": value}
return pinecone_filter
def search(
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
) -> List[OutputData]:
"""
Search for similar vectors.
Args:
query (str): Query.
vectors (list): List of vectors to search.
limit (int, optional): Number of results to return. Defaults to 5.
filters (dict, optional): Filters to apply to the search. Defaults to None.
Returns:
list: Search results.
"""
filter_dict = self._create_filter(filters) if filters else None
query_params = {
"vector": vectors,
"top_k": limit,
"include_metadata": True,
"include_values": False,
}
if filter_dict:
query_params["filter"] = filter_dict
if self.hybrid_search and self.sparse_encoder and "text" in filters:
query_text = filters.get("text")
if query_text:
sparse_vector = self.sparse_encoder.encode_queries(query_text)
query_params["sparse_vector"] = sparse_vector
response = self.index.query(**query_params, namespace=self.namespace)
results = self._parse_output(response.matches)
return results
def delete(self, vector_id: Union[str, int]):
"""
Delete a vector by ID.
Args:
vector_id (Union[str, int]): ID of the vector to delete.
"""
self.index.delete(ids=[str(vector_id)], namespace=self.namespace)
def update(self, vector_id: Union[str, int], vector: Optional[List[float]] = None, payload: Optional[Dict] = None):
"""
Update a vector and its payload.
Args:
vector_id (Union[str, int]): ID of the vector to update.
vector (list, optional): Updated vector. Defaults to None.
payload (dict, optional): Updated payload. Defaults to None.
"""
item = {
"id": str(vector_id),
}
if vector is not None:
item["values"] = vector
if payload is not None:
item["metadata"] = payload
if self.hybrid_search and self.sparse_encoder and "text" in payload:
sparse_vector = self.sparse_encoder.encode_documents(payload["text"])
item["sparse_values"] = sparse_vector
self.index.upsert(vectors=[item], namespace=self.namespace)
def get(self, vector_id: Union[str, int]) -> OutputData:
"""
Retrieve a vector by ID.
Args:
vector_id (Union[str, int]): ID of the vector to retrieve.
Returns:
dict: Retrieved vector or None if not found.
"""
try:
response = self.index.fetch(ids=[str(vector_id)], namespace=self.namespace)
if str(vector_id) in response.vectors:
return self._parse_output(response.vectors[str(vector_id)])
return None
except Exception as e:
logger.error(f"Error retrieving vector {vector_id}: {e}")
return None
def list_cols(self):
"""
List all indexes/collections.
Returns:
list: List of index information.
"""
return self.client.list_indexes()
def delete_col(self):
"""Delete an index/collection."""
try:
self.client.delete_index(self.collection_name)
logger.info(f"Index {self.collection_name} deleted successfully")
except Exception as e:
logger.error(f"Error deleting index {self.collection_name}: {e}")
def col_info(self) -> Dict:
"""
Get information about an index/collection.
Returns:
dict: Index information.
"""
return self.client.describe_index(self.collection_name)
def list(self, filters: Optional[Dict] = None, limit: int = 100) -> List[OutputData]:
"""
List vectors in an index with optional filtering.
Args:
filters (dict, optional): Filters to apply to the list. Defaults to None.
limit (int, optional): Number of vectors to return. Defaults to 100.
Returns:
dict: List of vectors with their metadata.
"""
filter_dict = self._create_filter(filters) if filters else None
stats = self.index.describe_index_stats()
dimension = stats.dimension
zero_vector = [0.0] * dimension
query_params = {
"vector": zero_vector,
"top_k": limit,
"include_metadata": True,
"include_values": True,
}
if filter_dict:
query_params["filter"] = filter_dict
try:
response = self.index.query(**query_params, namespace=self.namespace)
response = response.to_dict()
results = self._parse_output(response["matches"])
return [results]
except Exception as e:
logger.error(f"Error listing vectors: {e}")
return {"points": [], "next_page_token": None}
def count(self) -> int:
"""
Count number of vectors in the index.
Returns:
int: Total number of vectors.
"""
stats = self.index.describe_index_stats()
if self.namespace:
# Safely get the namespace stats and return vector_count, defaulting to 0 if not found
namespace_summary = (stats.namespaces or {}).get(self.namespace)
if namespace_summary:
return namespace_summary.vector_count or 0
return 0
return stats.total_vector_count or 0
def reset(self):
"""
Reset the index by deleting and recreating it.
"""
logger.warning(f"Resetting index {self.collection_name}...")
self.delete_col()
self.create_col(self.embedding_model_dims, self.metric)

View File

@@ -0,0 +1,270 @@
import logging
import os
import shutil
from qdrant_client import QdrantClient
from qdrant_client.models import (
Distance,
FieldCondition,
Filter,
MatchValue,
PointIdsList,
PointStruct,
Range,
VectorParams,
)
from mem0.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__)
class Qdrant(VectorStoreBase):
def __init__(
self,
collection_name: str,
embedding_model_dims: int,
client: QdrantClient = None,
host: str = None,
port: int = None,
path: str = None,
url: str = None,
api_key: str = None,
on_disk: bool = False,
):
"""
Initialize the Qdrant vector store.
Args:
collection_name (str): Name of the collection.
embedding_model_dims (int): Dimensions of the embedding model.
client (QdrantClient, optional): Existing Qdrant client instance. Defaults to None.
host (str, optional): Host address for Qdrant server. Defaults to None.
port (int, optional): Port for Qdrant server. Defaults to None.
path (str, optional): Path for local Qdrant database. Defaults to None.
url (str, optional): Full URL for Qdrant server. Defaults to None.
api_key (str, optional): API key for Qdrant server. Defaults to None.
on_disk (bool, optional): Enables persistent storage. Defaults to False.
"""
if client:
self.client = client
self.is_local = False
else:
params = {}
if api_key:
params["api_key"] = api_key
if url:
params["url"] = url
if host and port:
params["host"] = host
params["port"] = port
if not params:
params["path"] = path
self.is_local = True
if not on_disk:
if os.path.exists(path) and os.path.isdir(path):
shutil.rmtree(path)
else:
self.is_local = False
self.client = QdrantClient(**params)
self.collection_name = collection_name
self.embedding_model_dims = embedding_model_dims
self.on_disk = on_disk
self.create_col(embedding_model_dims, on_disk)
def create_col(self, vector_size: int, on_disk: bool, distance: Distance = Distance.COSINE):
"""
Create a new collection.
Args:
vector_size (int): Size of the vectors to be stored.
on_disk (bool): Enables persistent storage.
distance (Distance, optional): Distance metric for vector similarity. Defaults to Distance.COSINE.
"""
# Skip creating collection if already exists
response = self.list_cols()
for collection in response.collections:
if collection.name == self.collection_name:
logger.debug(f"Collection {self.collection_name} already exists. Skipping creation.")
self._create_filter_indexes()
return
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=vector_size, distance=distance, on_disk=on_disk),
)
self._create_filter_indexes()
def _create_filter_indexes(self):
"""Create indexes for commonly used filter fields to enable filtering."""
# Only create payload indexes for remote Qdrant servers
if self.is_local:
logger.debug("Skipping payload index creation for local Qdrant (not supported)")
return
common_fields = ["user_id", "agent_id", "run_id", "actor_id"]
for field in common_fields:
try:
self.client.create_payload_index(
collection_name=self.collection_name,
field_name=field,
field_schema="keyword"
)
logger.info(f"Created index for {field} in collection {self.collection_name}")
except Exception as e:
logger.debug(f"Index for {field} might already exist: {e}")
def insert(self, vectors: list, payloads: list = None, ids: list = None):
"""
Insert vectors into a collection.
Args:
vectors (list): List of vectors to insert.
payloads (list, optional): List of payloads corresponding to vectors. Defaults to None.
ids (list, optional): List of IDs corresponding to vectors. Defaults to None.
"""
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
points = [
PointStruct(
id=idx if ids is None else ids[idx],
vector=vector,
payload=payloads[idx] if payloads else {},
)
for idx, vector in enumerate(vectors)
]
self.client.upsert(collection_name=self.collection_name, points=points)
def _create_filter(self, filters: dict) -> Filter:
"""
Create a Filter object from the provided filters.
Args:
filters (dict): Filters to apply.
Returns:
Filter: The created Filter object.
"""
if not filters:
return None
conditions = []
for key, value in filters.items():
if isinstance(value, dict) and "gte" in value and "lte" in value:
conditions.append(FieldCondition(key=key, range=Range(gte=value["gte"], lte=value["lte"])))
else:
conditions.append(FieldCondition(key=key, match=MatchValue(value=value)))
return Filter(must=conditions) if conditions else None
def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None) -> list:
"""
Search for similar vectors.
Args:
query (str): Query.
vectors (list): Query vector.
limit (int, optional): Number of results to return. Defaults to 5.
filters (dict, optional): Filters to apply to the search. Defaults to None.
Returns:
list: Search results.
"""
query_filter = self._create_filter(filters) if filters else None
hits = self.client.query_points(
collection_name=self.collection_name,
query=vectors,
query_filter=query_filter,
limit=limit,
)
return hits.points
def delete(self, vector_id: int):
"""
Delete a vector by ID.
Args:
vector_id (int): ID of the vector to delete.
"""
self.client.delete(
collection_name=self.collection_name,
points_selector=PointIdsList(
points=[vector_id],
),
)
def update(self, vector_id: int, vector: list = None, payload: dict = None):
"""
Update a vector and its payload.
Args:
vector_id (int): ID of the vector to update.
vector (list, optional): Updated vector. Defaults to None.
payload (dict, optional): Updated payload. Defaults to None.
"""
point = PointStruct(id=vector_id, vector=vector, payload=payload)
self.client.upsert(collection_name=self.collection_name, points=[point])
def get(self, vector_id: int) -> dict:
"""
Retrieve a vector by ID.
Args:
vector_id (int): ID of the vector to retrieve.
Returns:
dict: Retrieved vector.
"""
result = self.client.retrieve(collection_name=self.collection_name, ids=[vector_id], with_payload=True)
return result[0] if result else None
def list_cols(self) -> list:
"""
List all collections.
Returns:
list: List of collection names.
"""
return self.client.get_collections()
def delete_col(self):
"""Delete a collection."""
self.client.delete_collection(collection_name=self.collection_name)
def col_info(self) -> dict:
"""
Get information about a collection.
Returns:
dict: Collection information.
"""
return self.client.get_collection(collection_name=self.collection_name)
def list(self, filters: dict = None, limit: int = 100) -> list:
"""
List all vectors in a collection.
Args:
filters (dict, optional): Filters to apply to the list. Defaults to None.
limit (int, optional): Number of vectors to return. Defaults to 100.
Returns:
list: List of vectors.
"""
query_filter = self._create_filter(filters) if filters else None
result = self.client.scroll(
collection_name=self.collection_name,
scroll_filter=query_filter,
limit=limit,
with_payload=True,
with_vectors=False,
)
return result
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")
self.delete_col()
self.create_col(self.embedding_model_dims, self.on_disk)

View File

@@ -0,0 +1,295 @@
import json
import logging
from datetime import datetime
from functools import reduce
import numpy as np
import pytz
import redis
from redis.commands.search.query import Query
from redisvl.index import SearchIndex
from redisvl.query import VectorQuery
from redisvl.query.filter import Tag
from mem0.memory.utils import extract_json
from mem0.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__)
# TODO: Improve as these are not the best fields for the Redis's perspective. Might do away with them.
DEFAULT_FIELDS = [
{"name": "memory_id", "type": "tag"},
{"name": "hash", "type": "tag"},
{"name": "agent_id", "type": "tag"},
{"name": "run_id", "type": "tag"},
{"name": "user_id", "type": "tag"},
{"name": "memory", "type": "text"},
{"name": "metadata", "type": "text"},
# TODO: Although it is numeric but also accepts string
{"name": "created_at", "type": "numeric"},
{"name": "updated_at", "type": "numeric"},
{
"name": "embedding",
"type": "vector",
"attrs": {"distance_metric": "cosine", "algorithm": "flat", "datatype": "float32"},
},
]
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
class MemoryResult:
def __init__(self, id: str, payload: dict, score: float = None):
self.id = id
self.payload = payload
self.score = score
class RedisDB(VectorStoreBase):
def __init__(
self,
redis_url: str,
collection_name: str,
embedding_model_dims: int,
):
"""
Initialize the Redis vector store.
Args:
redis_url (str): Redis URL.
collection_name (str): Collection name.
embedding_model_dims (int): Embedding model dimensions.
"""
self.embedding_model_dims = embedding_model_dims
index_schema = {
"name": collection_name,
"prefix": f"mem0:{collection_name}",
}
fields = DEFAULT_FIELDS.copy()
fields[-1]["attrs"]["dims"] = embedding_model_dims
self.schema = {"index": index_schema, "fields": fields}
self.client = redis.Redis.from_url(redis_url)
self.index = SearchIndex.from_dict(self.schema)
self.index.set_client(self.client)
self.index.create(overwrite=True)
def create_col(self, name=None, vector_size=None, distance=None):
"""
Create a new collection (index) in Redis.
Args:
name (str, optional): Name for the collection. Defaults to None, which uses the current collection_name.
vector_size (int, optional): Size of the vector embeddings. Defaults to None, which uses the current embedding_model_dims.
distance (str, optional): Distance metric to use. Defaults to None, which uses 'cosine'.
Returns:
The created index object.
"""
# Use provided parameters or fall back to instance attributes
collection_name = name or self.schema["index"]["name"]
embedding_dims = vector_size or self.embedding_model_dims
distance_metric = distance or "cosine"
# Create a new schema with the specified parameters
index_schema = {
"name": collection_name,
"prefix": f"mem0:{collection_name}",
}
# Copy the default fields and update the vector field with the specified dimensions
fields = DEFAULT_FIELDS.copy()
fields[-1]["attrs"]["dims"] = embedding_dims
fields[-1]["attrs"]["distance_metric"] = distance_metric
# Create the schema
schema = {"index": index_schema, "fields": fields}
# Create the index
index = SearchIndex.from_dict(schema)
index.set_client(self.client)
index.create(overwrite=True)
# Update instance attributes if creating a new collection
if name:
self.schema = schema
self.index = index
return index
def insert(self, vectors: list, payloads: list = None, ids: list = None):
data = []
for vector, payload, id in zip(vectors, payloads, ids):
# Start with required fields
entry = {
"memory_id": id,
"hash": payload["hash"],
"memory": payload["data"],
"created_at": int(datetime.fromisoformat(payload["created_at"]).timestamp()),
"embedding": np.array(vector, dtype=np.float32).tobytes(),
}
# Conditionally add optional fields
for field in ["agent_id", "run_id", "user_id"]:
if field in payload:
entry[field] = payload[field]
# Add metadata excluding specific keys
entry["metadata"] = json.dumps({k: v for k, v in payload.items() if k not in excluded_keys})
data.append(entry)
self.index.load(data, id_field="memory_id")
def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None):
conditions = [Tag(key) == value for key, value in filters.items() if value is not None]
filter = reduce(lambda x, y: x & y, conditions)
v = VectorQuery(
vector=np.array(vectors, dtype=np.float32).tobytes(),
vector_field_name="embedding",
return_fields=["memory_id", "hash", "agent_id", "run_id", "user_id", "memory", "metadata", "created_at"],
filter_expression=filter,
num_results=limit,
)
results = self.index.query(v)
return [
MemoryResult(
id=result["memory_id"],
score=result["vector_distance"],
payload={
"hash": result["hash"],
"data": result["memory"],
"created_at": datetime.fromtimestamp(
int(result["created_at"]), tz=pytz.timezone("US/Pacific")
).isoformat(timespec="microseconds"),
**(
{
"updated_at": datetime.fromtimestamp(
int(result["updated_at"]), tz=pytz.timezone("US/Pacific")
).isoformat(timespec="microseconds")
}
if "updated_at" in result
else {}
),
**{field: result[field] for field in ["agent_id", "run_id", "user_id"] if field in result},
**{k: v for k, v in json.loads(extract_json(result["metadata"])).items()},
},
)
for result in results
]
def delete(self, vector_id):
self.index.drop_keys(f"{self.schema['index']['prefix']}:{vector_id}")
def update(self, vector_id=None, vector=None, payload=None):
data = {
"memory_id": vector_id,
"hash": payload["hash"],
"memory": payload["data"],
"created_at": int(datetime.fromisoformat(payload["created_at"]).timestamp()),
"updated_at": int(datetime.fromisoformat(payload["updated_at"]).timestamp()),
"embedding": np.array(vector, dtype=np.float32).tobytes(),
}
for field in ["agent_id", "run_id", "user_id"]:
if field in payload:
data[field] = payload[field]
data["metadata"] = json.dumps({k: v for k, v in payload.items() if k not in excluded_keys})
self.index.load(data=[data], keys=[f"{self.schema['index']['prefix']}:{vector_id}"], id_field="memory_id")
def get(self, vector_id):
result = self.index.fetch(vector_id)
payload = {
"hash": result["hash"],
"data": result["memory"],
"created_at": datetime.fromtimestamp(int(result["created_at"]), tz=pytz.timezone("US/Pacific")).isoformat(
timespec="microseconds"
),
**(
{
"updated_at": datetime.fromtimestamp(
int(result["updated_at"]), tz=pytz.timezone("US/Pacific")
).isoformat(timespec="microseconds")
}
if "updated_at" in result
else {}
),
**{field: result[field] for field in ["agent_id", "run_id", "user_id"] if field in result},
**{k: v for k, v in json.loads(extract_json(result["metadata"])).items()},
}
return MemoryResult(id=result["memory_id"], payload=payload)
def list_cols(self):
return self.index.listall()
def delete_col(self):
self.index.delete()
def col_info(self, name):
return self.index.info()
def reset(self):
"""
Reset the index by deleting and recreating it.
"""
collection_name = self.schema["index"]["name"]
logger.warning(f"Resetting index {collection_name}...")
self.delete_col()
self.index = SearchIndex.from_dict(self.schema)
self.index.set_client(self.client)
self.index.create(overwrite=True)
# or use
# self.create_col(collection_name, self.embedding_model_dims)
# Recreate the index with the same parameters
self.create_col(collection_name, self.embedding_model_dims)
def list(self, filters: dict = None, limit: int = None) -> list:
"""
List all recent created memories from the vector store.
"""
conditions = [Tag(key) == value for key, value in filters.items() if value is not None]
filter = reduce(lambda x, y: x & y, conditions)
query = Query(str(filter)).sort_by("created_at", asc=False)
if limit is not None:
query = Query(str(filter)).sort_by("created_at", asc=False).paging(0, limit)
results = self.index.search(query)
return [
[
MemoryResult(
id=result["memory_id"],
payload={
"hash": result["hash"],
"data": result["memory"],
"created_at": datetime.fromtimestamp(
int(result["created_at"]), tz=pytz.timezone("US/Pacific")
).isoformat(timespec="microseconds"),
**(
{
"updated_at": datetime.fromtimestamp(
int(result["updated_at"]), tz=pytz.timezone("US/Pacific")
).isoformat(timespec="microseconds")
}
if result.__dict__.get("updated_at")
else {}
),
**{
field: result[field]
for field in ["agent_id", "run_id", "user_id"]
if field in result.__dict__
},
**{k: v for k, v in json.loads(extract_json(result["metadata"])).items()},
},
)
for result in results.docs
]
]

View File

@@ -0,0 +1,176 @@
import json
import logging
from typing import Dict, List, Optional
from pydantic import BaseModel
from mem0.vector_stores.base import VectorStoreBase
try:
import boto3
from botocore.exceptions import ClientError
except ImportError:
raise ImportError("The 'boto3' library is required. Please install it using 'pip install boto3'.")
logger = logging.getLogger(__name__)
class OutputData(BaseModel):
id: Optional[str]
score: Optional[float]
payload: Optional[Dict]
class S3Vectors(VectorStoreBase):
def __init__(
self,
vector_bucket_name: str,
collection_name: str,
embedding_model_dims: int,
distance_metric: str = "cosine",
region_name: Optional[str] = None,
):
self.client = boto3.client("s3vectors", region_name=region_name)
self.vector_bucket_name = vector_bucket_name
self.collection_name = collection_name
self.embedding_model_dims = embedding_model_dims
self.distance_metric = distance_metric
self._ensure_bucket_exists()
self.create_col(self.collection_name, self.embedding_model_dims, self.distance_metric)
def _ensure_bucket_exists(self):
try:
self.client.get_vector_bucket(vectorBucketName=self.vector_bucket_name)
logger.info(f"Vector bucket '{self.vector_bucket_name}' already exists.")
except ClientError as e:
if e.response["Error"]["Code"] == "NotFoundException":
logger.info(f"Vector bucket '{self.vector_bucket_name}' not found. Creating it.")
self.client.create_vector_bucket(vectorBucketName=self.vector_bucket_name)
logger.info(f"Vector bucket '{self.vector_bucket_name}' created.")
else:
raise
def create_col(self, name, vector_size, distance="cosine"):
try:
self.client.get_index(vectorBucketName=self.vector_bucket_name, indexName=name)
logger.info(f"Index '{name}' already exists in bucket '{self.vector_bucket_name}'.")
except ClientError as e:
if e.response["Error"]["Code"] == "NotFoundException":
logger.info(f"Index '{name}' not found in bucket '{self.vector_bucket_name}'. Creating it.")
self.client.create_index(
vectorBucketName=self.vector_bucket_name,
indexName=name,
dataType="float32",
dimension=vector_size,
distanceMetric=distance,
)
logger.info(f"Index '{name}' created.")
else:
raise
def _parse_output(self, vectors: List[Dict]) -> List[OutputData]:
results = []
for v in vectors:
payload = v.get("metadata", {})
# Boto3 might return metadata as a JSON string
if isinstance(payload, str):
try:
payload = json.loads(payload)
except json.JSONDecodeError:
logger.warning(f"Failed to parse metadata for key {v.get('key')}")
payload = {}
results.append(OutputData(id=v.get("key"), score=v.get("distance"), payload=payload))
return results
def insert(self, vectors, payloads=None, ids=None):
vectors_to_put = []
for i, vec in enumerate(vectors):
vectors_to_put.append(
{
"key": ids[i],
"data": {"float32": vec},
"metadata": payloads[i] if payloads else {},
}
)
self.client.put_vectors(
vectorBucketName=self.vector_bucket_name,
indexName=self.collection_name,
vectors=vectors_to_put,
)
def search(self, query, vectors, limit=5, filters=None):
params = {
"vectorBucketName": self.vector_bucket_name,
"indexName": self.collection_name,
"queryVector": {"float32": vectors},
"topK": limit,
"returnMetadata": True,
"returnDistance": True,
}
if filters:
params["filter"] = filters
response = self.client.query_vectors(**params)
return self._parse_output(response.get("vectors", []))
def delete(self, vector_id):
self.client.delete_vectors(
vectorBucketName=self.vector_bucket_name,
indexName=self.collection_name,
keys=[vector_id],
)
def update(self, vector_id, vector=None, payload=None):
# S3 Vectors uses put_vectors for updates (overwrite)
self.insert(vectors=[vector], payloads=[payload], ids=[vector_id])
def get(self, vector_id) -> Optional[OutputData]:
response = self.client.get_vectors(
vectorBucketName=self.vector_bucket_name,
indexName=self.collection_name,
keys=[vector_id],
returnData=False,
returnMetadata=True,
)
vectors = response.get("vectors", [])
if not vectors:
return None
return self._parse_output(vectors)[0]
def list_cols(self):
response = self.client.list_indexes(vectorBucketName=self.vector_bucket_name)
return [idx["indexName"] for idx in response.get("indexes", [])]
def delete_col(self):
self.client.delete_index(vectorBucketName=self.vector_bucket_name, indexName=self.collection_name)
def col_info(self):
response = self.client.get_index(vectorBucketName=self.vector_bucket_name, indexName=self.collection_name)
return response.get("index", {})
def list(self, filters=None, limit=None):
# Note: list_vectors does not support metadata filtering.
if filters:
logger.warning("S3 Vectors `list` does not support metadata filtering. Ignoring filters.")
params = {
"vectorBucketName": self.vector_bucket_name,
"indexName": self.collection_name,
"returnData": False,
"returnMetadata": True,
}
if limit:
params["maxResults"] = limit
paginator = self.client.get_paginator("list_vectors")
pages = paginator.paginate(**params)
all_vectors = []
for page in pages:
all_vectors.extend(page.get("vectors", []))
return [self._parse_output(all_vectors)]
def reset(self):
logger.warning(f"Resetting index {self.collection_name}...")
self.delete_col()
self.create_col(self.collection_name, self.embedding_model_dims, self.distance_metric)

View File

@@ -0,0 +1,237 @@
import logging
import uuid
from typing import List, Optional
from pydantic import BaseModel
try:
import vecs
except ImportError:
raise ImportError("The 'vecs' library is required. Please install it using 'pip install vecs'.")
from mem0.configs.vector_stores.supabase import IndexMeasure, IndexMethod
from mem0.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__)
class OutputData(BaseModel):
id: Optional[str]
score: Optional[float]
payload: Optional[dict]
class Supabase(VectorStoreBase):
def __init__(
self,
connection_string: str,
collection_name: str,
embedding_model_dims: int,
index_method: IndexMethod = IndexMethod.AUTO,
index_measure: IndexMeasure = IndexMeasure.COSINE,
):
"""
Initialize the Supabase vector store using vecs.
Args:
connection_string (str): PostgreSQL connection string
collection_name (str): Collection name
embedding_model_dims (int): Dimension of the embedding vector
index_method (IndexMethod): Index method to use. Defaults to AUTO.
index_measure (IndexMeasure): Distance measure to use. Defaults to COSINE.
"""
self.db = vecs.create_client(connection_string)
self.collection_name = collection_name
self.embedding_model_dims = embedding_model_dims
self.index_method = index_method
self.index_measure = index_measure
collections = self.list_cols()
if collection_name not in collections:
self.create_col(embedding_model_dims)
def _preprocess_filters(self, filters: Optional[dict] = None) -> Optional[dict]:
"""
Preprocess filters to be compatible with vecs.
Args:
filters (Dict, optional): Filters to preprocess. Multiple filters will be
combined with AND logic.
"""
if filters is None:
return None
if len(filters) == 1:
# For single filter, keep the simple format
key, value = next(iter(filters.items()))
return {key: {"$eq": value}}
# For multiple filters, use $and clause
return {"$and": [{key: {"$eq": value}} for key, value in filters.items()]}
def create_col(self, embedding_model_dims: Optional[int] = None) -> None:
"""
Create a new collection with vector support.
Will also initialize vector search index.
Args:
embedding_model_dims (int, optional): Dimension of the embedding vector.
If not provided, uses the dimension specified in initialization.
"""
dims = embedding_model_dims or self.embedding_model_dims
if not dims:
raise ValueError(
"embedding_model_dims must be provided either during initialization or when creating collection"
)
logger.info(f"Creating new collection: {self.collection_name}")
try:
self.collection = self.db.get_or_create_collection(name=self.collection_name, dimension=dims)
self.collection.create_index(method=self.index_method.value, measure=self.index_measure.value)
logger.info(f"Successfully created collection {self.collection_name} with dimension {dims}")
except Exception as e:
logger.error(f"Failed to create collection: {str(e)}")
raise
def insert(
self, vectors: List[List[float]], payloads: Optional[List[dict]] = None, ids: Optional[List[str]] = None
):
"""
Insert vectors into the collection.
Args:
vectors (List[List[float]]): List of vectors to insert
payloads (List[Dict], optional): List of payloads corresponding to vectors
ids (List[str], optional): List of IDs corresponding to vectors
"""
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
if not ids:
ids = [str(uuid.uuid4()) for _ in vectors]
if not payloads:
payloads = [{} for _ in vectors]
records = [(id, vector, payload) for id, vector, payload in zip(ids, vectors, payloads)]
self.collection.upsert(records)
def search(
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[dict] = None
) -> List[OutputData]:
"""
Search for similar vectors.
Args:
query (str): Query.
vectors (List[float]): Query vector.
limit (int, optional): Number of results to return. Defaults to 5.
filters (Dict, optional): Filters to apply to the search. Defaults to None.
Returns:
List[OutputData]: Search results
"""
filters = self._preprocess_filters(filters)
results = self.collection.query(
data=vectors, limit=limit, filters=filters, include_metadata=True, include_value=True
)
return [OutputData(id=str(result[0]), score=float(result[1]), payload=result[2]) for result in results]
def delete(self, vector_id: str):
"""
Delete a vector by ID.
Args:
vector_id (str): ID of the vector to delete
"""
self.collection.delete([(vector_id,)])
def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[dict] = None):
"""
Update a vector and/or its payload.
Args:
vector_id (str): ID of the vector to update
vector (List[float], optional): Updated vector
payload (Dict, optional): Updated payload
"""
if vector is None:
# If only updating metadata, we need to get the existing vector
existing = self.get(vector_id)
if existing and existing.payload:
vector = existing.payload.get("vector", [])
if vector:
self.collection.upsert([(vector_id, vector, payload or {})])
def get(self, vector_id: str) -> Optional[OutputData]:
"""
Retrieve a vector by ID.
Args:
vector_id (str): ID of the vector to retrieve
Returns:
Optional[OutputData]: Retrieved vector data or None if not found
"""
result = self.collection.fetch([(vector_id,)])
if not result:
return []
record = result[0]
return OutputData(id=str(record.id), score=None, payload=record.metadata)
def list_cols(self) -> List[str]:
"""
List all collections.
Returns:
List[str]: List of collection names
"""
return self.db.list_collections()
def delete_col(self):
"""Delete the collection."""
self.db.delete_collection(self.collection_name)
def col_info(self) -> dict:
"""
Get information about the collection.
Returns:
Dict: Collection information including name and configuration
"""
info = self.collection.describe()
return {
"name": info.name,
"count": info.vectors,
"dimension": info.dimension,
"index": {"method": info.index_method, "metric": info.distance_metric},
}
def list(self, filters: Optional[dict] = None, limit: int = 100) -> List[OutputData]:
"""
List vectors in the collection.
Args:
filters (Dict, optional): Filters to apply
limit (int, optional): Maximum number of results to return. Defaults to 100.
Returns:
List[OutputData]: List of vectors
"""
filters = self._preprocess_filters(filters)
query = [0] * self.embedding_model_dims
ids = self.collection.query(
data=query, limit=limit, filters=filters, include_metadata=True, include_value=False
)
ids = [id[0] for id in ids]
records = self.collection.fetch(ids=ids)
return [[OutputData(id=str(record[0]), score=None, payload=record[2]) for record in records]]
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")
self.delete_col()
self.create_col(self.embedding_model_dims)

View File

@@ -0,0 +1,293 @@
import logging
from typing import Dict, List, Optional
from pydantic import BaseModel
from mem0.vector_stores.base import VectorStoreBase
try:
from upstash_vector import Index
except ImportError:
raise ImportError("The 'upstash_vector' library is required. Please install it using 'pip install upstash_vector'.")
logger = logging.getLogger(__name__)
class OutputData(BaseModel):
id: Optional[str] # memory id
score: Optional[float] # is None for `get` method
payload: Optional[Dict] # metadata
class UpstashVector(VectorStoreBase):
def __init__(
self,
collection_name: str,
url: Optional[str] = None,
token: Optional[str] = None,
client: Optional[Index] = None,
enable_embeddings: bool = False,
):
"""
Initialize the UpstashVector vector store.
Args:
url (str, optional): URL for Upstash Vector index. Defaults to None.
token (int, optional): Token for Upstash Vector index. Defaults to None.
client (Index, optional): Existing `upstash_vector.Index` client instance. Defaults to None.
namespace (str, optional): Default namespace for the index. Defaults to None.
"""
if client:
self.client = client
elif url and token:
self.client = Index(url, token)
else:
raise ValueError("Either a client or URL and token must be provided.")
self.collection_name = collection_name
self.enable_embeddings = enable_embeddings
def insert(
self,
vectors: List[list],
payloads: Optional[List[Dict]] = None,
ids: Optional[List[str]] = None,
):
"""
Insert vectors
Args:
vectors (list): List of vectors to insert.
payloads (list, optional): List of payloads corresponding to vectors. These will be passed as metadatas to the Upstash Vector client. Defaults to None.
ids (list, optional): List of IDs corresponding to vectors. Defaults to None.
"""
logger.info(f"Inserting {len(vectors)} vectors into namespace {self.collection_name}")
if self.enable_embeddings:
if not payloads or any("data" not in m or m["data"] is None for m in payloads):
raise ValueError("When embeddings are enabled, all payloads must contain a 'data' field.")
processed_vectors = [
{
"id": ids[i] if ids else None,
"data": payloads[i]["data"],
"metadata": payloads[i],
}
for i, v in enumerate(vectors)
]
else:
processed_vectors = [
{
"id": ids[i] if ids else None,
"vector": vectors[i],
"metadata": payloads[i] if payloads else None,
}
for i, v in enumerate(vectors)
]
self.client.upsert(
vectors=processed_vectors,
namespace=self.collection_name,
)
def _stringify(self, x):
return f'"{x}"' if isinstance(x, str) else x
def search(
self,
query: str,
vectors: List[list],
limit: int = 5,
filters: Optional[Dict] = None,
) -> List[OutputData]:
"""
Search for similar vectors.
Args:
query (list): Query vector.
limit (int, optional): Number of results to return. Defaults to 5.
filters (Dict, optional): Filters to apply to the search.
Returns:
List[OutputData]: Search results.
"""
filters_str = " AND ".join([f"{k} = {self._stringify(v)}" for k, v in filters.items()]) if filters else None
response = []
if self.enable_embeddings:
response = self.client.query(
data=query,
top_k=limit,
filter=filters_str or "",
include_metadata=True,
namespace=self.collection_name,
)
else:
queries = [
{
"vector": v,
"top_k": limit,
"filter": filters_str or "",
"include_metadata": True,
"namespace": self.collection_name,
}
for v in vectors
]
responses = self.client.query_many(queries=queries)
# flatten
response = [res for res_list in responses for res in res_list]
return [
OutputData(
id=res.id,
score=res.score,
payload=res.metadata,
)
for res in response
]
def delete(self, vector_id: int):
"""
Delete a vector by ID.
Args:
vector_id (int): ID of the vector to delete.
"""
self.client.delete(
ids=[str(vector_id)],
namespace=self.collection_name,
)
def update(
self,
vector_id: int,
vector: Optional[list] = None,
payload: Optional[dict] = None,
):
"""
Update a vector and its payload.
Args:
vector_id (int): ID of the vector to update.
vector (list, optional): Updated vector. Defaults to None.
payload (dict, optional): Updated payload. Defaults to None.
"""
self.client.update(
id=str(vector_id),
vector=vector,
data=payload.get("data") if payload else None,
metadata=payload,
namespace=self.collection_name,
)
def get(self, vector_id: int) -> Optional[OutputData]:
"""
Retrieve a vector by ID.
Args:
vector_id (int): ID of the vector to retrieve.
Returns:
dict: Retrieved vector.
"""
response = self.client.fetch(
ids=[str(vector_id)],
namespace=self.collection_name,
include_metadata=True,
)
if len(response) == 0:
return None
vector = response[0]
if not vector:
return None
return OutputData(id=vector.id, score=None, payload=vector.metadata)
def list(self, filters: Optional[Dict] = None, limit: int = 100) -> List[List[OutputData]]:
"""
List all memories.
Args:
filters (Dict, optional): Filters to apply to the search. Defaults to None.
limit (int, optional): Number of results to return. Defaults to 100.
Returns:
List[OutputData]: Search results.
"""
filters_str = " AND ".join([f"{k} = {self._stringify(v)}" for k, v in filters.items()]) if filters else None
info = self.client.info()
ns_info = info.namespaces.get(self.collection_name)
if not ns_info or ns_info.vector_count == 0:
return [[]]
random_vector = [1.0] * self.client.info().dimension
results, query = self.client.resumable_query(
vector=random_vector,
filter=filters_str or "",
include_metadata=True,
namespace=self.collection_name,
top_k=100,
)
with query:
while True:
if len(results) >= limit:
break
res = query.fetch_next(100)
if not res:
break
results.extend(res)
parsed_result = [
OutputData(
id=res.id,
score=res.score,
payload=res.metadata,
)
for res in results
]
return [parsed_result]
def create_col(self, name, vector_size, distance):
"""
Upstash Vector has namespaces instead of collections. A namespace is created when the first vector is inserted.
This method is a placeholder to maintain the interface.
"""
pass
def list_cols(self) -> List[str]:
"""
Lists all namespaces in the Upstash Vector index.
Returns:
List[str]: List of namespaces.
"""
return self.client.list_namespaces()
def delete_col(self):
"""
Delete the namespace and all vectors in it.
"""
self.client.reset(namespace=self.collection_name)
pass
def col_info(self):
"""
Return general information about the Upstash Vector index.
- Total number of vectors across all namespaces
- Total number of vectors waiting to be indexed across all namespaces
- Total size of the index on disk in bytes
- Vector dimension
- Similarity function used
- Per-namespace vector and pending vector counts
"""
return self.client.info()
def reset(self):
"""
Reset the Upstash Vector index.
"""
self.delete_col()

View File

@@ -0,0 +1,824 @@
import json
import logging
from datetime import datetime
from typing import Dict
import numpy as np
import pytz
import valkey
from pydantic import BaseModel
from valkey.exceptions import ResponseError
from mem0.memory.utils import extract_json
from mem0.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__)
# Default fields for the Valkey index
DEFAULT_FIELDS = [
{"name": "memory_id", "type": "tag"},
{"name": "hash", "type": "tag"},
{"name": "agent_id", "type": "tag"},
{"name": "run_id", "type": "tag"},
{"name": "user_id", "type": "tag"},
{"name": "memory", "type": "tag"}, # Using TAG instead of TEXT for Valkey compatibility
{"name": "metadata", "type": "tag"}, # Using TAG instead of TEXT for Valkey compatibility
{"name": "created_at", "type": "numeric"},
{"name": "updated_at", "type": "numeric"},
{
"name": "embedding",
"type": "vector",
"attrs": {"distance_metric": "cosine", "algorithm": "flat", "datatype": "float32"},
},
]
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
class OutputData(BaseModel):
id: str
score: float
payload: Dict
class ValkeyDB(VectorStoreBase):
def __init__(
self,
valkey_url: str,
collection_name: str,
embedding_model_dims: int,
timezone: str = "UTC",
index_type: str = "hnsw",
hnsw_m: int = 16,
hnsw_ef_construction: int = 200,
hnsw_ef_runtime: int = 10,
):
"""
Initialize the Valkey vector store.
Args:
valkey_url (str): Valkey URL.
collection_name (str): Collection name.
embedding_model_dims (int): Embedding model dimensions.
timezone (str, optional): Timezone for timestamps. Defaults to "UTC".
index_type (str, optional): Index type ('hnsw' or 'flat'). Defaults to "hnsw".
hnsw_m (int, optional): HNSW M parameter (connections per node). Defaults to 16.
hnsw_ef_construction (int, optional): HNSW ef_construction parameter. Defaults to 200.
hnsw_ef_runtime (int, optional): HNSW ef_runtime parameter. Defaults to 10.
"""
self.embedding_model_dims = embedding_model_dims
self.collection_name = collection_name
self.prefix = f"mem0:{collection_name}"
self.timezone = timezone
self.index_type = index_type.lower()
self.hnsw_m = hnsw_m
self.hnsw_ef_construction = hnsw_ef_construction
self.hnsw_ef_runtime = hnsw_ef_runtime
# Validate index type
if self.index_type not in ["hnsw", "flat"]:
raise ValueError(f"Invalid index_type: {index_type}. Must be 'hnsw' or 'flat'")
# Connect to Valkey
try:
self.client = valkey.from_url(valkey_url)
logger.debug(f"Successfully connected to Valkey at {valkey_url}")
except Exception as e:
logger.exception(f"Failed to connect to Valkey at {valkey_url}: {e}")
raise
# Create the index schema
self._create_index(embedding_model_dims)
def _build_index_schema(self, collection_name, embedding_dims, distance_metric, prefix):
"""
Build the FT.CREATE command for index creation.
Args:
collection_name (str): Name of the collection/index
embedding_dims (int): Vector embedding dimensions
distance_metric (str): Distance metric (e.g., "COSINE", "L2", "IP")
prefix (str): Key prefix for the index
Returns:
list: Complete FT.CREATE command as list of arguments
"""
# Build the vector field configuration based on index type
if self.index_type == "hnsw":
vector_config = [
"embedding",
"VECTOR",
"HNSW",
"12", # Attribute count: TYPE, FLOAT32, DIM, dims, DISTANCE_METRIC, metric, M, m, EF_CONSTRUCTION, ef_construction, EF_RUNTIME, ef_runtime
"TYPE",
"FLOAT32",
"DIM",
str(embedding_dims),
"DISTANCE_METRIC",
distance_metric,
"M",
str(self.hnsw_m),
"EF_CONSTRUCTION",
str(self.hnsw_ef_construction),
"EF_RUNTIME",
str(self.hnsw_ef_runtime),
]
elif self.index_type == "flat":
vector_config = [
"embedding",
"VECTOR",
"FLAT",
"6", # Attribute count: TYPE, FLOAT32, DIM, dims, DISTANCE_METRIC, metric
"TYPE",
"FLOAT32",
"DIM",
str(embedding_dims),
"DISTANCE_METRIC",
distance_metric,
]
else:
# This should never happen due to constructor validation, but be defensive
raise ValueError(f"Unsupported index_type: {self.index_type}. Must be 'hnsw' or 'flat'")
# Build the complete command (comma is default separator for TAG fields)
cmd = [
"FT.CREATE",
collection_name,
"ON",
"HASH",
"PREFIX",
"1",
prefix,
"SCHEMA",
"memory_id",
"TAG",
"hash",
"TAG",
"agent_id",
"TAG",
"run_id",
"TAG",
"user_id",
"TAG",
"memory",
"TAG",
"metadata",
"TAG",
"created_at",
"NUMERIC",
"updated_at",
"NUMERIC",
] + vector_config
return cmd
def _create_index(self, embedding_model_dims):
"""
Create the search index with the specified schema.
Args:
embedding_model_dims (int): Dimensions for the vector embeddings.
Raises:
ValueError: If the search module is not available.
Exception: For other errors during index creation.
"""
# Check if the search module is available
try:
# Try to execute a search command
self.client.execute_command("FT._LIST")
except ResponseError as e:
if "unknown command" in str(e).lower():
raise ValueError(
"Valkey search module is not available. Please ensure Valkey is running with the search module enabled. "
"The search module can be loaded using the --loadmodule option with the valkey-search library. "
"For installation and setup instructions, refer to the Valkey Search documentation."
)
else:
logger.exception(f"Error checking search module: {e}")
raise
# Check if the index already exists
try:
self.client.ft(self.collection_name).info()
return
except ResponseError as e:
if "not found" not in str(e).lower():
logger.exception(f"Error checking index existence: {e}")
raise
# Build and execute the index creation command
cmd = self._build_index_schema(
self.collection_name,
embedding_model_dims,
"COSINE", # Fixed distance metric for initialization
self.prefix,
)
try:
self.client.execute_command(*cmd)
logger.info(f"Successfully created {self.index_type.upper()} index {self.collection_name}")
except Exception as e:
logger.exception(f"Error creating index {self.collection_name}: {e}")
raise
def create_col(self, name=None, vector_size=None, distance=None):
"""
Create a new collection (index) in Valkey.
Args:
name (str, optional): Name for the collection. Defaults to None, which uses the current collection_name.
vector_size (int, optional): Size of the vector embeddings. Defaults to None, which uses the current embedding_model_dims.
distance (str, optional): Distance metric to use. Defaults to None, which uses 'cosine'.
Returns:
The created index object.
"""
# Use provided parameters or fall back to instance attributes
collection_name = name or self.collection_name
embedding_dims = vector_size or self.embedding_model_dims
distance_metric = distance or "COSINE"
prefix = f"mem0:{collection_name}"
# Try to drop the index if it exists (cleanup before creation)
self._drop_index(collection_name, log_level="silent")
# Build and execute the index creation command
cmd = self._build_index_schema(
collection_name,
embedding_dims,
distance_metric, # Configurable distance metric
prefix,
)
try:
self.client.execute_command(*cmd)
logger.info(f"Successfully created {self.index_type.upper()} index {collection_name}")
# Update instance attributes if creating a new collection
if name:
self.collection_name = collection_name
self.prefix = prefix
return self.client.ft(collection_name)
except Exception as e:
logger.exception(f"Error creating collection {collection_name}: {e}")
raise
def insert(self, vectors: list, payloads: list = None, ids: list = None):
"""
Insert vectors and their payloads into the index.
Args:
vectors (list): List of vectors to insert.
payloads (list, optional): List of payloads corresponding to the vectors.
ids (list, optional): List of IDs for the vectors.
"""
for vector, payload, id in zip(vectors, payloads, ids):
try:
# Create the key for the hash
key = f"{self.prefix}:{id}"
# Check for required fields and provide defaults if missing
if "data" not in payload:
# Silently use default value for missing 'data' field
pass
# Ensure created_at is present
if "created_at" not in payload:
payload["created_at"] = datetime.now(pytz.timezone(self.timezone)).isoformat()
# Prepare the hash data
hash_data = {
"memory_id": id,
"hash": payload.get("hash", f"hash_{id}"), # Use a default hash if not provided
"memory": payload.get("data", f"data_{id}"), # Use a default data if not provided
"created_at": int(datetime.fromisoformat(payload["created_at"]).timestamp()),
"embedding": np.array(vector, dtype=np.float32).tobytes(),
}
# Add optional fields
for field in ["agent_id", "run_id", "user_id"]:
if field in payload:
hash_data[field] = payload[field]
# Add metadata
hash_data["metadata"] = json.dumps({k: v for k, v in payload.items() if k not in excluded_keys})
# Store in Valkey
self.client.hset(key, mapping=hash_data)
logger.debug(f"Successfully inserted vector with ID {id}")
except KeyError as e:
logger.error(f"Error inserting vector with ID {id}: Missing required field {e}")
except Exception as e:
logger.exception(f"Error inserting vector with ID {id}: {e}")
raise
def _build_search_query(self, knn_part, filters=None):
"""
Build a search query string with filters.
Args:
knn_part (str): The KNN part of the query.
filters (dict, optional): Filters to apply to the search. Each key-value pair
becomes a tag filter (@key:{value}). None values are ignored.
Values are used as-is (no validation) - wildcards, lists, etc. are
passed through literally to Valkey search. Multiple filters are
combined with AND logic (space-separated).
Returns:
str: The complete search query string in format "filter_expr =>[KNN...]"
or "*=>[KNN...]" if no valid filters.
"""
# No filters, just use the KNN search
if not filters or not any(value is not None for key, value in filters.items()):
return f"*=>{knn_part}"
# Build filter expression
filter_parts = []
for key, value in filters.items():
if value is not None:
# Use the correct filter syntax for Valkey
filter_parts.append(f"@{key}:{{{value}}}")
# No valid filter parts
if not filter_parts:
return f"*=>{knn_part}"
# Combine filter parts with proper syntax
filter_expr = " ".join(filter_parts)
return f"{filter_expr} =>{knn_part}"
def _execute_search(self, query, params):
"""
Execute a search query.
Args:
query (str): The search query to execute.
params (dict): The query parameters.
Returns:
The search results.
"""
try:
return self.client.ft(self.collection_name).search(query, query_params=params)
except ResponseError as e:
logger.error(f"Search failed with query '{query}': {e}")
raise
def _process_search_results(self, results):
"""
Process search results into OutputData objects.
Args:
results: The search results from Valkey.
Returns:
list: List of OutputData objects.
"""
memory_results = []
for doc in results.docs:
# Extract the score
score = float(doc.vector_score) if hasattr(doc, "vector_score") else None
# Create the payload
payload = {
"hash": doc.hash,
"data": doc.memory,
"created_at": self._format_timestamp(int(doc.created_at), self.timezone),
}
# Add updated_at if available
if hasattr(doc, "updated_at"):
payload["updated_at"] = self._format_timestamp(int(doc.updated_at), self.timezone)
# Add optional fields
for field in ["agent_id", "run_id", "user_id"]:
if hasattr(doc, field):
payload[field] = getattr(doc, field)
# Add metadata
if hasattr(doc, "metadata"):
try:
metadata = json.loads(extract_json(doc.metadata))
payload.update(metadata)
except (json.JSONDecodeError, TypeError) as e:
logger.warning(f"Failed to parse metadata: {e}")
# Create the result
memory_results.append(OutputData(id=doc.memory_id, score=score, payload=payload))
return memory_results
def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None, ef_runtime: int = None):
"""
Search for similar vectors in the index.
Args:
query (str): The search query.
vectors (list): The vector to search for.
limit (int, optional): Maximum number of results to return. Defaults to 5.
filters (dict, optional): Filters to apply to the search. Defaults to None.
ef_runtime (int, optional): HNSW ef_runtime parameter for this query. Only used with HNSW index. Defaults to None.
Returns:
list: List of OutputData objects.
"""
# Convert the vector to bytes
vector_bytes = np.array(vectors, dtype=np.float32).tobytes()
# Build the KNN part with optional EF_RUNTIME for HNSW
if self.index_type == "hnsw" and ef_runtime is not None:
knn_part = f"[KNN {limit} @embedding $vec_param EF_RUNTIME {ef_runtime} AS vector_score]"
else:
# For FLAT indexes or when ef_runtime is None, use basic KNN
knn_part = f"[KNN {limit} @embedding $vec_param AS vector_score]"
# Build the complete query
q = self._build_search_query(knn_part, filters)
# Log the query for debugging (only in debug mode)
logger.debug(f"Valkey search query: {q}")
# Set up the query parameters
params = {"vec_param": vector_bytes}
# Execute the search
results = self._execute_search(q, params)
# Process the results
return self._process_search_results(results)
def delete(self, vector_id):
"""
Delete a vector from the index.
Args:
vector_id (str): ID of the vector to delete.
"""
try:
key = f"{self.prefix}:{vector_id}"
self.client.delete(key)
logger.debug(f"Successfully deleted vector with ID {vector_id}")
except Exception as e:
logger.exception(f"Error deleting vector with ID {vector_id}: {e}")
raise
def update(self, vector_id=None, vector=None, payload=None):
"""
Update a vector in the index.
Args:
vector_id (str): ID of the vector to update.
vector (list, optional): New vector data.
payload (dict, optional): New payload data.
"""
try:
key = f"{self.prefix}:{vector_id}"
# Check for required fields and provide defaults if missing
if "data" not in payload:
# Silently use default value for missing 'data' field
pass
# Ensure created_at is present
if "created_at" not in payload:
payload["created_at"] = datetime.now(pytz.timezone(self.timezone)).isoformat()
# Prepare the hash data
hash_data = {
"memory_id": vector_id,
"hash": payload.get("hash", f"hash_{vector_id}"), # Use a default hash if not provided
"memory": payload.get("data", f"data_{vector_id}"), # Use a default data if not provided
"created_at": int(datetime.fromisoformat(payload["created_at"]).timestamp()),
"embedding": np.array(vector, dtype=np.float32).tobytes(),
}
# Add updated_at if available
if "updated_at" in payload:
hash_data["updated_at"] = int(datetime.fromisoformat(payload["updated_at"]).timestamp())
# Add optional fields
for field in ["agent_id", "run_id", "user_id"]:
if field in payload:
hash_data[field] = payload[field]
# Add metadata
hash_data["metadata"] = json.dumps({k: v for k, v in payload.items() if k not in excluded_keys})
# Update in Valkey
self.client.hset(key, mapping=hash_data)
logger.debug(f"Successfully updated vector with ID {vector_id}")
except KeyError as e:
logger.error(f"Error updating vector with ID {vector_id}: Missing required field {e}")
except Exception as e:
logger.exception(f"Error updating vector with ID {vector_id}: {e}")
raise
def _format_timestamp(self, timestamp, timezone=None):
"""
Format a timestamp with the specified timezone.
Args:
timestamp (int): The timestamp to format.
timezone (str, optional): The timezone to use. Defaults to UTC.
Returns:
str: The formatted timestamp.
"""
# Use UTC as default timezone if not specified
tz = pytz.timezone(timezone or "UTC")
return datetime.fromtimestamp(timestamp, tz=tz).isoformat(timespec="microseconds")
def _process_document_fields(self, result, vector_id):
"""
Process document fields from a Valkey hash result.
Args:
result (dict): The hash result from Valkey.
vector_id (str): The vector ID.
Returns:
dict: The processed payload.
str: The memory ID.
"""
# Create the payload with error handling
payload = {}
# Convert bytes to string for text fields
for k in result:
if k not in ["embedding"]:
if isinstance(result[k], bytes):
try:
result[k] = result[k].decode("utf-8")
except UnicodeDecodeError:
# If decoding fails, keep the bytes
pass
# Add required fields with error handling
for field in ["hash", "memory", "created_at"]:
if field in result:
if field == "created_at":
try:
payload[field] = self._format_timestamp(int(result[field]), self.timezone)
except (ValueError, TypeError):
payload[field] = result[field]
else:
payload[field] = result[field]
else:
# Use default values for missing fields
if field == "hash":
payload[field] = "unknown"
elif field == "memory":
payload[field] = "unknown"
elif field == "created_at":
payload[field] = self._format_timestamp(
int(datetime.now(tz=pytz.timezone(self.timezone)).timestamp()), self.timezone
)
# Rename memory to data for consistency
if "memory" in payload:
payload["data"] = payload.pop("memory")
# Add updated_at if available
if "updated_at" in result:
try:
payload["updated_at"] = self._format_timestamp(int(result["updated_at"]), self.timezone)
except (ValueError, TypeError):
payload["updated_at"] = result["updated_at"]
# Add optional fields
for field in ["agent_id", "run_id", "user_id"]:
if field in result:
payload[field] = result[field]
# Add metadata
if "metadata" in result:
try:
metadata = json.loads(extract_json(result["metadata"]))
payload.update(metadata)
except (json.JSONDecodeError, TypeError):
logger.warning(f"Failed to parse metadata: {result.get('metadata')}")
# Use memory_id from result if available, otherwise use vector_id
memory_id = result.get("memory_id", vector_id)
return payload, memory_id
def _convert_bytes(self, data):
"""Convert bytes data back to string"""
if isinstance(data, bytes):
try:
return data.decode("utf-8")
except UnicodeDecodeError:
return data
if isinstance(data, dict):
return {self._convert_bytes(key): self._convert_bytes(value) for key, value in data.items()}
if isinstance(data, list):
return [self._convert_bytes(item) for item in data]
if isinstance(data, tuple):
return tuple(self._convert_bytes(item) for item in data)
return data
def get(self, vector_id):
"""
Get a vector by ID.
Args:
vector_id (str): ID of the vector to get.
Returns:
OutputData: The retrieved vector.
"""
try:
key = f"{self.prefix}:{vector_id}"
result = self.client.hgetall(key)
if not result:
raise KeyError(f"Vector with ID {vector_id} not found")
# Convert bytes keys/values to strings
result = self._convert_bytes(result)
logger.debug(f"Retrieved result keys: {result.keys()}")
# Process the document fields
payload, memory_id = self._process_document_fields(result, vector_id)
return OutputData(id=memory_id, payload=payload, score=0.0)
except KeyError:
raise
except Exception as e:
logger.exception(f"Error getting vector with ID {vector_id}: {e}")
raise
def list_cols(self):
"""
List all collections (indices) in Valkey.
Returns:
list: List of collection names.
"""
try:
# Use the FT._LIST command to list all indices
return self.client.execute_command("FT._LIST")
except Exception as e:
logger.exception(f"Error listing collections: {e}")
raise
def _drop_index(self, collection_name, log_level="error"):
"""
Drop an index by name using the documented FT.DROPINDEX command.
Args:
collection_name (str): Name of the index to drop.
log_level (str): Logging level for missing index ("silent", "info", "error").
"""
try:
self.client.execute_command("FT.DROPINDEX", collection_name)
logger.info(f"Successfully deleted index {collection_name}")
return True
except ResponseError as e:
if "Unknown index name" in str(e):
# Index doesn't exist - handle based on context
if log_level == "silent":
pass # No logging in situations where this is expected such as initial index creation
elif log_level == "info":
logger.info(f"Index {collection_name} doesn't exist, skipping deletion")
return False
else:
# Real error - always log and raise
logger.error(f"Error deleting index {collection_name}: {e}")
raise
except Exception as e:
# Non-ResponseError exceptions - always log and raise
logger.error(f"Error deleting index {collection_name}: {e}")
raise
def delete_col(self):
"""
Delete the current collection (index).
"""
return self._drop_index(self.collection_name, log_level="info")
def col_info(self, name=None):
"""
Get information about a collection (index).
Args:
name (str, optional): Name of the collection. Defaults to None, which uses the current collection_name.
Returns:
dict: Information about the collection.
"""
try:
collection_name = name or self.collection_name
return self.client.ft(collection_name).info()
except Exception as e:
logger.exception(f"Error getting collection info for {collection_name}: {e}")
raise
def reset(self):
"""
Reset the index by deleting and recreating it.
"""
try:
collection_name = self.collection_name
logger.warning(f"Resetting index {collection_name}...")
# Delete the index
self.delete_col()
# Recreate the index
self._create_index(self.embedding_model_dims)
return True
except Exception as e:
logger.exception(f"Error resetting index {self.collection_name}: {e}")
raise
def _build_list_query(self, filters=None):
"""
Build a query for listing vectors.
Args:
filters (dict, optional): Filters to apply to the list. Each key-value pair
becomes a tag filter (@key:{value}). None values are ignored.
Values are used as-is (no validation) - wildcards, lists, etc. are
passed through literally to Valkey search.
Returns:
str: The query string. Returns "*" if no valid filters provided.
"""
# Default query
q = "*"
# Add filters if provided
if filters and any(value is not None for key, value in filters.items()):
filter_conditions = []
for key, value in filters.items():
if value is not None:
filter_conditions.append(f"@{key}:{{{value}}}")
if filter_conditions:
q = " ".join(filter_conditions)
return q
def list(self, filters: dict = None, limit: int = None) -> list:
"""
List all recent created memories from the vector store.
Args:
filters (dict, optional): Filters to apply to the list. Each key-value pair
becomes a tag filter (@key:{value}). None values are ignored.
Values are used as-is without validation - wildcards, special characters,
lists, etc. are passed through literally to Valkey search.
Multiple filters are combined with AND logic.
limit (int, optional): Maximum number of results to return. Defaults to 1000
if not specified.
Returns:
list: Nested list format [[MemoryResult(), ...]] matching Redis implementation.
Each MemoryResult contains id and payload with hash, data, timestamps, etc.
"""
try:
# Since Valkey search requires vector format, use a dummy vector search
# that returns all documents by using a zero vector and large K
dummy_vector = [0.0] * self.embedding_model_dims
search_limit = limit if limit is not None else 1000 # Large default
# Use the existing search method which handles filters properly
search_results = self.search("", dummy_vector, limit=search_limit, filters=filters)
# Convert search results to list format (match Redis format)
class MemoryResult:
def __init__(self, id: str, payload: dict, score: float = None):
self.id = id
self.payload = payload
self.score = score
memory_results = []
for result in search_results:
# Create payload in the expected format
payload = {
"hash": result.payload.get("hash", ""),
"data": result.payload.get("data", ""),
"created_at": result.payload.get("created_at"),
"updated_at": result.payload.get("updated_at"),
}
# Add metadata (exclude system fields)
for key, value in result.payload.items():
if key not in ["data", "hash", "created_at", "updated_at"]:
payload[key] = value
# Create MemoryResult object (matching Redis format)
memory_results.append(MemoryResult(id=result.id, payload=payload))
# Return nested list format like Redis
return [memory_results]
except Exception as e:
logger.exception(f"Error in list method: {e}")
return [[]] # Return empty result on error

View File

@@ -0,0 +1,629 @@
import logging
import traceback
import uuid
from typing import Any, Dict, List, Optional, Tuple
import google.api_core.exceptions
from google.cloud import aiplatform, aiplatform_v1
from google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint import (
Namespace,
)
from google.oauth2 import service_account
from langchain.schema import Document
from pydantic import BaseModel
from mem0.configs.vector_stores.vertex_ai_vector_search import (
GoogleMatchingEngineConfig,
)
from mem0.vector_stores.base import VectorStoreBase
# Configure logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
class OutputData(BaseModel):
id: Optional[str] # memory id
score: Optional[float] # distance
payload: Optional[Dict] # metadata
class GoogleMatchingEngine(VectorStoreBase):
def __init__(self, **kwargs):
"""Initialize Google Matching Engine client."""
logger.debug("Initializing Google Matching Engine with kwargs: %s", kwargs)
# If collection_name is passed, use it as deployment_index_id if deployment_index_id is not provided
if "collection_name" in kwargs and "deployment_index_id" not in kwargs:
kwargs["deployment_index_id"] = kwargs["collection_name"]
logger.debug("Using collection_name as deployment_index_id: %s", kwargs["deployment_index_id"])
elif "deployment_index_id" in kwargs and "collection_name" not in kwargs:
kwargs["collection_name"] = kwargs["deployment_index_id"]
logger.debug("Using deployment_index_id as collection_name: %s", kwargs["collection_name"])
try:
config = GoogleMatchingEngineConfig(**kwargs)
logger.debug("Config created: %s", config.model_dump())
logger.debug("Config collection_name: %s", getattr(config, "collection_name", None))
except Exception as e:
logger.error("Failed to validate config: %s", str(e))
raise
self.project_id = config.project_id
self.project_number = config.project_number
self.region = config.region
self.endpoint_id = config.endpoint_id
self.index_id = config.index_id # The actual index ID
self.deployment_index_id = config.deployment_index_id # The deployment-specific ID
self.collection_name = config.collection_name
self.vector_search_api_endpoint = config.vector_search_api_endpoint
logger.debug("Using project=%s, location=%s", self.project_id, self.region)
# Initialize Vertex AI with credentials if provided
init_args = {
"project": self.project_id,
"location": self.region,
}
if hasattr(config, "credentials_path") and config.credentials_path:
logger.debug("Using credentials from: %s", config.credentials_path)
credentials = service_account.Credentials.from_service_account_file(config.credentials_path)
init_args["credentials"] = credentials
try:
aiplatform.init(**init_args)
logger.debug("Vertex AI initialized successfully")
except Exception as e:
logger.error("Failed to initialize Vertex AI: %s", str(e))
raise
try:
# Format the index path properly using the configured index_id
index_path = f"projects/{self.project_number}/locations/{self.region}/indexes/{self.index_id}"
logger.debug("Initializing index with path: %s", index_path)
self.index = aiplatform.MatchingEngineIndex(index_name=index_path)
logger.debug("Index initialized successfully")
# Format the endpoint name properly
endpoint_name = self.endpoint_id
logger.debug("Initializing endpoint with name: %s", endpoint_name)
self.index_endpoint = aiplatform.MatchingEngineIndexEndpoint(index_endpoint_name=endpoint_name)
logger.debug("Endpoint initialized successfully")
except Exception as e:
logger.error("Failed to initialize Matching Engine components: %s", str(e))
raise ValueError(f"Invalid configuration: {str(e)}")
def _parse_output(self, data: Dict) -> List[OutputData]:
"""
Parse the output data.
Args:
data (Dict): Output data.
Returns:
List[OutputData]: Parsed output data.
"""
results = data.get("nearestNeighbors", {}).get("neighbors", [])
output_data = []
for result in results:
output_data.append(
OutputData(
id=result.get("datapoint").get("datapointId"),
score=result.get("distance"),
payload=result.get("datapoint").get("metadata"),
)
)
return output_data
def _create_restriction(self, key: str, value: Any) -> aiplatform_v1.types.index.IndexDatapoint.Restriction:
"""Create a restriction object for the Matching Engine index.
Args:
key: The namespace/key for the restriction
value: The value to restrict on
Returns:
Restriction object for the index
"""
str_value = str(value) if value is not None else ""
return aiplatform_v1.types.index.IndexDatapoint.Restriction(namespace=key, allow_list=[str_value])
def _create_datapoint(
self, vector_id: str, vector: List[float], payload: Optional[Dict] = None
) -> aiplatform_v1.types.index.IndexDatapoint:
"""Create a datapoint object for the Matching Engine index.
Args:
vector_id: The ID for the datapoint
vector: The vector to store
payload: Optional metadata to store with the vector
Returns:
IndexDatapoint object
"""
restrictions = []
if payload:
restrictions = [self._create_restriction(key, value) for key, value in payload.items()]
return aiplatform_v1.types.index.IndexDatapoint(
datapoint_id=vector_id, feature_vector=vector, restricts=restrictions
)
def insert(
self,
vectors: List[list],
payloads: Optional[List[Dict]] = None,
ids: Optional[List[str]] = None,
) -> None:
"""Insert vectors into the Matching Engine index.
Args:
vectors: List of vectors to insert
payloads: Optional list of metadata dictionaries
ids: Optional list of IDs for the vectors
Raises:
ValueError: If vectors is empty or lengths don't match
GoogleAPIError: If the API call fails
"""
if not vectors:
raise ValueError("No vectors provided for insertion")
if payloads and len(payloads) != len(vectors):
raise ValueError(f"Number of payloads ({len(payloads)}) does not match number of vectors ({len(vectors)})")
if ids and len(ids) != len(vectors):
raise ValueError(f"Number of ids ({len(ids)}) does not match number of vectors ({len(vectors)})")
logger.debug("Starting insert of %d vectors", len(vectors))
try:
datapoints = [
self._create_datapoint(
vector_id=ids[i] if ids else str(uuid.uuid4()),
vector=vector,
payload=payloads[i] if payloads and i < len(payloads) else None,
)
for i, vector in enumerate(vectors)
]
logger.debug("Created %d datapoints", len(datapoints))
self.index.upsert_datapoints(datapoints=datapoints)
logger.debug("Successfully inserted datapoints")
except google.api_core.exceptions.GoogleAPIError as e:
logger.error("Failed to insert vectors: %s", str(e))
raise
except Exception as e:
logger.error("Unexpected error during insert: %s", str(e))
logger.error("Stack trace: %s", traceback.format_exc())
raise
def search(
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
) -> List[OutputData]:
"""
Search for similar vectors.
Args:
query (str): Query.
vectors (List[float]): Query vector.
limit (int, optional): Number of results to return. Defaults to 5.
filters (Optional[Dict], optional): Filters to apply to the search. Defaults to None.
Returns:
List[OutputData]: Search results (unwrapped)
"""
logger.debug("Starting search")
logger.debug("Limit: %d, Filters: %s", limit, filters)
try:
filter_namespaces = []
if filters:
logger.debug("Processing filters")
for key, value in filters.items():
logger.debug("Processing filter %s=%s (type=%s)", key, value, type(value))
if isinstance(value, (str, int, float)):
logger.debug("Adding simple filter for %s", key)
filter_namespaces.append(Namespace(key, [str(value)], []))
elif isinstance(value, dict):
logger.debug("Adding complex filter for %s", key)
includes = value.get("include", [])
excludes = value.get("exclude", [])
filter_namespaces.append(Namespace(key, includes, excludes))
logger.debug("Final filter_namespaces: %s", filter_namespaces)
response = self.index_endpoint.find_neighbors(
deployed_index_id=self.deployment_index_id,
queries=[vectors],
num_neighbors=limit,
filter=filter_namespaces if filter_namespaces else None,
return_full_datapoint=True,
)
if not response or len(response) == 0 or len(response[0]) == 0:
logger.debug("No results found")
return []
results = []
for neighbor in response[0]:
logger.debug("Processing neighbor - id: %s, distance: %s", neighbor.id, neighbor.distance)
payload = {}
if hasattr(neighbor, "restricts"):
logger.debug("Processing restricts")
for restrict in neighbor.restricts:
if hasattr(restrict, "name") and hasattr(restrict, "allow_tokens") and restrict.allow_tokens:
logger.debug("Adding %s: %s", restrict.name, restrict.allow_tokens[0])
payload[restrict.name] = restrict.allow_tokens[0]
output_data = OutputData(id=neighbor.id, score=neighbor.distance, payload=payload)
results.append(output_data)
logger.debug("Returning %d results", len(results))
return results
except Exception as e:
logger.error("Error occurred: %s", str(e))
logger.error("Error type: %s", type(e))
logger.error("Stack trace: %s", traceback.format_exc())
raise
def delete(self, vector_id: Optional[str] = None, ids: Optional[List[str]] = None) -> bool:
"""
Delete vectors from the Matching Engine index.
Args:
vector_id (Optional[str]): Single ID to delete (for backward compatibility)
ids (Optional[List[str]]): List of IDs of vectors to delete
Returns:
bool: True if vectors were deleted successfully or already deleted, False if error
"""
logger.debug("Starting delete, vector_id: %s, ids: %s", vector_id, ids)
try:
# Handle both single vector_id and list of ids
if vector_id:
datapoint_ids = [vector_id]
elif ids:
datapoint_ids = ids
else:
raise ValueError("Either vector_id or ids must be provided")
logger.debug("Deleting ids: %s", datapoint_ids)
try:
self.index.remove_datapoints(datapoint_ids=datapoint_ids)
logger.debug("Delete completed successfully")
return True
except google.api_core.exceptions.NotFound:
# If the datapoint is already deleted, consider it a success
logger.debug("Datapoint already deleted")
return True
except google.api_core.exceptions.PermissionDenied as e:
logger.error("Permission denied: %s", str(e))
return False
except google.api_core.exceptions.InvalidArgument as e:
logger.error("Invalid argument: %s", str(e))
return False
except Exception as e:
logger.error("Error occurred: %s", str(e))
logger.error("Error type: %s", type(e))
logger.error("Stack trace: %s", traceback.format_exc())
return False
def update(
self,
vector_id: str,
vector: Optional[List[float]] = None,
payload: Optional[Dict] = None,
) -> bool:
"""Update a vector and its payload.
Args:
vector_id: ID of the vector to update
vector: Optional new vector values
payload: Optional new metadata payload
Returns:
bool: True if update was successful
Raises:
ValueError: If neither vector nor payload is provided
GoogleAPIError: If the API call fails
"""
logger.debug("Starting update for vector_id: %s", vector_id)
if vector is None and payload is None:
raise ValueError("Either vector or payload must be provided for update")
# First check if the vector exists
try:
existing = self.get(vector_id)
if existing is None:
logger.error("Vector ID not found: %s", vector_id)
return False
datapoint = self._create_datapoint(
vector_id=vector_id, vector=vector if vector is not None else [], payload=payload
)
logger.debug("Upserting datapoint: %s", datapoint)
self.index.upsert_datapoints(datapoints=[datapoint])
logger.debug("Update completed successfully")
return True
except google.api_core.exceptions.GoogleAPIError as e:
logger.error("API error during update: %s", str(e))
return False
except Exception as e:
logger.error("Unexpected error during update: %s", str(e))
logger.error("Stack trace: %s", traceback.format_exc())
raise
def get(self, vector_id: str) -> Optional[OutputData]:
"""
Retrieve a vector by ID.
Args:
vector_id (str): ID of the vector to retrieve.
Returns:
Optional[OutputData]: Retrieved vector or None if not found.
"""
logger.debug("Starting get for vector_id: %s", vector_id)
try:
if not self.vector_search_api_endpoint:
raise ValueError("vector_search_api_endpoint is required for get operation")
vector_search_client = aiplatform_v1.MatchServiceClient(
client_options={"api_endpoint": self.vector_search_api_endpoint},
)
datapoint = aiplatform_v1.IndexDatapoint(datapoint_id=vector_id)
query = aiplatform_v1.FindNeighborsRequest.Query(datapoint=datapoint, neighbor_count=1)
request = aiplatform_v1.FindNeighborsRequest(
index_endpoint=f"projects/{self.project_number}/locations/{self.region}/indexEndpoints/{self.endpoint_id}",
deployed_index_id=self.deployment_index_id,
queries=[query],
return_full_datapoint=True,
)
try:
response = vector_search_client.find_neighbors(request)
logger.debug("Got response")
if response and response.nearest_neighbors:
nearest = response.nearest_neighbors[0]
if nearest.neighbors:
neighbor = nearest.neighbors[0]
payload = {}
if hasattr(neighbor.datapoint, "restricts"):
for restrict in neighbor.datapoint.restricts:
if restrict.allow_list:
payload[restrict.namespace] = restrict.allow_list[0]
return OutputData(id=neighbor.datapoint.datapoint_id, score=neighbor.distance, payload=payload)
logger.debug("No results found")
return None
except google.api_core.exceptions.NotFound:
logger.debug("Datapoint not found")
return None
except google.api_core.exceptions.PermissionDenied as e:
logger.error("Permission denied: %s", str(e))
return None
except Exception as e:
logger.error("Error occurred: %s", str(e))
logger.error("Error type: %s", type(e))
logger.error("Stack trace: %s", traceback.format_exc())
raise
def list_cols(self) -> List[str]:
"""
List all collections (indexes).
Returns:
List[str]: List of collection names.
"""
return [self.deployment_index_id]
def delete_col(self):
"""
Delete a collection (index).
Note: This operation is not supported through the API.
"""
logger.warning("Delete collection operation is not supported for Google Matching Engine")
pass
def col_info(self) -> Dict:
"""
Get information about a collection (index).
Returns:
Dict: Collection information.
"""
return {
"index_id": self.index_id,
"endpoint_id": self.endpoint_id,
"project_id": self.project_id,
"region": self.region,
}
def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None) -> List[List[OutputData]]:
"""List vectors matching the given filters.
Args:
filters: Optional filters to apply
limit: Optional maximum number of results to return
Returns:
List[List[OutputData]]: List of matching vectors wrapped in an extra array
to match the interface
"""
logger.debug("Starting list operation")
logger.debug("Filters: %s", filters)
logger.debug("Limit: %s", limit)
try:
# Use a zero vector for the search
dimension = 768 # This should be configurable based on the model
zero_vector = [0.0] * dimension
# Use a large limit if none specified
search_limit = limit if limit is not None else 10000
results = self.search(query=zero_vector, limit=search_limit, filters=filters)
logger.debug("Found %d results", len(results))
return [results] # Wrap in extra array to match interface
except Exception as e:
logger.error("Error in list operation: %s", str(e))
logger.error("Stack trace: %s", traceback.format_exc())
raise
def create_col(self, name=None, vector_size=None, distance=None):
"""
Create a new collection. For Google Matching Engine, collections (indexes)
are created through the Google Cloud Console or API separately.
This method is a no-op since indexes are pre-created.
Args:
name: Ignored for Google Matching Engine
vector_size: Ignored for Google Matching Engine
distance: Ignored for Google Matching Engine
"""
# Google Matching Engine indexes are created through Google Cloud Console
# This method is included only to satisfy the abstract base class
pass
def add(self, text: str, metadata: Optional[Dict] = None, user_id: Optional[str] = None) -> str:
logger.debug("Starting add operation")
logger.debug("Text: %s", text)
logger.debug("Metadata: %s", metadata)
logger.debug("User ID: %s", user_id)
try:
# Generate a unique ID for this entry
vector_id = str(uuid.uuid4())
# Create the payload with all necessary fields
payload = {
"data": text, # Store the text in the data field
"user_id": user_id,
**(metadata or {}),
}
# Get the embedding
vector = self.embedder.embed_query(text)
# Insert using the insert method
self.insert(vectors=[vector], payloads=[payload], ids=[vector_id])
return vector_id
except Exception as e:
logger.error("Error occurred: %s", str(e))
raise
def add_texts(
self,
texts: List[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
) -> List[str]:
"""Add texts to the vector store.
Args:
texts: List of texts to add
metadatas: Optional list of metadata dicts
ids: Optional list of IDs to use
Returns:
List[str]: List of IDs of the added texts
Raises:
ValueError: If texts is empty or lengths don't match
"""
if not texts:
raise ValueError("No texts provided")
if metadatas and len(metadatas) != len(texts):
raise ValueError(
f"Number of metadata items ({len(metadatas)}) does not match number of texts ({len(texts)})"
)
if ids and len(ids) != len(texts):
raise ValueError(f"Number of ids ({len(ids)}) does not match number of texts ({len(texts)})")
logger.debug("Starting add_texts operation")
logger.debug("Number of texts: %d", len(texts))
logger.debug("Has metadatas: %s", metadatas is not None)
logger.debug("Has ids: %s", ids is not None)
if ids is None:
ids = [str(uuid.uuid4()) for _ in texts]
try:
# Get embeddings
embeddings = self.embedder.embed_documents(texts)
# Add to store
self.insert(vectors=embeddings, payloads=metadatas if metadatas else [{}] * len(texts), ids=ids)
return ids
except Exception as e:
logger.error("Error in add_texts: %s", str(e))
logger.error("Stack trace: %s", traceback.format_exc())
raise
@classmethod
def from_texts(
cls,
texts: List[str],
embedding: Any,
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> "GoogleMatchingEngine":
"""Create an instance from texts."""
logger.debug("Creating instance from texts")
store = cls(**kwargs)
store.add_texts(texts=texts, metadatas=metadatas, ids=ids)
return store
def similarity_search_with_score(
self,
query: str,
k: int = 5,
filter: Optional[Dict] = None,
) -> List[Tuple[Document, float]]:
"""Return documents most similar to query with scores."""
logger.debug("Starting similarity search with score")
logger.debug("Query: %s", query)
logger.debug("k: %d", k)
logger.debug("Filter: %s", filter)
embedding = self.embedder.embed_query(query)
results = self.search(query=embedding, limit=k, filters=filter)
docs_and_scores = [
(Document(page_content=result.payload.get("text", ""), metadata=result.payload), result.score)
for result in results
]
logger.debug("Found %d results", len(docs_and_scores))
return docs_and_scores
def similarity_search(
self,
query: str,
k: int = 5,
filter: Optional[Dict] = None,
) -> List[Document]:
"""Return documents most similar to query."""
logger.debug("Starting similarity search")
docs_and_scores = self.similarity_search_with_score(query, k, filter)
return [doc for doc, _ in docs_and_scores]
def reset(self):
"""
Reset the Google Matching Engine index.
"""
logger.warning("Reset operation is not supported for Google Matching Engine")
pass

View File

@@ -0,0 +1,343 @@
import logging
import uuid
from typing import Dict, List, Mapping, Optional
from urllib.parse import urlparse
from pydantic import BaseModel
try:
import weaviate
except ImportError:
raise ImportError(
"The 'weaviate' library is required. Please install it using 'pip install weaviate-client weaviate'."
)
import weaviate.classes.config as wvcc
from weaviate.classes.init import AdditionalConfig, Auth, Timeout
from weaviate.classes.query import Filter, MetadataQuery
from weaviate.util import get_valid_uuid
from mem0.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__)
class OutputData(BaseModel):
id: str
score: float
payload: Dict
class Weaviate(VectorStoreBase):
def __init__(
self,
collection_name: str,
embedding_model_dims: int,
cluster_url: str = None,
auth_client_secret: str = None,
additional_headers: dict = None,
):
"""
Initialize the Weaviate vector store.
Args:
collection_name (str): Name of the collection/class in Weaviate.
embedding_model_dims (int): Dimensions of the embedding model.
client (WeaviateClient, optional): Existing Weaviate client instance. Defaults to None.
cluster_url (str, optional): URL for Weaviate server. Defaults to None.
auth_config (dict, optional): Authentication configuration for Weaviate. Defaults to None.
additional_headers (dict, optional): Additional headers for requests. Defaults to None.
"""
if "localhost" in cluster_url:
self.client = weaviate.connect_to_local(headers=additional_headers)
elif auth_client_secret:
self.client = weaviate.connect_to_weaviate_cloud(
cluster_url=cluster_url,
auth_credentials=Auth.api_key(auth_client_secret),
headers=additional_headers,
)
else:
parsed = urlparse(cluster_url) # e.g., http://mem0_store:8080
http_host = parsed.hostname or "localhost"
http_port = parsed.port or (443 if parsed.scheme == "https" else 8080)
http_secure = parsed.scheme == "https"
# Weaviate gRPC defaults (inside Docker network)
grpc_host = http_host
grpc_port = 50051
grpc_secure = False
self.client = weaviate.connect_to_custom(
http_host,
http_port,
http_secure,
grpc_host,
grpc_port,
grpc_secure,
headers=additional_headers,
skip_init_checks=True,
additional_config=AdditionalConfig(timeout=Timeout(init=2.0)),
)
self.collection_name = collection_name
self.embedding_model_dims = embedding_model_dims
self.create_col(embedding_model_dims)
def _parse_output(self, data: Dict) -> List[OutputData]:
"""
Parse the output data.
Args:
data (Dict): Output data.
Returns:
List[OutputData]: Parsed output data.
"""
keys = ["ids", "distances", "metadatas"]
values = []
for key in keys:
value = data.get(key, [])
if isinstance(value, list) and value and isinstance(value[0], list):
value = value[0]
values.append(value)
ids, distances, metadatas = values
max_length = max(len(v) for v in values if isinstance(v, list) and v is not None)
result = []
for i in range(max_length):
entry = OutputData(
id=ids[i] if isinstance(ids, list) and ids and i < len(ids) else None,
score=(distances[i] if isinstance(distances, list) and distances and i < len(distances) else None),
payload=(metadatas[i] if isinstance(metadatas, list) and metadatas and i < len(metadatas) else None),
)
result.append(entry)
return result
def create_col(self, vector_size, distance="cosine"):
"""
Create a new collection with the specified schema.
Args:
vector_size (int): Size of the vectors to be stored.
distance (str, optional): Distance metric for vector similarity. Defaults to "cosine".
"""
if self.client.collections.exists(self.collection_name):
logger.debug(f"Collection {self.collection_name} already exists. Skipping creation.")
return
properties = [
wvcc.Property(name="ids", data_type=wvcc.DataType.TEXT),
wvcc.Property(name="hash", data_type=wvcc.DataType.TEXT),
wvcc.Property(
name="metadata",
data_type=wvcc.DataType.TEXT,
description="Additional metadata",
),
wvcc.Property(name="data", data_type=wvcc.DataType.TEXT),
wvcc.Property(name="created_at", data_type=wvcc.DataType.TEXT),
wvcc.Property(name="category", data_type=wvcc.DataType.TEXT),
wvcc.Property(name="updated_at", data_type=wvcc.DataType.TEXT),
wvcc.Property(name="user_id", data_type=wvcc.DataType.TEXT),
wvcc.Property(name="agent_id", data_type=wvcc.DataType.TEXT),
wvcc.Property(name="run_id", data_type=wvcc.DataType.TEXT),
]
vectorizer_config = wvcc.Configure.Vectorizer.none()
vector_index_config = wvcc.Configure.VectorIndex.hnsw()
self.client.collections.create(
self.collection_name,
vectorizer_config=vectorizer_config,
vector_index_config=vector_index_config,
properties=properties,
)
def insert(self, vectors, payloads=None, ids=None):
"""
Insert vectors into a collection.
Args:
vectors (list): List of vectors to insert.
payloads (list, optional): List of payloads corresponding to vectors. Defaults to None.
ids (list, optional): List of IDs corresponding to vectors. Defaults to None.
"""
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
with self.client.batch.fixed_size(batch_size=100) as batch:
for idx, vector in enumerate(vectors):
object_id = ids[idx] if ids and idx < len(ids) else str(uuid.uuid4())
object_id = get_valid_uuid(object_id)
data_object = payloads[idx] if payloads and idx < len(payloads) else {}
# Ensure 'id' is not included in properties (it's used as the Weaviate object ID)
if "ids" in data_object:
del data_object["ids"]
batch.add_object(collection=self.collection_name, properties=data_object, uuid=object_id, vector=vector)
def search(
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
) -> List[OutputData]:
"""
Search for similar vectors.
"""
collection = self.client.collections.get(str(self.collection_name))
filter_conditions = []
if filters:
for key, value in filters.items():
if value and key in ["user_id", "agent_id", "run_id"]:
filter_conditions.append(Filter.by_property(key).equal(value))
combined_filter = Filter.all_of(filter_conditions) if filter_conditions else None
response = collection.query.hybrid(
query="",
vector=vectors,
limit=limit,
filters=combined_filter,
return_properties=["hash", "created_at", "updated_at", "user_id", "agent_id", "run_id", "data", "category"],
return_metadata=MetadataQuery(score=True),
)
results = []
for obj in response.objects:
payload = obj.properties.copy()
for id_field in ["run_id", "agent_id", "user_id"]:
if id_field in payload and payload[id_field] is None:
del payload[id_field]
payload["id"] = str(obj.uuid).split("'")[0] # Include the id in the payload
if obj.metadata.distance is not None:
score = 1 - obj.metadata.distance # Convert distance to similarity score
elif obj.metadata.score is not None:
score = obj.metadata.score
else:
score = 1.0 # Default score if none provided
results.append(
OutputData(
id=str(obj.uuid),
score=score,
payload=payload,
)
)
return results
def delete(self, vector_id):
"""
Delete a vector by ID.
Args:
vector_id: ID of the vector to delete.
"""
collection = self.client.collections.get(str(self.collection_name))
collection.data.delete_by_id(vector_id)
def update(self, vector_id, vector=None, payload=None):
"""
Update a vector and its payload.
Args:
vector_id: ID of the vector to update.
vector (list, optional): Updated vector. Defaults to None.
payload (dict, optional): Updated payload. Defaults to None.
"""
collection = self.client.collections.get(str(self.collection_name))
if payload:
collection.data.update(uuid=vector_id, properties=payload)
if vector:
existing_data = self.get(vector_id)
if existing_data:
existing_data = dict(existing_data)
if "id" in existing_data:
del existing_data["id"]
existing_payload: Mapping[str, str] = existing_data
collection.data.update(uuid=vector_id, properties=existing_payload, vector=vector)
def get(self, vector_id):
"""
Retrieve a vector by ID.
Args:
vector_id: ID of the vector to retrieve.
Returns:
dict: Retrieved vector and metadata.
"""
vector_id = get_valid_uuid(vector_id)
collection = self.client.collections.get(str(self.collection_name))
response = collection.query.fetch_object_by_id(
uuid=vector_id,
return_properties=["hash", "created_at", "updated_at", "user_id", "agent_id", "run_id", "data", "category"],
)
# results = {}
# print("reponse",response)
# for obj in response.objects:
payload = response.properties.copy()
payload["id"] = str(response.uuid).split("'")[0]
results = OutputData(
id=str(response.uuid).split("'")[0],
score=1.0,
payload=payload,
)
return results
def list_cols(self):
"""
List all collections.
Returns:
list: List of collection names.
"""
collections = self.client.collections.list_all()
logger.debug(f"collections: {collections}")
print(f"collections: {collections}")
return {"collections": [{"name": col.name} for col in collections]}
def delete_col(self):
"""Delete a collection."""
self.client.collections.delete(self.collection_name)
def col_info(self):
"""
Get information about a collection.
Returns:
dict: Collection information.
"""
schema = self.client.collections.get(self.collection_name)
if schema:
return schema
return None
def list(self, filters=None, limit=100) -> List[OutputData]:
"""
List all vectors in a collection.
"""
collection = self.client.collections.get(self.collection_name)
filter_conditions = []
if filters:
for key, value in filters.items():
if value and key in ["user_id", "agent_id", "run_id"]:
filter_conditions.append(Filter.by_property(key).equal(value))
combined_filter = Filter.all_of(filter_conditions) if filter_conditions else None
response = collection.query.fetch_objects(
limit=limit,
filters=combined_filter,
return_properties=["hash", "created_at", "updated_at", "user_id", "agent_id", "run_id", "data", "category"],
)
results = []
for obj in response.objects:
payload = obj.properties.copy()
payload["id"] = str(obj.uuid).split("'")[0]
results.append(OutputData(id=str(obj.uuid).split("'")[0], score=1.0, payload=payload))
return [results]
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")
self.delete_col()
self.create_col()