Files
mem0/memory/utils.py
2026-03-06 21:11:10 +08:00

209 lines
6.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import hashlib
import re
from mem0.configs.prompts import (
FACT_RETRIEVAL_PROMPT,
USER_MEMORY_EXTRACTION_PROMPT,
AGENT_MEMORY_EXTRACTION_PROMPT,
)
def get_fact_retrieval_messages(message, is_agent_memory=False):
"""Get fact retrieval messages based on the memory type.
Args:
message: The message content to extract facts from
is_agent_memory: If True, use agent memory extraction prompt, else use user memory extraction prompt
Returns:
tuple: (system_prompt, user_prompt)
"""
if is_agent_memory:
return AGENT_MEMORY_EXTRACTION_PROMPT, f"Input:\n{message}"
else:
return USER_MEMORY_EXTRACTION_PROMPT, f"Input:\n{message}"
def get_fact_retrieval_messages_legacy(message):
"""Legacy function for backward compatibility."""
return FACT_RETRIEVAL_PROMPT, f"Input:\n{message}"
def parse_messages(messages):
response = ""
for msg in messages:
if msg["role"] == "system":
response += f"system: {msg['content']}\n"
if msg["role"] == "user":
response += f"user: {msg['content']}\n"
if msg["role"] == "assistant":
response += f"assistant: {msg['content']}\n"
return response
def format_entities(entities):
if not entities:
return ""
formatted_lines = []
for entity in entities:
simplified = f"{entity['source']} -- {entity['relationship']} -- {entity['destination']}"
formatted_lines.append(simplified)
return "\n".join(formatted_lines)
def remove_code_blocks(content: str) -> str:
"""
Removes enclosing code block markers ```[language] and ``` from a given string.
Remarks:
- The function uses a regex pattern to match code blocks that may start with ``` followed by an optional language tag (letters or numbers) and end with ```.
- If a code block is detected, it returns only the inner content, stripping out the markers.
- If no code block markers are found, the original content is returned as-is.
"""
pattern = r"^```[a-zA-Z0-9]*\n([\s\S]*?)\n```$"
match = re.match(pattern, content.strip())
match_res=match.group(1).strip() if match else content.strip()
return re.sub(r"<think>.*?</think>", "", match_res, flags=re.DOTALL).strip()
def extract_json(text):
"""
Extracts JSON content from a string, removing enclosing triple backticks and optional 'json' tag if present.
If no code block is found, returns the text as-is.
"""
text = text.strip()
match = re.search(r"```(?:json)?\s*(.*?)\s*```", text, re.DOTALL)
if match:
json_str = match.group(1)
else:
json_str = text # assume it's raw JSON
return json_str
def get_image_description(image_obj, llm, vision_details):
"""
Get the description of the image
"""
if isinstance(image_obj, str):
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "A user is providing an image. Provide a high level description of the image and do not include any additional text.",
},
{"type": "image_url", "image_url": {"url": image_obj, "detail": vision_details}},
],
},
]
else:
messages = [image_obj]
response = llm.generate_response(messages=messages)
return response
def parse_vision_messages(messages, llm=None, vision_details="auto"):
"""
Parse the vision messages from the messages
"""
returned_messages = []
for msg in messages:
if msg["role"] == "system":
returned_messages.append(msg)
continue
# Handle message content
if isinstance(msg["content"], list):
# Multiple image URLs in content
description = get_image_description(msg, llm, vision_details)
returned_messages.append({"role": msg["role"], "content": description})
elif isinstance(msg["content"], dict) and msg["content"].get("type") == "image_url":
# Single image content
image_url = msg["content"]["image_url"]["url"]
try:
description = get_image_description(image_url, llm, vision_details)
returned_messages.append({"role": msg["role"], "content": description})
except Exception:
raise Exception(f"Error while downloading {image_url}.")
else:
# Regular text content
returned_messages.append(msg)
return returned_messages
def process_telemetry_filters(filters):
"""
Process the telemetry filters
"""
if filters is None:
return {}
encoded_ids = {}
if "user_id" in filters:
encoded_ids["user_id"] = hashlib.md5(filters["user_id"].encode()).hexdigest()
if "agent_id" in filters:
encoded_ids["agent_id"] = hashlib.md5(filters["agent_id"].encode()).hexdigest()
if "run_id" in filters:
encoded_ids["run_id"] = hashlib.md5(filters["run_id"].encode()).hexdigest()
return list(filters.keys()), encoded_ids
def sanitize_relationship_for_cypher(relationship) -> str:
"""Sanitize relationship text for Cypher queries by replacing problematic characters."""
char_map = {
"...": "_ellipsis_",
"": "_ellipsis_",
"": "_period_",
"": "_comma_",
"": "_semicolon_",
"": "_colon_",
"": "_exclamation_",
"": "_question_",
"": "_lparen_",
"": "_rparen_",
"": "_lbracket_",
"": "_rbracket_",
"": "_langle_",
"": "_rangle_",
"'": "_apostrophe_",
'"': "_quote_",
"\\": "_backslash_",
"/": "_slash_",
"|": "_pipe_",
"&": "_ampersand_",
"=": "_equals_",
"+": "_plus_",
"*": "_asterisk_",
"^": "_caret_",
"%": "_percent_",
"$": "_dollar_",
"#": "_hash_",
"@": "_at_",
"!": "_bang_",
"?": "_question_",
"(": "_lparen_",
")": "_rparen_",
"[": "_lbracket_",
"]": "_rbracket_",
"{": "_lbrace_",
"}": "_rbrace_",
"<": "_langle_",
">": "_rangle_",
}
# Apply replacements and clean up
sanitized = relationship
for old, new in char_map.items():
sanitized = sanitized.replace(old, new)
return re.sub(r"_+", "_", sanitized).strip("_")