first commit

This commit is contained in:
2026-03-06 21:11:10 +08:00
commit 927b8a6cac
144 changed files with 26301 additions and 0 deletions

0
memory/__init__.py Normal file
View File

63
memory/base.py Normal file
View File

@@ -0,0 +1,63 @@
from abc import ABC, abstractmethod
class MemoryBase(ABC):
@abstractmethod
def get(self, memory_id):
"""
Retrieve a memory by ID.
Args:
memory_id (str): ID of the memory to retrieve.
Returns:
dict: Retrieved memory.
"""
pass
@abstractmethod
def get_all(self):
"""
List all memories.
Returns:
list: List of all memories.
"""
pass
@abstractmethod
def update(self, memory_id, data):
"""
Update a memory by ID.
Args:
memory_id (str): ID of the memory to update.
data (str): New content to update the memory with.
Returns:
dict: Success message indicating the memory was updated.
"""
pass
@abstractmethod
def delete(self, memory_id):
"""
Delete a memory by ID.
Args:
memory_id (str): ID of the memory to delete.
"""
pass
@abstractmethod
def history(self, memory_id):
"""
Get the history of changes for a memory by ID.
Args:
memory_id (str): ID of the memory to get history for.
Returns:
list: List of changes for the memory.
"""
pass

699
memory/graph_memory.py Normal file
View File

