first commit
This commit is contained in:
0
graphs/__init__.py
Normal file
0
graphs/__init__.py
Normal file
114
graphs/configs.py
Normal file
114
graphs/configs.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from mem0.llms.configs import LlmConfig
|
||||
|
||||
|
||||
class Neo4jConfig(BaseModel):
|
||||
url: Optional[str] = Field(None, description="Host address for the graph database")
|
||||
username: Optional[str] = Field(None, description="Username for the graph database")
|
||||
password: Optional[str] = Field(None, description="Password for the graph database")
|
||||
database: Optional[str] = Field(None, description="Database for the graph database")
|
||||
base_label: Optional[bool] = Field(None, description="Whether to use base node label __Entity__ for all entities")
|
||||
|
||||
@model_validator(mode="before")
|
||||
def check_host_port_or_path(cls, values):
|
||||
url, username, password = (
|
||||
values.get("url"),
|
||||
values.get("username"),
|
||||
values.get("password"),
|
||||
)
|
||||
if not url or not username or not password:
|
||||
raise ValueError("Please provide 'url', 'username' and 'password'.")
|
||||
return values
|
||||
|
||||
|
||||
class MemgraphConfig(BaseModel):
|
||||
url: Optional[str] = Field(None, description="Host address for the graph database")
|
||||
username: Optional[str] = Field(None, description="Username for the graph database")
|
||||
password: Optional[str] = Field(None, description="Password for the graph database")
|
||||
|
||||
@model_validator(mode="before")
|
||||
def check_host_port_or_path(cls, values):
|
||||
url, username, password = (
|
||||
values.get("url"),
|
||||
values.get("username"),
|
||||
values.get("password"),
|
||||
)
|
||||
if not url or not username or not password:
|
||||
raise ValueError("Please provide 'url', 'username' and 'password'.")
|
||||
return values
|
||||
|
||||
|
||||
class NeptuneConfig(BaseModel):
|
||||
app_id: Optional[str] = Field("Mem0", description="APP_ID for the connection")
|
||||
endpoint: Optional[str] = (
|
||||
Field(
|
||||
None,
|
||||
description="Endpoint to connect to a Neptune-DB Cluster as 'neptune-db://<host>' or Neptune Analytics Server as 'neptune-graph://<graphid>'",
|
||||
),
|
||||
)
|
||||
base_label: Optional[bool] = Field(None, description="Whether to use base node label __Entity__ for all entities")
|
||||
collection_name: Optional[str] = Field(None, description="vector_store collection name to store vectors when using Neptune-DB Clusters")
|
||||
|
||||
@model_validator(mode="before")
|
||||
def check_host_port_or_path(cls, values):
|
||||
endpoint = values.get("endpoint")
|
||||
if not endpoint:
|
||||
raise ValueError("Please provide 'endpoint' with the format as 'neptune-db://<endpoint>' or 'neptune-graph://<graphid>'.")
|
||||
if endpoint.startswith("neptune-db://"):
|
||||
# This is a Neptune DB Graph
|
||||
return values
|
||||
elif endpoint.startswith("neptune-graph://"):
|
||||
# This is a Neptune Analytics Graph
|
||||
graph_identifier = endpoint.replace("neptune-graph://", "")
|
||||
if not graph_identifier.startswith("g-"):
|
||||
raise ValueError("Provide a valid 'graph_identifier'.")
|
||||
values["graph_identifier"] = graph_identifier
|
||||
return values
|
||||
else:
|
||||
raise ValueError(
|
||||
"You must provide an endpoint to create a NeptuneServer as either neptune-db://<endpoint> or neptune-graph://<graphid>"
|
||||
)
|
||||
|
||||
|
||||
class KuzuConfig(BaseModel):
|
||||
db: Optional[str] = Field(":memory:", description="Path to a Kuzu database file")
|
||||
|
||||
|
||||
class GraphStoreConfig(BaseModel):
|
||||
provider: str = Field(
|
||||
description="Provider of the data store (e.g., 'neo4j', 'memgraph', 'neptune', 'kuzu')",
|
||||
default="neo4j",
|
||||
)
|
||||
config: Union[Neo4jConfig, MemgraphConfig, NeptuneConfig, KuzuConfig] = Field(
|
||||
description="Configuration for the specific data store", default=None
|
||||
)
|
||||
llm: Optional[LlmConfig] = Field(description="LLM configuration for querying the graph store", default=None)
|
||||
custom_prompt: Optional[str] = Field(
|
||||
description="Custom prompt to fetch entities from the given text", default=None
|
||||
)
|
||||
threshold: float = Field(
|
||||
description="Threshold for embedding similarity when matching nodes during graph ingestion. "
|
||||
"Range: 0.0 to 1.0. Higher values require closer matches. "
|
||||
"Use lower values (e.g., 0.5-0.7) for distinct entities with similar embeddings. "
|
||||
"Use higher values (e.g., 0.9+) when you want stricter matching.",
|
||||
default=0.7,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
)
|
||||
|
||||
@field_validator("config")
|
||||
def validate_config(cls, v, values):
|
||||
provider = values.data.get("provider")
|
||||
if provider == "neo4j":
|
||||
return Neo4jConfig(**v.model_dump())
|
||||
elif provider == "memgraph":
|
||||
return MemgraphConfig(**v.model_dump())
|
||||
elif provider == "neptune" or provider == "neptunedb":
|
||||
return NeptuneConfig(**v.model_dump())
|
||||
elif provider == "kuzu":
|
||||
return KuzuConfig(**v.model_dump())
|
||||
else:
|
||||
raise ValueError(f"Unsupported graph store provider: {provider}")
|
||||
0
graphs/neptune/__init__.py
Normal file
0
graphs/neptune/__init__.py
Normal file
497
graphs/neptune/base.py
Normal file
497
graphs/neptune/base.py
Normal file
@@ -0,0 +1,497 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from mem0.memory.utils import format_entities
|
||||
|
||||
try:
|
||||
from rank_bm25 import BM25Okapi
|
||||
except ImportError:
|
||||
raise ImportError("rank_bm25 is not installed. Please install it using pip install rank-bm25")
|
||||
|
||||
from mem0.graphs.tools import (
|
||||
DELETE_MEMORY_STRUCT_TOOL_GRAPH,
|
||||
DELETE_MEMORY_TOOL_GRAPH,
|
||||
EXTRACT_ENTITIES_STRUCT_TOOL,
|
||||
EXTRACT_ENTITIES_TOOL,
|
||||
RELATIONS_STRUCT_TOOL,
|
||||
RELATIONS_TOOL,
|
||||
)
|
||||
from mem0.graphs.utils import EXTRACT_RELATIONS_PROMPT, get_delete_messages
|
||||
from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NeptuneBase(ABC):
|
||||
"""
|
||||
Abstract base class for neptune (neptune analytics and neptune db) calls using OpenCypher
|
||||
to store/retrieve data
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _create_embedding_model(config):
|
||||
"""
|
||||
:return: the Embedder model used for memory store
|
||||
"""
|
||||
return EmbedderFactory.create(
|
||||
config.embedder.provider,
|
||||
config.embedder.config,
|
||||
{"enable_embeddings": True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_llm(config, llm_provider):
|
||||
"""
|
||||
:return: the llm model used for memory store
|
||||
"""
|
||||
return LlmFactory.create(llm_provider, config.llm.config)
|
||||
|
||||
@staticmethod
|
||||
def _create_vector_store(vector_store_provider, config):
|
||||
"""
|
||||
:param vector_store_provider: name of vector store
|
||||
:param config: the vector_store configuration
|
||||
:return:
|
||||
"""
|
||||
return VectorStoreFactory.create(vector_store_provider, config.vector_store.config)
|
||||
|
||||
def add(self, data, filters):
|
||||
"""
|
||||
Adds data to the graph.
|
||||
|
||||
Args:
|
||||
data (str): The data to add to the graph.
|
||||
filters (dict): A dictionary containing filters to be applied during the addition.
|
||||
"""
|
||||
entity_type_map = self._retrieve_nodes_from_data(data, filters)
|
||||
to_be_added = self._establish_nodes_relations_from_data(data, filters, entity_type_map)
|
||||
search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
|
||||
to_be_deleted = self._get_delete_entities_from_search_output(search_output, data, filters)
|
||||
|
||||
deleted_entities = self._delete_entities(to_be_deleted, filters["user_id"])
|
||||
added_entities = self._add_entities(to_be_added, filters["user_id"], entity_type_map)
|
||||
|
||||
return {"deleted_entities": deleted_entities, "added_entities": added_entities}
|
||||
|
||||
def _retrieve_nodes_from_data(self, data, filters):
|
||||
"""
|
||||
Extract all entities mentioned in the query.
|
||||
"""
|
||||
_tools = [EXTRACT_ENTITIES_TOOL]
|
||||
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
|
||||
_tools = [EXTRACT_ENTITIES_STRUCT_TOOL]
|
||||
search_results = self.llm.generate_response(
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"You are a smart assistant who understands entities and their types in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source entity. Extract all the entities from the text. ***DO NOT*** answer the question itself if the given text is a question.",
|
||||
},
|
||||
{"role": "user", "content": data},
|
||||
],
|
||||
tools=_tools,
|
||||
)
|
||||
|
||||
entity_type_map = {}
|
||||
|
||||
try:
|
||||
for tool_call in search_results["tool_calls"]:
|
||||
if tool_call["name"] != "extract_entities":
|
||||
continue
|
||||
for item in tool_call["arguments"]["entities"]:
|
||||
entity_type_map[item["entity"]] = item["entity_type"]
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error in search tool: {e}, llm_provider={self.llm_provider}, search_results={search_results}"
|
||||
)
|
||||
|
||||
entity_type_map = {k.lower().replace(" ", "_"): v.lower().replace(" ", "_") for k, v in entity_type_map.items()}
|
||||
return entity_type_map
|
||||
|
||||
def _establish_nodes_relations_from_data(self, data, filters, entity_type_map):
|
||||
"""
|
||||
Establish relations among the extracted nodes.
|
||||
"""
|
||||
if self.config.graph_store.custom_prompt:
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]).replace(
|
||||
"CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}"
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": data},
|
||||
]
|
||||
else:
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"List of entities: {list(entity_type_map.keys())}. \n\nText: {data}",
|
||||
},
|
||||
]
|
||||
|
||||
_tools = [RELATIONS_TOOL]
|
||||
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
|
||||
_tools = [RELATIONS_STRUCT_TOOL]
|
||||
|
||||
extracted_entities = self.llm.generate_response(
|
||||
messages=messages,
|
||||
tools=_tools,
|
||||
)
|
||||
|
||||
entities = []
|
||||
if extracted_entities["tool_calls"]:
|
||||
entities = extracted_entities["tool_calls"][0]["arguments"]["entities"]
|
||||
|
||||
entities = self._remove_spaces_from_entities(entities)
|
||||
logger.debug(f"Extracted entities: {entities}")
|
||||
return entities
|
||||
|
||||
def _remove_spaces_from_entities(self, entity_list):
|
||||
for item in entity_list:
|
||||
item["source"] = item["source"].lower().replace(" ", "_")
|
||||
item["relationship"] = item["relationship"].lower().replace(" ", "_")
|
||||
item["destination"] = item["destination"].lower().replace(" ", "_")
|
||||
return entity_list
|
||||
|
||||
def _get_delete_entities_from_search_output(self, search_output, data, filters):
|
||||
"""
|
||||
Get the entities to be deleted from the search output.
|
||||
"""
|
||||
|
||||
search_output_string = format_entities(search_output)
|
||||
system_prompt, user_prompt = get_delete_messages(search_output_string, data, filters["user_id"])
|
||||
|
||||
_tools = [DELETE_MEMORY_TOOL_GRAPH]
|
||||
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
|
||||
_tools = [
|
||||
DELETE_MEMORY_STRUCT_TOOL_GRAPH,
|
||||
]
|
||||
|
||||
memory_updates = self.llm.generate_response(
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
tools=_tools,
|
||||
)
|
||||
|
||||
to_be_deleted = []
|
||||
for item in memory_updates["tool_calls"]:
|
||||
if item["name"] == "delete_graph_memory":
|
||||
to_be_deleted.append(item["arguments"])
|
||||
# in case if it is not in the correct format
|
||||
to_be_deleted = self._remove_spaces_from_entities(to_be_deleted)
|
||||
logger.debug(f"Deleted relationships: {to_be_deleted}")
|
||||
return to_be_deleted
|
||||
|
||||
def _delete_entities(self, to_be_deleted, user_id):
|
||||
"""
|
||||
Delete the entities from the graph.
|
||||
"""
|
||||
|
||||
results = []
|
||||
for item in to_be_deleted:
|
||||
source = item["source"]
|
||||
destination = item["destination"]
|
||||
relationship = item["relationship"]
|
||||
|
||||
# Delete the specific relationship between nodes
|
||||
cypher, params = self._delete_entities_cypher(source, destination, relationship, user_id)
|
||||
result = self.graph.query(cypher, params=params)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
@abstractmethod
|
||||
def _delete_entities_cypher(self, source, destination, relationship, user_id):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters for deleting entities in the graph DB
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def _add_entities(self, to_be_added, user_id, entity_type_map):
|
||||
"""
|
||||
Add the new entities to the graph. Merge the nodes if they already exist.
|
||||
"""
|
||||
|
||||
results = []
|
||||
for item in to_be_added:
|
||||
# entities
|
||||
source = item["source"]
|
||||
destination = item["destination"]
|
||||
relationship = item["relationship"]
|
||||
|
||||
# types
|
||||
source_type = entity_type_map.get(source, "__User__")
|
||||
destination_type = entity_type_map.get(destination, "__User__")
|
||||
|
||||
# embeddings
|
||||
source_embedding = self.embedding_model.embed(source)
|
||||
dest_embedding = self.embedding_model.embed(destination)
|
||||
|
||||
# search for the nodes with the closest embeddings
|
||||
source_node_search_result = self._search_source_node(source_embedding, user_id, threshold=self.threshold)
|
||||
destination_node_search_result = self._search_destination_node(dest_embedding, user_id, threshold=self.threshold)
|
||||
|
||||
cypher, params = self._add_entities_cypher(
|
||||
source_node_search_result,
|
||||
source,
|
||||
source_embedding,
|
||||
source_type,
|
||||
destination_node_search_result,
|
||||
destination,
|
||||
dest_embedding,
|
||||
destination_type,
|
||||
relationship,
|
||||
user_id,
|
||||
)
|
||||
result = self.graph.query(cypher, params=params)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
def _add_entities_cypher(
|
||||
self,
|
||||
source_node_list,
|
||||
source,
|
||||
source_embedding,
|
||||
source_type,
|
||||
destination_node_list,
|
||||
destination,
|
||||
dest_embedding,
|
||||
destination_type,
|
||||
relationship,
|
||||
user_id,
|
||||
):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters for adding entities in the graph DB
|
||||
"""
|
||||
if not destination_node_list and source_node_list:
|
||||
return self._add_entities_by_source_cypher(
|
||||
source_node_list,
|
||||
destination,
|
||||
dest_embedding,
|
||||
destination_type,
|
||||
relationship,
|
||||
user_id)
|
||||
elif destination_node_list and not source_node_list:
|
||||
return self._add_entities_by_destination_cypher(
|
||||
source,
|
||||
source_embedding,
|
||||
source_type,
|
||||
destination_node_list,
|
||||
relationship,
|
||||
user_id)
|
||||
elif source_node_list and destination_node_list:
|
||||
return self._add_relationship_entities_cypher(
|
||||
source_node_list,
|
||||
destination_node_list,
|
||||
relationship,
|
||||
user_id)
|
||||
# else source_node_list and destination_node_list are empty
|
||||
return self._add_new_entities_cypher(
|
||||
source,
|
||||
source_embedding,
|
||||
source_type,
|
||||
destination,
|
||||
dest_embedding,
|
||||
destination_type,
|
||||
relationship,
|
||||
user_id)
|
||||
|
||||
@abstractmethod
|
||||
def _add_entities_by_source_cypher(
|
||||
self,
|
||||
source_node_list,
|
||||
destination,
|
||||
dest_embedding,
|
||||
destination_type,
|
||||
relationship,
|
||||
user_id,
|
||||
):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _add_entities_by_destination_cypher(
|
||||
self,
|
||||
source,
|
||||
source_embedding,
|
||||
source_type,
|
||||
destination_node_list,
|
||||
relationship,
|
||||
user_id,
|
||||
):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _add_relationship_entities_cypher(
|
||||
self,
|
||||
source_node_list,
|
||||
destination_node_list,
|
||||
relationship,
|
||||
user_id,
|
||||
):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _add_new_entities_cypher(
|
||||
self,
|
||||
source,
|
||||
source_embedding,
|
||||
source_type,
|
||||
destination,
|
||||
dest_embedding,
|
||||
destination_type,
|
||||
relationship,
|
||||
user_id,
|
||||
):
|
||||
pass
|
||||
|
||||
def search(self, query, filters, limit=100):
|
||||
"""
|
||||
Search for memories and related graph data.
|
||||
|
||||
Args:
|
||||
query (str): Query to search for.
|
||||
filters (dict): A dictionary containing filters to be applied during the search.
|
||||
limit (int): The maximum number of nodes and relationships to retrieve. Defaults to 100.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing:
|
||||
- "contexts": List of search results from the base data store.
|
||||
- "entities": List of related graph data based on the query.
|
||||
"""
|
||||
|
||||
entity_type_map = self._retrieve_nodes_from_data(query, filters)
|
||||
search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
|
||||
|
||||
if not search_output:
|
||||
return []
|
||||
|
||||
search_outputs_sequence = [
|
||||
[item["source"], item["relationship"], item["destination"]] for item in search_output
|
||||
]
|
||||
bm25 = BM25Okapi(search_outputs_sequence)
|
||||
|
||||
tokenized_query = query.split(" ")
|
||||
reranked_results = bm25.get_top_n(tokenized_query, search_outputs_sequence, n=5)
|
||||
|
||||
search_results = []
|
||||
for item in reranked_results:
|
||||
search_results.append({"source": item[0], "relationship": item[1], "destination": item[2]})
|
||||
|
||||
return search_results
|
||||
|
||||
def _search_source_node(self, source_embedding, user_id, threshold=0.9):
|
||||
cypher, params = self._search_source_node_cypher(source_embedding, user_id, threshold)
|
||||
result = self.graph.query(cypher, params=params)
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
def _search_source_node_cypher(self, source_embedding, user_id, threshold):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters to search for source nodes
|
||||
"""
|
||||
pass
|
||||
|
||||
def _search_destination_node(self, destination_embedding, user_id, threshold=0.9):
|
||||
cypher, params = self._search_destination_node_cypher(destination_embedding, user_id, threshold)
|
||||
result = self.graph.query(cypher, params=params)
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
def _search_destination_node_cypher(self, destination_embedding, user_id, threshold):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters to search for destination nodes
|
||||
"""
|
||||
pass
|
||||
|
||||
def delete_all(self, filters):
|
||||
cypher, params = self._delete_all_cypher(filters)
|
||||
self.graph.query(cypher, params=params)
|
||||
|
||||
@abstractmethod
|
||||
def _delete_all_cypher(self, filters):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters to delete all edges/nodes in the memory store
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_all(self, filters, limit=100):
|
||||
"""
|
||||
Retrieves all nodes and relationships from the graph database based on filtering criteria.
|
||||
|
||||
Args:
|
||||
filters (dict): A dictionary containing filters to be applied during the retrieval.
|
||||
limit (int): The maximum number of nodes and relationships to retrieve. Defaults to 100.
|
||||
Returns:
|
||||
list: A list of dictionaries, each containing:
|
||||
- 'contexts': The base data store response for each memory.
|
||||
- 'entities': A list of strings representing the nodes and relationships
|
||||
"""
|
||||
|
||||
# return all nodes and relationships
|
||||
query, params = self._get_all_cypher(filters, limit)
|
||||
results = self.graph.query(query, params=params)
|
||||
|
||||
final_results = []
|
||||
for result in results:
|
||||
final_results.append(
|
||||
{
|
||||
"source": result["source"],
|
||||
"relationship": result["relationship"],
|
||||
"target": result["target"],
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug(f"Retrieved {len(final_results)} relationships")
|
||||
|
||||
return final_results
|
||||
|
||||
@abstractmethod
|
||||
def _get_all_cypher(self, filters, limit):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters to get all edges/nodes in the memory store
|
||||
"""
|
||||
pass
|
||||
|
||||
def _search_graph_db(self, node_list, filters, limit=100):
|
||||
"""
|
||||
Search similar nodes among and their respective incoming and outgoing relations.
|
||||
"""
|
||||
result_relations = []
|
||||
|
||||
for node in node_list:
|
||||
n_embedding = self.embedding_model.embed(node)
|
||||
cypher_query, params = self._search_graph_db_cypher(n_embedding, filters, limit)
|
||||
ans = self.graph.query(cypher_query, params=params)
|
||||
result_relations.extend(ans)
|
||||
|
||||
return result_relations
|
||||
|
||||
@abstractmethod
|
||||
def _search_graph_db_cypher(self, n_embedding, filters, limit):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters to search for similar nodes in the memory store
|
||||
"""
|
||||
pass
|
||||
|
||||
# Reset is not defined in base.py
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the graph by clearing all nodes and relationships.
|
||||
|
||||
link: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/neptune-graph/client/reset_graph.html
|
||||
"""
|
||||
|
||||
logger.warning("Clearing graph...")
|
||||
graph_id = self.graph.graph_identifier
|
||||
self.graph.client.reset_graph(
|
||||
graphIdentifier=graph_id,
|
||||
skipSnapshot=True,
|
||||
)
|
||||
waiter = self.graph.client.get_waiter("graph_available")
|
||||
waiter.wait(graphIdentifier=graph_id, WaiterConfig={"Delay": 10, "MaxAttempts": 60})
|
||||
512
graphs/neptune/neptunedb.py
Normal file
512
graphs/neptune/neptunedb.py
Normal file
@@ -0,0 +1,512 @@
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
import pytz
|
||||
|
||||
from .base import NeptuneBase
|
||||
|
||||
try:
|
||||
from langchain_aws import NeptuneGraph
|
||||
except ImportError:
|
||||
raise ImportError("langchain_aws is not installed. Please install it using 'make install_all'.")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MemoryGraph(NeptuneBase):
|
||||
def __init__(self, config):
|
||||
"""
|
||||
Initialize the Neptune DB memory store.
|
||||
"""
|
||||
|
||||
self.config = config
|
||||
|
||||
self.graph = None
|
||||
endpoint = self.config.graph_store.config.endpoint
|
||||
if endpoint and endpoint.startswith("neptune-db://"):
|
||||
host = endpoint.replace("neptune-db://", "")
|
||||
port = 8182
|
||||
self.graph = NeptuneGraph(host, port)
|
||||
|
||||
if not self.graph:
|
||||
raise ValueError("Unable to create a Neptune-DB client: missing 'endpoint' in config")
|
||||
|
||||
self.node_label = ":`__Entity__`" if self.config.graph_store.config.base_label else ""
|
||||
|
||||
self.embedding_model = NeptuneBase._create_embedding_model(self.config)
|
||||
|
||||
# Default to openai if no specific provider is configured
|
||||
self.llm_provider = "openai"
|
||||
if self.config.graph_store.llm:
|
||||
self.llm_provider = self.config.graph_store.llm.provider
|
||||
elif self.config.llm.provider:
|
||||
self.llm_provider = self.config.llm.provider
|
||||
|
||||
# fetch the vector store as a provider
|
||||
self.vector_store_provider = self.config.vector_store.provider
|
||||
if self.config.graph_store.config.collection_name:
|
||||
vector_store_collection_name = self.config.graph_store.config.collection_name
|
||||
else:
|
||||
vector_store_config = self.config.vector_store.config
|
||||
if vector_store_config.collection_name:
|
||||
vector_store_collection_name = vector_store_config.collection_name + "_neptune_vector_store"
|
||||
else:
|
||||
vector_store_collection_name = "mem0_neptune_vector_store"
|
||||
self.config.vector_store.config.collection_name = vector_store_collection_name
|
||||
self.vector_store = NeptuneBase._create_vector_store(self.vector_store_provider, self.config)
|
||||
|
||||
self.llm = NeptuneBase._create_llm(self.config, self.llm_provider)
|
||||
self.user_id = None
|
||||
# Use threshold from graph_store config, default to 0.7 for backward compatibility
|
||||
self.threshold = self.config.graph_store.threshold if hasattr(self.config.graph_store, 'threshold') else 0.7
|
||||
self.vector_store_limit=5
|
||||
|
||||
def _delete_entities_cypher(self, source, destination, relationship, user_id):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters for deleting entities in the graph DB
|
||||
|
||||
:param source: source node
|
||||
:param destination: destination node
|
||||
:param relationship: relationship label
|
||||
:param user_id: user_id to use
|
||||
:return: str, dict
|
||||
"""
|
||||
|
||||
cypher = f"""
|
||||
MATCH (n {self.node_label} {{name: $source_name, user_id: $user_id}})
|
||||
-[r:{relationship}]->
|
||||
(m {self.node_label} {{name: $dest_name, user_id: $user_id}})
|
||||
DELETE r
|
||||
RETURN
|
||||
n.name AS source,
|
||||
m.name AS target,
|
||||
type(r) AS relationship
|
||||
"""
|
||||
params = {
|
||||
"source_name": source,
|
||||
"dest_name": destination,
|
||||
"user_id": user_id,
|
||||
}
|
||||
logger.debug(f"_delete_entities\n query={cypher}")
|
||||
return cypher, params
|
||||
|
||||
def _add_entities_by_source_cypher(
|
||||
self,
|
||||
source_node_list,
|
||||
destination,
|
||||
dest_embedding,
|
||||
destination_type,
|
||||
relationship,
|
||||
user_id,
|
||||
):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters for adding entities in the graph DB
|
||||
|
||||
:param source_node_list: list of source nodes
|
||||
:param destination: destination name
|
||||
:param dest_embedding: destination embedding
|
||||
:param destination_type: destination node label
|
||||
:param relationship: relationship label
|
||||
:param user_id: user id to use
|
||||
:return: str, dict
|
||||
"""
|
||||
destination_id = str(uuid.uuid4())
|
||||
destination_payload = {
|
||||
"name": destination,
|
||||
"type": destination_type,
|
||||
"user_id": user_id,
|
||||
"created_at": datetime.now(pytz.timezone("US/Pacific")).isoformat(),
|
||||
}
|
||||
self.vector_store.insert(
|
||||
vectors=[dest_embedding],
|
||||
payloads=[destination_payload],
|
||||
ids=[destination_id],
|
||||
)
|
||||
|
||||
destination_label = self.node_label if self.node_label else f":`{destination_type}`"
|
||||
destination_extra_set = f", destination:`{destination_type}`" if self.node_label else ""
|
||||
|
||||
cypher = f"""
|
||||
MATCH (source {{user_id: $user_id}})
|
||||
WHERE id(source) = $source_id
|
||||
SET source.mentions = coalesce(source.mentions, 0) + 1
|
||||
WITH source
|
||||
MERGE (destination {destination_label} {{`~id`: $destination_id, name: $destination_name, user_id: $user_id}})
|
||||
ON CREATE SET
|
||||
destination.created = timestamp(),
|
||||
destination.updated = timestamp(),
|
||||
destination.mentions = 1
|
||||
{destination_extra_set}
|
||||
ON MATCH SET
|
||||
destination.mentions = coalesce(destination.mentions, 0) + 1,
|
||||
destination.updated = timestamp()
|
||||
WITH source, destination
|
||||
MERGE (source)-[r:{relationship}]->(destination)
|
||||
ON CREATE SET
|
||||
r.created = timestamp(),
|
||||
r.updated = timestamp(),
|
||||
r.mentions = 1
|
||||
ON MATCH SET
|
||||
r.mentions = coalesce(r.mentions, 0) + 1,
|
||||
r.updated = timestamp()
|
||||
RETURN source.name AS source, type(r) AS relationship, destination.name AS target, id(destination) AS destination_id
|
||||
"""
|
||||
|
||||
params = {
|
||||
"source_id": source_node_list[0]["id(source_candidate)"],
|
||||
"destination_id": destination_id,
|
||||
"destination_name": destination,
|
||||
"dest_embedding": dest_embedding,
|
||||
"user_id": user_id,
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
f"_add_entities:\n source_node_search_result={source_node_list[0]}\n query={cypher}"
|
||||
)
|
||||
return cypher, params
|
||||
|
||||
def _add_entities_by_destination_cypher(
|
||||
self,
|
||||
source,
|
||||
source_embedding,
|
||||
source_type,
|
||||
destination_node_list,
|
||||
relationship,
|
||||
user_id,
|
||||
):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters for adding entities in the graph DB
|
||||
|
||||
:param source: source node name
|
||||
:param source_embedding: source node embedding
|
||||
:param source_type: source node label
|
||||
:param destination_node_list: list of dest nodes
|
||||
:param relationship: relationship label
|
||||
:param user_id: user id to use
|
||||
:return: str, dict
|
||||
"""
|
||||
source_id = str(uuid.uuid4())
|
||||
source_payload = {
|
||||
"name": source,
|
||||
"type": source_type,
|
||||
"user_id": user_id,
|
||||
"created_at": datetime.now(pytz.timezone("US/Pacific")).isoformat(),
|
||||
}
|
||||
self.vector_store.insert(
|
||||
vectors=[source_embedding],
|
||||
payloads=[source_payload],
|
||||
ids=[source_id],
|
||||
)
|
||||
|
||||
source_label = self.node_label if self.node_label else f":`{source_type}`"
|
||||
source_extra_set = f", source:`{source_type}`" if self.node_label else ""
|
||||
|
||||
cypher = f"""
|
||||
MATCH (destination {{user_id: $user_id}})
|
||||
WHERE id(destination) = $destination_id
|
||||
SET
|
||||
destination.mentions = coalesce(destination.mentions, 0) + 1,
|
||||
destination.updated = timestamp()
|
||||
WITH destination
|
||||
MERGE (source {source_label} {{`~id`: $source_id, name: $source_name, user_id: $user_id}})
|
||||
ON CREATE SET
|
||||
source.created = timestamp(),
|
||||
source.updated = timestamp(),
|
||||
source.mentions = 1
|
||||
{source_extra_set}
|
||||
ON MATCH SET
|
||||
source.mentions = coalesce(source.mentions, 0) + 1,
|
||||
source.updated = timestamp()
|
||||
WITH source, destination
|
||||
MERGE (source)-[r:{relationship}]->(destination)
|
||||
ON CREATE SET
|
||||
r.created = timestamp(),
|
||||
r.updated = timestamp(),
|
||||
r.mentions = 1
|
||||
ON MATCH SET
|
||||
r.mentions = coalesce(r.mentions, 0) + 1,
|
||||
r.updated = timestamp()
|
||||
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
||||
"""
|
||||
|
||||
params = {
|
||||
"destination_id": destination_node_list[0]["id(destination_candidate)"],
|
||||
"source_id": source_id,
|
||||
"source_name": source,
|
||||
"source_embedding": source_embedding,
|
||||
"user_id": user_id,
|
||||
}
|
||||
logger.debug(
|
||||
f"_add_entities:\n destination_node_search_result={destination_node_list[0]}\n query={cypher}"
|
||||
)
|
||||
return cypher, params
|
||||
|
||||
def _add_relationship_entities_cypher(
|
||||
self,
|
||||
source_node_list,
|
||||
destination_node_list,
|
||||
relationship,
|
||||
user_id,
|
||||
):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters for adding entities in the graph DB
|
||||
|
||||
:param source_node_list: list of source node ids
|
||||
:param destination_node_list: list of dest node ids
|
||||
:param relationship: relationship label
|
||||
:param user_id: user id to use
|
||||
:return: str, dict
|
||||
"""
|
||||
|
||||
cypher = f"""
|
||||
MATCH (source {{user_id: $user_id}})
|
||||
WHERE id(source) = $source_id
|
||||
SET
|
||||
source.mentions = coalesce(source.mentions, 0) + 1,
|
||||
source.updated = timestamp()
|
||||
WITH source
|
||||
MATCH (destination {{user_id: $user_id}})
|
||||
WHERE id(destination) = $destination_id
|
||||
SET
|
||||
destination.mentions = coalesce(destination.mentions) + 1,
|
||||
destination.updated = timestamp()
|
||||
MERGE (source)-[r:{relationship}]->(destination)
|
||||
ON CREATE SET
|
||||
r.created_at = timestamp(),
|
||||
r.updated_at = timestamp(),
|
||||
r.mentions = 1
|
||||
ON MATCH SET r.mentions = coalesce(r.mentions, 0) + 1
|
||||
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
||||
"""
|
||||
params = {
|
||||
"source_id": source_node_list[0]["id(source_candidate)"],
|
||||
"destination_id": destination_node_list[0]["id(destination_candidate)"],
|
||||
"user_id": user_id,
|
||||
}
|
||||
logger.debug(
|
||||
f"_add_entities:\n destination_node_search_result={destination_node_list[0]}\n source_node_search_result={source_node_list[0]}\n query={cypher}"
|
||||
)
|
||||
return cypher, params
|
||||
|
||||
def _add_new_entities_cypher(
|
||||
self,
|
||||
source,
|
||||
source_embedding,
|
||||
source_type,
|
||||
destination,
|
||||
dest_embedding,
|
||||
destination_type,
|
||||
relationship,
|
||||
user_id,
|
||||
):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters for adding entities in the graph DB
|
||||
|
||||
:param source: source node name
|
||||
:param source_embedding: source node embedding
|
||||
:param source_type: source node label
|
||||
:param destination: destination name
|
||||
:param dest_embedding: destination embedding
|
||||
:param destination_type: destination node label
|
||||
:param relationship: relationship label
|
||||
:param user_id: user id to use
|
||||
:return: str, dict
|
||||
"""
|
||||
source_id = str(uuid.uuid4())
|
||||
source_payload = {
|
||||
"name": source,
|
||||
"type": source_type,
|
||||
"user_id": user_id,
|
||||
"created_at": datetime.now(pytz.timezone("US/Pacific")).isoformat(),
|
||||
}
|
||||
destination_id = str(uuid.uuid4())
|
||||
destination_payload = {
|
||||
"name": destination,
|
||||
"type": destination_type,
|
||||
"user_id": user_id,
|
||||
"created_at": datetime.now(pytz.timezone("US/Pacific")).isoformat(),
|
||||
}
|
||||
self.vector_store.insert(
|
||||
vectors=[source_embedding, dest_embedding],
|
||||
payloads=[source_payload, destination_payload],
|
||||
ids=[source_id, destination_id],
|
||||
)
|
||||
|
||||
source_label = self.node_label if self.node_label else f":`{source_type}`"
|
||||
source_extra_set = f", source:`{source_type}`" if self.node_label else ""
|
||||
destination_label = self.node_label if self.node_label else f":`{destination_type}`"
|
||||
destination_extra_set = f", destination:`{destination_type}`" if self.node_label else ""
|
||||
|
||||
cypher = f"""
|
||||
MERGE (n {source_label} {{name: $source_name, user_id: $user_id, `~id`: $source_id}})
|
||||
ON CREATE SET n.created = timestamp(),
|
||||
n.mentions = 1
|
||||
{source_extra_set}
|
||||
ON MATCH SET n.mentions = coalesce(n.mentions, 0) + 1
|
||||
WITH n
|
||||
MERGE (m {destination_label} {{name: $dest_name, user_id: $user_id, `~id`: $dest_id}})
|
||||
ON CREATE SET m.created = timestamp(),
|
||||
m.mentions = 1
|
||||
{destination_extra_set}
|
||||
ON MATCH SET m.mentions = coalesce(m.mentions, 0) + 1
|
||||
WITH n, m
|
||||
MERGE (n)-[rel:{relationship}]->(m)
|
||||
ON CREATE SET rel.created = timestamp(), rel.mentions = 1
|
||||
ON MATCH SET rel.mentions = coalesce(rel.mentions, 0) + 1
|
||||
RETURN n.name AS source, type(rel) AS relationship, m.name AS target
|
||||
"""
|
||||
params = {
|
||||
"source_id": source_id,
|
||||
"dest_id": destination_id,
|
||||
"source_name": source,
|
||||
"dest_name": destination,
|
||||
"source_embedding": source_embedding,
|
||||
"dest_embedding": dest_embedding,
|
||||
"user_id": user_id,
|
||||
}
|
||||
logger.debug(
|
||||
f"_add_new_entities_cypher:\n query={cypher}"
|
||||
)
|
||||
return cypher, params
|
||||
|
||||
def _search_source_node_cypher(self, source_embedding, user_id, threshold):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters to search for source nodes
|
||||
|
||||
:param source_embedding: source vector
|
||||
:param user_id: user_id to use
|
||||
:param threshold: the threshold for similarity
|
||||
:return: str, dict
|
||||
"""
|
||||
|
||||
source_nodes = self.vector_store.search(
|
||||
query="",
|
||||
vectors=source_embedding,
|
||||
limit=self.vector_store_limit,
|
||||
filters={"user_id": user_id},
|
||||
)
|
||||
|
||||
ids = [n.id for n in filter(lambda s: s.score > threshold, source_nodes)]
|
||||
|
||||
cypher = f"""
|
||||
MATCH (source_candidate {self.node_label})
|
||||
WHERE source_candidate.user_id = $user_id AND id(source_candidate) IN $ids
|
||||
RETURN id(source_candidate)
|
||||
"""
|
||||
|
||||
params = {
|
||||
"ids": ids,
|
||||
"source_embedding": source_embedding,
|
||||
"user_id": user_id,
|
||||
"threshold": threshold,
|
||||
}
|
||||
logger.debug(f"_search_source_node\n query={cypher}")
|
||||
return cypher, params
|
||||
|
||||
def _search_destination_node_cypher(self, destination_embedding, user_id, threshold):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters to search for destination nodes
|
||||
|
||||
:param source_embedding: source vector
|
||||
:param user_id: user_id to use
|
||||
:param threshold: the threshold for similarity
|
||||
:return: str, dict
|
||||
"""
|
||||
destination_nodes = self.vector_store.search(
|
||||
query="",
|
||||
vectors=destination_embedding,
|
||||
limit=self.vector_store_limit,
|
||||
filters={"user_id": user_id},
|
||||
)
|
||||
|
||||
ids = [n.id for n in filter(lambda d: d.score > threshold, destination_nodes)]
|
||||
|
||||
cypher = f"""
|
||||
MATCH (destination_candidate {self.node_label})
|
||||
WHERE destination_candidate.user_id = $user_id AND id(destination_candidate) IN $ids
|
||||
RETURN id(destination_candidate)
|
||||
"""
|
||||
|
||||
params = {
|
||||
"ids": ids,
|
||||
"destination_embedding": destination_embedding,
|
||||
"user_id": user_id,
|
||||
}
|
||||
|
||||
logger.debug(f"_search_destination_node\n query={cypher}")
|
||||
return cypher, params
|
||||
|
||||
def _delete_all_cypher(self, filters):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters to delete all edges/nodes in the memory store
|
||||
|
||||
:param filters: search filters
|
||||
:return: str, dict
|
||||
"""
|
||||
|
||||
# remove the vector store index
|
||||
self.vector_store.reset()
|
||||
|
||||
# create a query that: deletes the nodes of the graph_store
|
||||
cypher = f"""
|
||||
MATCH (n {self.node_label} {{user_id: $user_id}})
|
||||
DETACH DELETE n
|
||||
"""
|
||||
params = {"user_id": filters["user_id"]}
|
||||
|
||||
logger.debug(f"delete_all query={cypher}")
|
||||
return cypher, params
|
||||
|
||||
def _get_all_cypher(self, filters, limit):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters to get all edges/nodes in the memory store
|
||||
|
||||
:param filters: search filters
|
||||
:param limit: return limit
|
||||
:return: str, dict
|
||||
"""
|
||||
|
||||
cypher = f"""
|
||||
MATCH (n {self.node_label} {{user_id: $user_id}})-[r]->(m {self.node_label} {{user_id: $user_id}})
|
||||
RETURN n.name AS source, type(r) AS relationship, m.name AS target
|
||||
LIMIT $limit
|
||||
"""
|
||||
params = {"user_id": filters["user_id"], "limit": limit}
|
||||
return cypher, params
|
||||
|
||||
def _search_graph_db_cypher(self, n_embedding, filters, limit):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters to search for similar nodes in the memory store
|
||||
|
||||
:param n_embedding: node vector
|
||||
:param filters: search filters
|
||||
:param limit: return limit
|
||||
:return: str, dict
|
||||
"""
|
||||
|
||||
# search vector store for applicable nodes using cosine similarity
|
||||
search_nodes = self.vector_store.search(
|
||||
query="",
|
||||
vectors=n_embedding,
|
||||
limit=self.vector_store_limit,
|
||||
filters=filters,
|
||||
)
|
||||
|
||||
ids = [n.id for n in search_nodes]
|
||||
|
||||
cypher_query = f"""
|
||||
MATCH (n {self.node_label})-[r]->(m)
|
||||
WHERE n.user_id = $user_id AND id(n) IN $n_ids
|
||||
RETURN n.name AS source, id(n) AS source_id, type(r) AS relationship, id(r) AS relation_id, m.name AS destination, id(m) AS destination_id
|
||||
UNION
|
||||
MATCH (m)-[r]->(n {self.node_label})
|
||||
RETURN m.name AS source, id(m) AS source_id, type(r) AS relationship, id(r) AS relation_id, n.name AS destination, id(n) AS destination_id
|
||||
LIMIT $limit
|
||||
"""
|
||||
params = {
|
||||
"n_ids": ids,
|
||||
"user_id": filters["user_id"],
|
||||
"limit": limit,
|
||||
}
|
||||
logger.debug(f"_search_graph_db\n query={cypher_query}")
|
||||
|
||||
return cypher_query, params
|
||||
475
graphs/neptune/neptunegraph.py
Normal file
475
graphs/neptune/neptunegraph.py
Normal file
@@ -0,0 +1,475 @@
|
||||
import logging
|
||||
|
||||
from .base import NeptuneBase
|
||||
|
||||
try:
|
||||
from langchain_aws import NeptuneAnalyticsGraph
|
||||
from botocore.config import Config
|
||||
except ImportError:
|
||||
raise ImportError("langchain_aws is not installed. Please install it using 'make install_all'.")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryGraph(NeptuneBase):
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
self.graph = None
|
||||
endpoint = self.config.graph_store.config.endpoint
|
||||
app_id = self.config.graph_store.config.app_id
|
||||
if endpoint and endpoint.startswith("neptune-graph://"):
|
||||
graph_identifier = endpoint.replace("neptune-graph://", "")
|
||||
self.graph = NeptuneAnalyticsGraph(graph_identifier = graph_identifier,
|
||||
config = Config(user_agent_appid=app_id))
|
||||
|
||||
if not self.graph:
|
||||
raise ValueError("Unable to create a Neptune client: missing 'endpoint' in config")
|
||||
|
||||
self.node_label = ":`__Entity__`" if self.config.graph_store.config.base_label else ""
|
||||
|
||||
self.embedding_model = NeptuneBase._create_embedding_model(self.config)
|
||||
|
||||
# Default to openai if no specific provider is configured
|
||||
self.llm_provider = "openai"
|
||||
if self.config.llm.provider:
|
||||
self.llm_provider = self.config.llm.provider
|
||||
if self.config.graph_store.llm:
|
||||
self.llm_provider = self.config.graph_store.llm.provider
|
||||
|
||||
self.llm = NeptuneBase._create_llm(self.config, self.llm_provider)
|
||||
self.user_id = None
|
||||
# Use threshold from graph_store config, default to 0.7 for backward compatibility
|
||||
self.threshold = self.config.graph_store.threshold if hasattr(self.config.graph_store, 'threshold') else 0.7
|
||||
|
||||
def _delete_entities_cypher(self, source, destination, relationship, user_id):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters for deleting entities in the graph DB
|
||||
|
||||
:param source: source node
|
||||
:param destination: destination node
|
||||
:param relationship: relationship label
|
||||
:param user_id: user_id to use
|
||||
:return: str, dict
|
||||
"""
|
||||
|
||||
cypher = f"""
|
||||
MATCH (n {self.node_label} {{name: $source_name, user_id: $user_id}})
|
||||
-[r:{relationship}]->
|
||||
(m {self.node_label} {{name: $dest_name, user_id: $user_id}})
|
||||
DELETE r
|
||||
RETURN
|
||||
n.name AS source,
|
||||
m.name AS target,
|
||||
type(r) AS relationship
|
||||
"""
|
||||
params = {
|
||||
"source_name": source,
|
||||
"dest_name": destination,
|
||||
"user_id": user_id,
|
||||
}
|
||||
logger.debug(f"_delete_entities\n query={cypher}")
|
||||
return cypher, params
|
||||
|
||||
def _add_entities_by_source_cypher(
|
||||
self,
|
||||
source_node_list,
|
||||
destination,
|
||||
dest_embedding,
|
||||
destination_type,
|
||||
relationship,
|
||||
user_id,
|
||||
):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters for adding entities in the graph DB
|
||||
|
||||
:param source_node_list: list of source nodes
|
||||
:param destination: destination name
|
||||
:param dest_embedding: destination embedding
|
||||
:param destination_type: destination node label
|
||||
:param relationship: relationship label
|
||||
:param user_id: user id to use
|
||||
:return: str, dict
|
||||
"""
|
||||
|
||||
destination_label = self.node_label if self.node_label else f":`{destination_type}`"
|
||||
destination_extra_set = f", destination:`{destination_type}`" if self.node_label else ""
|
||||
|
||||
cypher = f"""
|
||||
MATCH (source {{user_id: $user_id}})
|
||||
WHERE id(source) = $source_id
|
||||
SET source.mentions = coalesce(source.mentions, 0) + 1
|
||||
WITH source
|
||||
MERGE (destination {destination_label} {{name: $destination_name, user_id: $user_id}})
|
||||
ON CREATE SET
|
||||
destination.created = timestamp(),
|
||||
destination.updated = timestamp(),
|
||||
destination.mentions = 1
|
||||
{destination_extra_set}
|
||||
ON MATCH SET
|
||||
destination.mentions = coalesce(destination.mentions, 0) + 1,
|
||||
destination.updated = timestamp()
|
||||
WITH source, destination, $dest_embedding as dest_embedding
|
||||
CALL neptune.algo.vectors.upsert(destination, dest_embedding)
|
||||
WITH source, destination
|
||||
MERGE (source)-[r:{relationship}]->(destination)
|
||||
ON CREATE SET
|
||||
r.created = timestamp(),
|
||||
r.updated = timestamp(),
|
||||
r.mentions = 1
|
||||
ON MATCH SET
|
||||
r.mentions = coalesce(r.mentions, 0) + 1,
|
||||
r.updated = timestamp()
|
||||
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
||||
"""
|
||||
|
||||
params = {
|
||||
"source_id": source_node_list[0]["id(source_candidate)"],
|
||||
"destination_name": destination,
|
||||
"dest_embedding": dest_embedding,
|
||||
"user_id": user_id,
|
||||
}
|
||||
logger.debug(
|
||||
f"_add_entities:\n source_node_search_result={source_node_list[0]}\n query={cypher}"
|
||||
)
|
||||
return cypher, params
|
||||
|
||||
def _add_entities_by_destination_cypher(
|
||||
self,
|
||||
source,
|
||||
source_embedding,
|
||||
source_type,
|
||||
destination_node_list,
|
||||
relationship,
|
||||
user_id,
|
||||
):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters for adding entities in the graph DB
|
||||
|
||||
:param source: source node name
|
||||
:param source_embedding: source node embedding
|
||||
:param source_type: source node label
|
||||
:param destination_node_list: list of dest nodes
|
||||
:param relationship: relationship label
|
||||
:param user_id: user id to use
|
||||
:return: str, dict
|
||||
"""
|
||||
|
||||
source_label = self.node_label if self.node_label else f":`{source_type}`"
|
||||
source_extra_set = f", source:`{source_type}`" if self.node_label else ""
|
||||
|
||||
cypher = f"""
|
||||
MATCH (destination {{user_id: $user_id}})
|
||||
WHERE id(destination) = $destination_id
|
||||
SET
|
||||
destination.mentions = coalesce(destination.mentions, 0) + 1,
|
||||
destination.updated = timestamp()
|
||||
WITH destination
|
||||
MERGE (source {source_label} {{name: $source_name, user_id: $user_id}})
|
||||
ON CREATE SET
|
||||
source.created = timestamp(),
|
||||
source.updated = timestamp(),
|
||||
source.mentions = 1
|
||||
{source_extra_set}
|
||||
ON MATCH SET
|
||||
source.mentions = coalesce(source.mentions, 0) + 1,
|
||||
source.updated = timestamp()
|
||||
WITH source, destination, $source_embedding as source_embedding
|
||||
CALL neptune.algo.vectors.upsert(source, source_embedding)
|
||||
WITH source, destination
|
||||
MERGE (source)-[r:{relationship}]->(destination)
|
||||
ON CREATE SET
|
||||
r.created = timestamp(),
|
||||
r.updated = timestamp(),
|
||||
r.mentions = 1
|
||||
ON MATCH SET
|
||||
r.mentions = coalesce(r.mentions, 0) + 1,
|
||||
r.updated = timestamp()
|
||||
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
||||
"""
|
||||
|
||||
params = {
|
||||
"destination_id": destination_node_list[0]["id(destination_candidate)"],
|
||||
"source_name": source,
|
||||
"source_embedding": source_embedding,
|
||||
"user_id": user_id,
|
||||
}
|
||||
logger.debug(
|
||||
f"_add_entities:\n destination_node_search_result={destination_node_list[0]}\n query={cypher}"
|
||||
)
|
||||
return cypher, params
|
||||
|
||||
def _add_relationship_entities_cypher(
|
||||
self,
|
||||
source_node_list,
|
||||
destination_node_list,
|
||||
relationship,
|
||||
user_id,
|
||||
):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters for adding entities in the graph DB
|
||||
|
||||
:param source_node_list: list of source node ids
|
||||
:param destination_node_list: list of dest node ids
|
||||
:param relationship: relationship label
|
||||
:param user_id: user id to use
|
||||
:return: str, dict
|
||||
"""
|
||||
|
||||
cypher = f"""
|
||||
MATCH (source {{user_id: $user_id}})
|
||||
WHERE id(source) = $source_id
|
||||
SET
|
||||
source.mentions = coalesce(source.mentions, 0) + 1,
|
||||
source.updated = timestamp()
|
||||
WITH source
|
||||
MATCH (destination {{user_id: $user_id}})
|
||||
WHERE id(destination) = $destination_id
|
||||
SET
|
||||
destination.mentions = coalesce(destination.mentions) + 1,
|
||||
destination.updated = timestamp()
|
||||
MERGE (source)-[r:{relationship}]->(destination)
|
||||
ON CREATE SET
|
||||
r.created_at = timestamp(),
|
||||
r.updated_at = timestamp(),
|
||||
r.mentions = 1
|
||||
ON MATCH SET r.mentions = coalesce(r.mentions, 0) + 1
|
||||
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
||||
"""
|
||||
params = {
|
||||
"source_id": source_node_list[0]["id(source_candidate)"],
|
||||
"destination_id": destination_node_list[0]["id(destination_candidate)"],
|
||||
"user_id": user_id,
|
||||
}
|
||||
logger.debug(
|
||||
f"_add_entities:\n destination_node_search_result={destination_node_list[0]}\n source_node_search_result={source_node_list[0]}\n query={cypher}"
|
||||
)
|
||||
return cypher, params
|
||||
|
||||
def _add_new_entities_cypher(
|
||||
self,
|
||||
source,
|
||||
source_embedding,
|
||||
source_type,
|
||||
destination,
|
||||
dest_embedding,
|
||||
destination_type,
|
||||
relationship,
|
||||
user_id,
|
||||
):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters for adding entities in the graph DB
|
||||
|
||||
:param source: source node name
|
||||
:param source_embedding: source node embedding
|
||||
:param source_type: source node label
|
||||
:param destination: destination name
|
||||
:param dest_embedding: destination embedding
|
||||
:param destination_type: destination node label
|
||||
:param relationship: relationship label
|
||||
:param user_id: user id to use
|
||||
:return: str, dict
|
||||
"""
|
||||
|
||||
source_label = self.node_label if self.node_label else f":`{source_type}`"
|
||||
source_extra_set = f", source:`{source_type}`" if self.node_label else ""
|
||||
destination_label = self.node_label if self.node_label else f":`{destination_type}`"
|
||||
destination_extra_set = f", destination:`{destination_type}`" if self.node_label else ""
|
||||
|
||||
cypher = f"""
|
||||
MERGE (n {source_label} {{name: $source_name, user_id: $user_id}})
|
||||
ON CREATE SET n.created = timestamp(),
|
||||
n.updated = timestamp(),
|
||||
n.mentions = 1
|
||||
{source_extra_set}
|
||||
ON MATCH SET
|
||||
n.mentions = coalesce(n.mentions, 0) + 1,
|
||||
n.updated = timestamp()
|
||||
WITH n, $source_embedding as source_embedding
|
||||
CALL neptune.algo.vectors.upsert(n, source_embedding)
|
||||
WITH n
|
||||
MERGE (m {destination_label} {{name: $dest_name, user_id: $user_id}})
|
||||
ON CREATE SET
|
||||
m.created = timestamp(),
|
||||
m.updated = timestamp(),
|
||||
m.mentions = 1
|
||||
{destination_extra_set}
|
||||
ON MATCH SET
|
||||
m.updated = timestamp(),
|
||||
m.mentions = coalesce(m.mentions, 0) + 1
|
||||
WITH n, m, $dest_embedding as dest_embedding
|
||||
CALL neptune.algo.vectors.upsert(m, dest_embedding)
|
||||
WITH n, m
|
||||
MERGE (n)-[rel:{relationship}]->(m)
|
||||
ON CREATE SET
|
||||
rel.created = timestamp(),
|
||||
rel.updated = timestamp(),
|
||||
rel.mentions = 1
|
||||
ON MATCH SET
|
||||
rel.updated = timestamp(),
|
||||
rel.mentions = coalesce(rel.mentions, 0) + 1
|
||||
RETURN n.name AS source, type(rel) AS relationship, m.name AS target
|
||||
"""
|
||||
params = {
|
||||
"source_name": source,
|
||||
"dest_name": destination,
|
||||
"source_embedding": source_embedding,
|
||||
"dest_embedding": dest_embedding,
|
||||
"user_id": user_id,
|
||||
}
|
||||
logger.debug(
|
||||
f"_add_new_entities_cypher:\n query={cypher}"
|
||||
)
|
||||
return cypher, params
|
||||
|
||||
def _search_source_node_cypher(self, source_embedding, user_id, threshold):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters to search for source nodes
|
||||
|
||||
:param source_embedding: source vector
|
||||
:param user_id: user_id to use
|
||||
:param threshold: the threshold for similarity
|
||||
:return: str, dict
|
||||
"""
|
||||
cypher = f"""
|
||||
MATCH (source_candidate {self.node_label})
|
||||
WHERE source_candidate.user_id = $user_id
|
||||
|
||||
WITH source_candidate, $source_embedding as v_embedding
|
||||
CALL neptune.algo.vectors.distanceByEmbedding(
|
||||
v_embedding,
|
||||
source_candidate,
|
||||
{{metric:"CosineSimilarity"}}
|
||||
) YIELD distance
|
||||
WITH source_candidate, distance AS cosine_similarity
|
||||
WHERE cosine_similarity >= $threshold
|
||||
|
||||
WITH source_candidate, cosine_similarity
|
||||
ORDER BY cosine_similarity DESC
|
||||
LIMIT 1
|
||||
|
||||
RETURN id(source_candidate), cosine_similarity
|
||||
"""
|
||||
|
||||
params = {
|
||||
"source_embedding": source_embedding,
|
||||
"user_id": user_id,
|
||||
"threshold": threshold,
|
||||
}
|
||||
logger.debug(f"_search_source_node\n query={cypher}")
|
||||
return cypher, params
|
||||
|
||||
def _search_destination_node_cypher(self, destination_embedding, user_id, threshold):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters to search for destination nodes
|
||||
|
||||
:param source_embedding: source vector
|
||||
:param user_id: user_id to use
|
||||
:param threshold: the threshold for similarity
|
||||
:return: str, dict
|
||||
"""
|
||||
cypher = f"""
|
||||
MATCH (destination_candidate {self.node_label})
|
||||
WHERE destination_candidate.user_id = $user_id
|
||||
|
||||
WITH destination_candidate, $destination_embedding as v_embedding
|
||||
CALL neptune.algo.vectors.distanceByEmbedding(
|
||||
v_embedding,
|
||||
destination_candidate,
|
||||
{{metric:"CosineSimilarity"}}
|
||||
) YIELD distance
|
||||
WITH destination_candidate, distance AS cosine_similarity
|
||||
WHERE cosine_similarity >= $threshold
|
||||
|
||||
WITH destination_candidate, cosine_similarity
|
||||
ORDER BY cosine_similarity DESC
|
||||
LIMIT 1
|
||||
|
||||
RETURN id(destination_candidate), cosine_similarity
|
||||
"""
|
||||
params = {
|
||||
"destination_embedding": destination_embedding,
|
||||
"user_id": user_id,
|
||||
"threshold": threshold,
|
||||
}
|
||||
|
||||
logger.debug(f"_search_destination_node\n query={cypher}")
|
||||
return cypher, params
|
||||
|
||||
def _delete_all_cypher(self, filters):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters to delete all edges/nodes in the memory store
|
||||
|
||||
:param filters: search filters
|
||||
:return: str, dict
|
||||
"""
|
||||
cypher = f"""
|
||||
MATCH (n {self.node_label} {{user_id: $user_id}})
|
||||
DETACH DELETE n
|
||||
"""
|
||||
params = {"user_id": filters["user_id"]}
|
||||
|
||||
logger.debug(f"delete_all query={cypher}")
|
||||
return cypher, params
|
||||
|
||||
def _get_all_cypher(self, filters, limit):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters to get all edges/nodes in the memory store
|
||||
|
||||
:param filters: search filters
|
||||
:param limit: return limit
|
||||
:return: str, dict
|
||||
"""
|
||||
|
||||
cypher = f"""
|
||||
MATCH (n {self.node_label} {{user_id: $user_id}})-[r]->(m {self.node_label} {{user_id: $user_id}})
|
||||
RETURN n.name AS source, type(r) AS relationship, m.name AS target
|
||||
LIMIT $limit
|
||||
"""
|
||||
params = {"user_id": filters["user_id"], "limit": limit}
|
||||
return cypher, params
|
||||
|
||||
def _search_graph_db_cypher(self, n_embedding, filters, limit):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters to search for similar nodes in the memory store
|
||||
|
||||
:param n_embedding: node vector
|
||||
:param filters: search filters
|
||||
:param limit: return limit
|
||||
:return: str, dict
|
||||
"""
|
||||
|
||||
cypher_query = f"""
|
||||
MATCH (n {self.node_label})
|
||||
WHERE n.user_id = $user_id
|
||||
WITH n, $n_embedding as n_embedding
|
||||
CALL neptune.algo.vectors.distanceByEmbedding(
|
||||
n_embedding,
|
||||
n,
|
||||
{{metric:"CosineSimilarity"}}
|
||||
) YIELD distance
|
||||
WITH n, distance as similarity
|
||||
WHERE similarity >= $threshold
|
||||
CALL {{
|
||||
WITH n
|
||||
MATCH (n)-[r]->(m)
|
||||
RETURN n.name AS source, id(n) AS source_id, type(r) AS relationship, id(r) AS relation_id, m.name AS destination, id(m) AS destination_id
|
||||
UNION ALL
|
||||
WITH n
|
||||
MATCH (m)-[r]->(n)
|
||||
RETURN m.name AS source, id(m) AS source_id, type(r) AS relationship, id(r) AS relation_id, n.name AS destination, id(n) AS destination_id
|
||||
}}
|
||||
WITH distinct source, source_id, relationship, relation_id, destination, destination_id, similarity
|
||||
RETURN source, source_id, relationship, relation_id, destination, destination_id, similarity
|
||||
ORDER BY similarity DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
params = {
|
||||
"n_embedding": n_embedding,
|
||||
"threshold": self.threshold,
|
||||
"user_id": filters["user_id"],
|
||||
"limit": limit,
|
||||
}
|
||||
logger.debug(f"_search_graph_db\n query={cypher_query}")
|
||||
|
||||
return cypher_query, params
|
||||
371
graphs/tools.py
Normal file
371
graphs/tools.py
Normal file
@@ -0,0 +1,371 @@
|
||||
UPDATE_MEMORY_TOOL_GRAPH = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "update_graph_memory",
|
||||
"description": "Update the relationship key of an existing graph memory based on new information. This function should be called when there's a need to modify an existing relationship in the knowledge graph. The update should only be performed if the new information is more recent, more accurate, or provides additional context compared to the existing information. The source and destination nodes of the relationship must remain the same as in the existing graph memory; only the relationship itself can be updated.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"source": {
|
||||
"type": "string",
|
||||
"description": "The identifier of the source node in the relationship to be updated. This should match an existing node in the graph.",
|
||||
},
|
||||
"destination": {
|
||||
"type": "string",
|
||||
"description": "The identifier of the destination node in the relationship to be updated. This should match an existing node in the graph.",
|
||||
},
|
||||
"relationship": {
|
||||
"type": "string",
|
||||
"description": "The new or updated relationship between the source and destination nodes. This should be a concise, clear description of how the two nodes are connected.",
|
||||
},
|
||||
},
|
||||
"required": ["source", "destination", "relationship"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ADD_MEMORY_TOOL_GRAPH = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "add_graph_memory",
|
||||
"description": "Add a new graph memory to the knowledge graph. This function creates a new relationship between two nodes, potentially creating new nodes if they don't exist.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"source": {
|
||||
"type": "string",
|
||||
"description": "The identifier of the source node in the new relationship. This can be an existing node or a new node to be created.",
|
||||
},
|
||||
"destination": {
|
||||
"type": "string",
|
||||
"description": "The identifier of the destination node in the new relationship. This can be an existing node or a new node to be created.",
|
||||
},
|
||||
"relationship": {
|
||||
"type": "string",
|
||||
"description": "The type of relationship between the source and destination nodes. This should be a concise, clear description of how the two nodes are connected.",
|
||||
},
|
||||
"source_type": {
|
||||
"type": "string",
|
||||
"description": "The type or category of the source node. This helps in classifying and organizing nodes in the graph.",
|
||||
},
|
||||
"destination_type": {
|
||||
"type": "string",
|
||||
"description": "The type or category of the destination node. This helps in classifying and organizing nodes in the graph.",
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
"source",
|
||||
"destination",
|
||||
"relationship",
|
||||
"source_type",
|
||||
"destination_type",
|
||||
],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
NOOP_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "noop",
|
||||
"description": "No operation should be performed to the graph entities. This function is called when the system determines that no changes or additions are necessary based on the current input or context. It serves as a placeholder action when no other actions are required, ensuring that the system can explicitly acknowledge situations where no modifications to the graph are needed.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
RELATIONS_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "establish_relationships",
|
||||
"description": "Establish relationships among the entities based on the provided text.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"entities": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"source": {"type": "string", "description": "The source entity of the relationship."},
|
||||
"relationship": {
|
||||
"type": "string",
|
||||
"description": "The relationship between the source and destination entities.",
|
||||
},
|
||||
"destination": {
|
||||
"type": "string",
|
||||
"description": "The destination entity of the relationship.",
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
"source",
|
||||
"relationship",
|
||||
"destination",
|
||||
],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
}
|
||||
},
|
||||
"required": ["entities"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
EXTRACT_ENTITIES_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "extract_entities",
|
||||
"description": "Extract entities and their types from the text.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"entities": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"entity": {"type": "string", "description": "The name or identifier of the entity."},
|
||||
"entity_type": {"type": "string", "description": "The type or category of the entity."},
|
||||
},
|
||||
"required": ["entity", "entity_type"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"description": "An array of entities with their types.",
|
||||
}
|
||||
},
|
||||
"required": ["entities"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
UPDATE_MEMORY_STRUCT_TOOL_GRAPH = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "update_graph_memory",
|
||||
"description": "Update the relationship key of an existing graph memory based on new information. This function should be called when there's a need to modify an existing relationship in the knowledge graph. The update should only be performed if the new information is more recent, more accurate, or provides additional context compared to the existing information. The source and destination nodes of the relationship must remain the same as in the existing graph memory; only the relationship itself can be updated.",
|
||||
"strict": True,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"source": {
|
||||
"type": "string",
|
||||
"description": "The identifier of the source node in the relationship to be updated. This should match an existing node in the graph.",
|
||||
},
|
||||
"destination": {
|
||||
"type": "string",
|
||||
"description": "The identifier of the destination node in the relationship to be updated. This should match an existing node in the graph.",
|
||||
},
|
||||
"relationship": {
|
||||
"type": "string",
|
||||
"description": "The new or updated relationship between the source and destination nodes. This should be a concise, clear description of how the two nodes are connected.",
|
||||
},
|
||||
},
|
||||
"required": ["source", "destination", "relationship"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ADD_MEMORY_STRUCT_TOOL_GRAPH = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "add_graph_memory",
|
||||
"description": "Add a new graph memory to the knowledge graph. This function creates a new relationship between two nodes, potentially creating new nodes if they don't exist.",
|
||||
"strict": True,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"source": {
|
||||
"type": "string",
|
||||
"description": "The identifier of the source node in the new relationship. This can be an existing node or a new node to be created.",
|
||||
},
|
||||
"destination": {
|
||||
"type": "string",
|
||||
"description": "The identifier of the destination node in the new relationship. This can be an existing node or a new node to be created.",
|
||||
},
|
||||
"relationship": {
|
||||
"type": "string",
|
||||
"description": "The type of relationship between the source and destination nodes. This should be a concise, clear description of how the two nodes are connected.",
|
||||
},
|
||||
"source_type": {
|
||||
"type": "string",
|
||||
"description": "The type or category of the source node. This helps in classifying and organizing nodes in the graph.",
|
||||
},
|
||||
"destination_type": {
|
||||
"type": "string",
|
||||
"description": "The type or category of the destination node. This helps in classifying and organizing nodes in the graph.",
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
"source",
|
||||
"destination",
|
||||
"relationship",
|
||||
"source_type",
|
||||
"destination_type",
|
||||
],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
NOOP_STRUCT_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "noop",
|
||||
"description": "No operation should be performed to the graph entities. This function is called when the system determines that no changes or additions are necessary based on the current input or context. It serves as a placeholder action when no other actions are required, ensuring that the system can explicitly acknowledge situations where no modifications to the graph are needed.",
|
||||
"strict": True,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
RELATIONS_STRUCT_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "establish_relations",
|
||||
"description": "Establish relationships among the entities based on the provided text.",
|
||||
"strict": True,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"entities": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"source": {
|
||||
"type": "string",
|
||||
"description": "The source entity of the relationship.",
|
||||
},
|
||||
"relationship": {
|
||||
"type": "string",
|
||||
"description": "The relationship between the source and destination entities.",
|
||||
},
|
||||
"destination": {
|
||||
"type": "string",
|
||||
"description": "The destination entity of the relationship.",
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
"source",
|
||||
"relationship",
|
||||
"destination",
|
||||
],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
}
|
||||
},
|
||||
"required": ["entities"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
EXTRACT_ENTITIES_STRUCT_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "extract_entities",
|
||||
"description": "Extract entities and their types from the text.",
|
||||
"strict": True,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"entities": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"entity": {"type": "string", "description": "The name or identifier of the entity."},
|
||||
"entity_type": {"type": "string", "description": "The type or category of the entity."},
|
||||
},
|
||||
"required": ["entity", "entity_type"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"description": "An array of entities with their types.",
|
||||
}
|
||||
},
|
||||
"required": ["entities"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
DELETE_MEMORY_STRUCT_TOOL_GRAPH = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "delete_graph_memory",
|
||||
"description": "Delete the relationship between two nodes. This function deletes the existing relationship.",
|
||||
"strict": True,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"source": {
|
||||
"type": "string",
|
||||
"description": "The identifier of the source node in the relationship.",
|
||||
},
|
||||
"relationship": {
|
||||
"type": "string",
|
||||
"description": "The existing relationship between the source and destination nodes that needs to be deleted.",
|
||||
},
|
||||
"destination": {
|
||||
"type": "string",
|
||||
"description": "The identifier of the destination node in the relationship.",
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
"source",
|
||||
"relationship",
|
||||
"destination",
|
||||
],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
DELETE_MEMORY_TOOL_GRAPH = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "delete_graph_memory",
|
||||
"description": "Delete the relationship between two nodes. This function deletes the existing relationship.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"source": {
|
||||
"type": "string",
|
||||
"description": "The identifier of the source node in the relationship.",
|
||||
},
|
||||
"relationship": {
|
||||
"type": "string",
|
||||
"description": "The existing relationship between the source and destination nodes that needs to be deleted.",
|
||||
},
|
||||
"destination": {
|
||||
"type": "string",
|
||||
"description": "The identifier of the destination node in the relationship.",
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
"source",
|
||||
"relationship",
|
||||
"destination",
|
||||
],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
97
graphs/utils.py
Normal file
97
graphs/utils.py
Normal file
@@ -0,0 +1,97 @@
|
||||
UPDATE_GRAPH_PROMPT = """
|
||||
You are an AI expert specializing in graph memory management and optimization. Your task is to analyze existing graph memories alongside new information, and update the relationships in the memory list to ensure the most accurate, current, and coherent representation of knowledge.
|
||||
|
||||
Input:
|
||||
1. Existing Graph Memories: A list of current graph memories, each containing source, target, and relationship information.
|
||||
2. New Graph Memory: Fresh information to be integrated into the existing graph structure.
|
||||
|
||||
Guidelines:
|
||||
1. Identification: Use the source and target as primary identifiers when matching existing memories with new information.
|
||||
2. Conflict Resolution:
|
||||
- If new information contradicts an existing memory:
|
||||
a) For matching source and target but differing content, update the relationship of the existing memory.
|
||||
b) If the new memory provides more recent or accurate information, update the existing memory accordingly.
|
||||
3. Comprehensive Review: Thoroughly examine each existing graph memory against the new information, updating relationships as necessary. Multiple updates may be required.
|
||||
4. Consistency: Maintain a uniform and clear style across all memories. Each entry should be concise yet comprehensive.
|
||||
5. Semantic Coherence: Ensure that updates maintain or improve the overall semantic structure of the graph.
|
||||
6. Temporal Awareness: If timestamps are available, consider the recency of information when making updates.
|
||||
7. Relationship Refinement: Look for opportunities to refine relationship descriptions for greater precision or clarity.
|
||||
8. Redundancy Elimination: Identify and merge any redundant or highly similar relationships that may result from the update.
|
||||
|
||||
Memory Format:
|
||||
source -- RELATIONSHIP -- destination
|
||||
|
||||
Task Details:
|
||||
======= Existing Graph Memories:=======
|
||||
{existing_memories}
|
||||
|
||||
======= New Graph Memory:=======
|
||||
{new_memories}
|
||||
|
||||
Output:
|
||||
Provide a list of update instructions, each specifying the source, target, and the new relationship to be set. Only include memories that require updates.
|
||||
"""
|
||||
|
||||
EXTRACT_RELATIONS_PROMPT = """
|
||||
|
||||
You are an advanced algorithm designed to extract structured information from text to construct knowledge graphs. Your goal is to capture comprehensive and accurate information. Follow these key principles:
|
||||
|
||||
1. Extract only explicitly stated information from the text.
|
||||
2. Establish relationships among the entities provided.
|
||||
3. Use "USER_ID" as the source entity for any self-references (e.g., "I," "me," "my," etc.) in user messages.
|
||||
CUSTOM_PROMPT
|
||||
|
||||
Relationships:
|
||||
- Use consistent, general, and timeless relationship types.
|
||||
- Example: Prefer "professor" over "became_professor."
|
||||
- Relationships should only be established among the entities explicitly mentioned in the user message.
|
||||
|
||||
Entity Consistency:
|
||||
- Ensure that relationships are coherent and logically align with the context of the message.
|
||||
- Maintain consistent naming for entities across the extracted data.
|
||||
|
||||
Strive to construct a coherent and easily understandable knowledge graph by establishing all the relationships among the entities and adherence to the user’s context.
|
||||
|
||||
Adhere strictly to these guidelines to ensure high-quality knowledge graph extraction."""
|
||||
|
||||
DELETE_RELATIONS_SYSTEM_PROMPT = """
|
||||
You are a graph memory manager specializing in identifying, managing, and optimizing relationships within graph-based memories. Your primary task is to analyze a list of existing relationships and determine which ones should be deleted based on the new information provided.
|
||||
Input:
|
||||
1. Existing Graph Memories: A list of current graph memories, each containing source, relationship, and destination information.
|
||||
2. New Text: The new information to be integrated into the existing graph structure.
|
||||
3. Use "USER_ID" as node for any self-references (e.g., "I," "me," "my," etc.) in user messages.
|
||||
|
||||
Guidelines:
|
||||
1. Identification: Use the new information to evaluate existing relationships in the memory graph.
|
||||
2. Deletion Criteria: Delete a relationship only if it meets at least one of these conditions:
|
||||
- Outdated or Inaccurate: The new information is more recent or accurate.
|
||||
- Contradictory: The new information conflicts with or negates the existing information.
|
||||
3. DO NOT DELETE if their is a possibility of same type of relationship but different destination nodes.
|
||||
4. Comprehensive Analysis:
|
||||
- Thoroughly examine each existing relationship against the new information and delete as necessary.
|
||||
- Multiple deletions may be required based on the new information.
|
||||
5. Semantic Integrity:
|
||||
- Ensure that deletions maintain or improve the overall semantic structure of the graph.
|
||||
- Avoid deleting relationships that are NOT contradictory/outdated to the new information.
|
||||
6. Temporal Awareness: Prioritize recency when timestamps are available.
|
||||
7. Necessity Principle: Only DELETE relationships that must be deleted and are contradictory/outdated to the new information to maintain an accurate and coherent memory graph.
|
||||
|
||||
Note: DO NOT DELETE if their is a possibility of same type of relationship but different destination nodes.
|
||||
|
||||
For example:
|
||||
Existing Memory: alice -- loves_to_eat -- pizza
|
||||
New Information: Alice also loves to eat burger.
|
||||
|
||||
Do not delete in the above example because there is a possibility that Alice loves to eat both pizza and burger.
|
||||
|
||||
Memory Format:
|
||||
source -- relationship -- destination
|
||||
|
||||
Provide a list of deletion instructions, each specifying the relationship to be deleted.
|
||||
"""
|
||||
|
||||
|
||||
def get_delete_messages(existing_memories_string, data, user_id):
|
||||
return DELETE_RELATIONS_SYSTEM_PROMPT.replace(
|
||||
"USER_ID", user_id
|
||||
), f"Here are the existing memories: {existing_memories_string} \n\n New Information: {data}"
|
||||
Reference in New Issue
Block a user