825 lines
30 KiB
Python
825 lines
30 KiB
Python
import json
|
|
import logging
|
|
from datetime import datetime
|
|
from typing import Dict
|
|
|
|
import numpy as np
|
|
import pytz
|
|
import valkey
|
|
from pydantic import BaseModel
|
|
from valkey.exceptions import ResponseError
|
|
|
|
from mem0.memory.utils import extract_json
|
|
from mem0.vector_stores.base import VectorStoreBase
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Default fields for the Valkey index
|
|
DEFAULT_FIELDS = [
|
|
{"name": "memory_id", "type": "tag"},
|
|
{"name": "hash", "type": "tag"},
|
|
{"name": "agent_id", "type": "tag"},
|
|
{"name": "run_id", "type": "tag"},
|
|
{"name": "user_id", "type": "tag"},
|
|
{"name": "memory", "type": "tag"}, # Using TAG instead of TEXT for Valkey compatibility
|
|
{"name": "metadata", "type": "tag"}, # Using TAG instead of TEXT for Valkey compatibility
|
|
{"name": "created_at", "type": "numeric"},
|
|
{"name": "updated_at", "type": "numeric"},
|
|
{
|
|
"name": "embedding",
|
|
"type": "vector",
|
|
"attrs": {"distance_metric": "cosine", "algorithm": "flat", "datatype": "float32"},
|
|
},
|
|
]
|
|
|
|
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
|
|
|
|
|
|
class OutputData(BaseModel):
|
|
id: str
|
|
score: float
|
|
payload: Dict
|
|
|
|
|
|
class ValkeyDB(VectorStoreBase):
|
|
def __init__(
|
|
self,
|
|
valkey_url: str,
|
|
collection_name: str,
|
|
embedding_model_dims: int,
|
|
timezone: str = "UTC",
|
|
index_type: str = "hnsw",
|
|
hnsw_m: int = 16,
|
|
hnsw_ef_construction: int = 200,
|
|
hnsw_ef_runtime: int = 10,
|
|
):
|
|
"""
|
|
Initialize the Valkey vector store.
|
|
|
|
Args:
|
|
valkey_url (str): Valkey URL.
|
|
collection_name (str): Collection name.
|
|
embedding_model_dims (int): Embedding model dimensions.
|
|
timezone (str, optional): Timezone for timestamps. Defaults to "UTC".
|
|
index_type (str, optional): Index type ('hnsw' or 'flat'). Defaults to "hnsw".
|
|
hnsw_m (int, optional): HNSW M parameter (connections per node). Defaults to 16.
|
|
hnsw_ef_construction (int, optional): HNSW ef_construction parameter. Defaults to 200.
|
|
hnsw_ef_runtime (int, optional): HNSW ef_runtime parameter. Defaults to 10.
|
|
"""
|
|
self.embedding_model_dims = embedding_model_dims
|
|
self.collection_name = collection_name
|
|
self.prefix = f"mem0:{collection_name}"
|
|
self.timezone = timezone
|
|
self.index_type = index_type.lower()
|
|
self.hnsw_m = hnsw_m
|
|
self.hnsw_ef_construction = hnsw_ef_construction
|
|
self.hnsw_ef_runtime = hnsw_ef_runtime
|
|
|
|
# Validate index type
|
|
if self.index_type not in ["hnsw", "flat"]:
|
|
raise ValueError(f"Invalid index_type: {index_type}. Must be 'hnsw' or 'flat'")
|
|
|
|
# Connect to Valkey
|
|
try:
|
|
self.client = valkey.from_url(valkey_url)
|
|
logger.debug(f"Successfully connected to Valkey at {valkey_url}")
|
|
except Exception as e:
|
|
logger.exception(f"Failed to connect to Valkey at {valkey_url}: {e}")
|
|
raise
|
|
|
|
# Create the index schema
|
|
self._create_index(embedding_model_dims)
|
|
|
|
def _build_index_schema(self, collection_name, embedding_dims, distance_metric, prefix):
|
|
"""
|
|
Build the FT.CREATE command for index creation.
|
|
|
|
Args:
|
|
collection_name (str): Name of the collection/index
|
|
embedding_dims (int): Vector embedding dimensions
|
|
distance_metric (str): Distance metric (e.g., "COSINE", "L2", "IP")
|
|
prefix (str): Key prefix for the index
|
|
|
|
Returns:
|
|
list: Complete FT.CREATE command as list of arguments
|
|
"""
|
|
# Build the vector field configuration based on index type
|
|
if self.index_type == "hnsw":
|
|
vector_config = [
|
|
"embedding",
|
|
"VECTOR",
|
|
"HNSW",
|
|
"12", # Attribute count: TYPE, FLOAT32, DIM, dims, DISTANCE_METRIC, metric, M, m, EF_CONSTRUCTION, ef_construction, EF_RUNTIME, ef_runtime
|
|
"TYPE",
|
|
"FLOAT32",
|
|
"DIM",
|
|
str(embedding_dims),
|
|
"DISTANCE_METRIC",
|
|
distance_metric,
|
|
"M",
|
|
str(self.hnsw_m),
|
|
"EF_CONSTRUCTION",
|
|
str(self.hnsw_ef_construction),
|
|
"EF_RUNTIME",
|
|
str(self.hnsw_ef_runtime),
|
|
]
|
|
elif self.index_type == "flat":
|
|
vector_config = [
|
|
"embedding",
|
|
"VECTOR",
|
|
"FLAT",
|
|
"6", # Attribute count: TYPE, FLOAT32, DIM, dims, DISTANCE_METRIC, metric
|
|
"TYPE",
|
|
"FLOAT32",
|
|
"DIM",
|
|
str(embedding_dims),
|
|
"DISTANCE_METRIC",
|
|
distance_metric,
|
|
]
|
|
else:
|
|
# This should never happen due to constructor validation, but be defensive
|
|
raise ValueError(f"Unsupported index_type: {self.index_type}. Must be 'hnsw' or 'flat'")
|
|
|
|
# Build the complete command (comma is default separator for TAG fields)
|
|
cmd = [
|
|
"FT.CREATE",
|
|
collection_name,
|
|
"ON",
|
|
"HASH",
|
|
"PREFIX",
|
|
"1",
|
|
prefix,
|
|
"SCHEMA",
|
|
"memory_id",
|
|
"TAG",
|
|
"hash",
|
|
"TAG",
|
|
"agent_id",
|
|
"TAG",
|
|
"run_id",
|
|
"TAG",
|
|
"user_id",
|
|
"TAG",
|
|
"memory",
|
|
"TAG",
|
|
"metadata",
|
|
"TAG",
|
|
"created_at",
|
|
"NUMERIC",
|
|
"updated_at",
|
|
"NUMERIC",
|
|
] + vector_config
|
|
|
|
return cmd
|
|
|
|
def _create_index(self, embedding_model_dims):
|
|
"""
|
|
Create the search index with the specified schema.
|
|
|
|
Args:
|
|
embedding_model_dims (int): Dimensions for the vector embeddings.
|
|
|
|
Raises:
|
|
ValueError: If the search module is not available.
|
|
Exception: For other errors during index creation.
|
|
"""
|
|
# Check if the search module is available
|
|
try:
|
|
# Try to execute a search command
|
|
self.client.execute_command("FT._LIST")
|
|
except ResponseError as e:
|
|
if "unknown command" in str(e).lower():
|
|
raise ValueError(
|
|
"Valkey search module is not available. Please ensure Valkey is running with the search module enabled. "
|
|
"The search module can be loaded using the --loadmodule option with the valkey-search library. "
|
|
"For installation and setup instructions, refer to the Valkey Search documentation."
|
|
)
|
|
else:
|
|
logger.exception(f"Error checking search module: {e}")
|
|
raise
|
|
|
|
# Check if the index already exists
|
|
try:
|
|
self.client.ft(self.collection_name).info()
|
|
return
|
|
except ResponseError as e:
|
|
if "not found" not in str(e).lower():
|
|
logger.exception(f"Error checking index existence: {e}")
|
|
raise
|
|
|
|
# Build and execute the index creation command
|
|
cmd = self._build_index_schema(
|
|
self.collection_name,
|
|
embedding_model_dims,
|
|
"COSINE", # Fixed distance metric for initialization
|
|
self.prefix,
|
|
)
|
|
|
|
try:
|
|
self.client.execute_command(*cmd)
|
|
logger.info(f"Successfully created {self.index_type.upper()} index {self.collection_name}")
|
|
except Exception as e:
|
|
logger.exception(f"Error creating index {self.collection_name}: {e}")
|
|
raise
|
|
|
|
def create_col(self, name=None, vector_size=None, distance=None):
|
|
"""
|
|
Create a new collection (index) in Valkey.
|
|
|
|
Args:
|
|
name (str, optional): Name for the collection. Defaults to None, which uses the current collection_name.
|
|
vector_size (int, optional): Size of the vector embeddings. Defaults to None, which uses the current embedding_model_dims.
|
|
distance (str, optional): Distance metric to use. Defaults to None, which uses 'cosine'.
|
|
|
|
Returns:
|
|
The created index object.
|
|
"""
|
|
# Use provided parameters or fall back to instance attributes
|
|
collection_name = name or self.collection_name
|
|
embedding_dims = vector_size or self.embedding_model_dims
|
|
distance_metric = distance or "COSINE"
|
|
prefix = f"mem0:{collection_name}"
|
|
|
|
# Try to drop the index if it exists (cleanup before creation)
|
|
self._drop_index(collection_name, log_level="silent")
|
|
|
|
# Build and execute the index creation command
|
|
cmd = self._build_index_schema(
|
|
collection_name,
|
|
embedding_dims,
|
|
distance_metric, # Configurable distance metric
|
|
prefix,
|
|
)
|
|
|
|
try:
|
|
self.client.execute_command(*cmd)
|
|
logger.info(f"Successfully created {self.index_type.upper()} index {collection_name}")
|
|
|
|
# Update instance attributes if creating a new collection
|
|
if name:
|
|
self.collection_name = collection_name
|
|
self.prefix = prefix
|
|
|
|
return self.client.ft(collection_name)
|
|
except Exception as e:
|
|
logger.exception(f"Error creating collection {collection_name}: {e}")
|
|
raise
|
|
|
|
def insert(self, vectors: list, payloads: list = None, ids: list = None):
|
|
"""
|
|
Insert vectors and their payloads into the index.
|
|
|
|
Args:
|
|
vectors (list): List of vectors to insert.
|
|
payloads (list, optional): List of payloads corresponding to the vectors.
|
|
ids (list, optional): List of IDs for the vectors.
|
|
"""
|
|
for vector, payload, id in zip(vectors, payloads, ids):
|
|
try:
|
|
# Create the key for the hash
|
|
key = f"{self.prefix}:{id}"
|
|
|
|
# Check for required fields and provide defaults if missing
|
|
if "data" not in payload:
|
|
# Silently use default value for missing 'data' field
|
|
pass
|
|
|
|
# Ensure created_at is present
|
|
if "created_at" not in payload:
|
|
payload["created_at"] = datetime.now(pytz.timezone(self.timezone)).isoformat()
|
|
|
|
# Prepare the hash data
|
|
hash_data = {
|
|
"memory_id": id,
|
|
"hash": payload.get("hash", f"hash_{id}"), # Use a default hash if not provided
|
|
"memory": payload.get("data", f"data_{id}"), # Use a default data if not provided
|
|
"created_at": int(datetime.fromisoformat(payload["created_at"]).timestamp()),
|
|
"embedding": np.array(vector, dtype=np.float32).tobytes(),
|
|
}
|
|
|
|
# Add optional fields
|
|
for field in ["agent_id", "run_id", "user_id"]:
|
|
if field in payload:
|
|
hash_data[field] = payload[field]
|
|
|
|
# Add metadata
|
|
hash_data["metadata"] = json.dumps({k: v for k, v in payload.items() if k not in excluded_keys})
|
|
|
|
# Store in Valkey
|
|
self.client.hset(key, mapping=hash_data)
|
|
logger.debug(f"Successfully inserted vector with ID {id}")
|
|
except KeyError as e:
|
|
logger.error(f"Error inserting vector with ID {id}: Missing required field {e}")
|
|
except Exception as e:
|
|
logger.exception(f"Error inserting vector with ID {id}: {e}")
|
|
raise
|
|
|
|
def _build_search_query(self, knn_part, filters=None):
|
|
"""
|
|
Build a search query string with filters.
|
|
|
|
Args:
|
|
knn_part (str): The KNN part of the query.
|
|
filters (dict, optional): Filters to apply to the search. Each key-value pair
|
|
becomes a tag filter (@key:{value}). None values are ignored.
|
|
Values are used as-is (no validation) - wildcards, lists, etc. are
|
|
passed through literally to Valkey search. Multiple filters are
|
|
combined with AND logic (space-separated).
|
|
|
|
Returns:
|
|
str: The complete search query string in format "filter_expr =>[KNN...]"
|
|
or "*=>[KNN...]" if no valid filters.
|
|
"""
|
|
# No filters, just use the KNN search
|
|
if not filters or not any(value is not None for key, value in filters.items()):
|
|
return f"*=>{knn_part}"
|
|
|
|
# Build filter expression
|
|
filter_parts = []
|
|
for key, value in filters.items():
|
|
if value is not None:
|
|
# Use the correct filter syntax for Valkey
|
|
filter_parts.append(f"@{key}:{{{value}}}")
|
|
|
|
# No valid filter parts
|
|
if not filter_parts:
|
|
return f"*=>{knn_part}"
|
|
|
|
# Combine filter parts with proper syntax
|
|
filter_expr = " ".join(filter_parts)
|
|
return f"{filter_expr} =>{knn_part}"
|
|
|
|
def _execute_search(self, query, params):
|
|
"""
|
|
Execute a search query.
|
|
|
|
Args:
|
|
query (str): The search query to execute.
|
|
params (dict): The query parameters.
|
|
|
|
Returns:
|
|
The search results.
|
|
"""
|
|
try:
|
|
return self.client.ft(self.collection_name).search(query, query_params=params)
|
|
except ResponseError as e:
|
|
logger.error(f"Search failed with query '{query}': {e}")
|
|
raise
|
|
|
|
def _process_search_results(self, results):
|
|
"""
|
|
Process search results into OutputData objects.
|
|
|
|
Args:
|
|
results: The search results from Valkey.
|
|
|
|
Returns:
|
|
list: List of OutputData objects.
|
|
"""
|
|
memory_results = []
|
|
for doc in results.docs:
|
|
# Extract the score
|
|
score = float(doc.vector_score) if hasattr(doc, "vector_score") else None
|
|
|
|
# Create the payload
|
|
payload = {
|
|
"hash": doc.hash,
|
|
"data": doc.memory,
|
|
"created_at": self._format_timestamp(int(doc.created_at), self.timezone),
|
|
}
|
|
|
|
# Add updated_at if available
|
|
if hasattr(doc, "updated_at"):
|
|
payload["updated_at"] = self._format_timestamp(int(doc.updated_at), self.timezone)
|
|
|
|
# Add optional fields
|
|
for field in ["agent_id", "run_id", "user_id"]:
|
|
if hasattr(doc, field):
|
|
payload[field] = getattr(doc, field)
|
|
|
|
# Add metadata
|
|
if hasattr(doc, "metadata"):
|
|
try:
|
|
metadata = json.loads(extract_json(doc.metadata))
|
|
payload.update(metadata)
|
|
except (json.JSONDecodeError, TypeError) as e:
|
|
logger.warning(f"Failed to parse metadata: {e}")
|
|
|
|
# Create the result
|
|
memory_results.append(OutputData(id=doc.memory_id, score=score, payload=payload))
|
|
|
|
return memory_results
|
|
|
|
def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None, ef_runtime: int = None):
|
|
"""
|
|
Search for similar vectors in the index.
|
|
|
|
Args:
|
|
query (str): The search query.
|
|
vectors (list): The vector to search for.
|
|
limit (int, optional): Maximum number of results to return. Defaults to 5.
|
|
filters (dict, optional): Filters to apply to the search. Defaults to None.
|
|
ef_runtime (int, optional): HNSW ef_runtime parameter for this query. Only used with HNSW index. Defaults to None.
|
|
|
|
Returns:
|
|
list: List of OutputData objects.
|
|
"""
|
|
# Convert the vector to bytes
|
|
vector_bytes = np.array(vectors, dtype=np.float32).tobytes()
|
|
|
|
# Build the KNN part with optional EF_RUNTIME for HNSW
|
|
if self.index_type == "hnsw" and ef_runtime is not None:
|
|
knn_part = f"[KNN {limit} @embedding $vec_param EF_RUNTIME {ef_runtime} AS vector_score]"
|
|
else:
|
|
# For FLAT indexes or when ef_runtime is None, use basic KNN
|
|
knn_part = f"[KNN {limit} @embedding $vec_param AS vector_score]"
|
|
|
|
# Build the complete query
|
|
q = self._build_search_query(knn_part, filters)
|
|
|
|
# Log the query for debugging (only in debug mode)
|
|
logger.debug(f"Valkey search query: {q}")
|
|
|
|
# Set up the query parameters
|
|
params = {"vec_param": vector_bytes}
|
|
|
|
# Execute the search
|
|
results = self._execute_search(q, params)
|
|
|
|
# Process the results
|
|
return self._process_search_results(results)
|
|
|
|
def delete(self, vector_id):
|
|
"""
|
|
Delete a vector from the index.
|
|
|
|
Args:
|
|
vector_id (str): ID of the vector to delete.
|
|
"""
|
|
try:
|
|
key = f"{self.prefix}:{vector_id}"
|
|
self.client.delete(key)
|
|
logger.debug(f"Successfully deleted vector with ID {vector_id}")
|
|
except Exception as e:
|
|
logger.exception(f"Error deleting vector with ID {vector_id}: {e}")
|
|
raise
|
|
|
|
def update(self, vector_id=None, vector=None, payload=None):
|
|
"""
|
|
Update a vector in the index.
|
|
|
|
Args:
|
|
vector_id (str): ID of the vector to update.
|
|
vector (list, optional): New vector data.
|
|
payload (dict, optional): New payload data.
|
|
"""
|
|
try:
|
|
key = f"{self.prefix}:{vector_id}"
|
|
|
|
# Check for required fields and provide defaults if missing
|
|
if "data" not in payload:
|
|
# Silently use default value for missing 'data' field
|
|
pass
|
|
|
|
# Ensure created_at is present
|
|
if "created_at" not in payload:
|
|
payload["created_at"] = datetime.now(pytz.timezone(self.timezone)).isoformat()
|
|
|
|
# Prepare the hash data
|
|
hash_data = {
|
|
"memory_id": vector_id,
|
|
"hash": payload.get("hash", f"hash_{vector_id}"), # Use a default hash if not provided
|
|
"memory": payload.get("data", f"data_{vector_id}"), # Use a default data if not provided
|
|
"created_at": int(datetime.fromisoformat(payload["created_at"]).timestamp()),
|
|
"embedding": np.array(vector, dtype=np.float32).tobytes(),
|
|
}
|
|
|
|
# Add updated_at if available
|
|
if "updated_at" in payload:
|
|
hash_data["updated_at"] = int(datetime.fromisoformat(payload["updated_at"]).timestamp())
|
|
|
|
# Add optional fields
|
|
for field in ["agent_id", "run_id", "user_id"]:
|
|
if field in payload:
|
|
hash_data[field] = payload[field]
|
|
|
|
# Add metadata
|
|
hash_data["metadata"] = json.dumps({k: v for k, v in payload.items() if k not in excluded_keys})
|
|
|
|
# Update in Valkey
|
|
self.client.hset(key, mapping=hash_data)
|
|
logger.debug(f"Successfully updated vector with ID {vector_id}")
|
|
except KeyError as e:
|
|
logger.error(f"Error updating vector with ID {vector_id}: Missing required field {e}")
|
|
except Exception as e:
|
|
logger.exception(f"Error updating vector with ID {vector_id}: {e}")
|
|
raise
|
|
|
|
def _format_timestamp(self, timestamp, timezone=None):
|
|
"""
|
|
Format a timestamp with the specified timezone.
|
|
|
|
Args:
|
|
timestamp (int): The timestamp to format.
|
|
timezone (str, optional): The timezone to use. Defaults to UTC.
|
|
|
|
Returns:
|
|
str: The formatted timestamp.
|
|
"""
|
|
# Use UTC as default timezone if not specified
|
|
tz = pytz.timezone(timezone or "UTC")
|
|
return datetime.fromtimestamp(timestamp, tz=tz).isoformat(timespec="microseconds")
|
|
|
|
def _process_document_fields(self, result, vector_id):
|
|
"""
|
|
Process document fields from a Valkey hash result.
|
|
|
|
Args:
|
|
result (dict): The hash result from Valkey.
|
|
vector_id (str): The vector ID.
|
|
|
|
Returns:
|
|
dict: The processed payload.
|
|
str: The memory ID.
|
|
"""
|
|
# Create the payload with error handling
|
|
payload = {}
|
|
|
|
# Convert bytes to string for text fields
|
|
for k in result:
|
|
if k not in ["embedding"]:
|
|
if isinstance(result[k], bytes):
|
|
try:
|
|
result[k] = result[k].decode("utf-8")
|
|
except UnicodeDecodeError:
|
|
# If decoding fails, keep the bytes
|
|
pass
|
|
|
|
# Add required fields with error handling
|
|
for field in ["hash", "memory", "created_at"]:
|
|
if field in result:
|
|
if field == "created_at":
|
|
try:
|
|
payload[field] = self._format_timestamp(int(result[field]), self.timezone)
|
|
except (ValueError, TypeError):
|
|
payload[field] = result[field]
|
|
else:
|
|
payload[field] = result[field]
|
|
else:
|
|
# Use default values for missing fields
|
|
if field == "hash":
|
|
payload[field] = "unknown"
|
|
elif field == "memory":
|
|
payload[field] = "unknown"
|
|
elif field == "created_at":
|
|
payload[field] = self._format_timestamp(
|
|
int(datetime.now(tz=pytz.timezone(self.timezone)).timestamp()), self.timezone
|
|
)
|
|
|
|
# Rename memory to data for consistency
|
|
if "memory" in payload:
|
|
payload["data"] = payload.pop("memory")
|
|
|
|
# Add updated_at if available
|
|
if "updated_at" in result:
|
|
try:
|
|
payload["updated_at"] = self._format_timestamp(int(result["updated_at"]), self.timezone)
|
|
except (ValueError, TypeError):
|
|
payload["updated_at"] = result["updated_at"]
|
|
|
|
# Add optional fields
|
|
for field in ["agent_id", "run_id", "user_id"]:
|
|
if field in result:
|
|
payload[field] = result[field]
|
|
|
|
# Add metadata
|
|
if "metadata" in result:
|
|
try:
|
|
metadata = json.loads(extract_json(result["metadata"]))
|
|
payload.update(metadata)
|
|
except (json.JSONDecodeError, TypeError):
|
|
logger.warning(f"Failed to parse metadata: {result.get('metadata')}")
|
|
|
|
# Use memory_id from result if available, otherwise use vector_id
|
|
memory_id = result.get("memory_id", vector_id)
|
|
|
|
return payload, memory_id
|
|
|
|
def _convert_bytes(self, data):
|
|
"""Convert bytes data back to string"""
|
|
if isinstance(data, bytes):
|
|
try:
|
|
return data.decode("utf-8")
|
|
except UnicodeDecodeError:
|
|
return data
|
|
if isinstance(data, dict):
|
|
return {self._convert_bytes(key): self._convert_bytes(value) for key, value in data.items()}
|
|
if isinstance(data, list):
|
|
return [self._convert_bytes(item) for item in data]
|
|
if isinstance(data, tuple):
|
|
return tuple(self._convert_bytes(item) for item in data)
|
|
return data
|
|
|
|
def get(self, vector_id):
|
|
"""
|
|
Get a vector by ID.
|
|
|
|
Args:
|
|
vector_id (str): ID of the vector to get.
|
|
|
|
Returns:
|
|
OutputData: The retrieved vector.
|
|
"""
|
|
try:
|
|
key = f"{self.prefix}:{vector_id}"
|
|
result = self.client.hgetall(key)
|
|
|
|
if not result:
|
|
raise KeyError(f"Vector with ID {vector_id} not found")
|
|
|
|
# Convert bytes keys/values to strings
|
|
result = self._convert_bytes(result)
|
|
|
|
logger.debug(f"Retrieved result keys: {result.keys()}")
|
|
|
|
# Process the document fields
|
|
payload, memory_id = self._process_document_fields(result, vector_id)
|
|
|
|
return OutputData(id=memory_id, payload=payload, score=0.0)
|
|
except KeyError:
|
|
raise
|
|
except Exception as e:
|
|
logger.exception(f"Error getting vector with ID {vector_id}: {e}")
|
|
raise
|
|
|
|
def list_cols(self):
|
|
"""
|
|
List all collections (indices) in Valkey.
|
|
|
|
Returns:
|
|
list: List of collection names.
|
|
"""
|
|
try:
|
|
# Use the FT._LIST command to list all indices
|
|
return self.client.execute_command("FT._LIST")
|
|
except Exception as e:
|
|
logger.exception(f"Error listing collections: {e}")
|
|
raise
|
|
|
|
def _drop_index(self, collection_name, log_level="error"):
|
|
"""
|
|
Drop an index by name using the documented FT.DROPINDEX command.
|
|
|
|
Args:
|
|
collection_name (str): Name of the index to drop.
|
|
log_level (str): Logging level for missing index ("silent", "info", "error").
|
|
"""
|
|
try:
|
|
self.client.execute_command("FT.DROPINDEX", collection_name)
|
|
logger.info(f"Successfully deleted index {collection_name}")
|
|
return True
|
|
except ResponseError as e:
|
|
if "Unknown index name" in str(e):
|
|
# Index doesn't exist - handle based on context
|
|
if log_level == "silent":
|
|
pass # No logging in situations where this is expected such as initial index creation
|
|
elif log_level == "info":
|
|
logger.info(f"Index {collection_name} doesn't exist, skipping deletion")
|
|
return False
|
|
else:
|
|
# Real error - always log and raise
|
|
logger.error(f"Error deleting index {collection_name}: {e}")
|
|
raise
|
|
except Exception as e:
|
|
# Non-ResponseError exceptions - always log and raise
|
|
logger.error(f"Error deleting index {collection_name}: {e}")
|
|
raise
|
|
|
|
def delete_col(self):
|
|
"""
|
|
Delete the current collection (index).
|
|
"""
|
|
return self._drop_index(self.collection_name, log_level="info")
|
|
|
|
def col_info(self, name=None):
|
|
"""
|
|
Get information about a collection (index).
|
|
|
|
Args:
|
|
name (str, optional): Name of the collection. Defaults to None, which uses the current collection_name.
|
|
|
|
Returns:
|
|
dict: Information about the collection.
|
|
"""
|
|
try:
|
|
collection_name = name or self.collection_name
|
|
return self.client.ft(collection_name).info()
|
|
except Exception as e:
|
|
logger.exception(f"Error getting collection info for {collection_name}: {e}")
|
|
raise
|
|
|
|
def reset(self):
|
|
"""
|
|
Reset the index by deleting and recreating it.
|
|
"""
|
|
try:
|
|
collection_name = self.collection_name
|
|
logger.warning(f"Resetting index {collection_name}...")
|
|
|
|
# Delete the index
|
|
self.delete_col()
|
|
|
|
# Recreate the index
|
|
self._create_index(self.embedding_model_dims)
|
|
|
|
return True
|
|
except Exception as e:
|
|
logger.exception(f"Error resetting index {self.collection_name}: {e}")
|
|
raise
|
|
|
|
def _build_list_query(self, filters=None):
|
|
"""
|
|
Build a query for listing vectors.
|
|
|
|
Args:
|
|
filters (dict, optional): Filters to apply to the list. Each key-value pair
|
|
becomes a tag filter (@key:{value}). None values are ignored.
|
|
Values are used as-is (no validation) - wildcards, lists, etc. are
|
|
passed through literally to Valkey search.
|
|
|
|
Returns:
|
|
str: The query string. Returns "*" if no valid filters provided.
|
|
"""
|
|
# Default query
|
|
q = "*"
|
|
|
|
# Add filters if provided
|
|
if filters and any(value is not None for key, value in filters.items()):
|
|
filter_conditions = []
|
|
for key, value in filters.items():
|
|
if value is not None:
|
|
filter_conditions.append(f"@{key}:{{{value}}}")
|
|
|
|
if filter_conditions:
|
|
q = " ".join(filter_conditions)
|
|
|
|
return q
|
|
|
|
def list(self, filters: dict = None, limit: int = None) -> list:
|
|
"""
|
|
List all recent created memories from the vector store.
|
|
|
|
Args:
|
|
filters (dict, optional): Filters to apply to the list. Each key-value pair
|
|
becomes a tag filter (@key:{value}). None values are ignored.
|
|
Values are used as-is without validation - wildcards, special characters,
|
|
lists, etc. are passed through literally to Valkey search.
|
|
Multiple filters are combined with AND logic.
|
|
limit (int, optional): Maximum number of results to return. Defaults to 1000
|
|
if not specified.
|
|
|
|
Returns:
|
|
list: Nested list format [[MemoryResult(), ...]] matching Redis implementation.
|
|
Each MemoryResult contains id and payload with hash, data, timestamps, etc.
|
|
"""
|
|
try:
|
|
# Since Valkey search requires vector format, use a dummy vector search
|
|
# that returns all documents by using a zero vector and large K
|
|
dummy_vector = [0.0] * self.embedding_model_dims
|
|
search_limit = limit if limit is not None else 1000 # Large default
|
|
|
|
# Use the existing search method which handles filters properly
|
|
search_results = self.search("", dummy_vector, limit=search_limit, filters=filters)
|
|
|
|
# Convert search results to list format (match Redis format)
|
|
class MemoryResult:
|
|
def __init__(self, id: str, payload: dict, score: float = None):
|
|
self.id = id
|
|
self.payload = payload
|
|
self.score = score
|
|
|
|
memory_results = []
|
|
for result in search_results:
|
|
# Create payload in the expected format
|
|
payload = {
|
|
"hash": result.payload.get("hash", ""),
|
|
"data": result.payload.get("data", ""),
|
|
"created_at": result.payload.get("created_at"),
|
|
"updated_at": result.payload.get("updated_at"),
|
|
}
|
|
|
|
# Add metadata (exclude system fields)
|
|
for key, value in result.payload.items():
|
|
if key not in ["data", "hash", "created_at", "updated_at"]:
|
|
payload[key] = value
|
|
|
|
# Create MemoryResult object (matching Redis format)
|
|
memory_results.append(MemoryResult(id=result.id, payload=payload))
|
|
|
|
# Return nested list format like Redis
|
|
return [memory_results]
|
|
|
|
except Exception as e:
|
|
logger.exception(f"Error in list method: {e}")
|
|
return [[]] # Return empty result on error
|