@@ -0,0 +1,699 @@
import logging
from mem0.memory.utils import format_entities, sanitize_relationship_for_cypher
try:
from langchain_neo4j import Neo4jGraph
except ImportError:
raise ImportError("langchain_neo4j is not installed. Please install it using pip install langchain-neo4j")
try:
from rank_bm25 import BM25Okapi
except ImportError:
raise ImportError("rank_bm25 is not installed. Please install it using pip install rank-bm25")
from mem0.graphs.tools import (
DELETE_MEMORY_STRUCT_TOOL_GRAPH,
DELETE_MEMORY_TOOL_GRAPH,
EXTRACT_ENTITIES_STRUCT_TOOL,
EXTRACT_ENTITIES_TOOL,
RELATIONS_STRUCT_TOOL,
RELATIONS_TOOL,
)
from mem0.graphs.utils import EXTRACT_RELATIONS_PROMPT, get_delete_messages
from mem0.utils.factory import EmbedderFactory, LlmFactory
logger = logging.getLogger(__name__)
class MemoryGraph:
def __init__(self, config):
self.config = config
self.graph = Neo4jGraph(
self.config.graph_store.config.url,
self.config.graph_store.config.username,
self.config.graph_store.config.password,
self.config.graph_store.config.database,
refresh_schema=False,
driver_config={"notifications_min_severity": "OFF"},
)
self.embedding_model = EmbedderFactory.create(
self.config.embedder.provider, self.config.embedder.config, self.config.vector_store.config
)
self.node_label = ":`__Entity__`" if self.config.graph_store.config.base_label else ""
if self.config.graph_store.config.base_label:
# Safely add user_id index
try:
self.graph.query(f"CREATE INDEX entity_single IF NOT EXISTS FOR (n {self.node_label}) ON (n.user_id)")
except Exception:
pass
try: # Safely try to add composite index (Enterprise only)
self.graph.query(
f"CREATE INDEX entity_composite IF NOT EXISTS FOR (n {self.node_label}) ON (n.name, n.user_id)"
)
except Exception:
pass
# Default to openai if no specific provider is configured
self.llm_provider = "openai"
if self.config.llm and self.config.llm.provider:
self.llm_provider = self.config.llm.provider
if self.config.graph_store and self.config.graph_store.llm and self.config.graph_store.llm.provider:
self.llm_provider = self.config.graph_store.llm.provider
# Get LLM config with proper null checks
llm_config = None
if self.config.graph_store and self.config.graph_store.llm and hasattr(self.config.graph_store.llm, "config"):
llm_config = self.config.graph_store.llm.config
elif hasattr(self.config.llm, "config"):
llm_config = self.config.llm.config
self.llm = LlmFactory.create(self.llm_provider, llm_config)
self.user_id = None
# Use threshold from graph_store config, default to 0.7 for backward compatibility
self.threshold = self.config.graph_store.threshold if hasattr(self.config.graph_store, 'threshold') else 0.7
def add(self, data, filters):
"""
Adds data to the graph.
Args:
data (str): The data to add to the graph.
filters (dict): A dictionary containing filters to be applied during the addition.
"""
entity_type_map = self._retrieve_nodes_from_data(data, filters)
to_be_added = self._establish_nodes_relations_from_data(data, filters, entity_type_map)
search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
to_be_deleted = self._get_delete_entities_from_search_output(search_output, data, filters)
# TODO: Batch queries with APOC plugin
# TODO: Add more filter support
deleted_entities = self._delete_entities(to_be_deleted, filters)
added_entities = self._add_entities(to_be_added, filters, entity_type_map)
return {"deleted_entities": deleted_entities, "added_entities": added_entities}
def search(self, query, filters, limit=100):
"""
Search for memories and related graph data.
Args:
query (str): Query to search for.
filters (dict): A dictionary containing filters to be applied during the search.
limit (int): The maximum number of nodes and relationships to retrieve. Defaults to 100.
Returns:
dict: A dictionary containing:
- "contexts": List of search results from the base data store.
- "entities": List of related graph data based on the query.
"""
entity_type_map = self._retrieve_nodes_from_data(query, filters)
search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
if not search_output:
return []
search_outputs_sequence = [
[item["source"], item["relationship"], item["destination"]] for item in search_output
]
bm25 = BM25Okapi(search_outputs_sequence)
tokenized_query = query.split(" ")
reranked_results = bm25.get_top_n(tokenized_query, search_outputs_sequence, n=5)
search_results = []
for item in reranked_results:
search_results.append({"source": item[0], "relationship": item[1], "destination": item[2]})
logger.info(f"Returned {len(search_results)} search results")
return search_results
def delete_all(self, filters):
# Build node properties for filtering
node_props = ["user_id: $user_id"]
if filters.get("agent_id"):
node_props.append("agent_id: $agent_id")
if filters.get("run_id"):
node_props.append("run_id: $run_id")
node_props_str = ", ".join(node_props)
cypher = f"""
MATCH (n {self.node_label} {{{node_props_str}}})
DETACH DELETE n
"""
params = {"user_id": filters["user_id"]}
if filters.get("agent_id"):
params["agent_id"] = filters["agent_id"]
if filters.get("run_id"):
params["run_id"] = filters["run_id"]
self.graph.query(cypher, params=params)
def get_all(self, filters, limit=100):
"""
Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
Args:
filters (dict): A dictionary containing filters to be applied during the retrieval.
limit (int): The maximum number of nodes and relationships to retrieve. Defaults to 100.
Returns:
list: A list of dictionaries, each containing:
- 'contexts': The base data store response for each memory.
- 'entities': A list of strings representing the nodes and relationships
"""
params = {"user_id": filters["user_id"], "limit": limit}
# Build node properties based on filters
node_props = ["user_id: $user_id"]
if filters.get("agent_id"):
node_props.append("agent_id: $agent_id")
params["agent_id"] = filters["agent_id"]
if filters.get("run_id"):
node_props.append("run_id: $run_id")
params["run_id"] = filters["run_id"]
node_props_str = ", ".join(node_props)
query = f"""
MATCH (n {self.node_label} {{{node_props_str}}})-[r]->(m {self.node_label} {{{node_props_str}}})
RETURN n.name AS source, type(r) AS relationship, m.name AS target
LIMIT $limit
"""
results = self.graph.query(query, params=params)
final_results = []
for result in results:
final_results.append(
{
"source": result["source"],
"relationship": result["relationship"],
"target": result["target"],
}
)
logger.info(f"Retrieved {len(final_results)} relationships")
return final_results
def _retrieve_nodes_from_data(self, data, filters):
"""Extracts all the entities mentioned in the query."""
_tools = [EXTRACT_ENTITIES_TOOL]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
_tools = [EXTRACT_ENTITIES_STRUCT_TOOL]
search_results = self.llm.generate_response(
messages=[
{
"role": "system",
"content": f"You are a smart assistant who understands entities and their types in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source entity. Extract all the entities from the text. ***DO NOT*** answer the question itself if the given text is a question.",
},
{"role": "user", "content": data},
],
tools=_tools,
)
entity_type_map = {}
try:
for tool_call in search_results["tool_calls"]:
if tool_call["name"] != "extract_entities":
continue
for item in tool_call["arguments"]["entities"]:
entity_type_map[item["entity"]] = item["entity_type"]
except Exception as e:
logger.exception(
f"Error in search tool: {e}, llm_provider={self.llm_provider}, search_results={search_results}"
)
entity_type_map = {k.lower().replace(" ", "_"): v.lower().replace(" ", "_") for k, v in entity_type_map.items()}
logger.debug(f"Entity type map: {entity_type_map}\n search_results={search_results}")
return entity_type_map
def _establish_nodes_relations_from_data(self, data, filters, entity_type_map):
"""Establish relations among the extracted nodes."""
# Compose user identification string for prompt
user_identity = f"user_id: {filters['user_id']}"
if filters.get("agent_id"):
user_identity += f", agent_id: {filters['agent_id']}"
if filters.get("run_id"):
user_identity += f", run_id: {filters['run_id']}"
if self.config.graph_store.custom_prompt:
system_content = EXTRACT_RELATIONS_PROMPT.replace("USER_ID", user_identity)
# Add the custom prompt line if configured
system_content = system_content.replace("CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}")
messages = [
{"role": "system", "content": system_content},
{"role": "user", "content": data},
]
else:
system_content = EXTRACT_RELATIONS_PROMPT.replace("USER_ID", user_identity)
messages = [
{"role": "system", "content": system_content},
{"role": "user", "content": f"List of entities: {list(entity_type_map.keys())}. \n\nText: {data}"},
]
_tools = [RELATIONS_TOOL]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
_tools = [RELATIONS_STRUCT_TOOL]
extracted_entities = self.llm.generate_response(
messages=messages,
tools=_tools,
)
entities = []
if extracted_entities.get("tool_calls"):
entities = extracted_entities["tool_calls"][0].get("arguments", {}).get("entities", [])
entities = self._remove_spaces_from_entities(entities)
logger.debug(f"Extracted entities: {entities}")
return entities
def _search_graph_db(self, node_list, filters, limit=100):
"""Search similar nodes among and their respective incoming and outgoing relations."""
result_relations = []
# Build node properties for filtering
node_props = ["user_id: $user_id"]
if filters.get("agent_id"):
node_props.append("agent_id: $agent_id")
if filters.get("run_id"):
node_props.append("run_id: $run_id")
node_props_str = ", ".join(node_props)
for node in node_list:
n_embedding = self.embedding_model.embed(node)
cypher_query = f"""
MATCH (n {self.node_label} {{{node_props_str}}})
WHERE n.embedding IS NOT NULL
WITH n, round(2 * vector.similarity.cosine(n.embedding, $n_embedding) - 1, 4) AS similarity // denormalize for backward compatibility
WHERE similarity >= $threshold
CALL {{
WITH n
MATCH (n)-[r]->(m {self.node_label} {{{node_props_str}}})
RETURN n.name AS source, elementId(n) AS source_id, type(r) AS relationship, elementId(r) AS relation_id, m.name AS destination, elementId(m) AS destination_id
UNION
WITH n
MATCH (n)<-[r]-(m {self.node_label} {{{node_props_str}}})
RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relationship, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id
}}
WITH distinct source, source_id, relationship, relation_id, destination, destination_id, similarity
RETURN source, source_id, relationship, relation_id, destination, destination_id, similarity
ORDER BY similarity DESC
LIMIT $limit
"""
params = {
"n_embedding": n_embedding,
"threshold": self.threshold,
"user_id": filters["user_id"],
"limit": limit,
}
if filters.get("agent_id"):
params["agent_id"] = filters["agent_id"]
if filters.get("run_id"):
params["run_id"] = filters["run_id"]
ans = self.graph.query(cypher_query, params=params)
result_relations.extend(ans)
return result_relations
def _get_delete_entities_from_search_output(self, search_output, data, filters):
"""Get the entities to be deleted from the search output."""
search_output_string = format_entities(search_output)
# Compose user identification string for prompt
user_identity = f"user_id: {filters['user_id']}"
if filters.get("agent_id"):
user_identity += f", agent_id: {filters['agent_id']}"
if filters.get("run_id"):
user_identity += f", run_id: {filters['run_id']}"
system_prompt, user_prompt = get_delete_messages(search_output_string, data, user_identity)
_tools = [DELETE_MEMORY_TOOL_GRAPH]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
_tools = [
DELETE_MEMORY_STRUCT_TOOL_GRAPH,
]
memory_updates = self.llm.generate_response(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
tools=_tools,
)
to_be_deleted = []
for item in memory_updates.get("tool_calls", []):
if item.get("name") == "delete_graph_memory":
to_be_deleted.append(item.get("arguments"))
# Clean entities formatting
to_be_deleted = self._remove_spaces_from_entities(to_be_deleted)
logger.debug(f"Deleted relationships: {to_be_deleted}")
return to_be_deleted
def _delete_entities(self, to_be_deleted, filters):
"""Delete the entities from the graph."""
user_id = filters["user_id"]
agent_id = filters.get("agent_id", None)
run_id = filters.get("run_id", None)
results = []
for item in to_be_deleted:
source = item["source"]
destination = item["destination"]
relationship = item["relationship"]
# Build the agent filter for the query
params = {
"source_name": source,
"dest_name": destination,
"user_id": user_id,
}
if agent_id:
params["agent_id"] = agent_id
if run_id:
params["run_id"] = run_id
# Build node properties for filtering
source_props = ["name: $source_name", "user_id: $user_id"]
dest_props = ["name: $dest_name", "user_id: $user_id"]
if agent_id:
source_props.append("agent_id: $agent_id")
dest_props.append("agent_id: $agent_id")
if run_id:
source_props.append("run_id: $run_id")
dest_props.append("run_id: $run_id")
source_props_str = ", ".join(source_props)
dest_props_str = ", ".join(dest_props)
# Delete the specific relationship between nodes
cypher = f"""
MATCH (n {self.node_label} {{{source_props_str}}})
-[r:{relationship}]->
(m {self.node_label} {{{dest_props_str}}})
DELETE r
RETURN
n.name AS source,
m.name AS target,
type(r) AS relationship
"""
result = self.graph.query(cypher, params=params)
results.append(result)
return results
def _add_entities(self, to_be_added, filters, entity_type_map):
"""Add the new entities to the graph. Merge the nodes if they already exist."""
user_id = filters["user_id"]
agent_id = filters.get("agent_id", None)
run_id = filters.get("run_id", None)
results = []
for item in to_be_added:
# entities
source = item["source"]
destination = item["destination"]
relationship = item["relationship"]
# types
source_type = entity_type_map.get(source, "__User__")
source_label = self.node_label if self.node_label else f":`{source_type}`"
source_extra_set = f", source:`{source_type}`" if self.node_label else ""
destination_type = entity_type_map.get(destination, "__User__")
destination_label = self.node_label if self.node_label else f":`{destination_type}`"
destination_extra_set = f", destination:`{destination_type}`" if self.node_label else ""
# embeddings
source_embedding = self.embedding_model.embed(source)
dest_embedding = self.embedding_model.embed(destination)
# search for the nodes with the closest embeddings
source_node_search_result = self._search_source_node(source_embedding, filters, threshold=self.threshold)
destination_node_search_result = self._search_destination_node(dest_embedding, filters, threshold=self.threshold)
# TODO: Create a cypher query and common params for all the cases
if not destination_node_search_result and source_node_search_result:
# Build destination MERGE properties
merge_props = ["name: $destination_name", "user_id: $user_id"]
if agent_id:
merge_props.append("agent_id: $agent_id")
if run_id:
merge_props.append("run_id: $run_id")
merge_props_str = ", ".join(merge_props)
cypher = f"""
MATCH (source)
WHERE elementId(source) = $source_id
SET source.mentions = coalesce(source.mentions, 0) + 1
WITH source
MERGE (destination {destination_label} {{{merge_props_str}}})
ON CREATE SET
destination.created = timestamp(),
destination.mentions = 1
{destination_extra_set}
ON MATCH SET
destination.mentions = coalesce(destination.mentions, 0) + 1
WITH source, destination
CALL db.create.setNodeVectorProperty(destination, 'embedding', $destination_embedding)
WITH source, destination
MERGE (source)-[r:{relationship}]->(destination)
ON CREATE SET
r.created = timestamp(),
r.mentions = 1
ON MATCH SET
r.mentions = coalesce(r.mentions, 0) + 1
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
"""
params = {
"source_id": source_node_search_result[0]["elementId(source_candidate)"],
"destination_name": destination,
"destination_embedding": dest_embedding,
"user_id": user_id,
}
if agent_id:
params["agent_id"] = agent_id
if run_id:
params["run_id"] = run_id
elif destination_node_search_result and not source_node_search_result:
# Build source MERGE properties
merge_props = ["name: $source_name", "user_id: $user_id"]
if agent_id:
merge_props.append("agent_id: $agent_id")
if run_id:
merge_props.append("run_id: $run_id")
merge_props_str = ", ".join(merge_props)
cypher = f"""
MATCH (destination)
WHERE elementId(destination) = $destination_id
SET destination.mentions = coalesce(destination.mentions, 0) + 1
WITH destination
MERGE (source {source_label} {{{merge_props_str}}})
ON CREATE SET
source.created = timestamp(),
source.mentions = 1
{source_extra_set}
ON MATCH SET
source.mentions = coalesce(source.mentions, 0) + 1
WITH source, destination
CALL db.create.setNodeVectorProperty(source, 'embedding', $source_embedding)
WITH source, destination
MERGE (source)-[r:{relationship}]->(destination)
ON CREATE SET
r.created = timestamp(),
r.mentions = 1
ON MATCH SET
r.mentions = coalesce(r.mentions, 0) + 1
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
"""
params = {
"destination_id": destination_node_search_result[0]["elementId(destination_candidate)"],
"source_name": source,
"source_embedding": source_embedding,
"user_id": user_id,
}
if agent_id:
params["agent_id"] = agent_id
if run_id:
params["run_id"] = run_id
elif source_node_search_result and destination_node_search_result:
cypher = f"""
MATCH (source)
WHERE elementId(source) = $source_id
SET source.mentions = coalesce(source.mentions, 0) + 1
WITH source
MATCH (destination)
WHERE elementId(destination) = $destination_id
SET destination.mentions = coalesce(destination.mentions, 0) + 1
MERGE (source)-[r:{relationship}]->(destination)
ON CREATE SET
r.created_at = timestamp(),
r.updated_at = timestamp(),
r.mentions = 1
ON MATCH SET r.mentions = coalesce(r.mentions, 0) + 1
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
"""
params = {
"source_id": source_node_search_result[0]["elementId(source_candidate)"],
"destination_id": destination_node_search_result[0]["elementId(destination_candidate)"],
"user_id": user_id,
}
if agent_id:
params["agent_id"] = agent_id
if run_id:
params["run_id"] = run_id
else:
# Build dynamic MERGE props for both source and destination
source_props = ["name: $source_name", "user_id: $user_id"]
dest_props = ["name: $dest_name", "user_id: $user_id"]
if agent_id:
source_props.append("agent_id: $agent_id")
dest_props.append("agent_id: $agent_id")
if run_id:
source_props.append("run_id: $run_id")
dest_props.append("run_id: $run_id")
source_props_str = ", ".join(source_props)
dest_props_str = ", ".join(dest_props)
cypher = f"""
MERGE (source {source_label} {{{source_props_str}}})
ON CREATE SET source.created = timestamp(),
source.mentions = 1
{source_extra_set}
ON MATCH SET source.mentions = coalesce(source.mentions, 0) + 1
WITH source
CALL db.create.setNodeVectorProperty(source, 'embedding', $source_embedding)
WITH source
MERGE (destination {destination_label} {{{dest_props_str}}})
ON CREATE SET destination.created = timestamp(),
destination.mentions = 1
{destination_extra_set}
ON MATCH SET destination.mentions = coalesce(destination.mentions, 0) + 1
WITH source, destination
CALL db.create.setNodeVectorProperty(destination, 'embedding', $dest_embedding)
WITH source, destination
MERGE (source)-[rel:{relationship}]->(destination)
ON CREATE SET rel.created = timestamp(), rel.mentions = 1
ON MATCH SET rel.mentions = coalesce(rel.mentions, 0) + 1
RETURN source.name AS source, type(rel) AS relationship, destination.name AS target
"""
params = {
"source_name": source,
"dest_name": destination,
"source_embedding": source_embedding,
"dest_embedding": dest_embedding,
"user_id": user_id,
}
if agent_id:
params["agent_id"] = agent_id
if run_id:
params["run_id"] = run_id
result = self.graph.query(cypher, params=params)
results.append(result)
return results
def _remove_spaces_from_entities(self, entity_list):
for item in entity_list:
item["source"] = item["source"].lower().replace(" ", "_")
# Use the sanitization function for relationships to handle special characters
item["relationship"] = sanitize_relationship_for_cypher(item["relationship"].lower().replace(" ", "_"))
item["destination"] = item["destination"].lower().replace(" ", "_")
return entity_list
def _search_source_node(self, source_embedding, filters, threshold=0.9):
# Build WHERE conditions
where_conditions = ["source_candidate.embedding IS NOT NULL", "source_candidate.user_id = $user_id"]
if filters.get("agent_id"):
where_conditions.append("source_candidate.agent_id = $agent_id")
if filters.get("run_id"):
where_conditions.append("source_candidate.run_id = $run_id")
where_clause = " AND ".join(where_conditions)
cypher = f"""
MATCH (source_candidate {self.node_label})
WHERE {where_clause}
WITH source_candidate,
round(2 * vector.similarity.cosine(source_candidate.embedding, $source_embedding) - 1, 4) AS source_similarity // denormalize for backward compatibility
WHERE source_similarity >= $threshold
WITH source_candidate, source_similarity
ORDER BY source_similarity DESC
LIMIT 1
RETURN elementId(source_candidate)
"""
params = {
"source_embedding": source_embedding,
"user_id": filters["user_id"],
"threshold": threshold,
}
if filters.get("agent_id"):
params["agent_id"] = filters["agent_id"]
if filters.get("run_id"):
params["run_id"] = filters["run_id"]
result = self.graph.query(cypher, params=params)
return result
def _search_destination_node(self, destination_embedding, filters, threshold=0.9):
# Build WHERE conditions
where_conditions = ["destination_candidate.embedding IS NOT NULL", "destination_candidate.user_id = $user_id"]
if filters.get("agent_id"):
where_conditions.append("destination_candidate.agent_id = $agent_id")
if filters.get("run_id"):
where_conditions.append("destination_candidate.run_id = $run_id")
where_clause = " AND ".join(where_conditions)
cypher = f"""
MATCH (destination_candidate {self.node_label})
WHERE {where_clause}
WITH destination_candidate,
round(2 * vector.similarity.cosine(destination_candidate.embedding, $destination_embedding) - 1, 4) AS destination_similarity // denormalize for backward compatibility
WHERE destination_similarity >= $threshold
WITH destination_candidate, destination_similarity
ORDER BY destination_similarity DESC
LIMIT 1
RETURN elementId(destination_candidate)
"""
params = {
"destination_embedding": destination_embedding,
"user_id": filters["user_id"],
"threshold": threshold,
}
if filters.get("agent_id"):
params["agent_id"] = filters["agent_id"]
if filters.get("run_id"):
params["run_id"] = filters["run_id"]
result = self.graph.query(cypher, params=params)
return result
# Reset is not defined in base.py
def reset(self):
"""Reset the graph by clearing all nodes and relationships."""
logger.warning("Clearing graph...")
cypher_query = """
MATCH (n) DETACH DELETE n
"""
return self.graph.query(cypher_query)

