177 lines
6.5 KiB
Python
177 lines
6.5 KiB
Python
import json
|
|
import logging
|
|
from typing import Dict, List, Optional
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from mem0.vector_stores.base import VectorStoreBase
|
|
|
|
try:
|
|
import boto3
|
|
from botocore.exceptions import ClientError
|
|
except ImportError:
|
|
raise ImportError("The 'boto3' library is required. Please install it using 'pip install boto3'.")
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class OutputData(BaseModel):
|
|
id: Optional[str]
|
|
score: Optional[float]
|
|
payload: Optional[Dict]
|
|
|
|
|
|
class S3Vectors(VectorStoreBase):
|
|
def __init__(
|
|
self,
|
|
vector_bucket_name: str,
|
|
collection_name: str,
|
|
embedding_model_dims: int,
|
|
distance_metric: str = "cosine",
|
|
region_name: Optional[str] = None,
|
|
):
|
|
self.client = boto3.client("s3vectors", region_name=region_name)
|
|
self.vector_bucket_name = vector_bucket_name
|
|
self.collection_name = collection_name
|
|
self.embedding_model_dims = embedding_model_dims
|
|
self.distance_metric = distance_metric
|
|
|
|
self._ensure_bucket_exists()
|
|
self.create_col(self.collection_name, self.embedding_model_dims, self.distance_metric)
|
|
|
|
def _ensure_bucket_exists(self):
|
|
try:
|
|
self.client.get_vector_bucket(vectorBucketName=self.vector_bucket_name)
|
|
logger.info(f"Vector bucket '{self.vector_bucket_name}' already exists.")
|
|
except ClientError as e:
|
|
if e.response["Error"]["Code"] == "NotFoundException":
|
|
logger.info(f"Vector bucket '{self.vector_bucket_name}' not found. Creating it.")
|
|
self.client.create_vector_bucket(vectorBucketName=self.vector_bucket_name)
|
|
logger.info(f"Vector bucket '{self.vector_bucket_name}' created.")
|
|
else:
|
|
raise
|
|
|
|
def create_col(self, name, vector_size, distance="cosine"):
|
|
try:
|
|
self.client.get_index(vectorBucketName=self.vector_bucket_name, indexName=name)
|
|
logger.info(f"Index '{name}' already exists in bucket '{self.vector_bucket_name}'.")
|
|
except ClientError as e:
|
|
if e.response["Error"]["Code"] == "NotFoundException":
|
|
logger.info(f"Index '{name}' not found in bucket '{self.vector_bucket_name}'. Creating it.")
|
|
self.client.create_index(
|
|
vectorBucketName=self.vector_bucket_name,
|
|
indexName=name,
|
|
dataType="float32",
|
|
dimension=vector_size,
|
|
distanceMetric=distance,
|
|
)
|
|
logger.info(f"Index '{name}' created.")
|
|
else:
|
|
raise
|
|
|
|
def _parse_output(self, vectors: List[Dict]) -> List[OutputData]:
|
|
results = []
|
|
for v in vectors:
|
|
payload = v.get("metadata", {})
|
|
# Boto3 might return metadata as a JSON string
|
|
if isinstance(payload, str):
|
|
try:
|
|
payload = json.loads(payload)
|
|
except json.JSONDecodeError:
|
|
logger.warning(f"Failed to parse metadata for key {v.get('key')}")
|
|
payload = {}
|
|
results.append(OutputData(id=v.get("key"), score=v.get("distance"), payload=payload))
|
|
return results
|
|
|
|
def insert(self, vectors, payloads=None, ids=None):
|
|
vectors_to_put = []
|
|
for i, vec in enumerate(vectors):
|
|
vectors_to_put.append(
|
|
{
|
|
"key": ids[i],
|
|
"data": {"float32": vec},
|
|
"metadata": payloads[i] if payloads else {},
|
|
}
|
|
)
|
|
self.client.put_vectors(
|
|
vectorBucketName=self.vector_bucket_name,
|
|
indexName=self.collection_name,
|
|
vectors=vectors_to_put,
|
|
)
|
|
|
|
def search(self, query, vectors, limit=5, filters=None):
|
|
params = {
|
|
"vectorBucketName": self.vector_bucket_name,
|
|
"indexName": self.collection_name,
|
|
"queryVector": {"float32": vectors},
|
|
"topK": limit,
|
|
"returnMetadata": True,
|
|
"returnDistance": True,
|
|
}
|
|
if filters:
|
|
params["filter"] = filters
|
|
|
|
response = self.client.query_vectors(**params)
|
|
return self._parse_output(response.get("vectors", []))
|
|
|
|
def delete(self, vector_id):
|
|
self.client.delete_vectors(
|
|
vectorBucketName=self.vector_bucket_name,
|
|
indexName=self.collection_name,
|
|
keys=[vector_id],
|
|
)
|
|
|
|
def update(self, vector_id, vector=None, payload=None):
|
|
# S3 Vectors uses put_vectors for updates (overwrite)
|
|
self.insert(vectors=[vector], payloads=[payload], ids=[vector_id])
|
|
|
|
def get(self, vector_id) -> Optional[OutputData]:
|
|
response = self.client.get_vectors(
|
|
vectorBucketName=self.vector_bucket_name,
|
|
indexName=self.collection_name,
|
|
keys=[vector_id],
|
|
returnData=False,
|
|
returnMetadata=True,
|
|
)
|
|
vectors = response.get("vectors", [])
|
|
if not vectors:
|
|
return None
|
|
return self._parse_output(vectors)[0]
|
|
|
|
def list_cols(self):
|
|
response = self.client.list_indexes(vectorBucketName=self.vector_bucket_name)
|
|
return [idx["indexName"] for idx in response.get("indexes", [])]
|
|
|
|
def delete_col(self):
|
|
self.client.delete_index(vectorBucketName=self.vector_bucket_name, indexName=self.collection_name)
|
|
|
|
def col_info(self):
|
|
response = self.client.get_index(vectorBucketName=self.vector_bucket_name, indexName=self.collection_name)
|
|
return response.get("index", {})
|
|
|
|
def list(self, filters=None, limit=None):
|
|
# Note: list_vectors does not support metadata filtering.
|
|
if filters:
|
|
logger.warning("S3 Vectors `list` does not support metadata filtering. Ignoring filters.")
|
|
|
|
params = {
|
|
"vectorBucketName": self.vector_bucket_name,
|
|
"indexName": self.collection_name,
|
|
"returnData": False,
|
|
"returnMetadata": True,
|
|
}
|
|
if limit:
|
|
params["maxResults"] = limit
|
|
|
|
paginator = self.client.get_paginator("list_vectors")
|
|
pages = paginator.paginate(**params)
|
|
all_vectors = []
|
|
for page in pages:
|
|
all_vectors.extend(page.get("vectors", []))
|
|
return [self._parse_output(all_vectors)]
|
|
|
|
def reset(self):
|
|
logger.warning(f"Resetting index {self.collection_name}...")
|
|
self.delete_col()
|
|
self.create_col(self.collection_name, self.embedding_model_dims, self.distance_metric)
|