Files
mem0/vector_stores/azure_mysql.py
2026-03-06 21:11:10 +08:00

464 lines
16 KiB
Python

import json
import logging
from contextlib import contextmanager
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
try:
import pymysql
from pymysql.cursors import DictCursor
from dbutils.pooled_db import PooledDB
except ImportError:
raise ImportError(
"Azure MySQL vector store requires PyMySQL and DBUtils. "
"Please install them using 'pip install pymysql dbutils'"
)
try:
from azure.identity import DefaultAzureCredential
AZURE_IDENTITY_AVAILABLE = True
except ImportError:
AZURE_IDENTITY_AVAILABLE = False
from mem0.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__)
class OutputData(BaseModel):
id: Optional[str]
score: Optional[float]
payload: Optional[dict]
class AzureMySQL(VectorStoreBase):
def __init__(
self,
host: str,
port: int,
user: str,
password: Optional[str],
database: str,
collection_name: str,
embedding_model_dims: int,
use_azure_credential: bool = False,
ssl_ca: Optional[str] = None,
ssl_disabled: bool = False,
minconn: int = 1,
maxconn: int = 5,
connection_pool: Optional[Any] = None,
):
"""
Initialize the Azure MySQL vector store.
Args:
host (str): MySQL server host
port (int): MySQL server port
user (str): Database user
password (str, optional): Database password (not required if using Azure credential)
database (str): Database name
collection_name (str): Collection/table name
embedding_model_dims (int): Dimension of the embedding vector
use_azure_credential (bool): Use Azure DefaultAzureCredential for authentication
ssl_ca (str, optional): Path to SSL CA certificate
ssl_disabled (bool): Disable SSL connection
minconn (int): Minimum number of connections in the pool
maxconn (int): Maximum number of connections in the pool
connection_pool (Any, optional): Pre-configured connection pool
"""
self.host = host
self.port = port
self.user = user
self.password = password
self.database = database
self.collection_name = collection_name
self.embedding_model_dims = embedding_model_dims
self.use_azure_credential = use_azure_credential
self.ssl_ca = ssl_ca
self.ssl_disabled = ssl_disabled
self.connection_pool = connection_pool
# Handle Azure authentication
if use_azure_credential:
if not AZURE_IDENTITY_AVAILABLE:
raise ImportError(
"Azure Identity is required for Azure credential authentication. "
"Please install it using 'pip install azure-identity'"
)
self._setup_azure_auth()
# Setup connection pool
if self.connection_pool is None:
self._setup_connection_pool(minconn, maxconn)
# Create collection if it doesn't exist
collections = self.list_cols()
if collection_name not in collections:
self.create_col(name=collection_name, vector_size=embedding_model_dims, distance="cosine")
def _setup_azure_auth(self):
"""Setup Azure authentication using DefaultAzureCredential."""
try:
credential = DefaultAzureCredential()
# Get access token for Azure Database for MySQL
token = credential.get_token("https://ossrdbms-aad.database.windows.net/.default")
# Use token as password
self.password = token.token
logger.info("Successfully authenticated using Azure DefaultAzureCredential")
except Exception as e:
logger.error(f"Failed to authenticate with Azure: {e}")
raise
def _setup_connection_pool(self, minconn: int, maxconn: int):
"""Setup MySQL connection pool."""
connect_kwargs = {
"host": self.host,
"port": self.port,
"user": self.user,
"password": self.password,
"database": self.database,
"charset": "utf8mb4",
"cursorclass": DictCursor,
"autocommit": False,
}
# SSL configuration
if not self.ssl_disabled:
ssl_config = {"ssl_verify_cert": True}
if self.ssl_ca:
ssl_config["ssl_ca"] = self.ssl_ca
connect_kwargs["ssl"] = ssl_config
try:
self.connection_pool = PooledDB(
creator=pymysql,
mincached=minconn,
maxcached=maxconn,
maxconnections=maxconn,
blocking=True,
**connect_kwargs
)
logger.info("Successfully created MySQL connection pool")
except Exception as e:
logger.error(f"Failed to create connection pool: {e}")
raise
@contextmanager
def _get_cursor(self, commit: bool = False):
"""
Context manager to get a cursor from the connection pool.
Auto-commits or rolls back based on exception.
"""
conn = self.connection_pool.connection()
cur = conn.cursor()
try:
yield cur
if commit:
conn.commit()
except Exception as exc:
conn.rollback()
logger.error(f"Database error: {exc}", exc_info=True)
raise
finally:
cur.close()
conn.close()
def create_col(self, name: str = None, vector_size: int = None, distance: str = "cosine"):
"""
Create a new collection (table in MySQL).
Enables vector extension and creates appropriate indexes.
Args:
name (str, optional): Collection name (uses self.collection_name if not provided)
vector_size (int, optional): Vector dimension (uses self.embedding_model_dims if not provided)
distance (str): Distance metric (cosine, euclidean, dot_product)
"""
table_name = name or self.collection_name
dims = vector_size or self.embedding_model_dims
with self._get_cursor(commit=True) as cur:
# Create table with vector column
cur.execute(f"""
CREATE TABLE IF NOT EXISTS `{table_name}` (
id VARCHAR(255) PRIMARY KEY,
vector JSON,
payload JSON,
INDEX idx_payload_keys ((CAST(payload AS CHAR(255)) ARRAY))
)
""")
logger.info(f"Created collection '{table_name}' with vector dimension {dims}")
def insert(self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None):
"""
Insert vectors into the collection.
Args:
vectors (List[List[float]]): List of vectors to insert
payloads (List[Dict], optional): List of payloads corresponding to vectors
ids (List[str], optional): List of IDs corresponding to vectors
"""
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
if payloads is None:
payloads = [{}] * len(vectors)
if ids is None:
import uuid
ids = [str(uuid.uuid4()) for _ in range(len(vectors))]
data = []
for vector, payload, vec_id in zip(vectors, payloads, ids):
data.append((vec_id, json.dumps(vector), json.dumps(payload)))
with self._get_cursor(commit=True) as cur:
cur.executemany(
f"INSERT INTO `{self.collection_name}` (id, vector, payload) VALUES (%s, %s, %s) "
f"ON DUPLICATE KEY UPDATE vector = VALUES(vector), payload = VALUES(payload)",
data
)
def _cosine_distance(self, vec1_json: str, vec2: List[float]) -> str:
"""Generate SQL for cosine distance calculation."""
# For MySQL, we need to calculate cosine similarity manually
# This is a simplified version - in production, you'd use stored procedures or UDFs
return """
1 - (
(SELECT SUM(a.val * b.val) /
(SQRT(SUM(a.val * a.val)) * SQRT(SUM(b.val * b.val))))
FROM (
SELECT JSON_EXTRACT(vector, CONCAT('$[', idx, ']')) as val
FROM (SELECT @row := @row + 1 as idx FROM (SELECT 0 UNION ALL SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3) t1, (SELECT 0 UNION ALL SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3) t2) indices
WHERE idx < JSON_LENGTH(vector)
) a,
(
SELECT JSON_EXTRACT(%s, CONCAT('$[', idx, ']')) as val
FROM (SELECT @row := @row + 1 as idx FROM (SELECT 0 UNION ALL SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3) t1, (SELECT 0 UNION ALL SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3) t2) indices
WHERE idx < JSON_LENGTH(%s)
) b
WHERE a.idx = b.idx
)
"""
def search(
self,
query: str,
vectors: List[float],
limit: int = 5,
filters: Optional[Dict] = None,
) -> List[OutputData]:
"""
Search for similar vectors using cosine similarity.
Args:
query (str): Query string (not used in vector search)
vectors (List[float]): Query vector
limit (int): Number of results to return
filters (Dict, optional): Filters to apply to the search
Returns:
List[OutputData]: Search results
"""
filter_conditions = []
filter_params = []
if filters:
for k, v in filters.items():
filter_conditions.append("JSON_EXTRACT(payload, %s) = %s")
filter_params.extend([f"$.{k}", json.dumps(v)])
filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
# For simplicity, we'll compute cosine similarity in Python
# In production, you'd want to use MySQL stored procedures or UDFs
with self._get_cursor() as cur:
query_sql = f"""
SELECT id, vector, payload
FROM `{self.collection_name}`
{filter_clause}
"""
cur.execute(query_sql, filter_params)
results = cur.fetchall()
# Calculate cosine similarity in Python
import numpy as np
query_vec = np.array(vectors)
scored_results = []
for row in results:
vec = np.array(json.loads(row['vector']))
# Cosine similarity
similarity = np.dot(query_vec, vec) / (np.linalg.norm(query_vec) * np.linalg.norm(vec))
distance = 1 - similarity
scored_results.append((row['id'], distance, row['payload']))
# Sort by distance and limit
scored_results.sort(key=lambda x: x[1])
scored_results = scored_results[:limit]
return [
OutputData(id=r[0], score=float(r[1]), payload=json.loads(r[2]) if isinstance(r[2], str) else r[2])
for r in scored_results
]
def delete(self, vector_id: str):
"""
Delete a vector by ID.
Args:
vector_id (str): ID of the vector to delete
"""
with self._get_cursor(commit=True) as cur:
cur.execute(f"DELETE FROM `{self.collection_name}` WHERE id = %s", (vector_id,))
def update(
self,
vector_id: str,
vector: Optional[List[float]] = None,
payload: Optional[Dict] = None,
):
"""
Update a vector and its payload.
Args:
vector_id (str): ID of the vector to update
vector (List[float], optional): Updated vector
payload (Dict, optional): Updated payload
"""
with self._get_cursor(commit=True) as cur:
if vector is not None:
cur.execute(
f"UPDATE `{self.collection_name}` SET vector = %s WHERE id = %s",
(json.dumps(vector), vector_id),
)
if payload is not None:
cur.execute(
f"UPDATE `{self.collection_name}` SET payload = %s WHERE id = %s",
(json.dumps(payload), vector_id),
)
def get(self, vector_id: str) -> Optional[OutputData]:
"""
Retrieve a vector by ID.
Args:
vector_id (str): ID of the vector to retrieve
Returns:
OutputData: Retrieved vector or None if not found
"""
with self._get_cursor() as cur:
cur.execute(
f"SELECT id, vector, payload FROM `{self.collection_name}` WHERE id = %s",
(vector_id,),
)
result = cur.fetchone()
if not result:
return None
return OutputData(
id=result['id'],
score=None,
payload=json.loads(result['payload']) if isinstance(result['payload'], str) else result['payload']
)
def list_cols(self) -> List[str]:
"""
List all collections (tables).
Returns:
List[str]: List of collection names
"""
with self._get_cursor() as cur:
cur.execute("SHOW TABLES")
return [row[f"Tables_in_{self.database}"] for row in cur.fetchall()]
def delete_col(self):
"""Delete the collection (table)."""
with self._get_cursor(commit=True) as cur:
cur.execute(f"DROP TABLE IF EXISTS `{self.collection_name}`")
logger.info(f"Deleted collection '{self.collection_name}'")
def col_info(self) -> Dict[str, Any]:
"""
Get information about the collection.
Returns:
Dict[str, Any]: Collection information
"""
with self._get_cursor() as cur:
cur.execute("""
SELECT
TABLE_NAME as name,
TABLE_ROWS as count,
ROUND(((DATA_LENGTH + INDEX_LENGTH) / 1024 / 1024), 2) as size_mb
FROM information_schema.TABLES
WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s
""", (self.database, self.collection_name))
result = cur.fetchone()
if result:
return {
"name": result['name'],
"count": result['count'],
"size": f"{result['size_mb']} MB"
}
return {}
def list(
self,
filters: Optional[Dict] = None,
limit: int = 100
) -> List[List[OutputData]]:
"""
List all vectors in the collection.
Args:
filters (Dict, optional): Filters to apply
limit (int): Number of vectors to return
Returns:
List[List[OutputData]]: List of vectors
"""
filter_conditions = []
filter_params = []
if filters:
for k, v in filters.items():
filter_conditions.append("JSON_EXTRACT(payload, %s) = %s")
filter_params.extend([f"$.{k}", json.dumps(v)])
filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
with self._get_cursor() as cur:
cur.execute(
f"""
SELECT id, vector, payload
FROM `{self.collection_name}`
{filter_clause}
LIMIT %s
""",
(*filter_params, limit)
)
results = cur.fetchall()
return [[
OutputData(
id=r['id'],
score=None,
payload=json.loads(r['payload']) if isinstance(r['payload'], str) else r['payload']
) for r in results
]]
def reset(self):
"""Reset the collection by deleting and recreating it."""
logger.warning(f"Resetting collection {self.collection_name}...")
self.delete_col()
self.create_col(name=self.collection_name, vector_size=self.embedding_model_dims)
def __del__(self):
"""Close the connection pool when the object is deleted."""
try:
if hasattr(self, 'connection_pool') and self.connection_pool:
self.connection_pool.close()
except Exception:
pass