import gradio as gr import chromadb import os from langchain.prompts import PromptTemplate from langchain.chains import LLMChain from langchain.llms import OpenAI from langchain.schema.output_parser import StrOutputParser from langchain.load import dumps, loads import openai # Initialize the ChromaDB client client = chromadb.PersistentClient(path="indian_law_bge_work_1") # Load the collection collection = client.get_or_create_collection("indian_law_bge_work_1") # Vector Search Function def vector_search(query, top_k=5): try: results = collection.query(query_texts=[query], n_results=top_k) return results['documents'] except Exception as e: return f"Error during vector search: {e}" # Generate Query Function def generate_query(query, query_length): try: prompt = PromptTemplate( input_variables=["query", "query_length"], template=""" You are a helpful assistant that can answer questions about Indian law. You are given a query: "{query}" and you need to generate {query_length} reformulated queries for vector search. """ ) llm = OpenAI(temperature=0.7) chain = LLMChain(llm=llm, prompt=prompt, output_parser=StrOutputParser()) result = chain.run({"query": query, "query_length": query_length}) result = [i.strip() for i in result.split("\n") if i.strip()] result = [i for i in result if i != ""] return result except Exception as e: return f"Error during query generation: {e}" # Reciprocal Rank Fusion Function def reciprocal_rank_fusion(results_list, k=60): fused_scores = {} try: for docs in results_list: for rank, doc in enumerate(docs): doc_str = dumps(doc) if doc_str not in fused_scores: fused_scores[doc_str] = 0 fused_scores[doc_str] += 1 / (rank + 1 + k) reranked_results = [ (loads(doc), score) for doc, score in sorted(fused_scores.items(), key=lambda x: x[1], reverse=True) ] return reranked_results except Exception as e: return f"Error during RRF: {e}" def handle_query(openai_key, query, query_length): openai.api_key = openai_key os.environ["OPENAI_API_KEY"] = openai_key # Generate reformulated queries generated_queries = generate_query(query, query_length) if isinstance(generated_queries, str): return generated_queries, [] all_results = [] for g_query in generated_queries: documents = vector_search(g_query, top_k=5) if isinstance(documents, str): # Error handling return documents, [] all_results.append(documents) # Fuse results using RRF fused_results = reciprocal_rank_fusion(all_results) if isinstance(fused_results, str): return fused_results, [] # Prepare fused results for language model input fused_results_str = "\n".join([f"Document: {result}, Score: {score}" for result, score in fused_results]) # Prepare the prompt for the ChatCompletion API prompt = PromptTemplate( input_variables=["query", "fused_results"], template=""" You are a helpful assistant that can answer questions about Indian law. You are given a query: "{query}" and the following fused results from a vector search: {fused_results} These are the results from the vector search. Take the best result and provide a response. """ ) formatted_prompt = prompt.format( query=query, fused_results=fused_results_str ) # Use the ChatCompletion API instead of Completion response = openai.chat.Completion.create( model="gpt-4", messages=[ {"role": "system", "content": "You are a helpful assistant that answers questions about Indian law."}, {"role": "user", "content": formatted_prompt} ], max_tokens=300, temperature=0.7 ) answer = response['choices'][0]['message']['content'].strip() return answer, fused_results # Gradio Interface def app(openai_key, query, query_length): answer, fused_results = handle_query(openai_key, query, query_length) fused_results_str = "\n".join([f"Document: {result}, Score: {score}" for result, score in fused_results]) return answer, fused_results_str with gr.Blocks() as demo: gr.Markdown("## Indian Law Assistant") openai_key = gr.Textbox(label="OpenAI API Key", placeholder="Enter your OpenAI API key") query = gr.Textbox(label="Query", placeholder="Enter your query about Indian law") query_length = gr.Slider(minimum=1, maximum=10, value=3, label="Number of Reformulated Queries") answer_output = gr.Textbox(label="Answer", interactive=False) fused_results_output = gr.Textbox(label="Fused Results", interactive=False) submit_button = gr.Button("Submit") submit_button.click( fn=app, inputs=[openai_key, query, query_length], outputs=[answer_output, fused_results_output] ) demo.launch()