Files
mem0/vector_stores/valkey.py
2026-03-06 21:11:10 +08:00

825 lines
30 KiB
Python

import json
import logging
from datetime import datetime
from typing import Dict
import numpy as np
import pytz
import valkey
from pydantic import BaseModel
from valkey.exceptions import ResponseError
from mem0.memory.utils import extract_json
from mem0.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__)
# Default fields for the Valkey index
DEFAULT_FIELDS = [
{"name": "memory_id", "type": "tag"},
{"name": "hash", "type": "tag"},
{"name": "agent_id", "type": "tag"},
{"name": "run_id", "type": "tag"},
{"name": "user_id", "type": "tag"},
{"name": "memory", "type": "tag"}, # Using TAG instead of TEXT for Valkey compatibility
{"name": "metadata", "type": "tag"}, # Using TAG instead of TEXT for Valkey compatibility
{"name": "created_at", "type": "numeric"},
{"name": "updated_at", "type": "numeric"},
{
"name": "embedding",
"type": "vector",
"attrs": {"distance_metric": "cosine", "algorithm": "flat", "datatype": "float32"},
},
]
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
class OutputData(BaseModel):
id: str
score: float
payload: Dict
class ValkeyDB(VectorStoreBase):
def __init__(
self,
valkey_url: str,
collection_name: str,
embedding_model_dims: int,
timezone: str = "UTC",
index_type: str = "hnsw",
hnsw_m: int = 16,
hnsw_ef_construction: int = 200,
hnsw_ef_runtime: int = 10,
):
"""
Initialize the Valkey vector store.
Args:
valkey_url (str): Valkey URL.
collection_name (str): Collection name.
embedding_model_dims (int): Embedding model dimensions.
timezone (str, optional): Timezone for timestamps. Defaults to "UTC".
index_type (str, optional): Index type ('hnsw' or 'flat'). Defaults to "hnsw".
hnsw_m (int, optional): HNSW M parameter (connections per node). Defaults to 16.
hnsw_ef_construction (int, optional): HNSW ef_construction parameter. Defaults to 200.
hnsw_ef_runtime (int, optional): HNSW ef_runtime parameter. Defaults to 10.
"""
self.embedding_model_dims = embedding_model_dims
self.collection_name = collection_name
self.prefix = f"mem0:{collection_name}"
self.timezone = timezone
self.index_type = index_type.lower()
self.hnsw_m = hnsw_m
self.hnsw_ef_construction = hnsw_ef_construction
self.hnsw_ef_runtime = hnsw_ef_runtime
# Validate index type
if self.index_type not in ["hnsw", "flat"]:
raise ValueError(f"Invalid index_type: {index_type}. Must be 'hnsw' or 'flat'")
# Connect to Valkey
try:
self.client = valkey.from_url(valkey_url)
logger.debug(f"Successfully connected to Valkey at {valkey_url}")
except Exception as e:
logger.exception(f"Failed to connect to Valkey at {valkey_url}: {e}")
raise
# Create the index schema
self._create_index(embedding_model_dims)
def _build_index_schema(self, collection_name, embedding_dims, distance_metric, prefix):
"""
Build the FT.CREATE command for index creation.
Args:
collection_name (str): Name of the collection/index
embedding_dims (int): Vector embedding dimensions
distance_metric (str): Distance metric (e.g., "COSINE", "L2", "IP")
prefix (str): Key prefix for the index
Returns:
list: Complete FT.CREATE command as list of arguments
"""
# Build the vector field configuration based on index type
if self.index_type == "hnsw":
vector_config = [
"embedding",
"VECTOR",
"HNSW",
"12", # Attribute count: TYPE, FLOAT32, DIM, dims, DISTANCE_METRIC, metric, M, m, EF_CONSTRUCTION, ef_construction, EF_RUNTIME, ef_runtime
"TYPE",
"FLOAT32",
"DIM",
str(embedding_dims),
"DISTANCE_METRIC",
distance_metric,
"M",
str(self.hnsw_m),
"EF_CONSTRUCTION",
str(self.hnsw_ef_construction),
"EF_RUNTIME",
str(self.hnsw_ef_runtime),
]
elif self.index_type == "flat":
vector_config = [
"embedding",
"VECTOR",
"FLAT",
"6", # Attribute count: TYPE, FLOAT32, DIM, dims, DISTANCE_METRIC, metric
"TYPE",
"FLOAT32",
"DIM",
str(embedding_dims),
"DISTANCE_METRIC",
distance_metric,
]
else:
# This should never happen due to constructor validation, but be defensive
raise ValueError(f"Unsupported index_type: {self.index_type}. Must be 'hnsw' or 'flat'")
# Build the complete command (comma is default separator for TAG fields)
cmd = [
"FT.CREATE",
collection_name,
"ON",
"HASH",
"PREFIX",
"1",
prefix,
"SCHEMA",
"memory_id",
"TAG",
"hash",
"TAG",
"agent_id",
"TAG",
"run_id",
"TAG",
"user_id",
"TAG",
"memory",
"TAG",
"metadata",
"TAG",
"created_at",
"NUMERIC",
"updated_at",
"NUMERIC",
] + vector_config
return cmd
def _create_index(self, embedding_model_dims):
"""
Create the search index with the specified schema.
Args:
embedding_model_dims (int): Dimensions for the vector embeddings.
Raises:
ValueError: If the search module is not available.
Exception: For other errors during index creation.
"""
# Check if the search module is available
try:
# Try to execute a search command
self.client.execute_command("FT._LIST")
except ResponseError as e:
if "unknown command" in str(e).lower():
raise ValueError(
"Valkey search module is not available. Please ensure Valkey is running with the search module enabled. "
"The search module can be loaded using the --loadmodule option with the valkey-search library. "
"For installation and setup instructions, refer to the Valkey Search documentation."
)
else:
logger.exception(f"Error checking search module: {e}")
raise
# Check if the index already exists
try:
self.client.ft(self.collection_name).info()
return
except ResponseError as e:
if "not found" not in str(e).lower():
logger.exception(f"Error checking index existence: {e}")
raise
# Build and execute the index creation command
cmd = self._build_index_schema(
self.collection_name,
embedding_model_dims,
"COSINE", # Fixed distance metric for initialization
self.prefix,
)
try:
self.client.execute_command(*cmd)
logger.info(f"Successfully created {self.index_type.upper()} index {self.collection_name}")
except Exception as e:
logger.exception(f"Error creating index {self.collection_name}: {e}")
raise
def create_col(self, name=None, vector_size=None, distance=None):
"""
Create a new collection (index) in Valkey.
Args:
name (str, optional): Name for the collection. Defaults to None, which uses the current collection_name.
vector_size (int, optional): Size of the vector embeddings. Defaults to None, which uses the current embedding_model_dims.
distance (str, optional): Distance metric to use. Defaults to None, which uses 'cosine'.
Returns:
The created index object.
"""
# Use provided parameters or fall back to instance attributes
collection_name = name or self.collection_name
embedding_dims = vector_size or self.embedding_model_dims
distance_metric = distance or "COSINE"
prefix = f"mem0:{collection_name}"
# Try to drop the index if it exists (cleanup before creation)
self._drop_index(collection_name, log_level="silent")
# Build and execute the index creation command
cmd = self._build_index_schema(
collection_name,
embedding_dims,
distance_metric, # Configurable distance metric
prefix,
)
try:
self.client.execute_command(*cmd)
logger.info(f"Successfully created {self.index_type.upper()} index {collection_name}")
# Update instance attributes if creating a new collection
if name:
self.collection_name = collection_name
self.prefix = prefix
return self.client.ft(collection_name)
except Exception as e:
logger.exception(f"Error creating collection {collection_name}: {e}")
raise
def insert(self, vectors: list, payloads: list = None, ids: list = None):
"""
Insert vectors and their payloads into the index.
Args:
vectors (list): List of vectors to insert.
payloads (list, optional): List of payloads corresponding to the vectors.
ids (list, optional): List of IDs for the vectors.
"""
for vector, payload, id in zip(vectors, payloads, ids):
try:
# Create the key for the hash
key = f"{self.prefix}:{id}"
# Check for required fields and provide defaults if missing
if "data" not in payload:
# Silently use default value for missing 'data' field
pass
# Ensure created_at is present
if "created_at" not in payload:
payload["created_at"] = datetime.now(pytz.timezone(self.timezone)).isoformat()
# Prepare the hash data
hash_data = {
"memory_id": id,
"hash": payload.get("hash", f"hash_{id}"), # Use a default hash if not provided
"memory": payload.get("data", f"data_{id}"), # Use a default data if not provided
"created_at": int(datetime.fromisoformat(payload["created_at"]).timestamp()),
"embedding": np.array(vector, dtype=np.float32).tobytes(),
}
# Add optional fields
for field in ["agent_id", "run_id", "user_id"]:
if field in payload:
hash_data[field] = payload[field]
# Add metadata
hash_data["metadata"] = json.dumps({k: v for k, v in payload.items() if k not in excluded_keys})
# Store in Valkey
self.client.hset(key, mapping=hash_data)
logger.debug(f"Successfully inserted vector with ID {id}")
except KeyError as e:
logger.error(f"Error inserting vector with ID {id}: Missing required field {e}")
except Exception as e:
logger.exception(f"Error inserting vector with ID {id}: {e}")
raise
def _build_search_query(self, knn_part, filters=None):
"""
Build a search query string with filters.
Args:
knn_part (str): The KNN part of the query.
filters (dict, optional): Filters to apply to the search. Each key-value pair
becomes a tag filter (@key:{value}). None values are ignored.
Values are used as-is (no validation) - wildcards, lists, etc. are
passed through literally to Valkey search. Multiple filters are
combined with AND logic (space-separated).
Returns:
str: The complete search query string in format "filter_expr =>[KNN...]"
or "*=>[KNN...]" if no valid filters.
"""
# No filters, just use the KNN search
if not filters or not any(value is not None for key, value in filters.items()):
return f"*=>{knn_part}"
# Build filter expression
filter_parts = []
for key, value in filters.items():
if value is not None:
# Use the correct filter syntax for Valkey
filter_parts.append(f"@{key}:{{{value}}}")
# No valid filter parts
if not filter_parts:
return f"*=>{knn_part}"
# Combine filter parts with proper syntax
filter_expr = " ".join(filter_parts)
return f"{filter_expr} =>{knn_part}"
def _execute_search(self, query, params):
"""
Execute a search query.
Args:
query (str): The search query to execute.
params (dict): The query parameters.
Returns:
The search results.
"""
try:
return self.client.ft(self.collection_name).search(query, query_params=params)
except ResponseError as e:
logger.error(f"Search failed with query '{query}': {e}")
raise
def _process_search_results(self, results):
"""
Process search results into OutputData objects.
Args:
results: The search results from Valkey.
Returns:
list: List of OutputData objects.
"""
memory_results = []
for doc in results.docs:
# Extract the score
score = float(doc.vector_score) if hasattr(doc, "vector_score") else None
# Create the payload
payload = {
"hash": doc.hash,
"data": doc.memory,
"created_at": self._format_timestamp(int(doc.created_at), self.timezone),
}
# Add updated_at if available
if hasattr(doc, "updated_at"):
payload["updated_at"] = self._format_timestamp(int(doc.updated_at), self.timezone)
# Add optional fields
for field in ["agent_id", "run_id", "user_id"]:
if hasattr(doc, field):
payload[field] = getattr(doc, field)
# Add metadata
if hasattr(doc, "metadata"):
try:
metadata = json.loads(extract_json(doc.metadata))
payload.update(metadata)
except (json.JSONDecodeError, TypeError) as e:
logger.warning(f"Failed to parse metadata: {e}")
# Create the result
memory_results.append(OutputData(id=doc.memory_id, score=score, payload=payload))
return memory_results
def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None, ef_runtime: int = None):
"""
Search for similar vectors in the index.
Args:
query (str): The search query.
vectors (list): The vector to search for.
limit (int, optional): Maximum number of results to return. Defaults to 5.
filters (dict, optional): Filters to apply to the search. Defaults to None.
ef_runtime (int, optional): HNSW ef_runtime parameter for this query. Only used with HNSW index. Defaults to None.
Returns:
list: List of OutputData objects.
"""
# Convert the vector to bytes
vector_bytes = np.array(vectors, dtype=np.float32).tobytes()
# Build the KNN part with optional EF_RUNTIME for HNSW
if self.index_type == "hnsw" and ef_runtime is not None:
knn_part = f"[KNN {limit} @embedding $vec_param EF_RUNTIME {ef_runtime} AS vector_score]"
else:
# For FLAT indexes or when ef_runtime is None, use basic KNN
knn_part = f"[KNN {limit} @embedding $vec_param AS vector_score]"
# Build the complete query
q = self._build_search_query(knn_part, filters)
# Log the query for debugging (only in debug mode)
logger.debug(f"Valkey search query: {q}")
# Set up the query parameters
params = {"vec_param": vector_bytes}
# Execute the search
results = self._execute_search(q, params)
# Process the results
return self._process_search_results(results)
def delete(self, vector_id):
"""
Delete a vector from the index.
Args:
vector_id (str): ID of the vector to delete.
"""
try:
key = f"{self.prefix}:{vector_id}"
self.client.delete(key)
logger.debug(f"Successfully deleted vector with ID {vector_id}")
except Exception as e:
logger.exception(f"Error deleting vector with ID {vector_id}: {e}")
raise
def update(self, vector_id=None, vector=None, payload=None):
"""
Update a vector in the index.
Args:
vector_id (str): ID of the vector to update.
vector (list, optional): New vector data.
payload (dict, optional): New payload data.
"""
try:
key = f"{self.prefix}:{vector_id}"
# Check for required fields and provide defaults if missing
if "data" not in payload:
# Silently use default value for missing 'data' field
pass
# Ensure created_at is present
if "created_at" not in payload:
payload["created_at"] = datetime.now(pytz.timezone(self.timezone)).isoformat()
# Prepare the hash data
hash_data = {
"memory_id": vector_id,
"hash": payload.get("hash", f"hash_{vector_id}"), # Use a default hash if not provided
"memory": payload.get("data", f"data_{vector_id}"), # Use a default data if not provided
"created_at": int(datetime.fromisoformat(payload["created_at"]).timestamp()),
"embedding": np.array(vector, dtype=np.float32).tobytes(),
}
# Add updated_at if available
if "updated_at" in payload:
hash_data["updated_at"] = int(datetime.fromisoformat(payload["updated_at"]).timestamp())
# Add optional fields
for field in ["agent_id", "run_id", "user_id"]:
if field in payload:
hash_data[field] = payload[field]
# Add metadata
hash_data["metadata"] = json.dumps({k: v for k, v in payload.items() if k not in excluded_keys})
# Update in Valkey
self.client.hset(key, mapping=hash_data)
logger.debug(f"Successfully updated vector with ID {vector_id}")
except KeyError as e:
logger.error(f"Error updating vector with ID {vector_id}: Missing required field {e}")
except Exception as e:
logger.exception(f"Error updating vector with ID {vector_id}: {e}")
raise
def _format_timestamp(self, timestamp, timezone=None):
"""
Format a timestamp with the specified timezone.
Args:
timestamp (int): The timestamp to format.
timezone (str, optional): The timezone to use. Defaults to UTC.
Returns:
str: The formatted timestamp.
"""
# Use UTC as default timezone if not specified
tz = pytz.timezone(timezone or "UTC")
return datetime.fromtimestamp(timestamp, tz=tz).isoformat(timespec="microseconds")
def _process_document_fields(self, result, vector_id):
"""
Process document fields from a Valkey hash result.
Args:
result (dict): The hash result from Valkey.
vector_id (str): The vector ID.
Returns:
dict: The processed payload.
str: The memory ID.
"""
# Create the payload with error handling
payload = {}
# Convert bytes to string for text fields
for k in result:
if k not in ["embedding"]:
if isinstance(result[k], bytes):
try:
result[k] = result[k].decode("utf-8")
except UnicodeDecodeError:
# If decoding fails, keep the bytes
pass
# Add required fields with error handling
for field in ["hash", "memory", "created_at"]:
if field in result:
if field == "created_at":
try:
payload[field] = self._format_timestamp(int(result[field]), self.timezone)
except (ValueError, TypeError):
payload[field] = result[field]
else:
payload[field] = result[field]
else:
# Use default values for missing fields
if field == "hash":
payload[field] = "unknown"
elif field == "memory":
payload[field] = "unknown"
elif field == "created_at":
payload[field] = self._format_timestamp(
int(datetime.now(tz=pytz.timezone(self.timezone)).timestamp()), self.timezone
)
# Rename memory to data for consistency
if "memory" in payload:
payload["data"] = payload.pop("memory")
# Add updated_at if available
if "updated_at" in result:
try:
payload["updated_at"] = self._format_timestamp(int(result["updated_at"]), self.timezone)
except (ValueError, TypeError):
payload["updated_at"] = result["updated_at"]
# Add optional fields
for field in ["agent_id", "run_id", "user_id"]:
if field in result:
payload[field] = result[field]
# Add metadata
if "metadata" in result:
try:
metadata = json.loads(extract_json(result["metadata"]))
payload.update(metadata)
except (json.JSONDecodeError, TypeError):
logger.warning(f"Failed to parse metadata: {result.get('metadata')}")
# Use memory_id from result if available, otherwise use vector_id
memory_id = result.get("memory_id", vector_id)
return payload, memory_id
def _convert_bytes(self, data):
"""Convert bytes data back to string"""
if isinstance(data, bytes):
try:
return data.decode("utf-8")
except UnicodeDecodeError:
return data
if isinstance(data, dict):
return {self._convert_bytes(key): self._convert_bytes(value) for key, value in data.items()}
if isinstance(data, list):
return [self._convert_bytes(item) for item in data]
if isinstance(data, tuple):
return tuple(self._convert_bytes(item) for item in data)
return data
def get(self, vector_id):
"""
Get a vector by ID.
Args:
vector_id (str): ID of the vector to get.
Returns:
OutputData: The retrieved vector.
"""
try:
key = f"{self.prefix}:{vector_id}"
result = self.client.hgetall(key)
if not result:
raise KeyError(f"Vector with ID {vector_id} not found")
# Convert bytes keys/values to strings
result = self._convert_bytes(result)
logger.debug(f"Retrieved result keys: {result.keys()}")
# Process the document fields
payload, memory_id = self._process_document_fields(result, vector_id)
return OutputData(id=memory_id, payload=payload, score=0.0)
except KeyError:
raise
except Exception as e:
logger.exception(f"Error getting vector with ID {vector_id}: {e}")
raise
def list_cols(self):
"""
List all collections (indices) in Valkey.
Returns:
list: List of collection names.
"""
try:
# Use the FT._LIST command to list all indices
return self.client.execute_command("FT._LIST")
except Exception as e:
logger.exception(f"Error listing collections: {e}")
raise
def _drop_index(self, collection_name, log_level="error"):
"""
Drop an index by name using the documented FT.DROPINDEX command.
Args:
collection_name (str): Name of the index to drop.
log_level (str): Logging level for missing index ("silent", "info", "error").
"""
try:
self.client.execute_command("FT.DROPINDEX", collection_name)
logger.info(f"Successfully deleted index {collection_name}")
return True
except ResponseError as e:
if "Unknown index name" in str(e):
# Index doesn't exist - handle based on context
if log_level == "silent":
pass # No logging in situations where this is expected such as initial index creation
elif log_level == "info":
logger.info(f"Index {collection_name} doesn't exist, skipping deletion")
return False
else:
# Real error - always log and raise
logger.error(f"Error deleting index {collection_name}: {e}")
raise
except Exception as e:
# Non-ResponseError exceptions - always log and raise
logger.error(f"Error deleting index {collection_name}: {e}")
raise
def delete_col(self):
"""
Delete the current collection (index).
"""
return self._drop_index(self.collection_name, log_level="info")
def col_info(self, name=None):
"""
Get information about a collection (index).
Args:
name (str, optional): Name of the collection. Defaults to None, which uses the current collection_name.
Returns:
dict: Information about the collection.
"""
try:
collection_name = name or self.collection_name
return self.client.ft(collection_name).info()
except Exception as e:
logger.exception(f"Error getting collection info for {collection_name}: {e}")
raise
def reset(self):
"""
Reset the index by deleting and recreating it.
"""
try:
collection_name = self.collection_name
logger.warning(f"Resetting index {collection_name}...")
# Delete the index
self.delete_col()
# Recreate the index
self._create_index(self.embedding_model_dims)
return True
except Exception as e:
logger.exception(f"Error resetting index {self.collection_name}: {e}")
raise
def _build_list_query(self, filters=None):
"""
Build a query for listing vectors.
Args:
filters (dict, optional): Filters to apply to the list. Each key-value pair
becomes a tag filter (@key:{value}). None values are ignored.
Values are used as-is (no validation) - wildcards, lists, etc. are
passed through literally to Valkey search.
Returns:
str: The query string. Returns "*" if no valid filters provided.
"""
# Default query
q = "*"
# Add filters if provided
if filters and any(value is not None for key, value in filters.items()):
filter_conditions = []
for key, value in filters.items():
if value is not None:
filter_conditions.append(f"@{key}:{{{value}}}")
if filter_conditions:
q = " ".join(filter_conditions)
return q
def list(self, filters: dict = None, limit: int = None) -> list:
"""
List all recent created memories from the vector store.
Args:
filters (dict, optional): Filters to apply to the list. Each key-value pair
becomes a tag filter (@key:{value}). None values are ignored.
Values are used as-is without validation - wildcards, special characters,
lists, etc. are passed through literally to Valkey search.
Multiple filters are combined with AND logic.
limit (int, optional): Maximum number of results to return. Defaults to 1000
if not specified.
Returns:
list: Nested list format [[MemoryResult(), ...]] matching Redis implementation.
Each MemoryResult contains id and payload with hash, data, timestamps, etc.
"""
try:
# Since Valkey search requires vector format, use a dummy vector search
# that returns all documents by using a zero vector and large K
dummy_vector = [0.0] * self.embedding_model_dims
search_limit = limit if limit is not None else 1000 # Large default
# Use the existing search method which handles filters properly
search_results = self.search("", dummy_vector, limit=search_limit, filters=filters)
# Convert search results to list format (match Redis format)
class MemoryResult:
def __init__(self, id: str, payload: dict, score: float = None):
self.id = id
self.payload = payload
self.score = score
memory_results = []
for result in search_results:
# Create payload in the expected format
payload = {
"hash": result.payload.get("hash", ""),
"data": result.payload.get("data", ""),
"created_at": result.payload.get("created_at"),
"updated_at": result.payload.get("updated_at"),
}
# Add metadata (exclude system fields)
for key, value in result.payload.items():
if key not in ["data", "hash", "created_at", "updated_at"]:
payload[key] = value
# Create MemoryResult object (matching Redis format)
memory_results.append(MemoryResult(id=result.id, payload=payload))
# Return nested list format like Redis
return [memory_results]
except Exception as e:
logger.exception(f"Error in list method: {e}")
return [[]] # Return empty result on error