first commit
This commit is contained in:
176
vector_stores/s3_vectors.py
Normal file
176
vector_stores/s3_vectors.py
Normal file
@@ -0,0 +1,176 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
try:
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
except ImportError:
|
||||
raise ImportError("The 'boto3' library is required. Please install it using 'pip install boto3'.")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OutputData(BaseModel):
|
||||
id: Optional[str]
|
||||
score: Optional[float]
|
||||
payload: Optional[Dict]
|
||||
|
||||
|
||||
class S3Vectors(VectorStoreBase):
|
||||
def __init__(
|
||||
self,
|
||||
vector_bucket_name: str,
|
||||
collection_name: str,
|
||||
embedding_model_dims: int,
|
||||
distance_metric: str = "cosine",
|
||||
region_name: Optional[str] = None,
|
||||
):
|
||||
self.client = boto3.client("s3vectors", region_name=region_name)
|
||||
self.vector_bucket_name = vector_bucket_name
|
||||
self.collection_name = collection_name
|
||||
self.embedding_model_dims = embedding_model_dims
|
||||
self.distance_metric = distance_metric
|
||||
|
||||
self._ensure_bucket_exists()
|
||||
self.create_col(self.collection_name, self.embedding_model_dims, self.distance_metric)
|
||||
|
||||
def _ensure_bucket_exists(self):
|
||||
try:
|
||||
self.client.get_vector_bucket(vectorBucketName=self.vector_bucket_name)
|
||||
logger.info(f"Vector bucket '{self.vector_bucket_name}' already exists.")
|
||||
except ClientError as e:
|
||||
if e.response["Error"]["Code"] == "NotFoundException":
|
||||
logger.info(f"Vector bucket '{self.vector_bucket_name}' not found. Creating it.")
|
||||
self.client.create_vector_bucket(vectorBucketName=self.vector_bucket_name)
|
||||
logger.info(f"Vector bucket '{self.vector_bucket_name}' created.")
|
||||
else:
|
||||
raise
|
||||
|
||||
def create_col(self, name, vector_size, distance="cosine"):
|
||||
try:
|
||||
self.client.get_index(vectorBucketName=self.vector_bucket_name, indexName=name)
|
||||
logger.info(f"Index '{name}' already exists in bucket '{self.vector_bucket_name}'.")
|
||||
except ClientError as e:
|
||||
if e.response["Error"]["Code"] == "NotFoundException":
|
||||
logger.info(f"Index '{name}' not found in bucket '{self.vector_bucket_name}'. Creating it.")
|
||||
self.client.create_index(
|
||||
vectorBucketName=self.vector_bucket_name,
|
||||
indexName=name,
|
||||
dataType="float32",
|
||||
dimension=vector_size,
|
||||
distanceMetric=distance,
|
||||
)
|
||||
logger.info(f"Index '{name}' created.")
|
||||
else:
|
||||
raise
|
||||
|
||||
def _parse_output(self, vectors: List[Dict]) -> List[OutputData]:
|
||||
results = []
|
||||
for v in vectors:
|
||||
payload = v.get("metadata", {})
|
||||
# Boto3 might return metadata as a JSON string
|
||||
if isinstance(payload, str):
|
||||
try:
|
||||
payload = json.loads(payload)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse metadata for key {v.get('key')}")
|
||||
payload = {}
|
||||
results.append(OutputData(id=v.get("key"), score=v.get("distance"), payload=payload))
|
||||
return results
|
||||
|
||||
def insert(self, vectors, payloads=None, ids=None):
|
||||
vectors_to_put = []
|
||||
for i, vec in enumerate(vectors):
|
||||
vectors_to_put.append(
|
||||
{
|
||||
"key": ids[i],
|
||||
"data": {"float32": vec},
|
||||
"metadata": payloads[i] if payloads else {},
|
||||
}
|
||||
)
|
||||
self.client.put_vectors(
|
||||
vectorBucketName=self.vector_bucket_name,
|
||||
indexName=self.collection_name,
|
||||
vectors=vectors_to_put,
|
||||
)
|
||||
|
||||
def search(self, query, vectors, limit=5, filters=None):
|
||||
params = {
|
||||
"vectorBucketName": self.vector_bucket_name,
|
||||
"indexName": self.collection_name,
|
||||
"queryVector": {"float32": vectors},
|
||||
"topK": limit,
|
||||
"returnMetadata": True,
|
||||
"returnDistance": True,
|
||||
}
|
||||
if filters:
|
||||
params["filter"] = filters
|
||||
|
||||
response = self.client.query_vectors(**params)
|
||||
return self._parse_output(response.get("vectors", []))
|
||||
|
||||
def delete(self, vector_id):
|
||||
self.client.delete_vectors(
|
||||
vectorBucketName=self.vector_bucket_name,
|
||||
indexName=self.collection_name,
|
||||
keys=[vector_id],
|
||||
)
|
||||
|
||||
def update(self, vector_id, vector=None, payload=None):
|
||||
# S3 Vectors uses put_vectors for updates (overwrite)
|
||||
self.insert(vectors=[vector], payloads=[payload], ids=[vector_id])
|
||||
|
||||
def get(self, vector_id) -> Optional[OutputData]:
|
||||
response = self.client.get_vectors(
|
||||
vectorBucketName=self.vector_bucket_name,
|
||||
indexName=self.collection_name,
|
||||
keys=[vector_id],
|
||||
returnData=False,
|
||||
returnMetadata=True,
|
||||
)
|
||||
vectors = response.get("vectors", [])
|
||||
if not vectors:
|
||||
return None
|
||||
return self._parse_output(vectors)[0]
|
||||
|
||||
def list_cols(self):
|
||||
response = self.client.list_indexes(vectorBucketName=self.vector_bucket_name)
|
||||
return [idx["indexName"] for idx in response.get("indexes", [])]
|
||||
|
||||
def delete_col(self):
|
||||
self.client.delete_index(vectorBucketName=self.vector_bucket_name, indexName=self.collection_name)
|
||||
|
||||
def col_info(self):
|
||||
response = self.client.get_index(vectorBucketName=self.vector_bucket_name, indexName=self.collection_name)
|
||||
return response.get("index", {})
|
||||
|
||||
def list(self, filters=None, limit=None):
|
||||
# Note: list_vectors does not support metadata filtering.
|
||||
if filters:
|
||||
logger.warning("S3 Vectors `list` does not support metadata filtering. Ignoring filters.")
|
||||
|
||||
params = {
|
||||
"vectorBucketName": self.vector_bucket_name,
|
||||
"indexName": self.collection_name,
|
||||
"returnData": False,
|
||||
"returnMetadata": True,
|
||||
}
|
||||
if limit:
|
||||
params["maxResults"] = limit
|
||||
|
||||
paginator = self.client.get_paginator("list_vectors")
|
||||
pages = paginator.paginate(**params)
|
||||
all_vectors = []
|
||||
for page in pages:
|
||||
all_vectors.extend(page.get("vectors", []))
|
||||
return [self._parse_output(all_vectors)]
|
||||
|
||||
def reset(self):
|
||||
logger.warning(f"Resetting index {self.collection_name}...")
|
||||
self.delete_col()
|
||||
self.create_col(self.collection_name, self.embedding_model_dims, self.distance_metric)
|
||||
Reference in New Issue
Block a user