first commit
This commit is contained in:
0
llms/__init__.py
Normal file
0
llms/__init__.py
Normal file
87
llms/anthropic.py
Normal file
87
llms/anthropic.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import os
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
try:
|
||||
import anthropic
|
||||
except ImportError:
|
||||
raise ImportError("The 'anthropic' library is required. Please install it using 'pip install anthropic'.")
|
||||
|
||||
from mem0.configs.llms.anthropic import AnthropicConfig
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
|
||||
class AnthropicLLM(LLMBase):
|
||||
def __init__(self, config: Optional[Union[BaseLlmConfig, AnthropicConfig, Dict]] = None):
|
||||
# Convert to AnthropicConfig if needed
|
||||
if config is None:
|
||||
config = AnthropicConfig()
|
||||
elif isinstance(config, dict):
|
||||
config = AnthropicConfig(**config)
|
||||
elif isinstance(config, BaseLlmConfig) and not isinstance(config, AnthropicConfig):
|
||||
# Convert BaseLlmConfig to AnthropicConfig
|
||||
config = AnthropicConfig(
|
||||
model=config.model,
|
||||
temperature=config.temperature,
|
||||
api_key=config.api_key,
|
||||
max_tokens=config.max_tokens,
|
||||
top_p=config.top_p,
|
||||
top_k=config.top_k,
|
||||
enable_vision=config.enable_vision,
|
||||
vision_details=config.vision_details,
|
||||
http_client_proxies=config.http_client,
|
||||
)
|
||||
|
||||
super().__init__(config)
|
||||
|
||||
if not self.config.model:
|
||||
self.config.model = "claude-3-5-sonnet-20240620"
|
||||
|
||||
api_key = self.config.api_key or os.getenv("ANTHROPIC_API_KEY")
|
||||
self.client = anthropic.Anthropic(api_key=api_key)
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_format=None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: str = "auto",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Generate a response based on the given messages using Anthropic.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dicts containing 'role' and 'content'.
|
||||
response_format (str or object, optional): Format of the response. Defaults to "text".
|
||||
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||
**kwargs: Additional Anthropic-specific parameters.
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
"""
|
||||
# Separate system message from other messages
|
||||
system_message = ""
|
||||
filtered_messages = []
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
system_message = message["content"]
|
||||
else:
|
||||
filtered_messages.append(message)
|
||||
|
||||
params = self._get_supported_params(messages=messages, **kwargs)
|
||||
params.update(
|
||||
{
|
||||
"model": self.config.model,
|
||||
"messages": filtered_messages,
|
||||
"system": system_message,
|
||||
}
|
||||
)
|
||||
|
||||
if tools: # TODO: Remove tools if no issues found with new memory addition logic
|
||||
params["tools"] = tools
|
||||
params["tool_choice"] = tool_choice
|
||||
|
||||
response = self.client.messages.create(**params)
|
||||
return response.content[0].text
|
||||
665
llms/aws_bedrock.py
Normal file
665
llms/aws_bedrock.py
Normal file
@@ -0,0 +1,665 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
try:
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError, NoCredentialsError
|
||||
except ImportError:
|
||||
raise ImportError("The 'boto3' library is required. Please install it using 'pip install boto3'.")
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.configs.llms.aws_bedrock import AWSBedrockConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.memory.utils import extract_json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PROVIDERS = [
|
||||
"ai21", "amazon", "anthropic", "cohere", "meta", "mistral", "stability", "writer",
|
||||
"deepseek", "gpt-oss", "perplexity", "snowflake", "titan", "command", "j2", "llama"
|
||||
]
|
||||
|
||||
|
||||
def extract_provider(model: str) -> str:
|
||||
"""Extract provider from model identifier."""
|
||||
for provider in PROVIDERS:
|
||||
if re.search(rf"\b{re.escape(provider)}\b", model):
|
||||
return provider
|
||||
raise ValueError(f"Unknown provider in model: {model}")
|
||||
|
||||
|
||||
class AWSBedrockLLM(LLMBase):
|
||||
"""
|
||||
AWS Bedrock LLM integration for Mem0.
|
||||
|
||||
Supports all available Bedrock models with automatic provider detection.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[Union[AWSBedrockConfig, BaseLlmConfig, Dict]] = None):
|
||||
"""
|
||||
Initialize AWS Bedrock LLM.
|
||||
|
||||
Args:
|
||||
config: AWS Bedrock configuration object
|
||||
"""
|
||||
# Convert to AWSBedrockConfig if needed
|
||||
if config is None:
|
||||
config = AWSBedrockConfig()
|
||||
elif isinstance(config, dict):
|
||||
config = AWSBedrockConfig(**config)
|
||||
elif isinstance(config, BaseLlmConfig) and not isinstance(config, AWSBedrockConfig):
|
||||
# Convert BaseLlmConfig to AWSBedrockConfig
|
||||
config = AWSBedrockConfig(
|
||||
model=config.model,
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
top_p=config.top_p,
|
||||
top_k=config.top_k,
|
||||
enable_vision=getattr(config, "enable_vision", False),
|
||||
)
|
||||
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# Initialize AWS client
|
||||
self._initialize_aws_client()
|
||||
|
||||
# Get model configuration
|
||||
self.model_config = self.config.get_model_config()
|
||||
self.provider = extract_provider(self.config.model)
|
||||
|
||||
# Initialize provider-specific settings
|
||||
self._initialize_provider_settings()
|
||||
|
||||
def _initialize_aws_client(self):
|
||||
"""Initialize AWS Bedrock client with proper credentials."""
|
||||
try:
|
||||
aws_config = self.config.get_aws_config()
|
||||
|
||||
# Create Bedrock runtime client
|
||||
self.client = boto3.client("bedrock-runtime", **aws_config)
|
||||
|
||||
# Test connection
|
||||
self._test_connection()
|
||||
|
||||
except NoCredentialsError:
|
||||
raise ValueError(
|
||||
"AWS credentials not found. Please set AWS_ACCESS_KEY_ID, "
|
||||
"AWS_SECRET_ACCESS_KEY, and AWS_REGION environment variables, "
|
||||
"or provide them in the config."
|
||||
)
|
||||
except ClientError as e:
|
||||
if e.response["Error"]["Code"] == "UnauthorizedOperation":
|
||||
raise ValueError(
|
||||
f"Unauthorized access to Bedrock. Please ensure your AWS credentials "
|
||||
f"have permission to access Bedrock in region {self.config.aws_region}."
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"AWS Bedrock error: {e}")
|
||||
|
||||
def _test_connection(self):
|
||||
"""Test connection to AWS Bedrock service."""
|
||||
try:
|
||||
# List available models to test connection
|
||||
bedrock_client = boto3.client("bedrock", **self.config.get_aws_config())
|
||||
response = bedrock_client.list_foundation_models()
|
||||
self.available_models = [model["modelId"] for model in response["modelSummaries"]]
|
||||
|
||||
# Check if our model is available
|
||||
if self.config.model not in self.available_models:
|
||||
logger.warning(f"Model {self.config.model} may not be available in region {self.config.aws_region}")
|
||||
logger.info(f"Available models: {', '.join(self.available_models[:5])}...")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not verify model availability: {e}")
|
||||
self.available_models = []
|
||||
|
||||
def _initialize_provider_settings(self):
|
||||
"""Initialize provider-specific settings and capabilities."""
|
||||
# Determine capabilities based on provider and model
|
||||
self.supports_tools = self.provider in ["anthropic", "cohere", "amazon"]
|
||||
self.supports_vision = self.provider in ["anthropic", "amazon", "meta", "mistral"]
|
||||
self.supports_streaming = self.provider in ["anthropic", "cohere", "mistral", "amazon", "meta"]
|
||||
|
||||
# Set message formatting method
|
||||
if self.provider == "anthropic":
|
||||
self._format_messages = self._format_messages_anthropic
|
||||
elif self.provider == "cohere":
|
||||
self._format_messages = self._format_messages_cohere
|
||||
elif self.provider == "amazon":
|
||||
self._format_messages = self._format_messages_amazon
|
||||
elif self.provider == "meta":
|
||||
self._format_messages = self._format_messages_meta
|
||||
elif self.provider == "mistral":
|
||||
self._format_messages = self._format_messages_mistral
|
||||
else:
|
||||
self._format_messages = self._format_messages_generic
|
||||
|
||||
def _format_messages_anthropic(self, messages: List[Dict[str, str]]) -> tuple[List[Dict[str, Any]], Optional[str]]:
|
||||
"""Format messages for Anthropic models."""
|
||||
formatted_messages = []
|
||||
system_message = None
|
||||
|
||||
for message in messages:
|
||||
role = message["role"]
|
||||
content = message["content"]
|
||||
|
||||
if role == "system":
|
||||
# Anthropic supports system messages as a separate parameter
|
||||
# see: https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/system-prompts
|
||||
system_message = content
|
||||
elif role == "user":
|
||||
# Use Converse API format
|
||||
formatted_messages.append({"role": "user", "content": [{"text": content}]})
|
||||
elif role == "assistant":
|
||||
# Use Converse API format
|
||||
formatted_messages.append({"role": "assistant", "content": [{"text": content}]})
|
||||
|
||||
return formatted_messages, system_message
|
||||
|
||||
def _format_messages_cohere(self, messages: List[Dict[str, str]]) -> str:
|
||||
"""Format messages for Cohere models."""
|
||||
formatted_messages = []
|
||||
|
||||
for message in messages:
|
||||
role = message["role"].capitalize()
|
||||
content = message["content"]
|
||||
formatted_messages.append(f"{role}: {content}")
|
||||
|
||||
return "\n".join(formatted_messages)
|
||||
|
||||
def _format_messages_amazon(self, messages: List[Dict[str, str]]) -> List[Dict[str, Any]]:
|
||||
"""Format messages for Amazon models (including Nova)."""
|
||||
formatted_messages = []
|
||||
|
||||
for message in messages:
|
||||
role = message["role"]
|
||||
content = message["content"]
|
||||
|
||||
if role == "system":
|
||||
# Amazon models support system messages
|
||||
formatted_messages.append({"role": "system", "content": content})
|
||||
elif role == "user":
|
||||
formatted_messages.append({"role": "user", "content": content})
|
||||
elif role == "assistant":
|
||||
formatted_messages.append({"role": "assistant", "content": content})
|
||||
|
||||
return formatted_messages
|
||||
|
||||
def _format_messages_meta(self, messages: List[Dict[str, str]]) -> str:
|
||||
"""Format messages for Meta models."""
|
||||
formatted_messages = []
|
||||
|
||||
for message in messages:
|
||||
role = message["role"].capitalize()
|
||||
content = message["content"]
|
||||
formatted_messages.append(f"{role}: {content}")
|
||||
|
||||
return "\n".join(formatted_messages)
|
||||
|
||||
def _format_messages_mistral(self, messages: List[Dict[str, str]]) -> List[Dict[str, Any]]:
|
||||
"""Format messages for Mistral models."""
|
||||
formatted_messages = []
|
||||
|
||||
for message in messages:
|
||||
role = message["role"]
|
||||
content = message["content"]
|
||||
|
||||
if role == "system":
|
||||
# Mistral supports system messages
|
||||
formatted_messages.append({"role": "system", "content": content})
|
||||
elif role == "user":
|
||||
formatted_messages.append({"role": "user", "content": content})
|
||||
elif role == "assistant":
|
||||
formatted_messages.append({"role": "assistant", "content": content})
|
||||
|
||||
return formatted_messages
|
||||
|
||||
def _format_messages_generic(self, messages: List[Dict[str, str]]) -> str:
|
||||
"""Generic message formatting for other providers."""
|
||||
formatted_messages = []
|
||||
|
||||
for message in messages:
|
||||
role = message["role"].capitalize()
|
||||
content = message["content"]
|
||||
formatted_messages.append(f"\n\n{role}: {content}")
|
||||
|
||||
return "\n\nHuman: " + "".join(formatted_messages) + "\n\nAssistant:"
|
||||
|
||||
def _prepare_input(self, prompt: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Prepare input for the current provider's model.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt to process
|
||||
|
||||
Returns:
|
||||
Prepared input dictionary
|
||||
"""
|
||||
# Base configuration
|
||||
input_body = {"prompt": prompt}
|
||||
|
||||
# Provider-specific parameter mappings
|
||||
provider_mappings = {
|
||||
"meta": {"max_tokens": "max_gen_len"},
|
||||
"ai21": {"max_tokens": "maxTokens", "top_p": "topP"},
|
||||
"mistral": {"max_tokens": "max_tokens"},
|
||||
"cohere": {"max_tokens": "max_tokens", "top_p": "p"},
|
||||
"amazon": {"max_tokens": "maxTokenCount", "top_p": "topP"},
|
||||
"anthropic": {"max_tokens": "max_tokens", "top_p": "top_p"},
|
||||
}
|
||||
|
||||
# Apply provider mappings
|
||||
if self.provider in provider_mappings:
|
||||
for old_key, new_key in provider_mappings[self.provider].items():
|
||||
if old_key in self.model_config:
|
||||
input_body[new_key] = self.model_config[old_key]
|
||||
|
||||
# Special handling for specific providers
|
||||
if self.provider == "cohere" and "cohere.command" in self.config.model:
|
||||
input_body["message"] = input_body.pop("prompt")
|
||||
elif self.provider == "amazon":
|
||||
# Amazon Nova and other Amazon models
|
||||
if "nova" in self.config.model.lower():
|
||||
# Nova models use the converse API format
|
||||
input_body = {
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": self.model_config.get("max_tokens", 5000),
|
||||
"temperature": self.model_config.get("temperature", 0.1),
|
||||
"top_p": self.model_config.get("top_p", 0.9),
|
||||
}
|
||||
else:
|
||||
# Legacy Amazon models
|
||||
input_body = {
|
||||
"inputText": prompt,
|
||||
"textGenerationConfig": {
|
||||
"maxTokenCount": self.model_config.get("max_tokens", 5000),
|
||||
"topP": self.model_config.get("top_p", 0.9),
|
||||
"temperature": self.model_config.get("temperature", 0.1),
|
||||
},
|
||||
}
|
||||
# Remove None values
|
||||
input_body["textGenerationConfig"] = {
|
||||
k: v for k, v in input_body["textGenerationConfig"].items() if v is not None
|
||||
}
|
||||
elif self.provider == "anthropic":
|
||||
input_body = {
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": prompt}]}],
|
||||
"max_tokens": self.model_config.get("max_tokens", 2000),
|
||||
"temperature": self.model_config.get("temperature", 0.1),
|
||||
"top_p": self.model_config.get("top_p", 0.9),
|
||||
"anthropic_version": "bedrock-2023-05-31",
|
||||
}
|
||||
elif self.provider == "meta":
|
||||
input_body = {
|
||||
"prompt": prompt,
|
||||
"max_gen_len": self.model_config.get("max_tokens", 5000),
|
||||
"temperature": self.model_config.get("temperature", 0.1),
|
||||
"top_p": self.model_config.get("top_p", 0.9),
|
||||
}
|
||||
elif self.provider == "mistral":
|
||||
input_body = {
|
||||
"prompt": prompt,
|
||||
"max_tokens": self.model_config.get("max_tokens", 5000),
|
||||
"temperature": self.model_config.get("temperature", 0.1),
|
||||
"top_p": self.model_config.get("top_p", 0.9),
|
||||
}
|
||||
else:
|
||||
# Generic case - add all model config parameters
|
||||
input_body.update(self.model_config)
|
||||
|
||||
return input_body
|
||||
|
||||
def _convert_tool_format(self, original_tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Convert tools to Bedrock-compatible format.
|
||||
|
||||
Args:
|
||||
original_tools: List of tool definitions
|
||||
|
||||
Returns:
|
||||
Converted tools in Bedrock format
|
||||
"""
|
||||
new_tools = []
|
||||
|
||||
for tool in original_tools:
|
||||
if tool["type"] == "function":
|
||||
function = tool["function"]
|
||||
new_tool = {
|
||||
"toolSpec": {
|
||||
"name": function["name"],
|
||||
"description": function.get("description", ""),
|
||||
"inputSchema": {
|
||||
"json": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": function["parameters"].get("required", []),
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
# Add properties
|
||||
for prop, details in function["parameters"].get("properties", {}).items():
|
||||
new_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop] = details
|
||||
|
||||
new_tools.append(new_tool)
|
||||
|
||||
return new_tools
|
||||
|
||||
def _parse_response(
|
||||
self, response: Dict[str, Any], tools: Optional[List[Dict]] = None
|
||||
) -> Union[str, Dict[str, Any]]:
|
||||
"""
|
||||
Parse response from Bedrock API.
|
||||
|
||||
Args:
|
||||
response: Raw API response
|
||||
tools: List of tools if used
|
||||
|
||||
Returns:
|
||||
Parsed response
|
||||
"""
|
||||
if tools:
|
||||
# Handle tool-enabled responses
|
||||
processed_response = {"tool_calls": []}
|
||||
|
||||
if response.get("output", {}).get("message", {}).get("content"):
|
||||
for item in response["output"]["message"]["content"]:
|
||||
if "toolUse" in item:
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": item["toolUse"]["name"],
|
||||
"arguments": json.loads(extract_json(json.dumps(item["toolUse"]["input"]))),
|
||||
}
|
||||
)
|
||||
|
||||
return processed_response
|
||||
|
||||
# Handle regular text responses
|
||||
try:
|
||||
response_body = response.get("body").read().decode()
|
||||
response_json = json.loads(response_body)
|
||||
|
||||
# Provider-specific response parsing
|
||||
if self.provider == "anthropic":
|
||||
return response_json.get("content", [{"text": ""}])[0].get("text", "")
|
||||
elif self.provider == "amazon":
|
||||
# Handle both Nova and legacy Amazon models
|
||||
if "nova" in self.config.model.lower():
|
||||
# Nova models return content in a different format
|
||||
if "content" in response_json:
|
||||
return response_json["content"][0]["text"]
|
||||
elif "completion" in response_json:
|
||||
return response_json["completion"]
|
||||
else:
|
||||
# Legacy Amazon models
|
||||
return response_json.get("completion", "")
|
||||
elif self.provider == "meta":
|
||||
return response_json.get("generation", "")
|
||||
elif self.provider == "mistral":
|
||||
return response_json.get("outputs", [{"text": ""}])[0].get("text", "")
|
||||
elif self.provider == "cohere":
|
||||
return response_json.get("generations", [{"text": ""}])[0].get("text", "")
|
||||
elif self.provider == "ai21":
|
||||
return response_json.get("completions", [{"data", {"text": ""}}])[0].get("data", {}).get("text", "")
|
||||
else:
|
||||
# Generic parsing - try common response fields
|
||||
for field in ["content", "text", "completion", "generation"]:
|
||||
if field in response_json:
|
||||
if isinstance(response_json[field], list) and response_json[field]:
|
||||
return response_json[field][0].get("text", "")
|
||||
elif isinstance(response_json[field], str):
|
||||
return response_json[field]
|
||||
|
||||
# Fallback
|
||||
return str(response_json)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not parse response: {e}")
|
||||
return "Error parsing response"
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_format: Optional[str] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: str = "auto",
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
) -> Union[str, Dict[str, Any]]:
|
||||
"""
|
||||
Generate response using AWS Bedrock.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
response_format: Response format specification
|
||||
tools: List of tools for function calling
|
||||
tool_choice: Tool choice method
|
||||
stream: Whether to stream the response
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
Generated response
|
||||
"""
|
||||
try:
|
||||
if tools and self.supports_tools:
|
||||
# Use converse method for tool-enabled models
|
||||
return self._generate_with_tools(messages, tools, stream)
|
||||
else:
|
||||
# Use standard invoke_model method
|
||||
return self._generate_standard(messages, stream)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate response: {e}")
|
||||
raise RuntimeError(f"Failed to generate response: {e}")
|
||||
|
||||
@staticmethod
|
||||
def _convert_tools_to_converse_format(tools: List[Dict]) -> List[Dict]:
|
||||
"""Convert OpenAI-style tools to Converse API format."""
|
||||
if not tools:
|
||||
return []
|
||||
|
||||
converse_tools = []
|
||||
for tool in tools:
|
||||
if tool.get("type") == "function" and "function" in tool:
|
||||
func = tool["function"]
|
||||
converse_tool = {
|
||||
"toolSpec": {
|
||||
"name": func["name"],
|
||||
"description": func.get("description", ""),
|
||||
"inputSchema": {
|
||||
"json": func.get("parameters", {})
|
||||
}
|
||||
}
|
||||
}
|
||||
converse_tools.append(converse_tool)
|
||||
|
||||
return converse_tools
|
||||
|
||||
def _generate_with_tools(self, messages: List[Dict[str, str]], tools: List[Dict], stream: bool = False) -> Dict[str, Any]:
|
||||
"""Generate response with tool calling support using correct message format."""
|
||||
# Format messages for tool-enabled models
|
||||
system_message = None
|
||||
if self.provider == "anthropic":
|
||||
formatted_messages, system_message = self._format_messages_anthropic(messages)
|
||||
elif self.provider == "amazon":
|
||||
formatted_messages = self._format_messages_amazon(messages)
|
||||
else:
|
||||
formatted_messages = [{"role": "user", "content": [{"text": messages[-1]["content"]}]}]
|
||||
|
||||
# Prepare tool configuration in Converse API format
|
||||
tool_config = None
|
||||
if tools:
|
||||
converse_tools = self._convert_tools_to_converse_format(tools)
|
||||
if converse_tools:
|
||||
tool_config = {"tools": converse_tools}
|
||||
|
||||
# Prepare converse parameters
|
||||
converse_params = {
|
||||
"modelId": self.config.model,
|
||||
"messages": formatted_messages,
|
||||
"inferenceConfig": {
|
||||
"maxTokens": self.model_config.get("max_tokens", 2000),
|
||||
"temperature": self.model_config.get("temperature", 0.1),
|
||||
"topP": self.model_config.get("top_p", 0.9),
|
||||
}
|
||||
}
|
||||
|
||||
# Add system message if present (for Anthropic)
|
||||
if system_message:
|
||||
converse_params["system"] = [{"text": system_message}]
|
||||
|
||||
# Add tool config if present
|
||||
if tool_config:
|
||||
converse_params["toolConfig"] = tool_config
|
||||
|
||||
# Make API call
|
||||
response = self.client.converse(**converse_params)
|
||||
|
||||
return self._parse_response(response, tools)
|
||||
|
||||
def _generate_standard(self, messages: List[Dict[str, str]], stream: bool = False) -> str:
|
||||
"""Generate standard text response using Converse API for Anthropic models."""
|
||||
# For Anthropic models, always use Converse API
|
||||
if self.provider == "anthropic":
|
||||
formatted_messages, system_message = self._format_messages_anthropic(messages)
|
||||
|
||||
# Prepare converse parameters
|
||||
converse_params = {
|
||||
"modelId": self.config.model,
|
||||
"messages": formatted_messages,
|
||||
"inferenceConfig": {
|
||||
"maxTokens": self.model_config.get("max_tokens", 2000),
|
||||
"temperature": self.model_config.get("temperature", 0.1),
|
||||
"topP": self.model_config.get("top_p", 0.9),
|
||||
}
|
||||
}
|
||||
|
||||
# Add system message if present
|
||||
if system_message:
|
||||
converse_params["system"] = [{"text": system_message}]
|
||||
|
||||
# Use converse API for Anthropic models
|
||||
response = self.client.converse(**converse_params)
|
||||
|
||||
# Parse Converse API response
|
||||
if hasattr(response, 'output') and hasattr(response.output, 'message'):
|
||||
return response.output.message.content[0].text
|
||||
elif 'output' in response and 'message' in response['output']:
|
||||
return response['output']['message']['content'][0]['text']
|
||||
else:
|
||||
return str(response)
|
||||
|
||||
elif self.provider == "amazon" and "nova" in self.config.model.lower():
|
||||
# Nova models use converse API even without tools
|
||||
formatted_messages = self._format_messages_amazon(messages)
|
||||
input_body = {
|
||||
"messages": formatted_messages,
|
||||
"max_tokens": self.model_config.get("max_tokens", 5000),
|
||||
"temperature": self.model_config.get("temperature", 0.1),
|
||||
"top_p": self.model_config.get("top_p", 0.9),
|
||||
}
|
||||
|
||||
# Use converse API for Nova models
|
||||
response = self.client.converse(
|
||||
modelId=self.config.model,
|
||||
messages=input_body["messages"],
|
||||
inferenceConfig={
|
||||
"maxTokens": input_body["max_tokens"],
|
||||
"temperature": input_body["temperature"],
|
||||
"topP": input_body["top_p"],
|
||||
}
|
||||
)
|
||||
|
||||
return self._parse_response(response)
|
||||
else:
|
||||
# For other providers and legacy Amazon models (like Titan)
|
||||
if self.provider == "amazon":
|
||||
# Legacy Amazon models need string formatting, not array formatting
|
||||
prompt = self._format_messages_generic(messages)
|
||||
else:
|
||||
prompt = self._format_messages(messages)
|
||||
input_body = self._prepare_input(prompt)
|
||||
|
||||
# Convert to JSON
|
||||
body = json.dumps(input_body)
|
||||
|
||||
# Make API call
|
||||
response = self.client.invoke_model(
|
||||
body=body,
|
||||
modelId=self.config.model,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
|
||||
return self._parse_response(response)
|
||||
|
||||
def list_available_models(self) -> List[Dict[str, Any]]:
|
||||
"""List all available models in the current region."""
|
||||
try:
|
||||
bedrock_client = boto3.client("bedrock", **self.config.get_aws_config())
|
||||
response = bedrock_client.list_foundation_models()
|
||||
|
||||
models = []
|
||||
for model in response["modelSummaries"]:
|
||||
provider = extract_provider(model["modelId"])
|
||||
models.append(
|
||||
{
|
||||
"model_id": model["modelId"],
|
||||
"provider": provider,
|
||||
"model_name": model["modelId"].split(".", 1)[1]
|
||||
if "." in model["modelId"]
|
||||
else model["modelId"],
|
||||
"modelArn": model.get("modelArn", ""),
|
||||
"providerName": model.get("providerName", ""),
|
||||
"inputModalities": model.get("inputModalities", []),
|
||||
"outputModalities": model.get("outputModalities", []),
|
||||
"responseStreamingSupported": model.get("responseStreamingSupported", False),
|
||||
}
|
||||
)
|
||||
|
||||
return models
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not list models: {e}")
|
||||
return []
|
||||
|
||||
def get_model_capabilities(self) -> Dict[str, Any]:
|
||||
"""Get capabilities of the current model."""
|
||||
return {
|
||||
"model_id": self.config.model,
|
||||
"provider": self.provider,
|
||||
"model_name": self.config.model_name,
|
||||
"supports_tools": self.supports_tools,
|
||||
"supports_vision": self.supports_vision,
|
||||
"supports_streaming": self.supports_streaming,
|
||||
"max_tokens": self.model_config.get("max_tokens", 2000),
|
||||
}
|
||||
|
||||
def validate_model_access(self) -> bool:
|
||||
"""Validate if the model is accessible."""
|
||||
try:
|
||||
# Try to invoke the model with a minimal request
|
||||
if self.provider == "amazon" and "nova" in self.config.model.lower():
|
||||
# Test Nova model with converse API
|
||||
test_messages = [{"role": "user", "content": "test"}]
|
||||
self.client.converse(
|
||||
modelId=self.config.model,
|
||||
messages=test_messages,
|
||||
inferenceConfig={"maxTokens": 10}
|
||||
)
|
||||
else:
|
||||
# Test other models with invoke_model
|
||||
test_body = json.dumps({"prompt": "test"})
|
||||
self.client.invoke_model(
|
||||
body=test_body,
|
||||
modelId=self.config.model,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
141
llms/azure_openai.py
Normal file
141
llms/azure_openai.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
||||
from openai import AzureOpenAI
|
||||
|
||||
from mem0.configs.llms.azure import AzureOpenAIConfig
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.memory.utils import extract_json
|
||||
|
||||
SCOPE = "https://cognitiveservices.azure.com/.default"
|
||||
|
||||
|
||||
class AzureOpenAILLM(LLMBase):
|
||||
def __init__(self, config: Optional[Union[BaseLlmConfig, AzureOpenAIConfig, Dict]] = None):
|
||||
# Convert to AzureOpenAIConfig if needed
|
||||
if config is None:
|
||||
config = AzureOpenAIConfig()
|
||||
elif isinstance(config, dict):
|
||||
config = AzureOpenAIConfig(**config)
|
||||
elif isinstance(config, BaseLlmConfig) and not isinstance(config, AzureOpenAIConfig):
|
||||
# Convert BaseLlmConfig to AzureOpenAIConfig
|
||||
config = AzureOpenAIConfig(
|
||||
model=config.model,
|
||||
temperature=config.temperature,
|
||||
api_key=config.api_key,
|
||||
max_tokens=config.max_tokens,
|
||||
top_p=config.top_p,
|
||||
top_k=config.top_k,
|
||||
enable_vision=config.enable_vision,
|
||||
vision_details=config.vision_details,
|
||||
http_client_proxies=config.http_client,
|
||||
)
|
||||
|
||||
super().__init__(config)
|
||||
|
||||
# Model name should match the custom deployment name chosen for it.
|
||||
if not self.config.model:
|
||||
self.config.model = "gpt-4.1-nano-2025-04-14"
|
||||
|
||||
api_key = self.config.azure_kwargs.api_key or os.getenv("LLM_AZURE_OPENAI_API_KEY")
|
||||
azure_deployment = self.config.azure_kwargs.azure_deployment or os.getenv("LLM_AZURE_DEPLOYMENT")
|
||||
azure_endpoint = self.config.azure_kwargs.azure_endpoint or os.getenv("LLM_AZURE_ENDPOINT")
|
||||
api_version = self.config.azure_kwargs.api_version or os.getenv("LLM_AZURE_API_VERSION")
|
||||
default_headers = self.config.azure_kwargs.default_headers
|
||||
|
||||
# If the API key is not provided or is a placeholder, use DefaultAzureCredential.
|
||||
if api_key is None or api_key == "" or api_key == "your-api-key":
|
||||
self.credential = DefaultAzureCredential()
|
||||
azure_ad_token_provider = get_bearer_token_provider(
|
||||
self.credential,
|
||||
SCOPE,
|
||||
)
|
||||
api_key = None
|
||||
else:
|
||||
azure_ad_token_provider = None
|
||||
|
||||
self.client = AzureOpenAI(
|
||||
azure_deployment=azure_deployment,
|
||||
azure_endpoint=azure_endpoint,
|
||||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
api_version=api_version,
|
||||
api_key=api_key,
|
||||
http_client=self.config.http_client,
|
||||
default_headers=default_headers,
|
||||
)
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
"""
|
||||
Process the response based on whether tools are used or not.
|
||||
|
||||
Args:
|
||||
response: The raw response from API.
|
||||
tools: The list of tools provided in the request.
|
||||
|
||||
Returns:
|
||||
str or dict: The processed response.
|
||||
"""
|
||||
if tools:
|
||||
processed_response = {
|
||||
"content": response.choices[0].message.content,
|
||||
"tool_calls": [],
|
||||
}
|
||||
|
||||
if response.choices[0].message.tool_calls:
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(extract_json(tool_call.function.arguments)),
|
||||
}
|
||||
)
|
||||
|
||||
return processed_response
|
||||
else:
|
||||
return response.choices[0].message.content
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_format=None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: str = "auto",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Generate a response based on the given messages using Azure OpenAI.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dicts containing 'role' and 'content'.
|
||||
response_format (str or object, optional): Format of the response. Defaults to "text".
|
||||
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||
**kwargs: Additional Azure OpenAI-specific parameters.
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
"""
|
||||
|
||||
user_prompt = messages[-1]["content"]
|
||||
|
||||
user_prompt = user_prompt.replace("assistant", "ai")
|
||||
|
||||
messages[-1]["content"] = user_prompt
|
||||
|
||||
params = self._get_supported_params(messages=messages, **kwargs)
|
||||
|
||||
# Add model and messages
|
||||
params.update({
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
})
|
||||
|
||||
if tools:
|
||||
params["tools"] = tools
|
||||
params["tool_choice"] = tool_choice
|
||||
|
||||
response = self.client.chat.completions.create(**params)
|
||||
return self._parse_response(response, tools)
|
||||
91
llms/azure_openai_structured.py
Normal file
91
llms/azure_openai_structured.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
||||
from openai import AzureOpenAI
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
SCOPE = "https://cognitiveservices.azure.com/.default"
|
||||
|
||||
|
||||
class AzureOpenAIStructuredLLM(LLMBase):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
# Model name should match the custom deployment name chosen for it.
|
||||
if not self.config.model:
|
||||
self.config.model = "gpt-4.1-nano-2025-04-14"
|
||||
|
||||
api_key = self.config.azure_kwargs.api_key or os.getenv("LLM_AZURE_OPENAI_API_KEY")
|
||||
azure_deployment = self.config.azure_kwargs.azure_deployment or os.getenv("LLM_AZURE_DEPLOYMENT")
|
||||
azure_endpoint = self.config.azure_kwargs.azure_endpoint or os.getenv("LLM_AZURE_ENDPOINT")
|
||||
api_version = self.config.azure_kwargs.api_version or os.getenv("LLM_AZURE_API_VERSION")
|
||||
default_headers = self.config.azure_kwargs.default_headers
|
||||
|
||||
# If the API key is not provided or is a placeholder, use DefaultAzureCredential.
|
||||
if api_key is None or api_key == "" or api_key == "your-api-key":
|
||||
self.credential = DefaultAzureCredential()
|
||||
azure_ad_token_provider = get_bearer_token_provider(
|
||||
self.credential,
|
||||
SCOPE,
|
||||
)
|
||||
api_key = None
|
||||
else:
|
||||
azure_ad_token_provider = None
|
||||
|
||||
# Can display a warning if API version is of model and api-version
|
||||
self.client = AzureOpenAI(
|
||||
azure_deployment=azure_deployment,
|
||||
azure_endpoint=azure_endpoint,
|
||||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
api_version=api_version,
|
||||
api_key=api_key,
|
||||
http_client=self.config.http_client,
|
||||
default_headers=default_headers,
|
||||
)
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_format: Optional[str] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: str = "auto",
|
||||
) -> str:
|
||||
"""
|
||||
Generate a response based on the given messages using Azure OpenAI.
|
||||
|
||||
Args:
|
||||
messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
|
||||
response_format (Optional[str]): The desired format of the response. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
"""
|
||||
|
||||
user_prompt = messages[-1]["content"]
|
||||
|
||||
user_prompt = user_prompt.replace("assistant", "ai")
|
||||
|
||||
messages[-1]["content"] = user_prompt
|
||||
|
||||
params = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"temperature": self.config.temperature,
|
||||
"max_tokens": self.config.max_tokens,
|
||||
"top_p": self.config.top_p,
|
||||
}
|
||||
if response_format:
|
||||
params["response_format"] = response_format
|
||||
if tools:
|
||||
params["tools"] = tools
|
||||
params["tool_choice"] = tool_choice
|
||||
|
||||
if tools:
|
||||
params["tools"] = tools
|
||||
params["tool_choice"] = tool_choice
|
||||
|
||||
response = self.client.chat.completions.create(**params)
|
||||
return self._parse_response(response, tools)
|
||||
131
llms/base.py
Normal file
131
llms/base.py
Normal file
@@ -0,0 +1,131 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
|
||||
|
||||
class LLMBase(ABC):
|
||||
"""
|
||||
Base class for all LLM providers.
|
||||
Handles common functionality and delegates provider-specific logic to subclasses.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[Union[BaseLlmConfig, Dict]] = None):
|
||||
"""Initialize a base LLM class
|
||||
|
||||
:param config: LLM configuration option class or dict, defaults to None
|
||||
:type config: Optional[Union[BaseLlmConfig, Dict]], optional
|
||||
"""
|
||||
if config is None:
|
||||
self.config = BaseLlmConfig()
|
||||
elif isinstance(config, dict):
|
||||
# Handle dict-based configuration (backward compatibility)
|
||||
self.config = BaseLlmConfig(**config)
|
||||
else:
|
||||
self.config = config
|
||||
|
||||
# Validate configuration
|
||||
self._validate_config()
|
||||
|
||||
def _validate_config(self):
|
||||
"""
|
||||
Validate the configuration.
|
||||
Override in subclasses to add provider-specific validation.
|
||||
"""
|
||||
if not hasattr(self.config, "model"):
|
||||
raise ValueError("Configuration must have a 'model' attribute")
|
||||
|
||||
if not hasattr(self.config, "api_key") and not hasattr(self.config, "api_key"):
|
||||
# Check if API key is available via environment variable
|
||||
# This will be handled by individual providers
|
||||
pass
|
||||
|
||||
def _is_reasoning_model(self, model: str) -> bool:
|
||||
"""
|
||||
Check if the model is a reasoning model or GPT-5 series that doesn't support certain parameters.
|
||||
|
||||
Args:
|
||||
model: The model name to check
|
||||
|
||||
Returns:
|
||||
bool: True if the model is a reasoning model or GPT-5 series
|
||||
"""
|
||||
reasoning_models = {
|
||||
"o1", "o1-preview", "o3-mini", "o3",
|
||||
"gpt-5", "gpt-5o", "gpt-5o-mini", "gpt-5o-micro",
|
||||
}
|
||||
|
||||
if model.lower() in reasoning_models:
|
||||
return True
|
||||
|
||||
model_lower = model.lower()
|
||||
if any(reasoning_model in model_lower for reasoning_model in ["gpt-5", "o1", "o3"]):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _get_supported_params(self, **kwargs) -> Dict:
|
||||
"""
|
||||
Get parameters that are supported by the current model.
|
||||
Filters out unsupported parameters for reasoning models and GPT-5 series.
|
||||
|
||||
Args:
|
||||
**kwargs: Additional parameters to include
|
||||
|
||||
Returns:
|
||||
Dict: Filtered parameters dictionary
|
||||
"""
|
||||
model = getattr(self.config, 'model', '')
|
||||
|
||||
if self._is_reasoning_model(model):
|
||||
supported_params = {}
|
||||
|
||||
if "messages" in kwargs:
|
||||
supported_params["messages"] = kwargs["messages"]
|
||||
if "response_format" in kwargs:
|
||||
supported_params["response_format"] = kwargs["response_format"]
|
||||
if "tools" in kwargs:
|
||||
supported_params["tools"] = kwargs["tools"]
|
||||
if "tool_choice" in kwargs:
|
||||
supported_params["tool_choice"] = kwargs["tool_choice"]
|
||||
|
||||
return supported_params
|
||||
else:
|
||||
# For regular models, include all common parameters
|
||||
return self._get_common_params(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def generate_response(
|
||||
self, messages: List[Dict[str, str]], tools: Optional[List[Dict]] = None, tool_choice: str = "auto", **kwargs
|
||||
):
|
||||
"""
|
||||
Generate a response based on the given messages.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dicts containing 'role' and 'content'.
|
||||
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||
**kwargs: Additional provider-specific parameters.
|
||||
|
||||
Returns:
|
||||
str or dict: The generated response.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _get_common_params(self, **kwargs) -> Dict:
|
||||
"""
|
||||
Get common parameters that most providers use.
|
||||
|
||||
Returns:
|
||||
Dict: Common parameters dictionary.
|
||||
"""
|
||||
params = {
|
||||
"temperature": self.config.temperature,
|
||||
"max_tokens": self.config.max_tokens,
|
||||
"top_p": self.config.top_p,
|
||||
}
|
||||
|
||||
# Add provider-specific parameters from kwargs
|
||||
params.update(kwargs)
|
||||
|
||||
return params
|
||||
34
llms/configs.py
Normal file
34
llms/configs.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class LlmConfig(BaseModel):
|
||||
provider: str = Field(description="Provider of the LLM (e.g., 'ollama', 'openai')", default="openai")
|
||||
config: Optional[dict] = Field(description="Configuration for the specific LLM", default={})
|
||||
|
||||
@field_validator("config")
|
||||
def validate_config(cls, v, values):
|
||||
provider = values.data.get("provider")
|
||||
if provider in (
|
||||
"openai",
|
||||
"ollama",
|
||||
"anthropic",
|
||||
"groq",
|
||||
"together",
|
||||
"aws_bedrock",
|
||||
"litellm",
|
||||
"azure_openai",
|
||||
"openai_structured",
|
||||
"azure_openai_structured",
|
||||
"gemini",
|
||||
"deepseek",
|
||||
"xai",
|
||||
"sarvam",
|
||||
"lmstudio",
|
||||
"vllm",
|
||||
"langchain",
|
||||
):
|
||||
return v
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
107
llms/deepseek.py
Normal file
107
llms/deepseek.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.configs.llms.deepseek import DeepSeekConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.memory.utils import extract_json
|
||||
|
||||
|
||||
class DeepSeekLLM(LLMBase):
|
||||
def __init__(self, config: Optional[Union[BaseLlmConfig, DeepSeekConfig, Dict]] = None):
|
||||
# Convert to DeepSeekConfig if needed
|
||||
if config is None:
|
||||
config = DeepSeekConfig()
|
||||
elif isinstance(config, dict):
|
||||
config = DeepSeekConfig(**config)
|
||||
elif isinstance(config, BaseLlmConfig) and not isinstance(config, DeepSeekConfig):
|
||||
# Convert BaseLlmConfig to DeepSeekConfig
|
||||
config = DeepSeekConfig(
|
||||
model=config.model,
|
||||
temperature=config.temperature,
|
||||
api_key=config.api_key,
|
||||
max_tokens=config.max_tokens,
|
||||
top_p=config.top_p,
|
||||
top_k=config.top_k,
|
||||
enable_vision=config.enable_vision,
|
||||
vision_details=config.vision_details,
|
||||
http_client_proxies=config.http_client,
|
||||
)
|
||||
|
||||
super().__init__(config)
|
||||
|
||||
if not self.config.model:
|
||||
self.config.model = "deepseek-chat"
|
||||
|
||||
api_key = self.config.api_key or os.getenv("DEEPSEEK_API_KEY")
|
||||
base_url = self.config.deepseek_base_url or os.getenv("DEEPSEEK_API_BASE") or "https://api.deepseek.com"
|
||||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
"""
|
||||
Process the response based on whether tools are used or not.
|
||||
|
||||
Args:
|
||||
response: The raw response from API.
|
||||
tools: The list of tools provided in the request.
|
||||
|
||||
Returns:
|
||||
str or dict: The processed response.
|
||||
"""
|
||||
if tools:
|
||||
processed_response = {
|
||||
"content": response.choices[0].message.content,
|
||||
"tool_calls": [],
|
||||
}
|
||||
|
||||
if response.choices[0].message.tool_calls:
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(extract_json(tool_call.function.arguments)),
|
||||
}
|
||||
)
|
||||
|
||||
return processed_response
|
||||
else:
|
||||
return response.choices[0].message.content
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_format=None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: str = "auto",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Generate a response based on the given messages using DeepSeek.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dicts containing 'role' and 'content'.
|
||||
response_format (str or object, optional): Format of the response. Defaults to "text".
|
||||
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||
**kwargs: Additional DeepSeek-specific parameters.
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
"""
|
||||
params = self._get_supported_params(messages=messages, **kwargs)
|
||||
params.update(
|
||||
{
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
}
|
||||
)
|
||||
|
||||
if tools:
|
||||
params["tools"] = tools
|
||||
params["tool_choice"] = tool_choice
|
||||
|
||||
response = self.client.chat.completions.create(**params)
|
||||
return self._parse_response(response, tools)
|
||||
201
llms/gemini.py
Normal file
201
llms/gemini.py
Normal file
@@ -0,0 +1,201 @@
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
try:
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
except ImportError:
|
||||
raise ImportError("The 'google-genai' library is required. Please install it using 'pip install google-genai'.")
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
|
||||
class GeminiLLM(LLMBase):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
if not self.config.model:
|
||||
self.config.model = "gemini-2.0-flash"
|
||||
|
||||
api_key = self.config.api_key or os.getenv("GOOGLE_API_KEY")
|
||||
self.client = genai.Client(api_key=api_key)
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
"""
|
||||
Process the response based on whether tools are used or not.
|
||||
|
||||
Args:
|
||||
response: The raw response from API.
|
||||
tools: The list of tools provided in the request.
|
||||
|
||||
Returns:
|
||||
str or dict: The processed response.
|
||||
"""
|
||||
if tools:
|
||||
processed_response = {
|
||||
"content": None,
|
||||
"tool_calls": [],
|
||||
}
|
||||
|
||||
# Extract content from the first candidate
|
||||
if response.candidates and response.candidates[0].content.parts:
|
||||
for part in response.candidates[0].content.parts:
|
||||
if hasattr(part, "text") and part.text:
|
||||
processed_response["content"] = part.text
|
||||
break
|
||||
|
||||
# Extract function calls
|
||||
if response.candidates and response.candidates[0].content.parts:
|
||||
for part in response.candidates[0].content.parts:
|
||||
if hasattr(part, "function_call") and part.function_call:
|
||||
fn = part.function_call
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": fn.name,
|
||||
"arguments": dict(fn.args) if fn.args else {},
|
||||
}
|
||||
)
|
||||
|
||||
return processed_response
|
||||
else:
|
||||
if response.candidates and response.candidates[0].content.parts:
|
||||
for part in response.candidates[0].content.parts:
|
||||
if hasattr(part, "text") and part.text:
|
||||
return part.text
|
||||
return ""
|
||||
|
||||
def _reformat_messages(self, messages: List[Dict[str, str]]):
|
||||
"""
|
||||
Reformat messages for Gemini.
|
||||
|
||||
Args:
|
||||
messages: The list of messages provided in the request.
|
||||
|
||||
Returns:
|
||||
tuple: (system_instruction, contents_list)
|
||||
"""
|
||||
system_instruction = None
|
||||
contents = []
|
||||
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
system_instruction = message["content"]
|
||||
else:
|
||||
content = types.Content(
|
||||
parts=[types.Part(text=message["content"])],
|
||||
role=message["role"],
|
||||
)
|
||||
contents.append(content)
|
||||
|
||||
return system_instruction, contents
|
||||
|
||||
def _reformat_tools(self, tools: Optional[List[Dict]]):
|
||||
"""
|
||||
Reformat tools for Gemini.
|
||||
|
||||
Args:
|
||||
tools: The list of tools provided in the request.
|
||||
|
||||
Returns:
|
||||
list: The list of tools in the required format.
|
||||
"""
|
||||
|
||||
def remove_additional_properties(data):
|
||||
"""Recursively removes 'additionalProperties' from nested dictionaries."""
|
||||
if isinstance(data, dict):
|
||||
filtered_dict = {
|
||||
key: remove_additional_properties(value)
|
||||
for key, value in data.items()
|
||||
if not (key == "additionalProperties")
|
||||
}
|
||||
return filtered_dict
|
||||
else:
|
||||
return data
|
||||
|
||||
if tools:
|
||||
function_declarations = []
|
||||
for tool in tools:
|
||||
func = tool["function"].copy()
|
||||
cleaned_func = remove_additional_properties(func)
|
||||
|
||||
function_declaration = types.FunctionDeclaration(
|
||||
name=cleaned_func["name"],
|
||||
description=cleaned_func.get("description", ""),
|
||||
parameters=cleaned_func.get("parameters", {}),
|
||||
)
|
||||
function_declarations.append(function_declaration)
|
||||
|
||||
tool_obj = types.Tool(function_declarations=function_declarations)
|
||||
return [tool_obj]
|
||||
else:
|
||||
return None
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_format=None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: str = "auto",
|
||||
):
|
||||
"""
|
||||
Generate a response based on the given messages using Gemini.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dicts containing 'role' and 'content'.
|
||||
response_format (str or object, optional): Format for the response. Defaults to "text".
|
||||
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
"""
|
||||
|
||||
# Extract system instruction and reformat messages
|
||||
system_instruction, contents = self._reformat_messages(messages)
|
||||
|
||||
# Prepare generation config
|
||||
config_params = {
|
||||
"temperature": self.config.temperature,
|
||||
"max_output_tokens": self.config.max_tokens,
|
||||
"top_p": self.config.top_p,
|
||||
}
|
||||
|
||||
# Add system instruction to config if present
|
||||
if system_instruction:
|
||||
config_params["system_instruction"] = system_instruction
|
||||
|
||||
if response_format is not None and response_format["type"] == "json_object":
|
||||
config_params["response_mime_type"] = "application/json"
|
||||
if "schema" in response_format:
|
||||
config_params["response_schema"] = response_format["schema"]
|
||||
|
||||
if tools:
|
||||
formatted_tools = self._reformat_tools(tools)
|
||||
config_params["tools"] = formatted_tools
|
||||
|
||||
if tool_choice:
|
||||
if tool_choice == "auto":
|
||||
mode = types.FunctionCallingConfigMode.AUTO
|
||||
elif tool_choice == "any":
|
||||
mode = types.FunctionCallingConfigMode.ANY
|
||||
else:
|
||||
mode = types.FunctionCallingConfigMode.NONE
|
||||
|
||||
tool_config = types.ToolConfig(
|
||||
function_calling_config=types.FunctionCallingConfig(
|
||||
mode=mode,
|
||||
allowed_function_names=(
|
||||
[tool["function"]["name"] for tool in tools] if tool_choice == "any" else None
|
||||
),
|
||||
)
|
||||
)
|
||||
config_params["tool_config"] = tool_config
|
||||
|
||||
generation_config = types.GenerateContentConfig(**config_params)
|
||||
|
||||
response = self.client.models.generate_content(
|
||||
model=self.config.model, contents=contents, config=generation_config
|
||||
)
|
||||
|
||||
return self._parse_response(response, tools)
|
||||
88
llms/groq.py
Normal file
88
llms/groq.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
try:
|
||||
from groq import Groq
|
||||
except ImportError:
|
||||
raise ImportError("The 'groq' library is required. Please install it using 'pip install groq'.")
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.memory.utils import extract_json
|
||||
|
||||
|
||||
class GroqLLM(LLMBase):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
if not self.config.model:
|
||||
self.config.model = "llama3-70b-8192"
|
||||
|
||||
api_key = self.config.api_key or os.getenv("GROQ_API_KEY")
|
||||
self.client = Groq(api_key=api_key)
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
"""
|
||||
Process the response based on whether tools are used or not.
|
||||
|
||||
Args:
|
||||
response: The raw response from API.
|
||||
tools: The list of tools provided in the request.
|
||||
|
||||
Returns:
|
||||
str or dict: The processed response.
|
||||
"""
|
||||
if tools:
|
||||
processed_response = {
|
||||
"content": response.choices[0].message.content,
|
||||
"tool_calls": [],
|
||||
}
|
||||
|
||||
if response.choices[0].message.tool_calls:
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(extract_json(tool_call.function.arguments)),
|
||||
}
|
||||
)
|
||||
|
||||
return processed_response
|
||||
else:
|
||||
return response.choices[0].message.content
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_format=None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: str = "auto",
|
||||
):
|
||||
"""
|
||||
Generate a response based on the given messages using Groq.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dicts containing 'role' and 'content'.
|
||||
response_format (str or object, optional): Format of the response. Defaults to "text".
|
||||
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
"""
|
||||
params = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"temperature": self.config.temperature,
|
||||
"max_tokens": self.config.max_tokens,
|
||||
"top_p": self.config.top_p,
|
||||
}
|
||||
if response_format:
|
||||
params["response_format"] = response_format
|
||||
if tools:
|
||||
params["tools"] = tools
|
||||
params["tool_choice"] = tool_choice
|
||||
|
||||
response = self.client.chat.completions.create(**params)
|
||||
return self._parse_response(response, tools)
|
||||
94
llms/langchain.py
Normal file
94
llms/langchain.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
try:
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain_core.messages import AIMessage
|
||||
except ImportError:
|
||||
raise ImportError("langchain is not installed. Please install it using `pip install langchain`")
|
||||
|
||||
|
||||
class LangchainLLM(LLMBase):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
if self.config.model is None:
|
||||
raise ValueError("`model` parameter is required")
|
||||
|
||||
if not isinstance(self.config.model, BaseChatModel):
|
||||
raise ValueError("`model` must be an instance of BaseChatModel")
|
||||
|
||||
self.langchain_model = self.config.model
|
||||
|
||||
def _parse_response(self, response: AIMessage, tools: Optional[List[Dict]]):
|
||||
"""
|
||||
Process the response based on whether tools are used or not.
|
||||
|
||||
Args:
|
||||
response: AI Message.
|
||||
tools: The list of tools provided in the request.
|
||||
|
||||
Returns:
|
||||
str or dict: The processed response.
|
||||
"""
|
||||
if not tools:
|
||||
return response.content
|
||||
|
||||
processed_response = {
|
||||
"content": response.content,
|
||||
"tool_calls": [],
|
||||
}
|
||||
|
||||
for tool_call in response.tool_calls:
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": tool_call["name"],
|
||||
"arguments": tool_call["args"],
|
||||
}
|
||||
)
|
||||
|
||||
return processed_response
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_format=None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: str = "auto",
|
||||
):
|
||||
"""
|
||||
Generate a response based on the given messages using langchain_community.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dicts containing 'role' and 'content'.
|
||||
response_format (str or object, optional): Format of the response. Not used in Langchain.
|
||||
tools (list, optional): List of tools that the model can call.
|
||||
tool_choice (str, optional): Tool choice method.
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
"""
|
||||
# Convert the messages to LangChain's tuple format
|
||||
langchain_messages = []
|
||||
for message in messages:
|
||||
role = message["role"]
|
||||
content = message["content"]
|
||||
|
||||
if role == "system":
|
||||
langchain_messages.append(("system", content))
|
||||
elif role == "user":
|
||||
langchain_messages.append(("human", content))
|
||||
elif role == "assistant":
|
||||
langchain_messages.append(("ai", content))
|
||||
|
||||
if not langchain_messages:
|
||||
raise ValueError("No valid messages found in the messages list")
|
||||
|
||||
langchain_model = self.langchain_model
|
||||
if tools:
|
||||
langchain_model = langchain_model.bind_tools(tools=tools, tool_choice=tool_choice)
|
||||
|
||||
response: AIMessage = langchain_model.invoke(langchain_messages)
|
||||
return self._parse_response(response, tools)
|
||||
87
llms/litellm.py
Normal file
87
llms/litellm.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
try:
|
||||
import litellm
|
||||
except ImportError:
|
||||
raise ImportError("The 'litellm' library is required. Please install it using 'pip install litellm'.")
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.memory.utils import extract_json
|
||||
|
||||
|
||||
class LiteLLM(LLMBase):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
if not self.config.model:
|
||||
self.config.model = "gpt-4.1-nano-2025-04-14"
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
"""
|
||||
Process the response based on whether tools are used or not.
|
||||
|
||||
Args:
|
||||
response: The raw response from API.
|
||||
tools: The list of tools provided in the request.
|
||||
|
||||
Returns:
|
||||
str or dict: The processed response.
|
||||
"""
|
||||
if tools:
|
||||
processed_response = {
|
||||
"content": response.choices[0].message.content,
|
||||
"tool_calls": [],
|
||||
}
|
||||
|
||||
if response.choices[0].message.tool_calls:
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(extract_json(tool_call.function.arguments)),
|
||||
}
|
||||
)
|
||||
|
||||
return processed_response
|
||||
else:
|
||||
return response.choices[0].message.content
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_format=None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: str = "auto",
|
||||
):
|
||||
"""
|
||||
Generate a response based on the given messages using Litellm.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dicts containing 'role' and 'content'.
|
||||
response_format (str or object, optional): Format of the response. Defaults to "text".
|
||||
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
"""
|
||||
if not litellm.supports_function_calling(self.config.model):
|
||||
raise ValueError(f"Model '{self.config.model}' in litellm does not support function calling.")
|
||||
|
||||
params = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"temperature": self.config.temperature,
|
||||
"max_tokens": self.config.max_tokens,
|
||||
"top_p": self.config.top_p,
|
||||
}
|
||||
if response_format:
|
||||
params["response_format"] = response_format
|
||||
if tools: # TODO: Remove tools if no issues found with new memory addition logic
|
||||
params["tools"] = tools
|
||||
params["tool_choice"] = tool_choice
|
||||
|
||||
response = litellm.completion(**params)
|
||||
return self._parse_response(response, tools)
|
||||
114
llms/lmstudio.py
Normal file
114
llms/lmstudio.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import json
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.configs.llms.lmstudio import LMStudioConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.memory.utils import extract_json
|
||||
|
||||
|
||||
class LMStudioLLM(LLMBase):
|
||||
def __init__(self, config: Optional[Union[BaseLlmConfig, LMStudioConfig, Dict]] = None):
|
||||
# Convert to LMStudioConfig if needed
|
||||
if config is None:
|
||||
config = LMStudioConfig()
|
||||
elif isinstance(config, dict):
|
||||
config = LMStudioConfig(**config)
|
||||
elif isinstance(config, BaseLlmConfig) and not isinstance(config, LMStudioConfig):
|
||||
# Convert BaseLlmConfig to LMStudioConfig
|
||||
config = LMStudioConfig(
|
||||
model=config.model,
|
||||
temperature=config.temperature,
|
||||
api_key=config.api_key,
|
||||
max_tokens=config.max_tokens,
|
||||
top_p=config.top_p,
|
||||
top_k=config.top_k,
|
||||
enable_vision=config.enable_vision,
|
||||
vision_details=config.vision_details,
|
||||
http_client_proxies=config.http_client,
|
||||
)
|
||||
|
||||
super().__init__(config)
|
||||
|
||||
self.config.model = (
|
||||
self.config.model
|
||||
or "lmstudio-community/Meta-Llama-3.1-70B-Instruct-GGUF/Meta-Llama-3.1-70B-Instruct-IQ2_M.gguf"
|
||||
)
|
||||
self.config.api_key = self.config.api_key or "lm-studio"
|
||||
|
||||
self.client = OpenAI(base_url=self.config.lmstudio_base_url, api_key=self.config.api_key)
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
"""
|
||||
Process the response based on whether tools are used or not.
|
||||
|
||||
Args:
|
||||
response: The raw response from API.
|
||||
tools: The list of tools provided in the request.
|
||||
|
||||
Returns:
|
||||
str or dict: The processed response.
|
||||
"""
|
||||
if tools:
|
||||
processed_response = {
|
||||
"content": response.choices[0].message.content,
|
||||
"tool_calls": [],
|
||||
}
|
||||
|
||||
if response.choices[0].message.tool_calls:
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(extract_json(tool_call.function.arguments)),
|
||||
}
|
||||
)
|
||||
|
||||
return processed_response
|
||||
else:
|
||||
return response.choices[0].message.content
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_format=None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: str = "auto",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Generate a response based on the given messages using LM Studio.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dicts containing 'role' and 'content'.
|
||||
response_format (str or object, optional): Format of the response. Defaults to "text".
|
||||
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||
**kwargs: Additional LM Studio-specific parameters.
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
"""
|
||||
params = self._get_supported_params(messages=messages, **kwargs)
|
||||
params.update(
|
||||
{
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
}
|
||||
)
|
||||
|
||||
if self.config.lmstudio_response_format:
|
||||
params["response_format"] = self.config.lmstudio_response_format
|
||||
elif response_format:
|
||||
params["response_format"] = response_format
|
||||
else:
|
||||
params["response_format"] = {"type": "json_object"}
|
||||
|
||||
if tools:
|
||||
params["tools"] = tools
|
||||
params["tool_choice"] = tool_choice
|
||||
|
||||
response = self.client.chat.completions.create(**params)
|
||||
return self._parse_response(response, tools)
|
||||
117
llms/ollama.py
Normal file
117
llms/ollama.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
try:
|
||||
from ollama import Client
|
||||
except ImportError:
|
||||
raise ImportError("The 'ollama' library is required. Please install it using 'pip install ollama'.")
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.configs.llms.ollama import OllamaConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
|
||||
class OllamaLLM(LLMBase):
|
||||
def __init__(self, config: Optional[Union[BaseLlmConfig, OllamaConfig, Dict]] = None):
|
||||
# Convert to OllamaConfig if needed
|
||||
if config is None:
|
||||
config = OllamaConfig()
|
||||
elif isinstance(config, dict):
|
||||
config = OllamaConfig(**config)
|
||||
elif isinstance(config, BaseLlmConfig) and not isinstance(config, OllamaConfig):
|
||||
# Convert BaseLlmConfig to OllamaConfig
|
||||
config = OllamaConfig(
|
||||
model=config.model,
|
||||
temperature=config.temperature,
|
||||
api_key=config.api_key,
|
||||
max_tokens=config.max_tokens,
|
||||
top_p=config.top_p,
|
||||
top_k=config.top_k,
|
||||
enable_vision=config.enable_vision,
|
||||
vision_details=config.vision_details,
|
||||
http_client_proxies=config.http_client,
|
||||
)
|
||||
|
||||
super().__init__(config)
|
||||
|
||||
if not self.config.model:
|
||||
self.config.model = "llama3.1:70b"
|
||||
|
||||
self.client = Client(host=self.config.ollama_base_url)
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
"""
|
||||
Process the response based on whether tools are used or not.
|
||||
|
||||
Args:
|
||||
response: The raw response from API.
|
||||
tools: The list of tools provided in the request.
|
||||
|
||||
Returns:
|
||||
str or dict: The processed response.
|
||||
"""
|
||||
# Get the content from response
|
||||
if isinstance(response, dict):
|
||||
content = response["message"]["content"]
|
||||
else:
|
||||
content = response.message.content
|
||||
|
||||
if tools:
|
||||
processed_response = {
|
||||
"content": content,
|
||||
"tool_calls": [],
|
||||
}
|
||||
|
||||
# Ollama doesn't support tool calls in the same way, so we return the content
|
||||
return processed_response
|
||||
else:
|
||||
return content
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_format=None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: str = "auto",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Generate a response based on the given messages using Ollama.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dicts containing 'role' and 'content'.
|
||||
response_format (str or object, optional): Format of the response. Defaults to "text".
|
||||
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||
**kwargs: Additional Ollama-specific parameters.
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
"""
|
||||
# Build parameters for Ollama
|
||||
params = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
# Handle JSON response format by using Ollama's native format parameter
|
||||
if response_format and response_format.get("type") == "json_object":
|
||||
params["format"] = "json"
|
||||
# Also add JSON format instruction to the last message as a fallback
|
||||
if messages and messages[-1]["role"] == "user":
|
||||
messages[-1]["content"] += "\n\nPlease respond with valid JSON only."
|
||||
else:
|
||||
messages.append({"role": "user", "content": "Please respond with valid JSON only."})
|
||||
|
||||
# Add options for Ollama (temperature, num_predict, top_p)
|
||||
options = {
|
||||
"temperature": self.config.temperature,
|
||||
"num_predict": self.config.max_tokens,
|
||||
"top_p": self.config.top_p,
|
||||
}
|
||||
params["options"] = options
|
||||
|
||||
# Remove OpenAI-specific parameters that Ollama doesn't support
|
||||
params.pop("max_tokens", None) # Ollama uses different parameter names
|
||||
|
||||
response = self.client.chat(**params)
|
||||
return self._parse_response(response, tools)
|
||||
147
llms/openai.py
Normal file
147
llms/openai.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.configs.llms.openai import OpenAIConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.memory.utils import extract_json
|
||||
|
||||
|
||||
class OpenAILLM(LLMBase):
|
||||
def __init__(self, config: Optional[Union[BaseLlmConfig, OpenAIConfig, Dict]] = None):
|
||||
# Convert to OpenAIConfig if needed
|
||||
if config is None:
|
||||
config = OpenAIConfig()
|
||||
elif isinstance(config, dict):
|
||||
config = OpenAIConfig(**config)
|
||||
elif isinstance(config, BaseLlmConfig) and not isinstance(config, OpenAIConfig):
|
||||
# Convert BaseLlmConfig to OpenAIConfig
|
||||
config = OpenAIConfig(
|
||||
model=config.model,
|
||||
temperature=config.temperature,
|
||||
api_key=config.api_key,
|
||||
max_tokens=config.max_tokens,
|
||||
top_p=config.top_p,
|
||||
top_k=config.top_k,
|
||||
enable_vision=config.enable_vision,
|
||||
vision_details=config.vision_details,
|
||||
http_client_proxies=config.http_client,
|
||||
)
|
||||
|
||||
super().__init__(config)
|
||||
|
||||
if not self.config.model:
|
||||
self.config.model = "gpt-4.1-nano-2025-04-14"
|
||||
|
||||
if os.environ.get("OPENROUTER_API_KEY"): # Use OpenRouter
|
||||
self.client = OpenAI(
|
||||
api_key=os.environ.get("OPENROUTER_API_KEY"),
|
||||
base_url=self.config.openrouter_base_url
|
||||
or os.getenv("OPENROUTER_API_BASE")
|
||||
or "https://openrouter.ai/api/v1",
|
||||
)
|
||||
else:
|
||||
api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
|
||||
base_url = self.config.openai_base_url or os.getenv("OPENAI_BASE_URL") or "https://api.openai.com/v1"
|
||||
|
||||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
"""
|
||||
Process the response based on whether tools are used or not.
|
||||
|
||||
Args:
|
||||
response: The raw response from API.
|
||||
tools: The list of tools provided in the request.
|
||||
|
||||
Returns:
|
||||
str or dict: The processed response.
|
||||
"""
|
||||
if tools:
|
||||
processed_response = {
|
||||
"content": response.choices[0].message.content,
|
||||
"tool_calls": [],
|
||||
}
|
||||
|
||||
if response.choices[0].message.tool_calls:
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(extract_json(tool_call.function.arguments)),
|
||||
}
|
||||
)
|
||||
|
||||
return processed_response
|
||||
else:
|
||||
return response.choices[0].message.content
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_format=None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: str = "auto",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Generate a JSON response based on the given messages using OpenAI.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dicts containing 'role' and 'content'.
|
||||
response_format (str or object, optional): Format of the response. Defaults to "text".
|
||||
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||
**kwargs: Additional OpenAI-specific parameters.
|
||||
|
||||
Returns:
|
||||
json: The generated response.
|
||||
"""
|
||||
params = self._get_supported_params(messages=messages, **kwargs)
|
||||
|
||||
params.update({
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
})
|
||||
|
||||
if os.getenv("OPENROUTER_API_KEY"):
|
||||
openrouter_params = {}
|
||||
if self.config.models:
|
||||
openrouter_params["models"] = self.config.models
|
||||
openrouter_params["route"] = self.config.route
|
||||
params.pop("model")
|
||||
|
||||
if self.config.site_url and self.config.app_name:
|
||||
extra_headers = {
|
||||
"HTTP-Referer": self.config.site_url,
|
||||
"X-Title": self.config.app_name,
|
||||
}
|
||||
openrouter_params["extra_headers"] = extra_headers
|
||||
|
||||
params.update(**openrouter_params)
|
||||
|
||||
else:
|
||||
openai_specific_generation_params = ["store"]
|
||||
for param in openai_specific_generation_params:
|
||||
if hasattr(self.config, param):
|
||||
params[param] = getattr(self.config, param)
|
||||
|
||||
if response_format:
|
||||
params["response_format"] = response_format
|
||||
if tools: # TODO: Remove tools if no issues found with new memory addition logic
|
||||
params["tools"] = tools
|
||||
params["tool_choice"] = tool_choice
|
||||
response = self.client.chat.completions.create(**params)
|
||||
parsed_response = self._parse_response(response, tools)
|
||||
if self.config.response_callback:
|
||||
try:
|
||||
self.config.response_callback(self, response, params)
|
||||
except Exception as e:
|
||||
# Log error but don't propagate
|
||||
logging.error(f"Error due to callback: {e}")
|
||||
pass
|
||||
return parsed_response
|
||||
52
llms/openai_structured.py
Normal file
52
llms/openai_structured.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
|
||||
class OpenAIStructuredLLM(LLMBase):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
if not self.config.model:
|
||||
self.config.model = "gpt-4o-2024-08-06"
|
||||
|
||||
api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
|
||||
base_url = self.config.openai_base_url or os.getenv("OPENAI_API_BASE") or "https://api.openai.com/v1"
|
||||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_format: Optional[str] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: str = "auto",
|
||||
) -> str:
|
||||
"""
|
||||
Generate a response based on the given messages using OpenAI.
|
||||
|
||||
Args:
|
||||
messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
|
||||
response_format (Optional[str]): The desired format of the response. Defaults to None.
|
||||
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
"""
|
||||
params = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"temperature": self.config.temperature,
|
||||
}
|
||||
|
||||
if response_format:
|
||||
params["response_format"] = response_format
|
||||
if tools:
|
||||
params["tools"] = tools
|
||||
params["tool_choice"] = tool_choice
|
||||
|
||||
response = self.client.beta.chat.completions.parse(**params)
|
||||
return response.choices[0].message.content
|
||||
89
llms/sarvam.py
Normal file
89
llms/sarvam.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
|
||||
class SarvamLLM(LLMBase):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
# Set default model if not provided
|
||||
if not self.config.model:
|
||||
self.config.model = "sarvam-m"
|
||||
|
||||
# Get API key from config or environment variable
|
||||
self.api_key = self.config.api_key or os.getenv("SARVAM_API_KEY")
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"Sarvam API key is required. Set SARVAM_API_KEY environment variable or provide api_key in config."
|
||||
)
|
||||
|
||||
# Set base URL - use config value or environment or default
|
||||
self.base_url = (
|
||||
getattr(self.config, "sarvam_base_url", None) or os.getenv("SARVAM_API_BASE") or "https://api.sarvam.ai/v1"
|
||||
)
|
||||
|
||||
def generate_response(self, messages: List[Dict[str, str]], response_format=None) -> str:
|
||||
"""
|
||||
Generate a response based on the given messages using Sarvam-M.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dicts containing 'role' and 'content'.
|
||||
response_format (str or object, optional): Format of the response.
|
||||
Currently not used by Sarvam API.
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
"""
|
||||
url = f"{self.base_url}/chat/completions"
|
||||
|
||||
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
|
||||
# Prepare the request payload
|
||||
params = {
|
||||
"messages": messages,
|
||||
"model": self.config.model if isinstance(self.config.model, str) else "sarvam-m",
|
||||
}
|
||||
|
||||
# Add standard parameters that already exist in BaseLlmConfig
|
||||
if self.config.temperature is not None:
|
||||
params["temperature"] = self.config.temperature
|
||||
|
||||
if self.config.max_tokens is not None:
|
||||
params["max_tokens"] = self.config.max_tokens
|
||||
|
||||
if self.config.top_p is not None:
|
||||
params["top_p"] = self.config.top_p
|
||||
|
||||
# Handle Sarvam-specific parameters if model is passed as dict
|
||||
if isinstance(self.config.model, dict):
|
||||
# Extract model name
|
||||
params["model"] = self.config.model.get("name", "sarvam-m")
|
||||
|
||||
# Add Sarvam-specific parameters
|
||||
sarvam_specific_params = ["reasoning_effort", "frequency_penalty", "presence_penalty", "seed", "stop", "n"]
|
||||
|
||||
for param in sarvam_specific_params:
|
||||
if param in self.config.model:
|
||||
params[param] = self.config.model[param]
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=params, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
|
||||
if "choices" in result and len(result["choices"]) > 0:
|
||||
return result["choices"][0]["message"]["content"]
|
||||
else:
|
||||
raise ValueError("No response choices found in Sarvam API response")
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise RuntimeError(f"Sarvam API request failed: {e}")
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Unexpected response format from Sarvam API: {e}")
|
||||
88
llms/together.py
Normal file
88
llms/together.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
try:
|
||||
from together import Together
|
||||
except ImportError:
|
||||
raise ImportError("The 'together' library is required. Please install it using 'pip install together'.")
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.memory.utils import extract_json
|
||||
|
||||
|
||||
class TogetherLLM(LLMBase):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
if not self.config.model:
|
||||
self.config.model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
|
||||
api_key = self.config.api_key or os.getenv("TOGETHER_API_KEY")
|
||||
self.client = Together(api_key=api_key)
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
"""
|
||||
Process the response based on whether tools are used or not.
|
||||
|
||||
Args:
|
||||
response: The raw response from API.
|
||||
tools: The list of tools provided in the request.
|
||||
|
||||
Returns:
|
||||
str or dict: The processed response.
|
||||
"""
|
||||
if tools:
|
||||
processed_response = {
|
||||
"content": response.choices[0].message.content,
|
||||
"tool_calls": [],
|
||||
}
|
||||
|
||||
if response.choices[0].message.tool_calls:
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(extract_json(tool_call.function.arguments)),
|
||||
}
|
||||
)
|
||||
|
||||
return processed_response
|
||||
else:
|
||||
return response.choices[0].message.content
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_format=None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: str = "auto",
|
||||
):
|
||||
"""
|
||||
Generate a response based on the given messages using TogetherAI.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dicts containing 'role' and 'content'.
|
||||
response_format (str or object, optional): Format of the response. Defaults to "text".
|
||||
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
"""
|
||||
params = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"temperature": self.config.temperature,
|
||||
"max_tokens": self.config.max_tokens,
|
||||
"top_p": self.config.top_p,
|
||||
}
|
||||
if response_format:
|
||||
params["response_format"] = response_format
|
||||
if tools: # TODO: Remove tools if no issues found with new memory addition logic
|
||||
params["tools"] = tools
|
||||
params["tool_choice"] = tool_choice
|
||||
|
||||
response = self.client.chat.completions.create(**params)
|
||||
return self._parse_response(response, tools)
|
||||
107
llms/vllm.py
Normal file
107
llms/vllm.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.configs.llms.vllm import VllmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.memory.utils import extract_json
|
||||
|
||||
|
||||
class VllmLLM(LLMBase):
|
||||
def __init__(self, config: Optional[Union[BaseLlmConfig, VllmConfig, Dict]] = None):
|
||||
# Convert to VllmConfig if needed
|
||||
if config is None:
|
||||
config = VllmConfig()
|
||||
elif isinstance(config, dict):
|
||||
config = VllmConfig(**config)
|
||||
elif isinstance(config, BaseLlmConfig) and not isinstance(config, VllmConfig):
|
||||
# Convert BaseLlmConfig to VllmConfig
|
||||
config = VllmConfig(
|
||||
model=config.model,
|
||||
temperature=config.temperature,
|
||||
api_key=config.api_key,
|
||||
max_tokens=config.max_tokens,
|
||||
top_p=config.top_p,
|
||||
top_k=config.top_k,
|
||||
enable_vision=config.enable_vision,
|
||||
vision_details=config.vision_details,
|
||||
http_client_proxies=config.http_client,
|
||||
)
|
||||
|
||||
super().__init__(config)
|
||||
|
||||
if not self.config.model:
|
||||
self.config.model = "Qwen/Qwen2.5-32B-Instruct"
|
||||
|
||||
self.config.api_key = self.config.api_key or os.getenv("VLLM_API_KEY") or "vllm-api-key"
|
||||
base_url = self.config.vllm_base_url or os.getenv("VLLM_BASE_URL")
|
||||
self.client = OpenAI(api_key=self.config.api_key, base_url=base_url)
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
"""
|
||||
Process the response based on whether tools are used or not.
|
||||
|
||||
Args:
|
||||
response: The raw response from API.
|
||||
tools: The list of tools provided in the request.
|
||||
|
||||
Returns:
|
||||
str or dict: The processed response.
|
||||
"""
|
||||
if tools:
|
||||
processed_response = {
|
||||
"content": response.choices[0].message.content,
|
||||
"tool_calls": [],
|
||||
}
|
||||
|
||||
if response.choices[0].message.tool_calls:
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(extract_json(tool_call.function.arguments)),
|
||||
}
|
||||
)
|
||||
|
||||
return processed_response
|
||||
else:
|
||||
return response.choices[0].message.content
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_format=None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: str = "auto",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Generate a response based on the given messages using vLLM.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dicts containing 'role' and 'content'.
|
||||
response_format (str or object, optional): Format of the response. Defaults to "text".
|
||||
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||
**kwargs: Additional vLLM-specific parameters.
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
"""
|
||||
params = self._get_supported_params(messages=messages, **kwargs)
|
||||
params.update(
|
||||
{
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
}
|
||||
)
|
||||
|
||||
if tools:
|
||||
params["tools"] = tools
|
||||
params["tool_choice"] = tool_choice
|
||||
|
||||
response = self.client.chat.completions.create(**params)
|
||||
return self._parse_response(response, tools)
|
||||
52
llms/xai.py
Normal file
52
llms/xai.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
|
||||
class XAILLM(LLMBase):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
if not self.config.model:
|
||||
self.config.model = "grok-2-latest"
|
||||
|
||||
api_key = self.config.api_key or os.getenv("XAI_API_KEY")
|
||||
base_url = self.config.xai_base_url or os.getenv("XAI_API_BASE") or "https://api.x.ai/v1"
|
||||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_format=None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: str = "auto",
|
||||
):
|
||||
"""
|
||||
Generate a response based on the given messages using XAI.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dicts containing 'role' and 'content'.
|
||||
response_format (str or object, optional): Format of the response. Defaults to "text".
|
||||
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
"""
|
||||
params = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"temperature": self.config.temperature,
|
||||
"max_tokens": self.config.max_tokens,
|
||||
"top_p": self.config.top_p,
|
||||
}
|
||||
|
||||
if response_format:
|
||||
params["response_format"] = response_format
|
||||
|
||||
response = self.client.chat.completions.create(**params)
|
||||
return response.choices[0].message.content
|
||||
Reference in New Issue
Block a user