Harshit494 commited on
Commit
3c24702
·
verified ·
1 Parent(s): dd6864b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +273 -386
app.py CHANGED
@@ -1,10 +1,12 @@
1
  """
2
  BALM-PPI Pro · ESM-2 + LoRA + Integrated Gradients
3
- Light-first theme, auto dark-mode via prefers-color-scheme,
4
- theme-reactive NGL viewer, cached PDB fetches.
 
 
 
5
  """
6
 
7
- # ── Speed tweaks — must be BEFORE any HF/torch imports ─────────────────────
8
  import os
9
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
10
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -19,7 +21,7 @@ import requests
19
  from Bio import PDB
20
  from Bio.Data.PDBData import protein_letters_3to1 as THREE_TO_ONE
21
 
22
- st.set_page_config(page_title="BALM-PPI ", page_icon="🧬", layout="wide",
23
  initial_sidebar_state="expanded")
24
 
25
  HF_REPO_ID = "Harshit494/BALM-PPI"
@@ -39,7 +41,6 @@ EXAMPLES = [
39
  },
40
  ]
41
 
42
- # ── Session state ─────────────────────────────────────────────────────────
43
  _DEFAULTS: dict = {
44
  "model": None, "device": None,
45
  "result": None, "ig_a": None, "ig_b": None, "ig_chain_map": None,
@@ -54,13 +55,11 @@ for _k, _v in _DEFAULTS.items():
54
  st.session_state[_k] = _v
55
 
56
  # ══════════════════════════════════════════════════════════════════════════════
57
- # CSS — light-first with @media dark override
58
  # ══════════════════════════════════════════════════════════════════════════════
59
  st.markdown("""
60
  <style>
61
  @import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400;500;600;700&family=Inter:wght@300;400;500;600;700&display=swap');
62
-
63
- /* LIGHT tokens */
64
  :root {
65
  --bg0:#ffffff; --bg1:#f8f9fc; --bg2:#f1f4f9; --bg3:#e4e9f2;
66
  --border:rgba(37,99,235,0.13); --shadow:rgba(37,99,235,0.07);
@@ -69,7 +68,6 @@ st.markdown("""
69
  --text0:#0f172a; --text1:#334155; --text2:#64748b; --text3:#94a3b8;
70
  --mono:'JetBrains Mono',monospace; --sans:'Inter',sans-serif;
71
  }
72
- /* DARK tokens */
73
  @media (prefers-color-scheme:dark){
74
  :root {
75
  --bg0:#06090f; --bg1:#0b1120; --bg2:#101828; --bg3:#172135;
@@ -79,213 +77,111 @@ st.markdown("""
79
  --text0:#f1f5f9; --text1:#cbd5e1; --text2:#64748b; --text3:#334155;
80
  }
81
  }
82
-
83
- /* Global */
84
  html,body,[data-testid="stAppViewContainer"],[data-testid="stMain"]{
85
- background:var(--bg0)!important; font-family:var(--sans)!important;
86
  }
87
  [data-testid="stHeader"]{
88
- background:var(--bg0)!important;
89
- border-bottom:1px solid var(--border)!important;
90
  box-shadow:0 1px 8px var(--shadow)!important;
91
  }
92
- [data-testid="stSidebar"]{
93
- background:var(--bg1)!important;
94
- border-right:1px solid var(--border)!important;
95
- }
96
  [data-testid="stSidebar"] *{font-family:var(--sans)!important;}
97
-
98
- /* Typography */
99
  h1,h2,h3,h4,h5,h6{font-family:var(--sans)!important;color:var(--text0)!important;font-weight:700!important;}
100
  p,label,div,span,li{color:var(--text1);}
101
  code,pre{font-family:var(--mono)!important;font-size:.82rem!important;
102
  background:var(--bg2)!important;border-radius:4px!important;color:var(--accent)!important;}
103
-
104
- /* Buttons */
105
- .stButton>button{
106
- font-family:var(--sans)!important;font-weight:600!important;
107
- letter-spacing:.01em!important;border-radius:8px!important;
108
- transition:all .15s ease!important;height:38px!important;
109
- }
110
  .stButton>button[kind="primary"]{
111
  background:linear-gradient(135deg,var(--accent),var(--accent2))!important;
112
- border:none!important;color:#fff!important;
113
- box-shadow:0 2px 12px var(--shadow)!important;
114
- }
115
- .stButton>button[kind="primary"]:hover{
116
- box-shadow:0 4px 20px rgba(37,99,235,.28)!important;
117
- transform:translateY(-1px)!important;
118
- }
119
- .stButton>button[kind="secondary"]{
120
- background:var(--bg1)!important;
121
- border:1px solid var(--border)!important;
122
- color:var(--text1)!important;
123
- }
124
- .stButton>button[kind="secondary"]:hover{
125
- border-color:var(--accent)!important;color:var(--accent)!important;
126
- background:rgba(37,99,235,.05)!important;
127
- }
128
-
129
- /* Text areas */
130
- .stTextArea textarea{
131
- background:var(--bg0)!important;border:1.5px solid var(--border)!important;
132
- border-radius:8px!important;color:var(--text0)!important;
133
- font-family:var(--mono)!important;font-size:.82rem!important;
134
- line-height:1.65!important;resize:vertical!important;
135
- }
136
- .stTextArea textarea:focus{
137
- border-color:var(--accent)!important;
138
- box-shadow:0 0 0 3px rgba(37,99,235,.1)!important;
139
- }
140
- .stTextArea label{font-weight:600!important;font-size:.76rem!important;
141
- text-transform:uppercase!important;letter-spacing:.09em!important;color:var(--text2)!important;}
142
-
143
- /* Inputs */
144
- .stTextInput input,.stNumberInput input{
145
- background:var(--bg0)!important;border:1.5px solid var(--border)!important;
146
- border-radius:8px!important;color:var(--text0)!important;
147
- }
148
- .stTextInput input:focus,.stNumberInput input:focus{
149
- border-color:var(--accent)!important;box-shadow:0 0 0 3px rgba(37,99,235,.1)!important;
150
- }
151
- .stTextInput label,.stNumberInput label{
152
- font-size:.76rem!important;font-weight:600!important;
153
- color:var(--text2)!important;text-transform:uppercase!important;letter-spacing:.08em!important;
154
- }
155
-
156
- /* Radio/checkbox */
157
  .stRadio label,.stCheckbox label{color:var(--text1)!important;font-size:.9rem!important;}
158
-
159
- /* Tabs */
160
- .stTabs [data-baseweb="tab-list"]{
161
- background:var(--bg2)!important;border-radius:10px!important;
162
- padding:4px!important;gap:3px!important;border:1px solid var(--border)!important;
163
- }
164
- .stTabs [data-baseweb="tab"]{
165
- background:transparent!important;border-radius:7px!important;
166
- color:var(--text2)!important;font-family:var(--sans)!important;
167
- font-size:.82rem!important;font-weight:500!important;
168
- transition:all .14s ease!important;padding:6px 13px!important;
169
- }
170
- .stTabs [aria-selected="true"]{
171
- background:var(--bg0)!important;color:var(--accent)!important;
172
- font-weight:600!important;box-shadow:0 1px 6px var(--shadow)!important;
173
- }
174
-
175
- /* Metrics */
176
- [data-testid="stMetric"]{
177
- background:var(--bg1)!important;border:1px solid var(--border)!important;
178
- border-radius:10px!important;padding:14px 18px!important;
179
- box-shadow:0 1px 4px var(--shadow)!important;
180
- }
181
- [data-testid="stMetricLabel"]{
182
- color:var(--text2)!important;font-size:.72rem!important;
183
- text-transform:uppercase!important;letter-spacing:.1em!important;
184
- }
185
- [data-testid="stMetricValue"]{
186
- color:var(--accent)!important;font-family:var(--mono)!important;
187
- font-size:1.55rem!important;font-weight:700!important;
188
- }
189
-
190
- /* Misc */
191
  hr{border-color:var(--border)!important;margin:12px 0!important;}
192
  [data-testid="stAlert"]{background:var(--bg1)!important;border-radius:8px!important;border-left-width:3px!important;}
193
  [data-testid="stDataFrame"]{border:1px solid var(--border)!important;border-radius:8px!important;overflow:hidden!important;}
194
- .stProgress>div>div{
195
- background:linear-gradient(90deg,var(--accent),var(--accent2))!important;
196
- border-radius:4px!important;
197
- }
198
 
199
- /* ── Custom classes ─────────────────────────────────────────── */
200
  .app-header{display:flex;align-items:center;gap:14px;padding:4px 0 14px;}
201
- .app-logo{
202
- width:40px;height:40px;border-radius:10px;
203
  background:linear-gradient(135deg,var(--accent),var(--accent2));
204
  display:flex;align-items:center;justify-content:center;
205
- font-size:20px;flex-shrink:0;box-shadow:0 2px 12px rgba(37,99,235,.22);
206
- }
207
  .app-title{font-family:var(--mono)!important;font-size:1.22rem!important;
208
  font-weight:700!important;color:var(--text0)!important;letter-spacing:-.02em;}
209
- .app-subtitle{font-size:.77rem!important;color:var(--text2)!important;
210
- margin-top:1px;letter-spacing:.03em;}
211
-
212
- .sec-hdr{
213
- font-family:var(--mono);font-size:.65rem;font-weight:700;
214
- letter-spacing:.2em;text-transform:uppercase;
215
- color:var(--accent);margin:14px 0 8px;
216
- display:flex;align-items:center;gap:8px;
217
- }
218
  .sec-hdr::after{content:'';flex:1;height:1px;background:var(--border);}
219
-
220
- .ex-card{
221
- background:var(--bg1);border:1px solid var(--border);
222
- border-radius:10px;padding:12px 15px;margin-bottom:8px;
223
- position:relative;overflow:hidden;
224
- transition:box-shadow .15s,border-color .15s;
225
- }
226
- .ex-card::before{
227
- content:'';position:absolute;left:0;top:0;bottom:0;width:3px;
228
- background:linear-gradient(180deg,var(--accent),var(--accent2));
229
- border-radius:3px 0 0 3px;
230
- }
231
  .ex-card:hover{border-color:rgba(37,99,235,.3);box-shadow:0 2px 12px var(--shadow);}
232
  .ex-pdb{font-family:var(--mono);font-size:.95rem;font-weight:700;color:var(--accent);}
233
  .ex-sub{font-size:.8rem;font-weight:600;color:var(--text1);margin:2px 0;}
234
  .ex-desc{font-size:.72rem;line-height:1.5;color:var(--text2);margin-top:3px;}
235
-
236
- .pkd-card{
237
- background:linear-gradient(135deg,rgba(37,99,235,.05),rgba(124,58,237,.04));
238
- border:1px solid var(--border);border-radius:12px;padding:14px 18px;
239
- box-shadow:0 1px 6px var(--shadow);
240
- }
241
- .pkd-lbl{font-family:var(--mono);font-size:.67rem;color:var(--text2);
242
- text-transform:uppercase;letter-spacing:.12em;margin-bottom:2px;}
243
- .pkd-val{font-family:var(--mono);font-size:2rem;font-weight:700;
244
- color:var(--accent);line-height:1.1;}
245
- .pkd-badge{display:inline-block;padding:2px 9px;border-radius:20px;
246
- font-size:.7rem;font-weight:600;font-family:var(--mono);margin-top:5px;}
247
  .badge-weak{background:rgba(220,38,38,.08);color:var(--red);border:1px solid rgba(220,38,38,.22);}
248
  .badge-moderate{background:rgba(217,119,6,.08);color:var(--amber);border:1px solid rgba(217,119,6,.22);}
249
  .badge-strong{background:rgba(22,163,74,.08);color:var(--green);border:1px solid rgba(22,163,74,.22);}
250
-
251
  .str-bar{height:5px;border-radius:3px;background:var(--bg3);overflow:hidden;margin:8px 0 3px;}
252
- .str-fill{height:100%;border-radius:3px;
253
- background:linear-gradient(90deg,var(--accent),var(--accent2));
254
- transition:width .7s cubic-bezier(.4,0,.2,1);}
255
- .str-labels{display:flex;justify-content:space-between;
256
- font-size:.65rem;color:var(--text2);font-family:var(--mono);}
257
-
258
- .ready-badge{
259
- display:inline-flex;align-items:center;gap:6px;
260
- padding:3px 10px;border-radius:20px;
261
  background:rgba(22,163,74,.07);border:1px solid rgba(22,163,74,.22);
262
- font-family:var(--mono);font-size:.72rem;color:var(--green);font-weight:600;
263
- }
264
- .ready-dot{display:inline-block;width:6px;height:6px;border-radius:50%;
265
- background:var(--green);animation:pulse 2s infinite;}
266
- .idle-badge{
267
- display:inline-flex;align-items:center;gap:6px;
268
- padding:3px 10px;border-radius:20px;
269
  background:rgba(100,116,139,.07);border:1px solid rgba(100,116,139,.18);
270
- font-family:var(--mono);font-size:.72rem;color:var(--text2);font-weight:600;
271
- }
272
  @keyframes pulse{0%,100%{opacity:1;transform:scale(1)}50%{opacity:.5;transform:scale(.85)}}
273
-
274
- .model-section{background:var(--bg2);border:1px solid var(--border);
275
- border-radius:9px;padding:12px 14px;margin-bottom:10px;}
276
- .model-section-title{font-family:var(--mono);font-size:.63rem;font-weight:700;
277
- letter-spacing:.18em;text-transform:uppercase;color:var(--text2);margin-bottom:8px;}
278
-
279
- .ngl-placeholder{
280
- display:flex;flex-direction:column;align-items:center;justify-content:center;
281
- height:420px;background:var(--bg2);border:1px solid var(--border);
282
- border-radius:12px;color:var(--text2);
283
- font-family:var(--mono);font-size:.8rem;text-align:center;line-height:2;
284
- }
285
  </style>
286
  """, unsafe_allow_html=True)
287
 
288
- # postMessage bridge so NGL iframes get theme-change events
289
  st.markdown("""
290
  <script>
291
  (function(){
@@ -294,16 +190,14 @@ st.markdown("""
294
  try{f.contentWindow.postMessage({balmTheme:dark?'dark':'light'},'*');}catch(e){}
295
  });
296
  }
297
- window.matchMedia('(prefers-color-scheme:dark)').addEventListener('change',function(e){
298
- notify(e.matches);
299
- });
300
  })();
