import json import logging from contextlib import contextmanager from typing import Any, Dict, List, Optional from pydantic import BaseModel try: import pymysql from pymysql.cursors import DictCursor from dbutils.pooled_db import PooledDB except ImportError: raise ImportError( "Azure MySQL vector store requires PyMySQL and DBUtils. " "Please install them using 'pip install pymysql dbutils'" ) try: from azure.identity import DefaultAzureCredential AZURE_IDENTITY_AVAILABLE = True except ImportError: AZURE_IDENTITY_AVAILABLE = False from mem0.vector_stores.base import VectorStoreBase logger = logging.getLogger(__name__) class OutputData(BaseModel): id: Optional[str] score: Optional[float] payload: Optional[dict] class AzureMySQL(VectorStoreBase): def __init__( self, host: str, port: int, user: str, password: Optional[str], database: str, collection_name: str, embedding_model_dims: int, use_azure_credential: bool = False, ssl_ca: Optional[str] = None, ssl_disabled: bool = False, minconn: int = 1, maxconn: int = 5, connection_pool: Optional[Any] = None, ): """ Initialize the Azure MySQL vector store. Args: host (str): MySQL server host port (int): MySQL server port user (str): Database user password (str, optional): Database password (not required if using Azure credential) database (str): Database name collection_name (str): Collection/table name embedding_model_dims (int): Dimension of the embedding vector use_azure_credential (bool): Use Azure DefaultAzureCredential for authentication ssl_ca (str, optional): Path to SSL CA certificate ssl_disabled (bool): Disable SSL connection minconn (int): Minimum number of connections in the pool maxconn (int): Maximum number of connections in the pool connection_pool (Any, optional): Pre-configured connection pool """ self.host = host self.port = port self.user = user self.password = password self.database = database self.collection_name = collection_name self.embedding_model_dims = embedding_model_dims self.use_azure_credential = use_azure_credential self.ssl_ca = ssl_ca self.ssl_disabled = ssl_disabled self.connection_pool = connection_pool # Handle Azure authentication if use_azure_credential: if not AZURE_IDENTITY_AVAILABLE: raise ImportError( "Azure Identity is required for Azure credential authentication. " "Please install it using 'pip install azure-identity'" ) self._setup_azure_auth() # Setup connection pool if self.connection_pool is None: self._setup_connection_pool(minconn, maxconn) # Create collection if it doesn't exist collections = self.list_cols() if collection_name not in collections: self.create_col(name=collection_name, vector_size=embedding_model_dims, distance="cosine") def _setup_azure_auth(self): """Setup Azure authentication using DefaultAzureCredential.""" try: credential = DefaultAzureCredential() # Get access token for Azure Database for MySQL token = credential.get_token("https://ossrdbms-aad.database.windows.net/.default") # Use token as password self.password = token.token logger.info("Successfully authenticated using Azure DefaultAzureCredential") except Exception as e: logger.error(f"Failed to authenticate with Azure: {e}") raise def _setup_connection_pool(self, minconn: int, maxconn: int): """Setup MySQL connection pool.""" connect_kwargs = { "host": self.host, "port": self.port, "user": self.user, "password": self.password, "database": self.database, "charset": "utf8mb4", "cursorclass": DictCursor, "autocommit": False, } # SSL configuration if not self.ssl_disabled: ssl_config = {"ssl_verify_cert": True} if self.ssl_ca: ssl_config["ssl_ca"] = self.ssl_ca connect_kwargs["ssl"] = ssl_config try: self.connection_pool = PooledDB( creator=pymysql, mincached=minconn, maxcached=maxconn, maxconnections=maxconn, blocking=True, **connect_kwargs ) logger.info("Successfully created MySQL connection pool") except Exception as e: logger.error(f"Failed to create connection pool: {e}") raise @contextmanager def _get_cursor(self, commit: bool = False): """ Context manager to get a cursor from the connection pool. Auto-commits or rolls back based on exception. """ conn = self.connection_pool.connection() cur = conn.cursor() try: yield cur if commit: conn.commit() except Exception as exc: conn.rollback() logger.error(f"Database error: {exc}", exc_info=True) raise finally: cur.close() conn.close() def create_col(self, name: str = None, vector_size: int = None, distance: str = "cosine"): """ Create a new collection (table in MySQL). Enables vector extension and creates appropriate indexes. Args: name (str, optional): Collection name (uses self.collection_name if not provided) vector_size (int, optional): Vector dimension (uses self.embedding_model_dims if not provided) distance (str): Distance metric (cosine, euclidean, dot_product) """ table_name = name or self.collection_name dims = vector_size or self.embedding_model_dims with self._get_cursor(commit=True) as cur: # Create table with vector column cur.execute(f""" CREATE TABLE IF NOT EXISTS `{table_name}` ( id VARCHAR(255) PRIMARY KEY, vector JSON, payload JSON, INDEX idx_payload_keys ((CAST(payload AS CHAR(255)) ARRAY)) ) """) logger.info(f"Created collection '{table_name}' with vector dimension {dims}") def insert(self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None): """ Insert vectors into the collection. Args: vectors (List[List[float]]): List of vectors to insert payloads (List[Dict], optional): List of payloads corresponding to vectors ids (List[str], optional): List of IDs corresponding to vectors """ logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}") if payloads is None: payloads = [{}] * len(vectors) if ids is None: import uuid ids = [str(uuid.uuid4()) for _ in range(len(vectors))] data = [] for vector, payload, vec_id in zip(vectors, payloads, ids): data.append((vec_id, json.dumps(vector), json.dumps(payload))) with self._get_cursor(commit=True) as cur: cur.executemany( f"INSERT INTO `{self.collection_name}` (id, vector, payload) VALUES (%s, %s, %s) " f"ON DUPLICATE KEY UPDATE vector = VALUES(vector), payload = VALUES(payload)", data ) def _cosine_distance(self, vec1_json: str, vec2: List[float]) -> str: """Generate SQL for cosine distance calculation.""" # For MySQL, we need to calculate cosine similarity manually # This is a simplified version - in production, you'd use stored procedures or UDFs return """ 1 - ( (SELECT SUM(a.val * b.val) / (SQRT(SUM(a.val * a.val)) * SQRT(SUM(b.val * b.val)))) FROM ( SELECT JSON_EXTRACT(vector, CONCAT('$[', idx, ']')) as val FROM (SELECT @row := @row + 1 as idx FROM (SELECT 0 UNION ALL SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3) t1, (SELECT 0 UNION ALL SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3) t2) indices WHERE idx < JSON_LENGTH(vector) ) a, ( SELECT JSON_EXTRACT(%s, CONCAT('$[', idx, ']')) as val FROM (SELECT @row := @row + 1 as idx FROM (SELECT 0 UNION ALL SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3) t1, (SELECT 0 UNION ALL SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3) t2) indices WHERE idx < JSON_LENGTH(%s) ) b WHERE a.idx = b.idx ) """ def search( self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None, ) -> List[OutputData]: """ Search for similar vectors using cosine similarity. Args: query (str): Query string (not used in vector search) vectors (List[float]): Query vector limit (int): Number of results to return filters (Dict, optional): Filters to apply to the search Returns: List[OutputData]: Search results """ filter_conditions = [] filter_params = [] if filters: for k, v in filters.items(): filter_conditions.append("JSON_EXTRACT(payload, %s) = %s") filter_params.extend([f"$.{k}", json.dumps(v)]) filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else "" # For simplicity, we'll compute cosine similarity in Python # In production, you'd want to use MySQL stored procedures or UDFs with self._get_cursor() as cur: query_sql = f""" SELECT id, vector, payload FROM `{self.collection_name}` {filter_clause} """ cur.execute(query_sql, filter_params) results = cur.fetchall() # Calculate cosine similarity in Python import numpy as np query_vec = np.array(vectors) scored_results = [] for row in results: vec = np.array(json.loads(row['vector'])) # Cosine similarity similarity = np.dot(query_vec, vec) / (np.linalg.norm(query_vec) * np.linalg.norm(vec)) distance = 1 - similarity scored_results.append((row['id'], distance, row['payload'])) # Sort by distance and limit scored_results.sort(key=lambda x: x[1]) scored_results = scored_results[:limit] return [ OutputData(id=r[0], score=float(r[1]), payload=json.loads(r[2]) if isinstance(r[2], str) else r[2]) for r in scored_results ] def delete(self, vector_id: str): """ Delete a vector by ID. Args: vector_id (str): ID of the vector to delete """ with self._get_cursor(commit=True) as cur: cur.execute(f"DELETE FROM `{self.collection_name}` WHERE id = %s", (vector_id,)) def update( self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict] = None, ): """ Update a vector and its payload. Args: vector_id (str): ID of the vector to update vector (List[float], optional): Updated vector payload (Dict, optional): Updated payload """ with self._get_cursor(commit=True) as cur: if vector is not None: cur.execute( f"UPDATE `{self.collection_name}` SET vector = %s WHERE id = %s", (json.dumps(vector), vector_id), ) if payload is not None: cur.execute( f"UPDATE `{self.collection_name}` SET payload = %s WHERE id = %s", (json.dumps(payload), vector_id), ) def get(self, vector_id: str) -> Optional[OutputData]: """ Retrieve a vector by ID. Args: vector_id (str): ID of the vector to retrieve Returns: OutputData: Retrieved vector or None if not found """ with self._get_cursor() as cur: cur.execute( f"SELECT id, vector, payload FROM `{self.collection_name}` WHERE id = %s", (vector_id,), ) result = cur.fetchone() if not result: return None return OutputData( id=result['id'], score=None, payload=json.loads(result['payload']) if isinstance(result['payload'], str) else result['payload'] ) def list_cols(self) -> List[str]: """ List all collections (tables). Returns: List[str]: List of collection names """ with self._get_cursor() as cur: cur.execute("SHOW TABLES") return [row[f"Tables_in_{self.database}"] for row in cur.fetchall()] def delete_col(self): """Delete the collection (table).""" with self._get_cursor(commit=True) as cur: cur.execute(f"DROP TABLE IF EXISTS `{self.collection_name}`") logger.info(f"Deleted collection '{self.collection_name}'") def col_info(self) -> Dict[str, Any]: """ Get information about the collection. Returns: Dict[str, Any]: Collection information """ with self._get_cursor() as cur: cur.execute(""" SELECT TABLE_NAME as name, TABLE_ROWS as count, ROUND(((DATA_LENGTH + INDEX_LENGTH) / 1024 / 1024), 2) as size_mb FROM information_schema.TABLES WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s """, (self.database, self.collection_name)) result = cur.fetchone() if result: return { "name": result['name'], "count": result['count'], "size": f"{result['size_mb']} MB" } return {} def list( self, filters: Optional[Dict] = None, limit: int = 100 ) -> List[List[OutputData]]: """ List all vectors in the collection. Args: filters (Dict, optional): Filters to apply limit (int): Number of vectors to return Returns: List[List[OutputData]]: List of vectors """ filter_conditions = [] filter_params = [] if filters: for k, v in filters.items(): filter_conditions.append("JSON_EXTRACT(payload, %s) = %s") filter_params.extend([f"$.{k}", json.dumps(v)]) filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else "" with self._get_cursor() as cur: cur.execute( f""" SELECT id, vector, payload FROM `{self.collection_name}` {filter_clause} LIMIT %s """, (*filter_params, limit) ) results = cur.fetchall() return [[ OutputData( id=r['id'], score=None, payload=json.loads(r['payload']) if isinstance(r['payload'], str) else r['payload'] ) for r in results ]] def reset(self): """Reset the collection by deleting and recreating it.""" logger.warning(f"Resetting collection {self.collection_name}...") self.delete_col() self.create_col(name=self.collection_name, vector_size=self.embedding_model_dims) def __del__(self): """Close the connection pool when the object is deleted.""" try: if hasattr(self, 'connection_pool') and self.connection_pool: self.connection_pool.close() except Exception: pass