import json import logging import uuid from typing import Any, Dict, List, Optional import numpy as np from pydantic import BaseModel try: from cassandra.cluster import Cluster from cassandra.auth import PlainTextAuthProvider except ImportError: raise ImportError( "Apache Cassandra vector store requires cassandra-driver. " "Please install it using 'pip install cassandra-driver'" ) from mem0.vector_stores.base import VectorStoreBase logger = logging.getLogger(__name__) class OutputData(BaseModel): id: Optional[str] score: Optional[float] payload: Optional[dict] class CassandraDB(VectorStoreBase): def __init__( self, contact_points: List[str], port: int = 9042, username: Optional[str] = None, password: Optional[str] = None, keyspace: str = "mem0", collection_name: str = "memories", embedding_model_dims: int = 1536, secure_connect_bundle: Optional[str] = None, protocol_version: int = 4, load_balancing_policy: Optional[Any] = None, ): """ Initialize the Apache Cassandra vector store. Args: contact_points (List[str]): List of contact point addresses (e.g., ['127.0.0.1']) port (int): Cassandra port (default: 9042) username (str, optional): Database username password (str, optional): Database password keyspace (str): Keyspace name (default: "mem0") collection_name (str): Table name (default: "memories") embedding_model_dims (int): Dimension of the embedding vector (default: 1536) secure_connect_bundle (str, optional): Path to secure connect bundle for Astra DB protocol_version (int): CQL protocol version (default: 4) load_balancing_policy (Any, optional): Custom load balancing policy """ self.contact_points = contact_points self.port = port self.username = username self.password = password self.keyspace = keyspace self.collection_name = collection_name self.embedding_model_dims = embedding_model_dims self.secure_connect_bundle = secure_connect_bundle self.protocol_version = protocol_version self.load_balancing_policy = load_balancing_policy # Initialize connection self.cluster = None self.session = None self._setup_connection() # Create keyspace and table if they don't exist self._create_keyspace() self._create_table() def _setup_connection(self): """Setup Cassandra cluster connection.""" try: # Setup authentication auth_provider = None if self.username and self.password: auth_provider = PlainTextAuthProvider( username=self.username, password=self.password ) # Connect to Astra DB using secure connect bundle if self.secure_connect_bundle: self.cluster = Cluster( cloud={'secure_connect_bundle': self.secure_connect_bundle}, auth_provider=auth_provider, protocol_version=self.protocol_version ) else: # Connect to standard Cassandra cluster cluster_kwargs = { 'contact_points': self.contact_points, 'port': self.port, 'protocol_version': self.protocol_version } if auth_provider: cluster_kwargs['auth_provider'] = auth_provider if self.load_balancing_policy: cluster_kwargs['load_balancing_policy'] = self.load_balancing_policy self.cluster = Cluster(**cluster_kwargs) self.session = self.cluster.connect() logger.info("Successfully connected to Cassandra cluster") except Exception as e: logger.error(f"Failed to connect to Cassandra: {e}") raise def _create_keyspace(self): """Create keyspace if it doesn't exist.""" try: # Use SimpleStrategy for single datacenter, NetworkTopologyStrategy for production query = f""" CREATE KEYSPACE IF NOT EXISTS {self.keyspace} WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}} """ self.session.execute(query) self.session.set_keyspace(self.keyspace) logger.info(f"Keyspace '{self.keyspace}' is ready") except Exception as e: logger.error(f"Failed to create keyspace: {e}") raise def _create_table(self): """Create table with vector column if it doesn't exist.""" try: # Create table with vector stored as list and payload as text (JSON) query = f""" CREATE TABLE IF NOT EXISTS {self.keyspace}.{self.collection_name} ( id text PRIMARY KEY, vector list, payload text ) """ self.session.execute(query) logger.info(f"Table '{self.collection_name}' is ready") except Exception as e: logger.error(f"Failed to create table: {e}") raise def create_col(self, name: str = None, vector_size: int = None, distance: str = "cosine"): """ Create a new collection (table in Cassandra). Args: name (str, optional): Collection name (uses self.collection_name if not provided) vector_size (int, optional): Vector dimension (uses self.embedding_model_dims if not provided) distance (str): Distance metric (cosine, euclidean, dot_product) """ table_name = name or self.collection_name dims = vector_size or self.embedding_model_dims try: query = f""" CREATE TABLE IF NOT EXISTS {self.keyspace}.{table_name} ( id text PRIMARY KEY, vector list, payload text ) """ self.session.execute(query) logger.info(f"Created collection '{table_name}' with vector dimension {dims}") except Exception as e: logger.error(f"Failed to create collection: {e}") raise def insert( self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None ): """ Insert vectors into the collection. Args: vectors (List[List[float]]): List of vectors to insert payloads (List[Dict], optional): List of payloads corresponding to vectors ids (List[str], optional): List of IDs corresponding to vectors """ logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}") if payloads is None: payloads = [{}] * len(vectors) if ids is None: ids = [str(uuid.uuid4()) for _ in range(len(vectors))] try: query = f""" INSERT INTO {self.keyspace}.{self.collection_name} (id, vector, payload) VALUES (?, ?, ?) """ prepared = self.session.prepare(query) for vector, payload, vec_id in zip(vectors, payloads, ids): self.session.execute( prepared, (vec_id, vector, json.dumps(payload)) ) except Exception as e: logger.error(f"Failed to insert vectors: {e}") raise def search( self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None, ) -> List[OutputData]: """ Search for similar vectors using cosine similarity. Args: query (str): Query string (not used in vector search) vectors (List[float]): Query vector limit (int): Number of results to return filters (Dict, optional): Filters to apply to the search Returns: List[OutputData]: Search results """ try: # Fetch all vectors (in production, you'd want pagination or filtering) query_cql = f""" SELECT id, vector, payload FROM {self.keyspace}.{self.collection_name} """ rows = self.session.execute(query_cql) # Calculate cosine similarity in Python query_vec = np.array(vectors) scored_results = [] for row in rows: if not row.vector: continue vec = np.array(row.vector) # Cosine similarity similarity = np.dot(query_vec, vec) / (np.linalg.norm(query_vec) * np.linalg.norm(vec)) distance = 1 - similarity # Apply filters if provided if filters: try: payload = json.loads(row.payload) if row.payload else {} match = all(payload.get(k) == v for k, v in filters.items()) if not match: continue except json.JSONDecodeError: continue scored_results.append((row.id, distance, row.payload)) # Sort by distance and limit scored_results.sort(key=lambda x: x[1]) scored_results = scored_results[:limit] return [ OutputData( id=r[0], score=float(r[1]), payload=json.loads(r[2]) if r[2] else {} ) for r in scored_results ] except Exception as e: logger.error(f"Search failed: {e}") raise def delete(self, vector_id: str): """ Delete a vector by ID. Args: vector_id (str): ID of the vector to delete """ try: query = f""" DELETE FROM {self.keyspace}.{self.collection_name} WHERE id = ? """ prepared = self.session.prepare(query) self.session.execute(prepared, (vector_id,)) logger.info(f"Deleted vector with id: {vector_id}") except Exception as e: logger.error(f"Failed to delete vector: {e}") raise def update( self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict] = None, ): """ Update a vector and its payload. Args: vector_id (str): ID of the vector to update vector (List[float], optional): Updated vector payload (Dict, optional): Updated payload """ try: if vector is not None: query = f""" UPDATE {self.keyspace}.{self.collection_name} SET vector = ? WHERE id = ? """ prepared = self.session.prepare(query) self.session.execute(prepared, (vector, vector_id)) if payload is not None: query = f""" UPDATE {self.keyspace}.{self.collection_name} SET payload = ? WHERE id = ? """ prepared = self.session.prepare(query) self.session.execute(prepared, (json.dumps(payload), vector_id)) logger.info(f"Updated vector with id: {vector_id}") except Exception as e: logger.error(f"Failed to update vector: {e}") raise def get(self, vector_id: str) -> Optional[OutputData]: """ Retrieve a vector by ID. Args: vector_id (str): ID of the vector to retrieve Returns: OutputData: Retrieved vector or None if not found """ try: query = f""" SELECT id, vector, payload FROM {self.keyspace}.{self.collection_name} WHERE id = ? """ prepared = self.session.prepare(query) row = self.session.execute(prepared, (vector_id,)).one() if not row: return None return OutputData( id=row.id, score=None, payload=json.loads(row.payload) if row.payload else {} ) except Exception as e: logger.error(f"Failed to get vector: {e}") return None def list_cols(self) -> List[str]: """ List all collections (tables in the keyspace). Returns: List[str]: List of collection names """ try: query = f""" SELECT table_name FROM system_schema.tables WHERE keyspace_name = '{self.keyspace}' """ rows = self.session.execute(query) return [row.table_name for row in rows] except Exception as e: logger.error(f"Failed to list collections: {e}") return [] def delete_col(self): """Delete the collection (table).""" try: query = f""" DROP TABLE IF EXISTS {self.keyspace}.{self.collection_name} """ self.session.execute(query) logger.info(f"Deleted collection '{self.collection_name}'") except Exception as e: logger.error(f"Failed to delete collection: {e}") raise def col_info(self) -> Dict[str, Any]: """ Get information about the collection. Returns: Dict[str, Any]: Collection information """ try: # Get row count (approximate) query = f""" SELECT COUNT(*) as count FROM {self.keyspace}.{self.collection_name} """ row = self.session.execute(query).one() count = row.count if row else 0 return { "name": self.collection_name, "keyspace": self.keyspace, "count": count, "vector_dims": self.embedding_model_dims } except Exception as e: logger.error(f"Failed to get collection info: {e}") return {} def list( self, filters: Optional[Dict] = None, limit: int = 100 ) -> List[List[OutputData]]: """ List all vectors in the collection. Args: filters (Dict, optional): Filters to apply limit (int): Number of vectors to return Returns: List[List[OutputData]]: List of vectors """ try: query = f""" SELECT id, vector, payload FROM {self.keyspace}.{self.collection_name} LIMIT {limit} """ rows = self.session.execute(query) results = [] for row in rows: # Apply filters if provided if filters: try: payload = json.loads(row.payload) if row.payload else {} match = all(payload.get(k) == v for k, v in filters.items()) if not match: continue except json.JSONDecodeError: continue results.append( OutputData( id=row.id, score=None, payload=json.loads(row.payload) if row.payload else {} ) ) return [results] except Exception as e: logger.error(f"Failed to list vectors: {e}") return [[]] def reset(self): """Reset the collection by truncating it.""" try: logger.warning(f"Resetting collection {self.collection_name}...") query = f""" TRUNCATE TABLE {self.keyspace}.{self.collection_name} """ self.session.execute(query) logger.info(f"Collection '{self.collection_name}' has been reset") except Exception as e: logger.error(f"Failed to reset collection: {e}") raise def __del__(self): """Close the cluster connection when the object is deleted.""" try: if self.cluster: self.cluster.shutdown() logger.info("Cassandra cluster connection closed") except Exception: pass