|
import torch
|
|
import numpy as np
|
|
import sys
|
|
import os
|
|
from transformers import RobertaTokenizer, AutoModelForTokenClassification, RobertaForSequenceClassification
|
|
import spacy
|
|
import tokenizations
|
|
from numpy import asarray
|
|
from numpy import savetxt, loadtxt
|
|
import numpy as np
|
|
import json
|
|
from copy import deepcopy
|
|
from sty import fg, bg, ef, rs, RgbBg, Style
|
|
import re
|
|
from tqdm import tqdm
|
|
import gradio as gr
|
|
|
|
nlp = spacy.load("en_core_web_sm")
|
|
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
|
|
clause_model = AutoModelForTokenClassification.from_pretrained("C:\\Users\\pixin\\Desktop\\Reddit ML\\Trained Models\\clause_model_512", num_labels=3)
|
|
classification_model = RobertaForSequenceClassification.from_pretrained("C:\\Users\\pixin\\Desktop\\Reddit ML\\Trained Models\\classfication_model", num_labels=18)
|
|
|
|
|
|
labels2attrs = {
|
|
"##BOUNDED EVENT (SPECIFIC)": ("specific", "dynamic", "episodic"),
|
|
"##BOUNDED EVENT (GENERIC)": ("generic", "dynamic", "episodic"),
|
|
"##UNBOUNDED EVENT (SPECIFIC)": ("specific", "dynamic", "static"),
|
|
"##UNBOUNDED EVENT (GENERIC)": ("generic", "dynamic", "static"),
|
|
"##BASIC STATE": ("specific", "stative", "static"),
|
|
"##COERCED STATE (SPECIFIC)": ("specific", "dynamic", "static"),
|
|
"##COERCED STATE (GENERIC)": ("generic", "dynamic", "static"),
|
|
"##PERFECT COERCED STATE (SPECIFIC)": ("specific", "dynamic", "episodic"),
|
|
"##PERFECT COERCED STATE (GENERIC)": ("generic", "dynamic", "episodic"),
|
|
"##GENERIC SENTENCE (DYNAMIC)": ("generic", "dynamic", "habitual"),
|
|
"##GENERIC SENTENCE (STATIC)": ("generic", "stative", "static"),
|
|
"##GENERIC SENTENCE (HABITUAL)": ("generic", "stative", "habitual"),
|
|
"##GENERALIZING SENTENCE (DYNAMIC)": ("specific", "dynamic", "habitual"),
|
|
"##GENERALIZING SENTENCE (STATIVE)": ("specific", "stative", "habitual"),
|
|
"##QUESTION": ("NA", "NA", "NA"),
|
|
"##IMPERATIVE": ("NA", "NA", "NA"),
|
|
"##NONSENSE": ("NA", "NA", "NA"),
|
|
"##OTHER": ("NA", "NA", "NA"),
|
|
}
|
|
|
|
label2index = {l:i for l,i in zip(labels2attrs.keys(), np.arange(len(labels2attrs)))}
|
|
index2label = {i:l for l,i in label2index.items()}
|
|
|
|
def auto_split(text):
|
|
doc = nlp(text)
|
|
current_len = 0
|
|
snippets = []
|
|
current_snippet = ""
|
|
for sent in doc.sents:
|
|
text = sent.text
|
|
words = text.split()
|
|
if current_len + len(words) > 200:
|
|
snippets.append(current_snippet)
|
|
current_snippet = text
|
|
current_len = len(words)
|
|
else:
|
|
current_snippet += " " + text
|
|
current_len += len(words)
|
|
snippets.append(current_snippet)
|
|
return snippets
|
|
|
|
|
|
def majority_vote(array):
|
|
unique, counts = np.unique(np.array(array), return_counts=True)
|
|
return unique[np.argmax(counts)]
|
|
|
|
def get_pred_clause_labels(text, words):
|
|
model_inputs = tokenizer(text, padding='max_length', max_length=512, truncation=True, return_tensors='pt')
|
|
roberta_tokens = (tokenizer.convert_ids_to_tokens(model_inputs['input_ids'][0]))
|
|
a2b, b2a = tokenizations.get_alignments(words, roberta_tokens)
|
|
logits = clause_model(**model_inputs)[0]
|
|
tagging = logits.argmax(-1)[0].numpy()
|
|
pred_labels = []
|
|
for aligment in a2b:
|
|
if len(aligment) == 0: pred_labels.append(1)
|
|
elif len(aligment) == 1: pred_labels.append(tagging[aligment[0]])
|
|
else:
|
|
pred_labels.append(majority_vote([tagging[a] for a in aligment]))
|
|
assert len(pred_labels) == len(words)
|
|
return pred_labels
|
|
|
|
def seg_clause(text):
|
|
words = text.strip().split()
|
|
labels = get_pred_clause_labels(text, words)
|
|
segmented_clauses = []
|
|
prev_label = 2
|
|
current_clause = None
|
|
for cur_token, cur_label in zip(words, labels):
|
|
if prev_label == 2: current_clause = []
|
|
if current_clause != None: current_clause.append(cur_token)
|
|
|
|
if cur_label == 2:
|
|
if prev_label in [0, 1]:
|
|
segmented_clauses.append(deepcopy(current_clause))
|
|
current_clause = None
|
|
prev_label = cur_label
|
|
|
|
if current_clause is not None and len(current_clause) != 0:
|
|
segmented_clauses.append(deepcopy(current_clause))
|
|
return [" ".join(clause) for clause in segmented_clauses if clause is not None]
|
|
|
|
def pretty_print_segmented_clause(segmented_clauses):
|
|
np.random.seed(42)
|
|
bg.orange = Style(RgbBg(255, 150, 50))
|
|
bg.purple = Style(RgbBg(180, 130, 225))
|
|
colors = [bg.red, bg.orange, bg.yellow, bg.green, bg.blue, bg.purple]
|
|
prev_color = 0
|
|
to_print = []
|
|
for cl in segmented_clauses:
|
|
color_choice = np.random.choice(np.delete(np.arange(len(colors)), prev_color))
|
|
prev_color = color_choice
|
|
colored_cl = colors[color_choice] + cl + bg.rs
|
|
to_print.append(colored_cl)
|
|
print(*to_print, sep=" ")
|
|
|
|
|
|
def get_pred_classification_labels(clauses, batch_size=32):
|
|
clause2labels = []
|
|
for i in range(0, len(clauses) + 1, batch_size):
|
|
batch_examples = clauses[i : i + batch_size]
|
|
model_inputs = tokenizer(batch_examples, padding='max_length', max_length=128, truncation=True, return_tensors='pt')
|
|
logits = classification_model(**model_inputs)[0]
|
|
pred_labels = logits.argmax(-1).numpy()
|
|
pred_labels = [index2label[l] for l in pred_labels]
|
|
|
|
clause2labels.extend([(s, str(l),) for s,l in zip(batch_examples, pred_labels)])
|
|
return clause2labels
|
|
|
|
|
|
|
|
def run_pipeline(text):
|
|
snippets = auto_split(text)
|
|
all_clauses = []
|
|
for s in snippets:
|
|
segmented_clauses = seg_clause(s)
|
|
all_clauses.extend(segmented_clauses)
|
|
clause2labels = get_pred_classification_labels(all_clauses)
|
|
output_clauses = [(c, str(i + 1)) for i, c in enumerate(all_clauses)]
|
|
return output_clauses, clause2labels
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
color_panel_1 = ["red", "green", "yellow", "DodgerBlue", "orange", "DarkSalmon", "pink", "cyan", "gold", "aqua", "violet"]
|
|
index_colormap = {str(i) : color_panel_1[i % len(color_panel_1)] for i in np.arange(1, 100000)}
|
|
color_panel_2 = ["Violet", "DodgerBlue", "Wheat", "OliveDrab", "DarkKhaki", "DarkSalmon", "Orange", "Gold", "Aqua", "Tomato", "Gray"]
|
|
str_attrs = [str(v) for v in set(labels2attrs.values())]
|
|
print(str_attrs, len(str_attrs), len(color_panel_2))
|
|
assert len(str_attrs) == len(color_panel_2)
|
|
attr_colormap = {a:c for a, c in zip(str_attrs, color_panel_2)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo = gr.Interface(
|
|
fn=run_pipeline,
|
|
inputs=["text"],
|
|
outputs= [
|
|
gr.HighlightedText(
|
|
label="Clause Segmentation",
|
|
show_label=True,
|
|
combine_adjacent=False,
|
|
).style(color_map = index_colormap),
|
|
|
|
gr.HighlightedText(
|
|
label="Attribute Classification",
|
|
show_label=True,
|
|
show_legend=True,
|
|
combine_adjacent=False,
|
|
).style(color_map=attr_colormap),
|
|
]
|
|
)
|
|
|
|
demo.launch(share=True) |