File size: 12,799 Bytes
27d711f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f750907
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
import gradio as gr
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, pipeline
from PIL import Image
import traceback
import warnings
import json # For debugging output

# --- Model IDs ---
TROCR_MODELS = {
    "Printed Text": "microsoft/trocr-large-printed",
    "Handwritten":  "microsoft/trocr-large-handwritten",
}
DETECTOR_MODEL_ID = "openai-community/roberta-large-openai-detector"
print(f"Using AI Detector Model: {DETECTOR_MODEL_ID}")

# --- Load pipelines once at startup ---
# ... (OCR model loading code remains the same) ...
print("Loading OCR models...")
OCR_PIPELINES = {}
for name, model_id in TROCR_MODELS.items():
    try:
        proc = TrOCRProcessor.from_pretrained(model_id)
        mdl  = VisionEncoderDecoderModel.from_pretrained(model_id)
        OCR_PIPELINES[name] = (proc, mdl)
        print(f"Loaded {name} OCR model.")
    except Exception as e:
        print(f"Error loading OCR model {name} ({model_id}): {e}")

print(f"Loading AI detector model ({DETECTOR_MODEL_ID})...")
try:
    DETECTOR_PIPELINE = pipeline(
        "text-classification",
        model=DETECTOR_MODEL_ID,
        top_k=None # Get scores for all classes
    )
    print("Loaded AI detector model.")
except Exception as e:
    print(f"CRITICAL Error loading AI detector model ({DETECTOR_MODEL_ID}): {e}")
    traceback.print_exc()
    print("Exiting due to critical model loading failure.")
    exit()

# --- Updated function to get both AI and Human scores ---
def get_ai_and_human_scores(results):
    """
    Processes detector results to get likelihood scores for both AI and Human classes.
    Handles various label formats ('LABEL_0'/'LABEL_1', 'FAKE'/'REAL', 'AI'/'REAL').
    Returns:
        tuple: (ai_display_string, human_display_string)
    """
    ai_prob = 0.0
    human_prob = 0.0
    status_message = "Error: No results received" # Default status

    if not results:
        print("Warning: Received empty results for AI detection.")
        return status_message, "N/A" # Return error string for both outputs

    # Handle potential nested list structure
    score_list = []
    if isinstance(results, list) and len(results) > 0:
        if isinstance(results[0], list) and len(results[0]) > 0:
            score_list = results[0]
        elif isinstance(results[0], dict):
            score_list = results
        else:
            status_message = "Error: Unexpected detector output format (inner)"
            print(f"Warning: {status_message}. Results: {results[0]}")
            return status_message, "N/A"
    else:
        status_message = "Error: Unexpected detector output format (outer)"
        print(f"Warning: {status_message}. Results: {results}")
        return status_message, "N/A"

    # Build label→score map (uppercase labels)
    lbl2score = {
        entry["label"].upper(): entry["score"]
        for entry in score_list
        if isinstance(entry, dict) and "label" in entry and "score" in entry
    }

    if not lbl2score:
        status_message = "Error: Could not parse detector scores"
        print(f"Warning: {status_message}. Score list: {score_list}")
        return status_message, "N/A"

    label_keys_found = ", ".join(lbl2score.keys())
    found_pair = False
    inferred = False

    # --- Determine AI and Human probabilities based on labels ---
    # ** ASSUMPTION: LABEL_1=AI, LABEL_0=Human - VERIFY THIS! **
    if "LABEL_1" in lbl2score and "LABEL_0" in lbl2score:
        ai_prob = lbl2score["LABEL_1"]
        human_prob = lbl2score["LABEL_0"]
        found_pair = True
        status_message = "OK (Used LABEL_1/LABEL_0)"
    elif "FAKE" in lbl2score and "REAL" in lbl2score:
        ai_prob = lbl2score["FAKE"]
        human_prob = lbl2score["REAL"]
        found_pair = True
        status_message = "OK (Used FAKE/REAL)"
    elif "AI" in lbl2score and "REAL" in lbl2score:
         ai_prob = lbl2score["AI"]
         human_prob = lbl2score["REAL"]
         found_pair = True
         status_message = "OK (Used AI/REAL)"

    # If pair not found, try inferring from single known labels
    if not found_pair:
        if "LABEL_1" in lbl2score: # Assume LABEL_1 = AI
            ai_prob = lbl2score["LABEL_1"]
            human_prob = max(0.0, 1.0 - ai_prob) # Ensure non-negative
            inferred = True
            status_message = "OK (Inferred from LABEL_1)"
        elif "LABEL_0" in lbl2score: # Assume LABEL_0 = Human
            human_prob = lbl2score["LABEL_0"]
            ai_prob = max(0.0, 1.0 - human_prob) # Ensure non-negative
            inferred = True
            status_message = "OK (Inferred from LABEL_0)"
        elif "FAKE" in lbl2score:
            ai_prob = lbl2score["FAKE"]
            human_prob = max(0.0, 1.0 - ai_prob)
            inferred = True
            status_message = "OK (Inferred from FAKE)"
        elif "AI" in lbl2score:
            ai_prob = lbl2score["AI"]
            human_prob = max(0.0, 1.0 - ai_prob)
            inferred = True
            status_message = "OK (Inferred from AI)"
        elif "REAL" in lbl2score:
            human_prob = lbl2score["REAL"]
            ai_prob = max(0.0, 1.0 - human_prob)
            inferred = True
            status_message = "OK (Inferred from REAL)"

        if not inferred:
             status_message = f"Error: Unrecognized labels [{label_keys_found}]"
             print(f"Warning: {status_message}")
             # Keep probs at 0.0

    # --- Format output strings ---
    ai_display_str = f"{ai_prob*100:.2f}%"
    human_display_str = f"{human_prob*100:.2f}%"

    # If an error occurred, reflect it in the output strings
    if "Error:" in status_message:
         ai_display_str = status_message # Show error instead of percentage
         human_display_str = "N/A"

    print(f"Score Status: {status_message}. AI={ai_display_str}, Human={human_display_str}") # Log detail
    return ai_display_str, human_display_str

