first commit

This commit is contained in:
2026-03-06 21:11:10 +08:00
commit 927b8a6cac
144 changed files with 26301 additions and 0 deletions

View File

View File

@@ -0,0 +1,57 @@
from typing import Any, Dict, Optional
from pydantic import BaseModel, ConfigDict, Field, model_validator
class AzureAISearchConfig(BaseModel):
collection_name: str = Field("mem0", description="Name of the collection")
service_name: str = Field(None, description="Azure AI Search service name")
api_key: str = Field(None, description="API key for the Azure AI Search service")
embedding_model_dims: int = Field(1536, description="Dimension of the embedding vector")
compression_type: Optional[str] = Field(
None, description="Type of vector compression to use. Options: 'scalar', 'binary', or None"
)
use_float16: bool = Field(
False,
description="Whether to store vectors in half precision (Edm.Half) instead of full precision (Edm.Single)",
)
hybrid_search: bool = Field(
False, description="Whether to use hybrid search. If True, vector_filter_mode must be 'preFilter'"
)
vector_filter_mode: Optional[str] = Field(
"preFilter", description="Mode for vector filtering. Options: 'preFilter', 'postFilter'"
)
@model_validator(mode="before")
@classmethod
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
allowed_fields = set(cls.model_fields.keys())
input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields
# Check for use_compression to provide a helpful error
if "use_compression" in extra_fields:
raise ValueError(
"The parameter 'use_compression' is no longer supported. "
"Please use 'compression_type=\"scalar\"' instead of 'use_compression=True' "
"or 'compression_type=None' instead of 'use_compression=False'."
)
if extra_fields:
raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. "
f"Please input only the following fields: {', '.join(allowed_fields)}"
)
# Validate compression_type values
if "compression_type" in values and values["compression_type"] is not None:
valid_types = ["scalar", "binary"]
if values["compression_type"].lower() not in valid_types:
raise ValueError(
f"Invalid compression_type: {values['compression_type']}. "
f"Must be one of: {', '.join(valid_types)}, or None"
)
return values
model_config = ConfigDict(arbitrary_types_allowed=True)

View File

