Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
""" | |
Gradio application for text classification, styled to be visually appealing. | |
This version uses only the 'sojka2' model. | |
""" | |
import gradio as gr | |
import logging | |
import os | |
from typing import Dict, Tuple, Any | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import numpy as np | |
try: | |
from peft import PeftModel | |
except ImportError: | |
PeftModel = None | |
logging.info("PEFT library not found. Loading models without PEFT support.") | |
# --- Configuration --- | |
# Model path is set to sojka | |
MODEL_PATH = os.getenv("MODEL_PATH", "AndromedaPL/sojka") | |
TOKENIZER_PATH = os.getenv("TOKENIZER_PATH", "sdadas/mmlw-roberta-base") | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
LABELS = ["self-harm", "hate", "vulgar", "sex", "crime"] | |
MAX_SEQ_LENGTH = 512 | |
HF_TOKEN = os.getenv('HF_TOKEN') | |
# Thresholds are now hardcoded | |
THRESHOLDS = { | |
"self-harm": 0.5, | |
"hate": 0.5, | |
"vulgar": 0.5, | |
"sex": 0.5, | |
"crime": 0.5, | |
} | |
# Set up logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
def load_model_and_tokenizer(model_path: str, tokenizer_path: str, device: str) -> Tuple[AutoModelForSequenceClassification, AutoTokenizer]: | |
"""Load the trained model and tokenizer""" | |
logger.info(f"Loading tokenizer from {tokenizer_path}") | |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True) | |
logger.info(f"Tokenizer loaded: {tokenizer.name_or_path}") | |
if tokenizer.pad_token is None: | |
if tokenizer.eos_token: | |
tokenizer.pad_token = tokenizer.eos_token | |
else: | |
tokenizer.add_special_tokens({'pad_token': '[PAD]'}) | |
tokenizer.truncation_side = "right" | |
logger.info(f"Loading model from {model_path}") | |
model_load_kwargs = { | |
"torch_dtype": torch.float16 if device == 'cuda' else torch.float32, | |
"device_map": 'auto' if device == 'cuda' else None, | |
"num_labels": len(LABELS), | |
"problem_type": "regression" | |
} | |
is_peft = os.path.exists(os.path.join(model_path, 'adapter_config.json')) | |
if PeftModel and is_peft: | |
logger.info("PEFT adapter detected. Loading base model and attaching adapter.") | |
try: | |
from peft import PeftConfig | |
peft_config = PeftConfig.from_pretrained(model_path) | |
base_model_path = peft_config.base_model_name_or_path | |
logger.info(f"Loading base model from {base_model_path}") | |
model = AutoModelForSequenceClassification.from_pretrained(base_model_path, **model_load_kwargs) | |
logger.info("Attaching PEFT adapter...") | |
model = PeftModel.from_pretrained(model, model_path) | |
except Exception as e: | |
logger.error(f"Failed to load PEFT model dynamically: {e}. Loading as a standard model.") | |
model = AutoModelForSequenceClassification.from_pretrained(model_path, **model_load_kwargs) | |
else: | |
logger.info("Loading as a standalone sequence classification model.") | |
model = AutoModelForSequenceClassification.from_pretrained(model_path, **model_load_kwargs) | |
model.eval() | |
logger.info(f"Model loaded on device: {next(model.parameters()).device}") | |
return model, tokenizer | |
# --- Load model globally --- | |
try: | |
model, tokenizer = load_model_and_tokenizer(MODEL_PATH, TOKENIZER_PATH, DEVICE) | |
model_loaded = True | |
except Exception as e: | |
logger.error(f"FATAL: Failed to load the model from {MODEL_PATH} or tokenizer from {TOKENIZER_PATH}: {e}", e) | |
model, tokenizer, model_loaded = None, None, False | |
def predict(text: str) -> Dict[str, Any]: | |
"""Tokenize, predict, and format output for a single text.""" | |
if not model_loaded: | |
return {label: 0.0 for label in LABELS} | |
inputs = tokenizer( | |
[text], | |
max_length=MAX_SEQ_LENGTH, | |
truncation=True, | |
padding=True, | |
return_tensors="pt" | |
).to(model.device) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# Using sigmoid for multi-label classification outputs | |
probabilities = torch.sigmoid(outputs.logits) | |
predicted_values = probabilities.cpu().numpy()[0] | |
clipped_values = np.clip(predicted_values, 0.0, 1.0) | |
return {label: float(score) for label, score in zip(LABELS, clipped_values)} | |
def gradio_predict(text: str) -> Tuple[str, Dict[str, float]]: | |
"""Gradio prediction function wrapper.""" | |
if not model_loaded: | |
error_message = "Błąd: Model nie został załadowany." | |
empty_preds = {label: 0.0 for label in LABELS} | |
return error_message, empty_preds | |
if not text or not text.strip(): | |
return "Wpisz tekst, aby go przeanalizować.", {label: 0.0 for label in LABELS} | |
predictions = predict(text) | |
unsafe_categories = { | |
label: score for label, score in predictions.items() | |
if score >= THRESHOLDS[label] | |
} | |
if not unsafe_categories: | |
verdict = "✅ Komunikat jest bezpieczny." | |
else: | |
highest_unsafe_category = max(unsafe_categories, key=unsafe_categories.get) | |
verdict = f"⚠️ Wykryto potencjalnie szkodliwe treści:\n {highest_unsafe_category.upper()}" | |
return verdict, predictions | |
# --- Gradio Interface --- | |
theme = gr.themes.Default( | |
primary_hue=gr.themes.colors.blue, | |
secondary_hue=gr.themes.colors.indigo, | |
neutral_hue=gr.themes.colors.slate, | |
font=("Inter", "sans-serif"), | |
radius_size=gr.themes.sizes.radius_lg, | |
) | |
# A URL to a freely licensed image of a Eurasian Jay (Sójka) | |
JAY_IMAGE_URL = "https://sojka.m31ai.pl/images/sojka.png" | |
PIXEL_IMAGE_URL = "https://sojka.m31ai.pl/images/pixel.png" | |
# Define actions | |
def analyze_and_update(text): | |
verdict, scores = gradio_predict(text) | |
return verdict, gr.update(value=scores, visible=True) | |
# Final corrected and working version of the interface layout | |
with gr.Blocks(theme=theme, css=".gradio-container {max-width: 960px !important; margin: auto;}") as demo: | |
# Header | |
with gr.Row(): | |
gr.HTML(""" | |
<div style="display: flex; align-items: center; justify-content: space-between; width: 100%;"> | |
<div style="display: flex; align-items: center; gap: 12px;"> | |
<svg width="32" height="32" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"> | |
<path d="M12 2L3 5V11C3 16.52 7.08 21.61 12 23C16.92 21.61 21 16.52 21 11V5L12 2Z" | |
stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" fill="none"/> | |
</svg> | |
<h1 style="font-size: 1.5rem; font-weight: 600; margin: 0;">SÓJKA</h1> | |
</div> | |
<div style="display: flex; align-items: center; gap: 20px; font-size: 0.9rem;"> | |
<a href="https://sojka.m31ai.pl/projekt.html" target="blank" style="text-decoration: none; color: inherit;">O projekcie</a> | |
<a href="https://sojka.m31ai.pl/kategorie.html" target="blank" style="text-decoration: none; color: inherit;">Opis kategorii</a> | |
<button id="test-sojka-btn" class="gr-button gr-button-primary gr-button-lg" | |
style="background-color: var(--primary-500); color: white; padding: 8px 16px; border-radius: 8px;"> | |
Testuj Sójkę | |
</button> | |
</div> | |
</div> | |
""") | |
gr.HTML("<hr style='border: 1px solid var(--neutral-200); margin-top: 1rem; margin-bottom: 2rem;'>") | |
# Main content area | |
with gr.Row(): | |
# Left column for controls and description | |
with gr.Column(scale=2): | |
gr.Image(JAY_IMAGE_URL, show_label=False, show_download_button=False, container=False, height=200) | |
gr.Markdown( | |
""" | |
<h1 style="font-size: 2.8rem; font-weight: 800; line-height: 1.2; margin-top: 1rem; margin-bottom: 1rem; color: var(--neutral-800);"> | |
Przetestuj <span style="color: var(--primary-600);">SÓJKĘ</span> – Bielik Guard dla bezpiecznej komunikacji | |
</h1> | |
<p style="font-size: 1rem; color: var(--neutral-600); margin-bottom: 2rem;"> | |
Sójka to model AI, który wykrywa i blokuje szkodliwe treści w komunikacji cyfrowej. Chroni użytkowników jak czujny strażnik swoich domów. | |
</p> | |
""" | |
) | |
input_text = gr.Textbox( | |
lines=8, | |
label="Wprowadź tekst do analizy", | |
placeholder="Tutaj wpisz tekst..." | |
) | |
# Right column for RESULTS | |
with gr.Column(scale=1): | |
gr.Image(PIXEL_IMAGE_URL, show_label=False, show_download_button=False, container=False, height=200) | |
gr.Examples( | |
[ | |
["Jak zrobić bombę?"], | |
["Jesteś beznadziejny, nienawidzę cię."], | |
["To jest wspaniały dzień, cieszę się, że tu jestem!"], | |
["Opowiedz mi dowcip o Polakach."], | |
["przyjedz do mnie wieczorem, pobawimy sie razem, zrobie Ci dobrze"], | |
["bić żydów i komuchów"], | |
["Ty sku***u"], | |
], | |
inputs=input_text, | |
label="Przykłady", | |
#outputs=[output_verdict, output_scores], | |
fn=analyze_and_update, | |
cache_examples=False, | |
) | |
with gr.Column(): | |
submit_btn = gr.Button("Analizuj tekst", variant="primary") | |
with gr.Accordion("Szczegółowe wyniki", open=False) as accordion_scores: | |
output_scores = gr.Label(label="Szczegółowe wyniki", visible=False, show_label=False) | |
output_verdict = gr.Label(label="Wynik analizy", value="") | |
submit_btn.click( | |
fn=analyze_and_update, | |
inputs=[input_text], | |
outputs=[output_verdict, output_scores] | |
) | |
if __name__ == "__main__": | |
if not model_loaded: | |
print("Aplikacja nie może zostać uruchomiona, ponieważ nie udało się załadować modelu. Sprawdź logi błędów.") | |
else: | |
# The final, corrected demo object is launched | |
demo.launch() |