FrederickPi
commited on
Commit
·
a721802
1
Parent(s):
9b46ef3
Add application file
Browse files
app.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import sys
|
4 |
+
import os
|
5 |
+
from transformers import RobertaTokenizer, AutoModelForTokenClassification, RobertaForSequenceClassification
|
6 |
+
import spacy
|
7 |
+
import tokenizations
|
8 |
+
from numpy import asarray
|
9 |
+
from numpy import savetxt, loadtxt
|
10 |
+
import numpy as np
|
11 |
+
import json
|
12 |
+
from copy import deepcopy
|
13 |
+
from sty import fg, bg, ef, rs, RgbBg, Style
|
14 |
+
import re
|
15 |
+
from tqdm import tqdm
|
16 |
+
import gradio as gr
|
17 |
+
|
18 |
+
nlp = spacy.load("en_core_web_sm")
|
19 |
+
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
|
20 |
+
clause_model = AutoModelForTokenClassification.from_pretrained("C:\\Users\\pixin\\Desktop\\Reddit ML\\Trained Models\\clause_model_512", num_labels=3)
|
21 |
+
classification_model = RobertaForSequenceClassification.from_pretrained("C:\\Users\\pixin\\Desktop\\Reddit ML\\Trained Models\\classfication_model", num_labels=18)
|
22 |
+
|
23 |
+
|
24 |
+
labels2attrs = {
|
25 |
+
"##BOUNDED EVENT (SPECIFIC)": ("specific", "dynamic", "episodic"),
|
26 |
+
"##BOUNDED EVENT (GENERIC)": ("generic", "dynamic", "episodic"),
|
27 |
+
"##UNBOUNDED EVENT (SPECIFIC)": ("specific", "dynamic", "static"), # This should be (static, or habitual)
|
28 |
+
"##UNBOUNDED EVENT (GENERIC)": ("generic", "dynamic", "static"),
|
29 |
+
"##BASIC STATE": ("specific", "stative", "static"),
|
30 |
+
"##COERCED STATE (SPECIFIC)": ("specific", "dynamic", "static"),
|
31 |
+
"##COERCED STATE (GENERIC)": ("generic", "dynamic", "static"),
|
32 |
+
"##PERFECT COERCED STATE (SPECIFIC)": ("specific", "dynamic", "episodic"),
|
33 |
+
"##PERFECT COERCED STATE (GENERIC)": ("generic", "dynamic", "episodic"),
|
34 |
+
"##GENERIC SENTENCE (DYNAMIC)": ("generic", "dynamic", "habitual"), # habitual count as unbounded
|
35 |
+
"##GENERIC SENTENCE (STATIC)": ("generic", "stative", "static"), # The car is red now (static)
|
36 |
+
"##GENERIC SENTENCE (HABITUAL)": ("generic", "stative", "habitual"), # I go to the gym regularly (habitual)
|
37 |
+
"##GENERALIZING SENTENCE (DYNAMIC)": ("specific", "dynamic", "habitual"),
|
38 |
+
"##GENERALIZING SENTENCE (STATIVE)": ("specific", "stative", "habitual"),
|
39 |
+
"##QUESTION": ("NA", "NA", "NA"),
|
40 |
+
"##IMPERATIVE": ("NA", "NA", "NA"),
|
41 |
+
"##NONSENSE": ("NA", "NA", "NA"),
|
42 |
+
"##OTHER": ("NA", "NA", "NA"),
|
43 |
+
}
|
44 |
+
|
45 |
+
label2index = {l:i for l,i in zip(labels2attrs.keys(), np.arange(len(labels2attrs)))}
|
46 |
+
index2label = {i:l for l,i in label2index.items()}
|
47 |
+
|
48 |
+
def auto_split(text):
|
49 |
+
doc = nlp(text)
|
50 |
+
current_len = 0
|
51 |
+
snippets = []
|
52 |
+
current_snippet = ""
|
53 |
+
for sent in doc.sents:
|
54 |
+
text = sent.text
|
55 |
+
words = text.split()
|
56 |
+
if current_len + len(words) > 200:
|
57 |
+
snippets.append(current_snippet)
|
58 |
+
current_snippet = text
|
59 |
+
current_len = len(words)
|
60 |
+
else:
|
61 |
+
current_snippet += " " + text
|
62 |
+
current_len += len(words)
|
63 |
+
snippets.append(current_snippet) # the leftover part.
|
64 |
+
return snippets
|
65 |
+
|
66 |
+
|
67 |
+
def majority_vote(array):
|
68 |
+
unique, counts = np.unique(np.array(array), return_counts=True)
|
69 |
+
return unique[np.argmax(counts)]
|
70 |
+
|
71 |
+
def get_pred_clause_labels(text, words):
|
72 |
+
model_inputs = tokenizer(text, padding='max_length', max_length=512, truncation=True, return_tensors='pt')
|
73 |
+
roberta_tokens = (tokenizer.convert_ids_to_tokens(model_inputs['input_ids'][0]))
|
74 |
+
a2b, b2a = tokenizations.get_alignments(words, roberta_tokens)
|
75 |
+
logits = clause_model(**model_inputs)[0]
|
76 |
+
tagging = logits.argmax(-1)[0].numpy()
|
77 |
+
pred_labels = []
|
78 |
+
for aligment in a2b: # spacy token index to roberta_token index
|
79 |
+
if len(aligment) == 0: pred_labels.append(1)
|
80 |
+
elif len(aligment) == 1: pred_labels.append(tagging[aligment[0]])
|
81 |
+
else:
|
82 |
+
pred_labels.append(majority_vote([tagging[a] for a in aligment]))
|
83 |
+
assert len(pred_labels) == len(words)
|
84 |
+
return pred_labels
|
85 |
+
|
86 |
+
def seg_clause(text):
|
87 |
+
words = text.strip().split()
|
88 |
+
labels = get_pred_clause_labels(text, words)
|
89 |
+
segmented_clauses = []
|
90 |
+
prev_label = 2
|
91 |
+
current_clause = None
|
92 |
+
for cur_token, cur_label in zip(words, labels):
|
93 |
+
if prev_label == 2: current_clause = []
|
94 |
+
if current_clause != None: current_clause.append(cur_token)
|
95 |
+
|
96 |
+
if cur_label == 2:
|
97 |
+
if prev_label in [0, 1]:
|
98 |
+
segmented_clauses.append(deepcopy(current_clause)) ## 0 1 1 1 1 2 0 1 1
|
99 |
+
current_clause = None
|
100 |
+
prev_label = cur_label
|
101 |
+
|
102 |
+
if current_clause is not None and len(current_clause) != 0: # append leftover
|
103 |
+
segmented_clauses.append(deepcopy(current_clause))
|
104 |
+
return [" ".join(clause) for clause in segmented_clauses if clause is not None]
|
105 |
+
|
106 |
+
def pretty_print_segmented_clause(segmented_clauses):
|
107 |
+
np.random.seed(42)
|
108 |
+
bg.orange = Style(RgbBg(255, 150, 50))
|
109 |
+
bg.purple = Style(RgbBg(180, 130, 225))
|
110 |
+
colors = [bg.red, bg.orange, bg.yellow, bg.green, bg.blue, bg.purple]
|
111 |
+
prev_color = 0
|
112 |
+
to_print = []
|
113 |
+
for cl in segmented_clauses:
|
114 |
+
color_choice = np.random.choice(np.delete(np.arange(len(colors)), prev_color))
|
115 |
+
prev_color = color_choice
|
116 |
+
colored_cl = colors[color_choice] + cl + bg.rs
|
117 |
+
to_print.append(colored_cl)
|
118 |
+
print(*to_print, sep=" ")
|
119 |
+
|
120 |
+
|
121 |
+
def get_pred_classification_labels(clauses, batch_size=32):
|
122 |
+
clause2labels = []
|
123 |
+
for i in range(0, len(clauses) + 1, batch_size):
|
124 |
+
batch_examples = clauses[i : i + batch_size]
|
125 |
+
model_inputs = tokenizer(batch_examples, padding='max_length', max_length=128, truncation=True, return_tensors='pt')
|
126 |
+
logits = classification_model(**model_inputs)[0]
|
127 |
+
pred_labels = logits.argmax(-1).numpy()
|
128 |
+
pred_labels = [index2label[l] for l in pred_labels]
|
129 |
+
|
130 |
+
clause2labels.extend([(s, str(l),) for s,l in zip(batch_examples, pred_labels)])
|
131 |
+
return clause2labels
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
def run_pipeline(text):
|
136 |
+
snippets = auto_split(text)
|
137 |
+
all_clauses = []
|
138 |
+
for s in snippets:
|
139 |
+
segmented_clauses = seg_clause(s)
|
140 |
+
all_clauses.extend(segmented_clauses)
|
141 |
+
clause2labels = get_pred_classification_labels(all_clauses)
|
142 |
+
output_clauses = [(c, str(i + 1)) for i, c in enumerate(all_clauses)]
|
143 |
+
return output_clauses, clause2labels
|
144 |
+
|
145 |
+
# with open("pipeline_outputs.jsonl", "w") as fw:
|
146 |
+
# with open("all_text.txt", "r") as f:
|
147 |
+
# lines = f.readlines()
|
148 |
+
# print(f"Totally detected {len(lines)} documents.")
|
149 |
+
# for text in tqdm(lines):
|
150 |
+
# snippets = auto_split(text)
|
151 |
+
# all_clauses = []
|
152 |
+
# for s in snippets:
|
153 |
+
# segmented_clauses = seg_clause(s)
|
154 |
+
# all_clauses.extend(segmented_clauses)
|
155 |
+
# # pretty_print_segmented_clause(segmented_clauses)
|
156 |
+
|
157 |
+
# clause2labels = get_pred_classification_labels(all_clauses)
|
158 |
+
# json.dump(clause2labels, fw)
|
159 |
+
# fw.write("\n")
|
160 |
+
|
161 |
+
color_panel_1 = ["red", "green", "yellow", "DodgerBlue", "orange", "DarkSalmon", "pink", "cyan", "gold", "aqua", "violet"]
|
162 |
+
index_colormap = {str(i) : color_panel_1[i % len(color_panel_1)] for i in np.arange(1, 100000)}
|
163 |
+
color_panel_2 = ["Violet", "DodgerBlue", "Wheat", "OliveDrab", "DarkKhaki", "DarkSalmon", "Orange", "Gold", "Aqua", "Tomato", "Gray"]
|
164 |
+
str_attrs = [str(v) for v in set(labels2attrs.values())]
|
165 |
+
print(str_attrs, len(str_attrs), len(color_panel_2))
|
166 |
+
assert len(str_attrs) == len(color_panel_2)
|
167 |
+
attr_colormap = {a:c for a, c in zip(str_attrs, color_panel_2)}
|
168 |
+
# attr_colormap = {
|
169 |
+
# ("specific", "dynamic", "episodic"):
|
170 |
+
# ("generic", "dynamic", "episodic"):
|
171 |
+
# ("specific", "dynamic", "static"):
|
172 |
+
# ("generic", "dynamic", "static"):
|
173 |
+
# ("specific", "stative", "static"):
|
174 |
+
# ("specific", "dynamic", "static"):
|
175 |
+
# ("generic", "dynamic", "static"):
|
176 |
+
# ("specific", "dynamic", "episodic"):
|
177 |
+
# ("generic", "dynamic", "episodic"):
|
178 |
+
# ("generic", "dynamic", "habitual"):
|
179 |
+
# ("generic", "stative", "static"):
|
180 |
+
# ("generic", "stative", "habitual"):
|
181 |
+
# ("specific", "dynamic", "habitual"):
|
182 |
+
# ("specific", "stative", "habitual"):
|
183 |
+
# ("NA", "NA", "NA"):
|
184 |
+
# }
|
185 |
+
|
186 |
+
|
187 |
+
demo = gr.Interface(
|
188 |
+
fn=run_pipeline,
|
189 |
+
inputs=["text"],
|
190 |
+
outputs= [
|
191 |
+
gr.HighlightedText(
|
192 |
+
label="Clause Segmentation",
|
193 |
+
show_label=True,
|
194 |
+
combine_adjacent=False,
|
195 |
+
).style(color_map = index_colormap),
|
196 |
+
|
197 |
+
gr.HighlightedText(
|
198 |
+
label="Attribute Classification",
|
199 |
+
show_label=True,
|
200 |
+
show_legend=True,
|
201 |
+
combine_adjacent=False,
|
202 |
+
).style(color_map=attr_colormap),
|
203 |
+
]
|
204 |
+
)
|
205 |
+
|
206 |
+
demo.launch(share=True)
|