714
memory/kuzu_memory.py Normal file
View File

@@ -0,0 +1,714 @@
import logging
from mem0.memory.utils import format_entities
try:
import kuzu
except ImportError:
raise ImportError("kuzu is not installed. Please install it using pip install kuzu")
try:
from rank_bm25 import BM25Okapi
except ImportError:
raise ImportError("rank_bm25 is not installed. Please install it using pip install rank-bm25")
from mem0.graphs.tools import (
DELETE_MEMORY_STRUCT_TOOL_GRAPH,
DELETE_MEMORY_TOOL_GRAPH,
EXTRACT_ENTITIES_STRUCT_TOOL,
EXTRACT_ENTITIES_TOOL,
RELATIONS_STRUCT_TOOL,
RELATIONS_TOOL,
)
from mem0.graphs.utils import EXTRACT_RELATIONS_PROMPT, get_delete_messages
from mem0.utils.factory import EmbedderFactory, LlmFactory
logger = logging.getLogger(__name__)
class MemoryGraph:
def __init__(self, config):
self.config = config
self.embedding_model = EmbedderFactory.create(
self.config.embedder.provider,
self.config.embedder.config,
self.config.vector_store.config,
)
self.embedding_dims = self.embedding_model.config.embedding_dims
if self.embedding_dims is None or self.embedding_dims <= 0:
raise ValueError(f"embedding_dims must be a positive integer. Given: {self.embedding_dims}")
self.db = kuzu.Database(self.config.graph_store.config.db)
self.graph = kuzu.Connection(self.db)
self.node_label = ":Entity"
self.rel_label = ":CONNECTED_TO"
self.kuzu_create_schema()
# Default to openai if no specific provider is configured
self.llm_provider = "openai"
if self.config.llm and self.config.llm.provider:
self.llm_provider = self.config.llm.provider
if self.config.graph_store and self.config.graph_store.llm and self.config.graph_store.llm.provider:
self.llm_provider = self.config.graph_store.llm.provider
# Get LLM config with proper null checks
llm_config = None
if self.config.graph_store and self.config.graph_store.llm and hasattr(self.config.graph_store.llm, "config"):
llm_config = self.config.graph_store.llm.config
elif hasattr(self.config.llm, "config"):
llm_config = self.config.llm.config
self.llm = LlmFactory.create(self.llm_provider, llm_config)
self.user_id = None
# Use threshold from graph_store config, default to 0.7 for backward compatibility
self.threshold = self.config.graph_store.threshold if hasattr(self.config.graph_store, 'threshold') else 0.7
def kuzu_create_schema(self):
self.kuzu_execute(
"""
CREATE NODE TABLE IF NOT EXISTS Entity(
id SERIAL PRIMARY KEY,
user_id STRING,
agent_id STRING,
run_id STRING,
name STRING,
mentions INT64,
created TIMESTAMP,
embedding FLOAT[]);
"""
)
self.kuzu_execute(
"""
CREATE REL TABLE IF NOT EXISTS CONNECTED_TO(
FROM Entity TO Entity,
name STRING,
mentions INT64,
created TIMESTAMP,
updated TIMESTAMP
);
"""
)
def kuzu_execute(self, query, parameters=None):
results = self.graph.execute(query, parameters)
return list(results.rows_as_dict())
def add(self, data, filters):
"""
Adds data to the graph.
Args:
data (str): The data to add to the graph.
filters (dict): A dictionary containing filters to be applied during the addition.
"""
entity_type_map = self._retrieve_nodes_from_data(data, filters)
to_be_added = self._establish_nodes_relations_from_data(data, filters, entity_type_map)
search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
to_be_deleted = self._get_delete_entities_from_search_output(search_output, data, filters)
deleted_entities = self._delete_entities(to_be_deleted, filters)
added_entities = self._add_entities(to_be_added, filters, entity_type_map)
return {"deleted_entities": deleted_entities, "added_entities": added_entities}
def search(self, query, filters, limit=5):
"""
Search for memories and related graph data.
Args:
query (str): Query to search for.
filters (dict): A dictionary containing filters to be applied during the search.
limit (int): The maximum number of nodes and relationships to retrieve. Defaults to 100.
Returns:
dict: A dictionary containing:
- "contexts": List of search results from the base data store.
- "entities": List of related graph data based on the query.
"""
entity_type_map = self._retrieve_nodes_from_data(query, filters)
search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
if not search_output:
return []
search_outputs_sequence = [
[item["source"], item["relationship"], item["destination"]] for item in search_output
]
bm25 = BM25Okapi(search_outputs_sequence)
tokenized_query = query.split(" ")
reranked_results = bm25.get_top_n(tokenized_query, search_outputs_sequence, n=limit)
search_results = []
for item in reranked_results:
search_results.append({"source": item[0], "relationship": item[1], "destination": item[2]})
logger.info(f"Returned {len(search_results)} search results")
return search_results
def delete_all(self, filters):
# Build node properties for filtering
node_props = ["user_id: $user_id"]
if filters.get("agent_id"):
node_props.append("agent_id: $agent_id")
if filters.get("run_id"):
node_props.append("run_id: $run_id")
node_props_str = ", ".join(node_props)
cypher = f"""
MATCH (n {self.node_label} {{{node_props_str}}})
DETACH DELETE n
"""
params = {"user_id": filters["user_id"]}
if filters.get("agent_id"):
params["agent_id"] = filters["agent_id"]
if filters.get("run_id"):
params["run_id"] = filters["run_id"]
self.kuzu_execute(cypher, parameters=params)
def get_all(self, filters, limit=100):
"""
Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
Args:
filters (dict): A dictionary containing filters to be applied during the retrieval.
limit (int): The maximum number of nodes and relationships to retrieve. Defaults to 100.
Returns:
list: A list of dictionaries, each containing:
- 'contexts': The base data store response for each memory.
- 'entities': A list of strings representing the nodes and relationships
"""
params = {
"user_id": filters["user_id"],
"limit": limit,
}
# Build node properties based on filters
node_props = ["user_id: $user_id"]
if filters.get("agent_id"):
node_props.append("agent_id: $agent_id")
params["agent_id"] = filters["agent_id"]
if filters.get("run_id"):
node_props.append("run_id: $run_id")
params["run_id"] = filters["run_id"]
node_props_str = ", ".join(node_props)
query = f"""
MATCH (n {self.node_label} {{{node_props_str}}})-[r]->(m {self.node_label} {{{node_props_str}}})
RETURN
n.name AS source,
r.name AS relationship,
m.name AS target
LIMIT $limit
"""
results = self.kuzu_execute(query, parameters=params)
final_results = []
for result in results:
final_results.append(
{
"source": result["source"],
"relationship": result["relationship"],
"target": result["target"],
}
)
logger.info(f"Retrieved {len(final_results)} relationships")
return final_results
def _retrieve_nodes_from_data(self, data, filters):
"""Extracts all the entities mentioned in the query."""
_tools = [EXTRACT_ENTITIES_TOOL]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
_tools = [EXTRACT_ENTITIES_STRUCT_TOOL]
search_results = self.llm.generate_response(
messages=[
{
"role": "system",
"content": f"You are a smart assistant who understands entities and their types in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source entity. Extract all the entities from the text. ***DO NOT*** answer the question itself if the given text is a question.",
},
{"role": "user", "content": data},
],
tools=_tools,
)
entity_type_map = {}
try:
for tool_call in search_results["tool_calls"]:
if tool_call["name"] != "extract_entities":
continue
for item in tool_call["arguments"]["entities"]:
entity_type_map[item["entity"]] = item["entity_type"]
except Exception as e:
logger.exception(
f"Error in search tool: {e}, llm_provider={self.llm_provider}, search_results={search_results}"
)
entity_type_map = {k.lower().replace(" ", "_"): v.lower().replace(" ", "_") for k, v in entity_type_map.items()}
logger.debug(f"Entity type map: {entity_type_map}\n search_results={search_results}")
return entity_type_map
def _establish_nodes_relations_from_data(self, data, filters, entity_type_map):
"""Establish relations among the extracted nodes."""
# Compose user identification string for prompt
user_identity = f"user_id: {filters['user_id']}"
if filters.get("agent_id"):
user_identity += f", agent_id: {filters['agent_id']}"
if filters.get("run_id"):
user_identity += f", run_id: {filters['run_id']}"
if self.config.graph_store.custom_prompt:
system_content = EXTRACT_RELATIONS_PROMPT.replace("USER_ID", user_identity)
# Add the custom prompt line if configured
system_content = system_content.replace("CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}")
messages = [
{"role": "system", "content": system_content},
{"role": "user", "content": data},
]
else:
system_content = EXTRACT_RELATIONS_PROMPT.replace("USER_ID", user_identity)
messages = [
{"role": "system", "content": system_content},
{"role": "user", "content": f"List of entities: {list(entity_type_map.keys())}. \n\nText: {data}"},
]
_tools = [RELATIONS_TOOL]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
_tools = [RELATIONS_STRUCT_TOOL]
extracted_entities = self.llm.generate_response(
messages=messages,
tools=_tools,
)
entities = []
if extracted_entities.get("tool_calls"):
entities = extracted_entities["tool_calls"][0].get("arguments", {}).get("entities", [])
entities = self._remove_spaces_from_entities(entities)
logger.debug(f"Extracted entities: {entities}")
return entities
def _search_graph_db(self, node_list, filters, limit=100, threshold=None):
"""Search similar nodes among and their respective incoming and outgoing relations."""
result_relations = []
params = {
"threshold": threshold if threshold else self.threshold,
"user_id": filters["user_id"],
"limit": limit,
}
# Build node properties for filtering
node_props = ["user_id: $user_id"]
if filters.get("agent_id"):
node_props.append("agent_id: $agent_id")
params["agent_id"] = filters["agent_id"]
if filters.get("run_id"):
node_props.append("run_id: $run_id")
params["run_id"] = filters["run_id"]
node_props_str = ", ".join(node_props)
for node in node_list:
n_embedding = self.embedding_model.embed(node)
params["n_embedding"] = n_embedding
results = []
for match_fragment in [
f"(n)-[r]->(m {self.node_label} {{{node_props_str}}}) WITH n as src, r, m as dst, similarity",
f"(m {self.node_label} {{{node_props_str}}})-[r]->(n) WITH m as src, r, n as dst, similarity"
]:
results.extend(self.kuzu_execute(
f"""
MATCH (n {self.node_label} {{{node_props_str}}})
WHERE n.embedding IS NOT NULL
WITH n, array_cosine_similarity(n.embedding, CAST($n_embedding,'FLOAT[{self.embedding_dims}]')) AS similarity
WHERE similarity >= CAST($threshold, 'DOUBLE')
MATCH {match_fragment}
RETURN
src.name AS source,
id(src) AS source_id,
r.name AS relationship,
id(r) AS relation_id,
dst.name AS destination,
id(dst) AS destination_id,
similarity
LIMIT $limit
""",
parameters=params))
# Kuzu does not support sort/limit over unions. Do it manually for now.
result_relations.extend(sorted(results, key=lambda x: x["similarity"], reverse=True)[:limit])
return result_relations
def _get_delete_entities_from_search_output(self, search_output, data, filters):
"""Get the entities to be deleted from the search output."""
search_output_string = format_entities(search_output)
# Compose user identification string for prompt
user_identity = f"user_id: {filters['user_id']}"
if filters.get("agent_id"):
user_identity += f", agent_id: {filters['agent_id']}"
if filters.get("run_id"):
user_identity += f", run_id: {filters['run_id']}"
system_prompt, user_prompt = get_delete_messages(search_output_string, data, user_identity)
_tools = [DELETE_MEMORY_TOOL_GRAPH]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
_tools = [
DELETE_MEMORY_STRUCT_TOOL_GRAPH,
]
memory_updates = self.llm.generate_response(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
tools=_tools,
)
to_be_deleted = []
for item in memory_updates.get("tool_calls", []):
if item.get("name") == "delete_graph_memory":
to_be_deleted.append(item.get("arguments"))
# Clean entities formatting
to_be_deleted = self._remove_spaces_from_entities(to_be_deleted)
logger.debug(f"Deleted relationships: {to_be_deleted}")
return to_be_deleted
def _delete_entities(self, to_be_deleted, filters):
"""Delete the entities from the graph."""
user_id = filters["user_id"]
agent_id = filters.get("agent_id", None)
run_id = filters.get("run_id", None)
results = []
for item in to_be_deleted:
source = item["source"]
destination = item["destination"]
relationship = item["relationship"]
params = {
"source_name": source,
"dest_name": destination,
"user_id": user_id,
"relationship_name": relationship,
}
# Build node properties for filtering
source_props = ["name: $source_name", "user_id: $user_id"]
dest_props = ["name: $dest_name", "user_id: $user_id"]
if agent_id:
source_props.append("agent_id: $agent_id")
dest_props.append("agent_id: $agent_id")
params["agent_id"] = agent_id
if run_id:
source_props.append("run_id: $run_id")
dest_props.append("run_id: $run_id")
params["run_id"] = run_id
source_props_str = ", ".join(source_props)
dest_props_str = ", ".join(dest_props)
# Delete the specific relationship between nodes
cypher = f"""
MATCH (n {self.node_label} {{{source_props_str}}})
-[r {self.rel_label} {{name: $relationship_name}}]->
(m {self.node_label} {{{dest_props_str}}})
DELETE r
RETURN
n.name AS source,
r.name AS relationship,
m.name AS target
"""
result = self.kuzu_execute(cypher, parameters=params)
results.append(result)
return results
def _add_entities(self, to_be_added, filters, entity_type_map):
"""Add the new entities to the graph. Merge the nodes if they already exist."""
user_id = filters["user_id"]
agent_id = filters.get("agent_id", None)
run_id = filters.get("run_id", None)
results = []
for item in to_be_added:
# entities
source = item["source"]
source_label = self.node_label
destination = item["destination"]
destination_label = self.node_label
relationship = item["relationship"]
relationship_label = self.rel_label
# embeddings
source_embedding = self.embedding_model.embed(source)
dest_embedding = self.embedding_model.embed(destination)
# search for the nodes with the closest embeddings
source_node_search_result = self._search_source_node(source_embedding, filters, threshold=self.threshold)
destination_node_search_result = self._search_destination_node(dest_embedding, filters, threshold=self.threshold)
if not destination_node_search_result and source_node_search_result:
params = {
"table_id": source_node_search_result[0]["id"]["table"],
"offset_id": source_node_search_result[0]["id"]["offset"],
"destination_name": destination,
"destination_embedding": dest_embedding,
"relationship_name": relationship,
"user_id": user_id,
}
# Build source MERGE properties
merge_props = ["name: $destination_name", "user_id: $user_id"]
if agent_id:
merge_props.append("agent_id: $agent_id")
params["agent_id"] = agent_id
if run_id:
merge_props.append("run_id: $run_id")
params["run_id"] = run_id
merge_props_str = ", ".join(merge_props)
cypher = f"""
MATCH (source)
WHERE id(source) = internal_id($table_id, $offset_id)
SET source.mentions = coalesce(source.mentions, 0) + 1
WITH source
MERGE (destination {destination_label} {{{merge_props_str}}})
ON CREATE SET
destination.created = current_timestamp(),
destination.mentions = 1,
destination.embedding = CAST($destination_embedding,'FLOAT[{self.embedding_dims}]')
ON MATCH SET
destination.mentions = coalesce(destination.mentions, 0) + 1,
destination.embedding = CAST($destination_embedding,'FLOAT[{self.embedding_dims}]')
WITH source, destination
MERGE (source)-[r {relationship_label} {{name: $relationship_name}}]->(destination)
ON CREATE SET
r.created = current_timestamp(),
r.mentions = 1
ON MATCH SET
r.mentions = coalesce(r.mentions, 0) + 1
RETURN
source.name AS source,
r.name AS relationship,
destination.name AS target
"""
elif destination_node_search_result and not source_node_search_result:
params = {
"table_id": destination_node_search_result[0]["id"]["table"],
"offset_id": destination_node_search_result[0]["id"]["offset"],
"source_name": source,
"source_embedding": source_embedding,
"user_id": user_id,
"relationship_name": relationship,
}
# Build source MERGE properties
merge_props = ["name: $source_name", "user_id: $user_id"]
if agent_id:
merge_props.append("agent_id: $agent_id")
params["agent_id"] = agent_id
if run_id:
merge_props.append("run_id: $run_id")
params["run_id"] = run_id
merge_props_str = ", ".join(merge_props)
cypher = f"""
MATCH (destination)
WHERE id(destination) = internal_id($table_id, $offset_id)
SET destination.mentions = coalesce(destination.mentions, 0) + 1
WITH destination
MERGE (source {source_label} {{{merge_props_str}}})
ON CREATE SET
source.created = current_timestamp(),
source.mentions = 1,
source.embedding = CAST($source_embedding,'FLOAT[{self.embedding_dims}]')
ON MATCH SET
source.mentions = coalesce(source.mentions, 0) + 1,
source.embedding = CAST($source_embedding,'FLOAT[{self.embedding_dims}]')
WITH source, destination
MERGE (source)-[r {relationship_label} {{name: $relationship_name}}]->(destination)
ON CREATE SET
r.created = current_timestamp(),
r.mentions = 1
ON MATCH SET
r.mentions = coalesce(r.mentions, 0) + 1
RETURN
source.name AS source,
r.name AS relationship,
destination.name AS target
"""
elif source_node_search_result and destination_node_search_result:
cypher = f"""
MATCH (source)
WHERE id(source) = internal_id($src_table, $src_offset)
SET source.mentions = coalesce(source.mentions, 0) + 1
WITH source
MATCH (destination)
WHERE id(destination) = internal_id($dst_table, $dst_offset)
SET destination.mentions = coalesce(destination.mentions, 0) + 1
MERGE (source)-[r {relationship_label} {{name: $relationship_name}}]->(destination)
ON CREATE SET
r.created = current_timestamp(),
r.updated = current_timestamp(),
r.mentions = 1
ON MATCH SET r.mentions = coalesce(r.mentions, 0) + 1
RETURN
source.name AS source,
r.name AS relationship,
destination.name AS target
"""
params = {
"src_table": source_node_search_result[0]["id"]["table"],
"src_offset": source_node_search_result[0]["id"]["offset"],
"dst_table": destination_node_search_result[0]["id"]["table"],
"dst_offset": destination_node_search_result[0]["id"]["offset"],
"relationship_name": relationship,
}
else:
params = {
"source_name": source,
"dest_name": destination,
"relationship_name": relationship,
"source_embedding": source_embedding,
"dest_embedding": dest_embedding,
"user_id": user_id,
}
# Build dynamic MERGE props for both source and destination
source_props = ["name: $source_name", "user_id: $user_id"]
dest_props = ["name: $dest_name", "user_id: $user_id"]
if agent_id:
source_props.append("agent_id: $agent_id")
dest_props.append("agent_id: $agent_id")
params["agent_id"] = agent_id
if run_id:
source_props.append("run_id: $run_id")
dest_props.append("run_id: $run_id")
params["run_id"] = run_id
source_props_str = ", ".join(source_props)
dest_props_str = ", ".join(dest_props)
cypher = f"""
MERGE (source {source_label} {{{source_props_str}}})
ON CREATE SET
source.created = current_timestamp(),
source.mentions = 1,
source.embedding = CAST($source_embedding,'FLOAT[{self.embedding_dims}]')
ON MATCH SET
source.mentions = coalesce(source.mentions, 0) + 1,
source.embedding = CAST($source_embedding,'FLOAT[{self.embedding_dims}]')
WITH source
MERGE (destination {destination_label} {{{dest_props_str}}})
ON CREATE SET
destination.created = current_timestamp(),
destination.mentions = 1,
destination.embedding = CAST($dest_embedding,'FLOAT[{self.embedding_dims}]')
ON MATCH SET
destination.mentions = coalesce(destination.mentions, 0) + 1,
destination.embedding = CAST($dest_embedding,'FLOAT[{self.embedding_dims}]')
WITH source, destination
MERGE (source)-[rel {relationship_label} {{name: $relationship_name}}]->(destination)
ON CREATE SET
rel.created = current_timestamp(),
rel.mentions = 1
ON MATCH SET
rel.mentions = coalesce(rel.mentions, 0) + 1
RETURN
source.name AS source,
rel.name AS relationship,
destination.name AS target
"""
result = self.kuzu_execute(cypher, parameters=params)
results.append(result)
return results
def _remove_spaces_from_entities(self, entity_list):
for item in entity_list:
item["source"] = item["source"].lower().replace(" ", "_")
item["relationship"] = item["relationship"].lower().replace(" ", "_")
item["destination"] = item["destination"].lower().replace(" ", "_")
return entity_list
def _search_source_node(self, source_embedding, filters, threshold=0.9):
params = {
"source_embedding": source_embedding,
"user_id": filters["user_id"],
"threshold": threshold,
}
where_conditions = ["source_candidate.embedding IS NOT NULL", "source_candidate.user_id = $user_id"]
if filters.get("agent_id"):
where_conditions.append("source_candidate.agent_id = $agent_id")
params["agent_id"] = filters["agent_id"]
if filters.get("run_id"):
where_conditions.append("source_candidate.run_id = $run_id")
params["run_id"] = filters["run_id"]
where_clause = " AND ".join(where_conditions)
cypher = f"""
MATCH (source_candidate {self.node_label})
WHERE {where_clause}
WITH source_candidate,
array_cosine_similarity(source_candidate.embedding, CAST($source_embedding,'FLOAT[{self.embedding_dims}]')) AS source_similarity
WHERE source_similarity >= $threshold
WITH source_candidate, source_similarity
ORDER BY source_similarity DESC
LIMIT 2
RETURN id(source_candidate) as id, source_similarity
"""
return self.kuzu_execute(cypher, parameters=params)
def _search_destination_node(self, destination_embedding, filters, threshold=0.9):
params = {
"destination_embedding": destination_embedding,
"user_id": filters["user_id"],
"threshold": threshold,
}
where_conditions = ["destination_candidate.embedding IS NOT NULL", "destination_candidate.user_id = $user_id"]
if filters.get("agent_id"):
where_conditions.append("destination_candidate.agent_id = $agent_id")
params["agent_id"] = filters["agent_id"]
if filters.get("run_id"):
where_conditions.append("destination_candidate.run_id = $run_id")
params["run_id"] = filters["run_id"]
where_clause = " AND ".join(where_conditions)
cypher = f"""
MATCH (destination_candidate {self.node_label})
WHERE {where_clause}
WITH destination_candidate,
array_cosine_similarity(destination_candidate.embedding, CAST($destination_embedding,'FLOAT[{self.embedding_dims}]')) AS destination_similarity
WHERE destination_similarity >= $threshold
WITH destination_candidate, destination_similarity
ORDER BY destination_similarity DESC
LIMIT 2
RETURN id(destination_candidate) as id, destination_similarity
"""
return self.kuzu_execute(cypher, parameters=params)
# Reset is not defined in base.py
def reset(self):
"""Reset the graph by clearing all nodes and relationships."""
logger.warning("Clearing graph...")
cypher_query = """
MATCH (n) DETACH DELETE n
"""
return self.kuzu_execute(cypher_query)

