first commit
This commit is contained in:
0
configs/llms/__init__.py
Normal file
0
configs/llms/__init__.py
Normal file
56
configs/llms/anthropic.py
Normal file
56
configs/llms/anthropic.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from typing import Optional
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
|
||||
|
||||
class AnthropicConfig(BaseLlmConfig):
|
||||
"""
|
||||
Configuration class for Anthropic-specific parameters.
|
||||
Inherits from BaseLlmConfig and adds Anthropic-specific settings.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# Base parameters
|
||||
model: Optional[str] = None,
|
||||
temperature: float = 0.1,
|
||||
api_key: Optional[str] = None,
|
||||
max_tokens: int = 2000,
|
||||
top_p: float = 0.1,
|
||||
top_k: int = 1,
|
||||
enable_vision: bool = False,
|
||||
vision_details: Optional[str] = "auto",
|
||||
http_client_proxies: Optional[dict] = None,
|
||||
# Anthropic-specific parameters
|
||||
anthropic_base_url: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize Anthropic configuration.
|
||||
|
||||
Args:
|
||||
model: Anthropic model to use, defaults to None
|
||||
temperature: Controls randomness, defaults to 0.1
|
||||
api_key: Anthropic API key, defaults to None
|
||||
max_tokens: Maximum tokens to generate, defaults to 2000
|
||||
top_p: Nucleus sampling parameter, defaults to 0.1
|
||||
top_k: Top-k sampling parameter, defaults to 1
|
||||
enable_vision: Enable vision capabilities, defaults to False
|
||||
vision_details: Vision detail level, defaults to "auto"
|
||||
http_client_proxies: HTTP client proxy settings, defaults to None
|
||||
anthropic_base_url: Anthropic API base URL, defaults to None
|
||||
"""
|
||||
# Initialize base parameters
|
||||
super().__init__(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
enable_vision=enable_vision,
|
||||
vision_details=vision_details,
|
||||
http_client_proxies=http_client_proxies,
|
||||
)
|
||||
|
||||
# Anthropic-specific parameters
|
||||
self.anthropic_base_url = anthropic_base_url
|
||||
192
configs/llms/aws_bedrock.py
Normal file
192
configs/llms/aws_bedrock.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
|
||||
|
||||
class AWSBedrockConfig(BaseLlmConfig):
|
||||
"""
|
||||
Configuration class for AWS Bedrock LLM integration.
|
||||
|
||||
Supports all available Bedrock models with automatic provider detection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[str] = None,
|
||||
temperature: float = 0.1,
|
||||
max_tokens: int = 2000,
|
||||
top_p: float = 0.9,
|
||||
top_k: int = 1,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
aws_region: str = "",
|
||||
aws_session_token: Optional[str] = None,
|
||||
aws_profile: Optional[str] = None,
|
||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize AWS Bedrock configuration.
|
||||
|
||||
Args:
|
||||
model: Bedrock model identifier (e.g., "amazon.nova-3-mini-20241119-v1:0")
|
||||
temperature: Controls randomness (0.0 to 2.0)
|
||||
max_tokens: Maximum tokens to generate
|
||||
top_p: Nucleus sampling parameter (0.0 to 1.0)
|
||||
top_k: Top-k sampling parameter (1 to 40)
|
||||
aws_access_key_id: AWS access key (optional, uses env vars if not provided)
|
||||
aws_secret_access_key: AWS secret key (optional, uses env vars if not provided)
|
||||
aws_region: AWS region for Bedrock service
|
||||
aws_session_token: AWS session token for temporary credentials
|
||||
aws_profile: AWS profile name for credentials
|
||||
model_kwargs: Additional model-specific parameters
|
||||
**kwargs: Additional arguments passed to base class
|
||||
"""
|
||||
super().__init__(
|
||||
model=model or "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.aws_access_key_id = aws_access_key_id
|
||||
self.aws_secret_access_key = aws_secret_access_key
|
||||
self.aws_region = aws_region or os.getenv("AWS_REGION", "us-west-2")
|
||||
self.aws_session_token = aws_session_token
|
||||
self.aws_profile = aws_profile
|
||||
self.model_kwargs = model_kwargs or {}
|
||||
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
"""Get the provider from the model identifier."""
|
||||
if not self.model or "." not in self.model:
|
||||
return "unknown"
|
||||
return self.model.split(".")[0]
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
"""Get the model name without provider prefix."""
|
||||
if not self.model or "." not in self.model:
|
||||
return self.model
|
||||
return ".".join(self.model.split(".")[1:])
|
||||
|
||||
def get_model_config(self) -> Dict[str, Any]:
|
||||
"""Get model-specific configuration parameters."""
|
||||
base_config = {
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_tokens,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
}
|
||||
|
||||
# Add custom model kwargs
|
||||
base_config.update(self.model_kwargs)
|
||||
|
||||
return base_config
|
||||
|
||||
def get_aws_config(self) -> Dict[str, Any]:
|
||||
"""Get AWS configuration parameters."""
|
||||
config = {
|
||||
"region_name": self.aws_region,
|
||||
}
|
||||
|
||||
if self.aws_access_key_id:
|
||||
config["aws_access_key_id"] = self.aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID")
|
||||
|
||||
if self.aws_secret_access_key:
|
||||
config["aws_secret_access_key"] = self.aws_secret_access_key or os.getenv("AWS_SECRET_ACCESS_KEY")
|
||||
|
||||
if self.aws_session_token:
|
||||
config["aws_session_token"] = self.aws_session_token or os.getenv("AWS_SESSION_TOKEN")
|
||||
|
||||
if self.aws_profile:
|
||||
config["profile_name"] = self.aws_profile or os.getenv("AWS_PROFILE")
|
||||
|
||||
return config
|
||||
|
||||
def validate_model_format(self) -> bool:
|
||||
"""
|
||||
Validate that the model identifier follows Bedrock naming convention.
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise
|
||||
"""
|
||||
if not self.model:
|
||||
return False
|
||||
|
||||
# Check if model follows provider.model-name format
|
||||
if "." not in self.model:
|
||||
return False
|
||||
|
||||
provider, model_name = self.model.split(".", 1)
|
||||
|
||||
# Validate provider
|
||||
valid_providers = [
|
||||
"ai21", "amazon", "anthropic", "cohere", "meta", "mistral",
|
||||
"stability", "writer", "deepseek", "gpt-oss", "perplexity",
|
||||
"snowflake", "titan", "command", "j2", "llama"
|
||||
]
|
||||
|
||||
if provider not in valid_providers:
|
||||
return False
|
||||
|
||||
# Validate model name is not empty
|
||||
if not model_name:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_supported_regions(self) -> List[str]:
|
||||
"""Get list of AWS regions that support Bedrock."""
|
||||
return [
|
||||
"us-east-1",
|
||||
"us-west-2",
|
||||
"us-east-2",
|
||||
"eu-west-1",
|
||||
"ap-southeast-1",
|
||||
"ap-northeast-1",
|
||||
]
|
||||
|
||||
def get_model_capabilities(self) -> Dict[str, Any]:
|
||||
"""Get model capabilities based on provider."""
|
||||
capabilities = {
|
||||
"supports_tools": False,
|
||||
"supports_vision": False,
|
||||
"supports_streaming": False,
|
||||
"supports_multimodal": False,
|
||||
}
|
||||
|
||||
if self.provider == "anthropic":
|
||||
capabilities.update({
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
"supports_streaming": True,
|
||||
"supports_multimodal": True,
|
||||
})
|
||||
elif self.provider == "amazon":
|
||||
capabilities.update({
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
"supports_streaming": True,
|
||||
"supports_multimodal": True,
|
||||
})
|
||||
elif self.provider == "cohere":
|
||||
capabilities.update({
|
||||
"supports_tools": True,
|
||||
"supports_streaming": True,
|
||||
})
|
||||
elif self.provider == "meta":
|
||||
capabilities.update({
|
||||
"supports_vision": True,
|
||||
"supports_streaming": True,
|
||||
})
|
||||
elif self.provider == "mistral":
|
||||
capabilities.update({
|
||||
"supports_vision": True,
|
||||
"supports_streaming": True,
|
||||
})
|
||||
|
||||
return capabilities
|
||||
57
configs/llms/azure.py
Normal file
57
configs/llms/azure.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from mem0.configs.base import AzureConfig
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
|
||||
|
||||
class AzureOpenAIConfig(BaseLlmConfig):
|
||||
"""
|
||||
Configuration class for Azure OpenAI-specific parameters.
|
||||
Inherits from BaseLlmConfig and adds Azure OpenAI-specific settings.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# Base parameters
|
||||
model: Optional[str] = None,
|
||||
temperature: float = 0.1,
|
||||
api_key: Optional[str] = None,
|
||||
max_tokens: int = 2000,
|
||||
top_p: float = 0.1,
|
||||
top_k: int = 1,
|
||||
enable_vision: bool = False,
|
||||
vision_details: Optional[str] = "auto",
|
||||
http_client_proxies: Optional[dict] = None,
|
||||
# Azure OpenAI-specific parameters
|
||||
azure_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize Azure OpenAI configuration.
|
||||
|
||||
Args:
|
||||
model: Azure OpenAI model to use, defaults to None
|
||||
temperature: Controls randomness, defaults to 0.1
|
||||
api_key: Azure OpenAI API key, defaults to None
|
||||
max_tokens: Maximum tokens to generate, defaults to 2000
|
||||
top_p: Nucleus sampling parameter, defaults to 0.1
|
||||
top_k: Top-k sampling parameter, defaults to 1
|
||||
enable_vision: Enable vision capabilities, defaults to False
|
||||
vision_details: Vision detail level, defaults to "auto"
|
||||
http_client_proxies: HTTP client proxy settings, defaults to None
|
||||
azure_kwargs: Azure-specific configuration, defaults to None
|
||||
"""
|
||||
# Initialize base parameters
|
||||
super().__init__(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
enable_vision=enable_vision,
|
||||
vision_details=vision_details,
|
||||
http_client_proxies=http_client_proxies,
|
||||
)
|
||||
|
||||
# Azure OpenAI-specific parameters
|
||||
self.azure_kwargs = AzureConfig(**(azure_kwargs or {}))
|
||||
62
configs/llms/base.py
Normal file
62
configs/llms/base.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from abc import ABC
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
class BaseLlmConfig(ABC):
|
||||
"""
|
||||
Base configuration for LLMs with only common parameters.
|
||||
Provider-specific configurations should be handled by separate config classes.
|
||||
|
||||
This class contains only the parameters that are common across all LLM providers.
|
||||
For provider-specific parameters, use the appropriate provider config class.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[Union[str, Dict]] = None,
|
||||
temperature: float = 0.1,
|
||||
api_key: Optional[str] = None,
|
||||
max_tokens: int = 2000,
|
||||
top_p: float = 0.1,
|
||||
top_k: int = 1,
|
||||
enable_vision: bool = False,
|
||||
vision_details: Optional[str] = "auto",
|
||||
http_client_proxies: Optional[Union[Dict, str]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize a base configuration class instance for the LLM.
|
||||
|
||||
Args:
|
||||
model: The model identifier to use (e.g., "gpt-4.1-nano-2025-04-14", "claude-3-5-sonnet-20240620")
|
||||
Defaults to None (will be set by provider-specific configs)
|
||||
temperature: Controls the randomness of the model's output.
|
||||
Higher values (closer to 1) make output more random, lower values make it more deterministic.
|
||||
Range: 0.0 to 2.0. Defaults to 0.1
|
||||
api_key: API key for the LLM provider. If None, will try to get from environment variables.
|
||||
Defaults to None
|
||||
max_tokens: Maximum number of tokens to generate in the response.
|
||||
Range: 1 to 4096 (varies by model). Defaults to 2000
|
||||
top_p: Nucleus sampling parameter. Controls diversity via nucleus sampling.
|
||||
Higher values (closer to 1) make word selection more diverse.
|
||||
Range: 0.0 to 1.0. Defaults to 0.1
|
||||
top_k: Top-k sampling parameter. Limits the number of tokens considered for each step.
|
||||
Higher values make word selection more diverse.
|
||||
Range: 1 to 40. Defaults to 1
|
||||
enable_vision: Whether to enable vision capabilities for the model.
|
||||
Only applicable to vision-enabled models. Defaults to False
|
||||
vision_details: Level of detail for vision processing.
|
||||
Options: "low", "high", "auto". Defaults to "auto"
|
||||
http_client_proxies: Proxy settings for HTTP client.
|
||||
Can be a dict or string. Defaults to None
|
||||
"""
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.api_key = api_key
|
||||
self.max_tokens = max_tokens
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
self.enable_vision = enable_vision
|
||||
self.vision_details = vision_details
|
||||
self.http_client = httpx.Client(proxies=http_client_proxies) if http_client_proxies else None
|
||||
56
configs/llms/deepseek.py
Normal file
56
configs/llms/deepseek.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from typing import Optional
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
|
||||
|
||||
class DeepSeekConfig(BaseLlmConfig):
|
||||
"""
|
||||
Configuration class for DeepSeek-specific parameters.
|
||||
Inherits from BaseLlmConfig and adds DeepSeek-specific settings.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# Base parameters
|
||||
model: Optional[str] = None,
|
||||
temperature: float = 0.1,
|
||||
api_key: Optional[str] = None,
|
||||
max_tokens: int = 2000,
|
||||
top_p: float = 0.1,
|
||||
top_k: int = 1,
|
||||
enable_vision: bool = False,
|
||||
vision_details: Optional[str] = "auto",
|
||||
http_client_proxies: Optional[dict] = None,
|
||||
# DeepSeek-specific parameters
|
||||
deepseek_base_url: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize DeepSeek configuration.
|
||||
|
||||
Args:
|
||||
model: DeepSeek model to use, defaults to None
|
||||
temperature: Controls randomness, defaults to 0.1
|
||||
api_key: DeepSeek API key, defaults to None
|
||||
max_tokens: Maximum tokens to generate, defaults to 2000
|
||||
top_p: Nucleus sampling parameter, defaults to 0.1
|
||||
top_k: Top-k sampling parameter, defaults to 1
|
||||
enable_vision: Enable vision capabilities, defaults to False
|
||||
vision_details: Vision detail level, defaults to "auto"
|
||||
http_client_proxies: HTTP client proxy settings, defaults to None
|
||||
deepseek_base_url: DeepSeek API base URL, defaults to None
|
||||
"""
|
||||
# Initialize base parameters
|
||||
super().__init__(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
enable_vision=enable_vision,
|
||||
vision_details=vision_details,
|
||||
http_client_proxies=http_client_proxies,
|
||||
)
|
||||
|
||||
# DeepSeek-specific parameters
|
||||
self.deepseek_base_url = deepseek_base_url
|
||||
59
configs/llms/lmstudio.py
Normal file
59
configs/llms/lmstudio.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
|
||||
|
||||
class LMStudioConfig(BaseLlmConfig):
|
||||
"""
|
||||
Configuration class for LM Studio-specific parameters.
|
||||
Inherits from BaseLlmConfig and adds LM Studio-specific settings.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# Base parameters
|
||||
model: Optional[str] = None,
|
||||
temperature: float = 0.1,
|
||||
api_key: Optional[str] = None,
|
||||
max_tokens: int = 2000,
|
||||
top_p: float = 0.1,
|
||||
top_k: int = 1,
|
||||
enable_vision: bool = False,
|
||||
vision_details: Optional[str] = "auto",
|
||||
http_client_proxies: Optional[dict] = None,
|
||||
# LM Studio-specific parameters
|
||||
lmstudio_base_url: Optional[str] = None,
|
||||
lmstudio_response_format: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize LM Studio configuration.
|
||||
|
||||
Args:
|
||||
model: LM Studio model to use, defaults to None
|
||||
temperature: Controls randomness, defaults to 0.1
|
||||
api_key: LM Studio API key, defaults to None
|
||||
max_tokens: Maximum tokens to generate, defaults to 2000
|
||||
top_p: Nucleus sampling parameter, defaults to 0.1
|
||||
top_k: Top-k sampling parameter, defaults to 1
|
||||
enable_vision: Enable vision capabilities, defaults to False
|
||||
vision_details: Vision detail level, defaults to "auto"
|
||||
http_client_proxies: HTTP client proxy settings, defaults to None
|
||||
lmstudio_base_url: LM Studio base URL, defaults to None
|
||||
lmstudio_response_format: LM Studio response format, defaults to None
|
||||
"""
|
||||
# Initialize base parameters
|
||||
super().__init__(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
enable_vision=enable_vision,
|
||||
vision_details=vision_details,
|
||||
http_client_proxies=http_client_proxies,
|
||||
)
|
||||
|
||||
# LM Studio-specific parameters
|
||||
self.lmstudio_base_url = lmstudio_base_url or "http://localhost:1234/v1"
|
||||
self.lmstudio_response_format = lmstudio_response_format
|
||||
56
configs/llms/ollama.py
Normal file
56
configs/llms/ollama.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from typing import Optional
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
|
||||
|
||||
class OllamaConfig(BaseLlmConfig):
|
||||
"""
|
||||
Configuration class for Ollama-specific parameters.
|
||||
Inherits from BaseLlmConfig and adds Ollama-specific settings.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# Base parameters
|
||||
model: Optional[str] = None,
|
||||
temperature: float = 0.1,
|
||||
api_key: Optional[str] = None,
|
||||
max_tokens: int = 2000,
|
||||
top_p: float = 0.1,
|
||||
top_k: int = 1,
|
||||
enable_vision: bool = False,
|
||||
vision_details: Optional[str] = "auto",
|
||||
http_client_proxies: Optional[dict] = None,
|
||||
# Ollama-specific parameters
|
||||
ollama_base_url: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize Ollama configuration.
|
||||
|
||||
Args:
|
||||
model: Ollama model to use, defaults to None
|
||||
temperature: Controls randomness, defaults to 0.1
|
||||
api_key: Ollama API key, defaults to None
|
||||
max_tokens: Maximum tokens to generate, defaults to 2000
|
||||
top_p: Nucleus sampling parameter, defaults to 0.1
|
||||
top_k: Top-k sampling parameter, defaults to 1
|
||||
enable_vision: Enable vision capabilities, defaults to False
|
||||
vision_details: Vision detail level, defaults to "auto"
|
||||
http_client_proxies: HTTP client proxy settings, defaults to None
|
||||
ollama_base_url: Ollama base URL, defaults to None
|
||||
"""
|
||||
# Initialize base parameters
|
||||
super().__init__(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
enable_vision=enable_vision,
|
||||
vision_details=vision_details,
|
||||
http_client_proxies=http_client_proxies,
|
||||
)
|
||||
|
||||
# Ollama-specific parameters
|
||||
self.ollama_base_url = ollama_base_url
|
||||
79
configs/llms/openai.py
Normal file
79
configs/llms/openai.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from typing import Any, Callable, List, Optional
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
|
||||
|
||||
class OpenAIConfig(BaseLlmConfig):
|
||||
"""
|
||||
Configuration class for OpenAI and OpenRouter-specific parameters.
|
||||
Inherits from BaseLlmConfig and adds OpenAI-specific settings.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# Base parameters
|
||||
model: Optional[str] = None,
|
||||
temperature: float = 0.1,
|
||||
api_key: Optional[str] = None,
|
||||
max_tokens: int = 2000,
|
||||
top_p: float = 0.1,
|
||||
top_k: int = 1,
|
||||
enable_vision: bool = False,
|
||||
vision_details: Optional[str] = "auto",
|
||||
http_client_proxies: Optional[dict] = None,
|
||||
# OpenAI-specific parameters
|
||||
openai_base_url: Optional[str] = None,
|
||||
models: Optional[List[str]] = None,
|
||||
route: Optional[str] = "fallback",
|
||||
openrouter_base_url: Optional[str] = None,
|
||||
site_url: Optional[str] = None,
|
||||
app_name: Optional[str] = None,
|
||||
store: bool = False,
|
||||
# Response monitoring callback
|
||||
response_callback: Optional[Callable[[Any, dict, dict], None]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize OpenAI configuration.
|
||||
|
||||
Args:
|
||||
model: OpenAI model to use, defaults to None
|
||||
temperature: Controls randomness, defaults to 0.1
|
||||
api_key: OpenAI API key, defaults to None
|
||||
max_tokens: Maximum tokens to generate, defaults to 2000
|
||||
top_p: Nucleus sampling parameter, defaults to 0.1
|
||||
top_k: Top-k sampling parameter, defaults to 1
|
||||
enable_vision: Enable vision capabilities, defaults to False
|
||||
vision_details: Vision detail level, defaults to "auto"
|
||||
http_client_proxies: HTTP client proxy settings, defaults to None
|
||||
openai_base_url: OpenAI API base URL, defaults to None
|
||||
models: List of models for OpenRouter, defaults to None
|
||||
route: OpenRouter route strategy, defaults to "fallback"
|
||||
openrouter_base_url: OpenRouter base URL, defaults to None
|
||||
site_url: Site URL for OpenRouter, defaults to None
|
||||
app_name: Application name for OpenRouter, defaults to None
|
||||
response_callback: Optional callback for monitoring LLM responses.
|
||||
"""
|
||||
# Initialize base parameters
|
||||
super().__init__(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
enable_vision=enable_vision,
|
||||
vision_details=vision_details,
|
||||
http_client_proxies=http_client_proxies,
|
||||
)
|
||||
|
||||
# OpenAI-specific parameters
|
||||
self.openai_base_url = openai_base_url
|
||||
self.models = models
|
||||
self.route = route
|
||||
self.openrouter_base_url = openrouter_base_url
|
||||
self.site_url = site_url
|
||||
self.app_name = app_name
|
||||
self.store = store
|
||||
|
||||
# Response monitoring
|
||||
self.response_callback = response_callback
|
||||
56
configs/llms/vllm.py
Normal file
56
configs/llms/vllm.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from typing import Optional
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
|
||||
|
||||
class VllmConfig(BaseLlmConfig):
|
||||
"""
|
||||
Configuration class for vLLM-specific parameters.
|
||||
Inherits from BaseLlmConfig and adds vLLM-specific settings.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# Base parameters
|
||||
model: Optional[str] = None,
|
||||
temperature: float = 0.1,
|
||||
api_key: Optional[str] = None,
|
||||
max_tokens: int = 2000,
|
||||
top_p: float = 0.1,
|
||||
top_k: int = 1,
|
||||
enable_vision: bool = False,
|
||||
vision_details: Optional[str] = "auto",
|
||||
http_client_proxies: Optional[dict] = None,
|
||||
# vLLM-specific parameters
|
||||
vllm_base_url: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize vLLM configuration.
|
||||
|
||||
Args:
|
||||
model: vLLM model to use, defaults to None
|
||||
temperature: Controls randomness, defaults to 0.1
|
||||
api_key: vLLM API key, defaults to None
|
||||
max_tokens: Maximum tokens to generate, defaults to 2000
|
||||
top_p: Nucleus sampling parameter, defaults to 0.1
|
||||
top_k: Top-k sampling parameter, defaults to 1
|
||||
enable_vision: Enable vision capabilities, defaults to False
|
||||
vision_details: Vision detail level, defaults to "auto"
|
||||
http_client_proxies: HTTP client proxy settings, defaults to None
|
||||
vllm_base_url: vLLM base URL, defaults to None
|
||||
"""
|
||||
# Initialize base parameters
|
||||
super().__init__(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
enable_vision=enable_vision,
|
||||
vision_details=vision_details,
|
||||
http_client_proxies=http_client_proxies,
|
||||
)
|
||||
|
||||
# vLLM-specific parameters
|
||||
self.vllm_base_url = vllm_base_url or "http://localhost:8000/v1"
|
||||
Reference in New Issue
Block a user