513 lines
19 KiB
Python
513 lines
19 KiB
Python
import logging
|
|
import uuid
|
|
from datetime import datetime
|
|
import pytz
|
|
|
|
from .base import NeptuneBase
|
|
|
|
try:
|
|
from langchain_aws import NeptuneGraph
|
|
except ImportError:
|
|
raise ImportError("langchain_aws is not installed. Please install it using 'make install_all'.")
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class MemoryGraph(NeptuneBase):
|
|
def __init__(self, config):
|
|
"""
|
|
Initialize the Neptune DB memory store.
|
|
"""
|
|
|
|
self.config = config
|
|
|
|
self.graph = None
|
|
endpoint = self.config.graph_store.config.endpoint
|
|
if endpoint and endpoint.startswith("neptune-db://"):
|
|
host = endpoint.replace("neptune-db://", "")
|
|
port = 8182
|
|
self.graph = NeptuneGraph(host, port)
|
|
|
|
if not self.graph:
|
|
raise ValueError("Unable to create a Neptune-DB client: missing 'endpoint' in config")
|
|
|
|
self.node_label = ":`__Entity__`" if self.config.graph_store.config.base_label else ""
|
|
|
|
self.embedding_model = NeptuneBase._create_embedding_model(self.config)
|
|
|
|
# Default to openai if no specific provider is configured
|
|
self.llm_provider = "openai"
|
|
if self.config.graph_store.llm:
|
|
self.llm_provider = self.config.graph_store.llm.provider
|
|
elif self.config.llm.provider:
|
|
self.llm_provider = self.config.llm.provider
|
|
|
|
# fetch the vector store as a provider
|
|
self.vector_store_provider = self.config.vector_store.provider
|
|
if self.config.graph_store.config.collection_name:
|
|
vector_store_collection_name = self.config.graph_store.config.collection_name
|
|
else:
|
|
vector_store_config = self.config.vector_store.config
|
|
if vector_store_config.collection_name:
|
|
vector_store_collection_name = vector_store_config.collection_name + "_neptune_vector_store"
|
|
else:
|
|
vector_store_collection_name = "mem0_neptune_vector_store"
|
|
self.config.vector_store.config.collection_name = vector_store_collection_name
|
|
self.vector_store = NeptuneBase._create_vector_store(self.vector_store_provider, self.config)
|
|
|
|
self.llm = NeptuneBase._create_llm(self.config, self.llm_provider)
|
|
self.user_id = None
|
|
# Use threshold from graph_store config, default to 0.7 for backward compatibility
|
|
self.threshold = self.config.graph_store.threshold if hasattr(self.config.graph_store, 'threshold') else 0.7
|
|
self.vector_store_limit=5
|
|
|
|
def _delete_entities_cypher(self, source, destination, relationship, user_id):
|
|
"""
|
|
Returns the OpenCypher query and parameters for deleting entities in the graph DB
|
|
|
|
:param source: source node
|
|
:param destination: destination node
|
|
:param relationship: relationship label
|
|
:param user_id: user_id to use
|
|
:return: str, dict
|
|
"""
|
|
|
|
cypher = f"""
|
|
MATCH (n {self.node_label} {{name: $source_name, user_id: $user_id}})
|
|
-[r:{relationship}]->
|
|
(m {self.node_label} {{name: $dest_name, user_id: $user_id}})
|
|
DELETE r
|
|
RETURN
|
|
n.name AS source,
|
|
m.name AS target,
|
|
type(r) AS relationship
|
|
"""
|
|
params = {
|
|
"source_name": source,
|
|
"dest_name": destination,
|
|
"user_id": user_id,
|
|
}
|
|
logger.debug(f"_delete_entities\n query={cypher}")
|
|
return cypher, params
|
|
|
|
def _add_entities_by_source_cypher(
|
|
self,
|
|
source_node_list,
|
|
destination,
|
|
dest_embedding,
|
|
destination_type,
|
|
relationship,
|
|
user_id,
|
|
):
|
|
"""
|
|
Returns the OpenCypher query and parameters for adding entities in the graph DB
|
|
|
|
:param source_node_list: list of source nodes
|
|
:param destination: destination name
|
|
:param dest_embedding: destination embedding
|
|
:param destination_type: destination node label
|
|
:param relationship: relationship label
|
|
:param user_id: user id to use
|
|
:return: str, dict
|
|
"""
|
|
destination_id = str(uuid.uuid4())
|
|
destination_payload = {
|
|
"name": destination,
|
|
"type": destination_type,
|
|
"user_id": user_id,
|
|
"created_at": datetime.now(pytz.timezone("US/Pacific")).isoformat(),
|
|
}
|
|
self.vector_store.insert(
|
|
vectors=[dest_embedding],
|
|
payloads=[destination_payload],
|
|
ids=[destination_id],
|
|
)
|
|
|
|
destination_label = self.node_label if self.node_label else f":`{destination_type}`"
|
|
destination_extra_set = f", destination:`{destination_type}`" if self.node_label else ""
|
|
|
|
cypher = f"""
|
|
MATCH (source {{user_id: $user_id}})
|
|
WHERE id(source) = $source_id
|
|
SET source.mentions = coalesce(source.mentions, 0) + 1
|
|
WITH source
|
|
MERGE (destination {destination_label} {{`~id`: $destination_id, name: $destination_name, user_id: $user_id}})
|
|
ON CREATE SET
|
|
destination.created = timestamp(),
|
|
destination.updated = timestamp(),
|
|
destination.mentions = 1
|
|
{destination_extra_set}
|
|
ON MATCH SET
|
|
destination.mentions = coalesce(destination.mentions, 0) + 1,
|
|
destination.updated = timestamp()
|
|
WITH source, destination
|
|
MERGE (source)-[r:{relationship}]->(destination)
|
|
ON CREATE SET
|
|
r.created = timestamp(),
|
|
r.updated = timestamp(),
|
|
r.mentions = 1
|
|
ON MATCH SET
|
|
r.mentions = coalesce(r.mentions, 0) + 1,
|
|
r.updated = timestamp()
|
|
RETURN source.name AS source, type(r) AS relationship, destination.name AS target, id(destination) AS destination_id
|
|
"""
|
|
|
|
params = {
|
|
"source_id": source_node_list[0]["id(source_candidate)"],
|
|
"destination_id": destination_id,
|
|
"destination_name": destination,
|
|
"dest_embedding": dest_embedding,
|
|
"user_id": user_id,
|
|
}
|
|
|
|
logger.debug(
|
|
f"_add_entities:\n source_node_search_result={source_node_list[0]}\n query={cypher}"
|
|
)
|
|
return cypher, params
|
|
|
|
def _add_entities_by_destination_cypher(
|
|
self,
|
|
source,
|
|
source_embedding,
|
|
source_type,
|
|
destination_node_list,
|
|
relationship,
|
|
user_id,
|
|
):
|
|
"""
|
|
Returns the OpenCypher query and parameters for adding entities in the graph DB
|
|
|
|
:param source: source node name
|
|
:param source_embedding: source node embedding
|
|
:param source_type: source node label
|
|
:param destination_node_list: list of dest nodes
|
|
:param relationship: relationship label
|
|
:param user_id: user id to use
|
|
:return: str, dict
|
|
"""
|
|
source_id = str(uuid.uuid4())
|
|
source_payload = {
|
|
"name": source,
|
|
"type": source_type,
|
|
"user_id": user_id,
|
|
"created_at": datetime.now(pytz.timezone("US/Pacific")).isoformat(),
|
|
}
|
|
self.vector_store.insert(
|
|
vectors=[source_embedding],
|
|
payloads=[source_payload],
|
|
ids=[source_id],
|
|
)
|
|
|
|
source_label = self.node_label if self.node_label else f":`{source_type}`"
|
|
source_extra_set = f", source:`{source_type}`" if self.node_label else ""
|
|
|
|
cypher = f"""
|
|
MATCH (destination {{user_id: $user_id}})
|
|
WHERE id(destination) = $destination_id
|
|
SET
|
|
destination.mentions = coalesce(destination.mentions, 0) + 1,
|
|
destination.updated = timestamp()
|
|
WITH destination
|
|
MERGE (source {source_label} {{`~id`: $source_id, name: $source_name, user_id: $user_id}})
|
|
ON CREATE SET
|
|
source.created = timestamp(),
|
|
source.updated = timestamp(),
|
|
source.mentions = 1
|
|
{source_extra_set}
|
|
ON MATCH SET
|
|
source.mentions = coalesce(source.mentions, 0) + 1,
|
|
source.updated = timestamp()
|
|
WITH source, destination
|
|
MERGE (source)-[r:{relationship}]->(destination)
|
|
ON CREATE SET
|
|
r.created = timestamp(),
|
|
r.updated = timestamp(),
|
|
r.mentions = 1
|
|
ON MATCH SET
|
|
r.mentions = coalesce(r.mentions, 0) + 1,
|
|
r.updated = timestamp()
|
|
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
|
"""
|
|
|
|
params = {
|
|
"destination_id": destination_node_list[0]["id(destination_candidate)"],
|
|
"source_id": source_id,
|
|
"source_name": source,
|
|
"source_embedding": source_embedding,
|
|
"user_id": user_id,
|
|
}
|
|
logger.debug(
|
|
f"_add_entities:\n destination_node_search_result={destination_node_list[0]}\n query={cypher}"
|
|
)
|
|
return cypher, params
|
|
|
|
def _add_relationship_entities_cypher(
|
|
self,
|
|
source_node_list,
|
|
destination_node_list,
|
|
relationship,
|
|
user_id,
|
|
):
|
|
"""
|
|
Returns the OpenCypher query and parameters for adding entities in the graph DB
|
|
|
|
:param source_node_list: list of source node ids
|
|
:param destination_node_list: list of dest node ids
|
|
:param relationship: relationship label
|
|
:param user_id: user id to use
|
|
:return: str, dict
|
|
"""
|
|
|
|
cypher = f"""
|
|
MATCH (source {{user_id: $user_id}})
|
|
WHERE id(source) = $source_id
|
|
SET
|
|
source.mentions = coalesce(source.mentions, 0) + 1,
|
|
source.updated = timestamp()
|
|
WITH source
|
|
MATCH (destination {{user_id: $user_id}})
|
|
WHERE id(destination) = $destination_id
|
|
SET
|
|
destination.mentions = coalesce(destination.mentions) + 1,
|
|
destination.updated = timestamp()
|
|
MERGE (source)-[r:{relationship}]->(destination)
|
|
ON CREATE SET
|
|
r.created_at = timestamp(),
|
|
r.updated_at = timestamp(),
|
|
r.mentions = 1
|
|
ON MATCH SET r.mentions = coalesce(r.mentions, 0) + 1
|
|
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
|
"""
|
|
params = {
|
|
"source_id": source_node_list[0]["id(source_candidate)"],
|
|
"destination_id": destination_node_list[0]["id(destination_candidate)"],
|
|
"user_id": user_id,
|
|
}
|
|
logger.debug(
|
|
f"_add_entities:\n destination_node_search_result={destination_node_list[0]}\n source_node_search_result={source_node_list[0]}\n query={cypher}"
|
|
)
|
|
return cypher, params
|
|
|
|
def _add_new_entities_cypher(
|
|
self,
|
|
source,
|
|
source_embedding,
|
|
source_type,
|
|
destination,
|
|
dest_embedding,
|
|
destination_type,
|
|
relationship,
|
|
user_id,
|
|
):
|
|
"""
|
|
Returns the OpenCypher query and parameters for adding entities in the graph DB
|
|
|
|
:param source: source node name
|
|
:param source_embedding: source node embedding
|
|
:param source_type: source node label
|
|
:param destination: destination name
|
|
:param dest_embedding: destination embedding
|
|
:param destination_type: destination node label
|
|
:param relationship: relationship label
|
|
:param user_id: user id to use
|
|
:return: str, dict
|
|
"""
|
|
source_id = str(uuid.uuid4())
|
|
source_payload = {
|
|
"name": source,
|
|
"type": source_type,
|
|
"user_id": user_id,
|
|
"created_at": datetime.now(pytz.timezone("US/Pacific")).isoformat(),
|
|
}
|
|
destination_id = str(uuid.uuid4())
|
|
destination_payload = {
|
|
"name": destination,
|
|
"type": destination_type,
|
|
"user_id": user_id,
|
|
"created_at": datetime.now(pytz.timezone("US/Pacific")).isoformat(),
|
|
}
|
|
self.vector_store.insert(
|
|
vectors=[source_embedding, dest_embedding],
|
|
payloads=[source_payload, destination_payload],
|
|
ids=[source_id, destination_id],
|
|
)
|
|
|
|
source_label = self.node_label if self.node_label else f":`{source_type}`"
|
|
source_extra_set = f", source:`{source_type}`" if self.node_label else ""
|
|
destination_label = self.node_label if self.node_label else f":`{destination_type}`"
|
|
destination_extra_set = f", destination:`{destination_type}`" if self.node_label else ""
|
|
|
|
cypher = f"""
|
|
MERGE (n {source_label} {{name: $source_name, user_id: $user_id, `~id`: $source_id}})
|
|
ON CREATE SET n.created = timestamp(),
|
|
n.mentions = 1
|
|
{source_extra_set}
|
|
ON MATCH SET n.mentions = coalesce(n.mentions, 0) + 1
|
|
WITH n
|
|
MERGE (m {destination_label} {{name: $dest_name, user_id: $user_id, `~id`: $dest_id}})
|
|
ON CREATE SET m.created = timestamp(),
|
|
m.mentions = 1
|
|
{destination_extra_set}
|
|
ON MATCH SET m.mentions = coalesce(m.mentions, 0) + 1
|
|
WITH n, m
|
|
MERGE (n)-[rel:{relationship}]->(m)
|
|
ON CREATE SET rel.created = timestamp(), rel.mentions = 1
|
|
ON MATCH SET rel.mentions = coalesce(rel.mentions, 0) + 1
|
|
RETURN n.name AS source, type(rel) AS relationship, m.name AS target
|
|
"""
|
|
params = {
|
|
"source_id": source_id,
|
|
"dest_id": destination_id,
|
|
"source_name": source,
|
|
"dest_name": destination,
|
|
"source_embedding": source_embedding,
|
|
"dest_embedding": dest_embedding,
|
|
"user_id": user_id,
|
|
}
|
|
logger.debug(
|
|
f"_add_new_entities_cypher:\n query={cypher}"
|
|
)
|
|
return cypher, params
|
|
|
|
def _search_source_node_cypher(self, source_embedding, user_id, threshold):
|
|
"""
|
|
Returns the OpenCypher query and parameters to search for source nodes
|
|
|
|
:param source_embedding: source vector
|
|
:param user_id: user_id to use
|
|
:param threshold: the threshold for similarity
|
|
:return: str, dict
|
|
"""
|
|
|
|
source_nodes = self.vector_store.search(
|
|
query="",
|
|
vectors=source_embedding,
|
|
limit=self.vector_store_limit,
|
|
filters={"user_id": user_id},
|
|
)
|
|
|
|
ids = [n.id for n in filter(lambda s: s.score > threshold, source_nodes)]
|
|
|
|
cypher = f"""
|
|
MATCH (source_candidate {self.node_label})
|
|
WHERE source_candidate.user_id = $user_id AND id(source_candidate) IN $ids
|
|
RETURN id(source_candidate)
|
|
"""
|
|
|
|
params = {
|
|
"ids": ids,
|
|
"source_embedding": source_embedding,
|
|
"user_id": user_id,
|
|
"threshold": threshold,
|
|
}
|
|
logger.debug(f"_search_source_node\n query={cypher}")
|
|
return cypher, params
|
|
|
|
def _search_destination_node_cypher(self, destination_embedding, user_id, threshold):
|
|
"""
|
|
Returns the OpenCypher query and parameters to search for destination nodes
|
|
|
|
:param source_embedding: source vector
|
|
:param user_id: user_id to use
|
|
:param threshold: the threshold for similarity
|
|
:return: str, dict
|
|
"""
|
|
destination_nodes = self.vector_store.search(
|
|
query="",
|
|
vectors=destination_embedding,
|
|
limit=self.vector_store_limit,
|
|
filters={"user_id": user_id},
|
|
)
|
|
|
|
ids = [n.id for n in filter(lambda d: d.score > threshold, destination_nodes)]
|
|
|
|
cypher = f"""
|
|
MATCH (destination_candidate {self.node_label})
|
|
WHERE destination_candidate.user_id = $user_id AND id(destination_candidate) IN $ids
|
|
RETURN id(destination_candidate)
|
|
"""
|
|
|
|
params = {
|
|
"ids": ids,
|
|
"destination_embedding": destination_embedding,
|
|
"user_id": user_id,
|
|
}
|
|
|
|
logger.debug(f"_search_destination_node\n query={cypher}")
|
|
return cypher, params
|
|
|
|
def _delete_all_cypher(self, filters):
|
|
"""
|
|
Returns the OpenCypher query and parameters to delete all edges/nodes in the memory store
|
|
|
|
:param filters: search filters
|
|
:return: str, dict
|
|
"""
|
|
|
|
# remove the vector store index
|
|
self.vector_store.reset()
|
|
|
|
# create a query that: deletes the nodes of the graph_store
|
|
cypher = f"""
|
|
MATCH (n {self.node_label} {{user_id: $user_id}})
|
|
DETACH DELETE n
|
|
"""
|
|
params = {"user_id": filters["user_id"]}
|
|
|
|
logger.debug(f"delete_all query={cypher}")
|
|
return cypher, params
|
|
|
|
def _get_all_cypher(self, filters, limit):
|
|
"""
|
|
Returns the OpenCypher query and parameters to get all edges/nodes in the memory store
|
|
|
|
:param filters: search filters
|
|
:param limit: return limit
|
|
:return: str, dict
|
|
"""
|
|
|
|
cypher = f"""
|
|
MATCH (n {self.node_label} {{user_id: $user_id}})-[r]->(m {self.node_label} {{user_id: $user_id}})
|
|
RETURN n.name AS source, type(r) AS relationship, m.name AS target
|
|
LIMIT $limit
|
|
"""
|
|
params = {"user_id": filters["user_id"], "limit": limit}
|
|
return cypher, params
|
|
|
|
def _search_graph_db_cypher(self, n_embedding, filters, limit):
|
|
"""
|
|
Returns the OpenCypher query and parameters to search for similar nodes in the memory store
|
|
|
|
:param n_embedding: node vector
|
|
:param filters: search filters
|
|
:param limit: return limit
|
|
:return: str, dict
|
|
"""
|
|
|
|
# search vector store for applicable nodes using cosine similarity
|
|
search_nodes = self.vector_store.search(
|
|
query="",
|
|
vectors=n_embedding,
|
|
limit=self.vector_store_limit,
|
|
filters=filters,
|
|
)
|
|
|
|
ids = [n.id for n in search_nodes]
|
|
|
|
cypher_query = f"""
|
|
MATCH (n {self.node_label})-[r]->(m)
|
|
WHERE n.user_id = $user_id AND id(n) IN $n_ids
|
|
RETURN n.name AS source, id(n) AS source_id, type(r) AS relationship, id(r) AS relation_id, m.name AS destination, id(m) AS destination_id
|
|
UNION
|
|
MATCH (m)-[r]->(n {self.node_label})
|
|
RETURN m.name AS source, id(m) AS source_id, type(r) AS relationship, id(r) AS relation_id, n.name AS destination, id(n) AS destination_id
|
|
LIMIT $limit
|
|
"""
|
|
params = {
|
|
"n_ids": ids,
|
|
"user_id": filters["user_id"],
|
|
"limit": limit,
|
|
}
|
|
logger.debug(f"_search_graph_db\n query={cypher_query}")
|
|
|
|
return cypher_query, params
|