2325
memory/main.py Normal file

File diff suppressed because it is too large Load Diff

690
memory/memgraph_memory.py Normal file
View File

@@ -0,0 +1,690 @@
import logging
from mem0.memory.utils import format_entities, sanitize_relationship_for_cypher
try:
from langchain_memgraph.graphs.memgraph import Memgraph
except ImportError:
raise ImportError("langchain_memgraph is not installed. Please install it using pip install langchain-memgraph")
try:
from rank_bm25 import BM25Okapi
except ImportError:
raise ImportError("rank_bm25 is not installed. Please install it using pip install rank-bm25")
from mem0.graphs.tools import (
DELETE_MEMORY_STRUCT_TOOL_GRAPH,
DELETE_MEMORY_TOOL_GRAPH,
EXTRACT_ENTITIES_STRUCT_TOOL,
EXTRACT_ENTITIES_TOOL,
RELATIONS_STRUCT_TOOL,
RELATIONS_TOOL,
)
from mem0.graphs.utils import EXTRACT_RELATIONS_PROMPT, get_delete_messages
from mem0.utils.factory import EmbedderFactory, LlmFactory
logger = logging.getLogger(__name__)
class MemoryGraph:
def __init__(self, config):
self.config = config
self.graph = Memgraph(
self.config.graph_store.config.url,
self.config.graph_store.config.username,
self.config.graph_store.config.password,
)
self.embedding_model = EmbedderFactory.create(
self.config.embedder.provider,
self.config.embedder.config,
{"enable_embeddings": True},
)
# Default to openai if no specific provider is configured
self.llm_provider = "openai"
if self.config.llm and self.config.llm.provider:
self.llm_provider = self.config.llm.provider
if self.config.graph_store and self.config.graph_store.llm and self.config.graph_store.llm.provider:
self.llm_provider = self.config.graph_store.llm.provider
# Get LLM config with proper null checks
llm_config = None
if self.config.graph_store and self.config.graph_store.llm and hasattr(self.config.graph_store.llm, "config"):
llm_config = self.config.graph_store.llm.config
elif hasattr(self.config.llm, "config"):
llm_config = self.config.llm.config
self.llm = LlmFactory.create(self.llm_provider, llm_config)
self.user_id = None
# Use threshold from graph_store config, default to 0.7 for backward compatibility
self.threshold = self.config.graph_store.threshold if hasattr(self.config.graph_store, 'threshold') else 0.7
# Setup Memgraph:
# 1. Create vector index (created Entity label on all nodes)
# 2. Create label property index for performance optimizations
embedding_dims = self.config.embedder.config["embedding_dims"]
index_info = self._fetch_existing_indexes()
# Create vector index if not exists
if not self._vector_index_exists(index_info, "memzero"):
self.graph.query(
f"CREATE VECTOR INDEX memzero ON :Entity(embedding) WITH CONFIG {{'dimension': {embedding_dims}, 'capacity': 1000, 'metric': 'cos'}};"
)
# Create label+property index if not exists
if not self._label_property_index_exists(index_info, "Entity", "user_id"):
self.graph.query("CREATE INDEX ON :Entity(user_id);")
# Create label index if not exists
if not self._label_index_exists(index_info, "Entity"):
self.graph.query("CREATE INDEX ON :Entity;")
def add(self, data, filters):
"""
Adds data to the graph.
Args:
data (str): The data to add to the graph.
filters (dict): A dictionary containing filters to be applied during the addition.
"""
entity_type_map = self._retrieve_nodes_from_data(data, filters)
to_be_added = self._establish_nodes_relations_from_data(data, filters, entity_type_map)
search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
to_be_deleted = self._get_delete_entities_from_search_output(search_output, data, filters)
# TODO: Batch queries with APOC plugin
# TODO: Add more filter support
deleted_entities = self._delete_entities(to_be_deleted, filters)
added_entities = self._add_entities(to_be_added, filters, entity_type_map)
return {"deleted_entities": deleted_entities, "added_entities": added_entities}
def search(self, query, filters, limit=100):
"""
Search for memories and related graph data.
Args:
query (str): Query to search for.
filters (dict): A dictionary containing filters to be applied during the search.
limit (int): The maximum number of nodes and relationships to retrieve. Defaults to 100.
Returns:
dict: A dictionary containing:
- "contexts": List of search results from the base data store.
- "entities": List of related graph data based on the query.
"""
entity_type_map = self._retrieve_nodes_from_data(query, filters)
search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
if not search_output:
return []
search_outputs_sequence = [
[item["source"], item["relationship"], item["destination"]] for item in search_output
]
bm25 = BM25Okapi(search_outputs_sequence)
tokenized_query = query.split(" ")
reranked_results = bm25.get_top_n(tokenized_query, search_outputs_sequence, n=5)
search_results = []
for item in reranked_results:
search_results.append({"source": item[0], "relationship": item[1], "destination": item[2]})
logger.info(f"Returned {len(search_results)} search results")
return search_results
def delete_all(self, filters):
"""Delete all nodes and relationships for a user or specific agent."""
if filters.get("agent_id"):
cypher = """
MATCH (n:Entity {user_id: $user_id, agent_id: $agent_id})
DETACH DELETE n
"""
params = {"user_id": filters["user_id"], "agent_id": filters["agent_id"]}
else:
cypher = """
MATCH (n:Entity {user_id: $user_id})
DETACH DELETE n
"""
params = {"user_id": filters["user_id"]}
self.graph.query(cypher, params=params)
def get_all(self, filters, limit=100):
"""
Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
Args:
filters (dict): A dictionary containing filters to be applied during the retrieval.
Supports 'user_id' (required) and 'agent_id' (optional).
limit (int): The maximum number of nodes and relationships to retrieve. Defaults to 100.
Returns:
list: A list of dictionaries, each containing:
- 'source': The source node name.
- 'relationship': The relationship type.
- 'target': The target node name.
"""
# Build query based on whether agent_id is provided
if filters.get("agent_id"):
query = """
MATCH (n:Entity {user_id: $user_id, agent_id: $agent_id})-[r]->(m:Entity {user_id: $user_id, agent_id: $agent_id})
RETURN n.name AS source, type(r) AS relationship, m.name AS target
LIMIT $limit
"""
params = {"user_id": filters["user_id"], "agent_id": filters["agent_id"], "limit": limit}
else:
query = """
MATCH (n:Entity {user_id: $user_id})-[r]->(m:Entity {user_id: $user_id})
RETURN n.name AS source, type(r) AS relationship, m.name AS target
LIMIT $limit
"""
params = {"user_id": filters["user_id"], "limit": limit}
results = self.graph.query(query, params=params)
final_results = []
for result in results:
final_results.append(
{
"source": result["source"],
"relationship": result["relationship"],
"target": result["target"],
}
)
logger.info(f"Retrieved {len(final_results)} relationships")
return final_results
def _retrieve_nodes_from_data(self, data, filters):
"""Extracts all the entities mentioned in the query."""
_tools = [EXTRACT_ENTITIES_TOOL]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
_tools = [EXTRACT_ENTITIES_STRUCT_TOOL]
search_results = self.llm.generate_response(
messages=[
{
"role": "system",
"content": f"You are a smart assistant who understands entities and their types in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source entity. Extract all the entities from the text. ***DO NOT*** answer the question itself if the given text is a question.",
},
{"role": "user", "content": data},
],
tools=_tools,
)
entity_type_map = {}
try:
for tool_call in search_results["tool_calls"]:
if tool_call["name"] != "extract_entities":
continue
for item in tool_call["arguments"]["entities"]:
entity_type_map[item["entity"]] = item["entity_type"]
except Exception as e:
logger.exception(
f"Error in search tool: {e}, llm_provider={self.llm_provider}, search_results={search_results}"
)
entity_type_map = {k.lower().replace(" ", "_"): v.lower().replace(" ", "_") for k, v in entity_type_map.items()}
logger.debug(f"Entity type map: {entity_type_map}\n search_results={search_results}")
return entity_type_map
def _establish_nodes_relations_from_data(self, data, filters, entity_type_map):
"""Eshtablish relations among the extracted nodes."""
if self.config.graph_store.custom_prompt:
messages = [
{
"role": "system",
"content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]).replace(
"CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}"
),
},
{"role": "user", "content": data},
]
else:
messages = [
{
"role": "system",
"content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]),
},
{
"role": "user",
"content": f"List of entities: {list(entity_type_map.keys())}. \n\nText: {data}",
},
]
_tools = [RELATIONS_TOOL]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
_tools = [RELATIONS_STRUCT_TOOL]
extracted_entities = self.llm.generate_response(
messages=messages,
tools=_tools,
)
entities = []
if extracted_entities["tool_calls"]:
entities = extracted_entities["tool_calls"][0]["arguments"]["entities"]
entities = self._remove_spaces_from_entities(entities)
logger.debug(f"Extracted entities: {entities}")
return entities
def _search_graph_db(self, node_list, filters, limit=100):
"""Search similar nodes among and their respective incoming and outgoing relations."""
result_relations = []
for node in node_list:
n_embedding = self.embedding_model.embed(node)
# Build query based on whether agent_id is provided
if filters.get("agent_id"):
cypher_query = """
CALL vector_search.search("memzero", $limit, $n_embedding)
YIELD distance, node, similarity
WITH node AS n, similarity
WHERE n:Entity AND n.user_id = $user_id AND n.agent_id = $agent_id AND n.embedding IS NOT NULL AND similarity >= $threshold
MATCH (n)-[r]->(m:Entity)
RETURN n.name AS source, id(n) AS source_id, type(r) AS relationship, id(r) AS relation_id, m.name AS destination, id(m) AS destination_id, similarity
UNION
CALL vector_search.search("memzero", $limit, $n_embedding)
YIELD distance, node, similarity
WITH node AS n, similarity
WHERE n:Entity AND n.user_id = $user_id AND n.agent_id = $agent_id AND n.embedding IS NOT NULL AND similarity >= $threshold
MATCH (m:Entity)-[r]->(n)
RETURN m.name AS source, id(m) AS source_id, type(r) AS relationship, id(r) AS relation_id, n.name AS destination, id(n) AS destination_id, similarity
ORDER BY similarity DESC
LIMIT $limit;
"""
params = {
"n_embedding": n_embedding,
"threshold": self.threshold,
"user_id": filters["user_id"],
"agent_id": filters["agent_id"],
"limit": limit,
}
else:
cypher_query = """
CALL vector_search.search("memzero", $limit, $n_embedding)
YIELD distance, node, similarity
WITH node AS n, similarity
WHERE n:Entity AND n.user_id = $user_id AND n.embedding IS NOT NULL AND similarity >= $threshold
MATCH (n)-[r]->(m:Entity)
RETURN n.name AS source, id(n) AS source_id, type(r) AS relationship, id(r) AS relation_id, m.name AS destination, id(m) AS destination_id, similarity
UNION
CALL vector_search.search("memzero", $limit, $n_embedding)
YIELD distance, node, similarity
WITH node AS n, similarity
WHERE n:Entity AND n.user_id = $user_id AND n.embedding IS NOT NULL AND similarity >= $threshold
MATCH (m:Entity)-[r]->(n)
RETURN m.name AS source, id(m) AS source_id, type(r) AS relationship, id(r) AS relation_id, n.name AS destination, id(n) AS destination_id, similarity
ORDER BY similarity DESC
LIMIT $limit;
"""
params = {
"n_embedding": n_embedding,
"threshold": self.threshold,
"user_id": filters["user_id"],
"limit": limit,
}
ans = self.graph.query(cypher_query, params=params)
result_relations.extend(ans)
return result_relations
def _get_delete_entities_from_search_output(self, search_output, data, filters):
"""Get the entities to be deleted from the search output."""
search_output_string = format_entities(search_output)
system_prompt, user_prompt = get_delete_messages(search_output_string, data, filters["user_id"])
_tools = [DELETE_MEMORY_TOOL_GRAPH]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
_tools = [
DELETE_MEMORY_STRUCT_TOOL_GRAPH,
]
memory_updates = self.llm.generate_response(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
tools=_tools,
)
to_be_deleted = []
for item in memory_updates["tool_calls"]:
if item["name"] == "delete_graph_memory":
to_be_deleted.append(item["arguments"])
# in case if it is not in the correct format
to_be_deleted = self._remove_spaces_from_entities(to_be_deleted)
logger.debug(f"Deleted relationships: {to_be_deleted}")
return to_be_deleted
def _delete_entities(self, to_be_deleted, filters):
"""Delete the entities from the graph."""
user_id = filters["user_id"]
agent_id = filters.get("agent_id", None)
results = []
for item in to_be_deleted:
source = item["source"]
destination = item["destination"]
relationship = item["relationship"]
# Build the agent filter for the query
agent_filter = ""
params = {
"source_name": source,
"dest_name": destination,
"user_id": user_id,
}
if agent_id:
agent_filter = "AND n.agent_id = $agent_id AND m.agent_id = $agent_id"
params["agent_id"] = agent_id
# Delete the specific relationship between nodes
cypher = f"""
MATCH (n:Entity {{name: $source_name, user_id: $user_id}})
-[r:{relationship}]->
(m:Entity {{name: $dest_name, user_id: $user_id}})
WHERE 1=1 {agent_filter}
DELETE r
RETURN
n.name AS source,
m.name AS target,
type(r) AS relationship
"""
result = self.graph.query(cypher, params=params)
results.append(result)
return results
# added Entity label to all nodes for vector search to work
def _add_entities(self, to_be_added, filters, entity_type_map):
"""Add the new entities to the graph. Merge the nodes if they already exist."""
user_id = filters["user_id"]
agent_id = filters.get("agent_id", None)
results = []
for item in to_be_added:
# entities
source = item["source"]
destination = item["destination"]
relationship = item["relationship"]
# types
source_type = entity_type_map.get(source, "__User__")
destination_type = entity_type_map.get(destination, "__User__")
# embeddings
source_embedding = self.embedding_model.embed(source)
dest_embedding = self.embedding_model.embed(destination)
# search for the nodes with the closest embeddings
source_node_search_result = self._search_source_node(source_embedding, filters, threshold=self.threshold)
destination_node_search_result = self._search_destination_node(dest_embedding, filters, threshold=self.threshold)
# Prepare agent_id for node creation
agent_id_clause = ""
if agent_id:
agent_id_clause = ", agent_id: $agent_id"
# TODO: Create a cypher query and common params for all the cases
if not destination_node_search_result and source_node_search_result:
cypher = f"""
MATCH (source:Entity)
WHERE id(source) = $source_id
MERGE (destination:{destination_type}:Entity {{name: $destination_name, user_id: $user_id{agent_id_clause}}})
ON CREATE SET
destination.created = timestamp(),
destination.embedding = $destination_embedding,
destination:Entity
MERGE (source)-[r:{relationship}]->(destination)
ON CREATE SET
r.created = timestamp()
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
"""
params = {
"source_id": source_node_search_result[0]["id(source_candidate)"],
"destination_name": destination,
"destination_embedding": dest_embedding,
"user_id": user_id,
}
if agent_id:
params["agent_id"] = agent_id
elif destination_node_search_result and not source_node_search_result:
cypher = f"""
MATCH (destination:Entity)
WHERE id(destination) = $destination_id
MERGE (source:{source_type}:Entity {{name: $source_name, user_id: $user_id{agent_id_clause}}})
ON CREATE SET
source.created = timestamp(),
source.embedding = $source_embedding,
source:Entity
MERGE (source)-[r:{relationship}]->(destination)
ON CREATE SET
r.created = timestamp()
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
"""
params = {
"destination_id": destination_node_search_result[0]["id(destination_candidate)"],
"source_name": source,
"source_embedding": source_embedding,
"user_id": user_id,
}
if agent_id:
params["agent_id"] = agent_id
elif source_node_search_result and destination_node_search_result:
cypher = f"""
MATCH (source:Entity)
WHERE id(source) = $source_id
MATCH (destination:Entity)
WHERE id(destination) = $destination_id
MERGE (source)-[r:{relationship}]->(destination)
ON CREATE SET
r.created_at = timestamp(),
r.updated_at = timestamp()
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
"""
params = {
"source_id": source_node_search_result[0]["id(source_candidate)"],
"destination_id": destination_node_search_result[0]["id(destination_candidate)"],
"user_id": user_id,
}
if agent_id:
params["agent_id"] = agent_id
else:
cypher = f"""
MERGE (n:{source_type}:Entity {{name: $source_name, user_id: $user_id{agent_id_clause}}})
ON CREATE SET n.created = timestamp(), n.embedding = $source_embedding, n:Entity
ON MATCH SET n.embedding = $source_embedding
MERGE (m:{destination_type}:Entity {{name: $dest_name, user_id: $user_id{agent_id_clause}}})
ON CREATE SET m.created = timestamp(), m.embedding = $dest_embedding, m:Entity
ON MATCH SET m.embedding = $dest_embedding
MERGE (n)-[rel:{relationship}]->(m)
ON CREATE SET rel.created = timestamp()
RETURN n.name AS source, type(rel) AS relationship, m.name AS target
"""
params = {
"source_name": source,
"dest_name": destination,
"source_embedding": source_embedding,
"dest_embedding": dest_embedding,
"user_id": user_id,
}
if agent_id:
params["agent_id"] = agent_id
result = self.graph.query(cypher, params=params)
results.append(result)
return results
def _remove_spaces_from_entities(self, entity_list):
for item in entity_list:
item["source"] = item["source"].lower().replace(" ", "_")
# Use the sanitization function for relationships to handle special characters
item["relationship"] = sanitize_relationship_for_cypher(item["relationship"].lower().replace(" ", "_"))
item["destination"] = item["destination"].lower().replace(" ", "_")
return entity_list
def _search_source_node(self, source_embedding, filters, threshold=0.9):
"""Search for source nodes with similar embeddings."""
user_id = filters["user_id"]
agent_id = filters.get("agent_id", None)
if agent_id:
cypher = """
CALL vector_search.search("memzero", 1, $source_embedding)
YIELD distance, node, similarity
WITH node AS source_candidate, similarity
WHERE source_candidate.user_id = $user_id
AND source_candidate.agent_id = $agent_id
AND similarity >= $threshold
RETURN id(source_candidate);
"""
params = {
"source_embedding": source_embedding,
"user_id": user_id,
"agent_id": agent_id,
"threshold": threshold,
}
else:
cypher = """
CALL vector_search.search("memzero", 1, $source_embedding)
YIELD distance, node, similarity
WITH node AS source_candidate, similarity
WHERE source_candidate.user_id = $user_id
AND similarity >= $threshold
RETURN id(source_candidate);
"""
params = {
"source_embedding": source_embedding,
"user_id": user_id,
"threshold": threshold,
}
result = self.graph.query(cypher, params=params)
return result
def _search_destination_node(self, destination_embedding, filters, threshold=0.9):
"""Search for destination nodes with similar embeddings."""
user_id = filters["user_id"]
agent_id = filters.get("agent_id", None)
if agent_id:
cypher = """
CALL vector_search.search("memzero", 1, $destination_embedding)
YIELD distance, node, similarity
WITH node AS destination_candidate, similarity
WHERE node.user_id = $user_id
AND node.agent_id = $agent_id
AND similarity >= $threshold
RETURN id(destination_candidate);
"""
params = {
"destination_embedding": destination_embedding,
"user_id": user_id,
"agent_id": agent_id,
"threshold": threshold,
}
else:
cypher = """
CALL vector_search.search("memzero", 1, $destination_embedding)
YIELD distance, node, similarity
WITH node AS destination_candidate, similarity
WHERE node.user_id = $user_id
AND similarity >= $threshold
RETURN id(destination_candidate);
"""
params = {
"destination_embedding": destination_embedding,
"user_id": user_id,
"threshold": threshold,
}
result = self.graph.query(cypher, params=params)
return result
def _vector_index_exists(self, index_info, index_name):
"""
Check if a vector index exists, compatible with both Memgraph versions.
Args:
index_info (dict): Index information from _fetch_existing_indexes
index_name (str): Name of the index to check
Returns:
bool: True if index exists, False otherwise
"""
vector_indexes = index_info.get("vector_index_exists", [])
# Check for index by name regardless of version-specific format differences
return any(
idx.get("index_name") == index_name or
idx.get("index name") == index_name or
idx.get("name") == index_name
for idx in vector_indexes
)
def _label_property_index_exists(self, index_info, label, property_name):
"""
Check if a label+property index exists, compatible with both versions.
Args:
index_info (dict): Index information from _fetch_existing_indexes
label (str): Label name
property_name (str): Property name
Returns:
bool: True if index exists, False otherwise
"""
indexes = index_info.get("index_exists", [])
return any(
(idx.get("index type") == "label+property" or idx.get("index_type") == "label+property") and
(idx.get("label") == label) and
(idx.get("property") == property_name or property_name in str(idx.get("properties", "")))
for idx in indexes
)
def _label_index_exists(self, index_info, label):
"""
Check if a label index exists, compatible with both versions.
Args:
index_info (dict): Index information from _fetch_existing_indexes
label (str): Label name
Returns:
bool: True if index exists, False otherwise
"""
indexes = index_info.get("index_exists", [])
return any(
(idx.get("index type") == "label" or idx.get("index_type") == "label") and
(idx.get("label") == label)
for idx in indexes
)
def _fetch_existing_indexes(self):
"""
Retrieves information about existing indexes and vector indexes in the Memgraph database.
Returns:
dict: A dictionary containing lists of existing indexes and vector indexes.
"""
try:
index_exists = list(self.graph.query("SHOW INDEX INFO;"))
vector_index_exists = list(self.graph.query("SHOW VECTOR INDEX INFO;"))
return {"index_exists": index_exists, "vector_index_exists": vector_index_exists}
except Exception as e:
logger.warning(f"Error fetching indexes: {e}. Returning empty index info.")
return {"index_exists": [], "vector_index_exists": []}

