Files
mem0/utils/factory.py
2026-03-06 21:11:10 +08:00

284 lines
12 KiB
Python

import importlib
from typing import Dict, Optional, Union
from mem0.configs.embeddings.base import BaseEmbedderConfig
from mem0.configs.llms.anthropic import AnthropicConfig
from mem0.configs.llms.azure import AzureOpenAIConfig
from mem0.configs.llms.base import BaseLlmConfig
from mem0.configs.llms.deepseek import DeepSeekConfig
from mem0.configs.llms.lmstudio import LMStudioConfig
from mem0.configs.llms.ollama import OllamaConfig
from mem0.configs.llms.openai import OpenAIConfig
from mem0.configs.llms.vllm import VllmConfig
from mem0.configs.rerankers.base import BaseRerankerConfig
from mem0.configs.rerankers.cohere import CohereRerankerConfig
from mem0.configs.rerankers.sentence_transformer import SentenceTransformerRerankerConfig
from mem0.configs.rerankers.zero_entropy import ZeroEntropyRerankerConfig
from mem0.configs.rerankers.llm import LLMRerankerConfig
from mem0.configs.rerankers.huggingface import HuggingFaceRerankerConfig
from mem0.embeddings.mock import MockEmbeddings
def load_class(class_type):
module_path, class_name = class_type.rsplit(".", 1)
module = importlib.import_module(module_path)
return getattr(module, class_name)
class LlmFactory:
"""
Factory for creating LLM instances with appropriate configurations.
Supports both old-style BaseLlmConfig and new provider-specific configs.
"""
# Provider mappings with their config classes
provider_to_class = {
"ollama": ("mem0.llms.ollama.OllamaLLM", OllamaConfig),
"openai": ("mem0.llms.openai.OpenAILLM", OpenAIConfig),
"groq": ("mem0.llms.groq.GroqLLM", BaseLlmConfig),
"together": ("mem0.llms.together.TogetherLLM", BaseLlmConfig),
"aws_bedrock": ("mem0.llms.aws_bedrock.AWSBedrockLLM", BaseLlmConfig),
"litellm": ("mem0.llms.litellm.LiteLLM", BaseLlmConfig),
"azure_openai": ("mem0.llms.azure_openai.AzureOpenAILLM", AzureOpenAIConfig),
"openai_structured": ("mem0.llms.openai_structured.OpenAIStructuredLLM", OpenAIConfig),
"anthropic": ("mem0.llms.anthropic.AnthropicLLM", AnthropicConfig),
"azure_openai_structured": ("mem0.llms.azure_openai_structured.AzureOpenAIStructuredLLM", AzureOpenAIConfig),
"gemini": ("mem0.llms.gemini.GeminiLLM", BaseLlmConfig),
"deepseek": ("mem0.llms.deepseek.DeepSeekLLM", DeepSeekConfig),
"xai": ("mem0.llms.xai.XAILLM", BaseLlmConfig),
"sarvam": ("mem0.llms.sarvam.SarvamLLM", BaseLlmConfig),
"lmstudio": ("mem0.llms.lmstudio.LMStudioLLM", LMStudioConfig),
"vllm": ("mem0.llms.vllm.VllmLLM", VllmConfig),
"langchain": ("mem0.llms.langchain.LangchainLLM", BaseLlmConfig),
}
@classmethod
def create(cls, provider_name: str, config: Optional[Union[BaseLlmConfig, Dict]] = None, **kwargs):
"""
Create an LLM instance with the appropriate configuration.
Args:
provider_name (str): The provider name (e.g., 'openai', 'anthropic')
config: Configuration object or dict. If None, will create default config
**kwargs: Additional configuration parameters
Returns:
Configured LLM instance
Raises:
ValueError: If provider is not supported
"""
if provider_name not in cls.provider_to_class:
raise ValueError(f"Unsupported Llm provider: {provider_name}")
class_type, config_class = cls.provider_to_class[provider_name]
llm_class = load_class(class_type)
# Handle configuration
if config is None:
# Create default config with kwargs
config = config_class(**kwargs)
elif isinstance(config, dict):
# Merge dict config with kwargs
config.update(kwargs)
config = config_class(**config)
elif isinstance(config, BaseLlmConfig):
# Convert base config to provider-specific config if needed
if config_class != BaseLlmConfig:
# Convert to provider-specific config
config_dict = {
"model": config.model,
"temperature": config.temperature,
"api_key": config.api_key,
"max_tokens": config.max_tokens,
"top_p": config.top_p,
"top_k": config.top_k,
"enable_vision": config.enable_vision,
"vision_details": config.vision_details,
"http_client_proxies": config.http_client,
}
config_dict.update(kwargs)
config = config_class(**config_dict)
else:
# Use base config as-is
pass
else:
# Assume it's already the correct config type
pass
return llm_class(config)
@classmethod
def register_provider(cls, name: str, class_path: str, config_class=None):
"""
Register a new provider.
Args:
name (str): Provider name
class_path (str): Full path to LLM class
config_class: Configuration class for the provider (defaults to BaseLlmConfig)
"""
if config_class is None:
config_class = BaseLlmConfig
cls.provider_to_class[name] = (class_path, config_class)
@classmethod
def get_supported_providers(cls) -> list:
"""
Get list of supported providers.
Returns:
list: List of supported provider names
"""
return list(cls.provider_to_class.keys())
class EmbedderFactory:
provider_to_class = {
"openai": "mem0.embeddings.openai.OpenAIEmbedding",
"ollama": "mem0.embeddings.ollama.OllamaEmbedding",
"huggingface": "mem0.embeddings.huggingface.HuggingFaceEmbedding",
"azure_openai": "mem0.embeddings.azure_openai.AzureOpenAIEmbedding",
"gemini": "mem0.embeddings.gemini.GoogleGenAIEmbedding",
"vertexai": "mem0.embeddings.vertexai.VertexAIEmbedding",
"together": "mem0.embeddings.together.TogetherEmbedding",
"lmstudio": "mem0.embeddings.lmstudio.LMStudioEmbedding",
"langchain": "mem0.embeddings.langchain.LangchainEmbedding",
"aws_bedrock": "mem0.embeddings.aws_bedrock.AWSBedrockEmbedding",
"fastembed": "mem0.embeddings.fastembed.FastEmbedEmbedding",
}
@classmethod
def create(cls, provider_name, config, vector_config: Optional[dict]):
if provider_name == "upstash_vector" and vector_config and vector_config.enable_embeddings:
return MockEmbeddings()
class_type = cls.provider_to_class.get(provider_name)
if class_type:
embedder_instance = load_class(class_type)
base_config = BaseEmbedderConfig(**config)
return embedder_instance(base_config)
else:
raise ValueError(f"Unsupported Embedder provider: {provider_name}")
class VectorStoreFactory:
provider_to_class = {
"qdrant": "mem0.vector_stores.qdrant.Qdrant",
"chroma": "mem0.vector_stores.chroma.ChromaDB",
"pgvector": "mem0.vector_stores.pgvector.PGVector",
"milvus": "mem0.vector_stores.milvus.MilvusDB",
"upstash_vector": "mem0.vector_stores.upstash_vector.UpstashVector",
"azure_ai_search": "mem0.vector_stores.azure_ai_search.AzureAISearch",
"azure_mysql": "mem0.vector_stores.azure_mysql.AzureMySQL",
"pinecone": "mem0.vector_stores.pinecone.PineconeDB",
"mongodb": "mem0.vector_stores.mongodb.MongoDB",
"redis": "mem0.vector_stores.redis.RedisDB",
"valkey": "mem0.vector_stores.valkey.ValkeyDB",
"databricks": "mem0.vector_stores.databricks.Databricks",
"elasticsearch": "mem0.vector_stores.elasticsearch.ElasticsearchDB",
"vertex_ai_vector_search": "mem0.vector_stores.vertex_ai_vector_search.GoogleMatchingEngine",
"opensearch": "mem0.vector_stores.opensearch.OpenSearchDB",
"supabase": "mem0.vector_stores.supabase.Supabase",
"weaviate": "mem0.vector_stores.weaviate.Weaviate",
"faiss": "mem0.vector_stores.faiss.FAISS",
"langchain": "mem0.vector_stores.langchain.Langchain",
"s3_vectors": "mem0.vector_stores.s3_vectors.S3Vectors",
"baidu": "mem0.vector_stores.baidu.BaiduDB",
"cassandra": "mem0.vector_stores.cassandra.CassandraDB",
"neptune": "mem0.vector_stores.neptune_analytics.NeptuneAnalyticsVector",
}
@classmethod
def create(cls, provider_name, config):
class_type = cls.provider_to_class.get(provider_name)
if class_type:
if not isinstance(config, dict):
config = config.model_dump()
vector_store_instance = load_class(class_type)
return vector_store_instance(**config)
else:
raise ValueError(f"Unsupported VectorStore provider: {provider_name}")
@classmethod
def reset(cls, instance):
instance.reset()
return instance
class GraphStoreFactory:
"""
Factory for creating MemoryGraph instances for different graph store providers.
Usage: GraphStoreFactory.create(provider_name, config)
"""
provider_to_class = {
"memgraph": "mem0.memory.memgraph_memory.MemoryGraph",
"neptune": "mem0.graphs.neptune.neptunegraph.MemoryGraph",
"neptunedb": "mem0.graphs.neptune.neptunedb.MemoryGraph",
"kuzu": "mem0.memory.kuzu_memory.MemoryGraph",
"default": "mem0.memory.graph_memory.MemoryGraph",
}
@classmethod
def create(cls, provider_name, config):
class_type = cls.provider_to_class.get(provider_name, cls.provider_to_class["default"])
try:
GraphClass = load_class(class_type)
except (ImportError, AttributeError) as e:
raise ImportError(f"Could not import MemoryGraph for provider '{provider_name}': {e}")
return GraphClass(config)
class RerankerFactory:
"""
Factory for creating reranker instances with appropriate configurations.
Supports provider-specific configs following the same pattern as other factories.
"""
# Provider mappings with their config classes
provider_to_class = {
"cohere": ("mem0.reranker.cohere_reranker.CohereReranker", CohereRerankerConfig),
"sentence_transformer": ("mem0.reranker.sentence_transformer_reranker.SentenceTransformerReranker", SentenceTransformerRerankerConfig),
"zero_entropy": ("mem0.reranker.zero_entropy_reranker.ZeroEntropyReranker", ZeroEntropyRerankerConfig),
"llm_reranker": ("mem0.reranker.llm_reranker.LLMReranker", LLMRerankerConfig),
"huggingface": ("mem0.reranker.huggingface_reranker.HuggingFaceReranker", HuggingFaceRerankerConfig),
}
@classmethod
def create(cls, provider_name: str, config: Optional[Union[BaseRerankerConfig, Dict]] = None, **kwargs):
"""
Create a reranker instance based on the provider and configuration.
Args:
provider_name: The reranker provider (e.g., 'cohere', 'sentence_transformer')
config: Configuration object or dictionary
**kwargs: Additional configuration parameters
Returns:
Reranker instance configured for the specified provider
Raises:
ImportError: If the provider class cannot be imported
ValueError: If the provider is not supported
"""
if provider_name not in cls.provider_to_class:
raise ValueError(f"Unsupported reranker provider: {provider_name}")
class_path, config_class = cls.provider_to_class[provider_name]
# Handle configuration
if config is None:
config = config_class(**kwargs)
elif isinstance(config, dict):
config = config_class(**config, **kwargs)
elif not isinstance(config, BaseRerankerConfig):
raise ValueError(f"Config must be a {config_class.__name__} instance or dict")
# Import and create the reranker class
try:
reranker_class = load_class(class_path)
except (ImportError, AttributeError) as e:
raise ImportError(f"Could not import reranker for provider '{provider_name}': {e}")
return reranker_class(config)