first commit
This commit is contained in:
396
vector_stores/azure_ai_search.py
Normal file
396
vector_stores/azure_ai_search.py
Normal file
@@ -0,0 +1,396 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from mem0.memory.utils import extract_json
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
try:
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
from azure.identity import DefaultAzureCredential
|
||||
from azure.search.documents import SearchClient
|
||||
from azure.search.documents.indexes import SearchIndexClient
|
||||
from azure.search.documents.indexes.models import (
|
||||
BinaryQuantizationCompression,
|
||||
HnswAlgorithmConfiguration,
|
||||
ScalarQuantizationCompression,
|
||||
SearchField,
|
||||
SearchFieldDataType,
|
||||
SearchIndex,
|
||||
SimpleField,
|
||||
VectorSearch,
|
||||
VectorSearchProfile,
|
||||
)
|
||||
from azure.search.documents.models import VectorizedQuery
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The 'azure-search-documents' library is required. Please install it using 'pip install azure-search-documents==11.5.2'."
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OutputData(BaseModel):
|
||||
id: Optional[str]
|
||||
score: Optional[float]
|
||||
payload: Optional[dict]
|
||||
|
||||
|
||||
class AzureAISearch(VectorStoreBase):
|
||||
def __init__(
|
||||
self,
|
||||
service_name,
|
||||
collection_name,
|
||||
api_key,
|
||||
embedding_model_dims,
|
||||
compression_type: Optional[str] = None,
|
||||
use_float16: bool = False,
|
||||
hybrid_search: bool = False,
|
||||
vector_filter_mode: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the Azure AI Search vector store.
|
||||
|
||||
Args:
|
||||
service_name (str): Azure AI Search service name.
|
||||
collection_name (str): Index name.
|
||||
api_key (str): API key for the Azure AI Search service.
|
||||
embedding_model_dims (int): Dimension of the embedding vector.
|
||||
compression_type (Optional[str]): Specifies the type of quantization to use.
|
||||
Allowed values are None (no quantization), "scalar", or "binary".
|
||||
use_float16 (bool): Whether to store vectors in half precision (Edm.Half) or full precision (Edm.Single).
|
||||
(Note: This flag is preserved from the initial implementation per feedback.)
|
||||
hybrid_search (bool): Whether to use hybrid search. Default is False.
|
||||
vector_filter_mode (Optional[str]): Mode for vector filtering. Default is "preFilter".
|
||||
"""
|
||||
self.service_name = service_name
|
||||
self.api_key = api_key
|
||||
self.index_name = collection_name
|
||||
self.collection_name = collection_name
|
||||
self.embedding_model_dims = embedding_model_dims
|
||||
# If compression_type is None, treat it as "none".
|
||||
self.compression_type = (compression_type or "none").lower()
|
||||
self.use_float16 = use_float16
|
||||
self.hybrid_search = hybrid_search
|
||||
self.vector_filter_mode = vector_filter_mode
|
||||
|
||||
# If the API key is not provided or is a placeholder, use DefaultAzureCredential.
|
||||
if self.api_key is None or self.api_key == "" or self.api_key == "your-api-key":
|
||||
credential = DefaultAzureCredential()
|
||||
self.api_key = None
|
||||
else:
|
||||
credential = AzureKeyCredential(self.api_key)
|
||||
|
||||
self.search_client = SearchClient(
|
||||
endpoint=f"https://{service_name}.search.windows.net",
|
||||
index_name=self.index_name,
|
||||
credential=credential,
|
||||
)
|
||||
self.index_client = SearchIndexClient(
|
||||
endpoint=f"https://{service_name}.search.windows.net",
|
||||
credential=credential,
|
||||
)
|
||||
|
||||
self.search_client._client._config.user_agent_policy.add_user_agent("mem0")
|
||||
self.index_client._client._config.user_agent_policy.add_user_agent("mem0")
|
||||
|
||||
collections = self.list_cols()
|
||||
if collection_name not in collections:
|
||||
self.create_col()
|
||||
|
||||
def create_col(self):
|
||||
"""Create a new index in Azure AI Search."""
|
||||
# Determine vector type based on use_float16 setting.
|
||||
if self.use_float16:
|
||||
vector_type = "Collection(Edm.Half)"
|
||||
else:
|
||||
vector_type = "Collection(Edm.Single)"
|
||||
|
||||
# Configure compression settings based on the specified compression_type.
|
||||
compression_configurations = []
|
||||
compression_name = None
|
||||
if self.compression_type == "scalar":
|
||||
compression_name = "myCompression"
|
||||
# For SQ, rescoring defaults to True and oversampling defaults to 4.
|
||||
compression_configurations = [
|
||||
ScalarQuantizationCompression(
|
||||
compression_name=compression_name
|
||||
# rescoring defaults to True and oversampling defaults to 4
|
||||
)
|
||||
]
|
||||
elif self.compression_type == "binary":
|
||||
compression_name = "myCompression"
|
||||
# For BQ, rescoring defaults to True and oversampling defaults to 10.
|
||||
compression_configurations = [
|
||||
BinaryQuantizationCompression(
|
||||
compression_name=compression_name
|
||||
# rescoring defaults to True and oversampling defaults to 10
|
||||
)
|
||||
]
|
||||
# If no compression is desired, compression_configurations remains empty.
|
||||
fields = [
|
||||
SimpleField(name="id", type=SearchFieldDataType.String, key=True),
|
||||
SimpleField(name="user_id", type=SearchFieldDataType.String, filterable=True),
|
||||
SimpleField(name="run_id", type=SearchFieldDataType.String, filterable=True),
|
||||
SimpleField(name="agent_id", type=SearchFieldDataType.String, filterable=True),
|
||||
SearchField(
|
||||
name="vector",
|
||||
type=vector_type,
|
||||
searchable=True,
|
||||
vector_search_dimensions=self.embedding_model_dims,
|
||||
vector_search_profile_name="my-vector-config",
|
||||
),
|
||||
SearchField(name="payload", type=SearchFieldDataType.String, searchable=True),
|
||||
]
|
||||
|
||||
vector_search = VectorSearch(
|
||||
profiles=[
|
||||
VectorSearchProfile(
|
||||
name="my-vector-config",
|
||||
algorithm_configuration_name="my-algorithms-config",
|
||||
compression_name=compression_name if self.compression_type != "none" else None,
|
||||
)
|
||||
],
|
||||
algorithms=[HnswAlgorithmConfiguration(name="my-algorithms-config")],
|
||||
compressions=compression_configurations,
|
||||
)
|
||||
index = SearchIndex(name=self.index_name, fields=fields, vector_search=vector_search)
|
||||
self.index_client.create_or_update_index(index)
|
||||
|
||||
def _generate_document(self, vector, payload, id):
|
||||
document = {"id": id, "vector": vector, "payload": json.dumps(payload)}
|
||||
# Extract additional fields if they exist.
|
||||
for field in ["user_id", "run_id", "agent_id"]:
|
||||
if field in payload:
|
||||
document[field] = payload[field]
|
||||
return document
|
||||
|
||||
# Note: Explicit "insert" calls may later be decoupled from memory management decisions.
|
||||
def insert(self, vectors, payloads=None, ids=None):
|
||||
"""
|
||||
Insert vectors into the index.
|
||||
|
||||
Args:
|
||||
vectors (List[List[float]]): List of vectors to insert.
|
||||
payloads (List[Dict], optional): List of payloads corresponding to vectors.
|
||||
ids (List[str], optional): List of IDs corresponding to vectors.
|
||||
"""
|
||||
logger.info(f"Inserting {len(vectors)} vectors into index {self.index_name}")
|
||||
documents = [
|
||||
self._generate_document(vector, payload, id) for id, vector, payload in zip(ids, vectors, payloads)
|
||||
]
|
||||
response = self.search_client.upload_documents(documents)
|
||||
for doc in response:
|
||||
if not hasattr(doc, "status_code") and doc.get("status_code") != 201:
|
||||
raise Exception(f"Insert failed for document {doc.get('id')}: {doc}")
|
||||
return response
|
||||
|
||||
def _sanitize_key(self, key: str) -> str:
|
||||
return re.sub(r"[^\w]", "", key)
|
||||
|
||||
def _build_filter_expression(self, filters):
|
||||
filter_conditions = []
|
||||
for key, value in filters.items():
|
||||
safe_key = self._sanitize_key(key)
|
||||
if isinstance(value, str):
|
||||
safe_value = value.replace("'", "''")
|
||||
condition = f"{safe_key} eq '{safe_value}'"
|
||||
else:
|
||||
condition = f"{safe_key} eq {value}"
|
||||
filter_conditions.append(condition)
|
||||
filter_expression = " and ".join(filter_conditions)
|
||||
return filter_expression
|
||||
|
||||
def search(self, query, vectors, limit=5, filters=None):
|
||||
"""
|
||||
Search for similar vectors.
|
||||
|
||||
Args:
|
||||
query (str): Query.
|
||||
vectors (List[float]): Query vector.
|
||||
limit (int, optional): Number of results to return. Defaults to 5.
|
||||
filters (Dict, optional): Filters to apply to the search. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[OutputData]: Search results.
|
||||
"""
|
||||
filter_expression = None
|
||||
if filters:
|
||||
filter_expression = self._build_filter_expression(filters)
|
||||
|
||||
vector_query = VectorizedQuery(vector=vectors, k_nearest_neighbors=limit, fields="vector")
|
||||
if self.hybrid_search:
|
||||
search_results = self.search_client.search(
|
||||
search_text=query,
|
||||
vector_queries=[vector_query],
|
||||
filter=filter_expression,
|
||||
top=limit,
|
||||
vector_filter_mode=self.vector_filter_mode,
|
||||
search_fields=["payload"],
|
||||
)
|
||||
else:
|
||||
search_results = self.search_client.search(
|
||||
vector_queries=[vector_query],
|
||||
filter=filter_expression,
|
||||
top=limit,
|
||||
vector_filter_mode=self.vector_filter_mode,
|
||||
)
|
||||
|
||||
results = []
|
||||
for result in search_results:
|
||||
payload = json.loads(extract_json(result["payload"]))
|
||||
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
|
||||
return results
|
||||
|
||||
def delete(self, vector_id):
|
||||
"""
|
||||
Delete a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to delete.
|
||||
"""
|
||||
response = self.search_client.delete_documents(documents=[{"id": vector_id}])
|
||||
for doc in response:
|
||||
if not hasattr(doc, "status_code") and doc.get("status_code") != 200:
|
||||
raise Exception(f"Delete failed for document {vector_id}: {doc}")
|
||||
logger.info(f"Deleted document with ID '{vector_id}' from index '{self.index_name}'.")
|
||||
return response
|
||||
|
||||
def update(self, vector_id, vector=None, payload=None):
|
||||
"""
|
||||
Update a vector and its payload.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to update.
|
||||
vector (List[float], optional): Updated vector.
|
||||
payload (Dict, optional): Updated payload.
|
||||
"""
|
||||
document = {"id": vector_id}
|
||||
if vector:
|
||||
document["vector"] = vector
|
||||
if payload:
|
||||
json_payload = json.dumps(payload)
|
||||
document["payload"] = json_payload
|
||||
for field in ["user_id", "run_id", "agent_id"]:
|
||||
document[field] = payload.get(field)
|
||||
response = self.search_client.merge_or_upload_documents(documents=[document])
|
||||
for doc in response:
|
||||
if not hasattr(doc, "status_code") and doc.get("status_code") != 200:
|
||||
raise Exception(f"Update failed for document {vector_id}: {doc}")
|
||||
return response
|
||||
|
||||
def get(self, vector_id) -> OutputData:
|
||||
"""
|
||||
Retrieve a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to retrieve.
|
||||
|
||||
Returns:
|
||||
OutputData: Retrieved vector.
|
||||
"""
|
||||
try:
|
||||
result = self.search_client.get_document(key=vector_id)
|
||||
except ResourceNotFoundError:
|
||||
return None
|
||||
payload = json.loads(extract_json(result["payload"]))
|
||||
return OutputData(id=result["id"], score=None, payload=payload)
|
||||
|
||||
def list_cols(self) -> List[str]:
|
||||
"""
|
||||
List all collections (indexes).
|
||||
|
||||
Returns:
|
||||
List[str]: List of index names.
|
||||
"""
|
||||
try:
|
||||
names = self.index_client.list_index_names()
|
||||
except AttributeError:
|
||||
names = [index.name for index in self.index_client.list_indexes()]
|
||||
return names
|
||||
|
||||
def delete_col(self):
|
||||
"""Delete the index."""
|
||||
self.index_client.delete_index(self.index_name)
|
||||
|
||||
def col_info(self):
|
||||
"""
|
||||
Get information about the index.
|
||||
|
||||
Returns:
|
||||
dict: Index information.
|
||||
"""
|
||||
index = self.index_client.get_index(self.index_name)
|
||||
return {"name": index.name, "fields": index.fields}
|
||||
|
||||
def list(self, filters=None, limit=100):
|
||||
"""
|
||||
List all vectors in the index.
|
||||
|
||||
Args:
|
||||
filters (dict, optional): Filters to apply to the list.
|
||||
limit (int, optional): Number of vectors to return. Defaults to 100.
|
||||
|
||||
Returns:
|
||||
List[OutputData]: List of vectors.
|
||||
"""
|
||||
filter_expression = None
|
||||
if filters:
|
||||
filter_expression = self._build_filter_expression(filters)
|
||||
|
||||
search_results = self.search_client.search(search_text="*", filter=filter_expression, top=limit)
|
||||
results = []
|
||||
for result in search_results:
|
||||
payload = json.loads(extract_json(result["payload"]))
|
||||
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
|
||||
return [results]
|
||||
|
||||
def __del__(self):
|
||||
"""Close the search client when the object is deleted."""
|
||||
self.search_client.close()
|
||||
self.index_client.close()
|
||||
|
||||
def reset(self):
|
||||
"""Reset the index by deleting and recreating it."""
|
||||
logger.warning(f"Resetting index {self.index_name}...")
|
||||
|
||||
try:
|
||||
# Close the existing clients
|
||||
self.search_client.close()
|
||||
self.index_client.close()
|
||||
|
||||
# Delete the collection
|
||||
self.delete_col()
|
||||
|
||||
# If the API key is not provided or is a placeholder, use DefaultAzureCredential.
|
||||
if self.api_key is None or self.api_key == "" or self.api_key == "your-api-key":
|
||||
credential = DefaultAzureCredential()
|
||||
self.api_key = None
|
||||
else:
|
||||
credential = AzureKeyCredential(self.api_key)
|
||||
|
||||
# Reinitialize the clients
|
||||
service_endpoint = f"https://{self.service_name}.search.windows.net"
|
||||
self.search_client = SearchClient(
|
||||
endpoint=service_endpoint,
|
||||
index_name=self.index_name,
|
||||
credential=credential,
|
||||
)
|
||||
self.index_client = SearchIndexClient(
|
||||
endpoint=service_endpoint,
|
||||
credential=credential,
|
||||
)
|
||||
|
||||
# Add user agent
|
||||
self.search_client._client._config.user_agent_policy.add_user_agent("mem0")
|
||||
self.index_client._client._config.user_agent_policy.add_user_agent("mem0")
|
||||
|
||||
# Create the collection
|
||||
self.create_col()
|
||||
except Exception as e:
|
||||
logger.error(f"Error resetting index {self.index_name}: {e}")
|
||||
raise
|
||||
Reference in New Issue
Block a user