first commit
This commit is contained in:
0
configs/vector_stores/__init__.py
Normal file
0
configs/vector_stores/__init__.py
Normal file
57
configs/vector_stores/azure_ai_search.py
Normal file
57
configs/vector_stores/azure_ai_search.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
|
||||
class AzureAISearchConfig(BaseModel):
|
||||
collection_name: str = Field("mem0", description="Name of the collection")
|
||||
service_name: str = Field(None, description="Azure AI Search service name")
|
||||
api_key: str = Field(None, description="API key for the Azure AI Search service")
|
||||
embedding_model_dims: int = Field(1536, description="Dimension of the embedding vector")
|
||||
compression_type: Optional[str] = Field(
|
||||
None, description="Type of vector compression to use. Options: 'scalar', 'binary', or None"
|
||||
)
|
||||
use_float16: bool = Field(
|
||||
False,
|
||||
description="Whether to store vectors in half precision (Edm.Half) instead of full precision (Edm.Single)",
|
||||
)
|
||||
hybrid_search: bool = Field(
|
||||
False, description="Whether to use hybrid search. If True, vector_filter_mode must be 'preFilter'"
|
||||
)
|
||||
vector_filter_mode: Optional[str] = Field(
|
||||
"preFilter", description="Mode for vector filtering. Options: 'preFilter', 'postFilter'"
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
allowed_fields = set(cls.model_fields.keys())
|
||||
input_fields = set(values.keys())
|
||||
extra_fields = input_fields - allowed_fields
|
||||
|
||||
# Check for use_compression to provide a helpful error
|
||||
if "use_compression" in extra_fields:
|
||||
raise ValueError(
|
||||
"The parameter 'use_compression' is no longer supported. "
|
||||
"Please use 'compression_type=\"scalar\"' instead of 'use_compression=True' "
|
||||
"or 'compression_type=None' instead of 'use_compression=False'."
|
||||
)
|
||||
|
||||
if extra_fields:
|
||||
raise ValueError(
|
||||
f"Extra fields not allowed: {', '.join(extra_fields)}. "
|
||||
f"Please input only the following fields: {', '.join(allowed_fields)}"
|
||||
)
|
||||
|
||||
# Validate compression_type values
|
||||
if "compression_type" in values and values["compression_type"] is not None:
|
||||
valid_types = ["scalar", "binary"]
|
||||
if values["compression_type"].lower() not in valid_types:
|
||||
raise ValueError(
|
||||
f"Invalid compression_type: {values['compression_type']}. "
|
||||
f"Must be one of: {', '.join(valid_types)}, or None"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
84
configs/vector_stores/azure_mysql.py
Normal file
84
configs/vector_stores/azure_mysql.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class AzureMySQLConfig(BaseModel):
|
||||
"""Configuration for Azure MySQL vector database."""
|
||||
|
||||
host: str = Field(..., description="MySQL server host (e.g., myserver.mysql.database.azure.com)")
|
||||
port: int = Field(3306, description="MySQL server port")
|
||||
user: str = Field(..., description="Database user")
|
||||
password: Optional[str] = Field(None, description="Database password (not required if using Azure credential)")
|
||||
database: str = Field(..., description="Database name")
|
||||
collection_name: str = Field("mem0", description="Collection/table name")
|
||||
embedding_model_dims: int = Field(1536, description="Dimensions of the embedding model")
|
||||
use_azure_credential: bool = Field(
|
||||
False,
|
||||
description="Use Azure DefaultAzureCredential for authentication instead of password"
|
||||
)
|
||||
ssl_ca: Optional[str] = Field(None, description="Path to SSL CA certificate")
|
||||
ssl_disabled: bool = Field(False, description="Disable SSL connection (not recommended for production)")
|
||||
minconn: int = Field(1, description="Minimum number of connections in the pool")
|
||||
maxconn: int = Field(5, description="Maximum number of connections in the pool")
|
||||
connection_pool: Optional[Any] = Field(
|
||||
None,
|
||||
description="Pre-configured connection pool object (overrides other connection parameters)"
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_auth(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate authentication parameters."""
|
||||
# If connection_pool is provided, skip validation
|
||||
if values.get("connection_pool") is not None:
|
||||
return values
|
||||
|
||||
use_azure_credential = values.get("use_azure_credential", False)
|
||||
password = values.get("password")
|
||||
|
||||
# Either password or Azure credential must be provided
|
||||
if not use_azure_credential and not password:
|
||||
raise ValueError(
|
||||
"Either 'password' must be provided or 'use_azure_credential' must be set to True"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_required_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate required fields."""
|
||||
# If connection_pool is provided, skip validation of individual parameters
|
||||
if values.get("connection_pool") is not None:
|
||||
return values
|
||||
|
||||
required_fields = ["host", "user", "database"]
|
||||
missing_fields = [field for field in required_fields if not values.get(field)]
|
||||
|
||||
if missing_fields:
|
||||
raise ValueError(
|
||||
f"Missing required fields: {', '.join(missing_fields)}. "
|
||||
f"These fields are required when not using a pre-configured connection_pool."
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate that no extra fields are provided."""
|
||||
allowed_fields = set(cls.model_fields.keys())
|
||||
input_fields = set(values.keys())
|
||||
extra_fields = input_fields - allowed_fields
|
||||
|
||||
if extra_fields:
|
||||
raise ValueError(
|
||||
f"Extra fields not allowed: {', '.join(extra_fields)}. "
|
||||
f"Please input only the following fields: {', '.join(allowed_fields)}"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
27
configs/vector_stores/baidu.py
Normal file
27
configs/vector_stores/baidu.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
|
||||
class BaiduDBConfig(BaseModel):
|
||||
endpoint: str = Field("http://localhost:8287", description="Endpoint URL for Baidu VectorDB")
|
||||
account: str = Field("root", description="Account for Baidu VectorDB")
|
||||
api_key: str = Field(None, description="API Key for Baidu VectorDB")
|
||||
database_name: str = Field("mem0", description="Name of the database")
|
||||
table_name: str = Field("mem0", description="Name of the table")
|
||||
embedding_model_dims: int = Field(1536, description="Dimensions of the embedding model")
|
||||
metric_type: str = Field("L2", description="Metric type for similarity search")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
allowed_fields = set(cls.model_fields.keys())
|
||||
input_fields = set(values.keys())
|
||||
extra_fields = input_fields - allowed_fields
|
||||
if extra_fields:
|
||||
raise ValueError(
|
||||
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
|
||||
)
|
||||
return values
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
77
configs/vector_stores/cassandra.py
Normal file
77
configs/vector_stores/cassandra.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class CassandraConfig(BaseModel):
|
||||
"""Configuration for Apache Cassandra vector database."""
|
||||
|
||||
contact_points: List[str] = Field(
|
||||
...,
|
||||
description="List of contact point addresses (e.g., ['127.0.0.1', '127.0.0.2'])"
|
||||
)
|
||||
port: int = Field(9042, description="Cassandra port")
|
||||
username: Optional[str] = Field(None, description="Database username")
|
||||
password: Optional[str] = Field(None, description="Database password")
|
||||
keyspace: str = Field("mem0", description="Keyspace name")
|
||||
collection_name: str = Field("memories", description="Table name")
|
||||
embedding_model_dims: int = Field(1536, description="Dimensions of the embedding model")
|
||||
secure_connect_bundle: Optional[str] = Field(
|
||||
None,
|
||||
description="Path to secure connect bundle for DataStax Astra DB"
|
||||
)
|
||||
protocol_version: int = Field(4, description="CQL protocol version")
|
||||
load_balancing_policy: Optional[Any] = Field(
|
||||
None,
|
||||
description="Custom load balancing policy object"
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_auth(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate authentication parameters."""
|
||||
username = values.get("username")
|
||||
password = values.get("password")
|
||||
|
||||
# Both username and password must be provided together or not at all
|
||||
if (username and not password) or (password and not username):
|
||||
raise ValueError(
|
||||
"Both 'username' and 'password' must be provided together for authentication"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_connection_config(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate connection configuration."""
|
||||
secure_connect_bundle = values.get("secure_connect_bundle")
|
||||
contact_points = values.get("contact_points")
|
||||
|
||||
# Either secure_connect_bundle or contact_points must be provided
|
||||
if not secure_connect_bundle and not contact_points:
|
||||
raise ValueError(
|
||||
"Either 'contact_points' or 'secure_connect_bundle' must be provided"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate that no extra fields are provided."""
|
||||
allowed_fields = set(cls.model_fields.keys())
|
||||
input_fields = set(values.keys())
|
||||
extra_fields = input_fields - allowed_fields
|
||||
|
||||
if extra_fields:
|
||||
raise ValueError(
|
||||
f"Extra fields not allowed: {', '.join(extra_fields)}. "
|
||||
f"Please input only the following fields: {', '.join(allowed_fields)}"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
58
configs/vector_stores/chroma.py
Normal file
58
configs/vector_stores/chroma.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from typing import Any, ClassVar, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
|
||||
class ChromaDbConfig(BaseModel):
|
||||
try:
|
||||
from chromadb.api.client import Client
|
||||
except ImportError:
|
||||
raise ImportError("The 'chromadb' library is required. Please install it using 'pip install chromadb'.")
|
||||
Client: ClassVar[type] = Client
|
||||
|
||||
collection_name: str = Field("mem0", description="Default name for the collection/database")
|
||||
client: Optional[Client] = Field(None, description="Existing ChromaDB client instance")
|
||||
path: Optional[str] = Field(None, description="Path to the database directory")
|
||||
host: Optional[str] = Field(None, description="Database connection remote host")
|
||||
port: Optional[int] = Field(None, description="Database connection remote port")
|
||||
# ChromaDB Cloud configuration
|
||||
api_key: Optional[str] = Field(None, description="ChromaDB Cloud API key")
|
||||
tenant: Optional[str] = Field(None, description="ChromaDB Cloud tenant ID")
|
||||
|
||||
@model_validator(mode="before")
|
||||
def check_connection_config(cls, values):
|
||||
host, port, path = values.get("host"), values.get("port"), values.get("path")
|
||||
api_key, tenant = values.get("api_key"), values.get("tenant")
|
||||
|
||||
# Check if cloud configuration is provided
|
||||
cloud_config = bool(api_key and tenant)
|
||||
|
||||
# If cloud configuration is provided, remove any default path that might have been added
|
||||
if cloud_config and path == "/tmp/chroma":
|
||||
values.pop("path", None)
|
||||
return values
|
||||
|
||||
# Check if local/server configuration is provided (excluding default tmp path for cloud config)
|
||||
local_config = bool(path and path != "/tmp/chroma") or bool(host and port)
|
||||
|
||||
if not cloud_config and not local_config:
|
||||
raise ValueError("Either ChromaDB Cloud configuration (api_key, tenant) or local configuration (path or host/port) must be provided.")
|
||||
|
||||
if cloud_config and local_config:
|
||||
raise ValueError("Cannot specify both cloud configuration and local configuration. Choose one.")
|
||||
|
||||
return values
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
allowed_fields = set(cls.model_fields.keys())
|
||||
input_fields = set(values.keys())
|
||||
extra_fields = input_fields - allowed_fields
|
||||
if extra_fields:
|
||||
raise ValueError(
|
||||
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
|
||||
)
|
||||
return values
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
61
configs/vector_stores/databricks.py
Normal file
61
configs/vector_stores/databricks.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from databricks.sdk.service.vectorsearch import EndpointType, VectorIndexType, PipelineType
|
||||
|
||||
|
||||
class DatabricksConfig(BaseModel):
|
||||
"""Configuration for Databricks Vector Search vector store."""
|
||||
|
||||
workspace_url: str = Field(..., description="Databricks workspace URL")
|
||||
access_token: Optional[str] = Field(None, description="Personal access token for authentication")
|
||||
client_id: Optional[str] = Field(None, description="Databricks Service principal client ID")
|
||||
client_secret: Optional[str] = Field(None, description="Databricks Service principal client secret")
|
||||
azure_client_id: Optional[str] = Field(None, description="Azure AD application client ID (for Azure Databricks)")
|
||||
azure_client_secret: Optional[str] = Field(
|
||||
None, description="Azure AD application client secret (for Azure Databricks)"
|
||||
)
|
||||
endpoint_name: str = Field(..., description="Vector search endpoint name")
|
||||
catalog: str = Field(..., description="The Unity Catalog catalog name")
|
||||
schema: str = Field(..., description="The Unity Catalog schama name")
|
||||
table_name: str = Field(..., description="Source Delta table name")
|
||||
collection_name: str = Field("mem0", description="Vector search index name")
|
||||
index_type: VectorIndexType = Field("DELTA_SYNC", description="Index type: DELTA_SYNC or DIRECT_ACCESS")
|
||||
embedding_model_endpoint_name: Optional[str] = Field(
|
||||
None, description="Embedding model endpoint for Databricks-computed embeddings"
|
||||
)
|
||||
embedding_dimension: int = Field(1536, description="Vector embedding dimensions")
|
||||
endpoint_type: EndpointType = Field("STANDARD", description="Endpoint type: STANDARD or STORAGE_OPTIMIZED")
|
||||
pipeline_type: PipelineType = Field("TRIGGERED", description="Sync pipeline type: TRIGGERED or CONTINUOUS")
|
||||
warehouse_name: Optional[str] = Field(None, description="Databricks SQL warehouse Name")
|
||||
query_type: str = Field("ANN", description="Query type: `ANN` and `HYBRID`")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
allowed_fields = set(cls.model_fields.keys())
|
||||
input_fields = set(values.keys())
|
||||
extra_fields = input_fields - allowed_fields
|
||||
if extra_fields:
|
||||
raise ValueError(
|
||||
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
|
||||
)
|
||||
return values
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_authentication(self):
|
||||
"""Validate that either access_token or service principal credentials are provided."""
|
||||
has_token = self.access_token is not None
|
||||
has_service_principal = (self.client_id is not None and self.client_secret is not None) or (
|
||||
self.azure_client_id is not None and self.azure_client_secret is not None
|
||||
)
|
||||
|
||||
if not has_token and not has_service_principal:
|
||||
raise ValueError(
|
||||
"Either access_token or both client_id/client_secret or azure_client_id/azure_client_secret must be provided"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
65
configs/vector_stores/elasticsearch.py
Normal file
65
configs/vector_stores/elasticsearch.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class ElasticsearchConfig(BaseModel):
|
||||
collection_name: str = Field("mem0", description="Name of the index")
|
||||
host: str = Field("localhost", description="Elasticsearch host")
|
||||
port: int = Field(9200, description="Elasticsearch port")
|
||||
user: Optional[str] = Field(None, description="Username for authentication")
|
||||
password: Optional[str] = Field(None, description="Password for authentication")
|
||||
cloud_id: Optional[str] = Field(None, description="Cloud ID for Elastic Cloud")
|
||||
api_key: Optional[str] = Field(None, description="API key for authentication")
|
||||
embedding_model_dims: int = Field(1536, description="Dimension of the embedding vector")
|
||||
verify_certs: bool = Field(True, description="Verify SSL certificates")
|
||||
use_ssl: bool = Field(True, description="Use SSL for connection")
|
||||
auto_create_index: bool = Field(True, description="Automatically create index during initialization")
|
||||
custom_search_query: Optional[Callable[[List[float], int, Optional[Dict]], Dict]] = Field(
|
||||
None, description="Custom search query function. Parameters: (query, limit, filters) -> Dict"
|
||||
)
|
||||
headers: Optional[Dict[str, str]] = Field(None, description="Custom headers to include in requests")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_auth(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# Check if either cloud_id or host/port is provided
|
||||
if not values.get("cloud_id") and not values.get("host"):
|
||||
raise ValueError("Either cloud_id or host must be provided")
|
||||
|
||||
# Check if authentication is provided
|
||||
if not any([values.get("api_key"), (values.get("user") and values.get("password"))]):
|
||||
raise ValueError("Either api_key or user/password must be provided")
|
||||
|
||||
return values
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_headers(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate headers format and content"""
|
||||
headers = values.get("headers")
|
||||
if headers is not None:
|
||||
# Check if headers is a dictionary
|
||||
if not isinstance(headers, dict):
|
||||
raise ValueError("headers must be a dictionary")
|
||||
|
||||
# Check if all keys and values are strings
|
||||
for key, value in headers.items():
|
||||
if not isinstance(key, str) or not isinstance(value, str):
|
||||
raise ValueError("All header keys and values must be strings")
|
||||
|
||||
return values
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
allowed_fields = set(cls.model_fields.keys())
|
||||
input_fields = set(values.keys())
|
||||
extra_fields = input_fields - allowed_fields
|
||||
if extra_fields:
|
||||
raise ValueError(
|
||||
f"Extra fields not allowed: {', '.join(extra_fields)}. "
|
||||
f"Please input only the following fields: {', '.join(allowed_fields)}"
|
||||
)
|
||||
return values
|
||||
37
configs/vector_stores/faiss.py
Normal file
37
configs/vector_stores/faiss.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
|
||||
class FAISSConfig(BaseModel):
|
||||
collection_name: str = Field("mem0", description="Default name for the collection")
|
||||
path: Optional[str] = Field(None, description="Path to store FAISS index and metadata")
|
||||
distance_strategy: str = Field(
|
||||
"euclidean", description="Distance strategy to use. Options: 'euclidean', 'inner_product', 'cosine'"
|
||||
)
|
||||
normalize_L2: bool = Field(
|
||||
False, description="Whether to normalize L2 vectors (only applicable for euclidean distance)"
|
||||
)
|
||||
embedding_model_dims: int = Field(1536, description="Dimension of the embedding vector")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_distance_strategy(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
distance_strategy = values.get("distance_strategy")
|
||||
if distance_strategy and distance_strategy not in ["euclidean", "inner_product", "cosine"]:
|
||||
raise ValueError("Invalid distance_strategy. Must be one of: 'euclidean', 'inner_product', 'cosine'")
|
||||
return values
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
allowed_fields = set(cls.model_fields.keys())
|
||||
input_fields = set(values.keys())
|
||||
extra_fields = input_fields - allowed_fields
|
||||
if extra_fields:
|
||||
raise ValueError(
|
||||
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
|
||||
)
|
||||
return values
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
30
configs/vector_stores/langchain.py
Normal file
30
configs/vector_stores/langchain.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from typing import Any, ClassVar, Dict
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
|
||||
class LangchainConfig(BaseModel):
|
||||
try:
|
||||
from langchain_community.vectorstores import VectorStore
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The 'langchain_community' library is required. Please install it using 'pip install langchain_community'."
|
||||
)
|
||||
VectorStore: ClassVar[type] = VectorStore
|
||||
|
||||
client: VectorStore = Field(description="Existing VectorStore instance")
|
||||
collection_name: str = Field("mem0", description="Name of the collection to use")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
allowed_fields = set(cls.model_fields.keys())
|
||||
input_fields = set(values.keys())
|
||||
extra_fields = input_fields - allowed_fields
|
||||
if extra_fields:
|
||||
raise ValueError(
|
||||
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
|
||||
)
|
||||
return values
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
42
configs/vector_stores/milvus.py
Normal file
42
configs/vector_stores/milvus.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
|
||||
class MetricType(str, Enum):
|
||||
"""
|
||||
Metric Constant for milvus/ zilliz server.
|
||||
"""
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.value)
|
||||
|
||||
L2 = "L2"
|
||||
IP = "IP"
|
||||
COSINE = "COSINE"
|
||||
HAMMING = "HAMMING"
|
||||
JACCARD = "JACCARD"
|
||||
|
||||
|
||||
class MilvusDBConfig(BaseModel):
|
||||
url: str = Field("http://localhost:19530", description="Full URL for Milvus/Zilliz server")
|
||||
token: str = Field(None, description="Token for Zilliz server / local setup defaults to None.")
|
||||
collection_name: str = Field("mem0", description="Name of the collection")
|
||||
embedding_model_dims: int = Field(1536, description="Dimensions of the embedding model")
|
||||
metric_type: str = Field("L2", description="Metric type for similarity search")
|
||||
db_name: str = Field("", description="Name of the database")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
allowed_fields = set(cls.model_fields.keys())
|
||||
input_fields = set(values.keys())
|
||||
extra_fields = input_fields - allowed_fields
|
||||
if extra_fields:
|
||||
raise ValueError(
|
||||
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
|
||||
)
|
||||
return values
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
25
configs/vector_stores/mongodb.py
Normal file
25
configs/vector_stores/mongodb.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class MongoDBConfig(BaseModel):
|
||||
"""Configuration for MongoDB vector database."""
|
||||
|
||||
db_name: str = Field("mem0_db", description="Name of the MongoDB database")
|
||||
collection_name: str = Field("mem0", description="Name of the MongoDB collection")
|
||||
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding vectors")
|
||||
mongo_uri: str = Field("mongodb://localhost:27017", description="MongoDB URI. Default is mongodb://localhost:27017")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
allowed_fields = set(cls.model_fields.keys())
|
||||
input_fields = set(values.keys())
|
||||
extra_fields = input_fields - allowed_fields
|
||||
if extra_fields:
|
||||
raise ValueError(
|
||||
f"Extra fields not allowed: {', '.join(extra_fields)}. "
|
||||
f"Please provide only the following fields: {', '.join(allowed_fields)}."
|
||||
)
|
||||
return values
|
||||
27
configs/vector_stores/neptune.py
Normal file
27
configs/vector_stores/neptune.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""
|
||||
Configuration for Amazon Neptune Analytics vector store.
|
||||
|
||||
This module provides configuration settings for integrating with Amazon Neptune Analytics
|
||||
as a vector store backend for Mem0's memory layer.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class NeptuneAnalyticsConfig(BaseModel):
|
||||
"""
|
||||
Configuration class for Amazon Neptune Analytics vector store.
|
||||
|
||||
Amazon Neptune Analytics is a graph analytics engine that can be used as a vector store
|
||||
for storing and retrieving memory embeddings in Mem0.
|
||||
|
||||
Attributes:
|
||||
collection_name (str): Name of the collection to store vectors. Defaults to "mem0".
|
||||
endpoint (str): Neptune Analytics graph endpoint URL or Graph ID for the runtime.
|
||||
"""
|
||||
collection_name: str = Field("mem0", description="Default name for the collection")
|
||||
endpoint: str = Field("endpoint", description="Graph ID for the runtime")
|
||||
|
||||
model_config = {
|
||||
"arbitrary_types_allowed": False,
|
||||
}
|
||||
41
configs/vector_stores/opensearch.py
Normal file
41
configs/vector_stores/opensearch.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from typing import Any, Dict, Optional, Type, Union
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class OpenSearchConfig(BaseModel):
|
||||
collection_name: str = Field("mem0", description="Name of the index")
|
||||
host: str = Field("localhost", description="OpenSearch host")
|
||||
port: int = Field(9200, description="OpenSearch port")
|
||||
user: Optional[str] = Field(None, description="Username for authentication")
|
||||
password: Optional[str] = Field(None, description="Password for authentication")
|
||||
api_key: Optional[str] = Field(None, description="API key for authentication (if applicable)")
|
||||
embedding_model_dims: int = Field(1536, description="Dimension of the embedding vector")
|
||||
verify_certs: bool = Field(False, description="Verify SSL certificates (default False for OpenSearch)")
|
||||
use_ssl: bool = Field(False, description="Use SSL for connection (default False for OpenSearch)")
|
||||
http_auth: Optional[object] = Field(None, description="HTTP authentication method / AWS SigV4")
|
||||
connection_class: Optional[Union[str, Type]] = Field(
|
||||
"RequestsHttpConnection", description="Connection class for OpenSearch"
|
||||
)
|
||||
pool_maxsize: int = Field(20, description="Maximum number of connections in the pool")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_auth(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# Check if host is provided
|
||||
if not values.get("host"):
|
||||
raise ValueError("Host must be provided for OpenSearch")
|
||||
|
||||
return values
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
allowed_fields = set(cls.model_fields.keys())
|
||||
input_fields = set(values.keys())
|
||||
extra_fields = input_fields - allowed_fields
|
||||
if extra_fields:
|
||||
raise ValueError(
|
||||
f"Extra fields not allowed: {', '.join(extra_fields)}. Allowed fields: {', '.join(allowed_fields)}"
|
||||
)
|
||||
return values
|
||||
52
configs/vector_stores/pgvector.py
Normal file
52
configs/vector_stores/pgvector.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class PGVectorConfig(BaseModel):
|
||||
dbname: str = Field("postgres", description="Default name for the database")
|
||||
collection_name: str = Field("mem0", description="Default name for the collection")
|
||||
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model")
|
||||
user: Optional[str] = Field(None, description="Database user")
|
||||
password: Optional[str] = Field(None, description="Database password")
|
||||
host: Optional[str] = Field(None, description="Database host. Default is localhost")
|
||||
port: Optional[int] = Field(None, description="Database port. Default is 1536")
|
||||
diskann: Optional[bool] = Field(False, description="Use diskann for approximate nearest neighbors search")
|
||||
hnsw: Optional[bool] = Field(True, description="Use hnsw for faster search")
|
||||
minconn: Optional[int] = Field(1, description="Minimum number of connections in the pool")
|
||||
maxconn: Optional[int] = Field(5, description="Maximum number of connections in the pool")
|
||||
# New SSL and connection options
|
||||
sslmode: Optional[str] = Field(None, description="SSL mode for PostgreSQL connection (e.g., 'require', 'prefer', 'disable')")
|
||||
connection_string: Optional[str] = Field(None, description="PostgreSQL connection string (overrides individual connection parameters)")
|
||||
connection_pool: Optional[Any] = Field(None, description="psycopg connection pool object (overrides connection string and individual parameters)")
|
||||
|
||||
@model_validator(mode="before")
|
||||
def check_auth_and_connection(cls, values):
|
||||
# If connection_pool is provided, skip validation of individual connection parameters
|
||||
if values.get("connection_pool") is not None:
|
||||
return values
|
||||
|
||||
# If connection_string is provided, skip validation of individual connection parameters
|
||||
if values.get("connection_string") is not None:
|
||||
return values
|
||||
|
||||
# Otherwise, validate individual connection parameters
|
||||
user, password = values.get("user"), values.get("password")
|
||||
host, port = values.get("host"), values.get("port")
|
||||
if not user and not password:
|
||||
raise ValueError("Both 'user' and 'password' must be provided when not using connection_string.")
|
||||
if not host and not port:
|
||||
raise ValueError("Both 'host' and 'port' must be provided when not using connection_string.")
|
||||
return values
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
allowed_fields = set(cls.model_fields.keys())
|
||||
input_fields = set(values.keys())
|
||||
extra_fields = input_fields - allowed_fields
|
||||
if extra_fields:
|
||||
raise ValueError(
|
||||
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
|
||||
)
|
||||
return values
|
||||
55
configs/vector_stores/pinecone.py
Normal file
55
configs/vector_stores/pinecone.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
|
||||
class PineconeConfig(BaseModel):
|
||||
"""Configuration for Pinecone vector database."""
|
||||
|
||||
collection_name: str = Field("mem0", description="Name of the index/collection")
|
||||
embedding_model_dims: int = Field(1536, description="Dimensions of the embedding model")
|
||||
client: Optional[Any] = Field(None, description="Existing Pinecone client instance")
|
||||
api_key: Optional[str] = Field(None, description="API key for Pinecone")
|
||||
environment: Optional[str] = Field(None, description="Pinecone environment")
|
||||
serverless_config: Optional[Dict[str, Any]] = Field(None, description="Configuration for serverless deployment")
|
||||
pod_config: Optional[Dict[str, Any]] = Field(None, description="Configuration for pod-based deployment")
|
||||
hybrid_search: bool = Field(False, description="Whether to enable hybrid search")
|
||||
metric: str = Field("cosine", description="Distance metric for vector similarity")
|
||||
batch_size: int = Field(100, description="Batch size for operations")
|
||||
extra_params: Optional[Dict[str, Any]] = Field(None, description="Additional parameters for Pinecone client")
|
||||
namespace: Optional[str] = Field(None, description="Namespace for the collection")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_api_key_or_client(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
api_key, client = values.get("api_key"), values.get("client")
|
||||
if not api_key and not client and "PINECONE_API_KEY" not in os.environ:
|
||||
raise ValueError(
|
||||
"Either 'api_key' or 'client' must be provided, or PINECONE_API_KEY environment variable must be set."
|
||||
)
|
||||
return values
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_pod_or_serverless(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
pod_config, serverless_config = values.get("pod_config"), values.get("serverless_config")
|
||||
if pod_config and serverless_config:
|
||||
raise ValueError(
|
||||
"Both 'pod_config' and 'serverless_config' cannot be specified. Choose one deployment option."
|
||||
)
|
||||
return values
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
allowed_fields = set(cls.model_fields.keys())
|
||||
input_fields = set(values.keys())
|
||||
extra_fields = input_fields - allowed_fields
|
||||
if extra_fields:
|
||||
raise ValueError(
|
||||
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
|
||||
)
|
||||
return values
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
47
configs/vector_stores/qdrant.py
Normal file
47
configs/vector_stores/qdrant.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from typing import Any, ClassVar, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
|
||||
class QdrantConfig(BaseModel):
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
QdrantClient: ClassVar[type] = QdrantClient
|
||||
|
||||
collection_name: str = Field("mem0", description="Name of the collection")
|
||||
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model")
|
||||
client: Optional[QdrantClient] = Field(None, description="Existing Qdrant client instance")
|
||||
host: Optional[str] = Field(None, description="Host address for Qdrant server")
|
||||
port: Optional[int] = Field(None, description="Port for Qdrant server")
|
||||
path: Optional[str] = Field("/tmp/qdrant", description="Path for local Qdrant database")
|
||||
url: Optional[str] = Field(None, description="Full URL for Qdrant server")
|
||||
api_key: Optional[str] = Field(None, description="API key for Qdrant server")
|
||||
on_disk: Optional[bool] = Field(False, description="Enables persistent storage")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_host_port_or_path(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
host, port, path, url, api_key = (
|
||||
values.get("host"),
|
||||
values.get("port"),
|
||||
values.get("path"),
|
||||
values.get("url"),
|
||||
values.get("api_key"),
|
||||
)
|
||||
if not path and not (host and port) and not (url and api_key):
|
||||
raise ValueError("Either 'host' and 'port' or 'url' and 'api_key' or 'path' must be provided.")
|
||||
return values
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
allowed_fields = set(cls.model_fields.keys())
|
||||
input_fields = set(values.keys())
|
||||
extra_fields = input_fields - allowed_fields
|
||||
if extra_fields:
|
||||
raise ValueError(
|
||||
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
|
||||
)
|
||||
return values
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
24
configs/vector_stores/redis.py
Normal file
24
configs/vector_stores/redis.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
|
||||
# TODO: Upgrade to latest pydantic version
|
||||
class RedisDBConfig(BaseModel):
|
||||
redis_url: str = Field(..., description="Redis URL")
|
||||
collection_name: str = Field("mem0", description="Collection name")
|
||||
embedding_model_dims: int = Field(1536, description="Embedding model dimensions")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
allowed_fields = set(cls.model_fields.keys())
|
||||
input_fields = set(values.keys())
|
||||
extra_fields = input_fields - allowed_fields
|
||||
if extra_fields:
|
||||
raise ValueError(
|
||||
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
|
||||
)
|
||||
return values
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
28
configs/vector_stores/s3_vectors.py
Normal file
28
configs/vector_stores/s3_vectors.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
|
||||
class S3VectorsConfig(BaseModel):
|
||||
vector_bucket_name: str = Field(description="Name of the S3 Vector bucket")
|
||||
collection_name: str = Field("mem0", description="Name of the vector index")
|
||||
embedding_model_dims: int = Field(1536, description="Dimension of the embedding vector")
|
||||
distance_metric: str = Field(
|
||||
"cosine",
|
||||
description="Distance metric for similarity search. Options: 'cosine', 'euclidean'",
|
||||
)
|
||||
region_name: Optional[str] = Field(None, description="AWS region for the S3 Vectors client")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
allowed_fields = set(cls.model_fields.keys())
|
||||
input_fields = set(values.keys())
|
||||
extra_fields = input_fields - allowed_fields
|
||||
if extra_fields:
|
||||
raise ValueError(
|
||||
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
|
||||
)
|
||||
return values
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
44
configs/vector_stores/supabase.py
Normal file
44
configs/vector_stores/supabase.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class IndexMethod(str, Enum):
|
||||
AUTO = "auto"
|
||||
HNSW = "hnsw"
|
||||
IVFFLAT = "ivfflat"
|
||||
|
||||
|
||||
class IndexMeasure(str, Enum):
|
||||
COSINE = "cosine_distance"
|
||||
L2 = "l2_distance"
|
||||
L1 = "l1_distance"
|
||||
MAX_INNER_PRODUCT = "max_inner_product"
|
||||
|
||||
|
||||
class SupabaseConfig(BaseModel):
|
||||
connection_string: str = Field(..., description="PostgreSQL connection string")
|
||||
collection_name: str = Field("mem0", description="Name for the vector collection")
|
||||
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model")
|
||||
index_method: Optional[IndexMethod] = Field(IndexMethod.AUTO, description="Index method to use")
|
||||
index_measure: Optional[IndexMeasure] = Field(IndexMeasure.COSINE, description="Distance measure to use")
|
||||
|
||||
@model_validator(mode="before")
|
||||
def check_connection_string(cls, values):
|
||||
conn_str = values.get("connection_string")
|
||||
if not conn_str or not conn_str.startswith("postgresql://"):
|
||||
raise ValueError("A valid PostgreSQL connection string must be provided")
|
||||
return values
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
allowed_fields = set(cls.model_fields.keys())
|
||||
input_fields = set(values.keys())
|
||||
extra_fields = input_fields - allowed_fields
|
||||
if extra_fields:
|
||||
raise ValueError(
|
||||
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
|
||||
)
|
||||
return values
|
||||
34
configs/vector_stores/upstash_vector.py
Normal file
34
configs/vector_stores/upstash_vector.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import os
|
||||
from typing import Any, ClassVar, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
try:
|
||||
from upstash_vector import Index
|
||||
except ImportError:
|
||||
raise ImportError("The 'upstash_vector' library is required. Please install it using 'pip install upstash_vector'.")
|
||||
|
||||
|
||||
class UpstashVectorConfig(BaseModel):
|
||||
Index: ClassVar[type] = Index
|
||||
|
||||
url: Optional[str] = Field(None, description="URL for Upstash Vector index")
|
||||
token: Optional[str] = Field(None, description="Token for Upstash Vector index")
|
||||
client: Optional[Index] = Field(None, description="Existing `upstash_vector.Index` client instance")
|
||||
collection_name: str = Field("mem0", description="Namespace to use for the index")
|
||||
enable_embeddings: bool = Field(
|
||||
False, description="Whether to use built-in upstash embeddings or not. Default is True."
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_credentials_or_client(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
client = values.get("client")
|
||||
url = values.get("url") or os.environ.get("UPSTASH_VECTOR_REST_URL")
|
||||
token = values.get("token") or os.environ.get("UPSTASH_VECTOR_REST_TOKEN")
|
||||
|
||||
if not client and not (url and token):
|
||||
raise ValueError("Either a client or URL and token must be provided.")
|
||||
return values
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
15
configs/vector_stores/valkey.py
Normal file
15
configs/vector_stores/valkey.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ValkeyConfig(BaseModel):
|
||||
"""Configuration for Valkey vector store."""
|
||||
|
||||
valkey_url: str
|
||||
collection_name: str
|
||||
embedding_model_dims: int
|
||||
timezone: str = "UTC"
|
||||
index_type: str = "hnsw" # Default to HNSW, can be 'hnsw' or 'flat'
|
||||
# HNSW specific parameters with recommended defaults
|
||||
hnsw_m: int = 16 # Number of connections per layer (default from Valkey docs)
|
||||
hnsw_ef_construction: int = 200 # Search width during construction
|
||||
hnsw_ef_runtime: int = 10 # Search width during queries
|
||||
28
configs/vector_stores/vertex_ai_vector_search.py
Normal file
28
configs/vector_stores/vertex_ai_vector_search.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class GoogleMatchingEngineConfig(BaseModel):
|
||||
project_id: str = Field(description="Google Cloud project ID")
|
||||
project_number: str = Field(description="Google Cloud project number")
|
||||
region: str = Field(description="Google Cloud region")
|
||||
endpoint_id: str = Field(description="Vertex AI Vector Search endpoint ID")
|
||||
index_id: str = Field(description="Vertex AI Vector Search index ID")
|
||||
deployment_index_id: str = Field(description="Deployment-specific index ID")
|
||||
collection_name: Optional[str] = Field(None, description="Collection name, defaults to index_id")
|
||||
credentials_path: Optional[str] = Field(None, description="Path to service account credentials JSON file")
|
||||
service_account_json: Optional[Dict] = Field(None, description="Service account credentials as dictionary (alternative to credentials_path)")
|
||||
vector_search_api_endpoint: Optional[str] = Field(None, description="Vector search API endpoint")
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if not self.collection_name:
|
||||
self.collection_name = self.index_id
|
||||
|
||||
def model_post_init(self, _context) -> None:
|
||||
"""Set collection_name to index_id if not provided"""
|
||||
if self.collection_name is None:
|
||||
self.collection_name = self.index_id
|
||||
41
configs/vector_stores/weaviate.py
Normal file
41
configs/vector_stores/weaviate.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from typing import Any, ClassVar, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
|
||||
class WeaviateConfig(BaseModel):
|
||||
from weaviate import WeaviateClient
|
||||
|
||||
WeaviateClient: ClassVar[type] = WeaviateClient
|
||||
|
||||
collection_name: str = Field("mem0", description="Name of the collection")
|
||||
embedding_model_dims: int = Field(1536, description="Dimensions of the embedding model")
|
||||
cluster_url: Optional[str] = Field(None, description="URL for Weaviate server")
|
||||
auth_client_secret: Optional[str] = Field(None, description="API key for Weaviate authentication")
|
||||
additional_headers: Optional[Dict[str, str]] = Field(None, description="Additional headers for requests")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_connection_params(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
cluster_url = values.get("cluster_url")
|
||||
|
||||
if not cluster_url:
|
||||
raise ValueError("'cluster_url' must be provided.")
|
||||
|
||||
return values
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
allowed_fields = set(cls.model_fields.keys())
|
||||
input_fields = set(values.keys())
|
||||
extra_fields = input_fields - allowed_fields
|
||||
|
||||
if extra_fields:
|
||||
raise ValueError(
|
||||
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
Reference in New Issue
Block a user