Files
mem0/vector_stores/s3_vectors.py
2026-03-06 21:11:10 +08:00

177 lines
6.5 KiB
Python

import json
import logging
from typing import Dict, List, Optional
from pydantic import BaseModel
from mem0.vector_stores.base import VectorStoreBase
try:
import boto3
from botocore.exceptions import ClientError
except ImportError:
raise ImportError("The 'boto3' library is required. Please install it using 'pip install boto3'.")
logger = logging.getLogger(__name__)
class OutputData(BaseModel):
id: Optional[str]
score: Optional[float]
payload: Optional[Dict]
class S3Vectors(VectorStoreBase):
def __init__(
self,
vector_bucket_name: str,
collection_name: str,
embedding_model_dims: int,
distance_metric: str = "cosine",
region_name: Optional[str] = None,
):
self.client = boto3.client("s3vectors", region_name=region_name)
self.vector_bucket_name = vector_bucket_name
self.collection_name = collection_name
self.embedding_model_dims = embedding_model_dims
self.distance_metric = distance_metric
self._ensure_bucket_exists()
self.create_col(self.collection_name, self.embedding_model_dims, self.distance_metric)
def _ensure_bucket_exists(self):
try:
self.client.get_vector_bucket(vectorBucketName=self.vector_bucket_name)
logger.info(f"Vector bucket '{self.vector_bucket_name}' already exists.")
except ClientError as e:
if e.response["Error"]["Code"] == "NotFoundException":
logger.info(f"Vector bucket '{self.vector_bucket_name}' not found. Creating it.")
self.client.create_vector_bucket(vectorBucketName=self.vector_bucket_name)
logger.info(f"Vector bucket '{self.vector_bucket_name}' created.")
else:
raise
def create_col(self, name, vector_size, distance="cosine"):
try:
self.client.get_index(vectorBucketName=self.vector_bucket_name, indexName=name)
logger.info(f"Index '{name}' already exists in bucket '{self.vector_bucket_name}'.")
except ClientError as e:
if e.response["Error"]["Code"] == "NotFoundException":
logger.info(f"Index '{name}' not found in bucket '{self.vector_bucket_name}'. Creating it.")
self.client.create_index(
vectorBucketName=self.vector_bucket_name,
indexName=name,
dataType="float32",
dimension=vector_size,
distanceMetric=distance,
)
logger.info(f"Index '{name}' created.")
else:
raise
def _parse_output(self, vectors: List[Dict]) -> List[OutputData]:
results = []
for v in vectors:
payload = v.get("metadata", {})
# Boto3 might return metadata as a JSON string
if isinstance(payload, str):
try:
payload = json.loads(payload)
except json.JSONDecodeError:
logger.warning(f"Failed to parse metadata for key {v.get('key')}")
payload = {}
results.append(OutputData(id=v.get("key"), score=v.get("distance"), payload=payload))
return results
def insert(self, vectors, payloads=None, ids=None):
vectors_to_put = []
for i, vec in enumerate(vectors):
vectors_to_put.append(
{
"key": ids[i],
"data": {"float32": vec},
"metadata": payloads[i] if payloads else {},
}
)
self.client.put_vectors(
vectorBucketName=self.vector_bucket_name,
indexName=self.collection_name,
vectors=vectors_to_put,
)
def search(self, query, vectors, limit=5, filters=None):
params = {
"vectorBucketName": self.vector_bucket_name,
"indexName": self.collection_name,
"queryVector": {"float32": vectors},
"topK": limit,
"returnMetadata": True,
"returnDistance": True,
}
if filters:
params["filter"] = filters
response = self.client.query_vectors(**params)
return self._parse_output(response.get("vectors", []))
def delete(self, vector_id):
self.client.delete_vectors(
vectorBucketName=self.vector_bucket_name,
indexName=self.collection_name,
keys=[vector_id],
)
def update(self, vector_id, vector=None, payload=None):
# S3 Vectors uses put_vectors for updates (overwrite)
self.insert(vectors=[vector], payloads=[payload], ids=[vector_id])
def get(self, vector_id) -> Optional[OutputData]:
response = self.client.get_vectors(
vectorBucketName=self.vector_bucket_name,
indexName=self.collection_name,
keys=[vector_id],
returnData=False,
returnMetadata=True,
)
vectors = response.get("vectors", [])
if not vectors:
return None
return self._parse_output(vectors)[0]
def list_cols(self):
response = self.client.list_indexes(vectorBucketName=self.vector_bucket_name)
return [idx["indexName"] for idx in response.get("indexes", [])]
def delete_col(self):
self.client.delete_index(vectorBucketName=self.vector_bucket_name, indexName=self.collection_name)
def col_info(self):
response = self.client.get_index(vectorBucketName=self.vector_bucket_name, indexName=self.collection_name)
return response.get("index", {})
def list(self, filters=None, limit=None):
# Note: list_vectors does not support metadata filtering.
if filters:
logger.warning("S3 Vectors `list` does not support metadata filtering. Ignoring filters.")
params = {
"vectorBucketName": self.vector_bucket_name,
"indexName": self.collection_name,
"returnData": False,
"returnMetadata": True,
}
if limit:
params["maxResults"] = limit
paginator = self.client.get_paginator("list_vectors")
pages = paginator.paginate(**params)
all_vectors = []
for page in pages:
all_vectors.extend(page.get("vectors", []))
return [self._parse_output(all_vectors)]
def reset(self):
logger.warning(f"Resetting index {self.collection_name}...")
self.delete_col()
self.create_col(self.collection_name, self.embedding_model_dims, self.distance_metric)