cjc0013 commited on
Commit
5e57bf1
·
verified ·
1 Parent(s): dff3605

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -23
app.py CHANGED
@@ -15,7 +15,7 @@ def text_vector(text):
15
  def centroid(docs):
16
  C = Counter()
17
  for d in docs:
18
- C.update(text_vector(d["text"]))
19
  return C
20
 
21
  def cosine(a, b):
@@ -60,12 +60,11 @@ def initialize_state(records):
60
  "Ndocs": Ndocs,
61
  "avg_len": avg_len,
62
  "centroids": centroids
63
- }, sorted(cluster_map.keys()), f"Loaded {len(records)} records."
64
-
65
 
66
  def load_jsonl(user_file):
67
  if user_file is None:
68
- return None, None, "⚠ No file uploaded."
69
 
70
  records = []
71
  with open(user_file.name, "r", encoding="utf8") as f:
@@ -74,13 +73,12 @@ def load_jsonl(user_file):
74
  records.append(json.loads(line))
75
  except:
76
  pass
77
- return initialize_state(records)
78
-
79
 
80
  def load_default():
81
  path = "epstein_semantic.jsonl"
82
  if not os.path.exists(path):
83
- return None, None, "⚠ Upload a dataset to begin."
84
 
85
  records = []
86
  with open(path, "r", encoding="utf8") as f:
@@ -89,7 +87,7 @@ def load_default():
89
  records.append(json.loads(line))
90
  except:
91
  pass
92
- return initialize_state(records)
93
 
94
  # =====================================================================
95
  # BM25
@@ -108,7 +106,6 @@ def bm25_score(query, doc_toks, doc_freq, Ndocs, avg_len):
108
  idf = math.log((Ndocs - df + 0.5) / (df + 0.5) + 1)
109
  tf = doc_toks.count(q)
110
  denom = tf + k * (1 - b + b * (len(doc_toks) / avg_len))
111
-
112
  score += idf * (tf * (k + 1)) / denom
113
 
114
  return score
@@ -124,7 +121,7 @@ def do_view_cluster(state, cid):
124
  try:
125
  cid = int(cid)
126
  except:
127
- return "Enter a valid number."
128
 
129
  cluster_map = state["cluster_map"]
130
 
@@ -133,7 +130,6 @@ def do_view_cluster(state, cid):
133
 
134
  out = [f"=== Cluster {cid} ({len(cluster_map[cid])} docs) ===\n"]
135
 
136
- # show all docs, untruncated
137
  for d in cluster_map[cid]:
138
  rid = d.get("id", "unknown")
139
  out.append(f"\n--- id={rid} ---\n{d.get('text','')}\n")
@@ -152,7 +148,6 @@ def do_search(state, query):
152
  if score > 0:
153
  results.append((score, r))
154
 
155
- # FIX: sort by score, not dict
156
  results.sort(key=lambda x: x[0], reverse=True)
157
 
158
  out = [f"=== Results for '{query}' ==="]
@@ -177,11 +172,12 @@ subject re fw message thereof all may any doc email said
177
  out = ["=== Cluster Topics ==="]
178
 
179
  for cid, cent in state["centroids"].items():
180
- filtered = {w: c for w, c in cent.items()
181
- if w not in STOPWORDS and len(w) > 2 and c > 1}
 
 
182
 
183
  top = [w for w, _ in Counter(filtered).most_common(12)]
184
-
185
  out.append(f"Cluster {cid:<4} | {' '.join(top)}")
186
 
187
  return "\n".join(out)
@@ -207,16 +203,22 @@ def do_entity_search(state, name):
207
  return "\n".join(out)
208
 
209
  # =====================================================================
210
- # UI Layout
211
  # =====================================================================
212
 
213
- startup_state, startup_clusters, startup_msg = load_default()
 
214
 
215
- with gr.Blocks(title="Epstein Semantic Explorer", css="#output {white-space: pre-wrap;}") as demo:
 
 
 
 
216
 
217
  gr.Markdown("# **Epstein Semantic Explorer**")
218
  gr.Markdown(startup_msg)
219
 
 
220
  with gr.Tab("View Cluster"):
221
  cluster_num = gr.Number(label="Cluster #", value=96)
222
  out_cluster = gr.Textbox(label="Cluster Output", lines=40)
