Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, UploadFile, File, Request, Depends, HTTPException, APIRouter | |
from fastapi.responses import JSONResponse, FileResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from huggingface_hub import HfApi | |
from io import BytesIO | |
import re | |
import docx | |
from pathlib import Path | |
from docx.enum.text import WD_COLOR_INDEX | |
from PyPDF2 import PdfReader | |
from fastapi import FastAPI, UploadFile, File, Form | |
from fastapi.responses import JSONResponse | |
from Ai_rewriter.rewriter_fixed import rewrite_text | |
import uuid | |
import stripe | |
from pydantic import BaseModel | |
from supabase import create_client, Client | |
from dotenv import load_dotenv | |
import subprocess | |
import tempfile | |
import os | |
import shlex | |
load_dotenv() | |
# === CONFIG === | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
HF_DATASET_REPO = "AlyanAkram/StealthReports" | |
CORS_ORIGINS = ["http://localhost:5173", "https://stealth-writer.vercel.app/"] | |
stripe.api_key = os.getenv("STRIPE_SECRET_KEY") | |
print("π Stripe key loaded:", stripe.api_key, len(stripe.api_key)) | |
supabase: Client | None = None | |
PRICE_MAP = { | |
"basic": "price_1RyxK4KiaPeHFPzzwBG5C5Rf", | |
"premium": "price_1RyxKBKiaPeHFPzz5oDy6m2c", | |
} | |
# === FastAPI app setup === | |
app = FastAPI(docs_url="/docs", redoc_url="/redoc", openapi_url="/openapi.json") | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=CORS_ORIGINS, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# === Load model on startup === | |
analyze_text = None | |
generate_pdf_report = None | |
async def load_model(): | |
global analyze_text, generate_pdf_report, supabase | |
from detector.custom_model import analyze_text as at, generate_pdf_report as gpr | |
analyze_text = at | |
generate_pdf_report = gpr | |
supabase = create_client( | |
os.getenv("SUPABASE_URL"), | |
os.getenv("SUPABASE_KEY") | |
) | |
# === Utils === | |
def extract_text(file: UploadFile, ext: str) -> str: | |
content = file.file.read() | |
file_bytes = BytesIO(content) | |
if ext == ".txt": | |
return content.decode("utf-8", errors="ignore") | |
elif ext == ".pdf": | |
reader = PdfReader(file_bytes) | |
return "".join([page.extract_text() or "" for page in reader.pages]) | |
elif ext == ".docx": | |
doc = docx.Document(file_bytes) | |
return "\n".join([para.text for para in doc.paragraphs]) | |
else: | |
raise ValueError("Unsupported file type") | |
def sanitize_filename(name): | |
return re.sub(r"[^\w\-_.]", "_", name) | |
def upload_to_dataset(path: str, content: BytesIO, token: str) -> str: | |
api = HfApi() | |
api.upload_file( | |
path_or_fileobj=content, | |
path_in_repo=path, | |
repo_id=HF_DATASET_REPO, | |
repo_type="dataset", | |
token=token, | |
) | |
return f"https://huggingface.co/datasets/{HF_DATASET_REPO}/resolve/main/{path}" | |
# === Main endpoint === | |
async def detect(file: UploadFile = File(...)): | |
try: | |
ext = os.path.splitext(file.filename)[1].lower() | |
if ext not in [".txt", ".pdf", ".docx"]: | |
raise ValueError("Unsupported file format") | |
text = extract_text(file, ext) | |
result = analyze_text(text) | |
filename_base = sanitize_filename(os.path.splitext(file.filename)[0]) + "_" + str(uuid.uuid4())[:8] | |
docx_buffer = BytesIO() | |
doc = docx.Document() | |
doc.add_heading("AI Detection Summary", level=1) | |
doc.add_paragraph(f"Overall AI %: {result['overall_ai_percent']}%") | |
doc.add_paragraph(f"Total Sentences: {result['total_sentences']}") | |
doc.add_paragraph(f"AI Sentences: {result['ai_sentences']}") | |
doc.add_paragraph("Sentences detected as AI are highlighted in cyan.\n") | |
doc.add_heading("Sentence Analysis", level=2) | |
paragraph = doc.add_paragraph() | |
for para in result["results"]: | |
for sentence, is_ai, _ in para: | |
if not isinstance(sentence, str) or not sentence.strip(): | |
continue | |
run = paragraph.add_run(sentence + " ") | |
if is_ai: | |
run.font.highlight_color = WD_COLOR_INDEX.TURQUOISE | |
doc.save(docx_buffer) | |
docx_buffer.seek(0) | |
pdf_buffer = generate_pdf_report(result, filename_base) | |
docx_url = upload_to_dataset(f"{filename_base}.docx", docx_buffer, HF_TOKEN) | |
pdf_url = upload_to_dataset(f"{filename_base}.pdf", pdf_buffer, HF_TOKEN) | |
return { | |
"success": True, | |
"score": { | |
**{k: v for k, v in result.items() if k != "results"}, | |
"results": [ | |
[{"sentence": s, "is_ai": is_ai, "ai_score": round(ai_score * 100, 2)} for s, is_ai, ai_score in para] | |
for para in result["results"] | |
] | |
}, | |
"docx_url": docx_url, | |
"pdf_url": pdf_url | |
} | |
except Exception as e: | |
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500) | |
router = APIRouter(prefix="/api/payments", tags=["payments"]) | |
class CheckoutReq(BaseModel): | |
plan: str | |
success_url: str | |
cancel_url: str | |
email: str | |
user_id: str | |
async def create_session(data: CheckoutReq): | |
if data.plan not in PRICE_MAP: | |
raise HTTPException(400, "Unknown plan") | |
session = stripe.checkout.Session.create( | |
mode="subscription", | |
payment_method_types=["card"], | |
line_items=[{"price": PRICE_MAP[data.plan], "quantity": 1}], | |
customer_email=data.email, | |
client_reference_id=data.user_id, # β Store Supabase user.id | |
success_url=data.success_url + "?session_id={CHECKOUT_SESSION_ID}", | |
cancel_url=data.cancel_url, | |
) | |
return {"url": session.url} | |
async def stripe_webhook(request: Request): | |
payload = await request.body() | |
sig_header = request.headers.get("stripe-signature") | |
webhook_secret = os.getenv("STRIPE_WEBHOOK_SECRET") | |
try: | |
event = stripe.Webhook.construct_event(payload, sig_header, webhook_secret) | |
except stripe.error.SignatureVerificationError: | |
return JSONResponse(status_code=400, content={"error": "Invalid signature"}) | |
if event["type"] == "checkout.session.completed": | |
session = event["data"]["object"] | |
user_id = session.get("client_reference_id") # β Supabase user.id | |
subscription_id = session.get("subscription") | |
try: | |
subscription = stripe.Subscription.retrieve(subscription_id) | |
price_id = subscription["items"]["data"][0]["price"]["id"] | |
plan = next((k for k, v in PRICE_MAP.items() if v == price_id), None) | |
if plan and user_id and supabase: | |
print(f"Updating plan to {plan} for Supabase user {user_id}") | |
supabase.rpc("update_user_plan", { | |
"uid": user_id, | |
"new_plan": plan | |
}).execute() | |
print("β Supabase update response:", response) | |
except Exception as e: | |
print("Webhook error while updating Supabase:", str(e)) | |
return {"status": "success"} | |
# Use environment variable if running on Hugging Face Spaces | |
output_dir_env = os.environ.get("REWRITTEN_OUTPUTS_DIR", "rewritten_outputs") | |
OUTPUT_DIR = Path(output_dir_env) | |
OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
# Define the TextInput model | |
class TextInput(BaseModel): | |
text: str | |
async def extract_text_from_file(file: UploadFile) -> str: | |
"""Wrapper to match your existing extract_text but works with async file read.""" | |
ext = os.path.splitext(file.filename)[1].lower() | |
if ext not in [".txt", ".pdf", ".docx"]: | |
raise ValueError("Unsupported file format") | |
return extract_text(file, ext) | |
class RewriteRequest(BaseModel): | |
text: str | |
async def rewrite_endpoint(file: UploadFile = File(...)): | |
try: | |
ext = os.path.splitext(file.filename)[1].lower() | |
if ext not in [".txt", ".pdf", ".docx"]: | |
raise ValueError("Unsupported file format") | |
# 1. Save input file to temp | |
temp_input = Path(tempfile.gettempdir()) / f"input_{uuid.uuid4().hex}{ext}" | |
with open(temp_input, "wb") as f: | |
f.write(await file.read()) | |
# 2. Temp output file | |
temp_output = Path(tempfile.gettempdir()) / f"rewritten_{uuid.uuid4().hex}{ext}" | |
# 3. Run rewriter_fixed.py on file | |
subprocess.run( | |
["python", "-X", "utf8", "Ai_rewriter/rewriter_fixed.py", str(temp_input), str(temp_output)], | |
check=True | |
) | |
if not temp_output.exists(): | |
raise HTTPException(status_code=500, detail="Rewriter did not produce an output file.") | |
# 4. Create DOCX + PDF versions | |
rewritten_text = "" | |
if ext == ".txt": | |
rewritten_text = temp_output.read_text(encoding="utf-8") | |
elif ext == ".docx": | |
import docx | |
doc = docx.Document(temp_output) | |
rewritten_text = "\n".join([p.text for p in doc.paragraphs]) | |
elif ext == ".pdf": | |
from PyPDF2 import PdfReader | |
reader = PdfReader(str(temp_output)) | |
rewritten_text = "".join([page.extract_text() or "" for page in reader.pages]) | |
# Create DOCX | |
docx_buffer = BytesIO() | |
doc = docx.Document() | |
doc.add_paragraph(rewritten_text) | |
doc.save(docx_buffer) | |
docx_buffer.seek(0) | |
# Create PDF | |
from reportlab.lib.pagesizes import A4 | |
from reportlab.pdfgen import canvas | |
from reportlab.lib.units import inch | |
pdf_buffer = BytesIO() | |
c = canvas.Canvas(pdf_buffer, pagesize=A4) | |
width, height = A4 | |
text_object = c.beginText(0.5 * inch, height - 0.5 * inch) | |
text_object.setFont("Times-Roman", 12) | |
for line in rewritten_text.split("\n"): | |
text_object.textLine(line) | |
c.drawText(text_object) | |
c.showPage() | |
c.save() | |
pdf_buffer.seek(0) | |
# 5. Upload to Hugging Face dataset like detect | |
filename_base = sanitize_filename(os.path.splitext(file.filename)[0]) + "_" + str(uuid.uuid4())[:8] | |
docx_url = upload_to_dataset(f"{filename_base}.docx", docx_buffer, HF_TOKEN) | |
pdf_url = upload_to_dataset(f"{filename_base}.pdf", pdf_buffer, HF_TOKEN) | |
# 6. Cleanup | |
temp_input.unlink(missing_ok=True) | |
temp_output.unlink(missing_ok=True) | |
return { | |
"success": True, | |
"docx_url": docx_url, | |
"pdf_url": pdf_url | |
} | |
except subprocess.CalledProcessError as e: | |
raise HTTPException(status_code=500, detail=f"Rewriter process failed: {e}") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def rewrite_text(req: RewriteRequest): | |
temp_in_path = None | |
temp_out_path = None | |
try: | |
# Always set file paths before subprocess to avoid reference errors | |
temp_in = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w", encoding="utf-8") | |
temp_in.write(req.text) | |
temp_in.close() | |
temp_in_path = temp_in.name | |
temp_out_path = Path(tempfile.gettempdir()) / f"rewritten_{os.path.basename(temp_in_path)}" | |
# Prepare CLI exactly like your manual run, forcing UTF-8 | |
cmd = [ | |
"python", "-X", "utf8", | |
"Ai_rewriter/rewriter_fixed.py", | |
req.text, # pass raw text directly | |
str(temp_out_path) | |
] | |
print("π Running:", shlex.join(cmd)) | |
result = subprocess.run( | |
cmd, | |
capture_output=True, | |
text=True, | |
encoding="utf-8", | |
errors="replace", | |
shell=False | |
) | |
print("πΉ STDOUT:\n", result.stdout) | |
print("πΉ STDERR:\n", result.stderr) | |
if result.returncode != 0: | |
raise HTTPException( | |
status_code=500, | |
detail=f"Model process failed: {result.stderr or 'Unknown error'}" | |
) | |
if not temp_out_path.exists(): | |
raise HTTPException(status_code=500, detail="Output file not found.") | |
rewritten_text = temp_out_path.read_text(encoding="utf-8") | |
return {"rewritten_text": rewritten_text} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
finally: | |
# Cleanup safely | |
if temp_in_path and os.path.exists(temp_in_path): | |
os.remove(temp_in_path) | |
if temp_out_path and os.path.exists(temp_out_path): | |
os.remove(temp_out_path) | |
app.include_router(router) |