147 lines
5.4 KiB
Python
147 lines
5.4 KiB
Python
from typing import List, Dict, Any, Union
|
|
import numpy as np
|
|
|
|
from mem0.reranker.base import BaseReranker
|
|
from mem0.configs.rerankers.base import BaseRerankerConfig
|
|
from mem0.configs.rerankers.huggingface import HuggingFaceRerankerConfig
|
|
|
|
try:
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
|
import torch
|
|
TRANSFORMERS_AVAILABLE = True
|
|
except ImportError:
|
|
TRANSFORMERS_AVAILABLE = False
|
|
|
|
|
|
class HuggingFaceReranker(BaseReranker):
|
|
"""HuggingFace Transformers based reranker implementation."""
|
|
|
|
def __init__(self, config: Union[BaseRerankerConfig, HuggingFaceRerankerConfig, Dict]):
|
|
"""
|
|
Initialize HuggingFace reranker.
|
|
|
|
Args:
|
|
config: Configuration object with reranker parameters
|
|
"""
|
|
if not TRANSFORMERS_AVAILABLE:
|
|
raise ImportError("transformers package is required for HuggingFaceReranker. Install with: pip install transformers torch")
|
|
|
|
# Convert to HuggingFaceRerankerConfig if needed
|
|
if isinstance(config, dict):
|
|
config = HuggingFaceRerankerConfig(**config)
|
|
elif isinstance(config, BaseRerankerConfig) and not isinstance(config, HuggingFaceRerankerConfig):
|
|
# Convert BaseRerankerConfig to HuggingFaceRerankerConfig with defaults
|
|
config = HuggingFaceRerankerConfig(
|
|
provider=getattr(config, 'provider', 'huggingface'),
|
|
model=getattr(config, 'model', 'BAAI/bge-reranker-base'),
|
|
api_key=getattr(config, 'api_key', None),
|
|
top_k=getattr(config, 'top_k', None),
|
|
device=None, # Will auto-detect
|
|
batch_size=32, # Default
|
|
max_length=512, # Default
|
|
normalize=True, # Default
|
|
)
|
|
|
|
self.config = config
|
|
|
|
# Set device
|
|
if self.config.device is None:
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
else:
|
|
self.device = self.config.device
|
|
|
|
# Load model and tokenizer
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.config.model)
|
|
self.model = AutoModelForSequenceClassification.from_pretrained(self.config.model)
|
|
self.model.to(self.device)
|
|
self.model.eval()
|
|
|
|
def rerank(self, query: str, documents: List[Dict[str, Any]], top_k: int = None) -> List[Dict[str, Any]]:
|
|
"""
|
|
Rerank documents using HuggingFace cross-encoder model.
|
|
|
|
Args:
|
|
query: The search query
|
|
documents: List of documents to rerank
|
|
top_k: Number of top documents to return
|
|
|
|
Returns:
|
|
List of reranked documents with rerank_score
|
|
"""
|
|
if not documents:
|
|
return documents
|
|
|
|
# Extract text content for reranking
|
|
doc_texts = []
|
|
for doc in documents:
|
|
if 'memory' in doc:
|
|
doc_texts.append(doc['memory'])
|
|
elif 'text' in doc:
|
|
doc_texts.append(doc['text'])
|
|
elif 'content' in doc:
|
|
doc_texts.append(doc['content'])
|
|
else:
|
|
doc_texts.append(str(doc))
|
|
|
|
try:
|
|
scores = []
|
|
|
|
# Process documents in batches
|
|
for i in range(0, len(doc_texts), self.config.batch_size):
|
|
batch_docs = doc_texts[i:i + self.config.batch_size]
|
|
batch_pairs = [[query, doc] for doc in batch_docs]
|
|
|
|
# Tokenize batch
|
|
inputs = self.tokenizer(
|
|
batch_pairs,
|
|
padding=True,
|
|
truncation=True,
|
|
max_length=self.config.max_length,
|
|
return_tensors="pt"
|
|
).to(self.device)
|
|
|
|
# Get scores
|
|
with torch.no_grad():
|
|
outputs = self.model(**inputs)
|
|
batch_scores = outputs.logits.squeeze(-1).cpu().numpy()
|
|
|
|
# Handle single item case
|
|
if batch_scores.ndim == 0:
|
|
batch_scores = [float(batch_scores)]
|
|
else:
|
|
batch_scores = batch_scores.tolist()
|
|
|
|
scores.extend(batch_scores)
|
|
|
|
# Normalize scores if requested
|
|
if self.config.normalize:
|
|
scores = np.array(scores)
|
|
scores = (scores - scores.min()) / (scores.max() - scores.min() + 1e-8)
|
|
scores = scores.tolist()
|
|
|
|
# Combine documents with scores
|
|
doc_score_pairs = list(zip(documents, scores))
|
|
|
|
# Sort by score (descending)
|
|
doc_score_pairs.sort(key=lambda x: x[1], reverse=True)
|
|
|
|
# Apply top_k limit
|
|
final_top_k = top_k or self.config.top_k
|
|
if final_top_k:
|
|
doc_score_pairs = doc_score_pairs[:final_top_k]
|
|
|
|
# Create reranked results
|
|
reranked_docs = []
|
|
for doc, score in doc_score_pairs:
|
|
reranked_doc = doc.copy()
|
|
reranked_doc['rerank_score'] = float(score)
|
|
reranked_docs.append(reranked_doc)
|
|
|
|
return reranked_docs
|
|
|
|
except Exception:
|
|
# Fallback to original order if reranking fails
|
|
for doc in documents:
|
|
doc['rerank_score'] = 0.0
|
|
final_top_k = top_k or self.config.top_k
|
|
return documents[:final_top_k] if final_top_k else documents |