Update app.py
Browse files
app.py
CHANGED
|
@@ -15,7 +15,7 @@ def text_vector(text):
|
|
| 15 |
def centroid(docs):
|
| 16 |
C = Counter()
|
| 17 |
for d in docs:
|
| 18 |
-
C.update(text_vector(d
|
| 19 |
return C
|
| 20 |
|
| 21 |
def cosine(a, b):
|
|
@@ -60,12 +60,11 @@ def initialize_state(records):
|
|
| 60 |
"Ndocs": Ndocs,
|
| 61 |
"avg_len": avg_len,
|
| 62 |
"centroids": centroids
|
| 63 |
-
}
|
| 64 |
-
|
| 65 |
|
| 66 |
def load_jsonl(user_file):
|
| 67 |
if user_file is None:
|
| 68 |
-
return None,
|
| 69 |
|
| 70 |
records = []
|
| 71 |
with open(user_file.name, "r", encoding="utf8") as f:
|
|
@@ -74,13 +73,12 @@ def load_jsonl(user_file):
|
|
| 74 |
records.append(json.loads(line))
|
| 75 |
except:
|
| 76 |
pass
|
| 77 |
-
return initialize_state(records)
|
| 78 |
-
|
| 79 |
|
| 80 |
def load_default():
|
| 81 |
path = "epstein_semantic.jsonl"
|
| 82 |
if not os.path.exists(path):
|
| 83 |
-
return None,
|
| 84 |
|
| 85 |
records = []
|
| 86 |
with open(path, "r", encoding="utf8") as f:
|
|
@@ -89,7 +87,7 @@ def load_default():
|
|
| 89 |
records.append(json.loads(line))
|
| 90 |
except:
|
| 91 |
pass
|
| 92 |
-
return initialize_state(records)
|
| 93 |
|
| 94 |
# =====================================================================
|
| 95 |
# BM25
|
|
@@ -108,7 +106,6 @@ def bm25_score(query, doc_toks, doc_freq, Ndocs, avg_len):
|
|
| 108 |
idf = math.log((Ndocs - df + 0.5) / (df + 0.5) + 1)
|
| 109 |
tf = doc_toks.count(q)
|
| 110 |
denom = tf + k * (1 - b + b * (len(doc_toks) / avg_len))
|
| 111 |
-
|
| 112 |
score += idf * (tf * (k + 1)) / denom
|
| 113 |
|
| 114 |
return score
|
|
@@ -124,7 +121,7 @@ def do_view_cluster(state, cid):
|
|
| 124 |
try:
|
| 125 |
cid = int(cid)
|
| 126 |
except:
|
| 127 |
-
return "Enter a valid number."
|
| 128 |
|
| 129 |
cluster_map = state["cluster_map"]
|
| 130 |
|
|
@@ -133,7 +130,6 @@ def do_view_cluster(state, cid):
|
|
| 133 |
|
| 134 |
out = [f"=== Cluster {cid} ({len(cluster_map[cid])} docs) ===\n"]
|
| 135 |
|
| 136 |
-
# show all docs, untruncated
|
| 137 |
for d in cluster_map[cid]:
|
| 138 |
rid = d.get("id", "unknown")
|
| 139 |
out.append(f"\n--- id={rid} ---\n{d.get('text','')}\n")
|
|
@@ -152,7 +148,6 @@ def do_search(state, query):
|
|
| 152 |
if score > 0:
|
| 153 |
results.append((score, r))
|
| 154 |
|
| 155 |
-
# FIX: sort by score, not dict
|
| 156 |
results.sort(key=lambda x: x[0], reverse=True)
|
| 157 |
|
| 158 |
out = [f"=== Results for '{query}' ==="]
|
|
@@ -177,11 +172,12 @@ subject re fw message thereof all may any doc email said
|
|
| 177 |
out = ["=== Cluster Topics ==="]
|
| 178 |
|
| 179 |
for cid, cent in state["centroids"].items():
|
| 180 |
-
filtered = {
|
| 181 |
-
|
|
|
|
|
|
|
| 182 |
|
| 183 |
top = [w for w, _ in Counter(filtered).most_common(12)]
|
| 184 |
-
|
| 185 |
out.append(f"Cluster {cid:<4} | {' '.join(top)}")
|
| 186 |
|
| 187 |
return "\n".join(out)
|
|
@@ -207,16 +203,22 @@ def do_entity_search(state, name):
|
|
| 207 |
return "\n".join(out)
|
| 208 |
|
| 209 |
# =====================================================================
|
| 210 |
-
#
|
| 211 |
# =====================================================================
|
| 212 |
|
| 213 |
-
|
|
|
|
| 214 |
|
| 215 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
|
| 217 |
gr.Markdown("# **Epstein Semantic Explorer**")
|
| 218 |
gr.Markdown(startup_msg)
|
| 219 |
|
|
|
|
| 220 |
with gr.Tab("View Cluster"):
|
| 221 |
cluster_num = gr.Number(label="Cluster #", value=96)
|
| 222 |
out_cluster = gr.Textbox(label="Cluster Output", lines=40)
|
|
@@ -236,11 +238,17 @@ with gr.Blocks(title="Epstein Semantic Explorer", css="#output {white-space: pre
|
|
| 236 |
out_topics = gr.Textbox(label="Topics", lines=40)
|
| 237 |
gr.Button("Show Topics").click(do_show_topics, [startup_state], out_topics)
|
| 238 |
|
| 239 |
-
#
|
| 240 |
with gr.Tab("Upload Different Dataset"):
|
| 241 |
-
|
| 242 |
-
load_btn = gr.Button("Load
|
| 243 |
-
|
| 244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
|
| 246 |
demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)
|
|
|
|
| 15 |
def centroid(docs):
|
| 16 |
C = Counter()
|
| 17 |
for d in docs:
|
| 18 |
+
C.update(text_vector(d.get("text", "")))
|
| 19 |
return C
|
| 20 |
|
| 21 |
def cosine(a, b):
|
|
|
|
| 60 |
"Ndocs": Ndocs,
|
| 61 |
"avg_len": avg_len,
|
| 62 |
"centroids": centroids
|
| 63 |
+
}
|
|
|
|
| 64 |
|
| 65 |
def load_jsonl(user_file):
|
| 66 |
if user_file is None:
|
| 67 |
+
return None, "⚠ No file uploaded."
|
| 68 |
|
| 69 |
records = []
|
| 70 |
with open(user_file.name, "r", encoding="utf8") as f:
|
|
|
|
| 73 |
records.append(json.loads(line))
|
| 74 |
except:
|
| 75 |
pass
|
| 76 |
+
return initialize_state(records), f"Loaded {len(records)} records."
|
|
|
|
| 77 |
|
| 78 |
def load_default():
|
| 79 |
path = "epstein_semantic.jsonl"
|
| 80 |
if not os.path.exists(path):
|
| 81 |
+
return None, "⚠ No default dataset found."
|
| 82 |
|
| 83 |
records = []
|
| 84 |
with open(path, "r", encoding="utf8") as f:
|
|
|
|
| 87 |
records.append(json.loads(line))
|
| 88 |
except:
|
| 89 |
pass
|
| 90 |
+
return initialize_state(records), f"Loaded {len(records)} records."
|
| 91 |
|
| 92 |
# =====================================================================
|
| 93 |
# BM25
|
|
|
|
| 106 |
idf = math.log((Ndocs - df + 0.5) / (df + 0.5) + 1)
|
| 107 |
tf = doc_toks.count(q)
|
| 108 |
denom = tf + k * (1 - b + b * (len(doc_toks) / avg_len))
|
|
|
|
| 109 |
score += idf * (tf * (k + 1)) / denom
|
| 110 |
|
| 111 |
return score
|
|
|
|
| 121 |
try:
|
| 122 |
cid = int(cid)
|
| 123 |
except:
|
| 124 |
+
return "Enter a valid cluster number."
|
| 125 |
|
| 126 |
cluster_map = state["cluster_map"]
|
| 127 |
|
|
|
|
| 130 |
|
| 131 |
out = [f"=== Cluster {cid} ({len(cluster_map[cid])} docs) ===\n"]
|
| 132 |
|
|
|
|
| 133 |
for d in cluster_map[cid]:
|
| 134 |
rid = d.get("id", "unknown")
|
| 135 |
out.append(f"\n--- id={rid} ---\n{d.get('text','')}\n")
|
|
|
|
| 148 |
if score > 0:
|
| 149 |
results.append((score, r))
|
| 150 |
|
|
|
|
| 151 |
results.sort(key=lambda x: x[0], reverse=True)
|
| 152 |
|
| 153 |
out = [f"=== Results for '{query}' ==="]
|
|
|
|
| 172 |
out = ["=== Cluster Topics ==="]
|
| 173 |
|
| 174 |
for cid, cent in state["centroids"].items():
|
| 175 |
+
filtered = {
|
| 176 |
+
w: c for w, c in cent.items()
|
| 177 |
+
if w not in STOPWORDS and len(w) > 2 and c > 1
|
| 178 |
+
}
|
| 179 |
|
| 180 |
top = [w for w, _ in Counter(filtered).most_common(12)]
|
|
|
|
| 181 |
out.append(f"Cluster {cid:<4} | {' '.join(top)}")
|
| 182 |
|
| 183 |
return "\n".join(out)
|
|
|
|
| 203 |
return "\n".join(out)
|
| 204 |
|
| 205 |
# =====================================================================
|
| 206 |
+
# Startup
|
| 207 |
# =====================================================================
|
| 208 |
|
| 209 |
+
startup_state_raw, startup_msg = load_default()
|
| 210 |
+
startup_state = gr.State(startup_state_raw)
|
| 211 |
|
| 212 |
+
# =====================================================================
|
| 213 |
+
# UI
|
| 214 |
+
# =====================================================================
|
| 215 |
+
|
| 216 |
+
with gr.Blocks(title="Epstein Semantic Explorer") as demo:
|
| 217 |
|
| 218 |
gr.Markdown("# **Epstein Semantic Explorer**")
|
| 219 |
gr.Markdown(startup_msg)
|
| 220 |
|
| 221 |
+
# Tabs
|
| 222 |
with gr.Tab("View Cluster"):
|
| 223 |
cluster_num = gr.Number(label="Cluster #", value=96)
|
| 224 |
out_cluster = gr.Textbox(label="Cluster Output", lines=40)
|
|
|
|
| 238 |
out_topics = gr.Textbox(label="Topics", lines=40)
|
| 239 |
gr.Button("Show Topics").click(do_show_topics, [startup_state], out_topics)
|
| 240 |
|
| 241 |
+
# Upload override
|
| 242 |
with gr.Tab("Upload Different Dataset"):
|
| 243 |
+
file_up = gr.File(label="Upload JSONL")
|
| 244 |
+
load_btn = gr.Button("Load")
|
| 245 |
+
load_msg = gr.Textbox(label="Status", lines=2)
|
| 246 |
+
|
| 247 |
+
def apply_upload(file):
|
| 248 |
+
new_state, msg = load_jsonl(file)
|
| 249 |
+
startup_state.value = new_state
|
| 250 |
+
return msg
|
| 251 |
+
|
| 252 |
+
load_btn.click(apply_upload, [file_up], load_msg)
|
| 253 |
|
| 254 |
demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)
|