File size: 12,799 Bytes
27d711f f750907 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 |
import gradio as gr
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, pipeline
from PIL import Image
import traceback
import warnings
import json # For debugging output
# --- Model IDs ---
TROCR_MODELS = {
"Printed Text": "microsoft/trocr-large-printed",
"Handwritten": "microsoft/trocr-large-handwritten",
}
DETECTOR_MODEL_ID = "openai-community/roberta-large-openai-detector"
print(f"Using AI Detector Model: {DETECTOR_MODEL_ID}")
# --- Load pipelines once at startup ---
# ... (OCR model loading code remains the same) ...
print("Loading OCR models...")
OCR_PIPELINES = {}
for name, model_id in TROCR_MODELS.items():
try:
proc = TrOCRProcessor.from_pretrained(model_id)
mdl = VisionEncoderDecoderModel.from_pretrained(model_id)
OCR_PIPELINES[name] = (proc, mdl)
print(f"Loaded {name} OCR model.")
except Exception as e:
print(f"Error loading OCR model {name} ({model_id}): {e}")
print(f"Loading AI detector model ({DETECTOR_MODEL_ID})...")
try:
DETECTOR_PIPELINE = pipeline(
"text-classification",
model=DETECTOR_MODEL_ID,
top_k=None # Get scores for all classes
)
print("Loaded AI detector model.")
except Exception as e:
print(f"CRITICAL Error loading AI detector model ({DETECTOR_MODEL_ID}): {e}")
traceback.print_exc()
print("Exiting due to critical model loading failure.")
exit()
# --- Updated function to get both AI and Human scores ---
def get_ai_and_human_scores(results):
"""
Processes detector results to get likelihood scores for both AI and Human classes.
Handles various label formats ('LABEL_0'/'LABEL_1', 'FAKE'/'REAL', 'AI'/'REAL').
Returns:
tuple: (ai_display_string, human_display_string)
"""
ai_prob = 0.0
human_prob = 0.0
status_message = "Error: No results received" # Default status
if not results:
print("Warning: Received empty results for AI detection.")
return status_message, "N/A" # Return error string for both outputs
# Handle potential nested list structure
score_list = []
if isinstance(results, list) and len(results) > 0:
if isinstance(results[0], list) and len(results[0]) > 0:
score_list = results[0]
elif isinstance(results[0], dict):
score_list = results
else:
status_message = "Error: Unexpected detector output format (inner)"
print(f"Warning: {status_message}. Results: {results[0]}")
return status_message, "N/A"
else:
status_message = "Error: Unexpected detector output format (outer)"
print(f"Warning: {status_message}. Results: {results}")
return status_message, "N/A"
# Build label→score map (uppercase labels)
lbl2score = {
entry["label"].upper(): entry["score"]
for entry in score_list
if isinstance(entry, dict) and "label" in entry and "score" in entry
}
if not lbl2score:
status_message = "Error: Could not parse detector scores"
print(f"Warning: {status_message}. Score list: {score_list}")
return status_message, "N/A"
label_keys_found = ", ".join(lbl2score.keys())
found_pair = False
inferred = False
# --- Determine AI and Human probabilities based on labels ---
# ** ASSUMPTION: LABEL_1=AI, LABEL_0=Human - VERIFY THIS! **
if "LABEL_1" in lbl2score and "LABEL_0" in lbl2score:
ai_prob = lbl2score["LABEL_1"]
human_prob = lbl2score["LABEL_0"]
found_pair = True
status_message = "OK (Used LABEL_1/LABEL_0)"
elif "FAKE" in lbl2score and "REAL" in lbl2score:
ai_prob = lbl2score["FAKE"]
human_prob = lbl2score["REAL"]
found_pair = True
status_message = "OK (Used FAKE/REAL)"
elif "AI" in lbl2score and "REAL" in lbl2score:
ai_prob = lbl2score["AI"]
human_prob = lbl2score["REAL"]
found_pair = True
status_message = "OK (Used AI/REAL)"
# If pair not found, try inferring from single known labels
if not found_pair:
if "LABEL_1" in lbl2score: # Assume LABEL_1 = AI
ai_prob = lbl2score["LABEL_1"]
human_prob = max(0.0, 1.0 - ai_prob) # Ensure non-negative
inferred = True
status_message = "OK (Inferred from LABEL_1)"
elif "LABEL_0" in lbl2score: # Assume LABEL_0 = Human
human_prob = lbl2score["LABEL_0"]
ai_prob = max(0.0, 1.0 - human_prob) # Ensure non-negative
inferred = True
status_message = "OK (Inferred from LABEL_0)"
elif "FAKE" in lbl2score:
ai_prob = lbl2score["FAKE"]
human_prob = max(0.0, 1.0 - ai_prob)
inferred = True
status_message = "OK (Inferred from FAKE)"
elif "AI" in lbl2score:
ai_prob = lbl2score["AI"]
human_prob = max(0.0, 1.0 - ai_prob)
inferred = True
status_message = "OK (Inferred from AI)"
elif "REAL" in lbl2score:
human_prob = lbl2score["REAL"]
ai_prob = max(0.0, 1.0 - human_prob)
inferred = True
status_message = "OK (Inferred from REAL)"
if not inferred:
status_message = f"Error: Unrecognized labels [{label_keys_found}]"
print(f"Warning: {status_message}")
# Keep probs at 0.0
# --- Format output strings ---
ai_display_str = f"{ai_prob*100:.2f}%"
human_display_str = f"{human_prob*100:.2f}%"
# If an error occurred, reflect it in the output strings
if "Error:" in status_message:
ai_display_str = status_message # Show error instead of percentage
human_display_str = "N/A"
print(f"Score Status: {status_message}. AI={ai_display_str}, Human={human_display_str}") # Log detail
return ai_display_str, human_display_str
# --- Update calling functions ---
def analyze_image(image: Image.Image, ocr_choice: str):
"""Performs OCR and AI Content Detection, returns both AI and Human %."""
# Default return values in case of early exit
extracted = ""
ai_result_str = "N/A"
human_result_str = "N/A"
status_update = "Awaiting input..."
if image is None:
status_update = "Please upload an image first."
return extracted, ai_result_str, human_result_str, status_update
if not ocr_choice:
status_update = "Please select an OCR model."
return extracted, ai_result_str, human_result_str, status_update
# ... (other initial checks for models loaded remain the same) ...
if ocr_choice not in OCR_PIPELINES:
return "", "N/A", "N/A", f"Error: OCR model '{ocr_choice}' not loaded."
if DETECTOR_PIPELINE is None:
return "", "N/A", "N/A", "Critical Error: AI Detector model failed to load."
try:
status_update = f"Processing with {ocr_choice} OCR..."
print(status_update)
proc, mdl = OCR_PIPELINES[ocr_choice]
if image.mode != "RGB": image = image.convert("RGB")
pix = proc(images=image, return_tensors="pt").pixel_values
tokens = mdl.generate(pix, max_length=512)
extracted = proc.batch_decode(tokens, skip_special_tokens=True)[0]
if not extracted or extracted.isspace():
status_update = "OCR completed, but no text or only whitespace was extracted."
print(status_update)
# Return empty extracted text, N/A for scores, and status
return extracted, "N/A", "N/A", status_update
status_update = "Detecting AI/Human content..."
print(status_update)
results = DETECTOR_PIPELINE(extracted, truncation=True)
# --- Call updated function ---
ai_result_str, human_result_str = get_ai_and_human_scores(results)
# ---
# Check if an error message was returned
if "Error:" in ai_result_str:
status_update = "Analysis completed with detection errors."
else:
status_update = "Analysis complete."
print(status_update)
# Return: extracted_text, ai_%, human_%, status_message
return extracted, ai_result_str, human_result_str, status_update
except Exception as e:
print(f"Error during image analysis: {e}")
traceback.print_exc()
status_update = f"An error occurred during analysis: {e}"
# Return current state if possible, else defaults
return extracted, "Error", "Error", status_update
def classify_text(text: str):
"""Classifies provided text, returning both AI and Human %."""
# Default return values
ai_result_str = "N/A"
human_result_str = "N/A"
if not text or text.isspace():
# Return error message for AI%, N/A for Human%
return "Please enter some text.", "N/A"
if DETECTOR_PIPELINE is None:
return "Critical Error: AI Detector model failed to load.", "N/A"
print("Classifying text...")
try:
results = DETECTOR_PIPELINE(text, truncation=True)
# --- Call updated function ---
ai_result_str, human_result_str = get_ai_and_human_scores(results)
# ---
# Check if an error message was returned
if "Error:" not in ai_result_str:
print("Classification complete.")
else:
print("Classification completed with errors.")
# Return: ai_%, human_%
return ai_result_str, human_result_str
except Exception as e:
print(f"Error during text classification: {e}")
traceback.print_exc()
return f"Error: {e}", "Error"
# --- Gradio Interface Update ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
## OCR + AI/Human Content Detection
Upload an image or paste text. The tool extracts text via OCR (if image) and analyzes it
using an AI content detector (`openai-community/roberta-large-openai-detector`)
to estimate the likelihood of it being AI-generated vs. Human-written.
**Disclaimer:** AI content detection is challenging and not 100% accurate. These likelihoods
are estimates based on the model's training data and may not be definitive.
Performance varies with text type and AI generation methods.
**Label Assumption:** Assumes model outputs LABEL_1 for AI/Fake and LABEL_0 for Human/Real.
"""
)
with gr.Tab("Analyze Image"):
with gr.Row():
with gr.Column(scale=2):
img_in = gr.Image(type="pil", label="Upload Image", sources=["upload", "clipboard"])
with gr.Column(scale=1):
ocr_dd = gr.Dropdown(
list(TROCR_MODELS.keys()), label="1. Select OCR Model", info="Choose based on text type in image."
)
run_btn = gr.Button("2. Analyze Image", variant="primary")
status_img = gr.Label(value="Awaiting image analysis...", label="Status")
with gr.Row():
text_out_img = gr.Textbox(label="Extracted Text", lines=6, interactive=False)
# --- Two output boxes for scores ---
with gr.Column(scale=1):
ai_out_img = gr.Textbox(label="AI Likelihood %", interactive=False)
with gr.Column(scale=1):
human_out_img = gr.Textbox(label="Human Likelihood %", interactive=False)
# ---
# --- Update outputs list ---
run_btn.click(
fn=analyze_image,
inputs=[img_in, ocr_dd],
outputs=[text_out_img, ai_out_img, human_out_img, status_img], # 4 outputs now
queue=True
)
with gr.Tab("Classify Text"):
with gr.Column():
text_in_classify = gr.Textbox(label="Paste or type text here", lines=8)
classify_btn = gr.Button("Classify Text", variant="primary")
# --- Two output boxes for scores ---
with gr.Row():
with gr.Column(scale=1):
ai_out_classify = gr.Textbox(label="AI Likelihood %", interactive=False)
with gr.Column(scale=1):
human_out_classify = gr.Textbox(label="Human Likelihood %", interactive=False)
# ---
# --- Update outputs list ---
classify_btn.click(
fn=classify_text,
inputs=[text_in_classify],
outputs=[ai_out_classify, human_out_classify], # 2 outputs now
queue=True
)
gr.HTML(f"<footer style='text-align:center; margin-top: 20px; color: grey;'>Powered by TrOCR & {DETECTOR_MODEL_ID}</footer>")
if __name__ == "__main__":
print("Starting Gradio demo...")
demo.launch(share=False, server_name="0.0.0.0") |