import streamlit as st import os import tempfile import uuid from langchain_groq import ChatGroq from langchain.prompts import ChatPromptTemplate from langchain.schema import HumanMessage, AIMessage from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.document_loaders import PyPDFLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.vectorstores import Chroma from langchain.chains import RetrievalQA import re # Custom CSS Injection def inject_custom_css(): st.markdown(""" """, unsafe_allow_html=True) # Page Configuration st.set_page_config( page_title="AI Law Agent", page_icon="⚖️", layout="centered", initial_sidebar_state="expanded" ) # Constants DEFAULT_GROQ_API_KEY = os.getenv("GROQ_API_KEY") MODEL_NAME = "llama3-70b-8192" DEFAULT_DOCUMENT_PATH = "lawbook.pdf" DEFAULT_COLLECTION_NAME = "pakistan_laws_default" CHROMA_PERSIST_DIR = "./chroma_db" # Session state initialization if "messages" not in st.session_state: st.session_state.messages = [] if "user_id" not in st.session_state: st.session_state.user_id = str(uuid.uuid4()) if "vectordb" not in st.session_state: st.session_state.vectordb = None if "llm" not in st.session_state: st.session_state.llm = None if "qa_chain" not in st.session_state: st.session_state.qa_chain = None if "similar_questions" not in st.session_state: st.session_state.similar_questions = [] if "using_custom_docs" not in st.session_state: st.session_state.using_custom_docs = False if "custom_collection_name" not in st.session_state: st.session_state.custom_collection_name = f"custom_laws_{st.session_state.user_id}" def setup_embeddings(): """Sets up embeddings model""" return HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") def setup_llm(): """Setup the language model""" if st.session_state.llm is None: st.session_state.llm = ChatGroq( model_name=MODEL_NAME, groq_api_key=DEFAULT_GROQ_API_KEY, temperature=0.2 ) return st.session_state.llm def check_default_db_exists(): """Check if the default document database already exists""" return os.path.exists(os.path.join(CHROMA_PERSIST_DIR, DEFAULT_COLLECTION_NAME)) def load_existing_vectordb(collection_name): """Load an existing vector database from disk""" embeddings = setup_embeddings() try: return Chroma( persist_directory=CHROMA_PERSIST_DIR, embedding_function=embeddings, collection_name=collection_name ) except Exception as e: st.error(f"Error loading existing database: {str(e)}") return None def process_default_document(force_rebuild=False): """Process the default Pakistan laws document""" if check_default_db_exists() and not force_rebuild: st.info("Loading existing Pakistan law database...") db = load_existing_vectordb(DEFAULT_COLLECTION_NAME) if db: st.session_state.vectordb = db setup_qa_chain() st.session_state.using_custom_docs = False return True if not os.path.exists(DEFAULT_DOCUMENT_PATH): st.error(f"Default document {DEFAULT_DOCUMENT_PATH} not found.") return False try: with st.spinner("Building Pakistan law database..."): loader = PyPDFLoader(DEFAULT_DOCUMENT_PATH) documents = loader.load() for doc in documents: doc.metadata["source"] = "Pakistan Laws (Official)" text_splitter = RecursiveCharacterTextSplitter( chunk_size=1000, chunk_overlap=200 ) chunks = text_splitter.split_documents(documents) db = Chroma.from_documents( documents=chunks, embedding=setup_embeddings(), collection_name=DEFAULT_COLLECTION_NAME, persist_directory=CHROMA_PERSIST_DIR ) db.persist() st.session_state.vectordb = db setup_qa_chain() st.session_state.using_custom_docs = False return True except Exception as e: st.error(f"Error processing default document: {str(e)}") return False def process_custom_documents(uploaded_files): """Process user-uploaded PDF documents""" embeddings = setup_embeddings() collection_name = st.session_state.custom_collection_name documents = [] for uploaded_file in uploaded_files: with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file: tmp_file.write(uploaded_file.getvalue()) tmp_path = tmp_file.name try: loader = PyPDFLoader(tmp_path) file_docs = loader.load() for doc in file_docs: doc.metadata["source"] = uploaded_file.name documents.extend(file_docs) os.unlink(tmp_path) except Exception as e: st.error(f"Error processing {uploaded_file.name}: {str(e)}") if documents: text_splitter = RecursiveCharacterTextSplitter( chunk_size=1000, chunk_overlap=200 ) chunks = text_splitter.split_documents(documents) with st.spinner("Building custom document database..."): if Chroma( persist_directory=CHROMA_PERSIST_DIR, embedding_function=embeddings, collection_name=collection_name ).get(): Chroma( persist_directory=CHROMA_PERSIST_DIR, embedding_function=embeddings, collection_name=collection_name ).delete_collection() db = Chroma.from_documents( documents=chunks, embedding=embeddings, collection_name=collection_name, persist_directory=CHROMA_PERSIST_DIR ) db.persist() st.session_state.vectordb = db setup_qa_chain() st.session_state.using_custom_docs = True return True return False def setup_qa_chain(): """Set up the QA chain with the RAG system""" if st.session_state.vectordb: template = """You are a helpful legal assistant specializing in Pakistani law. Use the context to answer. If unsure, say so but provide general info. Context: {context} Question: {question} Answer:""" st.session_state.qa_chain = RetrievalQA.from_chain_type( llm=setup_llm(), chain_type="stuff", retriever=st.session_state.vectordb.as_retriever(search_kwargs={"k": 3}), chain_type_kwargs={"prompt": ChatPromptTemplate.from_template(template)}, return_source_documents=True ) def generate_similar_questions(question, docs): """Generate similar questions based on retrieved documents""" llm = setup_llm() context = "\n".join([doc.page_content for doc in docs[:2]]) prompt = f"""Generate 3 similar questions based on: Original Question: {question} Legal Context: {context} Generate exactly 3 similar questions:""" try: response = llm.invoke(prompt) questions = re.findall(r"\d+\.\s+(.*?)(?=\d+\.|$)", response.content, re.DOTALL) return [q.strip() for q in questions[:3] if "?" in q] except Exception as e: return [] def get_answer(question): """Get answer from QA chain""" if not st.session_state.vectordb: with st.spinner("Loading Pakistan law database..."): process_default_document() if st.session_state.qa_chain: result = st.session_state.qa_chain({"query": question}) answer = result["result"] sources = set() for doc in result.get("source_documents", []): if "source" in doc.metadata: sources.add(doc.metadata["source"]) if sources: answer += f"\n\nSources: {', '.join(sources)}" st.session_state.similar_questions = generate_similar_questions( question, result.get("source_documents", []) ) return answer return "Initializing knowledge base..." def main(): inject_custom_css() # CSS injection added here st.title("Pakistan Law AI Agent ⚖️") if st.session_state.using_custom_docs: st.subheader("Training on your personal resources") else: st.subheader("Powered by Pakistan law database") with st.sidebar: st.header("Resource Management") if st.session_state.using_custom_docs: if st.button("Return to Official Database"): with st.spinner("Loading official database..."): process_default_document() st.session_state.messages.append(AIMessage(content="Switched to official database!")) st.rerun() if not st.session_state.using_custom_docs: if st.button("Rebuild Official Database"): with st.spinner("Rebuilding..."): process_default_document(force_rebuild=True) st.rerun() st.header("Upload Custom Documents") uploaded_files = st.file_uploader( "Upload PDFs", type=["pdf"], accept_multiple_files=True) if st.button("Train on Uploaded Documents") and uploaded_files: with st.spinner("Processing..."): if process_custom_documents(uploaded_files): st.session_state.messages.append(AIMessage(content="Custom documents loaded!")) st.rerun() for message in st.session_state.messages: if isinstance(message, HumanMessage): with st.chat_message("user"): st.write(message.content) else: with st.chat_message("assistant", avatar="⚖️"): st.write(message.content) if st.session_state.similar_questions: st.markdown("#### Related Questions:") cols = st.columns(len(st.session_state.similar_questions)) for i, q in enumerate(st.session_state.similar_questions): if cols[i].button(q, key=f"similar_q_{i}"): st.session_state.messages.extend([ HumanMessage(content=q), AIMessage(content=get_answer(q)) ]) st.rerun() if user_input := st.chat_input("Ask a legal question..."): st.session_state.messages.append(HumanMessage(content=user_input)) with st.chat_message("user"): st.write(user_input) with st.chat_message("assistant", avatar="⚖️"): with st.spinner("Thinking..."): response = get_answer(user_input) st.write(response) st.session_state.messages.append(AIMessage(content=response)) st.rerun() if __name__ == "__main__": main()