import importlib from typing import Dict, Optional, Union from mem0.configs.embeddings.base import BaseEmbedderConfig from mem0.configs.llms.anthropic import AnthropicConfig from mem0.configs.llms.azure import AzureOpenAIConfig from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.deepseek import DeepSeekConfig from mem0.configs.llms.lmstudio import LMStudioConfig from mem0.configs.llms.ollama import OllamaConfig from mem0.configs.llms.openai import OpenAIConfig from mem0.configs.llms.vllm import VllmConfig from mem0.configs.rerankers.base import BaseRerankerConfig from mem0.configs.rerankers.cohere import CohereRerankerConfig from mem0.configs.rerankers.sentence_transformer import SentenceTransformerRerankerConfig from mem0.configs.rerankers.zero_entropy import ZeroEntropyRerankerConfig from mem0.configs.rerankers.llm import LLMRerankerConfig from mem0.configs.rerankers.huggingface import HuggingFaceRerankerConfig from mem0.embeddings.mock import MockEmbeddings def load_class(class_type): module_path, class_name = class_type.rsplit(".", 1) module = importlib.import_module(module_path) return getattr(module, class_name) class LlmFactory: """ Factory for creating LLM instances with appropriate configurations. Supports both old-style BaseLlmConfig and new provider-specific configs. """ # Provider mappings with their config classes provider_to_class = { "ollama": ("mem0.llms.ollama.OllamaLLM", OllamaConfig), "openai": ("mem0.llms.openai.OpenAILLM", OpenAIConfig), "groq": ("mem0.llms.groq.GroqLLM", BaseLlmConfig), "together": ("mem0.llms.together.TogetherLLM", BaseLlmConfig), "aws_bedrock": ("mem0.llms.aws_bedrock.AWSBedrockLLM", BaseLlmConfig), "litellm": ("mem0.llms.litellm.LiteLLM", BaseLlmConfig), "azure_openai": ("mem0.llms.azure_openai.AzureOpenAILLM", AzureOpenAIConfig), "openai_structured": ("mem0.llms.openai_structured.OpenAIStructuredLLM", OpenAIConfig), "anthropic": ("mem0.llms.anthropic.AnthropicLLM", AnthropicConfig), "azure_openai_structured": ("mem0.llms.azure_openai_structured.AzureOpenAIStructuredLLM", AzureOpenAIConfig), "gemini": ("mem0.llms.gemini.GeminiLLM", BaseLlmConfig), "deepseek": ("mem0.llms.deepseek.DeepSeekLLM", DeepSeekConfig), "xai": ("mem0.llms.xai.XAILLM", BaseLlmConfig), "sarvam": ("mem0.llms.sarvam.SarvamLLM", BaseLlmConfig), "lmstudio": ("mem0.llms.lmstudio.LMStudioLLM", LMStudioConfig), "vllm": ("mem0.llms.vllm.VllmLLM", VllmConfig), "langchain": ("mem0.llms.langchain.LangchainLLM", BaseLlmConfig), } @classmethod def create(cls, provider_name: str, config: Optional[Union[BaseLlmConfig, Dict]] = None, **kwargs): """ Create an LLM instance with the appropriate configuration. Args: provider_name (str): The provider name (e.g., 'openai', 'anthropic') config: Configuration object or dict. If None, will create default config **kwargs: Additional configuration parameters Returns: Configured LLM instance Raises: ValueError: If provider is not supported """ if provider_name not in cls.provider_to_class: raise ValueError(f"Unsupported Llm provider: {provider_name}") class_type, config_class = cls.provider_to_class[provider_name] llm_class = load_class(class_type) # Handle configuration if config is None: # Create default config with kwargs config = config_class(**kwargs) elif isinstance(config, dict): # Merge dict config with kwargs config.update(kwargs) config = config_class(**config) elif isinstance(config, BaseLlmConfig): # Convert base config to provider-specific config if needed if config_class != BaseLlmConfig: # Convert to provider-specific config config_dict = { "model": config.model, "temperature": config.temperature, "api_key": config.api_key, "max_tokens": config.max_tokens, "top_p": config.top_p, "top_k": config.top_k, "enable_vision": config.enable_vision, "vision_details": config.vision_details, "http_client_proxies": config.http_client, } config_dict.update(kwargs) config = config_class(**config_dict) else: # Use base config as-is pass else: # Assume it's already the correct config type pass return llm_class(config) @classmethod def register_provider(cls, name: str, class_path: str, config_class=None): """ Register a new provider. Args: name (str): Provider name class_path (str): Full path to LLM class config_class: Configuration class for the provider (defaults to BaseLlmConfig) """ if config_class is None: config_class = BaseLlmConfig cls.provider_to_class[name] = (class_path, config_class) @classmethod def get_supported_providers(cls) -> list: """ Get list of supported providers. Returns: list: List of supported provider names """ return list(cls.provider_to_class.keys()) class EmbedderFactory: provider_to_class = { "openai": "mem0.embeddings.openai.OpenAIEmbedding", "ollama": "mem0.embeddings.ollama.OllamaEmbedding", "huggingface": "mem0.embeddings.huggingface.HuggingFaceEmbedding", "azure_openai": "mem0.embeddings.azure_openai.AzureOpenAIEmbedding", "gemini": "mem0.embeddings.gemini.GoogleGenAIEmbedding", "vertexai": "mem0.embeddings.vertexai.VertexAIEmbedding", "together": "mem0.embeddings.together.TogetherEmbedding", "lmstudio": "mem0.embeddings.lmstudio.LMStudioEmbedding", "langchain": "mem0.embeddings.langchain.LangchainEmbedding", "aws_bedrock": "mem0.embeddings.aws_bedrock.AWSBedrockEmbedding", "fastembed": "mem0.embeddings.fastembed.FastEmbedEmbedding", } @classmethod def create(cls, provider_name, config, vector_config: Optional[dict]): if provider_name == "upstash_vector" and vector_config and vector_config.enable_embeddings: return MockEmbeddings() class_type = cls.provider_to_class.get(provider_name) if class_type: embedder_instance = load_class(class_type) base_config = BaseEmbedderConfig(**config) return embedder_instance(base_config) else: raise ValueError(f"Unsupported Embedder provider: {provider_name}") class VectorStoreFactory: provider_to_class = { "qdrant": "mem0.vector_stores.qdrant.Qdrant", "chroma": "mem0.vector_stores.chroma.ChromaDB", "pgvector": "mem0.vector_stores.pgvector.PGVector", "milvus": "mem0.vector_stores.milvus.MilvusDB", "upstash_vector": "mem0.vector_stores.upstash_vector.UpstashVector", "azure_ai_search": "mem0.vector_stores.azure_ai_search.AzureAISearch", "azure_mysql": "mem0.vector_stores.azure_mysql.AzureMySQL", "pinecone": "mem0.vector_stores.pinecone.PineconeDB", "mongodb": "mem0.vector_stores.mongodb.MongoDB", "redis": "mem0.vector_stores.redis.RedisDB", "valkey": "mem0.vector_stores.valkey.ValkeyDB", "databricks": "mem0.vector_stores.databricks.Databricks", "elasticsearch": "mem0.vector_stores.elasticsearch.ElasticsearchDB", "vertex_ai_vector_search": "mem0.vector_stores.vertex_ai_vector_search.GoogleMatchingEngine", "opensearch": "mem0.vector_stores.opensearch.OpenSearchDB", "supabase": "mem0.vector_stores.supabase.Supabase", "weaviate": "mem0.vector_stores.weaviate.Weaviate", "faiss": "mem0.vector_stores.faiss.FAISS", "langchain": "mem0.vector_stores.langchain.Langchain", "s3_vectors": "mem0.vector_stores.s3_vectors.S3Vectors", "baidu": "mem0.vector_stores.baidu.BaiduDB", "cassandra": "mem0.vector_stores.cassandra.CassandraDB", "neptune": "mem0.vector_stores.neptune_analytics.NeptuneAnalyticsVector", } @classmethod def create(cls, provider_name, config): class_type = cls.provider_to_class.get(provider_name) if class_type: if not isinstance(config, dict): config = config.model_dump() vector_store_instance = load_class(class_type) return vector_store_instance(**config) else: raise ValueError(f"Unsupported VectorStore provider: {provider_name}") @classmethod def reset(cls, instance): instance.reset() return instance class GraphStoreFactory: """ Factory for creating MemoryGraph instances for different graph store providers. Usage: GraphStoreFactory.create(provider_name, config) """ provider_to_class = { "memgraph": "mem0.memory.memgraph_memory.MemoryGraph", "neptune": "mem0.graphs.neptune.neptunegraph.MemoryGraph", "neptunedb": "mem0.graphs.neptune.neptunedb.MemoryGraph", "kuzu": "mem0.memory.kuzu_memory.MemoryGraph", "default": "mem0.memory.graph_memory.MemoryGraph", } @classmethod def create(cls, provider_name, config): class_type = cls.provider_to_class.get(provider_name, cls.provider_to_class["default"]) try: GraphClass = load_class(class_type) except (ImportError, AttributeError) as e: raise ImportError(f"Could not import MemoryGraph for provider '{provider_name}': {e}") return GraphClass(config) class RerankerFactory: """ Factory for creating reranker instances with appropriate configurations. Supports provider-specific configs following the same pattern as other factories. """ # Provider mappings with their config classes provider_to_class = { "cohere": ("mem0.reranker.cohere_reranker.CohereReranker", CohereRerankerConfig), "sentence_transformer": ("mem0.reranker.sentence_transformer_reranker.SentenceTransformerReranker", SentenceTransformerRerankerConfig), "zero_entropy": ("mem0.reranker.zero_entropy_reranker.ZeroEntropyReranker", ZeroEntropyRerankerConfig), "llm_reranker": ("mem0.reranker.llm_reranker.LLMReranker", LLMRerankerConfig), "huggingface": ("mem0.reranker.huggingface_reranker.HuggingFaceReranker", HuggingFaceRerankerConfig), } @classmethod def create(cls, provider_name: str, config: Optional[Union[BaseRerankerConfig, Dict]] = None, **kwargs): """ Create a reranker instance based on the provider and configuration. Args: provider_name: The reranker provider (e.g., 'cohere', 'sentence_transformer') config: Configuration object or dictionary **kwargs: Additional configuration parameters Returns: Reranker instance configured for the specified provider Raises: ImportError: If the provider class cannot be imported ValueError: If the provider is not supported """ if provider_name not in cls.provider_to_class: raise ValueError(f"Unsupported reranker provider: {provider_name}") class_path, config_class = cls.provider_to_class[provider_name] # Handle configuration if config is None: config = config_class(**kwargs) elif isinstance(config, dict): config = config_class(**config, **kwargs) elif not isinstance(config, BaseRerankerConfig): raise ValueError(f"Config must be a {config_class.__name__} instance or dict") # Import and create the reranker class try: reranker_class = load_class(class_path) except (ImportError, AttributeError) as e: raise ImportError(f"Could not import reranker for provider '{provider_name}': {e}") return reranker_class(config)