first commit
This commit is contained in:
192
configs/llms/aws_bedrock.py
Normal file
192
configs/llms/aws_bedrock.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
|
||||
|
||||
class AWSBedrockConfig(BaseLlmConfig):
|
||||
"""
|
||||
Configuration class for AWS Bedrock LLM integration.
|
||||
|
||||
Supports all available Bedrock models with automatic provider detection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[str] = None,
|
||||
temperature: float = 0.1,
|
||||
max_tokens: int = 2000,
|
||||
top_p: float = 0.9,
|
||||
top_k: int = 1,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
aws_region: str = "",
|
||||
aws_session_token: Optional[str] = None,
|
||||
aws_profile: Optional[str] = None,
|
||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize AWS Bedrock configuration.
|
||||
|
||||
Args:
|
||||
model: Bedrock model identifier (e.g., "amazon.nova-3-mini-20241119-v1:0")
|
||||
temperature: Controls randomness (0.0 to 2.0)
|
||||
max_tokens: Maximum tokens to generate
|
||||
top_p: Nucleus sampling parameter (0.0 to 1.0)
|
||||
top_k: Top-k sampling parameter (1 to 40)
|
||||
aws_access_key_id: AWS access key (optional, uses env vars if not provided)
|
||||
aws_secret_access_key: AWS secret key (optional, uses env vars if not provided)
|
||||
aws_region: AWS region for Bedrock service
|
||||
aws_session_token: AWS session token for temporary credentials
|
||||
aws_profile: AWS profile name for credentials
|
||||
model_kwargs: Additional model-specific parameters
|
||||
**kwargs: Additional arguments passed to base class
|
||||
"""
|
||||
super().__init__(
|
||||
model=model or "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.aws_access_key_id = aws_access_key_id
|
||||
self.aws_secret_access_key = aws_secret_access_key
|
||||
self.aws_region = aws_region or os.getenv("AWS_REGION", "us-west-2")
|
||||
self.aws_session_token = aws_session_token
|
||||
self.aws_profile = aws_profile
|
||||
self.model_kwargs = model_kwargs or {}
|
||||
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
"""Get the provider from the model identifier."""
|
||||
if not self.model or "." not in self.model:
|
||||
return "unknown"
|
||||
return self.model.split(".")[0]
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
"""Get the model name without provider prefix."""
|
||||
if not self.model or "." not in self.model:
|
||||
return self.model
|
||||
return ".".join(self.model.split(".")[1:])
|
||||
|
||||
def get_model_config(self) -> Dict[str, Any]:
|
||||
"""Get model-specific configuration parameters."""
|
||||
base_config = {
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_tokens,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
}
|
||||
|
||||
# Add custom model kwargs
|
||||
base_config.update(self.model_kwargs)
|
||||
|
||||
return base_config
|
||||
|
||||
def get_aws_config(self) -> Dict[str, Any]:
|
||||
"""Get AWS configuration parameters."""
|
||||
config = {
|
||||
"region_name": self.aws_region,
|
||||
}
|
||||
|
||||
if self.aws_access_key_id:
|
||||
config["aws_access_key_id"] = self.aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID")
|
||||
|
||||
if self.aws_secret_access_key:
|
||||
config["aws_secret_access_key"] = self.aws_secret_access_key or os.getenv("AWS_SECRET_ACCESS_KEY")
|
||||
|
||||
if self.aws_session_token:
|
||||
config["aws_session_token"] = self.aws_session_token or os.getenv("AWS_SESSION_TOKEN")
|
||||
|
||||
if self.aws_profile:
|
||||
config["profile_name"] = self.aws_profile or os.getenv("AWS_PROFILE")
|
||||
|
||||
return config
|
||||
|
||||
def validate_model_format(self) -> bool:
|
||||
"""
|
||||
Validate that the model identifier follows Bedrock naming convention.
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise
|
||||
"""
|
||||
if not self.model:
|
||||
return False
|
||||
|
||||
# Check if model follows provider.model-name format
|
||||
if "." not in self.model:
|
||||
return False
|
||||
|
||||
provider, model_name = self.model.split(".", 1)
|
||||
|
||||
# Validate provider
|
||||
valid_providers = [
|
||||
"ai21", "amazon", "anthropic", "cohere", "meta", "mistral",
|
||||
"stability", "writer", "deepseek", "gpt-oss", "perplexity",
|
||||
"snowflake", "titan", "command", "j2", "llama"
|
||||
]
|
||||
|
||||
if provider not in valid_providers:
|
||||
return False
|
||||
|
||||
# Validate model name is not empty
|
||||
if not model_name:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_supported_regions(self) -> List[str]:
|
||||
"""Get list of AWS regions that support Bedrock."""
|
||||
return [
|
||||
"us-east-1",
|
||||
"us-west-2",
|
||||
"us-east-2",
|
||||
"eu-west-1",
|
||||
"ap-southeast-1",
|
||||
"ap-northeast-1",
|
||||
]
|
||||
|
||||
def get_model_capabilities(self) -> Dict[str, Any]:
|
||||
"""Get model capabilities based on provider."""
|
||||
capabilities = {
|
||||
"supports_tools": False,
|
||||
"supports_vision": False,
|
||||
"supports_streaming": False,
|
||||
"supports_multimodal": False,
|
||||
}
|
||||
|
||||
if self.provider == "anthropic":
|
||||
capabilities.update({
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
"supports_streaming": True,
|
||||
"supports_multimodal": True,
|
||||
})
|
||||
elif self.provider == "amazon":
|
||||
capabilities.update({
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
"supports_streaming": True,
|
||||
"supports_multimodal": True,
|
||||
})
|
||||
elif self.provider == "cohere":
|
||||
capabilities.update({
|
||||
"supports_tools": True,
|
||||
"supports_streaming": True,
|
||||
})
|
||||
elif self.provider == "meta":
|
||||
capabilities.update({
|
||||
"supports_vision": True,
|
||||
"supports_streaming": True,
|
||||
})
|
||||
elif self.provider == "mistral":
|
||||
capabilities.update({
|
||||
"supports_vision": True,
|
||||
"supports_streaming": True,
|
||||
})
|
||||
|
||||
return capabilities
|
||||
Reference in New Issue
Block a user