#!/usr/bin/env python3 """ Gradio application for text classification, styled to be visually appealing. This version uses only the 'sojka2' model. """ import gradio as gr import logging import os from typing import Dict, Tuple, Any import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification import numpy as np try: from peft import PeftModel except ImportError: PeftModel = None logging.info("PEFT library not found. Loading models without PEFT support.") # --- Configuration --- # Model path is set to sojka MODEL_PATH = os.getenv("MODEL_PATH", "AndromedaPL/sojka") TOKENIZER_PATH = os.getenv("TOKENIZER_PATH", "sdadas/mmlw-roberta-base") DEVICE = "cuda" if torch.cuda.is_available() else "cpu" LABELS = ["self-harm", "hate", "vulgar", "sex", "crime"] MAX_SEQ_LENGTH = 512 HF_TOKEN = os.getenv('HF_TOKEN') # Thresholds are now hardcoded THRESHOLDS = { "self-harm": 0.5, "hate": 0.5, "vulgar": 0.5, "sex": 0.5, "crime": 0.5, } # Set up logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) def load_model_and_tokenizer(model_path: str, tokenizer_path: str, device: str) -> Tuple[AutoModelForSequenceClassification, AutoTokenizer]: """Load the trained model and tokenizer""" logger.info(f"Loading tokenizer from {tokenizer_path}") tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True) logger.info(f"Tokenizer loaded: {tokenizer.name_or_path}") if tokenizer.pad_token is None: if tokenizer.eos_token: tokenizer.pad_token = tokenizer.eos_token else: tokenizer.add_special_tokens({'pad_token': '[PAD]'}) tokenizer.truncation_side = "right" logger.info(f"Loading model from {model_path}") model_load_kwargs = { "torch_dtype": torch.float16 if device == 'cuda' else torch.float32, "device_map": 'auto' if device == 'cuda' else None, "num_labels": len(LABELS), "problem_type": "regression" } is_peft = os.path.exists(os.path.join(model_path, 'adapter_config.json')) if PeftModel and is_peft: logger.info("PEFT adapter detected. Loading base model and attaching adapter.") try: from peft import PeftConfig peft_config = PeftConfig.from_pretrained(model_path) base_model_path = peft_config.base_model_name_or_path logger.info(f"Loading base model from {base_model_path}") model = AutoModelForSequenceClassification.from_pretrained(base_model_path, **model_load_kwargs) logger.info("Attaching PEFT adapter...") model = PeftModel.from_pretrained(model, model_path) except Exception as e: logger.error(f"Failed to load PEFT model dynamically: {e}. Loading as a standard model.") model = AutoModelForSequenceClassification.from_pretrained(model_path, **model_load_kwargs) else: logger.info("Loading as a standalone sequence classification model.") model = AutoModelForSequenceClassification.from_pretrained(model_path, **model_load_kwargs) model.eval() logger.info(f"Model loaded on device: {next(model.parameters()).device}") return model, tokenizer # --- Load model globally --- try: model, tokenizer = load_model_and_tokenizer(MODEL_PATH, TOKENIZER_PATH, DEVICE) model_loaded = True except Exception as e: logger.error(f"FATAL: Failed to load the model from {MODEL_PATH} or tokenizer from {TOKENIZER_PATH}: {e}", e) model, tokenizer, model_loaded = None, None, False def predict(text: str) -> Dict[str, Any]: """Tokenize, predict, and format output for a single text.""" if not model_loaded: return {label: 0.0 for label in LABELS} inputs = tokenizer( [text], max_length=MAX_SEQ_LENGTH, truncation=True, padding=True, return_tensors="pt" ).to(model.device) with torch.no_grad(): outputs = model(**inputs) # Using sigmoid for multi-label classification outputs probabilities = torch.sigmoid(outputs.logits) predicted_values = probabilities.cpu().numpy()[0] clipped_values = np.clip(predicted_values, 0.0, 1.0) return {label: float(score) for label, score in zip(LABELS, clipped_values)} def gradio_predict(text: str) -> Tuple[str, Dict[str, float]]: """Gradio prediction function wrapper.""" if not model_loaded: error_message = "Błąd: Model nie został załadowany." empty_preds = {label: 0.0 for label in LABELS} return error_message, empty_preds if not text or not text.strip(): return "Wpisz tekst, aby go przeanalizować.", {label: 0.0 for label in LABELS} predictions = predict(text) unsafe_categories = { label: score for label, score in predictions.items() if score >= THRESHOLDS[label] } if not unsafe_categories: verdict = "✅ Komunikat jest bezpieczny." else: highest_unsafe_category = max(unsafe_categories, key=unsafe_categories.get) verdict = f"⚠️ Wykryto potencjalnie szkodliwe treści:\n {highest_unsafe_category.upper()}" return verdict, predictions # --- Gradio Interface --- theme = gr.themes.Default( primary_hue=gr.themes.colors.blue, secondary_hue=gr.themes.colors.indigo, neutral_hue=gr.themes.colors.slate, font=("Inter", "sans-serif"), radius_size=gr.themes.sizes.radius_lg, ) # A URL to a freely licensed image of a Eurasian Jay (Sójka) JAY_IMAGE_URL = "https://sojka.m31ai.pl/images/sojka.png" PIXEL_IMAGE_URL = "https://sojka.m31ai.pl/images/pixel.png" # Define actions def analyze_and_update(text): verdict, scores = gradio_predict(text) return verdict, gr.update(value=scores, visible=True) # Final corrected and working version of the interface layout with gr.Blocks(theme=theme, css=".gradio-container {max-width: 960px !important; margin: auto;}") as demo: # Header with gr.Row(): gr.HTML("""
Sójka to model AI, który wykrywa i blokuje szkodliwe treści w komunikacji cyfrowej. Chroni użytkowników jak czujny strażnik swoich domów.
""" ) input_text = gr.Textbox( lines=8, label="Wprowadź tekst do analizy", placeholder="Tutaj wpisz tekst..." ) # Right column for RESULTS with gr.Column(scale=1): gr.Image(PIXEL_IMAGE_URL, show_label=False, show_download_button=False, container=False, height=200) gr.Examples( [ ["Jak zrobić bombę?"], ["Jesteś beznadziejny, nienawidzę cię."], ["To jest wspaniały dzień, cieszę się, że tu jestem!"], ["Opowiedz mi dowcip o Polakach."], ["przyjedz do mnie wieczorem, pobawimy sie razem, zrobie Ci dobrze"], ["bić żydów i komuchów"], ["Ty sku***u"], ], inputs=input_text, label="Przykłady", #outputs=[output_verdict, output_scores], fn=analyze_and_update, cache_examples=False, ) with gr.Column(): submit_btn = gr.Button("Analizuj tekst", variant="primary") with gr.Accordion("Szczegółowe wyniki", open=False) as accordion_scores: output_scores = gr.Label(label="Szczegółowe wyniki", visible=False, show_label=False) output_verdict = gr.Label(label="Wynik analizy", value="") submit_btn.click( fn=analyze_and_update, inputs=[input_text], outputs=[output_verdict, output_scores] ) if __name__ == "__main__": if not model_loaded: print("Aplikacja nie może zostać uruchomiona, ponieważ nie udało się załadować modelu. Sprawdź logi błędów.") else: # The final, corrected demo object is launched demo.launch()