law_chatbot / app.py
kp0001's picture
Update app.py
dcb4a3f verified
import gradio as gr
import chromadb
import os
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain.llms import OpenAI
from langchain.schema.output_parser import StrOutputParser
from langchain.load import dumps, loads
import openai
# Initialize the ChromaDB client
client = chromadb.PersistentClient(path="indian_law_bge_work_1")
# Load the collection
collection = client.get_or_create_collection("indian_law_bge_work_1")
# Vector Search Function
def vector_search(query, top_k=5):
try:
results = collection.query(query_texts=[query], n_results=top_k)
return results['documents']
except Exception as e:
return f"Error during vector search: {e}"
# Generate Query Function
def generate_query(query, query_length):
try:
prompt = PromptTemplate(
input_variables=["query", "query_length"],
template="""
You are a helpful assistant that can answer questions about Indian law.
You are given a query: "{query}" and you need to generate {query_length} reformulated queries for vector search.
"""
)
llm = OpenAI(temperature=0.7)
chain = LLMChain(llm=llm, prompt=prompt, output_parser=StrOutputParser())
result = chain.run({"query": query, "query_length": query_length})
result = [i.strip() for i in result.split("\n") if i.strip()]
result = [i for i in result if i != ""]
return result
except Exception as e:
return f"Error during query generation: {e}"
# Reciprocal Rank Fusion Function
def reciprocal_rank_fusion(results_list, k=60):
fused_scores = {}
try:
for docs in results_list:
for rank, doc in enumerate(docs):
doc_str = dumps(doc)
if doc_str not in fused_scores:
fused_scores[doc_str] = 0
fused_scores[doc_str] += 1 / (rank + 1 + k)
reranked_results = [
(loads(doc), score)
for doc, score in sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)
]
return reranked_results
except Exception as e:
return f"Error during RRF: {e}"
def handle_query(openai_key, query, query_length):
openai.api_key = openai_key
os.environ["OPENAI_API_KEY"] = openai_key
# Generate reformulated queries
generated_queries = generate_query(query, query_length)
if isinstance(generated_queries, str):
return generated_queries, []
all_results = []
for g_query in generated_queries:
documents = vector_search(g_query, top_k=5)
if isinstance(documents, str): # Error handling
return documents, []
all_results.append(documents)
# Fuse results using RRF
fused_results = reciprocal_rank_fusion(all_results)
if isinstance(fused_results, str):
return fused_results, []
# Prepare fused results for language model input
fused_results_str = "\n".join([f"Document: {result}, Score: {score}" for result, score in fused_results])
# Prepare the prompt for the ChatCompletion API
prompt = PromptTemplate(
input_variables=["query", "fused_results"],
template="""
You are a helpful assistant that can answer questions about Indian law.
You are given a query: "{query}" and the following fused results from a vector search:
{fused_results}
These are the results from the vector search. Take the best result and provide a response.
"""
)
formatted_prompt = prompt.format(
query=query,
fused_results=fused_results_str
)
# Use the ChatCompletion API instead of Completion
response = openai.chat.Completion.create(
model="gpt-4",
messages=[
{"role": "system", "content": "You are a helpful assistant that answers questions about Indian law."},
{"role": "user", "content": formatted_prompt}
],
max_tokens=300,
temperature=0.7
)
answer = response['choices'][0]['message']['content'].strip()
return answer, fused_results
# Gradio Interface
def app(openai_key, query, query_length):
answer, fused_results = handle_query(openai_key, query, query_length)
fused_results_str = "\n".join([f"Document: {result}, Score: {score}" for result, score in fused_results])
return answer, fused_results_str
with gr.Blocks() as demo:
gr.Markdown("## Indian Law Assistant")
openai_key = gr.Textbox(label="OpenAI API Key", placeholder="Enter your OpenAI API key")
query = gr.Textbox(label="Query", placeholder="Enter your query about Indian law")
query_length = gr.Slider(minimum=1, maximum=10, value=3, label="Number of Reformulated Queries")
answer_output = gr.Textbox(label="Answer", interactive=False)
fused_results_output = gr.Textbox(label="Fused Results", interactive=False)
submit_button = gr.Button("Submit")
submit_button.click(
fn=app,
inputs=[openai_key, query, query_length],
outputs=[answer_output, fused_results_output]
)
demo.launch()