| import gradio as gr |
| from typing import Dict, List, Optional, TypedDict |
| from nlp4web_codebase.ir.data_loaders.sciq import load_sciq |
| from bm25 import BM25Index, BM25Retriever |
|
|
| sciq = load_sciq() |
| bm25_index = BM25Index.build_from_documents( |
| documents=iter(sciq.corpus), |
| ndocs=12160, |
| show_progress_bar=True, |
| k1=0.8, |
| b=0.6, |
| ) |
| bm25_index.save("output/bm25_sciq_index") |
| bm25_retriever = BM25Retriever(index_dir="output/bm25_sciq_index") |
|
|
|
|
| class Hit(TypedDict): |
| cid: str |
| score: float |
| text: str |
|
|
|
|
| demo: Optional[gr.Interface] = None |
| return_type = List[Hit] |
|
|
| |
| cid2doc = {doc.collection_id: doc.text for doc in sciq.corpus} |
|
|
|
|
| def search(query: str) -> List[Hit]: |
| ranking: Dict[str, float] = bm25_retriever.retrieve(query) |
| |
| sorted_ranking = sorted(ranking.items(), key=lambda item: item[1], reverse=True) |
| hits = [] |
| for cid, score in sorted_ranking: |
| hits.append(Hit(cid=cid, score=score, text=cid2doc[cid])) |
| return hits |
|
|
|
|
| demo = gr.Interface( |
| fn=search, |
| inputs=gr.Textbox(lines=2, placeholder="Enter your query here..."), |
| outputs="text", |
| title="BM25 Retriever Search", |
| description="Search using a BM25 retriever on [SciQ](https://huggingface.co/datasets/allenai/sciq) and return top-10 ranked documents with scores.", |
| ) |
| |
| demo.launch() |
|
|