| import torch | |
| import numpy as np | |
| from typing import List | |
| from rank_bm25 import BM25Okapi | |
| from sentence_transformers import SentenceTransformer, CrossEncoder | |
| from chunker import chunk_documents | |
| class Retriever: | |
| def __init__(self, docs: List[str], score: int) -> None: | |
| self.docs = chunk_documents(docs=docs) | |
| self.score = score | |
| tokenized_docs = [doc.lower().split(" ") for doc in self.docs] | |
| self.bm25 = BM25Okapi(tokenized_docs) | |
| self.sbert = SentenceTransformer( | |
| 'sentence-transformers/all-distilroberta-v1' | |
| ) | |
| self.doc_embeddings = self.sbert.encode( | |
| self.docs, show_progress_bar=True | |
| ) | |
| self.cross_encoder = CrossEncoder("cross-encoder/stsb-roberta-base") | |
| def get_docs(self, query: str, n: int = 5, score: int = 2) -> List[str]: | |
| match score: | |
| case 0: | |
| bm25_scores = self._get_bm25_scores(query=query) | |
| sorted_indices = torch.Tensor.tolist( | |
| np.argsort(bm25_scores) | |
| )[::-1] | |
| case 1: | |
| semantic_scores = self._get_semantic_scores(query=query) | |
| sorted_indices = torch.Tensor.tolist( | |
| np.argsort(semantic_scores) | |
| )[::-1] | |
| case 2: | |
| bm25_scores = self._get_bm25_scores(query=query) | |
| semantic_scores = self._get_semantic_scores(query=query) | |
| scores = torch.tensor(0.3 * bm25_scores) + 0.7 * semantic_scores | |
| sorted_indices = torch.Tensor.tolist(np.argsort(scores))[::-1] | |
| preselected_docs = [self.docs[i] for i in sorted_indices][:n] | |
| result = self.rerank(query=query, docs=preselected_docs) | |
| return result | |
| def _get_bm25_scores(self, query: str) -> np.ndarray[float]: | |
| tokenized_query = query.lower().split(" ") | |
| bm25_scores = self.bm25.get_scores(tokenized_query) | |
| return bm25_scores | |
| def _get_semantic_scores(self, query: str) -> torch.Tensor: | |
| query_embeddings = self.sbert.encode(query) | |
| semantic_scores = self.sbert.similarity( | |
| query_embeddings, self.doc_embeddings | |
| ) | |
| return semantic_scores[0] | |
| def rerank(self, query: str, docs: List[str]) -> List[str]: | |
| pairs = [(query, doc) for doc in docs] | |
| rerank_scores = self.cross_encoder.predict(pairs) | |
| reranked_docs = [doc for _, doc in sorted(zip(rerank_scores, docs), reverse=True)] | |
| return reranked_docs |