Update app.py
Browse files
app.py
CHANGED
@@ -127,10 +127,34 @@ def get_pred_classification_labels(clauses, batch_size=32):
|
|
127 |
logits = classification_model(**model_inputs)[0]
|
128 |
pred_labels = logits.argmax(-1).numpy()
|
129 |
pred_labels = [index2label[l] for l in pred_labels]
|
130 |
-
clause2labels.extend([(s,
|
131 |
-
print(clause2labels)
|
132 |
return clause2labels
|
133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
|
136 |
def run_pipeline(text):
|
@@ -143,7 +167,10 @@ def run_pipeline(text):
|
|
143 |
|
144 |
clause2labels = get_pred_classification_labels(all_clauses)
|
145 |
output_clauses = [(c, str(i + 1)) for i, c in enumerate(all_clauses)]
|
146 |
-
|
|
|
|
|
|
|
147 |
|
148 |
# with open("pipeline_outputs.jsonl", "w") as fw:
|
149 |
# with open("all_text.txt", "r") as f:
|
@@ -165,7 +192,7 @@ color_panel_1 = ["red", "green", "yellow", "DodgerBlue", "orange", "DarkSalmon
|
|
165 |
index_colormap = {str(i) : color_panel_1[i % len(color_panel_1)] for i in np.arange(1, 100000)}
|
166 |
color_panel_2 = ["Gray", "DodgerBlue", "Wheat", "OliveDrab", "DarkKhaki", "DarkSalmon", "Orange", "Gold", "Aqua", "Tomato", "Violet"]
|
167 |
str_attrs = sorted([str(v) for v in set(labels2attrs.values())])
|
168 |
-
print(str_attrs, len(str_attrs), len(color_panel_2))
|
169 |
assert len(str_attrs) == len(color_panel_2)
|
170 |
attr_colormap = {a:c for a, c in zip(str_attrs, color_panel_2)}
|
171 |
# attr_colormap = {
|
@@ -203,6 +230,8 @@ demo = gr.Interface(
|
|
203 |
show_legend=True,
|
204 |
combine_adjacent=False,
|
205 |
).style(color_map=attr_colormap),
|
|
|
|
|
206 |
]
|
207 |
)
|
208 |
|
|
|
127 |
logits = classification_model(**model_inputs)[0]
|
128 |
pred_labels = logits.argmax(-1).numpy()
|
129 |
pred_labels = [index2label[l] for l in pred_labels]
|
130 |
+
clause2labels.extend([(s, labels2attrs[l],) for s,l in zip(batch_examples, pred_labels)])
|
|
|
131 |
return clause2labels
|
132 |
|
133 |
+
def label_visualization(clause2labels):
|
134 |
+
total_clauses = len(clause2labels)
|
135 |
+
aspect_labels, genericity_labels, boundedness_labels = [], [], []
|
136 |
+
for _, labels in clause2labels:
|
137 |
+
labels = tuple(labels)
|
138 |
+
print(labels)
|
139 |
+
|
140 |
+
genericity_label = labels[0]
|
141 |
+
aspect_label = labels[1]
|
142 |
+
boundedness_label = labels[2]
|
143 |
+
aspect_labels.append(aspect_label)
|
144 |
+
genericity_labels.append(genericity_label)
|
145 |
+
boundedness_labels.append(boundedness_label)
|
146 |
+
aspect_dict = {"Dynamic": aspect_labels.count("dynamic"), "Stative": aspect_labels.count("stative")}
|
147 |
+
genericity_dict = {"Generic": genericity_labels.count("generic"), "Specific": genericity_labels.count("specific")}
|
148 |
+
boundedness_dict = {"Static": boundedness_labels.count("static"), "Episodic": boundedness_labels.count("episodic"), "Habitual": aspect_labels.count("habitual")}
|
149 |
+
print(aspect_dict, genericity_dict, boundedness_dict)
|
150 |
+
fig, axs = plt.subplots(1, 3, figsize=(10, 6,))
|
151 |
+
axs[0].pie([float(v / total_clauses) for v in aspect_dict.values()], colors = sns.color_palette('pastel')[0:3], labels=aspect_dict.keys(), autopct='%.0f%%', normalize=True )
|
152 |
+
axs[0].set_title("Aspect")
|
153 |
+
axs[1].pie([float(v / total_clauses) for v in genericity_dict.values()], colors = sns.color_palette('pastel')[3:6], labels=genericity_dict.keys(), autopct='%.0f%%', normalize=True)
|
154 |
+
axs[1].set_title("Genericity")
|
155 |
+
axs[2].pie([float(v / total_clauses) for v in boundedness_dict.values()], colors = sns.color_palette('pastel')[8:10], labels=boundedness_dict.keys(), autopct='%.0f%%', normalize=True)
|
156 |
+
axs[2].set_title("Boundedness")
|
157 |
+
return fig
|
158 |
|
159 |
|
160 |
def run_pipeline(text):
|
|
|
167 |
|
168 |
clause2labels = get_pred_classification_labels(all_clauses)
|
169 |
output_clauses = [(c, str(i + 1)) for i, c in enumerate(all_clauses)]
|
170 |
+
figure = label_visualization(clause2labels)
|
171 |
+
clause2labels = [(k,str(v),) for k, v in clause2labels]
|
172 |
+
return output_clauses, clause2labels, figure
|
173 |
+
|
174 |
|
175 |
# with open("pipeline_outputs.jsonl", "w") as fw:
|
176 |
# with open("all_text.txt", "r") as f:
|
|
|
192 |
index_colormap = {str(i) : color_panel_1[i % len(color_panel_1)] for i in np.arange(1, 100000)}
|
193 |
color_panel_2 = ["Gray", "DodgerBlue", "Wheat", "OliveDrab", "DarkKhaki", "DarkSalmon", "Orange", "Gold", "Aqua", "Tomato", "Violet"]
|
194 |
str_attrs = sorted([str(v) for v in set(labels2attrs.values())])
|
195 |
+
# print(str_attrs, len(str_attrs), len(color_panel_2))
|
196 |
assert len(str_attrs) == len(color_panel_2)
|
197 |
attr_colormap = {a:c for a, c in zip(str_attrs, color_panel_2)}
|
198 |
# attr_colormap = {
|
|
|
230 |
show_legend=True,
|
231 |
combine_adjacent=False,
|
232 |
).style(color_map=attr_colormap),
|
233 |
+
|
234 |
+
gr.Plot(),
|
235 |
]
|
236 |
)
|
237 |
|