first commit
This commit is contained in:
0
embeddings/__init__.py
Normal file
0
embeddings/__init__.py
Normal file
100
embeddings/aws_bedrock.py
Normal file
100
embeddings/aws_bedrock.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Literal, Optional
|
||||
|
||||
try:
|
||||
import boto3
|
||||
except ImportError:
|
||||
raise ImportError("The 'boto3' library is required. Please install it using 'pip install boto3'.")
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
||||
from mem0.embeddings.base import EmbeddingBase
|
||||
|
||||
|
||||
class AWSBedrockEmbedding(EmbeddingBase):
|
||||
"""AWS Bedrock embedding implementation.
|
||||
|
||||
This class uses AWS Bedrock's embedding models.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
self.config.model = self.config.model or "amazon.titan-embed-text-v1"
|
||||
|
||||
# Get AWS config from environment variables or use defaults
|
||||
aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", "")
|
||||
aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "")
|
||||
aws_session_token = os.environ.get("AWS_SESSION_TOKEN", "")
|
||||
|
||||
# Check if AWS config is provided in the config
|
||||
if hasattr(self.config, "aws_access_key_id"):
|
||||
aws_access_key = self.config.aws_access_key_id
|
||||
if hasattr(self.config, "aws_secret_access_key"):
|
||||
aws_secret_key = self.config.aws_secret_access_key
|
||||
|
||||
# AWS region is always set in config - see BaseEmbedderConfig
|
||||
aws_region = self.config.aws_region or "us-west-2"
|
||||
|
||||
self.client = boto3.client(
|
||||
"bedrock-runtime",
|
||||
region_name=aws_region,
|
||||
aws_access_key_id=aws_access_key if aws_access_key else None,
|
||||
aws_secret_access_key=aws_secret_key if aws_secret_key else None,
|
||||
aws_session_token=aws_session_token if aws_session_token else None,
|
||||
)
|
||||
|
||||
def _normalize_vector(self, embeddings):
|
||||
"""Normalize the embedding to a unit vector."""
|
||||
emb = np.array(embeddings)
|
||||
norm_emb = emb / np.linalg.norm(emb)
|
||||
return norm_emb.tolist()
|
||||
|
||||
def _get_embedding(self, text):
|
||||
"""Call out to Bedrock embedding endpoint."""
|
||||
|
||||
# Format input body based on the provider
|
||||
provider = self.config.model.split(".")[0]
|
||||
input_body = {}
|
||||
|
||||
if provider == "cohere":
|
||||
input_body["input_type"] = "search_document"
|
||||
input_body["texts"] = [text]
|
||||
else:
|
||||
# Amazon and other providers
|
||||
input_body["inputText"] = text
|
||||
|
||||
body = json.dumps(input_body)
|
||||
|
||||
try:
|
||||
response = self.client.invoke_model(
|
||||
body=body,
|
||||
modelId=self.config.model,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
|
||||
response_body = json.loads(response.get("body").read())
|
||||
|
||||
if provider == "cohere":
|
||||
embeddings = response_body.get("embeddings")[0]
|
||||
else:
|
||||
embeddings = response_body.get("embedding")
|
||||
|
||||
return embeddings
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error getting embedding from AWS Bedrock: {e}")
|
||||
|
||||
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
|
||||
"""
|
||||
Get the embedding for the given text using AWS Bedrock.
|
||||
|
||||
Args:
|
||||
text (str): The text to embed.
|
||||
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
|
||||
Returns:
|
||||
list: The embedding vector.
|
||||
"""
|
||||
return self._get_embedding(text)
|
||||
55
embeddings/azure_openai.py
Normal file
55
embeddings/azure_openai.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import os
|
||||
from typing import Literal, Optional
|
||||
|
||||
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
||||
from openai import AzureOpenAI
|
||||
|
||||
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
||||
from mem0.embeddings.base import EmbeddingBase
|
||||
|
||||
SCOPE = "https://cognitiveservices.azure.com/.default"
|
||||
|
||||
|
||||
class AzureOpenAIEmbedding(EmbeddingBase):
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
api_key = self.config.azure_kwargs.api_key or os.getenv("EMBEDDING_AZURE_OPENAI_API_KEY")
|
||||
azure_deployment = self.config.azure_kwargs.azure_deployment or os.getenv("EMBEDDING_AZURE_DEPLOYMENT")
|
||||
azure_endpoint = self.config.azure_kwargs.azure_endpoint or os.getenv("EMBEDDING_AZURE_ENDPOINT")
|
||||
api_version = self.config.azure_kwargs.api_version or os.getenv("EMBEDDING_AZURE_API_VERSION")
|
||||
default_headers = self.config.azure_kwargs.default_headers
|
||||
|
||||
# If the API key is not provided or is a placeholder, use DefaultAzureCredential.
|
||||
if api_key is None or api_key == "" or api_key == "your-api-key":
|
||||
self.credential = DefaultAzureCredential()
|
||||
azure_ad_token_provider = get_bearer_token_provider(
|
||||
self.credential,
|
||||
SCOPE,
|
||||
)
|
||||
api_key = None
|
||||
else:
|
||||
azure_ad_token_provider = None
|
||||
|
||||
self.client = AzureOpenAI(
|
||||
azure_deployment=azure_deployment,
|
||||
azure_endpoint=azure_endpoint,
|
||||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
api_version=api_version,
|
||||
api_key=api_key,
|
||||
http_client=self.config.http_client,
|
||||
default_headers=default_headers,
|
||||
)
|
||||
|
||||
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
|
||||
"""
|
||||
Get the embedding for the given text using OpenAI.
|
||||
|
||||
Args:
|
||||
text (str): The text to embed.
|
||||
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
|
||||
Returns:
|
||||
list: The embedding vector.
|
||||
"""
|
||||
text = text.replace("\n", " ")
|
||||
return self.client.embeddings.create(input=[text], model=self.config.model).data[0].embedding
|
||||
31
embeddings/base.py
Normal file
31
embeddings/base.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Literal, Optional
|
||||
|
||||
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
||||
|
||||
|
||||
class EmbeddingBase(ABC):
|
||||
"""Initialized a base embedding class
|
||||
|
||||
:param config: Embedding configuration option class, defaults to None
|
||||
:type config: Optional[BaseEmbedderConfig], optional
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
if config is None:
|
||||
self.config = BaseEmbedderConfig()
|
||||
else:
|
||||
self.config = config
|
||||
|
||||
@abstractmethod
|
||||
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]]):
|
||||
"""
|
||||
Get the embedding for the given text.
|
||||
|
||||
Args:
|
||||
text (str): The text to embed.
|
||||
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
|
||||
Returns:
|
||||
list: The embedding vector.
|
||||
"""
|
||||
pass
|
||||
31
embeddings/configs.py
Normal file
31
embeddings/configs.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class EmbedderConfig(BaseModel):
|
||||
provider: str = Field(
|
||||
description="Provider of the embedding model (e.g., 'ollama', 'openai')",
|
||||
default="openai",
|
||||
)
|
||||
config: Optional[dict] = Field(description="Configuration for the specific embedding model", default={})
|
||||
|
||||
@field_validator("config")
|
||||
def validate_config(cls, v, values):
|
||||
provider = values.data.get("provider")
|
||||
if provider in [
|
||||
"openai",
|
||||
"ollama",
|
||||
"huggingface",
|
||||
"azure_openai",
|
||||
"gemini",
|
||||
"vertexai",
|
||||
"together",
|
||||
"lmstudio",
|
||||
"langchain",
|
||||
"aws_bedrock",
|
||||
"fastembed",
|
||||
]:
|
||||
return v
|
||||
else:
|
||||
raise ValueError(f"Unsupported embedding provider: {provider}")
|
||||
29
embeddings/fastembed.py
Normal file
29
embeddings/fastembed.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from typing import Optional, Literal
|
||||
|
||||
from mem0.embeddings.base import EmbeddingBase
|
||||
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
||||
|
||||
try:
|
||||
from fastembed import TextEmbedding
|
||||
except ImportError:
|
||||
raise ImportError("FastEmbed is not installed. Please install it using `pip install fastembed`")
|
||||
|
||||
class FastEmbedEmbedding(EmbeddingBase):
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
self.config.model = self.config.model or "thenlper/gte-large"
|
||||
self.dense_model = TextEmbedding(model_name = self.config.model)
|
||||
|
||||
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
|
||||
"""
|
||||
Convert the text to embeddings using FastEmbed running in the Onnx runtime
|
||||
Args:
|
||||
text (str): The text to embed.
|
||||
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
|
||||
Returns:
|
||||
list: The embedding vector.
|
||||
"""
|
||||
text = text.replace("\n", " ")
|
||||
embeddings = list(self.dense_model.embed(text))
|
||||
return embeddings[0]
|
||||
39
embeddings/gemini.py
Normal file
39
embeddings/gemini.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import os
|
||||
from typing import Literal, Optional
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
|
||||
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
||||
from mem0.embeddings.base import EmbeddingBase
|
||||
|
||||
|
||||
class GoogleGenAIEmbedding(EmbeddingBase):
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
self.config.model = self.config.model or "models/text-embedding-004"
|
||||
self.config.embedding_dims = self.config.embedding_dims or self.config.output_dimensionality or 768
|
||||
|
||||
api_key = self.config.api_key or os.getenv("GOOGLE_API_KEY")
|
||||
|
||||
self.client = genai.Client(api_key=api_key)
|
||||
|
||||
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
|
||||
"""
|
||||
Get the embedding for the given text using Google Generative AI.
|
||||
Args:
|
||||
text (str): The text to embed.
|
||||
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
|
||||
Returns:
|
||||
list: The embedding vector.
|
||||
"""
|
||||
text = text.replace("\n", " ")
|
||||
|
||||
# Create config for embedding parameters
|
||||
config = types.EmbedContentConfig(output_dimensionality=self.config.embedding_dims)
|
||||
|
||||
# Call the embed_content method with the correct parameters
|
||||
response = self.client.models.embed_content(model=self.config.model, contents=text, config=config)
|
||||
|
||||
return response.embeddings[0].values
|
||||
44
embeddings/huggingface.py
Normal file
44
embeddings/huggingface.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import logging
|
||||
from typing import Literal, Optional
|
||||
|
||||
from openai import OpenAI
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
||||
from mem0.embeddings.base import EmbeddingBase
|
||||
|
||||
logging.getLogger("transformers").setLevel(logging.WARNING)
|
||||
logging.getLogger("sentence_transformers").setLevel(logging.WARNING)
|
||||
logging.getLogger("huggingface_hub").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
class HuggingFaceEmbedding(EmbeddingBase):
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
if config.huggingface_base_url:
|
||||
self.client = OpenAI(base_url=config.huggingface_base_url)
|
||||
self.config.model = self.config.model or "tei"
|
||||
else:
|
||||
self.config.model = self.config.model or "multi-qa-MiniLM-L6-cos-v1"
|
||||
|
||||
self.model = SentenceTransformer(self.config.model, **self.config.model_kwargs)
|
||||
|
||||
self.config.embedding_dims = self.config.embedding_dims or self.model.get_sentence_embedding_dimension()
|
||||
|
||||
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
|
||||
"""
|
||||
Get the embedding for the given text using Hugging Face.
|
||||
|
||||
Args:
|
||||
text (str): The text to embed.
|
||||
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
|
||||
Returns:
|
||||
list: The embedding vector.
|
||||
"""
|
||||
if self.config.huggingface_base_url:
|
||||
return self.client.embeddings.create(
|
||||
input=text, model=self.config.model, **self.config.model_kwargs
|
||||
).data[0].embedding
|
||||
else:
|
||||
return self.model.encode(text, convert_to_numpy=True).tolist()
|
||||
35
embeddings/langchain.py
Normal file
35
embeddings/langchain.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from typing import Literal, Optional
|
||||
|
||||
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
||||
from mem0.embeddings.base import EmbeddingBase
|
||||
|
||||
try:
|
||||
from langchain.embeddings.base import Embeddings
|
||||
except ImportError:
|
||||
raise ImportError("langchain is not installed. Please install it using `pip install langchain`")
|
||||
|
||||
|
||||
class LangchainEmbedding(EmbeddingBase):
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
if self.config.model is None:
|
||||
raise ValueError("`model` parameter is required")
|
||||
|
||||
if not isinstance(self.config.model, Embeddings):
|
||||
raise ValueError("`model` must be an instance of Embeddings")
|
||||
|
||||
self.langchain_model = self.config.model
|
||||
|
||||
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
|
||||
"""
|
||||
Get the embedding for the given text using Langchain.
|
||||
|
||||
Args:
|
||||
text (str): The text to embed.
|
||||
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
|
||||
Returns:
|
||||
list: The embedding vector.
|
||||
"""
|
||||
|
||||
return self.langchain_model.embed_query(text)
|
||||
29
embeddings/lmstudio.py
Normal file
29
embeddings/lmstudio.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from typing import Literal, Optional
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
||||
from mem0.embeddings.base import EmbeddingBase
|
||||
|
||||
|
||||
class LMStudioEmbedding(EmbeddingBase):
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
self.config.model = self.config.model or "nomic-ai/nomic-embed-text-v1.5-GGUF/nomic-embed-text-v1.5.f16.gguf"
|
||||
self.config.embedding_dims = self.config.embedding_dims or 1536
|
||||
self.config.api_key = self.config.api_key or "lm-studio"
|
||||
|
||||
self.client = OpenAI(base_url=self.config.lmstudio_base_url, api_key=self.config.api_key)
|
||||
|
||||
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
|
||||
"""
|
||||
Get the embedding for the given text using LM Studio.
|
||||
Args:
|
||||
text (str): The text to embed.
|
||||
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
|
||||
Returns:
|
||||
list: The embedding vector.
|
||||
"""
|
||||
text = text.replace("\n", " ")
|
||||
return self.client.embeddings.create(input=[text], model=self.config.model).data[0].embedding
|
||||
11
embeddings/mock.py
Normal file
11
embeddings/mock.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from typing import Literal, Optional
|
||||
|
||||
from mem0.embeddings.base import EmbeddingBase
|
||||
|
||||
|
||||
class MockEmbeddings(EmbeddingBase):
|
||||
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
|
||||
"""
|
||||
Generate a mock embedding with dimension of 10.
|
||||
"""
|
||||
return [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
|
||||
53
embeddings/ollama.py
Normal file
53
embeddings/ollama.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Literal, Optional
|
||||
|
||||
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
||||
from mem0.embeddings.base import EmbeddingBase
|
||||
|
||||
try:
|
||||
from ollama import Client
|
||||
except ImportError:
|
||||
user_input = input("The 'ollama' library is required. Install it now? [y/N]: ")
|
||||
if user_input.lower() == "y":
|
||||
try:
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "ollama"])
|
||||
from ollama import Client
|
||||
except subprocess.CalledProcessError:
|
||||
print("Failed to install 'ollama'. Please install it manually using 'pip install ollama'.")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print("The required 'ollama' library is not installed.")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
class OllamaEmbedding(EmbeddingBase):
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
self.config.model = self.config.model or "nomic-embed-text"
|
||||
self.config.embedding_dims = self.config.embedding_dims or 512
|
||||
|
||||
self.client = Client(host=self.config.ollama_base_url)
|
||||
self._ensure_model_exists()
|
||||
|
||||
def _ensure_model_exists(self):
|
||||
"""
|
||||
Ensure the specified model exists locally. If not, pull it from Ollama.
|
||||
"""
|
||||
local_models = self.client.list()["models"]
|
||||
if not any(model.get("name") == self.config.model or model.get("model") == self.config.model for model in local_models):
|
||||
self.client.pull(self.config.model)
|
||||
|
||||
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
|
||||
"""
|
||||
Get the embedding for the given text using Ollama.
|
||||
|
||||
Args:
|
||||
text (str): The text to embed.
|
||||
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
|
||||
Returns:
|
||||
list: The embedding vector.
|
||||
"""
|
||||
response = self.client.embeddings(model=self.config.model, prompt=text)
|
||||
return response["embedding"]
|
||||
49
embeddings/openai.py
Normal file
49
embeddings/openai.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import os
|
||||
import warnings
|
||||
from typing import Literal, Optional
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
||||
from mem0.embeddings.base import EmbeddingBase
|
||||
|
||||
|
||||
class OpenAIEmbedding(EmbeddingBase):
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
self.config.model = self.config.model or "text-embedding-3-small"
|
||||
self.config.embedding_dims = self.config.embedding_dims or 1536
|
||||
|
||||
api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
|
||||
base_url = (
|
||||
self.config.openai_base_url
|
||||
or os.getenv("OPENAI_API_BASE")
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
if os.environ.get("OPENAI_API_BASE"):
|
||||
warnings.warn(
|
||||
"The environment variable 'OPENAI_API_BASE' is deprecated and will be removed in the 0.1.80. "
|
||||
"Please use 'OPENAI_BASE_URL' instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
|
||||
"""
|
||||
Get the embedding for the given text using OpenAI.
|
||||
|
||||
Args:
|
||||
text (str): The text to embed.
|
||||
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
|
||||
Returns:
|
||||
list: The embedding vector.
|
||||
"""
|
||||
text = text.replace("\n", " ")
|
||||
return (
|
||||
self.client.embeddings.create(input=[text], model=self.config.model, dimensions=self.config.embedding_dims)
|
||||
.data[0]
|
||||
.embedding
|
||||
)
|
||||
31
embeddings/together.py
Normal file
31
embeddings/together.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import os
|
||||
from typing import Literal, Optional
|
||||
|
||||
from together import Together
|
||||
|
||||
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
||||
from mem0.embeddings.base import EmbeddingBase
|
||||
|
||||
|
||||
class TogetherEmbedding(EmbeddingBase):
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
self.config.model = self.config.model or "togethercomputer/m2-bert-80M-8k-retrieval"
|
||||
api_key = self.config.api_key or os.getenv("TOGETHER_API_KEY")
|
||||
# TODO: check if this is correct
|
||||
self.config.embedding_dims = self.config.embedding_dims or 768
|
||||
self.client = Together(api_key=api_key)
|
||||
|
||||
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
|
||||
"""
|
||||
Get the embedding for the given text using OpenAI.
|
||||
|
||||
Args:
|
||||
text (str): The text to embed.
|
||||
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
|
||||
Returns:
|
||||
list: The embedding vector.
|
||||
"""
|
||||
|
||||
return self.client.embeddings.create(model=self.config.model, input=text).data[0].embedding
|
||||
64
embeddings/vertexai.py
Normal file
64
embeddings/vertexai.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import os
|
||||
from typing import Literal, Optional
|
||||
|
||||
from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel
|
||||
|
||||
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
||||
from mem0.embeddings.base import EmbeddingBase
|
||||
from mem0.utils.gcp_auth import GCPAuthenticator
|
||||
|
||||
|
||||
class VertexAIEmbedding(EmbeddingBase):
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
self.config.model = self.config.model or "text-embedding-004"
|
||||
self.config.embedding_dims = self.config.embedding_dims or 256
|
||||
|
||||
self.embedding_types = {
|
||||
"add": self.config.memory_add_embedding_type or "RETRIEVAL_DOCUMENT",
|
||||
"update": self.config.memory_update_embedding_type or "RETRIEVAL_DOCUMENT",
|
||||
"search": self.config.memory_search_embedding_type or "RETRIEVAL_QUERY",
|
||||
}
|
||||
|
||||
# Set up authentication using centralized GCP authenticator
|
||||
# This supports multiple authentication methods while preserving environment variable support
|
||||
try:
|
||||
GCPAuthenticator.setup_vertex_ai(
|
||||
service_account_json=getattr(self.config, 'google_service_account_json', None),
|
||||
credentials_path=self.config.vertex_credentials_json,
|
||||
project_id=getattr(self.config, 'google_project_id', None)
|
||||
)
|
||||
except Exception:
|
||||
# Fall back to original behavior for backward compatibility
|
||||
credentials_path = self.config.vertex_credentials_json
|
||||
if credentials_path:
|
||||
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = credentials_path
|
||||
elif not os.getenv("GOOGLE_APPLICATION_CREDENTIALS"):
|
||||
raise ValueError(
|
||||
"Google application credentials JSON is not provided. Please provide a valid JSON path or set the 'GOOGLE_APPLICATION_CREDENTIALS' environment variable."
|
||||
)
|
||||
|
||||
self.model = TextEmbeddingModel.from_pretrained(self.config.model)
|
||||
|
||||
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
|
||||
"""
|
||||
Get the embedding for the given text using Vertex AI.
|
||||
|
||||
Args:
|
||||
text (str): The text to embed.
|
||||
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
|
||||
Returns:
|
||||
list: The embedding vector.
|
||||
"""
|
||||
embedding_type = "SEMANTIC_SIMILARITY"
|
||||
if memory_action is not None:
|
||||
if memory_action not in self.embedding_types:
|
||||
raise ValueError(f"Invalid memory action: {memory_action}")
|
||||
|
||||
embedding_type = self.embedding_types[memory_action]
|
||||
|
||||
text_input = TextEmbeddingInput(text=text, task_type=embedding_type)
|
||||
embeddings = self.model.get_embeddings(texts=[text_input], output_dimensionality=self.config.embedding_dims)
|
||||
|
||||
return embeddings[0].values
|
||||
Reference in New Issue
Block a user