@@ -0,0 +1,84 @@
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, model_validator
class AzureMySQLConfig(BaseModel):
"""Configuration for Azure MySQL vector database."""
host: str = Field(..., description="MySQL server host (e.g., myserver.mysql.database.azure.com)")
port: int = Field(3306, description="MySQL server port")
user: str = Field(..., description="Database user")
password: Optional[str] = Field(None, description="Database password (not required if using Azure credential)")
database: str = Field(..., description="Database name")
collection_name: str = Field("mem0", description="Collection/table name")
embedding_model_dims: int = Field(1536, description="Dimensions of the embedding model")
use_azure_credential: bool = Field(
False,
description="Use Azure DefaultAzureCredential for authentication instead of password"
)
ssl_ca: Optional[str] = Field(None, description="Path to SSL CA certificate")
ssl_disabled: bool = Field(False, description="Disable SSL connection (not recommended for production)")
minconn: int = Field(1, description="Minimum number of connections in the pool")
maxconn: int = Field(5, description="Maximum number of connections in the pool")
connection_pool: Optional[Any] = Field(
None,
description="Pre-configured connection pool object (overrides other connection parameters)"
)
@model_validator(mode="before")
@classmethod
def check_auth(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate authentication parameters."""
# If connection_pool is provided, skip validation
if values.get("connection_pool") is not None:
return values
use_azure_credential = values.get("use_azure_credential", False)
password = values.get("password")
# Either password or Azure credential must be provided
if not use_azure_credential and not password:
raise ValueError(
"Either 'password' must be provided or 'use_azure_credential' must be set to True"
)
return values
@model_validator(mode="before")
@classmethod
def check_required_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate required fields."""
# If connection_pool is provided, skip validation of individual parameters
if values.get("connection_pool") is not None:
return values
required_fields = ["host", "user", "database"]
missing_fields = [field for field in required_fields if not values.get(field)]
if missing_fields:
raise ValueError(
f"Missing required fields: {', '.join(missing_fields)}. "
f"These fields are required when not using a pre-configured connection_pool."
)
return values
@model_validator(mode="before")
@classmethod
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate that no extra fields are provided."""
allowed_fields = set(cls.model_fields.keys())
input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields
if extra_fields:
raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. "
f"Please input only the following fields: {', '.join(allowed_fields)}"
)
return values
class Config:
arbitrary_types_allowed = True

View File

@@ -0,0 +1,27 @@
from typing import Any, Dict
from pydantic import BaseModel, ConfigDict, Field, model_validator
class BaiduDBConfig(BaseModel):
endpoint: str = Field("http://localhost:8287", description="Endpoint URL for Baidu VectorDB")
account: str = Field("root", description="Account for Baidu VectorDB")
api_key: str = Field(None, description="API Key for Baidu VectorDB")
database_name: str = Field("mem0", description="Name of the database")
table_name: str = Field("mem0", description="Name of the table")
embedding_model_dims: int = Field(1536, description="Dimensions of the embedding model")
metric_type: str = Field("L2", description="Metric type for similarity search")
@model_validator(mode="before")
@classmethod
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
allowed_fields = set(cls.model_fields.keys())
input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields
if extra_fields:
raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
)
return values
model_config = ConfigDict(arbitrary_types_allowed=True)

View File

@@ -0,0 +1,77 @@
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, model_validator
class CassandraConfig(BaseModel):
"""Configuration for Apache Cassandra vector database."""
contact_points: List[str] = Field(
...,
description="List of contact point addresses (e.g., ['127.0.0.1', '127.0.0.2'])"
)
port: int = Field(9042, description="Cassandra port")
username: Optional[str] = Field(None, description="Database username")
password: Optional[str] = Field(None, description="Database password")
keyspace: str = Field("mem0", description="Keyspace name")
collection_name: str = Field("memories", description="Table name")
embedding_model_dims: int = Field(1536, description="Dimensions of the embedding model")
secure_connect_bundle: Optional[str] = Field(
None,
description="Path to secure connect bundle for DataStax Astra DB"
)
protocol_version: int = Field(4, description="CQL protocol version")
load_balancing_policy: Optional[Any] = Field(
None,
description="Custom load balancing policy object"
)
@model_validator(mode="before")
@classmethod
def check_auth(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate authentication parameters."""
username = values.get("username")
password = values.get("password")
# Both username and password must be provided together or not at all
if (username and not password) or (password and not username):
raise ValueError(
"Both 'username' and 'password' must be provided together for authentication"
)
return values
@model_validator(mode="before")
@classmethod
def check_connection_config(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate connection configuration."""
secure_connect_bundle = values.get("secure_connect_bundle")
contact_points = values.get("contact_points")
# Either secure_connect_bundle or contact_points must be provided
if not secure_connect_bundle and not contact_points:
raise ValueError(
"Either 'contact_points' or 'secure_connect_bundle' must be provided"
)
return values
@model_validator(mode="before")
@classmethod
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate that no extra fields are provided."""
allowed_fields = set(cls.model_fields.keys())
input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields
if extra_fields:
raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. "
f"Please input only the following fields: {', '.join(allowed_fields)}"
)
return values
class Config:
arbitrary_types_allowed = True

View File

@@ -0,0 +1,58 @@
from typing import Any, ClassVar, Dict, Optional
from pydantic import BaseModel, ConfigDict, Field, model_validator
class ChromaDbConfig(BaseModel):
try:
from chromadb.api.client import Client
except ImportError:
raise ImportError("The 'chromadb' library is required. Please install it using 'pip install chromadb'.")
Client: ClassVar[type] = Client
collection_name: str = Field("mem0", description="Default name for the collection/database")
client: Optional[Client] = Field(None, description="Existing ChromaDB client instance")
path: Optional[str] = Field(None, description="Path to the database directory")
host: Optional[str] = Field(None, description="Database connection remote host")
port: Optional[int] = Field(None, description="Database connection remote port")
# ChromaDB Cloud configuration
api_key: Optional[str] = Field(None, description="ChromaDB Cloud API key")
tenant: Optional[str] = Field(None, description="ChromaDB Cloud tenant ID")
@model_validator(mode="before")
def check_connection_config(cls, values):
host, port, path = values.get("host"), values.get("port"), values.get("path")
api_key, tenant = values.get("api_key"), values.get("tenant")
# Check if cloud configuration is provided
cloud_config = bool(api_key and tenant)
# If cloud configuration is provided, remove any default path that might have been added
if cloud_config and path == "/tmp/chroma":
values.pop("path", None)
return values
# Check if local/server configuration is provided (excluding default tmp path for cloud config)
local_config = bool(path and path != "/tmp/chroma") or bool(host and port)
if not cloud_config and not local_config:
raise ValueError("Either ChromaDB Cloud configuration (api_key, tenant) or local configuration (path or host/port) must be provided.")
if cloud_config and local_config:
raise ValueError("Cannot specify both cloud configuration and local configuration. Choose one.")
return values
@model_validator(mode="before")
@classmethod
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
allowed_fields = set(cls.model_fields.keys())
input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields
if extra_fields:
raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
)
return values
model_config = ConfigDict(arbitrary_types_allowed=True)

