first commit
This commit is contained in:
665
llms/aws_bedrock.py
Normal file
665
llms/aws_bedrock.py
Normal file
@@ -0,0 +1,665 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
try:
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError, NoCredentialsError
|
||||
except ImportError:
|
||||
raise ImportError("The 'boto3' library is required. Please install it using 'pip install boto3'.")
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.configs.llms.aws_bedrock import AWSBedrockConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.memory.utils import extract_json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PROVIDERS = [
|
||||
"ai21", "amazon", "anthropic", "cohere", "meta", "mistral", "stability", "writer",
|
||||
"deepseek", "gpt-oss", "perplexity", "snowflake", "titan", "command", "j2", "llama"
|
||||
]
|
||||
|
||||
|
||||
def extract_provider(model: str) -> str:
|
||||
"""Extract provider from model identifier."""
|
||||
for provider in PROVIDERS:
|
||||
if re.search(rf"\b{re.escape(provider)}\b", model):
|
||||
return provider
|
||||
raise ValueError(f"Unknown provider in model: {model}")
|
||||
|
||||
|
||||
class AWSBedrockLLM(LLMBase):
|
||||
"""
|
||||
AWS Bedrock LLM integration for Mem0.
|
||||
|
||||
Supports all available Bedrock models with automatic provider detection.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[Union[AWSBedrockConfig, BaseLlmConfig, Dict]] = None):
|
||||
"""
|
||||
Initialize AWS Bedrock LLM.
|
||||
|
||||
Args:
|
||||
config: AWS Bedrock configuration object
|
||||
"""
|
||||
# Convert to AWSBedrockConfig if needed
|
||||
if config is None:
|
||||
config = AWSBedrockConfig()
|
||||
elif isinstance(config, dict):
|
||||
config = AWSBedrockConfig(**config)
|
||||
elif isinstance(config, BaseLlmConfig) and not isinstance(config, AWSBedrockConfig):
|
||||
# Convert BaseLlmConfig to AWSBedrockConfig
|
||||
config = AWSBedrockConfig(
|
||||
model=config.model,
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
top_p=config.top_p,
|
||||
top_k=config.top_k,
|
||||
enable_vision=getattr(config, "enable_vision", False),
|
||||
)
|
||||
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# Initialize AWS client
|
||||
self._initialize_aws_client()
|
||||
|
||||
# Get model configuration
|
||||
self.model_config = self.config.get_model_config()
|
||||
self.provider = extract_provider(self.config.model)
|
||||
|
||||
# Initialize provider-specific settings
|
||||
self._initialize_provider_settings()
|
||||
|
||||
def _initialize_aws_client(self):
|
||||
"""Initialize AWS Bedrock client with proper credentials."""
|
||||
try:
|
||||
aws_config = self.config.get_aws_config()
|
||||
|
||||
# Create Bedrock runtime client
|
||||
self.client = boto3.client("bedrock-runtime", **aws_config)
|
||||
|
||||
# Test connection
|
||||
self._test_connection()
|
||||
|
||||
except NoCredentialsError:
|
||||
raise ValueError(
|
||||
"AWS credentials not found. Please set AWS_ACCESS_KEY_ID, "
|
||||
"AWS_SECRET_ACCESS_KEY, and AWS_REGION environment variables, "
|
||||
"or provide them in the config."
|
||||
)
|
||||
except ClientError as e:
|
||||
if e.response["Error"]["Code"] == "UnauthorizedOperation":
|
||||
raise ValueError(
|
||||
f"Unauthorized access to Bedrock. Please ensure your AWS credentials "
|
||||
f"have permission to access Bedrock in region {self.config.aws_region}."
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"AWS Bedrock error: {e}")
|
||||
|
||||
def _test_connection(self):
|
||||
"""Test connection to AWS Bedrock service."""
|
||||
try:
|
||||
# List available models to test connection
|
||||
bedrock_client = boto3.client("bedrock", **self.config.get_aws_config())
|
||||
response = bedrock_client.list_foundation_models()
|
||||
self.available_models = [model["modelId"] for model in response["modelSummaries"]]
|
||||
|
||||
# Check if our model is available
|
||||
if self.config.model not in self.available_models:
|
||||
logger.warning(f"Model {self.config.model} may not be available in region {self.config.aws_region}")
|
||||
logger.info(f"Available models: {', '.join(self.available_models[:5])}...")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not verify model availability: {e}")
|
||||
self.available_models = []
|
||||
|
||||
def _initialize_provider_settings(self):
|
||||
"""Initialize provider-specific settings and capabilities."""
|
||||
# Determine capabilities based on provider and model
|
||||
self.supports_tools = self.provider in ["anthropic", "cohere", "amazon"]
|
||||
self.supports_vision = self.provider in ["anthropic", "amazon", "meta", "mistral"]
|
||||
self.supports_streaming = self.provider in ["anthropic", "cohere", "mistral", "amazon", "meta"]
|
||||
|
||||
# Set message formatting method
|
||||
if self.provider == "anthropic":
|
||||
self._format_messages = self._format_messages_anthropic
|
||||
elif self.provider == "cohere":
|
||||
self._format_messages = self._format_messages_cohere
|
||||
elif self.provider == "amazon":
|
||||
self._format_messages = self._format_messages_amazon
|
||||
elif self.provider == "meta":
|
||||
self._format_messages = self._format_messages_meta
|
||||
elif self.provider == "mistral":
|
||||
self._format_messages = self._format_messages_mistral
|
||||
else:
|
||||
self._format_messages = self._format_messages_generic
|
||||
|
||||
def _format_messages_anthropic(self, messages: List[Dict[str, str]]) -> tuple[List[Dict[str, Any]], Optional[str]]:
|
||||
"""Format messages for Anthropic models."""
|
||||
formatted_messages = []
|
||||
system_message = None
|
||||
|
||||
for message in messages:
|
||||
role = message["role"]
|
||||
content = message["content"]
|
||||
|
||||
if role == "system":
|
||||
# Anthropic supports system messages as a separate parameter
|
||||
# see: https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/system-prompts
|
||||
system_message = content
|
||||
elif role == "user":
|
||||
# Use Converse API format
|
||||
formatted_messages.append({"role": "user", "content": [{"text": content}]})
|
||||
elif role == "assistant":
|
||||
# Use Converse API format
|
||||
formatted_messages.append({"role": "assistant", "content": [{"text": content}]})
|
||||
|
||||
return formatted_messages, system_message
|
||||
|
||||
def _format_messages_cohere(self, messages: List[Dict[str, str]]) -> str:
|
||||
"""Format messages for Cohere models."""
|
||||
formatted_messages = []
|
||||
|
||||
for message in messages:
|
||||
role = message["role"].capitalize()
|
||||
content = message["content"]
|
||||
formatted_messages.append(f"{role}: {content}")
|
||||
|
||||
return "\n".join(formatted_messages)
|
||||
|
||||
def _format_messages_amazon(self, messages: List[Dict[str, str]]) -> List[Dict[str, Any]]:
|
||||
"""Format messages for Amazon models (including Nova)."""
|
||||
formatted_messages = []
|
||||
|
||||
for message in messages:
|
||||
role = message["role"]
|
||||
content = message["content"]
|
||||
|
||||
if role == "system":
|
||||
# Amazon models support system messages
|
||||
formatted_messages.append({"role": "system", "content": content})
|
||||
elif role == "user":
|
||||
formatted_messages.append({"role": "user", "content": content})
|
||||
elif role == "assistant":
|
||||
formatted_messages.append({"role": "assistant", "content": content})
|
||||
|
||||
return formatted_messages
|
||||
|
||||
def _format_messages_meta(self, messages: List[Dict[str, str]]) -> str:
|
||||
"""Format messages for Meta models."""
|
||||
formatted_messages = []
|
||||
|
||||
for message in messages:
|
||||
role = message["role"].capitalize()
|
||||
content = message["content"]
|
||||
formatted_messages.append(f"{role}: {content}")
|
||||
|
||||
return "\n".join(formatted_messages)
|
||||
|
||||
def _format_messages_mistral(self, messages: List[Dict[str, str]]) -> List[Dict[str, Any]]:
|
||||
"""Format messages for Mistral models."""
|
||||
formatted_messages = []
|
||||
|
||||
for message in messages:
|
||||
role = message["role"]
|
||||
content = message["content"]
|
||||
|
||||
if role == "system":
|
||||
# Mistral supports system messages
|
||||
formatted_messages.append({"role": "system", "content": content})
|
||||
elif role == "user":
|
||||
formatted_messages.append({"role": "user", "content": content})
|
||||
elif role == "assistant":
|
||||
formatted_messages.append({"role": "assistant", "content": content})
|
||||
|
||||
return formatted_messages
|
||||
|
||||
def _format_messages_generic(self, messages: List[Dict[str, str]]) -> str:
|
||||
"""Generic message formatting for other providers."""
|
||||
formatted_messages = []
|
||||
|
||||
for message in messages:
|
||||
role = message["role"].capitalize()
|
||||
content = message["content"]
|
||||
formatted_messages.append(f"\n\n{role}: {content}")
|
||||
|
||||
return "\n\nHuman: " + "".join(formatted_messages) + "\n\nAssistant:"
|
||||
|
||||
def _prepare_input(self, prompt: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Prepare input for the current provider's model.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt to process
|
||||
|
||||
Returns:
|
||||
Prepared input dictionary
|
||||
"""
|
||||
# Base configuration
|
||||
input_body = {"prompt": prompt}
|
||||
|
||||
# Provider-specific parameter mappings
|
||||
provider_mappings = {
|
||||
"meta": {"max_tokens": "max_gen_len"},
|
||||
"ai21": {"max_tokens": "maxTokens", "top_p": "topP"},
|
||||
"mistral": {"max_tokens": "max_tokens"},
|
||||
"cohere": {"max_tokens": "max_tokens", "top_p": "p"},
|
||||
"amazon": {"max_tokens": "maxTokenCount", "top_p": "topP"},
|
||||
"anthropic": {"max_tokens": "max_tokens", "top_p": "top_p"},
|
||||
}
|
||||
|
||||
# Apply provider mappings
|
||||
if self.provider in provider_mappings:
|
||||
for old_key, new_key in provider_mappings[self.provider].items():
|
||||
if old_key in self.model_config:
|
||||
input_body[new_key] = self.model_config[old_key]
|
||||
|
||||
# Special handling for specific providers
|
||||
if self.provider == "cohere" and "cohere.command" in self.config.model:
|
||||
input_body["message"] = input_body.pop("prompt")
|
||||
elif self.provider == "amazon":
|
||||
# Amazon Nova and other Amazon models
|
||||
if "nova" in self.config.model.lower():
|
||||
# Nova models use the converse API format
|
||||
input_body = {
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": self.model_config.get("max_tokens", 5000),
|
||||
"temperature": self.model_config.get("temperature", 0.1),
|
||||
"top_p": self.model_config.get("top_p", 0.9),
|
||||
}
|
||||
else:
|
||||
# Legacy Amazon models
|
||||
input_body = {
|
||||
"inputText": prompt,
|
||||
"textGenerationConfig": {
|
||||
"maxTokenCount": self.model_config.get("max_tokens", 5000),
|
||||
"topP": self.model_config.get("top_p", 0.9),
|
||||
"temperature": self.model_config.get("temperature", 0.1),
|
||||
},
|
||||
}
|
||||
# Remove None values
|
||||
input_body["textGenerationConfig"] = {
|
||||
k: v for k, v in input_body["textGenerationConfig"].items() if v is not None
|
||||
}
|
||||
elif self.provider == "anthropic":
|
||||
input_body = {
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": prompt}]}],
|
||||
"max_tokens": self.model_config.get("max_tokens", 2000),
|
||||
"temperature": self.model_config.get("temperature", 0.1),
|
||||
"top_p": self.model_config.get("top_p", 0.9),
|
||||
"anthropic_version": "bedrock-2023-05-31",
|
||||
}
|
||||
elif self.provider == "meta":
|
||||
input_body = {
|
||||
"prompt": prompt,
|
||||
"max_gen_len": self.model_config.get("max_tokens", 5000),
|
||||
"temperature": self.model_config.get("temperature", 0.1),
|
||||
"top_p": self.model_config.get("top_p", 0.9),
|
||||
}
|
||||
elif self.provider == "mistral":
|
||||
input_body = {
|
||||
"prompt": prompt,
|
||||
"max_tokens": self.model_config.get("max_tokens", 5000),
|
||||
"temperature": self.model_config.get("temperature", 0.1),
|
||||
"top_p": self.model_config.get("top_p", 0.9),
|
||||
}
|
||||
else:
|
||||
# Generic case - add all model config parameters
|
||||
input_body.update(self.model_config)
|
||||
|
||||
return input_body
|
||||
|
||||
def _convert_tool_format(self, original_tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Convert tools to Bedrock-compatible format.
|
||||
|
||||
Args:
|
||||
original_tools: List of tool definitions
|
||||
|
||||
Returns:
|
||||
Converted tools in Bedrock format
|
||||
"""
|
||||
new_tools = []
|
||||
|
||||
for tool in original_tools:
|
||||
if tool["type"] == "function":
|
||||
function = tool["function"]
|
||||
new_tool = {
|
||||
"toolSpec": {
|
||||
"name": function["name"],
|
||||
"description": function.get("description", ""),
|
||||
"inputSchema": {
|
||||
"json": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": function["parameters"].get("required", []),
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
# Add properties
|
||||
for prop, details in function["parameters"].get("properties", {}).items():
|
||||
new_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop] = details
|
||||
|
||||
new_tools.append(new_tool)
|
||||
|
||||
return new_tools
|
||||
|
||||
def _parse_response(
|
||||
self, response: Dict[str, Any], tools: Optional[List[Dict]] = None
|
||||
) -> Union[str, Dict[str, Any]]:
|
||||
"""
|
||||
Parse response from Bedrock API.
|
||||
|
||||
Args:
|
||||
response: Raw API response
|
||||
tools: List of tools if used
|
||||
|
||||
Returns:
|
||||
Parsed response
|
||||
"""
|
||||
if tools:
|
||||
# Handle tool-enabled responses
|
||||
processed_response = {"tool_calls": []}
|
||||
|
||||
if response.get("output", {}).get("message", {}).get("content"):
|
||||
for item in response["output"]["message"]["content"]:
|
||||
if "toolUse" in item:
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": item["toolUse"]["name"],
|
||||
"arguments": json.loads(extract_json(json.dumps(item["toolUse"]["input"]))),
|
||||
}
|
||||
)
|
||||
|
||||
return processed_response
|
||||
|
||||
# Handle regular text responses
|
||||
try:
|
||||
response_body = response.get("body").read().decode()
|
||||
response_json = json.loads(response_body)
|
||||
|
||||
# Provider-specific response parsing
|
||||
if self.provider == "anthropic":
|
||||
return response_json.get("content", [{"text": ""}])[0].get("text", "")
|
||||
elif self.provider == "amazon":
|
||||
# Handle both Nova and legacy Amazon models
|
||||
if "nova" in self.config.model.lower():
|
||||
# Nova models return content in a different format
|
||||
if "content" in response_json:
|
||||
return response_json["content"][0]["text"]
|
||||
elif "completion" in response_json:
|
||||
return response_json["completion"]
|
||||
else:
|
||||
# Legacy Amazon models
|
||||
return response_json.get("completion", "")
|
||||
elif self.provider == "meta":
|
||||
return response_json.get("generation", "")
|
||||
elif self.provider == "mistral":
|
||||
return response_json.get("outputs", [{"text": ""}])[0].get("text", "")
|
||||
elif self.provider == "cohere":
|
||||
return response_json.get("generations", [{"text": ""}])[0].get("text", "")
|
||||
elif self.provider == "ai21":
|
||||
return response_json.get("completions", [{"data", {"text": ""}}])[0].get("data", {}).get("text", "")
|
||||
else:
|
||||
# Generic parsing - try common response fields
|
||||
for field in ["content", "text", "completion", "generation"]:
|
||||
if field in response_json:
|
||||
if isinstance(response_json[field], list) and response_json[field]:
|
||||
return response_json[field][0].get("text", "")
|
||||
elif isinstance(response_json[field], str):
|
||||
return response_json[field]
|
||||
|
||||
# Fallback
|
||||
return str(response_json)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not parse response: {e}")
|
||||
return "Error parsing response"
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_format: Optional[str] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: str = "auto",
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
) -> Union[str, Dict[str, Any]]:
|
||||
"""
|
||||
Generate response using AWS Bedrock.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
response_format: Response format specification
|
||||
tools: List of tools for function calling
|
||||
tool_choice: Tool choice method
|
||||
stream: Whether to stream the response
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
Generated response
|
||||
"""
|
||||
try:
|
||||
if tools and self.supports_tools:
|
||||
# Use converse method for tool-enabled models
|
||||
return self._generate_with_tools(messages, tools, stream)
|
||||
else:
|
||||
# Use standard invoke_model method
|
||||
return self._generate_standard(messages, stream)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate response: {e}")
|
||||
raise RuntimeError(f"Failed to generate response: {e}")
|
||||
|
||||
@staticmethod
|
||||
def _convert_tools_to_converse_format(tools: List[Dict]) -> List[Dict]:
|
||||
"""Convert OpenAI-style tools to Converse API format."""
|
||||
if not tools:
|
||||
return []
|
||||
|
||||
converse_tools = []
|
||||
for tool in tools:
|
||||
if tool.get("type") == "function" and "function" in tool:
|
||||
func = tool["function"]
|
||||
converse_tool = {
|
||||
"toolSpec": {
|
||||
"name": func["name"],
|
||||
"description": func.get("description", ""),
|
||||
"inputSchema": {
|
||||
"json": func.get("parameters", {})
|
||||
}
|
||||
}
|
||||
}
|
||||
converse_tools.append(converse_tool)
|
||||
|
||||
return converse_tools
|
||||
|
||||
def _generate_with_tools(self, messages: List[Dict[str, str]], tools: List[Dict], stream: bool = False) -> Dict[str, Any]:
|
||||
"""Generate response with tool calling support using correct message format."""
|
||||
# Format messages for tool-enabled models
|
||||
system_message = None
|
||||
if self.provider == "anthropic":
|
||||
formatted_messages, system_message = self._format_messages_anthropic(messages)
|
||||
elif self.provider == "amazon":
|
||||
formatted_messages = self._format_messages_amazon(messages)
|
||||
else:
|
||||
formatted_messages = [{"role": "user", "content": [{"text": messages[-1]["content"]}]}]
|
||||
|
||||
# Prepare tool configuration in Converse API format
|
||||
tool_config = None
|
||||
if tools:
|
||||
converse_tools = self._convert_tools_to_converse_format(tools)
|
||||
if converse_tools:
|
||||
tool_config = {"tools": converse_tools}
|
||||
|
||||
# Prepare converse parameters
|
||||
converse_params = {
|
||||
"modelId": self.config.model,
|
||||
"messages": formatted_messages,
|
||||
"inferenceConfig": {
|
||||
"maxTokens": self.model_config.get("max_tokens", 2000),
|
||||
"temperature": self.model_config.get("temperature", 0.1),
|
||||
"topP": self.model_config.get("top_p", 0.9),
|
||||
}
|
||||
}
|
||||
|
||||
# Add system message if present (for Anthropic)
|
||||
if system_message:
|
||||
converse_params["system"] = [{"text": system_message}]
|
||||
|
||||
# Add tool config if present
|
||||
if tool_config:
|
||||
converse_params["toolConfig"] = tool_config
|
||||
|
||||
# Make API call
|
||||
response = self.client.converse(**converse_params)
|
||||
|
||||
return self._parse_response(response, tools)
|
||||
|
||||
def _generate_standard(self, messages: List[Dict[str, str]], stream: bool = False) -> str:
|
||||
"""Generate standard text response using Converse API for Anthropic models."""
|
||||
# For Anthropic models, always use Converse API
|
||||
if self.provider == "anthropic":
|
||||
formatted_messages, system_message = self._format_messages_anthropic(messages)
|
||||
|
||||
# Prepare converse parameters
|
||||
converse_params = {
|
||||
"modelId": self.config.model,
|
||||
"messages": formatted_messages,
|
||||
"inferenceConfig": {
|
||||
"maxTokens": self.model_config.get("max_tokens", 2000),
|
||||
"temperature": self.model_config.get("temperature", 0.1),
|
||||
"topP": self.model_config.get("top_p", 0.9),
|
||||
}
|
||||
}
|
||||
|
||||
# Add system message if present
|
||||
if system_message:
|
||||
converse_params["system"] = [{"text": system_message}]
|
||||
|
||||
# Use converse API for Anthropic models
|
||||
response = self.client.converse(**converse_params)
|
||||
|
||||
# Parse Converse API response
|
||||
if hasattr(response, 'output') and hasattr(response.output, 'message'):
|
||||
return response.output.message.content[0].text
|
||||
elif 'output' in response and 'message' in response['output']:
|
||||
return response['output']['message']['content'][0]['text']
|
||||
else:
|
||||
return str(response)
|
||||
|
||||
elif self.provider == "amazon" and "nova" in self.config.model.lower():
|
||||
# Nova models use converse API even without tools
|
||||
formatted_messages = self._format_messages_amazon(messages)
|
||||
input_body = {
|
||||
"messages": formatted_messages,
|
||||
"max_tokens": self.model_config.get("max_tokens", 5000),
|
||||
"temperature": self.model_config.get("temperature", 0.1),
|
||||
"top_p": self.model_config.get("top_p", 0.9),
|
||||
}
|
||||
|
||||
# Use converse API for Nova models
|
||||
response = self.client.converse(
|
||||
modelId=self.config.model,
|
||||
messages=input_body["messages"],
|
||||
inferenceConfig={
|
||||
"maxTokens": input_body["max_tokens"],
|
||||
"temperature": input_body["temperature"],
|
||||
"topP": input_body["top_p"],
|
||||
}
|
||||
)
|
||||
|
||||
return self._parse_response(response)
|
||||
else:
|
||||
# For other providers and legacy Amazon models (like Titan)
|
||||
if self.provider == "amazon":
|
||||
# Legacy Amazon models need string formatting, not array formatting
|
||||
prompt = self._format_messages_generic(messages)
|
||||
else:
|
||||
prompt = self._format_messages(messages)
|
||||
input_body = self._prepare_input(prompt)
|
||||
|
||||
# Convert to JSON
|
||||
body = json.dumps(input_body)
|
||||
|
||||
# Make API call
|
||||
response = self.client.invoke_model(
|
||||
body=body,
|
||||
modelId=self.config.model,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
|
||||
return self._parse_response(response)
|
||||
|
||||
def list_available_models(self) -> List[Dict[str, Any]]:
|
||||
"""List all available models in the current region."""
|
||||
try:
|
||||
bedrock_client = boto3.client("bedrock", **self.config.get_aws_config())
|
||||
response = bedrock_client.list_foundation_models()
|
||||
|
||||
models = []
|
||||
for model in response["modelSummaries"]:
|
||||
provider = extract_provider(model["modelId"])
|
||||
models.append(
|
||||
{
|
||||
"model_id": model["modelId"],
|
||||
"provider": provider,
|
||||
"model_name": model["modelId"].split(".", 1)[1]
|
||||
if "." in model["modelId"]
|
||||
else model["modelId"],
|
||||
"modelArn": model.get("modelArn", ""),
|
||||
"providerName": model.get("providerName", ""),
|
||||
"inputModalities": model.get("inputModalities", []),
|
||||
"outputModalities": model.get("outputModalities", []),
|
||||
"responseStreamingSupported": model.get("responseStreamingSupported", False),
|
||||
}
|
||||
)
|
||||
|
||||
return models
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not list models: {e}")
|
||||
return []
|
||||
|
||||
def get_model_capabilities(self) -> Dict[str, Any]:
|
||||
"""Get capabilities of the current model."""
|
||||
return {
|
||||
"model_id": self.config.model,
|
||||
"provider": self.provider,
|
||||
"model_name": self.config.model_name,
|
||||
"supports_tools": self.supports_tools,
|
||||
"supports_vision": self.supports_vision,
|
||||
"supports_streaming": self.supports_streaming,
|
||||
"max_tokens": self.model_config.get("max_tokens", 2000),
|
||||
}
|
||||
|
||||
def validate_model_access(self) -> bool:
|
||||
"""Validate if the model is accessible."""
|
||||
try:
|
||||
# Try to invoke the model with a minimal request
|
||||
if self.provider == "amazon" and "nova" in self.config.model.lower():
|
||||
# Test Nova model with converse API
|
||||
test_messages = [{"role": "user", "content": "test"}]
|
||||
self.client.converse(
|
||||
modelId=self.config.model,
|
||||
messages=test_messages,
|
||||
inferenceConfig={"maxTokens": 10}
|
||||
)
|
||||
else:
|
||||
# Test other models with invoke_model
|
||||
test_body = json.dumps({"prompt": "test"})
|
||||
self.client.invoke_model(
|
||||
body=test_body,
|
||||
modelId=self.config.model,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
Reference in New Issue
Block a user