Hiridharan10's picture
Update app.py
27d711f verified
import gradio as gr
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, pipeline
from PIL import Image
import traceback
import warnings
import json # For debugging output
# --- Model IDs ---
TROCR_MODELS = {
"Printed Text": "microsoft/trocr-large-printed",
"Handwritten": "microsoft/trocr-large-handwritten",
}
DETECTOR_MODEL_ID = "openai-community/roberta-large-openai-detector"
print(f"Using AI Detector Model: {DETECTOR_MODEL_ID}")
# --- Load pipelines once at startup ---
# ... (OCR model loading code remains the same) ...
print("Loading OCR models...")
OCR_PIPELINES = {}
for name, model_id in TROCR_MODELS.items():
try:
proc = TrOCRProcessor.from_pretrained(model_id)
mdl = VisionEncoderDecoderModel.from_pretrained(model_id)
OCR_PIPELINES[name] = (proc, mdl)
print(f"Loaded {name} OCR model.")
except Exception as e:
print(f"Error loading OCR model {name} ({model_id}): {e}")
print(f"Loading AI detector model ({DETECTOR_MODEL_ID})...")
try:
DETECTOR_PIPELINE = pipeline(
"text-classification",
model=DETECTOR_MODEL_ID,
top_k=None # Get scores for all classes
)
print("Loaded AI detector model.")
except Exception as e:
print(f"CRITICAL Error loading AI detector model ({DETECTOR_MODEL_ID}): {e}")
traceback.print_exc()
print("Exiting due to critical model loading failure.")
exit()
# --- Updated function to get both AI and Human scores ---
def get_ai_and_human_scores(results):
"""
Processes detector results to get likelihood scores for both AI and Human classes.
Handles various label formats ('LABEL_0'/'LABEL_1', 'FAKE'/'REAL', 'AI'/'REAL').
Returns:
tuple: (ai_display_string, human_display_string)
"""
ai_prob = 0.0
human_prob = 0.0
status_message = "Error: No results received" # Default status
if not results:
print("Warning: Received empty results for AI detection.")
return status_message, "N/A" # Return error string for both outputs
# Handle potential nested list structure
score_list = []
if isinstance(results, list) and len(results) > 0:
if isinstance(results[0], list) and len(results[0]) > 0:
score_list = results[0]
elif isinstance(results[0], dict):
score_list = results
else:
status_message = "Error: Unexpected detector output format (inner)"
print(f"Warning: {status_message}. Results: {results[0]}")
return status_message, "N/A"
else:
status_message = "Error: Unexpected detector output format (outer)"
print(f"Warning: {status_message}. Results: {results}")
return status_message, "N/A"
# Build label→score map (uppercase labels)
lbl2score = {
entry["label"].upper(): entry["score"]
for entry in score_list
if isinstance(entry, dict) and "label" in entry and "score" in entry
}
if not lbl2score:
status_message = "Error: Could not parse detector scores"
print(f"Warning: {status_message}. Score list: {score_list}")
return status_message, "N/A"
label_keys_found = ", ".join(lbl2score.keys())
found_pair = False
inferred = False
# --- Determine AI and Human probabilities based on labels ---
# ** ASSUMPTION: LABEL_1=AI, LABEL_0=Human - VERIFY THIS! **
if "LABEL_1" in lbl2score and "LABEL_0" in lbl2score:
ai_prob = lbl2score["LABEL_1"]
human_prob = lbl2score["LABEL_0"]
found_pair = True
status_message = "OK (Used LABEL_1/LABEL_0)"
elif "FAKE" in lbl2score and "REAL" in lbl2score:
ai_prob = lbl2score["FAKE"]
human_prob = lbl2score["REAL"]
found_pair = True
status_message = "OK (Used FAKE/REAL)"
elif "AI" in lbl2score and "REAL" in lbl2score:
ai_prob = lbl2score["AI"]
human_prob = lbl2score["REAL"]
found_pair = True
status_message = "OK (Used AI/REAL)"
# If pair not found, try inferring from single known labels
if not found_pair:
if "LABEL_1" in lbl2score: # Assume LABEL_1 = AI
ai_prob = lbl2score["LABEL_1"]
human_prob = max(0.0, 1.0 - ai_prob) # Ensure non-negative
inferred = True
status_message = "OK (Inferred from LABEL_1)"
elif "LABEL_0" in lbl2score: # Assume LABEL_0 = Human
human_prob = lbl2score["LABEL_0"]
ai_prob = max(0.0, 1.0 - human_prob) # Ensure non-negative
inferred = True
status_message = "OK (Inferred from LABEL_0)"
elif "FAKE" in lbl2score:
ai_prob = lbl2score["FAKE"]
human_prob = max(0.0, 1.0 - ai_prob)
inferred = True
status_message = "OK (Inferred from FAKE)"
elif "AI" in lbl2score:
ai_prob = lbl2score["AI"]
human_prob = max(0.0, 1.0 - ai_prob)
inferred = True
status_message = "OK (Inferred from AI)"
elif "REAL" in lbl2score:
human_prob = lbl2score["REAL"]
ai_prob = max(0.0, 1.0 - human_prob)
inferred = True
status_message = "OK (Inferred from REAL)"
if not inferred:
status_message = f"Error: Unrecognized labels [{label_keys_found}]"
print(f"Warning: {status_message}")
# Keep probs at 0.0
# --- Format output strings ---
ai_display_str = f"{ai_prob*100:.2f}%"
human_display_str = f"{human_prob*100:.2f}%"
# If an error occurred, reflect it in the output strings
if "Error:" in status_message:
ai_display_str = status_message # Show error instead of percentage
human_display_str = "N/A"
print(f"Score Status: {status_message}. AI={ai_display_str}, Human={human_display_str}") # Log detail
return ai_display_str, human_display_str
# --- Update calling functions ---
def analyze_image(image: Image.Image, ocr_choice: str):
"""Performs OCR and AI Content Detection, returns both AI and Human %."""
# Default return values in case of early exit
extracted = ""
ai_result_str = "N/A"
human_result_str = "N/A"
status_update = "Awaiting input..."
if image is None:
status_update = "Please upload an image first."
return extracted, ai_result_str, human_result_str, status_update
if not ocr_choice:
status_update = "Please select an OCR model."
return extracted, ai_result_str, human_result_str, status_update
# ... (other initial checks for models loaded remain the same) ...
if ocr_choice not in OCR_PIPELINES:
return "", "N/A", "N/A", f"Error: OCR model '{ocr_choice}' not loaded."
if DETECTOR_PIPELINE is None:
return "", "N/A", "N/A", "Critical Error: AI Detector model failed to load."
try:
status_update = f"Processing with {ocr_choice} OCR..."
print(status_update)
proc, mdl = OCR_PIPELINES[ocr_choice]
if image.mode != "RGB": image = image.convert("RGB")
pix = proc(images=image, return_tensors="pt").pixel_values
tokens = mdl.generate(pix, max_length=512)
extracted = proc.batch_decode(tokens, skip_special_tokens=True)[0]
if not extracted or extracted.isspace():
status_update = "OCR completed, but no text or only whitespace was extracted."
print(status_update)
# Return empty extracted text, N/A for scores, and status
return extracted, "N/A", "N/A", status_update
status_update = "Detecting AI/Human content..."
print(status_update)
results = DETECTOR_PIPELINE(extracted, truncation=True)
# --- Call updated function ---
ai_result_str, human_result_str = get_ai_and_human_scores(results)
# ---
# Check if an error message was returned
if "Error:" in ai_result_str:
status_update = "Analysis completed with detection errors."
else:
status_update = "Analysis complete."
print(status_update)
# Return: extracted_text, ai_%, human_%, status_message
return extracted, ai_result_str, human_result_str, status_update
except Exception as e:
print(f"Error during image analysis: {e}")
traceback.print_exc()
status_update = f"An error occurred during analysis: {e}"
# Return current state if possible, else defaults
return extracted, "Error", "Error", status_update
def classify_text(text: str):
"""Classifies provided text, returning both AI and Human %."""
# Default return values
ai_result_str = "N/A"
human_result_str = "N/A"
if not text or text.isspace():
# Return error message for AI%, N/A for Human%
return "Please enter some text.", "N/A"
if DETECTOR_PIPELINE is None:
return "Critical Error: AI Detector model failed to load.", "N/A"
print("Classifying text...")
try:
results = DETECTOR_PIPELINE(text, truncation=True)
# --- Call updated function ---
ai_result_str, human_result_str = get_ai_and_human_scores(results)
# ---
# Check if an error message was returned
if "Error:" not in ai_result_str:
print("Classification complete.")
else:
print("Classification completed with errors.")
# Return: ai_%, human_%
return ai_result_str, human_result_str
except Exception as e:
print(f"Error during text classification: {e}")
traceback.print_exc()
return f"Error: {e}", "Error"
# --- Gradio Interface Update ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
## OCR + AI/Human Content Detection
Upload an image or paste text. The tool extracts text via OCR (if image) and analyzes it
using an AI content detector (`openai-community/roberta-large-openai-detector`)
to estimate the likelihood of it being AI-generated vs. Human-written.
**Disclaimer:** AI content detection is challenging and not 100% accurate. These likelihoods
are estimates based on the model's training data and may not be definitive.
Performance varies with text type and AI generation methods.
**Label Assumption:** Assumes model outputs LABEL_1 for AI/Fake and LABEL_0 for Human/Real.
"""
)
with gr.Tab("Analyze Image"):
with gr.Row():
with gr.Column(scale=2):
img_in = gr.Image(type="pil", label="Upload Image", sources=["upload", "clipboard"])
with gr.Column(scale=1):
ocr_dd = gr.Dropdown(
list(TROCR_MODELS.keys()), label="1. Select OCR Model", info="Choose based on text type in image."
)
run_btn = gr.Button("2. Analyze Image", variant="primary")
status_img = gr.Label(value="Awaiting image analysis...", label="Status")
with gr.Row():
text_out_img = gr.Textbox(label="Extracted Text", lines=6, interactive=False)
# --- Two output boxes for scores ---
with gr.Column(scale=1):
ai_out_img = gr.Textbox(label="AI Likelihood %", interactive=False)
with gr.Column(scale=1):
human_out_img = gr.Textbox(label="Human Likelihood %", interactive=False)
# ---
# --- Update outputs list ---
run_btn.click(
fn=analyze_image,
inputs=[img_in, ocr_dd],
outputs=[text_out_img, ai_out_img, human_out_img, status_img], # 4 outputs now
queue=True
)
with gr.Tab("Classify Text"):
with gr.Column():
text_in_classify = gr.Textbox(label="Paste or type text here", lines=8)
classify_btn = gr.Button("Classify Text", variant="primary")
# --- Two output boxes for scores ---
with gr.Row():
with gr.Column(scale=1):
ai_out_classify = gr.Textbox(label="AI Likelihood %", interactive=False)
with gr.Column(scale=1):
human_out_classify = gr.Textbox(label="Human Likelihood %", interactive=False)
# ---
# --- Update outputs list ---
classify_btn.click(
fn=classify_text,
inputs=[text_in_classify],
outputs=[ai_out_classify, human_out_classify], # 2 outputs now
queue=True
)
gr.HTML(f"<footer style='text-align:center; margin-top: 20px; color: grey;'>Powered by TrOCR & {DETECTOR_MODEL_ID}</footer>")
if __name__ == "__main__":
print("Starting Gradio demo...")
demo.launch(share=False, server_name="0.0.0.0")