View File

@@ -0,0 +1,61 @@
from typing import Any, Dict, Optional
from pydantic import BaseModel, ConfigDict, Field, model_validator
from databricks.sdk.service.vectorsearch import EndpointType, VectorIndexType, PipelineType
class DatabricksConfig(BaseModel):
"""Configuration for Databricks Vector Search vector store."""
workspace_url: str = Field(..., description="Databricks workspace URL")
access_token: Optional[str] = Field(None, description="Personal access token for authentication")
client_id: Optional[str] = Field(None, description="Databricks Service principal client ID")
client_secret: Optional[str] = Field(None, description="Databricks Service principal client secret")
azure_client_id: Optional[str] = Field(None, description="Azure AD application client ID (for Azure Databricks)")
azure_client_secret: Optional[str] = Field(
None, description="Azure AD application client secret (for Azure Databricks)"
)
endpoint_name: str = Field(..., description="Vector search endpoint name")
catalog: str = Field(..., description="The Unity Catalog catalog name")
schema: str = Field(..., description="The Unity Catalog schama name")
table_name: str = Field(..., description="Source Delta table name")
collection_name: str = Field("mem0", description="Vector search index name")
index_type: VectorIndexType = Field("DELTA_SYNC", description="Index type: DELTA_SYNC or DIRECT_ACCESS")
embedding_model_endpoint_name: Optional[str] = Field(
None, description="Embedding model endpoint for Databricks-computed embeddings"
)
embedding_dimension: int = Field(1536, description="Vector embedding dimensions")
endpoint_type: EndpointType = Field("STANDARD", description="Endpoint type: STANDARD or STORAGE_OPTIMIZED")
pipeline_type: PipelineType = Field("TRIGGERED", description="Sync pipeline type: TRIGGERED or CONTINUOUS")
warehouse_name: Optional[str] = Field(None, description="Databricks SQL warehouse Name")
query_type: str = Field("ANN", description="Query type: `ANN` and `HYBRID`")
@model_validator(mode="before")
@classmethod
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
allowed_fields = set(cls.model_fields.keys())
input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields
if extra_fields:
raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
)
return values
@model_validator(mode="after")
def validate_authentication(self):
"""Validate that either access_token or service principal credentials are provided."""
has_token = self.access_token is not None
has_service_principal = (self.client_id is not None and self.client_secret is not None) or (
self.azure_client_id is not None and self.azure_client_secret is not None
)
if not has_token and not has_service_principal:
raise ValueError(
"Either access_token or both client_id/client_secret or azure_client_id/azure_client_secret must be provided"
)
return self
model_config = ConfigDict(arbitrary_types_allowed=True)

View File

