Spaces:
Sleeping
Sleeping
# main.py | |
import torch | |
from fastapi import FastAPI, File, UploadFile, Depends, HTTPException, status | |
from fastapi.middleware.cors import CORSMiddleware | |
from transformers import pipeline | |
import io | |
import librosa | |
import re | |
import uuid | |
from pydantic import BaseModel | |
import sqlite3 | |
from passlib.context import CryptContext | |
import os | |
# --- Password Hashing Setup --- | |
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") | |
# --- Database Setup --- | |
DATABASE_NAME = os.getenv("DATABASE_PATH", "lexisynth.db") | |
def get_db_connection(): | |
# Use a yield pattern for more efficient connections in FastAPI | |
conn = sqlite3.connect(DATABASE_NAME) | |
conn.row_factory = sqlite3.Row | |
try: | |
yield conn | |
finally: | |
conn.close() | |
def create_user_table(): | |
# Use 'with' statement for cleaner connection handling | |
with sqlite3.connect(DATABASE_NAME) as conn: | |
conn.execute(''' | |
CREATE TABLE IF NOT EXISTS users ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
name TEXT NOT NULL, | |
username TEXT NOT NULL UNIQUE, | |
email TEXT NOT NULL UNIQUE, | |
hashed_password TEXT NOT NULL | |
) | |
''') | |
conn.commit() | |
# Create the table when the app starts | |
create_user_table() | |
# --- Pydantic Models for API --- | |
class UserCreate(BaseModel): | |
name: str | |
username: str | |
email: str | |
password: str | |
class UserLogin(BaseModel): | |
username: str | |
password: str | |
class QuestionRequest(BaseModel): | |
question: str | |
context: str | |
class TextAnalysisRequest(BaseModel): | |
filename: str | |
text: str | |
# --- Security Functions --- | |
def verify_password(plain_password, hashed_password): | |
return pwd_context.verify(plain_password, hashed_password) | |
def get_password_hash(password): | |
return pwd_context.hash(password) | |
# --- Initialize the FastAPI app --- | |
app = FastAPI( | |
title="Lexi-Synth API", | |
description="An AI-powered legal analysis tool with user authentication." | |
) | |
# --- Add CORS Middleware --- | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=[ | |
# For local development | |
"http://localhost:8501", | |
# Hugging Face Spaces URLs for your FRONTEND | |
"https://yogeshbawankar03-lexiscribe-frontend.hf.space", | |
"https://huggingface.co/spaces/yogeshbawankar03/lexiscribe-frontend", | |
# Hugging Face Spaces URLs for your BACKEND | |
"https://yogeshbawankar03-lexiscribe-backend.hf.space", | |
"https://huggingface.co/spaces/yogeshbawankar03/lexiscribe-backend" | |
], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# --- Load AI Models --- | |
print("Loading all AI models...") | |
device_id = 0 if torch.cuda.is_available() else -1 | |
summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-6-6", device=device_id) | |
ner_pipeline = pipeline("ner", model="dslim/bert-base-NER", aggregation_strategy="simple", device=device_id) | |
asr_pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-base.en", return_timestamps="word", device=device_id) | |
qa_pipeline = pipeline("question-answering", model="distilbert-base-cased-distilled-squad", device=device_id) | |
print("All models loaded.") | |
# --- Helper Functions --- | |
def find_legal_citations(text: str): | |
citation_pattern = re.compile(r'((?:[A-Z][a-zA-Z\s]+)+Act\s+§\s+[\d\(\),\s]+)|(\d+\s+C\.F\.R\.\s+§\s+[\d\.]+)') | |
matches = citation_pattern.findall(text) | |
return ["".join(match) for match in matches] | |
def fetch_citation_text(citation: str): | |
# This is a mock function, replace with real implementation if needed | |
if "16 C.F.R. § 444.1" in citation: | |
return "Mock text for 16 C.F.R. § 444.1..." | |
if "Uniform Trade Secrets Act" in citation: | |
return "Mock text for the Uniform Trade Secrets Act..." | |
return f"Full text for '{citation}' could not be retrieved." | |
def format_transcription(chunks): | |
full_text = "" | |
for chunk in chunks: | |
start_str = f"[{chunk['timestamp'][0]:07.3f}]" | |
full_text += f"{start_str} {chunk['text']}" | |
return full_text.strip() | |
def analyze_text(text: str): | |
max_chunk_length = 450 | |
words = text.split() | |
chunks = [" ".join(words[i:i + max_chunk_length]) for i in range(0, len(words), max_chunk_length)] | |
if not chunks: | |
chunks.append(text) | |
summaries = summarizer(chunks, max_length=100, min_length=20, do_sample=False) | |
summary_text = " ".join([s['summary_text'] for s in summaries]) | |
ner_results_list = ner_pipeline(chunks) | |
ner_results = [entity for sublist in ner_results_list for entity in sublist] | |
seen_entities = set() | |
cleaned_entities = [] | |
for entity in ner_results: | |
entity_tuple = (entity['word'], entity['entity_group']) | |
if entity_tuple not in seen_entities: | |
cleaned_entities.append({"entity_group": entity["entity_group"], "score": float(entity["score"]), "word": entity["word"]}) | |
seen_entities.add(entity_tuple) | |
found_citations = find_legal_citations(text) | |
citations_with_text = [{"citation": cit, "text": fetch_citation_text(cit)} for cit in found_citations] | |
return {"analysis_id": str(uuid.uuid4()), "summary": summary_text, "entities": cleaned_entities, "citations": citations_with_text, "original_text": text} | |
# --- API Endpoints --- | |
def read_root(): | |
return {"message": "Welcome to Lexi-Synth. The API is running."} | |
def register_user(user: UserCreate, conn: sqlite3.Connection = Depends(get_db_connection)): | |
hashed_password = get_password_hash(user.password) | |
try: | |
conn.execute( | |
"INSERT INTO users (name, username, email, hashed_password) VALUES (?, ?, ?, ?)", | |
(user.name, user.username, user.email, hashed_password), | |
) | |
conn.commit() | |
except sqlite3.IntegrityError: | |
raise HTTPException(status_code=400, detail="Username or email already registered") | |
return {"message": "User registered successfully"} | |
def login_user(user: UserLogin, conn: sqlite3.Connection = Depends(get_db_connection)): | |
db_user = conn.execute("SELECT * FROM users WHERE username = ?", (user.username,)).fetchone() | |
if not db_user or not verify_password(user.password, db_user["hashed_password"]): | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Incorrect username or password", | |
headers={"WWW-Authenticate": "Bearer"}, | |
) | |
return {"message": "Login successful", "user": {"username": db_user["username"], "name": db_user["name"]}} | |
async def analyze_text_from_json(request: TextAnalysisRequest): | |
analysis_results = analyze_text(request.text) | |
return {"filename": request.filename, **analysis_results} | |
async def analyze_audio_file(file: UploadFile = File(...)): | |
audio_bytes = await file.read() | |
audio, _ = librosa.load(io.BytesIO(audio_bytes), sr=16000) | |
transcription_result = asr_pipeline({"sampling_rate": 16000, "raw": audio}) | |
analysis_results = analyze_text(transcription_result['text']) | |
analysis_results['original_text'] = format_transcription(transcription_result['chunks']) | |
return {"filename": file.filename, **analysis_results} | |
async def answer_question(request: QuestionRequest): | |
result = qa_pipeline(question=request.question, context=request.context) | |
return {"answer": result['answer'], "score": result['score']} |