import streamlit as st import warnings from datasets import load_dataset from haystack import Pipeline from haystack.components.readers import ExtractiveReader from haystack.components.retrievers.in_memory import InMemoryBM25Retriever from haystack.document_stores.in_memory import InMemoryDocumentStore from utils import get_unique_docs # πŸ”‡ Suppress noisy warnings like "Bad message format" warnings.filterwarnings("ignore", message="Bad message format") # Load the dataset @st.cache_data(show_spinner=False) def load_documents(): """ Load the documents from the dataset considering only unique documents. Returns: - documents: list of dictionaries with the documents. """ unique_docs = set() dataset_name = "PedroCJardim/QASports" dataset_split = "basketball" st.caption(f'Fetching "{dataset_name}" dataset') # build the dataset dataset = load_dataset(dataset_name, name=dataset_split) docs_validation = get_unique_docs(dataset["validation"], unique_docs) docs_train = get_unique_docs(dataset["train"], unique_docs) docs_test = get_unique_docs(dataset["test"], unique_docs) documents = docs_validation + docs_train + docs_test return documents @st.cache_resource(show_spinner=False) def get_document_store(documents): """ Index the files in the document store. Args: - files: list of dictionaries with the documents. """ st.caption(f"Building the Document Store") document_store = InMemoryDocumentStore() document_store.write_documents(documents=documents) return document_store @st.cache_resource(show_spinner=False) def get_question_pipeline(_doc_store): """ Create the pipeline with the retriever and reader components. Args: - doc_store: instance of the document store. Returns: - pipe: instance of the pipeline. """ st.caption(f"Building the Question Answering pipeline") retriever = InMemoryBM25Retriever(document_store=_doc_store) reader = ExtractiveReader(model="deepset/roberta-base-squad2") reader.warm_up() pipe = Pipeline() pipe.add_component(instance=retriever, name="retriever") pipe.add_component(instance=reader, name="reader") pipe.connect("retriever.documents", "reader.documents") return pipe def search(pipeline, question: str): """ Search for the answer to a question in the documents. Args: - pipeline: instance of the pipeline. - question: string with the question. Returns: - answer: list of answers. """ try: top_k = 3 answer = pipeline.run( data={ "retriever": {"query": question, "top_k": 10}, "reader": {"query": question, "top_k": top_k}, } ) max_k = min(top_k, len(answer["reader"]["answers"])) return answer["reader"]["answers"][0:max_k] except Exception as e: st.error(f"⚠️ Error during search: {e}") return [] # Streamlit UI _, centering_column, _ = st.columns(3) with centering_column: st.image("assets/qasports-logo.png", use_column_width=True) # Loading pipeline and dataset with st.status( "Downloading dataset...", expanded=st.session_state.get("expanded", True) ) as status: documents = load_documents() status.update(label="Indexing documents...") doc_store = get_document_store(documents) status.update(label="Creating pipeline...") pipe = get_question_pipeline(doc_store) status.update( label="Download and indexing complete!", state="complete", expanded=False ) st.session_state["expanded"] = False st.subheader("πŸ”Ž HoopMind Basketball QA", divider="rainbow") st.caption( """**HoopMind** πŸ€πŸ€– is an AI-powered basketball assistant that answers open-ended questions about players, teams, stats, and the history of the game. It combines smart retrieval with precise reading to deliver fast, accurate, and insightful responsesβ€”your go-to hub for everything basketball.""" ) # Question input if user_query := st.text_input( label="Ask a question about Basketball! πŸ€", placeholder="How many field goals did Kobe Bryant score?", ): with st.spinner("Searching for answers..."): answers = search(pipe, user_query) if answers: for idx, ans in enumerate(answers): st.info( f""" **Answer {idx+1}:** "{ans.data}" πŸ”Ή Score: {ans.score:0.4f} πŸ“„ Document: "{ans.document.meta["title"]}" 🌐 URL: {ans.document.meta["url"]} """ ) with st.expander("See details", expanded=False): st.write(ans) st.divider() else: st.error("❌ No answer found for your question.")