56
memory/setup.py Normal file
View File

@@ -0,0 +1,56 @@
import json
import os
import uuid
# Set up the directory path
VECTOR_ID = str(uuid.uuid4())
home_dir = os.path.expanduser("~")
mem0_dir = os.environ.get("MEM0_DIR") or os.path.join(home_dir, ".mem0")
os.makedirs(mem0_dir, exist_ok=True)
def setup_config():
config_path = os.path.join(mem0_dir, "config.json")
if not os.path.exists(config_path):
user_id = str(uuid.uuid4())
config = {"user_id": user_id}
with open(config_path, "w") as config_file:
json.dump(config, config_file, indent=4)
def get_user_id():
config_path = os.path.join(mem0_dir, "config.json")
if not os.path.exists(config_path):
return "anonymous_user"
try:
with open(config_path, "r") as config_file:
config = json.load(config_file)
user_id = config.get("user_id")
return user_id
except Exception:
return "anonymous_user"
def get_or_create_user_id(vector_store):
"""Store user_id in vector store and return it."""
user_id = get_user_id()
# Try to get existing user_id from vector store
try:
existing = vector_store.get(vector_id=user_id)
if existing and hasattr(existing, "payload") and existing.payload and "user_id" in existing.payload:
return existing.payload["user_id"]
except Exception:
pass
# If we get here, we need to insert the user_id
try:
dims = getattr(vector_store, "embedding_model_dims", 1536)
vector_store.insert(
vectors=[[0.1] * dims], payloads=[{"user_id": user_id, "type": "user_identity"}], ids=[user_id]
)
except Exception:
pass
return user_id

