Spaces:
Sleeping
Sleeping
| import re | |
| import gradio as gr | |
| import nltk | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForMaskedLM | |
| nltk.download("punkt") | |
| nltk.download("punkt_tab") | |
| def pred_slonspell(input_text: str): | |
| return_values = [] | |
| input_text = re.sub(r"(\n)+|( ){2,}", " ", input_text) | |
| input_sentences = nltk.sent_tokenize(input_text, language="slovene") | |
| for _sent in input_sentences: | |
| input_words = nltk.word_tokenize(_sent, language="slovene") | |
| formatted_text = " <mask> ".join(input_words) | |
| formatted_text = f"{formatted_text} <mask>" | |
| encoded_input = tokenizer(formatted_text, return_tensors="pt", max_length=512, truncation=True) | |
| mask_positions = encoded_input["input_ids"] == tokenizer.mask_token_id # bool tensor | |
| with torch.no_grad(): | |
| logits = model(**{k: v.to(DEVICE) for k, v in encoded_input.items()}).logits[:, :, [0, 1, 2, 3]].cpu() | |
| probas = torch.softmax(logits, dim=-1)[0] | |
| relevant_probas = probas[mask_positions[0]] # [num_words, 4] | |
| is_ok_proba = relevant_probas[:, [0]] | |
| is_err_proba = torch.sum(relevant_probas[:, 1:], dim=1, keepdim=True) | |
| binary_probas = torch.hstack((is_ok_proba, is_err_proba)) | |
| preds = torch.argmax(binary_probas, dim=-1).tolist() | |
| # pred_label_probas = binary_probas[torch.arange(len(preds)), preds] | |
| return_values.extend( | |
| [(_word, "error" if preds[_idx_word] else None) for _idx_word, _word in enumerate(input_words)] | |
| ) | |
| return return_values | |
| _description = """\ | |
| <h1> SloNSpell demo</h1> | |
| <p>This is a simple demo setup for SloNSpell, a 🇸🇮 Slovene spelling error detection model. | |
| You can find more about the model in the model card <a href='https://huggingface.co/cjvt/SloBERTa-slo-word-spelling-annotator'>\ | |
| cjvt/SloBERTa-slo-word-spelling-annotator</a>.</p> | |
| <p>Given an input text: </p> | |
| <p>1. The input is segmented into sentences and tokenized using NLTK to prepare the model input.</p> | |
| <p>2. The model makes predictions on the sentence level. </p> | |
| <b>The model does not work perfectly and can make mistakes, please check the output!</b> | |
| """ | |
| demo = gr.Interface( | |
| pred_slonspell, | |
| gr.Textbox( | |
| label="Input text", | |
| info="The text that you want to run through the SloNSpell spell-checking model.", | |
| lines=3, | |
| value="Model vbesedilu o znači besede, v katerih se najajajo napake.", | |
| ), | |
| gr.HighlightedText( | |
| label="Spell-checking prediction", | |
| show_legend=True, | |
| color_map={"error": "red"}), | |
| theme=gr.themes.Base(), | |
| description=_description, | |
| allow_flagging="never" # RIP flagging to HuggingFace dataset | |
| ) | |
| if __name__ == "__main__": | |
| model_name = "cjvt/SloBERTa-slo-word-spelling-annotator" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) | |
| model = AutoModelForMaskedLM.from_pretrained(model_name) | |
| mask_token = tokenizer.mask_token | |
| DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| demo.launch() | |