301
  </script>
302
  """, unsafe_allow_html=True)
303
 
304
 
305
  # ══════════════════════════════════════════════════════════════════════════════
306
- # PDB UTILITIES — PDB responses are cached on disk for 1 h
307
  # ══════════════════════════════════════════════════════════════════════════════
308
 
309
  @st.cache_data(show_spinner=False, ttl=3600)
@@ -314,6 +208,23 @@ def _fetch_pdb_cached(pdb_id: str) -> str:
314
  return r.text
315
 
316
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
  def get_chains_from_pdb(pdb_id: str):
318
  try:
319
  pdb_text = _fetch_pdb_cached(pdb_id.strip().upper())
@@ -367,17 +278,18 @@ def populate_from_pdb(pdb_id, mode, chain_a=None, chain_b=None,
367
  if not pdb_content or not chains: return False
368
  available = sorted(chains.keys())
369
  st.info(f"✅ **{pdb_id.upper()}** — chains: **{', '.join(available)}**")
370
- st.session_state.pdb_content = pdb_content
371
- st.session_state.ig_a = None
372
- st.session_state.ig_b = None
373
- st.session_state.ig_chain_map = None
374
- st.session_state.result = None
375
  st.session_state.chain_info_a = []
376
  st.session_state.chain_info_b = []
377
  st.session_state["_pm"] = mode
 
 
 
378
  if mode == "Protein-Protein":
379
  id_a = chain_a or available[0]
380
- id_b = chain_b or (available[1] if len(available)>1 else available[0])
381
  seq_a, info_a, miss_a = resolve_chains(chains, id_a, available)
382
  seq_b, info_b, miss_b = resolve_chains(chains, id_b, available)
383
  for m in miss_a: st.warning(f"Chain **{m}** not found (Side A). Available: {', '.join(available)}")
@@ -386,10 +298,11 @@ def populate_from_pdb(pdb_id, mode, chain_a=None, chain_b=None,
386
  st.session_state["_pb"] = seq_b
387
  st.session_state.chain_info_a = info_a
388
  st.session_state.chain_info_b = info_b
 
389
  else:
390
- id_h = chain_h or "H"
391
- id_l = chain_l or "L"
392
- id_ag_str = chain_ag or next((c for c in available if c not in (id_h,id_l)), available[0])
393
  def _single(cid):
394
  if cid in chains and chains[cid]["seq"]: return cid, chains[cid]
395
  f = next((k for k in chains if k.upper()==cid.upper() and chains[k]["seq"]), None)
@@ -405,6 +318,17 @@ def populate_from_pdb(pdb_id, mode, chain_a=None, chain_b=None,
405
  st.session_state["_pag"] = seq_ag
406
  st.session_state.chain_info_a = info_ag if info_ag else []
407
  st.session_state.chain_info_b = [{"id":real_h,**ch},{"id":real_l,**cl}]
 
 
 
 
 
 
 
 
 
 
 
408
  return True
409
 
410
 
@@ -444,12 +368,12 @@ def flat_ig_for_display(ig_a, ig_b, chain_info_a, chain_info_b):
444
  seq_a = "".join(c["seq"] for c in chain_info_a) if chain_info_a else ""
445
  seq_b = "".join(c["seq"] for c in chain_info_b) if chain_info_b else ""
446
  if ig_a and chain_info_a:
447
- parts = split_ig_by_chains(ig_a,[c["seq"] for c in chain_info_a])
448
  ig_a_out = [s for sub in parts for s in sub]
449
  else:
450
  ig_a_out = list(ig_a) if ig_a else []
451
  if ig_b and chain_info_b:
452
- parts = split_ig_by_chains(ig_b,[c["seq"] for c in chain_info_b])
453
  ig_b_out = [s for sub in parts for s in sub]
454
  else:
455
  ig_b_out = list(ig_b) if ig_b else []
@@ -458,22 +382,29 @@ def flat_ig_for_display(ig_a, ig_b, chain_info_a, chain_info_b):
458
 
459
  # ══════════════════════════════════════════════════════════════════════════════
460
  # MODEL
 
 
 
461
  # ══════════════════════════════════════════════════════════════════════════════
462
  @st.cache_resource(show_spinner="Loading ESM-2 backbone…")
463
  def _load_esm_base():
464
  from transformers import EsmModel, EsmTokenizer
465
  base = EsmModel.from_pretrained(
466
  "facebook/esm2_t33_650M_UR50D",
467
- torch_dtype="auto", low_cpu_mem_usage=True, use_safetensors=True,
 
 
468
  )
469
  tok = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
470
  return base, tok
471
 
 
472
  def _download_weights_hf() -> bytes:
473
  from huggingface_hub import hf_hub_download
474
  path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_FILENAME, resume_download=True)
475
  with open(path,"rb") as f: return f.read()
476
 
 
477
  def build_and_load_model(weights_bytes, pkd_bounds):
478
  import torch.nn as nn
479
  import torch.nn.functional as F
@@ -528,6 +459,7 @@ def build_and_load_model(weights_bytes, pkd_bounds):
528
  model.to(device).eval()
529
  return model, device
530
 
 
531
  def _ensure_model_loaded(pkd_lo=1.0, pkd_hi=16.0) -> bool:
532
  if st.session_state.model is not None: return True
533
  try:
@@ -544,46 +476,77 @@ def _ensure_model_loaded(pkd_lo=1.0, pkd_hi=16.0) -> bool:
544
 
545
  # ══════════════════════════════════════════════════════════════════════════════
546
  # INTEGRATED GRADIENTS
 
 
 
 
 
547
  # ══════════════════════════════════════════════════════════════════════════════
548
- def compute_ig(model, seq_a, seq_b, steps=30):
549
  device = next(model.parameters()).device
550
  tok = model.esm_tokenizer
551
  cls_tok = tok.cls_token
552
  esm = model.esm_model
 
553
  def tokenise(seq):
554
- proc = seq.replace("|",f"{cls_tok}{cls_tok}")
555
- return tok(proc,return_tensors="pt",padding=False,truncation=True,max_length=1024).to(device)
 
 
556
  word_embed = esm.base_model.model.embeddings.word_embeddings
557
- enc_a = tokenise(seq_a); mask_a = enc_a["attention_mask"]
558
- enc_b = tokenise(seq_b); mask_b = enc_b["attention_mask"]
559
- emb_a = word_embed(enc_a["input_ids"]).detach()
560
- emb_b = word_embed(enc_b["input_ids"]).detach()
561
- def encode(embs,mask):
562
- ext = esm.base_model.model.get_extended_attention_mask(mask,embs.shape[:2])
563
- h = esm.base_model.model.encoder(embs,attention_mask=ext).last_hidden_state
564
- m = mask.unsqueeze(-1).expand(h.size()).float()
565
- return (torch.sum(h*m,1)/torch.clamp(m.sum(1),min=1e-9)).float()
566
- def fwd(e_a,e_b):
567
- return model.projection_head(encode(e_a,mask_a),encode(e_b,mask_b))
568
- def riemann(tgt,baseline,fixed,tgt_is_b):
 
 
 
 
 
 
 
 
 
 
 
 
 
569
  grads = []
570
- for a in torch.linspace(0,1,steps,device=device):
571
- x = (baseline+a*(tgt-baseline)).detach().requires_grad_(True)
572
- out = fwd(fixed,x) if tgt_is_b else fwd(x,fixed)
573
  out.sum().backward()
574
- grads.append(x.grad.detach().clone())
575
- avg = torch.stack(grads).mean(0)
576
- return (avg*(tgt-baseline)).abs().sum(-1).squeeze(0).cpu().numpy()
577
- attr_a = riemann(emb_a,torch.zeros_like(emb_a),emb_b,False)
578
- attr_b = riemann(emb_b,torch.zeros_like(emb_b),emb_a,True)
 
 
 
 
 
 
579
  def norm(a):
580
- lo,hi = a.min(),a.max()
581
- return ((a-lo)/(hi-lo+1e-9)).tolist()
 
 
 
 
582
  return norm(attr_a[1:-1]), norm(attr_b[1:-1])
583
 
584
 
585
  # ══════════════════════════════════════════════════════════════════════════════
586
- # NGL VIEWER — theme-reactive light dark
587
  # ══════════════════════════════════════════════════════════════════════════════
588
  def ngl_viewer_html(pdb_content, ig_chain_map=None, height=420) -> str:
589
  ig_json = json.dumps(ig_chain_map) if ig_chain_map else "null"
@@ -595,8 +558,7 @@ def ngl_viewer_html(pdb_content, ig_chain_map=None, height=420) -> str:
595
  *{{box-sizing:border-box;margin:0;padding:0}}
596
  body{{overflow:hidden;font-family:'JetBrains Mono',monospace;transition:background .35s}}
597
  #vp{{width:100%;height:{height}px}}
598
- #hint{{position:absolute;top:10px;left:12px;font-size:10px;pointer-events:none;
599
- letter-spacing:.04em;transition:color .3s}}
600
  #ctrl{{position:absolute;bottom:12px;right:12px;display:flex;flex-direction:column;gap:4px}}
601
  .cb{{padding:4px 11px;font-family:'JetBrains Mono',monospace;font-size:9px;font-weight:500;
602
  border-radius:5px;cursor:pointer;letter-spacing:.1em;text-transform:uppercase;transition:all .15s}}
@@ -613,57 +575,39 @@ body{{overflow:hidden;font-family:'JetBrains Mono',monospace;transition:backgrou
613
  </div>
614
  <div id="leg"></div>
615
  <script>
616
- var T = {{
617
- light:{{
618
- bg:'#f0f4fa',
619
- hint:'rgba(51,65,85,.5)',
620
- btn:'rgba(240,244,250,.93)','btn_b':'rgba(37,99,235,.16)',
621
- btn_c:'rgba(71,85,105,.6)',
622
  btn_hc:'#2563eb',btn_hb:'rgba(37,99,235,.55)',btn_hbg:'rgba(37,99,235,.06)',
623
  leg:'rgba(240,244,250,.93)',leg_b:'rgba(37,99,235,.13)',leg_c:'rgba(71,85,105,.65)',
624
- ig_hi:'#1d4ed8',ig_lo:'#c7d7f0',
625
- null_col:0xe8ecf4,
626
- }},
627
- dark:{{
628
- bg:'#06090f',
629
- hint:'rgba(168,184,216,.4)',
630
- btn:'rgba(6,9,15,.88)',btn_b:'rgba(96,165,250,.16)',
631
- btn_c:'rgba(168,184,216,.55)',
632
  btn_hc:'#60a5fa',btn_hb:'rgba(96,165,250,.55)',btn_hbg:'rgba(96,165,250,.06)',
633
  leg:'rgba(6,9,15,.88)',leg_b:'rgba(96,165,250,.13)',leg_c:'rgba(168,184,216,.55)',
634
- ig_hi:'#60a5fa',ig_lo:'#172135',
635
- null_col:0x172135,
636
- }}
637
  }};
638
-
639
  function isDark(){{return window.matchMedia('(prefers-color-scheme:dark)').matches;}}
640
- var theme = isDark() ? T.dark : T.light;
641
- var stage = null, comp = null, curRep = 'cartoon';
642
- var igData = {ig_json};
643
 
644
- function hex(s,dark){{
645
- // light: #c7d7f0(199,215,240) → #1d4ed8(29,78,216)
646
- // dark: #172135(23,33,53) → #60a5fa(96,165,250)
647
  s=Math.max(0,Math.min(1,s||0));
648
- if(dark) return (Math.round(23+s*73)<<16)|(Math.round(33+s*132)<<8)|Math.round(53+s*197);
649
- return (Math.round(199-s*170)<<16)|(Math.round(215-s*137)<<8)|Math.round(240-s*24);
650
  }}
651
-
652
  function styleUI(dark){{
653
- var Th = dark ? T.dark : T.light;
654
- theme = Th;
655
- document.body.style.background = Th.bg;
656
- document.getElementById('hint').style.color = Th.hint;
657
  document.querySelectorAll('.cb').forEach(function(b){{
658
  b.style.background=Th.btn;b.style.border='1px solid '+Th.btn_b;b.style.color=Th.btn_c;
659
  b.onmouseenter=function(){{b.style.borderColor=Th.btn_hb;b.style.color=Th.btn_hc;b.style.background=Th.btn_hbg;}};
660
- b.onmouseleave=function(){{b.style.borderColor=Th.btn_b;b.style.color=Th.btn_c;b.style.background=Th.btn; }};
661
  }});
662
  var leg=document.getElementById('leg');
663
  leg.style.background=Th.leg;leg.style.border='1px solid '+Th.leg_b;leg.style.color=Th.leg_c;
664
- if(stage) stage.setParameters({{backgroundColor:Th.bg}});
665
  }}
666
-
667
  function addRepr(c){{
668
  comp=c;comp.removeAllRepresentations();
669
  var dark=isDark(),Th=dark?T.dark:T.light,leg=document.getElementById('leg');
@@ -673,43 +617,31 @@ function addRepr(c){{
673
  var cd=igData[atom.chainname];
674
  if(!cd||!cd.scores||!cd.scores.length) return Th.null_col;
675
  var i=atom.resno-cd.start;
676
- return(i<0||i>=cd.scores.length)?Th.null_col:hex(cd.scores[i],dark);
677
  }};
678
  }});
679
  comp.addRepresentation(curRep,{{color:sid}});
680
  leg.innerHTML='<span style="color:'+Th.ig_hi+'">■</span> High IG &nbsp;&nbsp;<span style="color:'+Th.ig_lo+'">■</span> Low IG';
681
- }} else {{
682
  comp.addRepresentation(curRep,{{colorScheme:'chainname'}});
683
  leg.innerHTML='Coloured by chain';
684
  }}
685
  stage.autoView();
686
  }}
687
-
688
  function setRep(r){{curRep=r;if(comp){{comp.removeAllRepresentations();addRepr(comp);}}}}
689
-
690
- function themeSwitch(dark){{
691
- styleUI(dark);
692
- if(comp){{comp.removeAllRepresentations();addRepr(comp);}}
693
- }}
694
-
695
- window.matchMedia('(prefers-color-scheme:dark)').addEventListener('change',function(e){{
696
- themeSwitch(e.matches);
697
- }});
698
- window.addEventListener('message',function(e){{
699
- if(e.data&&e.data.balmTheme) themeSwitch(e.data.balmTheme==='dark');
700
- }});
701
  window.addEventListener('resize',function(){{if(stage)stage.handleResize();}});
702
-
703
- stage = new NGL.Stage('vp',{{backgroundColor:theme.bg,quality:'medium',tooltip:true}});
704
  styleUI(isDark());
705
-
706
  var blob=new Blob([`{escaped}`],{{type:'text/plain'}});
707
  stage.loadFile(URL.createObjectURL(blob),{{ext:'pdb',name:'s'}}).then(addRepr);
708
  </script></body></html>"""
