StealthWriter / main.py
AlyanAkram's picture
Update main.py
0874acd verified
from fastapi import FastAPI, UploadFile, File, Request, Depends, HTTPException, APIRouter
from fastapi.responses import JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
from huggingface_hub import HfApi
from io import BytesIO
import re
import docx
from pathlib import Path
from docx.enum.text import WD_COLOR_INDEX
from PyPDF2 import PdfReader
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.responses import JSONResponse
from Ai_rewriter.rewriter_fixed import rewrite_text
import uuid
import stripe
from pydantic import BaseModel
from supabase import create_client, Client
from dotenv import load_dotenv
import subprocess
import tempfile
import os
import shlex
load_dotenv()
# === CONFIG ===
HF_TOKEN = os.environ.get("HF_TOKEN")
HF_DATASET_REPO = "AlyanAkram/StealthReports"
CORS_ORIGINS = ["http://localhost:5173", "https://stealth-writer.vercel.app/"]
stripe.api_key = os.getenv("STRIPE_SECRET_KEY")
print("πŸ”‘ Stripe key loaded:", stripe.api_key, len(stripe.api_key))
supabase: Client | None = None
PRICE_MAP = {
"basic": "price_1RyxK4KiaPeHFPzzwBG5C5Rf",
"premium": "price_1RyxKBKiaPeHFPzz5oDy6m2c",
}
# === FastAPI app setup ===
app = FastAPI(docs_url="/docs", redoc_url="/redoc", openapi_url="/openapi.json")
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# === Load model on startup ===
analyze_text = None
generate_pdf_report = None
@app.on_event("startup")
async def load_model():
global analyze_text, generate_pdf_report, supabase
from detector.custom_model import analyze_text as at, generate_pdf_report as gpr
analyze_text = at
generate_pdf_report = gpr
supabase = create_client(
os.getenv("SUPABASE_URL"),
os.getenv("SUPABASE_KEY")
)
# === Utils ===
def extract_text(file: UploadFile, ext: str) -> str:
content = file.file.read()
file_bytes = BytesIO(content)
if ext == ".txt":
return content.decode("utf-8", errors="ignore")
elif ext == ".pdf":
reader = PdfReader(file_bytes)
return "".join([page.extract_text() or "" for page in reader.pages])
elif ext == ".docx":
doc = docx.Document(file_bytes)
return "\n".join([para.text for para in doc.paragraphs])
else:
raise ValueError("Unsupported file type")
def sanitize_filename(name):
return re.sub(r"[^\w\-_.]", "_", name)
def upload_to_dataset(path: str, content: BytesIO, token: str) -> str:
api = HfApi()
api.upload_file(
path_or_fileobj=content,
path_in_repo=path,
repo_id=HF_DATASET_REPO,
repo_type="dataset",
token=token,
)
return f"https://huggingface.co/datasets/{HF_DATASET_REPO}/resolve/main/{path}"
# === Main endpoint ===
@app.post("/api/detect")
async def detect(file: UploadFile = File(...)):
try:
ext = os.path.splitext(file.filename)[1].lower()
if ext not in [".txt", ".pdf", ".docx"]:
raise ValueError("Unsupported file format")
text = extract_text(file, ext)
result = analyze_text(text)
filename_base = sanitize_filename(os.path.splitext(file.filename)[0]) + "_" + str(uuid.uuid4())[:8]
docx_buffer = BytesIO()
doc = docx.Document()
doc.add_heading("AI Detection Summary", level=1)
doc.add_paragraph(f"Overall AI %: {result['overall_ai_percent']}%")
doc.add_paragraph(f"Total Sentences: {result['total_sentences']}")
doc.add_paragraph(f"AI Sentences: {result['ai_sentences']}")
doc.add_paragraph("Sentences detected as AI are highlighted in cyan.\n")
doc.add_heading("Sentence Analysis", level=2)
paragraph = doc.add_paragraph()
for para in result["results"]:
for sentence, is_ai, _ in para:
if not isinstance(sentence, str) or not sentence.strip():
continue
run = paragraph.add_run(sentence + " ")
if is_ai:
run.font.highlight_color = WD_COLOR_INDEX.TURQUOISE
doc.save(docx_buffer)
docx_buffer.seek(0)
pdf_buffer = generate_pdf_report(result, filename_base)
docx_url = upload_to_dataset(f"{filename_base}.docx", docx_buffer, HF_TOKEN)
pdf_url = upload_to_dataset(f"{filename_base}.pdf", pdf_buffer, HF_TOKEN)
return {
"success": True,
"score": {
**{k: v for k, v in result.items() if k != "results"},
"results": [
[{"sentence": s, "is_ai": is_ai, "ai_score": round(ai_score * 100, 2)} for s, is_ai, ai_score in para]
for para in result["results"]
]
},
"docx_url": docx_url,
"pdf_url": pdf_url
}
except Exception as e:
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
router = APIRouter(prefix="/api/payments", tags=["payments"])
class CheckoutReq(BaseModel):
plan: str
success_url: str
cancel_url: str
email: str
user_id: str
@router.post("/create-session")
async def create_session(data: CheckoutReq):
if data.plan not in PRICE_MAP:
raise HTTPException(400, "Unknown plan")
session = stripe.checkout.Session.create(
mode="subscription",
payment_method_types=["card"],
line_items=[{"price": PRICE_MAP[data.plan], "quantity": 1}],
customer_email=data.email,
client_reference_id=data.user_id, # βœ… Store Supabase user.id
success_url=data.success_url + "?session_id={CHECKOUT_SESSION_ID}",
cancel_url=data.cancel_url,
)
return {"url": session.url}
@router.post("/webhook")
async def stripe_webhook(request: Request):
payload = await request.body()
sig_header = request.headers.get("stripe-signature")
webhook_secret = os.getenv("STRIPE_WEBHOOK_SECRET")
try:
event = stripe.Webhook.construct_event(payload, sig_header, webhook_secret)
except stripe.error.SignatureVerificationError:
return JSONResponse(status_code=400, content={"error": "Invalid signature"})
if event["type"] == "checkout.session.completed":
session = event["data"]["object"]
user_id = session.get("client_reference_id") # βœ… Supabase user.id
subscription_id = session.get("subscription")
try:
subscription = stripe.Subscription.retrieve(subscription_id)
price_id = subscription["items"]["data"][0]["price"]["id"]
plan = next((k for k, v in PRICE_MAP.items() if v == price_id), None)
if plan and user_id and supabase:
print(f"Updating plan to {plan} for Supabase user {user_id}")
supabase.rpc("update_user_plan", {
"uid": user_id,
"new_plan": plan
}).execute()
print("βœ… Supabase update response:", response)
except Exception as e:
print("Webhook error while updating Supabase:", str(e))
return {"status": "success"}
# Use environment variable if running on Hugging Face Spaces
output_dir_env = os.environ.get("REWRITTEN_OUTPUTS_DIR", "rewritten_outputs")
OUTPUT_DIR = Path(output_dir_env)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
# Define the TextInput model
class TextInput(BaseModel):
text: str
async def extract_text_from_file(file: UploadFile) -> str:
"""Wrapper to match your existing extract_text but works with async file read."""
ext = os.path.splitext(file.filename)[1].lower()
if ext not in [".txt", ".pdf", ".docx"]:
raise ValueError("Unsupported file format")
return extract_text(file, ext)
class RewriteRequest(BaseModel):
text: str
@app.post("/rewrite")
async def rewrite_endpoint(file: UploadFile = File(...)):
try:
ext = os.path.splitext(file.filename)[1].lower()
if ext not in [".txt", ".pdf", ".docx"]:
raise ValueError("Unsupported file format")
# 1. Save input file to temp
temp_input = Path(tempfile.gettempdir()) / f"input_{uuid.uuid4().hex}{ext}"
with open(temp_input, "wb") as f:
f.write(await file.read())
# 2. Temp output file
temp_output = Path(tempfile.gettempdir()) / f"rewritten_{uuid.uuid4().hex}{ext}"
# 3. Run rewriter_fixed.py on file
subprocess.run(
["python", "-X", "utf8", "Ai_rewriter/rewriter_fixed.py", str(temp_input), str(temp_output)],
check=True
)
if not temp_output.exists():
raise HTTPException(status_code=500, detail="Rewriter did not produce an output file.")
# 4. Create DOCX + PDF versions
rewritten_text = ""
if ext == ".txt":
rewritten_text = temp_output.read_text(encoding="utf-8")
elif ext == ".docx":
import docx
doc = docx.Document(temp_output)
rewritten_text = "\n".join([p.text for p in doc.paragraphs])
elif ext == ".pdf":
from PyPDF2 import PdfReader
reader = PdfReader(str(temp_output))
rewritten_text = "".join([page.extract_text() or "" for page in reader.pages])
# Create DOCX
docx_buffer = BytesIO()
doc = docx.Document()
doc.add_paragraph(rewritten_text)
doc.save(docx_buffer)
docx_buffer.seek(0)
# Create PDF
from reportlab.lib.pagesizes import A4
from reportlab.pdfgen import canvas
from reportlab.lib.units import inch
pdf_buffer = BytesIO()
c = canvas.Canvas(pdf_buffer, pagesize=A4)
width, height = A4
text_object = c.beginText(0.5 * inch, height - 0.5 * inch)
text_object.setFont("Times-Roman", 12)
for line in rewritten_text.split("\n"):
text_object.textLine(line)
c.drawText(text_object)
c.showPage()
c.save()
pdf_buffer.seek(0)
# 5. Upload to Hugging Face dataset like detect
filename_base = sanitize_filename(os.path.splitext(file.filename)[0]) + "_" + str(uuid.uuid4())[:8]
docx_url = upload_to_dataset(f"{filename_base}.docx", docx_buffer, HF_TOKEN)
pdf_url = upload_to_dataset(f"{filename_base}.pdf", pdf_buffer, HF_TOKEN)
# 6. Cleanup
temp_input.unlink(missing_ok=True)
temp_output.unlink(missing_ok=True)
return {
"success": True,
"docx_url": docx_url,
"pdf_url": pdf_url
}
except subprocess.CalledProcessError as e:
raise HTTPException(status_code=500, detail=f"Rewriter process failed: {e}")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/rewrite-text")
async def rewrite_text(req: RewriteRequest):
temp_in_path = None
temp_out_path = None
try:
# Always set file paths before subprocess to avoid reference errors
temp_in = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w", encoding="utf-8")
temp_in.write(req.text)
temp_in.close()
temp_in_path = temp_in.name
temp_out_path = Path(tempfile.gettempdir()) / f"rewritten_{os.path.basename(temp_in_path)}"
# Prepare CLI exactly like your manual run, forcing UTF-8
cmd = [
"python", "-X", "utf8",
"Ai_rewriter/rewriter_fixed.py",
req.text, # pass raw text directly
str(temp_out_path)
]
print("πŸš€ Running:", shlex.join(cmd))
result = subprocess.run(
cmd,
capture_output=True,
text=True,
encoding="utf-8",
errors="replace",
shell=False
)
print("πŸ”Ή STDOUT:\n", result.stdout)
print("πŸ”Ή STDERR:\n", result.stderr)
if result.returncode != 0:
raise HTTPException(
status_code=500,
detail=f"Model process failed: {result.stderr or 'Unknown error'}"
)
if not temp_out_path.exists():
raise HTTPException(status_code=500, detail="Output file not found.")
rewritten_text = temp_out_path.read_text(encoding="utf-8")
return {"rewritten_text": rewritten_text}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
finally:
# Cleanup safely
if temp_in_path and os.path.exists(temp_in_path):
os.remove(temp_in_path)
if temp_out_path and os.path.exists(temp_out_path):
os.remove(temp_out_path)
app.include_router(router)