@@ -236,11 +238,17 @@ with gr.Blocks(title="Epstein Semantic Explorer", css="#output {white-space: pre
236
  out_topics = gr.Textbox(label="Topics", lines=40)
237
  gr.Button("Show Topics").click(do_show_topics, [startup_state], out_topics)
238
 
239
- # File Upload (override default)
240
  with gr.Tab("Upload Different Dataset"):
241
- jsonl_file = gr.File(label="Upload JSONL")
242
- load_btn = gr.Button("Load Dataset")
243
- load_out = gr.Textbox(label="Status", lines=2)
244
- load_btn.click(load_jsonl, [jsonl_file], [startup_state, cluster_num, load_out])
 
 
 
 
 
 
245
 
246
  demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)
 
15
  def centroid(docs):
16
  C = Counter()
17
  for d in docs:
18
+ C.update(text_vector(d.get("text", "")))
19
  return C
20
 
21
  def cosine(a, b):
 
60
  "Ndocs": Ndocs,
61
  "avg_len": avg_len,
62
  "centroids": centroids
63
+ }
 
64
 
65
  def load_jsonl(user_file):
66
  if user_file is None:
67
+ return None, "⚠ No file uploaded."
68
 
69
  records = []
70
  with open(user_file.name, "r", encoding="utf8") as f:
 
73
  records.append(json.loads(line))
74
  except:
75
  pass
76
+ return initialize_state(records), f"Loaded {len(records)} records."
 
77
 
78
  def load_default():
79
  path = "epstein_semantic.jsonl"
80
  if not os.path.exists(path):
81
+ return None, "⚠ No default dataset found."
82
 
83
  records = []
84
  with open(path, "r", encoding="utf8") as f:
 
87
  records.append(json.loads(line))
88
  except:
89
  pass
90
+ return initialize_state(records), f"Loaded {len(records)} records."
91
 
92
  # =====================================================================
93
  # BM25
 
106
  idf = math.log((Ndocs - df + 0.5) / (df + 0.5) + 1)
107
  tf = doc_toks.count(q)
108
  denom = tf + k * (1 - b + b * (len(doc_toks) / avg_len))
 
109
  score += idf * (tf * (k + 1)) / denom
110
 
111
  return score
 
121
  try:
122
  cid = int(cid)
123
  except:
124
+ return "Enter a valid cluster number."
125
 
126
  cluster_map = state["cluster_map"]
127
 
 
130
 
131
  out = [f"=== Cluster {cid} ({len(cluster_map[cid])} docs) ===\n"]
132
 
 
133
  for d in cluster_map[cid]:
134
  rid = d.get("id", "unknown")
135
  out.append(f"\n--- id={rid} ---\n{d.get('text','')}\n")
 
148
  if score > 0:
149
  results.append((score, r))
150
 
 
151
  results.sort(key=lambda x: x[0], reverse=True)
152
 
153
  out = [f"=== Results for '{query}' ==="]
 
172
  out = ["=== Cluster Topics ==="]
173
 
174
  for cid, cent in state["centroids"].items():
175
+ filtered = {
176
+ w: c for w, c in cent.items()
177
+ if w not in STOPWORDS and len(w) > 2 and c > 1
178
+ }
179
 
180
  top = [w for w, _ in Counter(filtered).most_common(12)]
 
181
  out.append(f"Cluster {cid:<4} | {' '.join(top)}")
182
 
183
  return "\n".join(out)
 
203
  return "\n".join(out)
204
 
205
  # =====================================================================
206
+ # Startup
207
  # =====================================================================
208
 
209
+ startup_state_raw, startup_msg = load_default()
210
+ startup_state = gr.State(startup_state_raw)
211
 
212
+ # =====================================================================
213
+ # UI
214
+ # =====================================================================
215
+
216
+ with gr.Blocks(title="Epstein Semantic Explorer") as demo:
217
 
218
  gr.Markdown("# **Epstein Semantic Explorer**")
219
  gr.Markdown(startup_msg)
220
 
221
+ # Tabs
222
  with gr.Tab("View Cluster"):
223
  cluster_num = gr.Number(label="Cluster #", value=96)
224
  out_cluster = gr.Textbox(label="Cluster Output", lines=40)
 
238
  out_topics = gr.Textbox(label="Topics", lines=40)
239
  gr.Button("Show Topics").click(do_show_topics, [startup_state], out_topics)
240
 
241
+ # Upload override
242
  with gr.Tab("Upload Different Dataset"):
243
+ file_up = gr.File(label="Upload JSONL")
244
+ load_btn = gr.Button("Load")
245
+ load_msg = gr.Textbox(label="Status", lines=2)
246
+
247
+ def apply_upload(file):
248
+ new_state, msg = load_jsonl(file)
249
+ startup_state.value = new_state
250
+ return msg
251
+
252
+ load_btn.click(apply_upload, [file_up], load_msg)
253
 
254
  demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)