@@ -0,0 +1,65 @@
from collections.abc import Callable
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, model_validator
class ElasticsearchConfig(BaseModel):
collection_name: str = Field("mem0", description="Name of the index")
host: str = Field("localhost", description="Elasticsearch host")
port: int = Field(9200, description="Elasticsearch port")
user: Optional[str] = Field(None, description="Username for authentication")
password: Optional[str] = Field(None, description="Password for authentication")
cloud_id: Optional[str] = Field(None, description="Cloud ID for Elastic Cloud")
api_key: Optional[str] = Field(None, description="API key for authentication")
embedding_model_dims: int = Field(1536, description="Dimension of the embedding vector")
verify_certs: bool = Field(True, description="Verify SSL certificates")
use_ssl: bool = Field(True, description="Use SSL for connection")
auto_create_index: bool = Field(True, description="Automatically create index during initialization")
custom_search_query: Optional[Callable[[List[float], int, Optional[Dict]], Dict]] = Field(
None, description="Custom search query function. Parameters: (query, limit, filters) -> Dict"
)
headers: Optional[Dict[str, str]] = Field(None, description="Custom headers to include in requests")
@model_validator(mode="before")
@classmethod
def validate_auth(cls, values: Dict[str, Any]) -> Dict[str, Any]:
# Check if either cloud_id or host/port is provided
if not values.get("cloud_id") and not values.get("host"):
raise ValueError("Either cloud_id or host must be provided")
# Check if authentication is provided
if not any([values.get("api_key"), (values.get("user") and values.get("password"))]):
raise ValueError("Either api_key or user/password must be provided")
return values
@model_validator(mode="before")
@classmethod
def validate_headers(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate headers format and content"""
headers = values.get("headers")
if headers is not None:
# Check if headers is a dictionary
if not isinstance(headers, dict):
raise ValueError("headers must be a dictionary")
# Check if all keys and values are strings
for key, value in headers.items():
if not isinstance(key, str) or not isinstance(value, str):
raise ValueError("All header keys and values must be strings")
return values
@model_validator(mode="before")
@classmethod
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
allowed_fields = set(cls.model_fields.keys())
input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields
if extra_fields:
raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. "
f"Please input only the following fields: {', '.join(allowed_fields)}"
)
return values

View File

@@ -0,0 +1,37 @@
from typing import Any, Dict, Optional
from pydantic import BaseModel, ConfigDict, Field, model_validator
class FAISSConfig(BaseModel):
collection_name: str = Field("mem0", description="Default name for the collection")
path: Optional[str] = Field(None, description="Path to store FAISS index and metadata")
distance_strategy: str = Field(
"euclidean", description="Distance strategy to use. Options: 'euclidean', 'inner_product', 'cosine'"
)
normalize_L2: bool = Field(
False, description="Whether to normalize L2 vectors (only applicable for euclidean distance)"
)
embedding_model_dims: int = Field(1536, description="Dimension of the embedding vector")
@model_validator(mode="before")
@classmethod
def validate_distance_strategy(cls, values: Dict[str, Any]) -> Dict[str, Any]:
distance_strategy = values.get("distance_strategy")
if distance_strategy and distance_strategy not in ["euclidean", "inner_product", "cosine"]:
raise ValueError("Invalid distance_strategy. Must be one of: 'euclidean', 'inner_product', 'cosine'")
return values
@model_validator(mode="before")
@classmethod
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
allowed_fields = set(cls.model_fields.keys())
input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields
if extra_fields:
raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
)
return values
model_config = ConfigDict(arbitrary_types_allowed=True)

View File

@@ -0,0 +1,30 @@
from typing import Any, ClassVar, Dict
from pydantic import BaseModel, ConfigDict, Field, model_validator
class LangchainConfig(BaseModel):
try:
from langchain_community.vectorstores import VectorStore
except ImportError:
raise ImportError(
"The 'langchain_community' library is required. Please install it using 'pip install langchain_community'."
)
VectorStore: ClassVar[type] = VectorStore
client: VectorStore = Field(description="Existing VectorStore instance")
collection_name: str = Field("mem0", description="Name of the collection to use")
@model_validator(mode="before")
@classmethod
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
allowed_fields = set(cls.model_fields.keys())
input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields
if extra_fields:
raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
)
return values
model_config = ConfigDict(arbitrary_types_allowed=True)

View File

@@ -0,0 +1,42 @@
from enum import Enum
from typing import Any, Dict
from pydantic import BaseModel, ConfigDict, Field, model_validator
class MetricType(str, Enum):
"""
Metric Constant for milvus/ zilliz server.
"""
def __str__(self) -> str:
return str(self.value)
L2 = "L2"
IP = "IP"
COSINE = "COSINE"
HAMMING = "HAMMING"
JACCARD = "JACCARD"
class MilvusDBConfig(BaseModel):
url: str = Field("http://localhost:19530", description="Full URL for Milvus/Zilliz server")
token: str = Field(None, description="Token for Zilliz server / local setup defaults to None.")
collection_name: str = Field("mem0", description="Name of the collection")
embedding_model_dims: int = Field(1536, description="Dimensions of the embedding model")
metric_type: str = Field("L2", description="Metric type for similarity search")
db_name: str = Field("", description="Name of the database")
@model_validator(mode="before")
@classmethod
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
allowed_fields = set(cls.model_fields.keys())
input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields
if extra_fields:
raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
)
return values
model_config = ConfigDict(arbitrary_types_allowed=True)

View File

@@ -0,0 +1,25 @@
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, model_validator
class MongoDBConfig(BaseModel):
"""Configuration for MongoDB vector database."""
db_name: str = Field("mem0_db", description="Name of the MongoDB database")
collection_name: str = Field("mem0", description="Name of the MongoDB collection")
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding vectors")
mongo_uri: str = Field("mongodb://localhost:27017", description="MongoDB URI. Default is mongodb://localhost:27017")
@model_validator(mode="before")
@classmethod
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
allowed_fields = set(cls.model_fields.keys())
input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields
if extra_fields:
raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. "
f"Please provide only the following fields: {', '.join(allowed_fields)}."
)
return values

View File

@@ -0,0 +1,27 @@
"""
Configuration for Amazon Neptune Analytics vector store.
This module provides configuration settings for integrating with Amazon Neptune Analytics
as a vector store backend for Mem0's memory layer.
"""
from pydantic import BaseModel, Field
class NeptuneAnalyticsConfig(BaseModel):
"""
Configuration class for Amazon Neptune Analytics vector store.
Amazon Neptune Analytics is a graph analytics engine that can be used as a vector store
for storing and retrieving memory embeddings in Mem0.
Attributes:
collection_name (str): Name of the collection to store vectors. Defaults to "mem0".
endpoint (str): Neptune Analytics graph endpoint URL or Graph ID for the runtime.
"""
collection_name: str = Field("mem0", description="Default name for the collection")
endpoint: str = Field("endpoint", description="Graph ID for the runtime")
model_config = {
"arbitrary_types_allowed": False,
}

View File

@@ -0,0 +1,41 @@
from typing import Any, Dict, Optional, Type, Union
from pydantic import BaseModel, Field, model_validator
class OpenSearchConfig(BaseModel):
collection_name: str = Field("mem0", description="Name of the index")
host: str = Field("localhost", description="OpenSearch host")
port: int = Field(9200, description="OpenSearch port")
user: Optional[str] = Field(None, description="Username for authentication")
password: Optional[str] = Field(None, description="Password for authentication")
api_key: Optional[str] = Field(None, description="API key for authentication (if applicable)")
embedding_model_dims: int = Field(1536, description="Dimension of the embedding vector")
verify_certs: bool = Field(False, description="Verify SSL certificates (default False for OpenSearch)")
use_ssl: bool = Field(False, description="Use SSL for connection (default False for OpenSearch)")
http_auth: Optional[object] = Field(None, description="HTTP authentication method / AWS SigV4")
connection_class: Optional[Union[str, Type]] = Field(
"RequestsHttpConnection", description="Connection class for OpenSearch"
)
pool_maxsize: int = Field(20, description="Maximum number of connections in the pool")
@model_validator(mode="before")
@classmethod
def validate_auth(cls, values: Dict[str, Any]) -> Dict[str, Any]:
# Check if host is provided
if not values.get("host"):
raise ValueError("Host must be provided for OpenSearch")
return values
@model_validator(mode="before")
@classmethod
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
allowed_fields = set(cls.model_fields.keys())
input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields
if extra_fields:
raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. Allowed fields: {', '.join(allowed_fields)}"
)
return values

View File

@@ -0,0 +1,52 @@
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, model_validator
class PGVectorConfig(BaseModel):
dbname: str = Field("postgres", description="Default name for the database")
collection_name: str = Field("mem0", description="Default name for the collection")
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model")
user: Optional[str] = Field(None, description="Database user")
password: Optional[str] = Field(None, description="Database password")
host: Optional[str] = Field(None, description="Database host. Default is localhost")
port: Optional[int] = Field(None, description="Database port. Default is 1536")
diskann: Optional[bool] = Field(False, description="Use diskann for approximate nearest neighbors search")
hnsw: Optional[bool] = Field(True, description="Use hnsw for faster search")
minconn: Optional[int] = Field(1, description="Minimum number of connections in the pool")
maxconn: Optional[int] = Field(5, description="Maximum number of connections in the pool")
# New SSL and connection options
sslmode: Optional[str] = Field(None, description="SSL mode for PostgreSQL connection (e.g., 'require', 'prefer', 'disable')")
connection_string: Optional[str] = Field(None, description="PostgreSQL connection string (overrides individual connection parameters)")
connection_pool: Optional[Any] = Field(None, description="psycopg connection pool object (overrides connection string and individual parameters)")
@model_validator(mode="before")
def check_auth_and_connection(cls, values):
# If connection_pool is provided, skip validation of individual connection parameters
if values.get("connection_pool") is not None:
return values
# If connection_string is provided, skip validation of individual connection parameters
if values.get("connection_string") is not None:
return values
# Otherwise, validate individual connection parameters
user, password = values.get("user"), values.get("password")
host, port = values.get("host"), values.get("port")
if not user and not password:
raise ValueError("Both 'user' and 'password' must be provided when not using connection_string.")
if not host and not port:
raise ValueError("Both 'host' and 'port' must be provided when not using connection_string.")
return values
@model_validator(mode="before")
@classmethod
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
allowed_fields = set(cls.model_fields.keys())
input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields
if extra_fields:
raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
)
return values

View File

@@ -0,0 +1,55 @@
import os
from typing import Any, Dict, Optional
from pydantic import BaseModel, ConfigDict, Field, model_validator
class PineconeConfig(BaseModel):
"""Configuration for Pinecone vector database."""
collection_name: str = Field("mem0", description="Name of the index/collection")
embedding_model_dims: int = Field(1536, description="Dimensions of the embedding model")
client: Optional[Any] = Field(None, description="Existing Pinecone client instance")
api_key: Optional[str] = Field(None, description="API key for Pinecone")
environment: Optional[str] = Field(None, description="Pinecone environment")
serverless_config: Optional[Dict[str, Any]] = Field(None, description="Configuration for serverless deployment")
pod_config: Optional[Dict[str, Any]] = Field(None, description="Configuration for pod-based deployment")
hybrid_search: bool = Field(False, description="Whether to enable hybrid search")
metric: str = Field("cosine", description="Distance metric for vector similarity")
batch_size: int = Field(100, description="Batch size for operations")
extra_params: Optional[Dict[str, Any]] = Field(None, description="Additional parameters for Pinecone client")
namespace: Optional[str] = Field(None, description="Namespace for the collection")
@model_validator(mode="before")
@classmethod
def check_api_key_or_client(cls, values: Dict[str, Any]) -> Dict[str, Any]:
api_key, client = values.get("api_key"), values.get("client")
if not api_key and not client and "PINECONE_API_KEY" not in os.environ:
raise ValueError(
"Either 'api_key' or 'client' must be provided, or PINECONE_API_KEY environment variable must be set."
)
return values
@model_validator(mode="before")
@classmethod
def check_pod_or_serverless(cls, values: Dict[str, Any]) -> Dict[str, Any]:
pod_config, serverless_config = values.get("pod_config"), values.get("serverless_config")
if pod_config and serverless_config:
raise ValueError(
"Both 'pod_config' and 'serverless_config' cannot be specified. Choose one deployment option."
)
return values
@model_validator(mode="before")
@classmethod
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
allowed_fields = set(cls.model_fields.keys())
input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields
if extra_fields:
raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
)
return values
model_config = ConfigDict(arbitrary_types_allowed=True)

View File

@@ -0,0 +1,47 @@
from typing import Any, ClassVar, Dict, Optional
from pydantic import BaseModel, ConfigDict, Field, model_validator
class QdrantConfig(BaseModel):
from qdrant_client import QdrantClient
QdrantClient: ClassVar[type] = QdrantClient
collection_name: str = Field("mem0", description="Name of the collection")
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model")
client: Optional[QdrantClient] = Field(None, description="Existing Qdrant client instance")
host: Optional[str] = Field(None, description="Host address for Qdrant server")
port: Optional[int] = Field(None, description="Port for Qdrant server")
path: Optional[str] = Field("/tmp/qdrant", description="Path for local Qdrant database")
url: Optional[str] = Field(None, description="Full URL for Qdrant server")
api_key: Optional[str] = Field(None, description="API key for Qdrant server")
on_disk: Optional[bool] = Field(False, description="Enables persistent storage")
@model_validator(mode="before")
@classmethod
def check_host_port_or_path(cls, values: Dict[str, Any]) -> Dict[str, Any]:
host, port, path, url, api_key = (
values.get("host"),
values.get("port"),
values.get("path"),
values.get("url"),
values.get("api_key"),
)
if not path and not (host and port) and not (url and api_key):
raise ValueError("Either 'host' and 'port' or 'url' and 'api_key' or 'path' must be provided.")
return values
@model_validator(mode="before")
@classmethod
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
allowed_fields = set(cls.model_fields.keys())
input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields
if extra_fields:
raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
)
return values
model_config = ConfigDict(arbitrary_types_allowed=True)

View File

@@ -0,0 +1,24 @@
from typing import Any, Dict
from pydantic import BaseModel, ConfigDict, Field, model_validator
# TODO: Upgrade to latest pydantic version
class RedisDBConfig(BaseModel):
redis_url: str = Field(..., description="Redis URL")
collection_name: str = Field("mem0", description="Collection name")
embedding_model_dims: int = Field(1536, description="Embedding model dimensions")
@model_validator(mode="before")
@classmethod
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
allowed_fields = set(cls.model_fields.keys())
input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields
if extra_fields:
raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
)
return values
model_config = ConfigDict(arbitrary_types_allowed=True)

View File

@@ -0,0 +1,28 @@
from typing import Any, Dict, Optional
from pydantic import BaseModel, ConfigDict, Field, model_validator
class S3VectorsConfig(BaseModel):
vector_bucket_name: str = Field(description="Name of the S3 Vector bucket")
collection_name: str = Field("mem0", description="Name of the vector index")
embedding_model_dims: int = Field(1536, description="Dimension of the embedding vector")
distance_metric: str = Field(
"cosine",
description="Distance metric for similarity search. Options: 'cosine', 'euclidean'",
)
region_name: Optional[str] = Field(None, description="AWS region for the S3 Vectors client")
@model_validator(mode="before")
@classmethod
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
allowed_fields = set(cls.model_fields.keys())
input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields
if extra_fields:
raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
)
return values
model_config = ConfigDict(arbitrary_types_allowed=True)

View File

@@ -0,0 +1,44 @@
from enum import Enum
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, model_validator
class IndexMethod(str, Enum):
AUTO = "auto"
HNSW = "hnsw"
IVFFLAT = "ivfflat"
class IndexMeasure(str, Enum):
COSINE = "cosine_distance"
L2 = "l2_distance"
L1 = "l1_distance"
MAX_INNER_PRODUCT = "max_inner_product"
class SupabaseConfig(BaseModel):
connection_string: str = Field(..., description="PostgreSQL connection string")
collection_name: str = Field("mem0", description="Name for the vector collection")
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model")
index_method: Optional[IndexMethod] = Field(IndexMethod.AUTO, description="Index method to use")
index_measure: Optional[IndexMeasure] = Field(IndexMeasure.COSINE, description="Distance measure to use")
@model_validator(mode="before")
def check_connection_string(cls, values):
conn_str = values.get("connection_string")
if not conn_str or not conn_str.startswith("postgresql://"):
raise ValueError("A valid PostgreSQL connection string must be provided")
return values
@model_validator(mode="before")
@classmethod
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
allowed_fields = set(cls.model_fields.keys())
input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields
if extra_fields:
raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
)
return values

View File

@@ -0,0 +1,34 @@
import os
from typing import Any, ClassVar, Dict, Optional
from pydantic import BaseModel, ConfigDict, Field, model_validator
try:
from upstash_vector import Index
except ImportError:
raise ImportError("The 'upstash_vector' library is required. Please install it using 'pip install upstash_vector'.")
class UpstashVectorConfig(BaseModel):
Index: ClassVar[type] = Index
url: Optional[str] = Field(None, description="URL for Upstash Vector index")
token: Optional[str] = Field(None, description="Token for Upstash Vector index")
client: Optional[Index] = Field(None, description="Existing `upstash_vector.Index` client instance")
collection_name: str = Field("mem0", description="Namespace to use for the index")
enable_embeddings: bool = Field(
False, description="Whether to use built-in upstash embeddings or not. Default is True."
)
@model_validator(mode="before")
@classmethod
def check_credentials_or_client(cls, values: Dict[str, Any]) -> Dict[str, Any]:
client = values.get("client")
url = values.get("url") or os.environ.get("UPSTASH_VECTOR_REST_URL")
token = values.get("token") or os.environ.get("UPSTASH_VECTOR_REST_TOKEN")
if not client and not (url and token):
raise ValueError("Either a client or URL and token must be provided.")
return values
model_config = ConfigDict(arbitrary_types_allowed=True)

View File

@@ -0,0 +1,15 @@
from pydantic import BaseModel
class ValkeyConfig(BaseModel):
"""Configuration for Valkey vector store."""
valkey_url: str
collection_name: str
embedding_model_dims: int
timezone: str = "UTC"
index_type: str = "hnsw" # Default to HNSW, can be 'hnsw' or 'flat'
# HNSW specific parameters with recommended defaults
hnsw_m: int = 16 # Number of connections per layer (default from Valkey docs)
hnsw_ef_construction: int = 200 # Search width during construction
hnsw_ef_runtime: int = 10 # Search width during queries

View File

@@ -0,0 +1,28 @@
from typing import Dict, Optional
from pydantic import BaseModel, ConfigDict, Field
class GoogleMatchingEngineConfig(BaseModel):
project_id: str = Field(description="Google Cloud project ID")
project_number: str = Field(description="Google Cloud project number")
region: str = Field(description="Google Cloud region")
endpoint_id: str = Field(description="Vertex AI Vector Search endpoint ID")
index_id: str = Field(description="Vertex AI Vector Search index ID")
deployment_index_id: str = Field(description="Deployment-specific index ID")
collection_name: Optional[str] = Field(None, description="Collection name, defaults to index_id")
credentials_path: Optional[str] = Field(None, description="Path to service account credentials JSON file")
service_account_json: Optional[Dict] = Field(None, description="Service account credentials as dictionary (alternative to credentials_path)")
vector_search_api_endpoint: Optional[str] = Field(None, description="Vector search API endpoint")
model_config = ConfigDict(extra="forbid")
def __init__(self, **kwargs):
super().__init__(**kwargs)
if not self.collection_name:
self.collection_name = self.index_id
def model_post_init(self, _context) -> None:
"""Set collection_name to index_id if not provided"""
if self.collection_name is None:
self.collection_name = self.index_id

View File

@@ -0,0 +1,41 @@
from typing import Any, ClassVar, Dict, Optional
from pydantic import BaseModel, ConfigDict, Field, model_validator
class WeaviateConfig(BaseModel):
from weaviate import WeaviateClient
WeaviateClient: ClassVar[type] = WeaviateClient
collection_name: str = Field("mem0", description="Name of the collection")
embedding_model_dims: int = Field(1536, description="Dimensions of the embedding model")
cluster_url: Optional[str] = Field(None, description="URL for Weaviate server")
auth_client_secret: Optional[str] = Field(None, description="API key for Weaviate authentication")
additional_headers: Optional[Dict[str, str]] = Field(None, description="Additional headers for requests")
@model_validator(mode="before")
@classmethod
def check_connection_params(cls, values: Dict[str, Any]) -> Dict[str, Any]:
cluster_url = values.get("cluster_url")
if not cluster_url:
raise ValueError("'cluster_url' must be provided.")
return values
@model_validator(mode="before")
@classmethod
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
allowed_fields = set(cls.model_fields.keys())
input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields
if extra_fields:
raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
)
return values
model_config = ConfigDict(arbitrary_types_allowed=True)