# --- Update calling functions ---

def analyze_image(image: Image.Image, ocr_choice: str):
    """Performs OCR and AI Content Detection, returns both AI and Human %."""
    # Default return values in case of early exit
    extracted = ""
    ai_result_str = "N/A"
    human_result_str = "N/A"
    status_update = "Awaiting input..."

    if image is None:
        status_update = "Please upload an image first."
        return extracted, ai_result_str, human_result_str, status_update
    if not ocr_choice:
        status_update = "Please select an OCR model."
        return extracted, ai_result_str, human_result_str, status_update
    # ... (other initial checks for models loaded remain the same) ...
    if ocr_choice not in OCR_PIPELINES:
        return "", "N/A", "N/A", f"Error: OCR model '{ocr_choice}' not loaded."
    if DETECTOR_PIPELINE is None:
        return "", "N/A", "N/A", "Critical Error: AI Detector model failed to load."


    try:
        status_update = f"Processing with {ocr_choice} OCR..."
        print(status_update)
        proc, mdl = OCR_PIPELINES[ocr_choice]
        if image.mode != "RGB": image = image.convert("RGB")
        pix = proc(images=image, return_tensors="pt").pixel_values
        tokens = mdl.generate(pix, max_length=512)
        extracted = proc.batch_decode(tokens, skip_special_tokens=True)[0]

        if not extracted or extracted.isspace():
             status_update = "OCR completed, but no text or only whitespace was extracted."
             print(status_update)
             # Return empty extracted text, N/A for scores, and status
             return extracted, "N/A", "N/A", status_update

        status_update = "Detecting AI/Human content..."
        print(status_update)
        results = DETECTOR_PIPELINE(extracted, truncation=True)

        # --- Call updated function ---
        ai_result_str, human_result_str = get_ai_and_human_scores(results)
        # ---

        # Check if an error message was returned
        if "Error:" in ai_result_str:
             status_update = "Analysis completed with detection errors."
        else:
             status_update = "Analysis complete."
        print(status_update)

        # Return: extracted_text, ai_%, human_%, status_message
        return extracted, ai_result_str, human_result_str, status_update

    except Exception as e:
        print(f"Error during image analysis: {e}")
        traceback.print_exc()
        status_update = f"An error occurred during analysis: {e}"
        # Return current state if possible, else defaults
        return extracted, "Error", "Error", status_update