218
memory/storage.py Normal file
View File

@@ -0,0 +1,218 @@
import logging
import sqlite3
import threading
import uuid
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
class SQLiteManager:
def __init__(self, db_path: str = ":memory:"):
self.db_path = db_path
self.connection = sqlite3.connect(self.db_path, check_same_thread=False)
self._lock = threading.Lock()
self._migrate_history_table()
self._create_history_table()
def _migrate_history_table(self) -> None:
"""
If a pre-existing history table had the old group-chat columns,
rename it, create the new schema, copy the intersecting data, then
drop the old table.
"""
with self._lock:
try:
# Start a transaction
self.connection.execute("BEGIN")
cur = self.connection.cursor()
cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='history'")
if cur.fetchone() is None:
self.connection.execute("COMMIT")
return # nothing to migrate
cur.execute("PRAGMA table_info(history)")
old_cols = {row[1] for row in cur.fetchall()}
expected_cols = {
"id",
"memory_id",
"old_memory",
"new_memory",
"event",
"created_at",
"updated_at",
"is_deleted",
"actor_id",
"role",
}
if old_cols == expected_cols:
self.connection.execute("COMMIT")
return
logger.info("Migrating history table to new schema (no convo columns).")
# Clean up any existing history_old table from previous failed migration
cur.execute("DROP TABLE IF EXISTS history_old")
# Rename the current history table
cur.execute("ALTER TABLE history RENAME TO history_old")
# Create the new history table with updated schema
cur.execute(
"""
CREATE TABLE history (
id TEXT PRIMARY KEY,
memory_id TEXT,
old_memory TEXT,
new_memory TEXT,
event TEXT,
created_at DATETIME,
updated_at DATETIME,
is_deleted INTEGER,
actor_id TEXT,
role TEXT
)
"""
)
# Copy data from old table to new table
intersecting = list(expected_cols & old_cols)
if intersecting:
cols_csv = ", ".join(intersecting)
cur.execute(f"INSERT INTO history ({cols_csv}) SELECT {cols_csv} FROM history_old")
# Drop the old table
cur.execute("DROP TABLE history_old")
# Commit the transaction
self.connection.execute("COMMIT")
logger.info("History table migration completed successfully.")
except Exception as e:
# Rollback the transaction on any error
self.connection.execute("ROLLBACK")
logger.error(f"History table migration failed: {e}")
raise
def _create_history_table(self) -> None:
with self._lock:
try:
self.connection.execute("BEGIN")
self.connection.execute(
"""
CREATE TABLE IF NOT EXISTS history (
id TEXT PRIMARY KEY,
memory_id TEXT,
old_memory TEXT,
new_memory TEXT,
event TEXT,
created_at DATETIME,
updated_at DATETIME,
is_deleted INTEGER,
actor_id TEXT,
role TEXT
)
"""
)
self.connection.execute("COMMIT")
except Exception as e:
self.connection.execute("ROLLBACK")
logger.error(f"Failed to create history table: {e}")
raise
def add_history(
self,
memory_id: str,
old_memory: Optional[str],
new_memory: Optional[str],
event: str,
*,
created_at: Optional[str] = None,
updated_at: Optional[str] = None,
is_deleted: int = 0,
actor_id: Optional[str] = None,
role: Optional[str] = None,
) -> None:
with self._lock:
try:
self.connection.execute("BEGIN")
self.connection.execute(
"""
INSERT INTO history (
id, memory_id, old_memory, new_memory, event,
created_at, updated_at, is_deleted, actor_id, role
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
str(uuid.uuid4()),
memory_id,
old_memory,
new_memory,
event,
created_at,
updated_at,
is_deleted,
actor_id,
role,
),
)
self.connection.execute("COMMIT")
except Exception as e:
self.connection.execute("ROLLBACK")
logger.error(f"Failed to add history record: {e}")
raise
def get_history(self, memory_id: str) -> List[Dict[str, Any]]:
with self._lock:
cur = self.connection.execute(
"""
SELECT id, memory_id, old_memory, new_memory, event,
created_at, updated_at, is_deleted, actor_id, role
FROM history
WHERE memory_id = ?
ORDER BY created_at ASC, DATETIME(updated_at) ASC
""",
(memory_id,),
)
rows = cur.fetchall()
return [
{
"id": r[0],
"memory_id": r[1],
"old_memory": r[2],
"new_memory": r[3],
"event": r[4],
"created_at": r[5],
"updated_at": r[6],
"is_deleted": bool(r[7]),
"actor_id": r[8],
"role": r[9],
}
for r in rows
]
def reset(self) -> None:
"""Drop and recreate the history table."""
with self._lock:
try:
self.connection.execute("BEGIN")
self.connection.execute("DROP TABLE IF EXISTS history")
self.connection.execute("COMMIT")
self._create_history_table()
except Exception as e:
self.connection.execute("ROLLBACK")
logger.error(f"Failed to reset history table: {e}")
raise
def close(self) -> None:
if self.connection:
self.connection.close()
self.connection = None
def __del__(self):
self.close()

101
memory/telemetry.py Normal file
View File

@@ -0,0 +1,101 @@
import logging
import os
import platform
import sys
from posthog import Posthog
import mem0
from mem0.memory.setup import get_or_create_user_id
MEM0_TELEMETRY = os.environ.get("MEM0_TELEMETRY", "True")
PROJECT_API_KEY = "phc_hgJkUVJFYtmaJqrvf6CYN67TIQ8yhXAkWzUn9AMU4yX"
HOST = "https://us.i.posthog.com"
if isinstance(MEM0_TELEMETRY, str):
MEM0_TELEMETRY = MEM0_TELEMETRY.lower() in ("true", "1", "yes")
if not isinstance(MEM0_TELEMETRY, bool):
raise ValueError("MEM0_TELEMETRY must be a boolean value.")
logging.getLogger("posthog").setLevel(logging.CRITICAL + 1)
logging.getLogger("urllib3").setLevel(logging.CRITICAL + 1)
class AnonymousTelemetry:
def __init__(self, vector_store=None):
if not MEM0_TELEMETRY:
self.posthog = None
self.user_id = None
return
self.posthog = Posthog(project_api_key=PROJECT_API_KEY, host=HOST)
self.user_id = get_or_create_user_id(vector_store)
def capture_event(self, event_name, properties=None, user_email=None):
if self.posthog is None:
return
if properties is None:
properties = {}
properties = {
"client_source": "python",
"client_version": mem0.__version__,
"python_version": sys.version,
"os": sys.platform,
"os_version": platform.version(),
"os_release": platform.release(),
"processor": platform.processor(),
"machine": platform.machine(),
**properties,
}
distinct_id = self.user_id if user_email is None else user_email
self.posthog.capture(distinct_id=distinct_id, event=event_name, properties=properties)
def close(self):
if self.posthog is not None:
self.posthog.shutdown()
client_telemetry = AnonymousTelemetry()
def capture_event(event_name, memory_instance, additional_data=None):
if not MEM0_TELEMETRY:
return
oss_telemetry = AnonymousTelemetry(
vector_store=memory_instance._telemetry_vector_store
if hasattr(memory_instance, "_telemetry_vector_store")
else None,
)
event_data = {
"collection": memory_instance.collection_name,
"vector_size": memory_instance.embedding_model.config.embedding_dims,
"history_store": "sqlite",
"graph_store": f"{memory_instance.graph.__class__.__module__}.{memory_instance.graph.__class__.__name__}"
if memory_instance.config.graph_store.config
else None,
"vector_store": f"{memory_instance.vector_store.__class__.__module__}.{memory_instance.vector_store.__class__.__name__}",
"llm": f"{memory_instance.llm.__class__.__module__}.{memory_instance.llm.__class__.__name__}",
"embedding_model": f"{memory_instance.embedding_model.__class__.__module__}.{memory_instance.embedding_model.__class__.__name__}",
"function": f"{memory_instance.__class__.__module__}.{memory_instance.__class__.__name__}.{memory_instance.api_version}",
}
if additional_data:
event_data.update(additional_data)
oss_telemetry.capture_event(event_name, event_data)
def capture_client_event(event_name, instance, additional_data=None):
if not MEM0_TELEMETRY:
return
event_data = {
"function": f"{instance.__class__.__module__}.{instance.__class__.__name__}",
}
if additional_data:
event_data.update(additional_data)
client_telemetry.capture_event(event_name, event_data, instance.user_email)

208
memory/utils.py Normal file
View File

@@ -0,0 +1,208 @@
import hashlib
import re
from mem0.configs.prompts import (
FACT_RETRIEVAL_PROMPT,
USER_MEMORY_EXTRACTION_PROMPT,
AGENT_MEMORY_EXTRACTION_PROMPT,
)
def get_fact_retrieval_messages(message, is_agent_memory=False):
"""Get fact retrieval messages based on the memory type.
Args:
message: The message content to extract facts from
is_agent_memory: If True, use agent memory extraction prompt, else use user memory extraction prompt
Returns:
tuple: (system_prompt, user_prompt)
"""
if is_agent_memory:
return AGENT_MEMORY_EXTRACTION_PROMPT, f"Input:\n{message}"
else:
return USER_MEMORY_EXTRACTION_PROMPT, f"Input:\n{message}"
def get_fact_retrieval_messages_legacy(message):
"""Legacy function for backward compatibility."""
return FACT_RETRIEVAL_PROMPT, f"Input:\n{message}"
def parse_messages(messages):
response = ""
for msg in messages:
if msg["role"] == "system":
response += f"system: {msg['content']}\n"
if msg["role"] == "user":
response += f"user: {msg['content']}\n"
if msg["role"] == "assistant":
response += f"assistant: {msg['content']}\n"
return response
def format_entities(entities):
if not entities:
return ""
formatted_lines = []
for entity in entities:
simplified = f"{entity['source']} -- {entity['relationship']} -- {entity['destination']}"
formatted_lines.append(simplified)
return "\n".join(formatted_lines)
def remove_code_blocks(content: str) -> str:
"""
Removes enclosing code block markers ```[language] and ``` from a given string.
Remarks:
- The function uses a regex pattern to match code blocks that may start with ``` followed by an optional language tag (letters or numbers) and end with ```.
- If a code block is detected, it returns only the inner content, stripping out the markers.
- If no code block markers are found, the original content is returned as-is.
"""
pattern = r"^```[a-zA-Z0-9]*\n([\s\S]*?)\n```$"
match = re.match(pattern, content.strip())
match_res=match.group(1).strip() if match else content.strip()
return re.sub(r"<think>.*?</think>", "", match_res, flags=re.DOTALL).strip()
def extract_json(text):
"""
Extracts JSON content from a string, removing enclosing triple backticks and optional 'json' tag if present.
If no code block is found, returns the text as-is.
"""
text = text.strip()
match = re.search(r"```(?:json)?\s*(.*?)\s*```", text, re.DOTALL)
if match:
json_str = match.group(1)
else:
json_str = text # assume it's raw JSON
return json_str
def get_image_description(image_obj, llm, vision_details):
"""
Get the description of the image
"""
if isinstance(image_obj, str):
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "A user is providing an image. Provide a high level description of the image and do not include any additional text.",
},
{"type": "image_url", "image_url": {"url": image_obj, "detail": vision_details}},
],
},
]
else:
messages = [image_obj]
response = llm.generate_response(messages=messages)
return response
def parse_vision_messages(messages, llm=None, vision_details="auto"):
"""
Parse the vision messages from the messages
"""
returned_messages = []
for msg in messages:
if msg["role"] == "system":
returned_messages.append(msg)
continue
# Handle message content
if isinstance(msg["content"], list):
# Multiple image URLs in content
description = get_image_description(msg, llm, vision_details)
returned_messages.append({"role": msg["role"], "content": description})
elif isinstance(msg["content"], dict) and msg["content"].get("type") == "image_url":
# Single image content
image_url = msg["content"]["image_url"]["url"]
try:
description = get_image_description(image_url, llm, vision_details)
returned_messages.append({"role": msg["role"], "content": description})
except Exception:
raise Exception(f"Error while downloading {image_url}.")
else:
# Regular text content
returned_messages.append(msg)
return returned_messages
def process_telemetry_filters(filters):
"""
Process the telemetry filters
"""
if filters is None:
return {}
encoded_ids = {}
if "user_id" in filters:
encoded_ids["user_id"] = hashlib.md5(filters["user_id"].encode()).hexdigest()
if "agent_id" in filters:
encoded_ids["agent_id"] = hashlib.md5(filters["agent_id"].encode()).hexdigest()
if "run_id" in filters:
encoded_ids["run_id"] = hashlib.md5(filters["run_id"].encode()).hexdigest()
return list(filters.keys()), encoded_ids
def sanitize_relationship_for_cypher(relationship) -> str:
"""Sanitize relationship text for Cypher queries by replacing problematic characters."""
char_map = {
"...": "_ellipsis_",
"": "_ellipsis_",
"": "_period_",
"": "_comma_",
"": "_semicolon_",
"": "_colon_",
"": "_exclamation_",
"": "_question_",
"": "_lparen_",
"": "_rparen_",
"": "_lbracket_",
"": "_rbracket_",
"": "_langle_",
"": "_rangle_",
"'": "_apostrophe_",
'"': "_quote_",
"\\": "_backslash_",
"/": "_slash_",
"|": "_pipe_",
"&": "_ampersand_",
"=": "_equals_",
"+": "_plus_",
"*": "_asterisk_",
"^": "_caret_",
"%": "_percent_",
"$": "_dollar_",
"#": "_hash_",
"@": "_at_",
"!": "_bang_",
"?": "_question_",
"(": "_lparen_",
")": "_rparen_",
"[": "_lbracket_",
"]": "_rbracket_",
"{": "_lbrace_",
"}": "_rbrace_",
"<": "_langle_",
">": "_rangle_",
}
# Apply replacements and clean up
sanitized = relationship
for old, new in char_map.items():
sanitized = sanitized.replace(old, new)
return re.sub(r"_+", "_", sanitized).strip("_")