|
import gradio as gr |
|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, pipeline |
|
from PIL import Image |
|
import traceback |
|
import warnings |
|
import json |
|
|
|
|
|
TROCR_MODELS = { |
|
"Printed Text": "microsoft/trocr-large-printed", |
|
"Handwritten": "microsoft/trocr-large-handwritten", |
|
} |
|
DETECTOR_MODEL_ID = "openai-community/roberta-large-openai-detector" |
|
print(f"Using AI Detector Model: {DETECTOR_MODEL_ID}") |
|
|
|
|
|
|
|
print("Loading OCR models...") |
|
OCR_PIPELINES = {} |
|
for name, model_id in TROCR_MODELS.items(): |
|
try: |
|
proc = TrOCRProcessor.from_pretrained(model_id) |
|
mdl = VisionEncoderDecoderModel.from_pretrained(model_id) |
|
OCR_PIPELINES[name] = (proc, mdl) |
|
print(f"Loaded {name} OCR model.") |
|
except Exception as e: |
|
print(f"Error loading OCR model {name} ({model_id}): {e}") |
|
|
|
print(f"Loading AI detector model ({DETECTOR_MODEL_ID})...") |
|
try: |
|
DETECTOR_PIPELINE = pipeline( |
|
"text-classification", |
|
model=DETECTOR_MODEL_ID, |
|
top_k=None |
|
) |
|
print("Loaded AI detector model.") |
|
except Exception as e: |
|
print(f"CRITICAL Error loading AI detector model ({DETECTOR_MODEL_ID}): {e}") |
|
traceback.print_exc() |
|
print("Exiting due to critical model loading failure.") |
|
exit() |
|
|
|
|
|
def get_ai_and_human_scores(results): |
|
""" |
|
Processes detector results to get likelihood scores for both AI and Human classes. |
|
Handles various label formats ('LABEL_0'/'LABEL_1', 'FAKE'/'REAL', 'AI'/'REAL'). |
|
Returns: |
|
tuple: (ai_display_string, human_display_string) |
|
""" |
|
ai_prob = 0.0 |
|
human_prob = 0.0 |
|
status_message = "Error: No results received" |
|
|
|
if not results: |
|
print("Warning: Received empty results for AI detection.") |
|
return status_message, "N/A" |
|
|
|
|
|
score_list = [] |
|
if isinstance(results, list) and len(results) > 0: |
|
if isinstance(results[0], list) and len(results[0]) > 0: |
|
score_list = results[0] |
|
elif isinstance(results[0], dict): |
|
score_list = results |
|
else: |
|
status_message = "Error: Unexpected detector output format (inner)" |
|
print(f"Warning: {status_message}. Results: {results[0]}") |
|
return status_message, "N/A" |
|
else: |
|
status_message = "Error: Unexpected detector output format (outer)" |
|
print(f"Warning: {status_message}. Results: {results}") |
|
return status_message, "N/A" |
|
|
|
|
|
lbl2score = { |
|
entry["label"].upper(): entry["score"] |
|
for entry in score_list |
|
if isinstance(entry, dict) and "label" in entry and "score" in entry |
|
} |
|
|
|
if not lbl2score: |
|
status_message = "Error: Could not parse detector scores" |
|
print(f"Warning: {status_message}. Score list: {score_list}") |
|
return status_message, "N/A" |
|
|
|
label_keys_found = ", ".join(lbl2score.keys()) |
|
found_pair = False |
|
inferred = False |
|
|
|
|
|
|
|
if "LABEL_1" in lbl2score and "LABEL_0" in lbl2score: |
|
ai_prob = lbl2score["LABEL_1"] |
|
human_prob = lbl2score["LABEL_0"] |
|
found_pair = True |
|
status_message = "OK (Used LABEL_1/LABEL_0)" |
|
elif "FAKE" in lbl2score and "REAL" in lbl2score: |
|
ai_prob = lbl2score["FAKE"] |
|
human_prob = lbl2score["REAL"] |
|
found_pair = True |
|
status_message = "OK (Used FAKE/REAL)" |
|
elif "AI" in lbl2score and "REAL" in lbl2score: |
|
ai_prob = lbl2score["AI"] |
|
human_prob = lbl2score["REAL"] |
|
found_pair = True |
|
status_message = "OK (Used AI/REAL)" |
|
|
|
|
|
if not found_pair: |
|
if "LABEL_1" in lbl2score: |
|
ai_prob = lbl2score["LABEL_1"] |
|
human_prob = max(0.0, 1.0 - ai_prob) |
|
inferred = True |
|
status_message = "OK (Inferred from LABEL_1)" |
|
elif "LABEL_0" in lbl2score: |
|
human_prob = lbl2score["LABEL_0"] |
|
ai_prob = max(0.0, 1.0 - human_prob) |
|
inferred = True |
|
status_message = "OK (Inferred from LABEL_0)" |
|
elif "FAKE" in lbl2score: |
|
ai_prob = lbl2score["FAKE"] |
|
human_prob = max(0.0, 1.0 - ai_prob) |
|
inferred = True |
|
status_message = "OK (Inferred from FAKE)" |
|
elif "AI" in lbl2score: |
|
ai_prob = lbl2score["AI"] |
|
human_prob = max(0.0, 1.0 - ai_prob) |
|
inferred = True |
|
status_message = "OK (Inferred from AI)" |
|
elif "REAL" in lbl2score: |
|
human_prob = lbl2score["REAL"] |
|
ai_prob = max(0.0, 1.0 - human_prob) |
|
inferred = True |
|
status_message = "OK (Inferred from REAL)" |
|
|
|
if not inferred: |
|
status_message = f"Error: Unrecognized labels [{label_keys_found}]" |
|
print(f"Warning: {status_message}") |
|
|
|
|
|
|
|
ai_display_str = f"{ai_prob*100:.2f}%" |
|
human_display_str = f"{human_prob*100:.2f}%" |
|
|
|
|
|
if "Error:" in status_message: |
|
ai_display_str = status_message |
|
human_display_str = "N/A" |
|
|
|
print(f"Score Status: {status_message}. AI={ai_display_str}, Human={human_display_str}") |
|
return ai_display_str, human_display_str |
|
|
|
|
|
|
|
def analyze_image(image: Image.Image, ocr_choice: str): |
|
"""Performs OCR and AI Content Detection, returns both AI and Human %.""" |
|
|
|
extracted = "" |
|
ai_result_str = "N/A" |
|
human_result_str = "N/A" |
|
status_update = "Awaiting input..." |
|
|
|
if image is None: |
|
status_update = "Please upload an image first." |
|
return extracted, ai_result_str, human_result_str, status_update |
|
if not ocr_choice: |
|
status_update = "Please select an OCR model." |
|
return extracted, ai_result_str, human_result_str, status_update |
|
|
|
if ocr_choice not in OCR_PIPELINES: |
|
return "", "N/A", "N/A", f"Error: OCR model '{ocr_choice}' not loaded." |
|
if DETECTOR_PIPELINE is None: |
|
return "", "N/A", "N/A", "Critical Error: AI Detector model failed to load." |
|
|
|
|
|
try: |
|
status_update = f"Processing with {ocr_choice} OCR..." |
|
print(status_update) |
|
proc, mdl = OCR_PIPELINES[ocr_choice] |
|
if image.mode != "RGB": image = image.convert("RGB") |
|
pix = proc(images=image, return_tensors="pt").pixel_values |
|
tokens = mdl.generate(pix, max_length=512) |
|
extracted = proc.batch_decode(tokens, skip_special_tokens=True)[0] |
|
|
|
if not extracted or extracted.isspace(): |
|
status_update = "OCR completed, but no text or only whitespace was extracted." |
|
print(status_update) |
|
|
|
return extracted, "N/A", "N/A", status_update |
|
|
|
status_update = "Detecting AI/Human content..." |
|
print(status_update) |
|
results = DETECTOR_PIPELINE(extracted, truncation=True) |
|
|
|
|
|
ai_result_str, human_result_str = get_ai_and_human_scores(results) |
|
|
|
|
|
|
|
if "Error:" in ai_result_str: |
|
status_update = "Analysis completed with detection errors." |
|
else: |
|
status_update = "Analysis complete." |
|
print(status_update) |
|
|
|
|
|
return extracted, ai_result_str, human_result_str, status_update |
|
|
|
except Exception as e: |
|
print(f"Error during image analysis: {e}") |
|
traceback.print_exc() |
|
status_update = f"An error occurred during analysis: {e}" |
|
|
|
return extracted, "Error", "Error", status_update |
|
|
|
|
|
def classify_text(text: str): |
|
"""Classifies provided text, returning both AI and Human %.""" |
|
|
|
ai_result_str = "N/A" |
|
human_result_str = "N/A" |
|
|
|
if not text or text.isspace(): |
|
|
|
return "Please enter some text.", "N/A" |
|
if DETECTOR_PIPELINE is None: |
|
return "Critical Error: AI Detector model failed to load.", "N/A" |
|
|
|
print("Classifying text...") |
|
try: |
|
results = DETECTOR_PIPELINE(text, truncation=True) |
|
|
|
|
|
ai_result_str, human_result_str = get_ai_and_human_scores(results) |
|
|
|
|
|
|
|
if "Error:" not in ai_result_str: |
|
print("Classification complete.") |
|
else: |
|
print("Classification completed with errors.") |
|
|
|
|
|
return ai_result_str, human_result_str |
|
|
|
except Exception as e: |
|
print(f"Error during text classification: {e}") |
|
traceback.print_exc() |
|
return f"Error: {e}", "Error" |
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown( |
|
""" |
|
## OCR + AI/Human Content Detection |
|
Upload an image or paste text. The tool extracts text via OCR (if image) and analyzes it |
|
using an AI content detector (`openai-community/roberta-large-openai-detector`) |
|
to estimate the likelihood of it being AI-generated vs. Human-written. |
|
**Disclaimer:** AI content detection is challenging and not 100% accurate. These likelihoods |
|
are estimates based on the model's training data and may not be definitive. |
|
Performance varies with text type and AI generation methods. |
|
**Label Assumption:** Assumes model outputs LABEL_1 for AI/Fake and LABEL_0 for Human/Real. |
|
""" |
|
) |
|
|
|
with gr.Tab("Analyze Image"): |
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
img_in = gr.Image(type="pil", label="Upload Image", sources=["upload", "clipboard"]) |
|
with gr.Column(scale=1): |
|
ocr_dd = gr.Dropdown( |
|
list(TROCR_MODELS.keys()), label="1. Select OCR Model", info="Choose based on text type in image." |
|
) |
|
run_btn = gr.Button("2. Analyze Image", variant="primary") |
|
status_img = gr.Label(value="Awaiting image analysis...", label="Status") |
|
|
|
with gr.Row(): |
|
text_out_img = gr.Textbox(label="Extracted Text", lines=6, interactive=False) |
|
|
|
with gr.Column(scale=1): |
|
ai_out_img = gr.Textbox(label="AI Likelihood %", interactive=False) |
|
with gr.Column(scale=1): |
|
human_out_img = gr.Textbox(label="Human Likelihood %", interactive=False) |
|
|
|
|
|
|
|
run_btn.click( |
|
fn=analyze_image, |
|
inputs=[img_in, ocr_dd], |
|
outputs=[text_out_img, ai_out_img, human_out_img, status_img], |
|
queue=True |
|
) |
|
|
|
with gr.Tab("Classify Text"): |
|
with gr.Column(): |
|
text_in_classify = gr.Textbox(label="Paste or type text here", lines=8) |
|
classify_btn = gr.Button("Classify Text", variant="primary") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
ai_out_classify = gr.Textbox(label="AI Likelihood %", interactive=False) |
|
with gr.Column(scale=1): |
|
human_out_classify = gr.Textbox(label="Human Likelihood %", interactive=False) |
|
|
|
|
|
|
|
classify_btn.click( |
|
fn=classify_text, |
|
inputs=[text_in_classify], |
|
outputs=[ai_out_classify, human_out_classify], |
|
queue=True |
|
) |
|
|
|
gr.HTML(f"<footer style='text-align:center; margin-top: 20px; color: grey;'>Powered by TrOCR & {DETECTOR_MODEL_ID}</footer>") |
|
|
|
|
|
if __name__ == "__main__": |
|
print("Starting Gradio demo...") |
|
demo.launch(share=False, server_name="0.0.0.0") |