709
 
710
 
711
  # ══════════════════════════════════════════════════════════════════════════════
712
- # PLOTLY — light theme (matches default)
713
  # ══════════════════════════════════════════════════════════════════════════════
714
  _PL = dict(paper_bgcolor="rgba(255,255,255,.97)",plot_bgcolor="rgba(241,244,249,1)")
715
  _FONT = dict(family="JetBrains Mono, monospace",color="#475569",size=9)
@@ -725,8 +657,7 @@ def make_heatmap(seq, attr, title):
725
  fig = go.Figure(go.Heatmap(z=rows,text=hover,hoverinfo="text",
726
  colorscale=[[0,"#e8f0fe"],[.4,"#93c5fd"],[.75,"#3b82f6"],[1,"#1d4ed8"]],
727
  zmin=0,zmax=1,xgap=1,ygap=1,
728
- colorbar=dict(thickness=10,len=.9,tickfont=dict(**_FONT),
729
- title=dict(text="IG",font=dict(**_FONT)))))
730
  fig.update_layout(
731
  title=dict(text=title,font=dict(family="JetBrains Mono, monospace",size=10,color="#1d4ed8")),
732
  **_PL,height=200,margin=dict(t=30,b=24,l=55,r=45),
@@ -768,9 +699,6 @@ def residue_strip_html(seq, attr) -> str:
768
  '<div style="line-height:2.2;word-break:break-all">'+"".join(cells)+"</div>")
769
 
770
 
771
- # ══════════════════════════════════════════════════════════════════════════════
772
- # BATCH TEMPLATES
773
- # ══════════════════════════════════════════════════════════════════════════════
774
  PPI_TEMPLATE = "seq_a,seq_b\nACDEFGHIKLMNPQRSTVWY,QWERTYIPASDFGKLCVBNM\n"
775
  ABAG_TEMPLATE = "heavy_chain,light_chain,antigen\nEVQLVESGGG...,DIQMTQ...,IYSPT...\n"
776
 
@@ -786,14 +714,10 @@ def main():
786
  if _src in st.session_state:
787
  st.session_state[_dst] = st.session_state.pop(_src)
788
 
789
- # Auto-load model from example click
790
  if st.session_state.get("_auto_load_model") and st.session_state.model is None:
791
  st.session_state["_auto_load_model"] = False
792
- _ensure_model_loaded(
793
- st.session_state.get("pkd_lo",1.0),
794
- st.session_state.get("pkd_hi",16.0))
795
 
796
- # Header
797
  st.markdown("""
798
  <div class="app-header">
799
  <div class="app-logo">🧬</div>
@@ -801,15 +725,14 @@ def main():
801
  <div class="app-title">BALM-PPI Pro</div>
802
  <div class="app-subtitle">ESM-2 · LoRA · Integrated Gradients &nbsp;·&nbsp; Protein Binding Affinity Prediction</div>
803
  </div>
804
- </div>
805
- """, unsafe_allow_html=True)
806
  st.divider()
807
 
808
- # ── SIDEBAR ──────────────────────────────────────────────────────────────
809
  with st.sidebar:
810
  st.markdown("### ⚙️ Model")
811
- st.markdown('<div class="model-section">', unsafe_allow_html=True)
812
- st.markdown('<div class="model-section-title">Weights Source</div>', unsafe_allow_html=True)
813
  model_src = st.radio("src",["HuggingFace (auto-cached)","Upload .pth"],
814
  key="model_src",label_visibility="collapsed")
815
  custom_w = None
@@ -826,12 +749,11 @@ def main():
826
 
827
  if st.button("⚡ Load Model",use_container_width=True,type="primary"):
828
  try:
829
- if model_src == "HuggingFace (auto-cached)":
830
  with st.spinner("📥 Fetching weights (cached after first run)…"):
831
  wb = _download_weights_hf()
832
  else:
833
- if custom_w is None:
834
- st.error("Upload a .pth file first."); st.stop()
835
  wb = custom_w.read()
836
  with st.spinner("🔧 Building ESM-2 + LoRA…"):
837
  m,dev = build_and_load_model(wb,(pkd_lo,pkd_hi))
@@ -842,10 +764,7 @@ def main():
842
  st.error(f"❌ {e}"); st.code(traceback.format_exc())
843
 
844
  if st.session_state.model:
845
- st.markdown(
846
- f'<div class="ready-badge"><span class="ready-dot"></span>'
847
- f'READY &nbsp;·&nbsp; {st.session_state.device}</div>',
848
- unsafe_allow_html=True)
849
  else:
850
  st.markdown('<div class="idle-badge">○ &nbsp;NOT LOADED</div>',unsafe_allow_html=True)
851
 
@@ -854,12 +773,10 @@ def main():
854
  res = st.session_state.result
855
  st.markdown("**Latest Result**")
856
  st.markdown(
857
- f'<div class="pkd-card">'
858
- f'<div class="pkd-lbl">Predicted pKd</div>'
859
  f'<div class="pkd-val">{res["pkd"]:.3f}</div>'
860
  f'<div class="pkd-lbl" style="margin-top:10px">Cosine Similarity</div>'
861
- f'<div style="font-family:var(--mono);font-size:1rem;font-weight:700;color:var(--text0)">'
862
- f'{res["cosine"]:.4f}</div></div>',
863
  unsafe_allow_html=True)
864
 
865
  st.divider()
@@ -869,17 +786,15 @@ def main():
869
  '💻 <a href="https://github.com/rgorantla04/BALM-PPI" style="color:var(--accent);text-decoration:none">rgorantla04/BALM-PPI</a>'
870
  '</div>',unsafe_allow_html=True)
871
 
872
- # ── TABS ──────────────────────────────────────────────────────────────────
873
- tab_single, tab_batch, tab_info = st.tabs(
874
- ["🎯 Single Prediction","📂 Batch Prediction","📚 Model Info"])
875
 
876
  with tab_single:
877
  mode_opts = ["Protein-Protein","Antibody-Antigen"]
878
- mode = st.radio("im_r",mode_opts,horizontal=True,
879
- label_visibility="collapsed",key="interaction_mode")
880
  st.markdown("---")
881
 
882
- if mode == "Protein-Protein":
883
  sc1,sc2 = st.columns(2)
884
  with sc1:
885
  st.markdown("**Target Protein — Seq A**")
@@ -889,8 +804,7 @@ def main():
889
  st.markdown("**Binder Protein — Seq B** &nbsp; `|` = chain separator")
890
  seq_b_raw = st.text_area("Binder",label_visibility="collapsed",key="ta_ppi_b",
891
  height=140,placeholder="Paste FASTA or raw sequence…")
892
- seq_a = clean_multi(seq_a_raw)
893
- seq_b = clean_multi(seq_b_raw)
894
  if not st.session_state.pdb_content:
895
  if seq_a:
