Frederick commited on
Commit
202ae33
·
1 Parent(s): 95345e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -4
app.py CHANGED
@@ -127,10 +127,34 @@ def get_pred_classification_labels(clauses, batch_size=32):
127
  logits = classification_model(**model_inputs)[0]
128
  pred_labels = logits.argmax(-1).numpy()
129
  pred_labels = [index2label[l] for l in pred_labels]
130
- clause2labels.extend([(s, str(labels2attrs[l]),) for s,l in zip(batch_examples, pred_labels)])
131
- print(clause2labels)
132
  return clause2labels
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
 
136
  def run_pipeline(text):
@@ -143,7 +167,10 @@ def run_pipeline(text):
143
 
144
  clause2labels = get_pred_classification_labels(all_clauses)
145
  output_clauses = [(c, str(i + 1)) for i, c in enumerate(all_clauses)]
146
- return output_clauses, clause2labels
 
 
 
147
 
148
  # with open("pipeline_outputs.jsonl", "w") as fw:
149
  # with open("all_text.txt", "r") as f:
@@ -165,7 +192,7 @@ color_panel_1 = ["red", "green", "yellow", "DodgerBlue", "orange", "DarkSalmon
165
  index_colormap = {str(i) : color_panel_1[i % len(color_panel_1)] for i in np.arange(1, 100000)}
166
  color_panel_2 = ["Gray", "DodgerBlue", "Wheat", "OliveDrab", "DarkKhaki", "DarkSalmon", "Orange", "Gold", "Aqua", "Tomato", "Violet"]
167
  str_attrs = sorted([str(v) for v in set(labels2attrs.values())])
168
- print(str_attrs, len(str_attrs), len(color_panel_2))
169
  assert len(str_attrs) == len(color_panel_2)
170
  attr_colormap = {a:c for a, c in zip(str_attrs, color_panel_2)}
171
  # attr_colormap = {
@@ -203,6 +230,8 @@ demo = gr.Interface(
203
  show_legend=True,
204
  combine_adjacent=False,
205
  ).style(color_map=attr_colormap),
 
 
206
  ]
207
  )
208
 
 
127
  logits = classification_model(**model_inputs)[0]
128
  pred_labels = logits.argmax(-1).numpy()
129
  pred_labels = [index2label[l] for l in pred_labels]
130
+ clause2labels.extend([(s, labels2attrs[l],) for s,l in zip(batch_examples, pred_labels)])
 
131
  return clause2labels
132
 
133
+ def label_visualization(clause2labels):
134
+ total_clauses = len(clause2labels)
135
+ aspect_labels, genericity_labels, boundedness_labels = [], [], []
136
+ for _, labels in clause2labels:
137
+ labels = tuple(labels)
138
+ print(labels)
139
+
140
+ genericity_label = labels[0]
141
+ aspect_label = labels[1]
142
+ boundedness_label = labels[2]
143
+ aspect_labels.append(aspect_label)
144
+ genericity_labels.append(genericity_label)
145
+ boundedness_labels.append(boundedness_label)
146
+ aspect_dict = {"Dynamic": aspect_labels.count("dynamic"), "Stative": aspect_labels.count("stative")}
147
+ genericity_dict = {"Generic": genericity_labels.count("generic"), "Specific": genericity_labels.count("specific")}
148
+ boundedness_dict = {"Static": boundedness_labels.count("static"), "Episodic": boundedness_labels.count("episodic"), "Habitual": aspect_labels.count("habitual")}
149
+ print(aspect_dict, genericity_dict, boundedness_dict)
150
+ fig, axs = plt.subplots(1, 3, figsize=(10, 6,))
151
+ axs[0].pie([float(v / total_clauses) for v in aspect_dict.values()], colors = sns.color_palette('pastel')[0:3], labels=aspect_dict.keys(), autopct='%.0f%%', normalize=True )
152
+ axs[0].set_title("Aspect")
153
+ axs[1].pie([float(v / total_clauses) for v in genericity_dict.values()], colors = sns.color_palette('pastel')[3:6], labels=genericity_dict.keys(), autopct='%.0f%%', normalize=True)
154
+ axs[1].set_title("Genericity")
155
+ axs[2].pie([float(v / total_clauses) for v in boundedness_dict.values()], colors = sns.color_palette('pastel')[8:10], labels=boundedness_dict.keys(), autopct='%.0f%%', normalize=True)
156
+ axs[2].set_title("Boundedness")
157
+ return fig
158
 
159
 
160
  def run_pipeline(text):
 
167
 
168
  clause2labels = get_pred_classification_labels(all_clauses)
169
  output_clauses = [(c, str(i + 1)) for i, c in enumerate(all_clauses)]
170
+ figure = label_visualization(clause2labels)
171
+ clause2labels = [(k,str(v),) for k, v in clause2labels]
172
+ return output_clauses, clause2labels, figure
173
+
174
 
175
  # with open("pipeline_outputs.jsonl", "w") as fw:
176
  # with open("all_text.txt", "r") as f:
 
192
  index_colormap = {str(i) : color_panel_1[i % len(color_panel_1)] for i in np.arange(1, 100000)}
193
  color_panel_2 = ["Gray", "DodgerBlue", "Wheat", "OliveDrab", "DarkKhaki", "DarkSalmon", "Orange", "Gold", "Aqua", "Tomato", "Violet"]
194
  str_attrs = sorted([str(v) for v in set(labels2attrs.values())])
195
+ # print(str_attrs, len(str_attrs), len(color_panel_2))
196
  assert len(str_attrs) == len(color_panel_2)
197
  attr_colormap = {a:c for a, c in zip(str_attrs, color_panel_2)}
198
  # attr_colormap = {
 
230
  show_legend=True,
231
  combine_adjacent=False,
232
  ).style(color_map=attr_colormap),
233
+
234
+ gr.Plot(),
235
  ]
236
  )
237