autonomy phase 2.5 - tightening up some stuff in the pipeline

This commit is contained in:
serversdwn
2025-12-15 01:56:57 -05:00
parent 193bf814ec
commit d4fd393f52
3 changed files with 47 additions and 12 deletions

View File

@@ -1,7 +1,9 @@
import logging
import os
import shutil
from typing import Optional
from pydantic import BaseModel
from qdrant_client import QdrantClient
from qdrant_client.models import (
Distance,
@@ -19,6 +21,13 @@ from mem0.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__)
class OutputData(BaseModel):
"""Standard output format for vector search results."""
id: Optional[str]
score: Optional[float]
payload: Optional[dict]
class Qdrant(VectorStoreBase):
def __init__(
self,
@@ -170,7 +179,7 @@ class Qdrant(VectorStoreBase):
filters (dict, optional): Filters to apply to the search. Defaults to None.
Returns:
list: Search results.
list: Search results wrapped in OutputData format.
"""
query_filter = self._create_filter(filters) if filters else None
hits = self.client.query_points(
@@ -179,7 +188,16 @@ class Qdrant(VectorStoreBase):
query_filter=query_filter,
limit=limit,
)
return hits.points
# Wrap results in OutputData format to match other vector stores
return [
OutputData(
id=str(hit.id),
score=hit.score,
payload=hit.payload
)
for hit in hits.points
]
def delete(self, vector_id: int):
"""
@@ -207,7 +225,7 @@ class Qdrant(VectorStoreBase):
point = PointStruct(id=vector_id, vector=vector, payload=payload)
self.client.upsert(collection_name=self.collection_name, points=[point])
def get(self, vector_id: int) -> dict:
def get(self, vector_id: int) -> OutputData:
"""
Retrieve a vector by ID.
@@ -215,10 +233,17 @@ class Qdrant(VectorStoreBase):
vector_id (int): ID of the vector to retrieve.
Returns:
dict: Retrieved vector.
OutputData: Retrieved vector wrapped in OutputData format.
"""
result = self.client.retrieve(collection_name=self.collection_name, ids=[vector_id], with_payload=True)
return result[0] if result else None
if result:
hit = result[0]
return OutputData(
id=str(hit.id),
score=None, # No score for direct retrieval
payload=hit.payload
)
return None
def list_cols(self) -> list:
"""
@@ -251,7 +276,7 @@ class Qdrant(VectorStoreBase):
limit (int, optional): Number of vectors to return. Defaults to 100.
Returns:
list: List of vectors.
list: List of vectors wrapped in OutputData format.
"""
query_filter = self._create_filter(filters) if filters else None
result = self.client.scroll(
@@ -261,7 +286,18 @@ class Qdrant(VectorStoreBase):
with_payload=True,
with_vectors=False,
)
return result
# Wrap results in OutputData format
# scroll() returns tuple: (points, next_page_offset)
points = result[0] if isinstance(result, tuple) else result
return [
OutputData(
id=str(point.id),
score=None, # No score for list operation
payload=point.payload
)
for point in points
]
def reset(self):
"""Reset the index by deleting and recreating it."""