896
  st.session_state.chain_info_a = [
@@ -904,26 +818,20 @@ def main():
904
  sc1,sc2 = st.columns([2,3])
905
  with sc1:
906
  st.markdown("**Antigen / Target — Seq A**")
907
- ag_raw = st.text_area("Antigen",label_visibility="collapsed",
908
- key="ta_ab_ag",height=140,placeholder="Antigen sequence…")
909
  with sc2:
910
  ah,al = st.columns(2)
911
  with ah:
912
  st.markdown("**Heavy Chain (H)**")
913
- h_raw = st.text_area("Heavy",label_visibility="collapsed",
914
- key="ta_ab_h",height=140,placeholder="VH sequence…")
915
  with al:
916
  st.markdown("**Light Chain (L)**")
917
- l_raw = st.text_area("Light",label_visibility="collapsed",
918
- key="ta_ab_l",height=140,placeholder="VL sequence…")
919
- seq_a = clean_multi(ag_raw)
920
- h_seq = clean_multi(h_raw)
921
- l_seq = clean_multi(l_raw)
922
  seq_b = f"{h_seq}|{l_seq}" if (h_seq or l_seq) else ""
923
  if not st.session_state.pdb_content:
924
  if seq_a:
925
- st.session_state.chain_info_a = [
926
- {"id":"Ag","seq":seq_a,"resnos":list(range(1,len(seq_a)+1))}]
927
  if h_seq or l_seq:
928
  st.session_state.chain_info_b = [
929
  {"id":"H","seq":h_seq,"resnos":list(range(1,len(h_seq)+1))},
@@ -936,12 +844,10 @@ def main():
936
  if st.session_state.model is None:
937
  st.caption("⬅ Load model in sidebar first")
938
  with rc2:
939
- run_ig = st.checkbox("Run Integrated Gradients",value=True,
940
- help="~1–2 min CPU · <10 s GPU")
941
  with rc3:
942
  if st.session_state.model:
943
- st.markdown('<div class="ready-badge" style="margin-top:6px;font-size:.7rem">'
944
- '<span class="ready-dot"></span>READY</div>',unsafe_allow_html=True)
945
 
946
  if run_btn:
947
  if not seq_a or not seq_b:
@@ -958,37 +864,37 @@ def main():
958
  except Exception as e:
959
  st.error(f"Prediction failed: {e}"); st.code(traceback.format_exc())
960
  if run_ig and st.session_state.result:
961
- with st.spinner("Computing Integrated Gradients…"):
962
  try:
963
- ig_a,ig_b = compute_ig(st.session_state.model,seq_a,seq_b)
964
  st.session_state.ig_a = ig_a
965
  st.session_state.ig_b = ig_b
966
  if st.session_state.chain_info_a or st.session_state.chain_info_b:
967
  st.session_state.ig_chain_map = build_ig_chain_map(
968
  st.session_state.chain_info_a,st.session_state.chain_info_b,ig_a,ig_b)
969
  except Exception as e:
970
- st.warning(f"IG failed: {e}")
971
  st.rerun()
972
 
973
  if st.session_state.result:
974
  res = st.session_state.result
975
  pkd,cos = res["pkd"],res["cosine"]
976
  pct = max(0.0,min(100.0,((pkd-pkd_lo)/max(pkd_hi-pkd_lo,1e-9))*100))
977
- strength = ("Weak" if pct<30 else "Moderate" if pct<55 else "Strong" if pct<75 else "Very Strong")
978
- badge_cls = ("badge-weak" if pct<30 else "badge-moderate" if pct<55 else "badge-strong")
979
  st.markdown("---")
980
  m1,m2,m3 = st.columns([1,1,2])
981
  m1.metric("Predicted pKd",f"{pkd:.3f}")
982
  m2.metric("Cosine Similarity",f"{cos:.4f}")
983
  with m3:
984
  st.markdown(
985
- f'<div class="pkd-card">'
986
- f'<div class="str-labels"><span>Weak ({pkd_lo:.0f})</span>'
987
  f'<span style="color:var(--text0);font-weight:600">{strength}</span>'
988
  f'<span>Strong ({pkd_hi:.0f})</span></div>'
989
  f'<div class="str-bar"><div class="str-fill" style="width:{pct:.1f}%"></div></div>'
990
- f'<span class="pkd-badge {badge_cls}">{strength}</span>'
991
- f'</div>',unsafe_allow_html=True)
992
 
993
  st.markdown("---")
994
  mode_now = st.session_state.get("interaction_mode","Protein-Protein")
@@ -1013,13 +919,11 @@ def main():
1013
  st.markdown(
1014
  f'<div class="ex-card"><div class="ex-pdb">{ex["label"]}</div>'
1015
  f'<div class="ex-sub">{ex["subtitle"]}</div>'
1016
- f'<div class="ex-desc">{ex["desc"]}</div></div>',
1017
- unsafe_allow_html=True)
1018
  if st.button(f"⬇ Load {ex['pdb']}",key=f"ex_{ex['pdb']}",use_container_width=True):
1019
  ok = populate_from_pdb(ex["pdb"],ex["mode"],
1020
  chain_a=ex.get("chain_a"),chain_b=ex.get("chain_b"),
1021
- chain_h=ex.get("chain_h"),chain_l=ex.get("chain_l"),
1022
- chain_ag=ex.get("chain_ag"))
1023
  if ok:
1024
  if st.session_state.model is None:
1025
  st.session_state["_auto_load_model"] = True
@@ -1029,7 +933,7 @@ def main():
1029
  st.markdown('<div class="sec-hdr">Custom PDB Fetch</div>',unsafe_allow_html=True)
1030
  pdb_in = st.text_input("PDB ID",placeholder="6M0J, 1BRS, 2VXQ…",key="pdb_id_input")
1031
 
1032
- if mode == "Protein-Protein":
1033
  st.caption("Comma-separate chains for multi-chain, e.g. `B,C`")
1034
  fc1,fc2 = st.columns(2)
1035
  ca_in = fc1.text_input("Side A chain(s)",placeholder="A or A,B",key="ca_in")
@@ -1055,34 +959,29 @@ def main():
1055
  if ok: st.rerun()
1056
 
1057
  with right_col:
1058
- vt = st.tabs([
1059
- "🧬 Structure",
1060
- f"📊 {lbl_a} Strip",f"🗺 {lbl_a} Heatmap",
1061
- f"📊 {lbl_b} Strip",f"🗺 {lbl_b} Heatmap",
1062
- "📥 Download",
1063
- ])
1064
 
1065
  with vt[0]:
1066
  if st.session_state.pdb_content:
1067
- st.caption("🎨 IG attribution" if st.session_state.ig_chain_map
1068
- else "🔗 Coloured by chain — run prediction with IG for attribution")
1069
  st.components.v1.html(
1070
- ngl_viewer_html(st.session_state.pdb_content,
1071
- st.session_state.ig_chain_map,height=420),
1072
  height=425,scrolling=False)
1073
  else:
1074
  st.markdown(
1075
- '<div class="ngl-placeholder">'
1076
- '<div style="font-size:2.2rem">🧬</div>'
1077
  '<div style="margin-top:12px;font-weight:600">Load an example or fetch a PDB</div>'
1078
- '<div style="font-size:.72rem;margin-top:4px">Structure will appear here</div>'
1079
  '</div>',unsafe_allow_html=True)
1080
 
1081
  with vt[1]:
1082
  if ig_a_d and seq_a_d:
1083
  st.markdown(residue_strip_html(seq_a_d,ig_a_d),unsafe_allow_html=True)
1084
- st.plotly_chart(top10_bar(seq_a_d,ig_a_d,"#2563eb",f"Top 10 · {lbl_a}"),
1085
- use_container_width=True)
1086
  else:
1087
  st.info("Run prediction with **Integrated Gradients** to see attribution.")
1088
 
@@ -1096,8 +995,7 @@ def main():
1096
  with vt[3]:
1097
  if ig_b_d and seq_b_d:
1098
  st.markdown(residue_strip_html(seq_b_d,ig_b_d),unsafe_allow_html=True)
1099
- st.plotly_chart(top10_bar(seq_b_d,ig_b_d,"#7c3aed",f"Top 10 · {lbl_b}"),
1100
- use_container_width=True)
1101
  else:
1102
  st.info("Run prediction with **Integrated Gradients** to see attribution.")
1103
 
@@ -1112,37 +1010,27 @@ def main():
1112
  if st.session_state.result:
1113
  if ig_a_d or ig_b_d:
1114
  dl1,dl2 = st.columns(2)
1115
- buf = sio.StringIO()
1116
- wc = csv.writer(buf)
1117
  wc.writerow(["chain","position","residue","ig_score"])
1118
- for i,(aa,sc) in enumerate(zip(seq_a_d or "",ig_a_d or [])):
1119
- wc.writerow(["Target",i+1,aa,f"{sc:.6f}"])
1120
- for i,(aa,sc) in enumerate(zip(seq_b_d or "",ig_b_d or [])):
1121
- wc.writerow(["proteina",i+1,aa,f"{sc:.6f}"])
1122
- dl1.download_button("📥 IG Scores CSV",buf.getvalue().encode(),
1123
- "ig_scores.csv","text/csv",use_container_width=True)
1124
- res_dl = st.session_state.result
1125
  summary = {
1126
  "pkd":res_dl["pkd"],"cosine":res_dl["cosine"],
1127
- "Target_length":len(seq_a_d or ""),
1128
- "proteina_length":len(seq_b_d or ""),
1129
  "top10_Target":sorted([{"res":f"{(seq_a_d or '')[i]}{i+1}","ig":(ig_a_d or [])[i]}
1130
- for i in range(min(len(seq_a_d or ""),len(ig_a_d or [])))],
1131
- key=lambda x:-x["ig"])[:10],
1132
  "top10_proteina":sorted([{"res":f"{(seq_b_d or '')[i]}{i+1}","ig":(ig_b_d or [])[i]}
1133
- for i in range(min(len(seq_b_d or ""),len(ig_b_d or [])))],
1134
- key=lambda x:-x["ig"])[:10],
1135
  }
1136
- dl2.download_button("📥 Summary JSON",
1137
- json.dumps(summary,indent=2).encode(),
1138
- "balm_result.json","application/json",
1139
- use_container_width=True)
1140
  else:
1141
  res_dl = st.session_state.result
1142
  st.info("Enable IG and re-run for per-residue CSV.")
1143
  st.download_button("📥 Basic Result JSON",
1144
- json.dumps({"pkd":res_dl["pkd"],"cosine":res_dl["cosine"]},indent=2).encode(),
1145
- "balm_result.json","application/json")
1146
  else:
1147
  st.info("Run a prediction to enable download.")
1148
 
@@ -1161,8 +1049,7 @@ def main():
1161
  df = pd.read_csv(bf)
1162
  st.dataframe(df.head(),use_container_width=True)
1163
  if st.button("🏃 Run Batch",type="primary",use_container_width=True):
1164
- if st.session_state.model is None:
1165
- st.error("Load the model first.")
1166
  else:
1167
  model,results = st.session_state.model,[]
1168
  prog,stat,n = st.progress(0.0),st.empty(),len(df)
@@ -1180,7 +1067,7 @@ def main():
1180
  pkd_t,cos_t = model(sa,sb)
1181
  rr={"pKd":round(float(pkd_t.item()),4),"Cosine":round(float(cos_t.item()),6)}
1182
  if run_ig_b:
1183
- ia,ib=compute_ig(model,sa,sb)
1184
  sap,sbp=sa.replace("|",""),sb.replace("|","")
1185
  rr["top5_Target"]="|".join(f"{sap[j]}{j+1}:{ia[j]:.3f}"
1186
  for j in sorted(range(min(len(sap),len(ia))),key=lambda x:-ia[x])[:5])
@@ -1193,14 +1080,13 @@ def main():
1193
  df_res = pd.concat([df,pd.DataFrame(results)],axis=1)
1194
  st.success(f"✅ {n} predictions complete")
1195
  st.dataframe(df_res,use_container_width=True)
1196
- st.download_button("📥 Download Results CSV",df_res.to_csv(index=False).encode(),
1197
- "balm_batch_results.csv","text/csv",use_container_width=True)
1198
 
1199
  with tab_info:
1200
  st.markdown("### 📚 About BALM-PPI")
1201
  st.markdown("""
1202
- **BALM-PPI** predicts protein–protein and antibody–antigen binding affinity
1203
- using ESM-2 fine-tuned with LoRA, with Integrated Gradients explainability.
1204
 
1205
  | Component | Detail |
1206
  |-----------|--------|
@@ -1209,8 +1095,9 @@ using ESM-2 fine-tuned with LoRA, with Integrated Gradients explainability.
1209
  | Column mapping | `seq_a` → **Target** · `seq_b` → **proteina** |
1210
  | Multi-chain | `\\|` separator → `<cls><cls>` double CLS token |
1211
  | Affinity output | pKd = ((cos+1)/2) × (pKd_max − pKd_min) + pKd_min |
1212
- | Explainability | Integrated Gradients, 30-step Riemann approximation |
1213
- | Speed | hf_transfer · torch_dtype=auto · low_cpu_mem_usage · HF disk cache · PDB cache (1 h TTL) |
 
1214
 
1215
  **Examples** · **6M0J** — ACE2 ↔ SARS-CoV-2 RBD &nbsp;·&nbsp; **1BRS** — Barnase ↔ Barstar
1216
 
 
1
  """
2
  BALM-PPI Pro · ESM-2 + LoRA + Integrated Gradients
3
+ Fixes:
4
+ - IG forced to float32 (torch_dtype="auto" → fp16 kills grads on HF Spaces)
5
+ - IG steps reduced 30→15 for speed with negligible accuracy loss
6
+ - NGL viewer filters PDB to only the selected chains
7
+ - PDB fetches cached for 1 h
8
  """
9
 
 
10
  import os
11
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
12
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
 
21
  from Bio import PDB
22
  from Bio.Data.PDBData import protein_letters_3to1 as THREE_TO_ONE
23
 
24
+ st.set_page_config(page_title="BALM-PPI Pro", page_icon="🧬", layout="wide",
25
  initial_sidebar_state="expanded")
26
 
27
  HF_REPO_ID = "Harshit494/BALM-PPI"
 
41
  },
42
  ]
43
 
 
44
  _DEFAULTS: dict = {
45
  "model": None, "device": None,
46
  "result": None, "ig_a": None, "ig_b": None, "ig_chain_map": None,
 
55
  st.session_state[_k] = _v
56
 
57
  # ══════════════════════════════════════════════════════════════════════════════
58
+ # CSS — light-first, dark via @media
59
  # ══════════════════════════════════════════════════════════════════════════════
60
  st.markdown("""
61
  <style>
62
  @import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400;500;600;700&family=Inter:wght@300;400;500;600;700&display=swap');
 
 
63
  :root {
64
  --bg0:#ffffff; --bg1:#f8f9fc; --bg2:#f1f4f9; --bg3:#e4e9f2;
65
  --border:rgba(37,99,235,0.13); --shadow:rgba(37,99,235,0.07);
 
68
  --text0:#0f172a; --text1:#334155; --text2:#64748b; --text3:#94a3b8;
69
  --mono:'JetBrains Mono',monospace; --sans:'Inter',sans-serif;
70
  }
 
71
  @media (prefers-color-scheme:dark){
72
  :root {
73
  --bg0:#06090f; --bg1:#0b1120; --bg2:#101828; --bg3:#172135;
 
77
  --text0:#f1f5f9; --text1:#cbd5e1; --text2:#64748b; --text3:#334155;
78
  }
79
  }
 
 
80
  html,body,[data-testid="stAppViewContainer"],[data-testid="stMain"]{
81
+ background:var(--bg0)!important;font-family:var(--sans)!important;
82
  }
83
  [data-testid="stHeader"]{
84
+ background:var(--bg0)!important;border-bottom:1px solid var(--border)!important;
 
85
  box-shadow:0 1px 8px var(--shadow)!important;
86
  }
87
+ [data-testid="stSidebar"]{background:var(--bg1)!important;border-right:1px solid var(--border)!important;}
 
 
 
88
  [data-testid="stSidebar"] *{font-family:var(--sans)!important;}
 
 
89
  h1,h2,h3,h4,h5,h6{font-family:var(--sans)!important;color:var(--text0)!important;font-weight:700!important;}
90
  p,label,div,span,li{color:var(--text1);}
91
  code,pre{font-family:var(--mono)!important;font-size:.82rem!important;
92
  background:var(--bg2)!important;border-radius:4px!important;color:var(--accent)!important;}
93
+ .stButton>button{font-family:var(--sans)!important;font-weight:600!important;
94
+ letter-spacing:.01em!important;border-radius:8px!important;transition:all .15s ease!important;height:38px!important;}
 
 
 
 
 
95
  .stButton>button[kind="primary"]{
96
  background:linear-gradient(135deg,var(--accent),var(--accent2))!important;
97
+ border:none!important;color:#fff!important;box-shadow:0 2px 12px var(--shadow)!important;}
98
+ .stButton>button[kind="primary"]:hover{box-shadow:0 4px 20px rgba(37,99,235,.28)!important;transform:translateY(-1px)!important;}
99
+ .stButton>button[kind="secondary"]{background:var(--bg1)!important;border:1px solid var(--border)!important;color:var(--text1)!important;}
100
+ .stButton>button[kind="secondary"]:hover{border-color:var(--accent)!important;color:var(--accent)!important;background:rgba(37,99,235,.05)!important;}
101
+ .stTextArea textarea{background:var(--bg0)!important;border:1.5px solid var(--border)!important;
102
+ border-radius:8px!important;color:var(--text0)!important;font-family:var(--mono)!important;
103
+ font-size:.82rem!important;line-height:1.65!important;resize:vertical!important;}
104
+ .stTextArea textarea:focus{border-color:var(--accent)!important;box-shadow:0 0 0 3px rgba(37,99,235,.1)!important;}
105
+ .stTextArea label{font-weight:600!important;font-size:.76rem!important;text-transform:uppercase!important;
106
+ letter-spacing:.09em!important;color:var(--text2)!important;}
107
+ .stTextInput input,.stNumberInput input{background:var(--bg0)!important;border:1.5px solid var(--border)!important;
108
+ border-radius:8px!important;color:var(--text0)!important;}
109
+ .stTextInput input:focus,.stNumberInput input:focus{border-color:var(--accent)!important;box-shadow:0 0 0 3px rgba(37,99,235,.1)!important;}
110
+ .stTextInput label,.stNumberInput label{font-size:.76rem!important;font-weight:600!important;
111
+ color:var(--text2)!important;text-transform:uppercase!important;letter-spacing:.08em!important;}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  .stRadio label,.stCheckbox label{color:var(--text1)!important;font-size:.9rem!important;}
113
+ .stTabs [data-baseweb="tab-list"]{background:var(--bg2)!important;border-radius:10px!important;
114
+ padding:4px!important;gap:3px!important;border:1px solid var(--border)!important;}
115
+ .stTabs [data-baseweb="tab"]{background:transparent!important;border-radius:7px!important;
116
+ color:var(--text2)!important;font-family:var(--sans)!important;font-size:.82rem!important;
117
+ font-weight:500!important;transition:all .14s ease!important;padding:6px 13px!important;}
118
+ .stTabs [aria-selected="true"]{background:var(--bg0)!important;color:var(--accent)!important;
119
+ font-weight:600!important;box-shadow:0 1px 6px var(--shadow)!important;}
120
+ [data-testid="stMetric"]{background:var(--bg1)!important;border:1px solid var(--border)!important;
121
+ border-radius:10px!important;padding:14px 18px!important;box-shadow:0 1px 4px var(--shadow)!important;}
122
+ [data-testid="stMetricLabel"]{color:var(--text2)!important;font-size:.72rem!important;
123
+ text-transform:uppercase!important;letter-spacing:.1em!important;}
124
+ [data-testid="stMetricValue"]{color:var(--accent)!important;font-family:var(--mono)!important;
125
+ font-size:1.55rem!important;font-weight:700!important;}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  hr{border-color:var(--border)!important;margin:12px 0!important;}
127
  [data-testid="stAlert"]{background:var(--bg1)!important;border-radius:8px!important;border-left-width:3px!important;}
128
  [data-testid="stDataFrame"]{border:1px solid var(--border)!important;border-radius:8px!important;overflow:hidden!important;}
129
+ .stProgress>div>div{background:linear-gradient(90deg,var(--accent),var(--accent2))!important;border-radius:4px!important;}
 
 
 
130
 
 
131
  .app-header{display:flex;align-items:center;gap:14px;padding:4px 0 14px;}
132
+ .app-logo{width:40px;height:40px;border-radius:10px;
 
133
  background:linear-gradient(135deg,var(--accent),var(--accent2));
134
  display:flex;align-items:center;justify-content:center;
135
+ font-size:20px;flex-shrink:0;box-shadow:0 2px 12px rgba(37,99,235,.22);}
 
136
  .app-title{font-family:var(--mono)!important;font-size:1.22rem!important;
137
  font-weight:700!important;color:var(--text0)!important;letter-spacing:-.02em;}
138
+ .app-subtitle{font-size:.77rem!important;color:var(--text2)!important;margin-top:1px;letter-spacing:.03em;}
139
+ .sec-hdr{font-family:var(--mono);font-size:.65rem;font-weight:700;letter-spacing:.2em;
140
+ text-transform:uppercase;color:var(--accent);margin:14px 0 8px;
141
+ display:flex;align-items:center;gap:8px;}
 
 
 
 
 
142
  .sec-hdr::after{content:'';flex:1;height:1px;background:var(--border);}
143
+ .ex-card{background:var(--bg1);border:1px solid var(--border);border-radius:10px;
144
+ padding:12px 15px;margin-bottom:8px;position:relative;overflow:hidden;
145
+ transition:box-shadow .15s,border-color .15s;}
146
+ .ex-card::before{content:'';position:absolute;left:0;top:0;bottom:0;width:3px;
147
+ background:linear-gradient(180deg,var(--accent),var(--accent2));border-radius:3px 0 0 3px;}
 
 
 
 
 
 
 
148
  .ex-card:hover{border-color:rgba(37,99,235,.3);box-shadow:0 2px 12px var(--shadow);}
149
  .ex-pdb{font-family:var(--mono);font-size:.95rem;font-weight:700;color:var(--accent);}
150
  .ex-sub{font-size:.8rem;font-weight:600;color:var(--text1);margin:2px 0;}
151
  .ex-desc{font-size:.72rem;line-height:1.5;color:var(--text2);margin-top:3px;}
152
+ .pkd-card{background:linear-gradient(135deg,rgba(37,99,235,.05),rgba(124,58,237,.04));
153
+ border:1px solid var(--border);border-radius:12px;padding:14px 18px;box-shadow:0 1px 6px var(--shadow);}
154
+ .pkd-lbl{font-family:var(--mono);font-size:.67rem;color:var(--text2);text-transform:uppercase;letter-spacing:.12em;margin-bottom:2px;}
155
+ .pkd-val{font-family:var(--mono);font-size:2rem;font-weight:700;color:var(--accent);line-height:1.1;}
156
+ .pkd-badge{display:inline-block;padding:2px 9px;border-radius:20px;font-size:.7rem;font-weight:600;font-family:var(--mono);margin-top:5px;}
 
 
 
 
 
 
 
157
  .badge-weak{background:rgba(220,38,38,.08);color:var(--red);border:1px solid rgba(220,38,38,.22);}
158
  .badge-moderate{background:rgba(217,119,6,.08);color:var(--amber);border:1px solid rgba(217,119,6,.22);}
159
  .badge-strong{background:rgba(22,163,74,.08);color:var(--green);border:1px solid rgba(22,163,74,.22);}
 
160
  .str-bar{height:5px;border-radius:3px;background:var(--bg3);overflow:hidden;margin:8px 0 3px;}
161
+ .str-fill{height:100%;border-radius:3px;background:linear-gradient(90deg,var(--accent),var(--accent2));transition:width .7s cubic-bezier(.4,0,.2,1);}
162
+ .str-labels{display:flex;justify-content:space-between;font-size:.65rem;color:var(--text2);font-family:var(--mono);}
163
+ .ready-badge{display:inline-flex;align-items:center;gap:6px;padding:3px 10px;border-radius:20px;
 
 
 
 
 
 
164
  background:rgba(22,163,74,.07);border:1px solid rgba(22,163,74,.22);
165
+ font-family:var(--mono);font-size:.72rem;color:var(--green);font-weight:600;}
166
+ .ready-dot{display:inline-block;width:6px;height:6px;border-radius:50%;background:var(--green);animation:pulse 2s infinite;}
167
+ .idle-badge{display:inline-flex;align-items:center;gap:6px;padding:3px 10px;border-radius:20px;
 
 
 
 
168
  background:rgba(100,116,139,.07);border:1px solid rgba(100,116,139,.18);
169
+ font-family:var(--mono);font-size:.72rem;color:var(--text2);font-weight:600;}
 
170
  @keyframes pulse{0%,100%{opacity:1;transform:scale(1)}50%{opacity:.5;transform:scale(.85)}}
171
+ .model-section{background:var(--bg2);border:1px solid var(--border);border-radius:9px;padding:12px 14px;margin-bottom:10px;}
172
+ .model-section-title{font-family:var(--mono);font-size:.63rem;font-weight:700;letter-spacing:.18em;text-transform:uppercase;color:var(--text2);margin-bottom:8px;}
173
+ .ngl-placeholder{display:flex;flex-direction:column;align-items:center;justify-content:center;
174
+ height:420px;background:var(--bg2);border:1px solid var(--border);border-radius:12px;
175
+ color:var(--text2);font-family:var(--mono);font-size:.8rem;text-align:center;line-height:2;}
176
+
177
+ /* IG progress info box */
178
+ .ig-info{background:linear-gradient(135deg,rgba(37,99,235,.04),rgba(124,58,237,.03));
179
+ border:1px solid var(--border);border-radius:8px;padding:10px 14px;
180
+ font-family:var(--mono);font-size:.75rem;color:var(--text2);margin-bottom:8px;}
 
 
181
  </style>
182
  """, unsafe_allow_html=True)
183
 
184
+ # postMessage bridge for NGL iframe theme sync
185
  st.markdown("""
186
  <script>
187
  (function(){
 
190
  try{f.contentWindow.postMessage({balmTheme:dark?'dark':'light'},'*');}catch(e){}
191
  });
192
  }
193
+ window.matchMedia('(prefers-color-scheme:dark)').addEventListener('change',function(e){notify(e.matches);});
 
 
194
  })();
195
  </script>
196
  """, unsafe_allow_html=True)
197
 
198
 
199
  # ══════════════════════════════════════════════════════════════════════════════
200
+ # PDB UTILITIES
201
  # ══════════════════════════════════════════════════════════════════════════════
202
 
203
  @st.cache_data(show_spinner=False, ttl=3600)
 
208
  return r.text
209
 
210
 
211
+ def filter_pdb_to_chains(pdb_text: str, keep_chains: set) -> str:
212
+ """
213
+ Strip PDB text to only ATOM/HETATM records for chain IDs in keep_chains.
214
+ Preserves HEADER, TITLE, REMARK lines so NGL can still parse it.
215
+ This keeps the viewer uncluttered — only selected chains are shown.
216
+ """
217
+ out = []
218
+ for line in pdb_text.splitlines():
219
+ rec = line[:6].strip()
220
+ if rec in ("ATOM", "HETATM"):
221
+ chain_col = line[21:22] # column 22 (0-indexed: 21) = chain ID
222
+ if chain_col not in keep_chains:
223
+ continue
224
+ out.append(line)
225
+ return "\n".join(out)
226
+
227
+
228
  def get_chains_from_pdb(pdb_id: str):
229
  try:
230
  pdb_text = _fetch_pdb_cached(pdb_id.strip().upper())
 
278
  if not pdb_content or not chains: return False
279
  available = sorted(chains.keys())
280
  st.info(f"✅ **{pdb_id.upper()}** — chains: **{', '.join(available)}**")
281
+
282
+ st.session_state.ig_a = st.session_state.ig_b = st.session_state.ig_chain_map = None
283
+ st.session_state.result = None
 
 
284
  st.session_state.chain_info_a = []
285
  st.session_state.chain_info_b = []
286
  st.session_state["_pm"] = mode
287
+
288
+ selected_chain_ids: set = set() # ← collect which chains to keep in viewer
289
+
290
  if mode == "Protein-Protein":
291
  id_a = chain_a or available[0]
292
+ id_b = chain_b or (available[1] if len(available) > 1 else available[0])
293
  seq_a, info_a, miss_a = resolve_chains(chains, id_a, available)
294
  seq_b, info_b, miss_b = resolve_chains(chains, id_b, available)
295
  for m in miss_a: st.warning(f"Chain **{m}** not found (Side A). Available: {', '.join(available)}")
 
298
  st.session_state["_pb"] = seq_b
299
  st.session_state.chain_info_a = info_a
300
  st.session_state.chain_info_b = info_b
301
+ selected_chain_ids = {c["id"] for c in info_a} | {c["id"] for c in info_b}
302
  else:
303
+ id_h = chain_h or "H"
304
+ id_l = chain_l or "L"
305
+ id_ag_str = chain_ag or next((c for c in available if c not in (id_h, id_l)), available[0])
306
  def _single(cid):
307
  if cid in chains and chains[cid]["seq"]: return cid, chains[cid]
308
  f = next((k for k in chains if k.upper()==cid.upper() and chains[k]["seq"]), None)
 
318
  st.session_state["_pag"] = seq_ag
319
  st.session_state.chain_info_a = info_ag if info_ag else []
320
  st.session_state.chain_info_b = [{"id":real_h,**ch},{"id":real_l,**cl}]
321
+ selected_chain_ids = ({c["id"] for c in info_ag} if info_ag else set())
322
+ if ch["seq"]: selected_chain_ids.add(real_h)
323
+ if cl["seq"]: selected_chain_ids.add(real_l)
324
+
325
+ # ── Filter PDB to selected chains only ───────────────────────────────────
326
+ if selected_chain_ids:
327
+ filtered = filter_pdb_to_chains(pdb_content, selected_chain_ids)
328
+ st.session_state.pdb_content = filtered
329
+ else:
330
+ st.session_state.pdb_content = pdb_content
331
+
332
  return True
333
 
334
 
 
368
  seq_a = "".join(c["seq"] for c in chain_info_a) if chain_info_a else ""
369
  seq_b = "".join(c["seq"] for c in chain_info_b) if chain_info_b else ""
370
  if ig_a and chain_info_a:
371
+ parts = split_ig_by_chains(ig_a,[c["seq"] for c in chain_info_a])
372
  ig_a_out = [s for sub in parts for s in sub]
373
  else:
374
  ig_a_out = list(ig_a) if ig_a else []
375
  if ig_b and chain_info_b:
376
+ parts = split_ig_by_chains(ig_b,[c["seq"] for c in chain_info_b])
377
  ig_b_out = [s for sub in parts for s in sub]
378
  else:
379
  ig_b_out = list(ig_b) if ig_b else []
 
382
 
383
  # ══════════════════════════════════════════════════════════════════════════════
384
  # MODEL
385
+ # KEY FIX: do NOT use torch_dtype="auto" — on HF Spaces this loads fp16
386
+ # which silently produces zero gradients and breaks IG entirely.
387
+ # Always load in float32; the extra memory is worth correct attributions.
388
  # ══════════════════════════════════════════════════════════════════════════════
389
  @st.cache_resource(show_spinner="Loading ESM-2 backbone…")
390
  def _load_esm_base():
391
  from transformers import EsmModel, EsmTokenizer
392
  base = EsmModel.from_pretrained(
393
  "facebook/esm2_t33_650M_UR50D",
394
+ # torch_dtype="auto" removed: fp16 kills gradients on HF Spaces / CPU
395
+ # low_cpu_mem_usage removed: creates meta tensors → deepcopy+.to(device) crashes
396
+ use_safetensors=True,
397
  )
398
  tok = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
399
  return base, tok
400
 
401
+
402
  def _download_weights_hf() -> bytes:
403
  from huggingface_hub import hf_hub_download
404
  path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_FILENAME, resume_download=True)
405
  with open(path,"rb") as f: return f.read()
406
 
407
+
408
  def build_and_load_model(weights_bytes, pkd_bounds):
409
  import torch.nn as nn
410
  import torch.nn.functional as F
 
459
  model.to(device).eval()
460
  return model, device
461
 
462
+
463
  def _ensure_model_loaded(pkd_lo=1.0, pkd_hi=16.0) -> bool:
464
  if st.session_state.model is not None: return True
465
  try:
 
476
 
477
  # ══════════════════════════════════════════════════════════════════════════════
478
  # INTEGRATED GRADIENTS
479
+ # KEY FIXES:
480
+ # 1. Explicit .float() casts — model may internally have mixed dtypes after LoRA
481
+ # 2. Steps reduced 30 → 15 (Riemann IG converges fast; halves compute time)
482
+ # 3. word_embed detach + clone to avoid stale graph issues
483
+ # 4. Proper gradient zeroing between steps
484
  # ══════════════════════════════════════════════════════════════════════════════
485
+ def compute_ig(model, seq_a: str, seq_b: str, steps: int = 15):
486
  device = next(model.parameters()).device
487
  tok = model.esm_tokenizer
488
  cls_tok = tok.cls_token
489
  esm = model.esm_model
490
+
491
  def tokenise(seq):
492
+ proc = seq.replace("|", f"{cls_tok}{cls_tok}")
493
+ return tok(proc, return_tensors="pt",
494
+ padding=False, truncation=True, max_length=1024).to(device)
495
+
496
  word_embed = esm.base_model.model.embeddings.word_embeddings
497
+
498
+ enc_a = tokenise(seq_a); mask_a = enc_a["attention_mask"]
499
+ enc_b = tokenise(seq_b); mask_b = enc_b["attention_mask"]
500
+
501
+ # ── Force float32 — critical on HF Spaces where fp16 zeros gradients ────
502
+ emb_a = word_embed(enc_a["input_ids"]).detach().float().clone()
503
+ emb_b = word_embed(enc_b["input_ids"]).detach().float().clone()
504
+ mask_a = mask_a.float()
505
+ mask_b = mask_b.float()
506
+
507
+ def encode(embs, mask):
508
+ # Cast mask for attention
509
+ int_mask = mask.long()
510
+ ext = esm.base_model.model.get_extended_attention_mask(int_mask, embs.shape[:2])
511
+ # Run encoder — force float32 throughout
512
+ h = esm.base_model.model.encoder(
513
+ embs.float(), attention_mask=ext
514
+ ).last_hidden_state.float()
515
+ m = mask.unsqueeze(-1).expand(h.size()).float()
516
+ return (torch.sum(h * m, 1) / torch.clamp(m.sum(1), min=1e-9))
517
+
518
+ def fwd(e_a, e_b):
519
+ return model.projection_head(encode(e_a, mask_a), encode(e_b, mask_b))
520
+
521
+ def riemann(tgt, baseline, fixed, tgt_is_b):
522
  grads = []
523
+ for alpha in torch.linspace(0, 1, steps, device=device):
524
+ interp = (baseline + alpha * (tgt - baseline)).detach().requires_grad_(True)
525
+ out = fwd(fixed, interp) if tgt_is_b else fwd(interp, fixed)
526
  out.sum().backward()
527
+ grads.append(interp.grad.detach().float().clone())
528
+ avg_grad = torch.stack(grads).mean(0)
529
+ ig = (avg_grad * (tgt - baseline)).abs().sum(-1).squeeze(0)
530
+ return ig.cpu().numpy()
531
+
532
+ baseline_a = torch.zeros_like(emb_a)
533
+ baseline_b = torch.zeros_like(emb_b)
534
+
535
+ attr_a = riemann(emb_a, baseline_a, emb_b, False)
536
+ attr_b = riemann(emb_b, baseline_b, emb_a, True)
537
+
538
  def norm(a):
539
+ lo, hi = a.min(), a.max()
540
+ if hi - lo < 1e-9:
541
+ return [0.0] * len(a) # flat → return zeros, not NaN
542
+ return ((a - lo) / (hi - lo)).tolist()
543
+
544
+ # Slice off [CLS] and [EOS] tokens (indices 0 and -1)
545
  return norm(attr_a[1:-1]), norm(attr_b[1:-1])
546
 
547
 
548
  # ══════════════════════════════════════════════════════════════════════════════
549
+ # NGL VIEWER — theme-reactive, shows only filtered PDB chains
550
  # ══════════════════════════════════════════════════════════════════════════════
551
  def ngl_viewer_html(pdb_content, ig_chain_map=None, height=420) -> str:
552
  ig_json = json.dumps(ig_chain_map) if ig_chain_map else "null"
 
558
  *{{box-sizing:border-box;margin:0;padding:0}}
559
  body{{overflow:hidden;font-family:'JetBrains Mono',monospace;transition:background .35s}}
560
  #vp{{width:100%;height:{height}px}}
561
+ #hint{{position:absolute;top:10px;left:12px;font-size:10px;pointer-events:none;letter-spacing:.04em;transition:color .3s}}
 
562
  #ctrl{{position:absolute;bottom:12px;right:12px;display:flex;flex-direction:column;gap:4px}}
563
  .cb{{padding:4px 11px;font-family:'JetBrains Mono',monospace;font-size:9px;font-weight:500;
564
  border-radius:5px;cursor:pointer;letter-spacing:.1em;text-transform:uppercase;transition:all .15s}}
 
575
  </div>
576
  <div id="leg"></div>
577
  <script>
578
+ var T={{
579
+ light:{{bg:'#f0f4fa',hint:'rgba(51,65,85,.5)',
580
+ btn:'rgba(240,244,250,.93)',btn_b:'rgba(37,99,235,.16)',btn_c:'rgba(71,85,105,.6)',
 
 
 
581
  btn_hc:'#2563eb',btn_hb:'rgba(37,99,235,.55)',btn_hbg:'rgba(37,99,235,.06)',
582
  leg:'rgba(240,244,250,.93)',leg_b:'rgba(37,99,235,.13)',leg_c:'rgba(71,85,105,.65)',
583
+ ig_hi:'#1d4ed8',ig_lo:'#c7d7f0',null_col:0xe8ecf4}},
584
+ dark:{{bg:'#06090f',hint:'rgba(168,184,216,.4)',
585
+ btn:'rgba(6,9,15,.88)',btn_b:'rgba(96,165,250,.16)',btn_c:'rgba(168,184,216,.55)',
 
 
 
 
 
586
  btn_hc:'#60a5fa',btn_hb:'rgba(96,165,250,.55)',btn_hbg:'rgba(96,165,250,.06)',
587
  leg:'rgba(6,9,15,.88)',leg_b:'rgba(96,165,250,.13)',leg_c:'rgba(168,184,216,.55)',
588
+ ig_hi:'#60a5fa',ig_lo:'#172135',null_col:0x172135}}
 
 
589
  }};
 
590
  function isDark(){{return window.matchMedia('(prefers-color-scheme:dark)').matches;}}
591
+ var theme=isDark()?T.dark:T.light,stage=null,comp=null,curRep='cartoon',igData={ig_json};
 
 
592
 
593
+ function igColor(s,dark){{
 
 
594
  s=Math.max(0,Math.min(1,s||0));
595
+ if(dark) return(Math.round(23+s*73)<<16)|(Math.round(33+s*132)<<8)|Math.round(53+s*197);
596
+ return(Math.round(199-s*170)<<16)|(Math.round(215-s*137)<<8)|Math.round(240-s*24);
597
  }}
 
598
  function styleUI(dark){{
599
+ var Th=dark?T.dark:T.light; theme=Th;
600
+ document.body.style.background=Th.bg;
601
+ document.getElementById('hint').style.color=Th.hint;
 
602
  document.querySelectorAll('.cb').forEach(function(b){{
603
  b.style.background=Th.btn;b.style.border='1px solid '+Th.btn_b;b.style.color=Th.btn_c;
604
  b.onmouseenter=function(){{b.style.borderColor=Th.btn_hb;b.style.color=Th.btn_hc;b.style.background=Th.btn_hbg;}};
605
+ b.onmouseleave=function(){{b.style.borderColor=Th.btn_b;b.style.color=Th.btn_c;b.style.background=Th.btn;}};
606
  }});
607
  var leg=document.getElementById('leg');
608
  leg.style.background=Th.leg;leg.style.border='1px solid '+Th.leg_b;leg.style.color=Th.leg_c;
609
+ if(stage)stage.setParameters({{backgroundColor:Th.bg}});
610
  }}
 
611
  function addRepr(c){{
612
  comp=c;comp.removeAllRepresentations();
613
  var dark=isDark(),Th=dark?T.dark:T.light,leg=document.getElementById('leg');
 
617
  var cd=igData[atom.chainname];
618
  if(!cd||!cd.scores||!cd.scores.length) return Th.null_col;
619
  var i=atom.resno-cd.start;
620
+ return(i<0||i>=cd.scores.length)?Th.null_col:igColor(cd.scores[i],dark);
621
  }};
622
  }});
623
  comp.addRepresentation(curRep,{{color:sid}});
624
  leg.innerHTML='<span style="color:'+Th.ig_hi+'">■</span> High IG &nbsp;&nbsp;<span style="color:'+Th.ig_lo+'">■</span> Low IG';
625
+ }}else{{
626
  comp.addRepresentation(curRep,{{colorScheme:'chainname'}});
627
  leg.innerHTML='Coloured by chain';
628
  }}
629
  stage.autoView();
630
  }}
 
631
  function setRep(r){{curRep=r;if(comp){{comp.removeAllRepresentations();addRepr(comp);}}}}
632
+ function themeSwitch(dark){{styleUI(dark);if(comp){{comp.removeAllRepresentations();addRepr(comp);}}}}
633
+ window.matchMedia('(prefers-color-scheme:dark)').addEventListener('change',function(e){{themeSwitch(e.matches);}});
634
+ window.addEventListener('message',function(e){{if(e.data&&e.data.balmTheme)themeSwitch(e.data.balmTheme==='dark');}});
 
 
 
 
 
 
 
 
 
635
  window.addEventListener('resize',function(){{if(stage)stage.handleResize();}});
636
+ stage=new NGL.Stage('vp',{{backgroundColor:theme.bg,quality:'medium',tooltip:true}});
 
637
  styleUI(isDark());
 
638
  var blob=new Blob([`{escaped}`],{{type:'text/plain'}});
639
  stage.loadFile(URL.createObjectURL(blob),{{ext:'pdb',name:'s'}}).then(addRepr);
640
  </script></body></html>"""
641
 
642
 
643
  # ══════════════════════════════════════════════════════════════════════════════
644
+ # PLOTLY
645
  # ══════════════════════════════════════════════════════════════════════════════
646
  _PL = dict(paper_bgcolor="rgba(255,255,255,.97)",plot_bgcolor="rgba(241,244,249,1)")
647
  _FONT = dict(family="JetBrains Mono, monospace",color="#475569",size=9)
 
657
  fig = go.Figure(go.Heatmap(z=rows,text=hover,hoverinfo="text",
658
  colorscale=[[0,"#e8f0fe"],[.4,"#93c5fd"],[.75,"#3b82f6"],[1,"#1d4ed8"]],
659
  zmin=0,zmax=1,xgap=1,ygap=1,
660
+ colorbar=dict(thickness=10,len=.9,tickfont=dict(**_FONT),title=dict(text="IG",font=dict(**_FONT)))))
 
661
  fig.update_layout(
662
  title=dict(text=title,font=dict(family="JetBrains Mono, monospace",size=10,color="#1d4ed8")),
663
  **_PL,height=200,margin=dict(t=30,b=24,l=55,r=45),
 
699
  '<div style="line-height:2.2;word-break:break-all">'+"".join(cells)+"</div>")
700
 
701
 
 
 
 
702
  PPI_TEMPLATE = "seq_a,seq_b\nACDEFGHIKLMNPQRSTVWY,QWERTYIPASDFGKLCVBNM\n"
703
  ABAG_TEMPLATE = "heavy_chain,light_chain,antigen\nEVQLVESGGG...,DIQMTQ...,IYSPT...\n"
704
 
 
714
  if _src in st.session_state:
715
  st.session_state[_dst] = st.session_state.pop(_src)
716
 
 
717
  if st.session_state.get("_auto_load_model") and st.session_state.model is None:
718
  st.session_state["_auto_load_model"] = False
719
+ _ensure_model_loaded(st.session_state.get("pkd_lo",1.0),st.session_state.get("pkd_hi",16.0))
 
 
720
 
 
721
  st.markdown("""
722
  <div class="app-header">
723
  <div class="app-logo">🧬</div>
 
725
  <div class="app-title">BALM-PPI Pro</div>
726
  <div class="app-subtitle">ESM-2 · LoRA · Integrated Gradients &nbsp;·&nbsp; Protein Binding Affinity Prediction</div>
727
  </div>
728
+ </div>""", unsafe_allow_html=True)
 
729
  st.divider()
730
 
731
+ # SIDEBAR
732
  with st.sidebar:
733
  st.markdown("### ⚙️ Model")
734
+ st.markdown('<div class="model-section">',unsafe_allow_html=True)
735
+ st.markdown('<div class="model-section-title">Weights Source</div>',unsafe_allow_html=True)
736
  model_src = st.radio("src",["HuggingFace (auto-cached)","Upload .pth"],
737
  key="model_src",label_visibility="collapsed")
738
  custom_w = None
 
749
 
750
  if st.button("⚡ Load Model",use_container_width=True,type="primary"):
751
  try:
752
+ if model_src=="HuggingFace (auto-cached)":
753
  with st.spinner("📥 Fetching weights (cached after first run)…"):
754
  wb = _download_weights_hf()
755
  else:
756
+ if custom_w is None: st.error("Upload a .pth file first."); st.stop()
 
757
  wb = custom_w.read()
758
  with st.spinner("🔧 Building ESM-2 + LoRA…"):
759
  m,dev = build_and_load_model(wb,(pkd_lo,pkd_hi))
 
764
  st.error(f"❌ {e}"); st.code(traceback.format_exc())
765
 
766
  if st.session_state.model:
767
+ st.markdown(f'<div class="ready-badge"><span class="ready-dot"></span>READY &nbsp;·&nbsp; {st.session_state.device}</div>',unsafe_allow_html=True)
 
 
 
768
  else:
769
  st.markdown('<div class="idle-badge">○ &nbsp;NOT LOADED</div>',unsafe_allow_html=True)
770
 
 
773
  res = st.session_state.result
774
  st.markdown("**Latest Result**")
775
  st.markdown(
776
+ f'<div class="pkd-card"><div class="pkd-lbl">Predicted pKd</div>'
 
777
  f'<div class="pkd-val">{res["pkd"]:.3f}</div>'
778
  f'<div class="pkd-lbl" style="margin-top:10px">Cosine Similarity</div>'
779
+ f'<div style="font-family:var(--mono);font-size:1rem;font-weight:700;color:var(--text0)">{res["cosine"]:.4f}</div></div>',
 
780
  unsafe_allow_html=True)
781
 
782
  st.divider()
 
786
  '💻 <a href="https://github.com/rgorantla04/BALM-PPI" style="color:var(--accent);text-decoration:none">rgorantla04/BALM-PPI</a>'
787
  '</div>',unsafe_allow_html=True)
