# main.py import torch from fastapi import FastAPI, File, UploadFile, Depends, HTTPException, status from fastapi.middleware.cors import CORSMiddleware from transformers import pipeline import io import librosa import re import uuid from pydantic import BaseModel import sqlite3 from passlib.context import CryptContext import os # --- Password Hashing Setup --- pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") # --- Database Setup --- DATABASE_NAME = os.getenv("DATABASE_PATH", "lexisynth.db") def get_db_connection(): # Use a yield pattern for more efficient connections in FastAPI conn = sqlite3.connect(DATABASE_NAME) conn.row_factory = sqlite3.Row try: yield conn finally: conn.close() def create_user_table(): # Use 'with' statement for cleaner connection handling with sqlite3.connect(DATABASE_NAME) as conn: conn.execute(''' CREATE TABLE IF NOT EXISTS users ( id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, username TEXT NOT NULL UNIQUE, email TEXT NOT NULL UNIQUE, hashed_password TEXT NOT NULL ) ''') conn.commit() # Create the table when the app starts create_user_table() # --- Pydantic Models for API --- class UserCreate(BaseModel): name: str username: str email: str password: str class UserLogin(BaseModel): username: str password: str class QuestionRequest(BaseModel): question: str context: str class TextAnalysisRequest(BaseModel): filename: str text: str # --- Security Functions --- def verify_password(plain_password, hashed_password): return pwd_context.verify(plain_password, hashed_password) def get_password_hash(password): return pwd_context.hash(password) # --- Initialize the FastAPI app --- app = FastAPI( title="Lexi-Synth API", description="An AI-powered legal analysis tool with user authentication." ) # --- Add CORS Middleware --- app.add_middleware( CORSMiddleware, allow_origins=[ # For local development "http://localhost:8501", # Hugging Face Spaces URLs for your FRONTEND "https://yogeshbawankar03-lexiscribe-frontend.hf.space", "https://huggingface.co/spaces/yogeshbawankar03/lexiscribe-frontend", # Hugging Face Spaces URLs for your BACKEND "https://yogeshbawankar03-lexiscribe-backend.hf.space", "https://huggingface.co/spaces/yogeshbawankar03/lexiscribe-backend" ], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --- Load AI Models --- print("Loading all AI models...") device_id = 0 if torch.cuda.is_available() else -1 summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-6-6", device=device_id) ner_pipeline = pipeline("ner", model="dslim/bert-base-NER", aggregation_strategy="simple", device=device_id) asr_pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-base.en", return_timestamps="word", device=device_id) qa_pipeline = pipeline("question-answering", model="distilbert-base-cased-distilled-squad", device=device_id) print("All models loaded.") # --- Helper Functions --- def find_legal_citations(text: str): citation_pattern = re.compile(r'((?:[A-Z][a-zA-Z\s]+)+Act\s+§\s+[\d\(\),\s]+)|(\d+\s+C\.F\.R\.\s+§\s+[\d\.]+)') matches = citation_pattern.findall(text) return ["".join(match) for match in matches] def fetch_citation_text(citation: str): # This is a mock function, replace with real implementation if needed if "16 C.F.R. § 444.1" in citation: return "Mock text for 16 C.F.R. § 444.1..." if "Uniform Trade Secrets Act" in citation: return "Mock text for the Uniform Trade Secrets Act..." return f"Full text for '{citation}' could not be retrieved." def format_transcription(chunks): full_text = "" for chunk in chunks: start_str = f"[{chunk['timestamp'][0]:07.3f}]" full_text += f"{start_str} {chunk['text']}" return full_text.strip() def analyze_text(text: str): max_chunk_length = 450 words = text.split() chunks = [" ".join(words[i:i + max_chunk_length]) for i in range(0, len(words), max_chunk_length)] if not chunks: chunks.append(text) summaries = summarizer(chunks, max_length=100, min_length=20, do_sample=False) summary_text = " ".join([s['summary_text'] for s in summaries]) ner_results_list = ner_pipeline(chunks) ner_results = [entity for sublist in ner_results_list for entity in sublist] seen_entities = set() cleaned_entities = [] for entity in ner_results: entity_tuple = (entity['word'], entity['entity_group']) if entity_tuple not in seen_entities: cleaned_entities.append({"entity_group": entity["entity_group"], "score": float(entity["score"]), "word": entity["word"]}) seen_entities.add(entity_tuple) found_citations = find_legal_citations(text) citations_with_text = [{"citation": cit, "text": fetch_citation_text(cit)} for cit in found_citations] return {"analysis_id": str(uuid.uuid4()), "summary": summary_text, "entities": cleaned_entities, "citations": citations_with_text, "original_text": text} # --- API Endpoints --- @app.get("/") def read_root(): return {"message": "Welcome to Lexi-Synth. The API is running."} @app.post("/register/") def register_user(user: UserCreate, conn: sqlite3.Connection = Depends(get_db_connection)): hashed_password = get_password_hash(user.password) try: conn.execute( "INSERT INTO users (name, username, email, hashed_password) VALUES (?, ?, ?, ?)", (user.name, user.username, user.email, hashed_password), ) conn.commit() except sqlite3.IntegrityError: raise HTTPException(status_code=400, detail="Username or email already registered") return {"message": "User registered successfully"} @app.post("/login/") def login_user(user: UserLogin, conn: sqlite3.Connection = Depends(get_db_connection)): db_user = conn.execute("SELECT * FROM users WHERE username = ?", (user.username,)).fetchone() if not db_user or not verify_password(user.password, db_user["hashed_password"]): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect username or password", headers={"WWW-Authenticate": "Bearer"}, ) return {"message": "Login successful", "user": {"username": db_user["username"], "name": db_user["name"]}} @app.post("/analyze-text/") async def analyze_text_from_json(request: TextAnalysisRequest): analysis_results = analyze_text(request.text) return {"filename": request.filename, **analysis_results} @app.post("/analyze-audio/") async def analyze_audio_file(file: UploadFile = File(...)): audio_bytes = await file.read() audio, _ = librosa.load(io.BytesIO(audio_bytes), sr=16000) transcription_result = asr_pipeline({"sampling_rate": 16000, "raw": audio}) analysis_results = analyze_text(transcription_result['text']) analysis_results['original_text'] = format_transcription(transcription_result['chunks']) return {"filename": file.filename, **analysis_results} @app.post("/answer-question/") async def answer_question(request: QuestionRequest): result = qa_pipeline(question=request.question, context=request.context) return {"answer": result['answer'], "score": result['score']}