def classify_text(text: str):
    """Classifies provided text, returning both AI and Human %."""
    # Default return values
    ai_result_str = "N/A"
    human_result_str = "N/A"

    if not text or text.isspace():
        # Return error message for AI%, N/A for Human%
        return "Please enter some text.", "N/A"
    if DETECTOR_PIPELINE is None:
        return "Critical Error: AI Detector model failed to load.", "N/A"

    print("Classifying text...")
    try:
        results = DETECTOR_PIPELINE(text, truncation=True)

        # --- Call updated function ---
        ai_result_str, human_result_str = get_ai_and_human_scores(results)
        # ---

        # Check if an error message was returned
        if "Error:" not in ai_result_str:
            print("Classification complete.")
        else:
            print("Classification completed with errors.")

        # Return: ai_%, human_%
        return ai_result_str, human_result_str

    except Exception as e:
        print(f"Error during text classification: {e}")
        traceback.print_exc()
        return f"Error: {e}", "Error"


# --- Gradio Interface Update ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        ## OCR + AI/Human Content Detection
        Upload an image or paste text. The tool extracts text via OCR (if image) and analyzes it
        using an AI content detector (`openai-community/roberta-large-openai-detector`)
        to estimate the likelihood of it being AI-generated vs. Human-written.
        **Disclaimer:** AI content detection is challenging and not 100% accurate. These likelihoods
        are estimates based on the model's training data and may not be definitive.
        Performance varies with text type and AI generation methods.
        **Label Assumption:** Assumes model outputs LABEL_1 for AI/Fake and LABEL_0 for Human/Real.
        """
    )

    with gr.Tab("Analyze Image"):
        with gr.Row():
            with gr.Column(scale=2):
                img_in = gr.Image(type="pil", label="Upload Image", sources=["upload", "clipboard"])
            with gr.Column(scale=1):
                ocr_dd = gr.Dropdown(
                    list(TROCR_MODELS.keys()), label="1. Select OCR Model", info="Choose based on text type in image."
                )
                run_btn = gr.Button("2. Analyze Image", variant="primary")
                status_img = gr.Label(value="Awaiting image analysis...", label="Status")

        with gr.Row():
            text_out_img = gr.Textbox(label="Extracted Text", lines=6, interactive=False)
            # --- Two output boxes for scores ---
            with gr.Column(scale=1):
                 ai_out_img = gr.Textbox(label="AI Likelihood %", interactive=False)
            with gr.Column(scale=1):
                 human_out_img = gr.Textbox(label="Human Likelihood %", interactive=False)
            # ---

        # --- Update outputs list ---
        run_btn.click(
            fn=analyze_image,
            inputs=[img_in, ocr_dd],
            outputs=[text_out_img, ai_out_img, human_out_img, status_img], # 4 outputs now
            queue=True
        )

    with gr.Tab("Classify Text"):
         with gr.Column():
            text_in_classify = gr.Textbox(label="Paste or type text here", lines=8)
            classify_btn = gr.Button("Classify Text", variant="primary")
            # --- Two output boxes for scores ---
            with gr.Row():
                 with gr.Column(scale=1):
                    ai_out_classify = gr.Textbox(label="AI Likelihood %", interactive=False)
                 with gr.Column(scale=1):
                    human_out_classify = gr.Textbox(label="Human Likelihood %", interactive=False)
            # ---

            # --- Update outputs list ---
            classify_btn.click(
                fn=classify_text,
                inputs=[text_in_classify],
                outputs=[ai_out_classify, human_out_classify], # 2 outputs now
                queue=True
            )

    gr.HTML(f"<footer style='text-align:center; margin-top: 20px; color: grey;'>Powered by TrOCR & {DETECTOR_MODEL_ID}</footer>")


if __name__ == "__main__":
    print("Starting Gradio demo...")
    demo.launch(share=False, server_name="0.0.0.0")