Files
mem0/memory/kuzu_memory.py
2026-03-06 21:11:10 +08:00

715 lines
30 KiB
Python

import logging
from mem0.memory.utils import format_entities
try:
import kuzu
except ImportError:
raise ImportError("kuzu is not installed. Please install it using pip install kuzu")
try:
from rank_bm25 import BM25Okapi
except ImportError:
raise ImportError("rank_bm25 is not installed. Please install it using pip install rank-bm25")
from mem0.graphs.tools import (
DELETE_MEMORY_STRUCT_TOOL_GRAPH,
DELETE_MEMORY_TOOL_GRAPH,
EXTRACT_ENTITIES_STRUCT_TOOL,
EXTRACT_ENTITIES_TOOL,
RELATIONS_STRUCT_TOOL,
RELATIONS_TOOL,
)
from mem0.graphs.utils import EXTRACT_RELATIONS_PROMPT, get_delete_messages
from mem0.utils.factory import EmbedderFactory, LlmFactory
logger = logging.getLogger(__name__)
class MemoryGraph:
def __init__(self, config):
self.config = config
self.embedding_model = EmbedderFactory.create(
self.config.embedder.provider,
self.config.embedder.config,
self.config.vector_store.config,
)
self.embedding_dims = self.embedding_model.config.embedding_dims
if self.embedding_dims is None or self.embedding_dims <= 0:
raise ValueError(f"embedding_dims must be a positive integer. Given: {self.embedding_dims}")
self.db = kuzu.Database(self.config.graph_store.config.db)
self.graph = kuzu.Connection(self.db)
self.node_label = ":Entity"
self.rel_label = ":CONNECTED_TO"
self.kuzu_create_schema()
# Default to openai if no specific provider is configured
self.llm_provider = "openai"
if self.config.llm and self.config.llm.provider:
self.llm_provider = self.config.llm.provider
if self.config.graph_store and self.config.graph_store.llm and self.config.graph_store.llm.provider:
self.llm_provider = self.config.graph_store.llm.provider
# Get LLM config with proper null checks
llm_config = None
if self.config.graph_store and self.config.graph_store.llm and hasattr(self.config.graph_store.llm, "config"):
llm_config = self.config.graph_store.llm.config
elif hasattr(self.config.llm, "config"):
llm_config = self.config.llm.config
self.llm = LlmFactory.create(self.llm_provider, llm_config)
self.user_id = None
# Use threshold from graph_store config, default to 0.7 for backward compatibility
self.threshold = self.config.graph_store.threshold if hasattr(self.config.graph_store, 'threshold') else 0.7
def kuzu_create_schema(self):
self.kuzu_execute(
"""
CREATE NODE TABLE IF NOT EXISTS Entity(
id SERIAL PRIMARY KEY,
user_id STRING,
agent_id STRING,
run_id STRING,
name STRING,
mentions INT64,
created TIMESTAMP,
embedding FLOAT[]);
"""
)
self.kuzu_execute(
"""
CREATE REL TABLE IF NOT EXISTS CONNECTED_TO(
FROM Entity TO Entity,
name STRING,
mentions INT64,
created TIMESTAMP,
updated TIMESTAMP
);
"""
)
def kuzu_execute(self, query, parameters=None):
results = self.graph.execute(query, parameters)
return list(results.rows_as_dict())
def add(self, data, filters):
"""
Adds data to the graph.
Args:
data (str): The data to add to the graph.
filters (dict): A dictionary containing filters to be applied during the addition.
"""
entity_type_map = self._retrieve_nodes_from_data(data, filters)
to_be_added = self._establish_nodes_relations_from_data(data, filters, entity_type_map)
search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
to_be_deleted = self._get_delete_entities_from_search_output(search_output, data, filters)
deleted_entities = self._delete_entities(to_be_deleted, filters)
added_entities = self._add_entities(to_be_added, filters, entity_type_map)
return {"deleted_entities": deleted_entities, "added_entities": added_entities}
def search(self, query, filters, limit=5):
"""
Search for memories and related graph data.
Args:
query (str): Query to search for.
filters (dict): A dictionary containing filters to be applied during the search.
limit (int): The maximum number of nodes and relationships to retrieve. Defaults to 100.
Returns:
dict: A dictionary containing:
- "contexts": List of search results from the base data store.
- "entities": List of related graph data based on the query.
"""
entity_type_map = self._retrieve_nodes_from_data(query, filters)
search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
if not search_output:
return []
search_outputs_sequence = [
[item["source"], item["relationship"], item["destination"]] for item in search_output
]
bm25 = BM25Okapi(search_outputs_sequence)
tokenized_query = query.split(" ")
reranked_results = bm25.get_top_n(tokenized_query, search_outputs_sequence, n=limit)
search_results = []
for item in reranked_results:
search_results.append({"source": item[0], "relationship": item[1], "destination": item[2]})
logger.info(f"Returned {len(search_results)} search results")
return search_results
def delete_all(self, filters):
# Build node properties for filtering
node_props = ["user_id: $user_id"]
if filters.get("agent_id"):
node_props.append("agent_id: $agent_id")
if filters.get("run_id"):
node_props.append("run_id: $run_id")
node_props_str = ", ".join(node_props)
cypher = f"""
MATCH (n {self.node_label} {{{node_props_str}}})
DETACH DELETE n
"""
params = {"user_id": filters["user_id"]}
if filters.get("agent_id"):
params["agent_id"] = filters["agent_id"]
if filters.get("run_id"):
params["run_id"] = filters["run_id"]
self.kuzu_execute(cypher, parameters=params)
def get_all(self, filters, limit=100):
"""
Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
Args:
filters (dict): A dictionary containing filters to be applied during the retrieval.
limit (int): The maximum number of nodes and relationships to retrieve. Defaults to 100.
Returns:
list: A list of dictionaries, each containing:
- 'contexts': The base data store response for each memory.
- 'entities': A list of strings representing the nodes and relationships
"""
params = {
"user_id": filters["user_id"],
"limit": limit,
}
# Build node properties based on filters
node_props = ["user_id: $user_id"]
if filters.get("agent_id"):
node_props.append("agent_id: $agent_id")
params["agent_id"] = filters["agent_id"]
if filters.get("run_id"):
node_props.append("run_id: $run_id")
params["run_id"] = filters["run_id"]
node_props_str = ", ".join(node_props)
query = f"""
MATCH (n {self.node_label} {{{node_props_str}}})-[r]->(m {self.node_label} {{{node_props_str}}})
RETURN
n.name AS source,
r.name AS relationship,
m.name AS target
LIMIT $limit
"""
results = self.kuzu_execute(query, parameters=params)
final_results = []
for result in results:
final_results.append(
{
"source": result["source"],
"relationship": result["relationship"],
"target": result["target"],
}
)
logger.info(f"Retrieved {len(final_results)} relationships")
return final_results
def _retrieve_nodes_from_data(self, data, filters):
"""Extracts all the entities mentioned in the query."""
_tools = [EXTRACT_ENTITIES_TOOL]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
_tools = [EXTRACT_ENTITIES_STRUCT_TOOL]
search_results = self.llm.generate_response(
messages=[
{
"role": "system",
"content": f"You are a smart assistant who understands entities and their types in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source entity. Extract all the entities from the text. ***DO NOT*** answer the question itself if the given text is a question.",
},
{"role": "user", "content": data},
],
tools=_tools,
)
entity_type_map = {}
try:
for tool_call in search_results["tool_calls"]:
if tool_call["name"] != "extract_entities":
continue
for item in tool_call["arguments"]["entities"]:
entity_type_map[item["entity"]] = item["entity_type"]
except Exception as e:
logger.exception(
f"Error in search tool: {e}, llm_provider={self.llm_provider}, search_results={search_results}"
)
entity_type_map = {k.lower().replace(" ", "_"): v.lower().replace(" ", "_") for k, v in entity_type_map.items()}
logger.debug(f"Entity type map: {entity_type_map}\n search_results={search_results}")
return entity_type_map
def _establish_nodes_relations_from_data(self, data, filters, entity_type_map):
"""Establish relations among the extracted nodes."""
# Compose user identification string for prompt
user_identity = f"user_id: {filters['user_id']}"
if filters.get("agent_id"):
user_identity += f", agent_id: {filters['agent_id']}"
if filters.get("run_id"):
user_identity += f", run_id: {filters['run_id']}"
if self.config.graph_store.custom_prompt:
system_content = EXTRACT_RELATIONS_PROMPT.replace("USER_ID", user_identity)
# Add the custom prompt line if configured
system_content = system_content.replace("CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}")
messages = [
{"role": "system", "content": system_content},
{"role": "user", "content": data},
]
else:
system_content = EXTRACT_RELATIONS_PROMPT.replace("USER_ID", user_identity)
messages = [
{"role": "system", "content": system_content},
{"role": "user", "content": f"List of entities: {list(entity_type_map.keys())}. \n\nText: {data}"},
]
_tools = [RELATIONS_TOOL]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
_tools = [RELATIONS_STRUCT_TOOL]
extracted_entities = self.llm.generate_response(
messages=messages,
tools=_tools,
)
entities = []
if extracted_entities.get("tool_calls"):
entities = extracted_entities["tool_calls"][0].get("arguments", {}).get("entities", [])
entities = self._remove_spaces_from_entities(entities)
logger.debug(f"Extracted entities: {entities}")
return entities
def _search_graph_db(self, node_list, filters, limit=100, threshold=None):
"""Search similar nodes among and their respective incoming and outgoing relations."""
result_relations = []
params = {
"threshold": threshold if threshold else self.threshold,
"user_id": filters["user_id"],
"limit": limit,
}
# Build node properties for filtering
node_props = ["user_id: $user_id"]
if filters.get("agent_id"):
node_props.append("agent_id: $agent_id")
params["agent_id"] = filters["agent_id"]
if filters.get("run_id"):
node_props.append("run_id: $run_id")
params["run_id"] = filters["run_id"]
node_props_str = ", ".join(node_props)
for node in node_list:
n_embedding = self.embedding_model.embed(node)
params["n_embedding"] = n_embedding
results = []
for match_fragment in [
f"(n)-[r]->(m {self.node_label} {{{node_props_str}}}) WITH n as src, r, m as dst, similarity",
f"(m {self.node_label} {{{node_props_str}}})-[r]->(n) WITH m as src, r, n as dst, similarity"
]:
results.extend(self.kuzu_execute(
f"""
MATCH (n {self.node_label} {{{node_props_str}}})
WHERE n.embedding IS NOT NULL
WITH n, array_cosine_similarity(n.embedding, CAST($n_embedding,'FLOAT[{self.embedding_dims}]')) AS similarity
WHERE similarity >= CAST($threshold, 'DOUBLE')
MATCH {match_fragment}
RETURN
src.name AS source,
id(src) AS source_id,
r.name AS relationship,
id(r) AS relation_id,
dst.name AS destination,
id(dst) AS destination_id,
similarity
LIMIT $limit
""",
parameters=params))
# Kuzu does not support sort/limit over unions. Do it manually for now.
result_relations.extend(sorted(results, key=lambda x: x["similarity"], reverse=True)[:limit])
return result_relations
def _get_delete_entities_from_search_output(self, search_output, data, filters):
"""Get the entities to be deleted from the search output."""
search_output_string = format_entities(search_output)
# Compose user identification string for prompt
user_identity = f"user_id: {filters['user_id']}"
if filters.get("agent_id"):
user_identity += f", agent_id: {filters['agent_id']}"
if filters.get("run_id"):
user_identity += f", run_id: {filters['run_id']}"
system_prompt, user_prompt = get_delete_messages(search_output_string, data, user_identity)
_tools = [DELETE_MEMORY_TOOL_GRAPH]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
_tools = [
DELETE_MEMORY_STRUCT_TOOL_GRAPH,
]
memory_updates = self.llm.generate_response(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
tools=_tools,
)
to_be_deleted = []
for item in memory_updates.get("tool_calls", []):
if item.get("name") == "delete_graph_memory":
to_be_deleted.append(item.get("arguments"))
# Clean entities formatting
to_be_deleted = self._remove_spaces_from_entities(to_be_deleted)
logger.debug(f"Deleted relationships: {to_be_deleted}")
return to_be_deleted
def _delete_entities(self, to_be_deleted, filters):
"""Delete the entities from the graph."""
user_id = filters["user_id"]
agent_id = filters.get("agent_id", None)
run_id = filters.get("run_id", None)
results = []
for item in to_be_deleted:
source = item["source"]
destination = item["destination"]
relationship = item["relationship"]
params = {
"source_name": source,
"dest_name": destination,
"user_id": user_id,
"relationship_name": relationship,
}
# Build node properties for filtering
source_props = ["name: $source_name", "user_id: $user_id"]
dest_props = ["name: $dest_name", "user_id: $user_id"]
if agent_id:
source_props.append("agent_id: $agent_id")
dest_props.append("agent_id: $agent_id")
params["agent_id"] = agent_id
if run_id:
source_props.append("run_id: $run_id")
dest_props.append("run_id: $run_id")
params["run_id"] = run_id
source_props_str = ", ".join(source_props)
dest_props_str = ", ".join(dest_props)
# Delete the specific relationship between nodes
cypher = f"""
MATCH (n {self.node_label} {{{source_props_str}}})
-[r {self.rel_label} {{name: $relationship_name}}]->
(m {self.node_label} {{{dest_props_str}}})
DELETE r
RETURN
n.name AS source,
r.name AS relationship,
m.name AS target
"""
result = self.kuzu_execute(cypher, parameters=params)
results.append(result)
return results
def _add_entities(self, to_be_added, filters, entity_type_map):
"""Add the new entities to the graph. Merge the nodes if they already exist."""
user_id = filters["user_id"]
agent_id = filters.get("agent_id", None)
run_id = filters.get("run_id", None)
results = []
for item in to_be_added:
# entities
source = item["source"]
source_label = self.node_label
destination = item["destination"]
destination_label = self.node_label
relationship = item["relationship"]
relationship_label = self.rel_label
# embeddings
source_embedding = self.embedding_model.embed(source)
dest_embedding = self.embedding_model.embed(destination)
# search for the nodes with the closest embeddings
source_node_search_result = self._search_source_node(source_embedding, filters, threshold=self.threshold)
destination_node_search_result = self._search_destination_node(dest_embedding, filters, threshold=self.threshold)
if not destination_node_search_result and source_node_search_result:
params = {
"table_id": source_node_search_result[0]["id"]["table"],
"offset_id": source_node_search_result[0]["id"]["offset"],
"destination_name": destination,
"destination_embedding": dest_embedding,
"relationship_name": relationship,
"user_id": user_id,
}
# Build source MERGE properties
merge_props = ["name: $destination_name", "user_id: $user_id"]
if agent_id:
merge_props.append("agent_id: $agent_id")
params["agent_id"] = agent_id
if run_id:
merge_props.append("run_id: $run_id")
params["run_id"] = run_id
merge_props_str = ", ".join(merge_props)
cypher = f"""
MATCH (source)
WHERE id(source) = internal_id($table_id, $offset_id)
SET source.mentions = coalesce(source.mentions, 0) + 1
WITH source
MERGE (destination {destination_label} {{{merge_props_str}}})
ON CREATE SET
destination.created = current_timestamp(),
destination.mentions = 1,
destination.embedding = CAST($destination_embedding,'FLOAT[{self.embedding_dims}]')
ON MATCH SET
destination.mentions = coalesce(destination.mentions, 0) + 1,
destination.embedding = CAST($destination_embedding,'FLOAT[{self.embedding_dims}]')
WITH source, destination
MERGE (source)-[r {relationship_label} {{name: $relationship_name}}]->(destination)
ON CREATE SET
r.created = current_timestamp(),
r.mentions = 1
ON MATCH SET
r.mentions = coalesce(r.mentions, 0) + 1
RETURN
source.name AS source,
r.name AS relationship,
destination.name AS target
"""
elif destination_node_search_result and not source_node_search_result:
params = {
"table_id": destination_node_search_result[0]["id"]["table"],
"offset_id": destination_node_search_result[0]["id"]["offset"],
"source_name": source,
"source_embedding": source_embedding,
"user_id": user_id,
"relationship_name": relationship,
}
# Build source MERGE properties
merge_props = ["name: $source_name", "user_id: $user_id"]
if agent_id:
merge_props.append("agent_id: $agent_id")
params["agent_id"] = agent_id
if run_id:
merge_props.append("run_id: $run_id")
params["run_id"] = run_id
merge_props_str = ", ".join(merge_props)
cypher = f"""
MATCH (destination)
WHERE id(destination) = internal_id($table_id, $offset_id)
SET destination.mentions = coalesce(destination.mentions, 0) + 1
WITH destination
MERGE (source {source_label} {{{merge_props_str}}})
ON CREATE SET
source.created = current_timestamp(),
source.mentions = 1,
source.embedding = CAST($source_embedding,'FLOAT[{self.embedding_dims}]')
ON MATCH SET
source.mentions = coalesce(source.mentions, 0) + 1,
source.embedding = CAST($source_embedding,'FLOAT[{self.embedding_dims}]')
WITH source, destination
MERGE (source)-[r {relationship_label} {{name: $relationship_name}}]->(destination)
ON CREATE SET
r.created = current_timestamp(),
r.mentions = 1
ON MATCH SET
r.mentions = coalesce(r.mentions, 0) + 1
RETURN
source.name AS source,
r.name AS relationship,
destination.name AS target
"""
elif source_node_search_result and destination_node_search_result:
cypher = f"""
MATCH (source)
WHERE id(source) = internal_id($src_table, $src_offset)
SET source.mentions = coalesce(source.mentions, 0) + 1
WITH source
MATCH (destination)
WHERE id(destination) = internal_id($dst_table, $dst_offset)
SET destination.mentions = coalesce(destination.mentions, 0) + 1
MERGE (source)-[r {relationship_label} {{name: $relationship_name}}]->(destination)
ON CREATE SET
r.created = current_timestamp(),
r.updated = current_timestamp(),
r.mentions = 1
ON MATCH SET r.mentions = coalesce(r.mentions, 0) + 1
RETURN
source.name AS source,
r.name AS relationship,
destination.name AS target
"""
params = {
"src_table": source_node_search_result[0]["id"]["table"],
"src_offset": source_node_search_result[0]["id"]["offset"],
"dst_table": destination_node_search_result[0]["id"]["table"],
"dst_offset": destination_node_search_result[0]["id"]["offset"],
"relationship_name": relationship,
}
else:
params = {
"source_name": source,
"dest_name": destination,
"relationship_name": relationship,
"source_embedding": source_embedding,
"dest_embedding": dest_embedding,
"user_id": user_id,
}
# Build dynamic MERGE props for both source and destination
source_props = ["name: $source_name", "user_id: $user_id"]
dest_props = ["name: $dest_name", "user_id: $user_id"]
if agent_id:
source_props.append("agent_id: $agent_id")
dest_props.append("agent_id: $agent_id")
params["agent_id"] = agent_id
if run_id:
source_props.append("run_id: $run_id")
dest_props.append("run_id: $run_id")
params["run_id"] = run_id
source_props_str = ", ".join(source_props)
dest_props_str = ", ".join(dest_props)
cypher = f"""
MERGE (source {source_label} {{{source_props_str}}})
ON CREATE SET
source.created = current_timestamp(),
source.mentions = 1,
source.embedding = CAST($source_embedding,'FLOAT[{self.embedding_dims}]')
ON MATCH SET
source.mentions = coalesce(source.mentions, 0) + 1,
source.embedding = CAST($source_embedding,'FLOAT[{self.embedding_dims}]')
WITH source
MERGE (destination {destination_label} {{{dest_props_str}}})
ON CREATE SET
destination.created = current_timestamp(),
destination.mentions = 1,
destination.embedding = CAST($dest_embedding,'FLOAT[{self.embedding_dims}]')
ON MATCH SET
destination.mentions = coalesce(destination.mentions, 0) + 1,
destination.embedding = CAST($dest_embedding,'FLOAT[{self.embedding_dims}]')
WITH source, destination
MERGE (source)-[rel {relationship_label} {{name: $relationship_name}}]->(destination)
ON CREATE SET
rel.created = current_timestamp(),
rel.mentions = 1
ON MATCH SET
rel.mentions = coalesce(rel.mentions, 0) + 1
RETURN
source.name AS source,
rel.name AS relationship,
destination.name AS target
"""
result = self.kuzu_execute(cypher, parameters=params)
results.append(result)
return results
def _remove_spaces_from_entities(self, entity_list):
for item in entity_list:
item["source"] = item["source"].lower().replace(" ", "_")
item["relationship"] = item["relationship"].lower().replace(" ", "_")
item["destination"] = item["destination"].lower().replace(" ", "_")
return entity_list
def _search_source_node(self, source_embedding, filters, threshold=0.9):
params = {
"source_embedding": source_embedding,
"user_id": filters["user_id"],
"threshold": threshold,
}
where_conditions = ["source_candidate.embedding IS NOT NULL", "source_candidate.user_id = $user_id"]
if filters.get("agent_id"):
where_conditions.append("source_candidate.agent_id = $agent_id")
params["agent_id"] = filters["agent_id"]
if filters.get("run_id"):
where_conditions.append("source_candidate.run_id = $run_id")
params["run_id"] = filters["run_id"]
where_clause = " AND ".join(where_conditions)
cypher = f"""
MATCH (source_candidate {self.node_label})
WHERE {where_clause}
WITH source_candidate,
array_cosine_similarity(source_candidate.embedding, CAST($source_embedding,'FLOAT[{self.embedding_dims}]')) AS source_similarity
WHERE source_similarity >= $threshold
WITH source_candidate, source_similarity
ORDER BY source_similarity DESC
LIMIT 2
RETURN id(source_candidate) as id, source_similarity
"""
return self.kuzu_execute(cypher, parameters=params)
def _search_destination_node(self, destination_embedding, filters, threshold=0.9):
params = {
"destination_embedding": destination_embedding,
"user_id": filters["user_id"],
"threshold": threshold,
}
where_conditions = ["destination_candidate.embedding IS NOT NULL", "destination_candidate.user_id = $user_id"]
if filters.get("agent_id"):
where_conditions.append("destination_candidate.agent_id = $agent_id")
params["agent_id"] = filters["agent_id"]
if filters.get("run_id"):
where_conditions.append("destination_candidate.run_id = $run_id")
params["run_id"] = filters["run_id"]
where_clause = " AND ".join(where_conditions)
cypher = f"""
MATCH (destination_candidate {self.node_label})
WHERE {where_clause}
WITH destination_candidate,
array_cosine_similarity(destination_candidate.embedding, CAST($destination_embedding,'FLOAT[{self.embedding_dims}]')) AS destination_similarity
WHERE destination_similarity >= $threshold
WITH destination_candidate, destination_similarity
ORDER BY destination_similarity DESC
LIMIT 2
RETURN id(destination_candidate) as id, destination_similarity
"""
return self.kuzu_execute(cypher, parameters=params)
# Reset is not defined in base.py
def reset(self):
"""Reset the graph by clearing all nodes and relationships."""
logger.warning("Clearing graph...")
cypher_query = """
MATCH (n) DETACH DELETE n
"""
return self.kuzu_execute(cypher_query)