Spaces:
Sleeping
Sleeping
import gradio as gr | |
import chromadb | |
import os | |
from langchain.prompts import PromptTemplate | |
from langchain.chains import LLMChain | |
from langchain.llms import OpenAI | |
from langchain.schema.output_parser import StrOutputParser | |
from langchain.load import dumps, loads | |
import openai | |
# Initialize the ChromaDB client | |
client = chromadb.PersistentClient(path="indian_law_bge_work_1") | |
# Load the collection | |
collection = client.get_or_create_collection("indian_law_bge_work_1") | |
# Vector Search Function | |
def vector_search(query, top_k=5): | |
try: | |
results = collection.query(query_texts=[query], n_results=top_k) | |
return results['documents'] | |
except Exception as e: | |
return f"Error during vector search: {e}" | |
# Generate Query Function | |
def generate_query(query, query_length): | |
try: | |
prompt = PromptTemplate( | |
input_variables=["query", "query_length"], | |
template=""" | |
You are a helpful assistant that can answer questions about Indian law. | |
You are given a query: "{query}" and you need to generate {query_length} reformulated queries for vector search. | |
""" | |
) | |
llm = OpenAI(temperature=0.7) | |
chain = LLMChain(llm=llm, prompt=prompt, output_parser=StrOutputParser()) | |
result = chain.run({"query": query, "query_length": query_length}) | |
result = [i.strip() for i in result.split("\n") if i.strip()] | |
result = [i for i in result if i != ""] | |
return result | |
except Exception as e: | |
return f"Error during query generation: {e}" | |
# Reciprocal Rank Fusion Function | |
def reciprocal_rank_fusion(results_list, k=60): | |
fused_scores = {} | |
try: | |
for docs in results_list: | |
for rank, doc in enumerate(docs): | |
doc_str = dumps(doc) | |
if doc_str not in fused_scores: | |
fused_scores[doc_str] = 0 | |
fused_scores[doc_str] += 1 / (rank + 1 + k) | |
reranked_results = [ | |
(loads(doc), score) | |
for doc, score in sorted(fused_scores.items(), key=lambda x: x[1], reverse=True) | |
] | |
return reranked_results | |
except Exception as e: | |
return f"Error during RRF: {e}" | |
def handle_query(openai_key, query, query_length): | |
openai.api_key = openai_key | |
os.environ["OPENAI_API_KEY"] = openai_key | |
# Generate reformulated queries | |
generated_queries = generate_query(query, query_length) | |
if isinstance(generated_queries, str): | |
return generated_queries, [] | |
all_results = [] | |
for g_query in generated_queries: | |
documents = vector_search(g_query, top_k=5) | |
if isinstance(documents, str): # Error handling | |
return documents, [] | |
all_results.append(documents) | |
# Fuse results using RRF | |
fused_results = reciprocal_rank_fusion(all_results) | |
if isinstance(fused_results, str): | |
return fused_results, [] | |
# Prepare fused results for language model input | |
fused_results_str = "\n".join([f"Document: {result}, Score: {score}" for result, score in fused_results]) | |
# Prepare the prompt for the ChatCompletion API | |
prompt = PromptTemplate( | |
input_variables=["query", "fused_results"], | |
template=""" | |
You are a helpful assistant that can answer questions about Indian law. | |
You are given a query: "{query}" and the following fused results from a vector search: | |
{fused_results} | |
These are the results from the vector search. Take the best result and provide a response. | |
""" | |
) | |
formatted_prompt = prompt.format( | |
query=query, | |
fused_results=fused_results_str | |
) | |
# Use the ChatCompletion API instead of Completion | |
response = openai.chat.Completion.create( | |
model="gpt-4", | |
messages=[ | |
{"role": "system", "content": "You are a helpful assistant that answers questions about Indian law."}, | |
{"role": "user", "content": formatted_prompt} | |
], | |
max_tokens=300, | |
temperature=0.7 | |
) | |
answer = response['choices'][0]['message']['content'].strip() | |
return answer, fused_results | |
# Gradio Interface | |
def app(openai_key, query, query_length): | |
answer, fused_results = handle_query(openai_key, query, query_length) | |
fused_results_str = "\n".join([f"Document: {result}, Score: {score}" for result, score in fused_results]) | |
return answer, fused_results_str | |
with gr.Blocks() as demo: | |
gr.Markdown("## Indian Law Assistant") | |
openai_key = gr.Textbox(label="OpenAI API Key", placeholder="Enter your OpenAI API key") | |
query = gr.Textbox(label="Query", placeholder="Enter your query about Indian law") | |
query_length = gr.Slider(minimum=1, maximum=10, value=3, label="Number of Reformulated Queries") | |
answer_output = gr.Textbox(label="Answer", interactive=False) | |
fused_results_output = gr.Textbox(label="Fused Results", interactive=False) | |
submit_button = gr.Button("Submit") | |
submit_button.click( | |
fn=app, | |
inputs=[openai_key, query, query_length], | |
outputs=[answer_output, fused_results_output] | |
) | |
demo.launch() | |