import torch import numpy as np import sys import os from transformers import RobertaTokenizer, AutoModelForTokenClassification, RobertaForSequenceClassification import spacy import tokenizations from numpy import asarray from numpy import savetxt, loadtxt import numpy as np import json from copy import deepcopy # from sty import fg, bg, ef, rs, RgbBg, Style import re from tqdm import tqdm import gradio as gr from matplotlib import pyplot as plt import seaborn as sns os.system("python -m spacy download en_core_web_sm") nlp = spacy.load("en_core_web_sm") tokenizer = RobertaTokenizer.from_pretrained("roberta-base") clause_model = AutoModelForTokenClassification.from_pretrained("./clause_model_512", num_labels=3) classification_model = RobertaForSequenceClassification.from_pretrained("./classfication_model", num_labels=18) labels2attrs = { "##BOUNDED EVENT (SPECIFIC)": ("specific", "dynamic", "episodic"), "##BOUNDED EVENT (GENERIC)": ("generic", "dynamic", "episodic"), "##UNBOUNDED EVENT (SPECIFIC)": ("specific", "dynamic", "static"), # This should be (static, or habitual) "##UNBOUNDED EVENT (GENERIC)": ("generic", "dynamic", "static"), "##BASIC STATE": ("specific", "stative", "static"), "##COERCED STATE (SPECIFIC)": ("specific", "dynamic", "static"), "##COERCED STATE (GENERIC)": ("generic", "dynamic", "static"), "##PERFECT COERCED STATE (SPECIFIC)": ("specific", "dynamic", "episodic"), "##PERFECT COERCED STATE (GENERIC)": ("generic", "dynamic", "episodic"), "##GENERIC SENTENCE (DYNAMIC)": ("generic", "dynamic", "habitual"), # habitual count as unbounded "##GENERIC SENTENCE (STATIC)": ("generic", "stative", "static"), # The car is red now (static) "##GENERIC SENTENCE (HABITUAL)": ("generic", "stative", "habitual"), # I go to the gym regularly (habitual) "##GENERALIZING SENTENCE (DYNAMIC)": ("specific", "dynamic", "habitual"), "##GENERALIZING SENTENCE (STATIVE)": ("specific", "stative", "habitual"), "##QUESTION": ("NA", "NA", "NA"), "##IMPERATIVE": ("NA", "NA", "NA"), "##NONSENSE": ("NA", "NA", "NA"), "##OTHER": ("NA", "NA", "NA"), } label2index = {l:i for l,i in zip(labels2attrs.keys(), np.arange(len(labels2attrs)))} index2label = {i:l for l,i in label2index.items()} def auto_split(text): doc = nlp(text) current_len = 0 snippets = [] current_snippet = "" for sent in doc.sents: text = sent.text words = text.split() if current_len + len(words) > 200: snippets.append(current_snippet) current_snippet = text current_len = len(words) else: current_snippet += " " + text current_len += len(words) snippets.append(current_snippet) # the leftover part. return snippets def majority_vote(array): unique, counts = np.unique(np.array(array), return_counts=True) return unique[np.argmax(counts)] def get_pred_clause_labels(text, words): model_inputs = tokenizer(text, padding='max_length', max_length=512, truncation=True, return_tensors='pt') roberta_tokens = (tokenizer.convert_ids_to_tokens(model_inputs['input_ids'][0])) a2b, b2a = tokenizations.get_alignments(words, roberta_tokens) logits = clause_model(**model_inputs)[0] tagging = logits.argmax(-1)[0].numpy() pred_labels = [] for aligment in a2b: # spacy token index to roberta_token index if len(aligment) == 0: pred_labels.append(1) elif len(aligment) == 1: pred_labels.append(tagging[aligment[0]]) else: pred_labels.append(majority_vote([tagging[a] for a in aligment])) assert len(pred_labels) == len(words) return pred_labels def seg_clause(text): words = text.strip().split() labels = get_pred_clause_labels(text, words) segmented_clauses = [] prev_label = 2 current_clause = None for cur_token, cur_label in zip(words, labels): if prev_label == 2: current_clause = [] if current_clause != None: current_clause.append(cur_token) if cur_label == 2: if prev_label in [0, 1]: segmented_clauses.append(deepcopy(current_clause)) ## 0 1 1 1 1 2 0 1 1 current_clause = None prev_label = cur_label if current_clause is not None and len(current_clause) != 0: # append leftover segmented_clauses.append(deepcopy(current_clause)) return [" ".join(clause) for clause in segmented_clauses if clause is not None] # def pretty_print_segmented_clause(segmented_clauses): # np.random.seed(42) # bg.orange = Style(RgbBg(255, 150, 50)) # bg.purple = Style(RgbBg(180, 130, 225)) # colors = [bg.red, bg.orange, bg.yellow, bg.green, bg.blue, bg.purple] # prev_color = 0 # to_print = [] # for cl in segmented_clauses: # color_choice = np.random.choice(np.delete(np.arange(len(colors)), prev_color)) # prev_color = color_choice # colored_cl = colors[color_choice] + cl + bg.rs # to_print.append(colored_cl) # print(*to_print, sep=" ") def get_pred_classification_labels(clauses, batch_size=32): clause2labels = [] for i in range(0, len(clauses), batch_size): batch_examples = clauses[i : i + batch_size] model_inputs = tokenizer(batch_examples, padding='max_length', max_length=128, truncation=True, return_tensors='pt') logits = classification_model(**model_inputs)[0] pred_labels = logits.argmax(-1).numpy() pred_labels = [index2label[l] for l in pred_labels] clause2labels.extend([(s, labels2attrs[l],) for s,l in zip(batch_examples, pred_labels)]) return clause2labels def label_visualization(clause2labels): total_clauses = len(clause2labels) aspect_labels, genericity_labels, boundedness_labels = [], [], [] for _, labels in clause2labels: labels = tuple(labels) print(labels) genericity_label = labels[0] aspect_label = labels[1] boundedness_label = labels[2] aspect_labels.append(aspect_label) genericity_labels.append(genericity_label) boundedness_labels.append(boundedness_label) aspect_dict = {"Dynamic": aspect_labels.count("dynamic"), "Stative": aspect_labels.count("stative"), "NA": aspect_labels.count("NA")} genericity_dict = {"Generic": genericity_labels.count("generic"), "Specific": genericity_labels.count("specific"), "NA": genericity_labels.count("NA")} boundedness_dict = {"Static": boundedness_labels.count("static"), "Episodic": boundedness_labels.count("episodic"), "Habitual": boundedness_labels.count("habitual"), "NA": boundedness_labels.count("NA")} print(aspect_dict, genericity_dict, boundedness_dict) fig, axs = plt.subplots(1, 3, figsize=(10, 6,)) fig.tight_layout(pad=5.0) dict_aspect = {k : float(v / total_clauses) for k, v in aspect_dict.items() if v != 0} dict_genericity = {k : float(v / total_clauses) for k, v in genericity_dict.items() if v != 0} dict_boundedness = {k : float(v / total_clauses) for k, v in boundedness_dict.items() if v != 0} print(dict_aspect) print(list(dict_aspect.values()), len(dict_aspect.keys()), list(dict_aspect.keys())) axs[0].pie(list(dict_aspect.values()), colors = sns.color_palette('pastel')[0:len(dict_aspect.keys())], labels=dict_aspect.keys(), autopct='%.0f%%', normalize=True ) axs[0].set_title("Aspect") axs[1].pie(list(dict_genericity.values()), colors = sns.color_palette('pastel')[3: 3 + len(dict_genericity.keys())], labels=dict_genericity.keys(), autopct='%.0f%%', normalize=True) axs[1].set_title("Genericity") axs[2].pie(list(dict_boundedness.values()), colors = sns.color_palette('pastel')[6: 6 + len(dict_boundedness.keys())], labels=dict_boundedness.keys(), autopct='%.0f%%', normalize=True) axs[2].set_title("Boundedness") return fig def run_pipeline(text): snippets = auto_split(text) print(snippets) all_clauses = [] for s in snippets: segmented_clauses = seg_clause(s) all_clauses.extend(segmented_clauses) clause2labels = get_pred_classification_labels(all_clauses) output_clauses = [(c, str(i + 1)) for i, c in enumerate(all_clauses)] figure = label_visualization(clause2labels) clause2labels = [(k,str(v),) for k, v in clause2labels] return output_clauses, clause2labels, figure # with open("pipeline_outputs.jsonl", "w") as fw: # with open("all_text.txt", "r") as f: # lines = f.readlines() # print(f"Totally detected {len(lines)} documents.") # for text in tqdm(lines): # snippets = auto_split(text) # all_clauses = [] # for s in snippets: # segmented_clauses = seg_clause(s) # all_clauses.extend(segmented_clauses) # # pretty_print_segmented_clause(segmented_clauses) # clause2labels = get_pred_classification_labels(all_clauses) # json.dump(clause2labels, fw) # fw.write("\n") color_panel_1 = ["red", "green", "yellow", "DodgerBlue", "orange", "DarkSalmon", "pink", "cyan", "gold", "aqua", "violet"] index_colormap = {str(i) : color_panel_1[i % len(color_panel_1)] for i in np.arange(1, 100000)} color_panel_2 = ["Gray", "DodgerBlue", "Wheat", "OliveDrab", "DarkKhaki", "DarkSalmon", "Orange", "Gold", "Aqua", "Tomato", "Violet"] str_attrs = sorted([str(v) for v in set(labels2attrs.values())]) # print(str_attrs, len(str_attrs), len(color_panel_2)) assert len(str_attrs) == len(color_panel_2) attr_colormap = {a:c for a, c in zip(str_attrs, color_panel_2)} # attr_colormap = { # ("specific", "dynamic", "episodic"): # ("generic", "dynamic", "episodic"): # ("specific", "dynamic", "static"): # ("generic", "dynamic", "static"): # ("specific", "stative", "static"): # ("specific", "dynamic", "static"): # ("generic", "dynamic", "static"): # ("specific", "dynamic", "episodic"): # ("generic", "dynamic", "episodic"): # ("generic", "dynamic", "habitual"): # ("generic", "stative", "static"): # ("generic", "stative", "habitual"): # ("specific", "dynamic", "habitual"): # ("specific", "stative", "habitual"): # ("NA", "NA", "NA"): # } demo = gr.Interface( fn=run_pipeline, inputs=["text"], outputs= [ gr.HighlightedText( label="Clause Segmentation", show_label=True, combine_adjacent=False, ).style(color_map = index_colormap), gr.HighlightedText( label="Attribute Classification", show_label=True, show_legend=True, combine_adjacent=False, ).style(color_map=attr_colormap), gr.Plot(), ] ) demo.launch()