788
 
789
+ # TABS
790
+ tab_single,tab_batch,tab_info = st.tabs(["🎯 Single Prediction","📂 Batch Prediction","📚 Model Info"])
 
791
 
792
  with tab_single:
793
  mode_opts = ["Protein-Protein","Antibody-Antigen"]
794
+ mode = st.radio("im_r",mode_opts,horizontal=True,label_visibility="collapsed",key="interaction_mode")
 
795
  st.markdown("---")
796
 
797
+ if mode=="Protein-Protein":
798
  sc1,sc2 = st.columns(2)
799
  with sc1:
800
  st.markdown("**Target Protein — Seq A**")
 
804
  st.markdown("**Binder Protein — Seq B** &nbsp; `|` = chain separator")
805
  seq_b_raw = st.text_area("Binder",label_visibility="collapsed",key="ta_ppi_b",
806
  height=140,placeholder="Paste FASTA or raw sequence…")
807
+ seq_a = clean_multi(seq_a_raw); seq_b = clean_multi(seq_b_raw)
 
808
  if not st.session_state.pdb_content:
809
  if seq_a:
810
  st.session_state.chain_info_a = [
 
818
  sc1,sc2 = st.columns([2,3])
819
  with sc1:
820
  st.markdown("**Antigen / Target — Seq A**")
821
+ ag_raw = st.text_area("Antigen",label_visibility="collapsed",key="ta_ab_ag",height=140,placeholder="Antigen sequence…")
 
822
  with sc2:
823
  ah,al = st.columns(2)
824
  with ah:
825
  st.markdown("**Heavy Chain (H)**")
826
+ h_raw = st.text_area("Heavy",label_visibility="collapsed",key="ta_ab_h",height=140,placeholder="VH sequence…")
 
827
  with al:
828
  st.markdown("**Light Chain (L)**")
829
+ l_raw = st.text_area("Light",label_visibility="collapsed",key="ta_ab_l",height=140,placeholder="VL sequence…")
830
+ seq_a = clean_multi(ag_raw); h_seq = clean_multi(h_raw); l_seq = clean_multi(l_raw)
 
 
 
831
  seq_b = f"{h_seq}|{l_seq}" if (h_seq or l_seq) else ""
832
  if not st.session_state.pdb_content:
833
  if seq_a:
834
+ st.session_state.chain_info_a = [{"id":"Ag","seq":seq_a,"resnos":list(range(1,len(seq_a)+1))}]
 
835
  if h_seq or l_seq:
836
  st.session_state.chain_info_b = [
837
  {"id":"H","seq":h_seq,"resnos":list(range(1,len(h_seq)+1))},
 
844
  if st.session_state.model is None:
845
  st.caption("⬅ Load model in sidebar first")
846
  with rc2:
847
+ run_ig = st.checkbox("Run Integrated Gradients",value=True,help="~45-90s CPU · <5s GPU")
 
848
  with rc3:
849
  if st.session_state.model:
850
+ st.markdown('<div class="ready-badge" style="margin-top:6px;font-size:.7rem"><span class="ready-dot"></span>READY</div>',unsafe_allow_html=True)
 
851
 
852
  if run_btn:
853
  if not seq_a or not seq_b:
 
864
  except Exception as e:
865
  st.error(f"Prediction failed: {e}"); st.code(traceback.format_exc())
866
  if run_ig and st.session_state.result:
867
+ with st.spinner("Computing Integrated Gradients (15 steps)…"):
868
  try:
869
+ ig_a,ig_b = compute_ig(st.session_state.model,seq_a,seq_b,steps=15)
870
  st.session_state.ig_a = ig_a
871
  st.session_state.ig_b = ig_b
872
  if st.session_state.chain_info_a or st.session_state.chain_info_b:
873
  st.session_state.ig_chain_map = build_ig_chain_map(
874
  st.session_state.chain_info_a,st.session_state.chain_info_b,ig_a,ig_b)
875
  except Exception as e:
876
+ st.warning(f"IG failed: {e}\n{traceback.format_exc()}")
877
  st.rerun()
878
 
879
  if st.session_state.result:
880
  res = st.session_state.result
881
  pkd,cos = res["pkd"],res["cosine"]
882
  pct = max(0.0,min(100.0,((pkd-pkd_lo)/max(pkd_hi-pkd_lo,1e-9))*100))
883
+ strength = "Weak" if pct<30 else "Moderate" if pct<55 else "Strong" if pct<75 else "Very Strong"
884
+ badge_cls = "badge-weak" if pct<30 else "badge-moderate" if pct<55 else "badge-strong"
885
  st.markdown("---")
886
  m1,m2,m3 = st.columns([1,1,2])
887
  m1.metric("Predicted pKd",f"{pkd:.3f}")
888
  m2.metric("Cosine Similarity",f"{cos:.4f}")
889
  with m3:
890
  st.markdown(
891
+ f'<div class="pkd-card"><div class="str-labels">'
892
+ f'<span>Weak ({pkd_lo:.0f})</span>'
893
  f'<span style="color:var(--text0);font-weight:600">{strength}</span>'
894
  f'<span>Strong ({pkd_hi:.0f})</span></div>'
895
  f'<div class="str-bar"><div class="str-fill" style="width:{pct:.1f}%"></div></div>'
896
+ f'<span class="pkd-badge {badge_cls}">{strength}</span></div>',
897
+ unsafe_allow_html=True)
898
 
899
  st.markdown("---")
900
  mode_now = st.session_state.get("interaction_mode","Protein-Protein")
 
919
  st.markdown(
920
  f'<div class="ex-card"><div class="ex-pdb">{ex["label"]}</div>'
921
  f'<div class="ex-sub">{ex["subtitle"]}</div>'
922
+ f'<div class="ex-desc">{ex["desc"]}</div></div>',unsafe_allow_html=True)
 
923
  if st.button(f"⬇ Load {ex['pdb']}",key=f"ex_{ex['pdb']}",use_container_width=True):
924
  ok = populate_from_pdb(ex["pdb"],ex["mode"],
925
  chain_a=ex.get("chain_a"),chain_b=ex.get("chain_b"),
926
+ chain_h=ex.get("chain_h"),chain_l=ex.get("chain_l"),chain_ag=ex.get("chain_ag"))
 
927
  if ok:
928
  if st.session_state.model is None:
929
  st.session_state["_auto_load_model"] = True
 
933
  st.markdown('<div class="sec-hdr">Custom PDB Fetch</div>',unsafe_allow_html=True)
934
  pdb_in = st.text_input("PDB ID",placeholder="6M0J, 1BRS, 2VXQ…",key="pdb_id_input")
935
 
936
+ if mode=="Protein-Protein":
937
  st.caption("Comma-separate chains for multi-chain, e.g. `B,C`")
938
  fc1,fc2 = st.columns(2)
939
  ca_in = fc1.text_input("Side A chain(s)",placeholder="A or A,B",key="ca_in")
 
959
  if ok: st.rerun()
960
 
961
  with right_col:
962
+ vt = st.tabs(["🧬 Structure",
963
+ f"📊 {lbl_a} Strip",f"🗺 {lbl_a} Heatmap",
964
+ f"📊 {lbl_b} Strip",f"🗺 {lbl_b} Heatmap",
965
+ "📥 Download"])
 
 
966
 
967
  with vt[0]:
968
  if st.session_state.pdb_content:
969
+ st.caption("🎨 IG attribution (selected chains only)" if st.session_state.ig_chain_map
970
+ else "🔗 Selected chains — run prediction with IG for attribution colouring")
971
  st.components.v1.html(
972
+ ngl_viewer_html(st.session_state.pdb_content,st.session_state.ig_chain_map,height=420),
 
973
  height=425,scrolling=False)
974
  else:
975
  st.markdown(
976
+ '<div class="ngl-placeholder"><div style="font-size:2.2rem">🧬</div>'
 
977
  '<div style="margin-top:12px;font-weight:600">Load an example or fetch a PDB</div>'
978
+ '<div style="font-size:.72rem;margin-top:4px">Only selected chains will be shown</div>'
979
  '</div>',unsafe_allow_html=True)
980
 
981
  with vt[1]:
982
  if ig_a_d and seq_a_d:
983
  st.markdown(residue_strip_html(seq_a_d,ig_a_d),unsafe_allow_html=True)
984
+ st.plotly_chart(top10_bar(seq_a_d,ig_a_d,"#2563eb",f"Top 10 · {lbl_a}"),use_container_width=True)
 
985
  else:
986
  st.info("Run prediction with **Integrated Gradients** to see attribution.")
987
 
 
995
  with vt[3]:
996
  if ig_b_d and seq_b_d:
997
  st.markdown(residue_strip_html(seq_b_d,ig_b_d),unsafe_allow_html=True)
998
+ st.plotly_chart(top10_bar(seq_b_d,ig_b_d,"#7c3aed",f"Top 10 · {lbl_b}"),use_container_width=True)
 
999
  else:
1000
  st.info("Run prediction with **Integrated Gradients** to see attribution.")
1001
 
 
1010
  if st.session_state.result:
1011
  if ig_a_d or ig_b_d:
1012
  dl1,dl2 = st.columns(2)
1013
+ buf = sio.StringIO(); wc = csv.writer(buf)
 
1014
  wc.writerow(["chain","position","residue","ig_score"])
1015
+ for i,(aa,sc) in enumerate(zip(seq_a_d or "",ig_a_d or [])): wc.writerow(["Target",i+1,aa,f"{sc:.6f}"])
1016
+ for i,(aa,sc) in enumerate(zip(seq_b_d or "",ig_b_d or [])): wc.writerow(["proteina",i+1,aa,f"{sc:.6f}"])
1017
+ dl1.download_button("📥 IG Scores CSV",buf.getvalue().encode(),"ig_scores.csv","text/csv",use_container_width=True)
1018
+ res_dl = st.session_state.result
 
 
 
1019
  summary = {
1020
  "pkd":res_dl["pkd"],"cosine":res_dl["cosine"],
1021
+ "Target_length":len(seq_a_d or ""),"proteina_length":len(seq_b_d or ""),
 
1022
  "top10_Target":sorted([{"res":f"{(seq_a_d or '')[i]}{i+1}","ig":(ig_a_d or [])[i]}
1023
+ for i in range(min(len(seq_a_d or ""),len(ig_a_d or [])))],key=lambda x:-x["ig"])[:10],
 
1024
  "top10_proteina":sorted([{"res":f"{(seq_b_d or '')[i]}{i+1}","ig":(ig_b_d or [])[i]}
1025
+ for i in range(min(len(seq_b_d or ""),len(ig_b_d or [])))],key=lambda x:-x["ig"])[:10],
 
1026
  }
1027
+ dl2.download_button("📥 Summary JSON",json.dumps(summary,indent=2).encode(),"balm_result.json","application/json",use_container_width=True)
 
 
 
1028
  else:
1029
  res_dl = st.session_state.result
1030
  st.info("Enable IG and re-run for per-residue CSV.")
1031
  st.download_button("📥 Basic Result JSON",
1032
+ json.dumps({"pkd":res_dl["pkd"],"cosine":res_dl["cosine"]},indent=2).encode(),
1033
+ "balm_result.json","application/json")
1034
  else:
1035
  st.info("Run a prediction to enable download.")
1036
 
 
1049
  df = pd.read_csv(bf)
1050
  st.dataframe(df.head(),use_container_width=True)
1051
  if st.button("🏃 Run Batch",type="primary",use_container_width=True):
1052
+ if st.session_state.model is None: st.error("Load the model first.")
 
1053
  else:
1054
  model,results = st.session_state.model,[]
1055
  prog,stat,n = st.progress(0.0),st.empty(),len(df)
 
1067
  pkd_t,cos_t = model(sa,sb)
1068
  rr={"pKd":round(float(pkd_t.item()),4),"Cosine":round(float(cos_t.item()),6)}
1069
  if run_ig_b:
1070
+ ia,ib=compute_ig(model,sa,sb,steps=15)
1071
  sap,sbp=sa.replace("|",""),sb.replace("|","")
1072
  rr["top5_Target"]="|".join(f"{sap[j]}{j+1}:{ia[j]:.3f}"
1073
  for j in sorted(range(min(len(sap),len(ia))),key=lambda x:-ia[x])[:5])
 
1080
  df_res = pd.concat([df,pd.DataFrame(results)],axis=1)
1081
  st.success(f"✅ {n} predictions complete")
1082
  st.dataframe(df_res,use_container_width=True)
1083
+ st.download_button("📥 Download Results CSV",df_res.to_csv(index=False).encode(),"balm_batch_results.csv","text/csv",use_container_width=True)
 
1084
 
1085
  with tab_info:
1086
  st.markdown("### 📚 About BALM-PPI")
1087
  st.markdown("""
1088
+ **BALM-PPI** predicts protein–protein and antibody–antigen binding affinity using ESM-2 + LoRA
1089
+ with Integrated Gradients explainability.
1090
 
1091
  | Component | Detail |
1092
  |-----------|--------|
 
1095
  | Column mapping | `seq_a` → **Target** · `seq_b` → **proteina** |
1096
  | Multi-chain | `\\|` separator → `<cls><cls>` double CLS token |
1097
  | Affinity output | pKd = ((cos+1)/2) × (pKd_max − pKd_min) + pKd_min |
1098
+ | IG | Integrated Gradients, **15-step** Riemann (float32-safe, HF Spaces compatible) |
1099
+ | Viewer | PDB filtered to **selected chains only** |
1100
+ | Speed | hf_transfer · use_safetensors · HF disk cache · PDB cache (1 h TTL) |
1101
 
1102
  **Examples** · **6M0J** — ACE2 ↔ SARS-CoV-2 RBD &nbsp;·&nbsp; **1BRS** — Barnase ↔ Barstar
1103