Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
| 1 |
"""
|
| 2 |
BALM-PPI Pro · ESM-2 + LoRA + Integrated Gradients
|
| 3 |
-
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
| 5 |
"""
|
| 6 |
|
| 7 |
-
# ── Speed tweaks — must be BEFORE any HF/torch imports ─────────────────────
|
| 8 |
import os
|
| 9 |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
| 10 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
@@ -19,7 +21,7 @@ import requests
|
|
| 19 |
from Bio import PDB
|
| 20 |
from Bio.Data.PDBData import protein_letters_3to1 as THREE_TO_ONE
|
| 21 |
|
| 22 |
-
st.set_page_config(page_title="BALM-PPI ", page_icon="🧬", layout="wide",
|
| 23 |
initial_sidebar_state="expanded")
|
| 24 |
|
| 25 |
HF_REPO_ID = "Harshit494/BALM-PPI"
|
|
@@ -39,7 +41,6 @@ EXAMPLES = [
|
|
| 39 |
},
|
| 40 |
]
|
| 41 |
|
| 42 |
-
# ── Session state ─────────────────────────────────────────────────────────
|
| 43 |
_DEFAULTS: dict = {
|
| 44 |
"model": None, "device": None,
|
| 45 |
"result": None, "ig_a": None, "ig_b": None, "ig_chain_map": None,
|
|
@@ -54,13 +55,11 @@ for _k, _v in _DEFAULTS.items():
|
|
| 54 |
st.session_state[_k] = _v
|
| 55 |
|
| 56 |
# ══════════════════════════════════════════════════════════════════════════════
|
| 57 |
-
# CSS — light-first
|
| 58 |
# ══════════════════════════════════════════════════════════════════════════════
|
| 59 |
st.markdown("""
|
| 60 |
<style>
|
| 61 |
@import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400;500;600;700&family=Inter:wght@300;400;500;600;700&display=swap');
|
| 62 |
-
|
| 63 |
-
/* LIGHT tokens */
|
| 64 |
:root {
|
| 65 |
--bg0:#ffffff; --bg1:#f8f9fc; --bg2:#f1f4f9; --bg3:#e4e9f2;
|
| 66 |
--border:rgba(37,99,235,0.13); --shadow:rgba(37,99,235,0.07);
|
|
@@ -69,7 +68,6 @@ st.markdown("""
|
|
| 69 |
--text0:#0f172a; --text1:#334155; --text2:#64748b; --text3:#94a3b8;
|
| 70 |
--mono:'JetBrains Mono',monospace; --sans:'Inter',sans-serif;
|
| 71 |
}
|
| 72 |
-
/* DARK tokens */
|
| 73 |
@media (prefers-color-scheme:dark){
|
| 74 |
:root {
|
| 75 |
--bg0:#06090f; --bg1:#0b1120; --bg2:#101828; --bg3:#172135;
|
|
@@ -79,213 +77,111 @@ st.markdown("""
|
|
| 79 |
--text0:#f1f5f9; --text1:#cbd5e1; --text2:#64748b; --text3:#334155;
|
| 80 |
}
|
| 81 |
}
|
| 82 |
-
|
| 83 |
-
/* Global */
|
| 84 |
html,body,[data-testid="stAppViewContainer"],[data-testid="stMain"]{
|
| 85 |
-
background:var(--bg0)!important;
|
| 86 |
}
|
| 87 |
[data-testid="stHeader"]{
|
| 88 |
-
background:var(--bg0)!important;
|
| 89 |
-
border-bottom:1px solid var(--border)!important;
|
| 90 |
box-shadow:0 1px 8px var(--shadow)!important;
|
| 91 |
}
|
| 92 |
-
[data-testid="stSidebar"]{
|
| 93 |
-
background:var(--bg1)!important;
|
| 94 |
-
border-right:1px solid var(--border)!important;
|
| 95 |
-
}
|
| 96 |
[data-testid="stSidebar"] *{font-family:var(--sans)!important;}
|
| 97 |
-
|
| 98 |
-
/* Typography */
|
| 99 |
h1,h2,h3,h4,h5,h6{font-family:var(--sans)!important;color:var(--text0)!important;font-weight:700!important;}
|
| 100 |
p,label,div,span,li{color:var(--text1);}
|
| 101 |
code,pre{font-family:var(--mono)!important;font-size:.82rem!important;
|
| 102 |
background:var(--bg2)!important;border-radius:4px!important;color:var(--accent)!important;}
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
.stButton>button{
|
| 106 |
-
font-family:var(--sans)!important;font-weight:600!important;
|
| 107 |
-
letter-spacing:.01em!important;border-radius:8px!important;
|
| 108 |
-
transition:all .15s ease!important;height:38px!important;
|
| 109 |
-
}
|
| 110 |
.stButton>button[kind="primary"]{
|
| 111 |
background:linear-gradient(135deg,var(--accent),var(--accent2))!important;
|
| 112 |
-
border:none!important;color:#fff!important;
|
| 113 |
-
|
| 114 |
-
}
|
| 115 |
-
.stButton>button[kind="
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
}
|
| 119 |
-
.
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
}
|
| 124 |
-
.
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
}
|
| 128 |
-
|
| 129 |
-
/* Text areas */
|
| 130 |
-
.stTextArea textarea{
|
| 131 |
-
background:var(--bg0)!important;border:1.5px solid var(--border)!important;
|
| 132 |
-
border-radius:8px!important;color:var(--text0)!important;
|
| 133 |
-
font-family:var(--mono)!important;font-size:.82rem!important;
|
| 134 |
-
line-height:1.65!important;resize:vertical!important;
|
| 135 |
-
}
|
| 136 |
-
.stTextArea textarea:focus{
|
| 137 |
-
border-color:var(--accent)!important;
|
| 138 |
-
box-shadow:0 0 0 3px rgba(37,99,235,.1)!important;
|
| 139 |
-
}
|
| 140 |
-
.stTextArea label{font-weight:600!important;font-size:.76rem!important;
|
| 141 |
-
text-transform:uppercase!important;letter-spacing:.09em!important;color:var(--text2)!important;}
|
| 142 |
-
|
| 143 |
-
/* Inputs */
|
| 144 |
-
.stTextInput input,.stNumberInput input{
|
| 145 |
-
background:var(--bg0)!important;border:1.5px solid var(--border)!important;
|
| 146 |
-
border-radius:8px!important;color:var(--text0)!important;
|
| 147 |
-
}
|
| 148 |
-
.stTextInput input:focus,.stNumberInput input:focus{
|
| 149 |
-
border-color:var(--accent)!important;box-shadow:0 0 0 3px rgba(37,99,235,.1)!important;
|
| 150 |
-
}
|
| 151 |
-
.stTextInput label,.stNumberInput label{
|
| 152 |
-
font-size:.76rem!important;font-weight:600!important;
|
| 153 |
-
color:var(--text2)!important;text-transform:uppercase!important;letter-spacing:.08em!important;
|
| 154 |
-
}
|
| 155 |
-
|
| 156 |
-
/* Radio/checkbox */
|
| 157 |
.stRadio label,.stCheckbox label{color:var(--text1)!important;font-size:.9rem!important;}
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
.stTabs [data-baseweb="tab
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
.
|
| 171 |
-
background:var(--bg0)!important;color:var(--accent)!important;
|
| 172 |
-
font-weight:600!important;box-shadow:0 1px 6px var(--shadow)!important;
|
| 173 |
-
}
|
| 174 |
-
|
| 175 |
-
/* Metrics */
|
| 176 |
-
[data-testid="stMetric"]{
|
| 177 |
-
background:var(--bg1)!important;border:1px solid var(--border)!important;
|
| 178 |
-
border-radius:10px!important;padding:14px 18px!important;
|
| 179 |
-
box-shadow:0 1px 4px var(--shadow)!important;
|
| 180 |
-
}
|
| 181 |
-
[data-testid="stMetricLabel"]{
|
| 182 |
-
color:var(--text2)!important;font-size:.72rem!important;
|
| 183 |
-
text-transform:uppercase!important;letter-spacing:.1em!important;
|
| 184 |
-
}
|
| 185 |
-
[data-testid="stMetricValue"]{
|
| 186 |
-
color:var(--accent)!important;font-family:var(--mono)!important;
|
| 187 |
-
font-size:1.55rem!important;font-weight:700!important;
|
| 188 |
-
}
|
| 189 |
-
|
| 190 |
-
/* Misc */
|
| 191 |
hr{border-color:var(--border)!important;margin:12px 0!important;}
|
| 192 |
[data-testid="stAlert"]{background:var(--bg1)!important;border-radius:8px!important;border-left-width:3px!important;}
|
| 193 |
[data-testid="stDataFrame"]{border:1px solid var(--border)!important;border-radius:8px!important;overflow:hidden!important;}
|
| 194 |
-
.stProgress>div>div{
|
| 195 |
-
background:linear-gradient(90deg,var(--accent),var(--accent2))!important;
|
| 196 |
-
border-radius:4px!important;
|
| 197 |
-
}
|
| 198 |
|
| 199 |
-
/* ── Custom classes ─────────────────────────────────────────── */
|
| 200 |
.app-header{display:flex;align-items:center;gap:14px;padding:4px 0 14px;}
|
| 201 |
-
.app-logo{
|
| 202 |
-
width:40px;height:40px;border-radius:10px;
|
| 203 |
background:linear-gradient(135deg,var(--accent),var(--accent2));
|
| 204 |
display:flex;align-items:center;justify-content:center;
|
| 205 |
-
font-size:20px;flex-shrink:0;box-shadow:0 2px 12px rgba(37,99,235,.22);
|
| 206 |
-
}
|
| 207 |
.app-title{font-family:var(--mono)!important;font-size:1.22rem!important;
|
| 208 |
font-weight:700!important;color:var(--text0)!important;letter-spacing:-.02em;}
|
| 209 |
-
.app-subtitle{font-size:.77rem!important;color:var(--text2)!important;
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
font-family:var(--mono);font-size:.65rem;font-weight:700;
|
| 214 |
-
letter-spacing:.2em;text-transform:uppercase;
|
| 215 |
-
color:var(--accent);margin:14px 0 8px;
|
| 216 |
-
display:flex;align-items:center;gap:8px;
|
| 217 |
-
}
|
| 218 |
.sec-hdr::after{content:'';flex:1;height:1px;background:var(--border);}
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
transition:box-shadow .15s,border-color .15s;
|
| 225 |
-
}
|
| 226 |
-
.ex-card::before{
|
| 227 |
-
content:'';position:absolute;left:0;top:0;bottom:0;width:3px;
|
| 228 |
-
background:linear-gradient(180deg,var(--accent),var(--accent2));
|
| 229 |
-
border-radius:3px 0 0 3px;
|
| 230 |
-
}
|
| 231 |
.ex-card:hover{border-color:rgba(37,99,235,.3);box-shadow:0 2px 12px var(--shadow);}
|
| 232 |
.ex-pdb{font-family:var(--mono);font-size:.95rem;font-weight:700;color:var(--accent);}
|
| 233 |
.ex-sub{font-size:.8rem;font-weight:600;color:var(--text1);margin:2px 0;}
|
| 234 |
.ex-desc{font-size:.72rem;line-height:1.5;color:var(--text2);margin-top:3px;}
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
}
|
| 241 |
-
.pkd-lbl{font-family:var(--mono);font-size:.67rem;color:var(--text2);
|
| 242 |
-
text-transform:uppercase;letter-spacing:.12em;margin-bottom:2px;}
|
| 243 |
-
.pkd-val{font-family:var(--mono);font-size:2rem;font-weight:700;
|
| 244 |
-
color:var(--accent);line-height:1.1;}
|
| 245 |
-
.pkd-badge{display:inline-block;padding:2px 9px;border-radius:20px;
|
| 246 |
-
font-size:.7rem;font-weight:600;font-family:var(--mono);margin-top:5px;}
|
| 247 |
.badge-weak{background:rgba(220,38,38,.08);color:var(--red);border:1px solid rgba(220,38,38,.22);}
|
| 248 |
.badge-moderate{background:rgba(217,119,6,.08);color:var(--amber);border:1px solid rgba(217,119,6,.22);}
|
| 249 |
.badge-strong{background:rgba(22,163,74,.08);color:var(--green);border:1px solid rgba(22,163,74,.22);}
|
| 250 |
-
|
| 251 |
.str-bar{height:5px;border-radius:3px;background:var(--bg3);overflow:hidden;margin:8px 0 3px;}
|
| 252 |
-
.str-fill{height:100%;border-radius:3px;
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
.str-labels{display:flex;justify-content:space-between;
|
| 256 |
-
font-size:.65rem;color:var(--text2);font-family:var(--mono);}
|
| 257 |
-
|
| 258 |
-
.ready-badge{
|
| 259 |
-
display:inline-flex;align-items:center;gap:6px;
|
| 260 |
-
padding:3px 10px;border-radius:20px;
|
| 261 |
background:rgba(22,163,74,.07);border:1px solid rgba(22,163,74,.22);
|
| 262 |
-
font-family:var(--mono);font-size:.72rem;color:var(--green);font-weight:600;
|
| 263 |
-
}
|
| 264 |
-
.
|
| 265 |
-
background:var(--green);animation:pulse 2s infinite;}
|
| 266 |
-
.idle-badge{
|
| 267 |
-
display:inline-flex;align-items:center;gap:6px;
|
| 268 |
-
padding:3px 10px;border-radius:20px;
|
| 269 |
background:rgba(100,116,139,.07);border:1px solid rgba(100,116,139,.18);
|
| 270 |
-
font-family:var(--mono);font-size:.72rem;color:var(--text2);font-weight:600;
|
| 271 |
-
}
|
| 272 |
@keyframes pulse{0%,100%{opacity:1;transform:scale(1)}50%{opacity:.5;transform:scale(.85)}}
|
| 273 |
-
|
| 274 |
-
.model-section{
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
font-family:var(--mono);font-size:.8rem;text-align:center;line-height:2;
|
| 284 |
-
}
|
| 285 |
</style>
|
| 286 |
""", unsafe_allow_html=True)
|
| 287 |
|
| 288 |
-
# postMessage bridge
|
| 289 |
st.markdown("""
|
| 290 |
<script>
|
| 291 |
(function(){
|
|
@@ -294,16 +190,14 @@ st.markdown("""
|
|
| 294 |
try{f.contentWindow.postMessage({balmTheme:dark?'dark':'light'},'*');}catch(e){}
|
| 295 |
});
|
| 296 |
}
|
| 297 |
-
window.matchMedia('(prefers-color-scheme:dark)').addEventListener('change',function(e){
|
| 298 |
-
notify(e.matches);
|
| 299 |
-
});
|
| 300 |
})();
|
| 301 |
</script>
|
| 302 |
""", unsafe_allow_html=True)
|
| 303 |
|
| 304 |
|
| 305 |
# ══════════════════════════════════════════════════════════════════════════════
|
| 306 |
-
# PDB UTILITIES
|
| 307 |
# ══════════════════════════════════════════════════════════════════════════════
|
| 308 |
|
| 309 |
@st.cache_data(show_spinner=False, ttl=3600)
|
|
@@ -314,6 +208,23 @@ def _fetch_pdb_cached(pdb_id: str) -> str:
|
|
| 314 |
return r.text
|
| 315 |
|
| 316 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
def get_chains_from_pdb(pdb_id: str):
|
| 318 |
try:
|
| 319 |
pdb_text = _fetch_pdb_cached(pdb_id.strip().upper())
|
|
@@ -367,17 +278,18 @@ def populate_from_pdb(pdb_id, mode, chain_a=None, chain_b=None,
|
|
| 367 |
if not pdb_content or not chains: return False
|
| 368 |
available = sorted(chains.keys())
|
| 369 |
st.info(f"✅ **{pdb_id.upper()}** — chains: **{', '.join(available)}**")
|
| 370 |
-
|
| 371 |
-
st.session_state.ig_a
|
| 372 |
-
st.session_state.
|
| 373 |
-
st.session_state.ig_chain_map = None
|
| 374 |
-
st.session_state.result = None
|
| 375 |
st.session_state.chain_info_a = []
|
| 376 |
st.session_state.chain_info_b = []
|
| 377 |
st.session_state["_pm"] = mode
|
|
|
|
|
|
|
|
|
|
| 378 |
if mode == "Protein-Protein":
|
| 379 |
id_a = chain_a or available[0]
|
| 380 |
-
id_b = chain_b or (available[1] if len(available)>1 else available[0])
|
| 381 |
seq_a, info_a, miss_a = resolve_chains(chains, id_a, available)
|
| 382 |
seq_b, info_b, miss_b = resolve_chains(chains, id_b, available)
|
| 383 |
for m in miss_a: st.warning(f"Chain **{m}** not found (Side A). Available: {', '.join(available)}")
|
|
@@ -386,10 +298,11 @@ def populate_from_pdb(pdb_id, mode, chain_a=None, chain_b=None,
|
|
| 386 |
st.session_state["_pb"] = seq_b
|
| 387 |
st.session_state.chain_info_a = info_a
|
| 388 |
st.session_state.chain_info_b = info_b
|
|
|
|
| 389 |
else:
|
| 390 |
-
id_h = chain_h
|
| 391 |
-
id_l = chain_l
|
| 392 |
-
id_ag_str = chain_ag or next((c for c in available if c not in (id_h,id_l)), available[0])
|
| 393 |
def _single(cid):
|
| 394 |
if cid in chains and chains[cid]["seq"]: return cid, chains[cid]
|
| 395 |
f = next((k for k in chains if k.upper()==cid.upper() and chains[k]["seq"]), None)
|
|
@@ -405,6 +318,17 @@ def populate_from_pdb(pdb_id, mode, chain_a=None, chain_b=None,
|
|
| 405 |
st.session_state["_pag"] = seq_ag
|
| 406 |
st.session_state.chain_info_a = info_ag if info_ag else []
|
| 407 |
st.session_state.chain_info_b = [{"id":real_h,**ch},{"id":real_l,**cl}]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 408 |
return True
|
| 409 |
|
| 410 |
|
|
@@ -444,12 +368,12 @@ def flat_ig_for_display(ig_a, ig_b, chain_info_a, chain_info_b):
|
|
| 444 |
seq_a = "".join(c["seq"] for c in chain_info_a) if chain_info_a else ""
|
| 445 |
seq_b = "".join(c["seq"] for c in chain_info_b) if chain_info_b else ""
|
| 446 |
if ig_a and chain_info_a:
|
| 447 |
-
parts
|
| 448 |
ig_a_out = [s for sub in parts for s in sub]
|
| 449 |
else:
|
| 450 |
ig_a_out = list(ig_a) if ig_a else []
|
| 451 |
if ig_b and chain_info_b:
|
| 452 |
-
parts
|
| 453 |
ig_b_out = [s for sub in parts for s in sub]
|
| 454 |
else:
|
| 455 |
ig_b_out = list(ig_b) if ig_b else []
|
|
@@ -458,22 +382,29 @@ def flat_ig_for_display(ig_a, ig_b, chain_info_a, chain_info_b):
|
|
| 458 |
|
| 459 |
# ══════════════════════════════════════════════════════════════════════════════
|
| 460 |
# MODEL
|
|
|
|
|
|
|
|
|
|
| 461 |
# ══════════════════════════════════════════════════════════════════════════════
|
| 462 |
@st.cache_resource(show_spinner="Loading ESM-2 backbone…")
|
| 463 |
def _load_esm_base():
|
| 464 |
from transformers import EsmModel, EsmTokenizer
|
| 465 |
base = EsmModel.from_pretrained(
|
| 466 |
"facebook/esm2_t33_650M_UR50D",
|
| 467 |
-
torch_dtype="auto"
|
|
|
|
|
|
|
| 468 |
)
|
| 469 |
tok = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
| 470 |
return base, tok
|
| 471 |
|
|
|
|
| 472 |
def _download_weights_hf() -> bytes:
|
| 473 |
from huggingface_hub import hf_hub_download
|
| 474 |
path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_FILENAME, resume_download=True)
|
| 475 |
with open(path,"rb") as f: return f.read()
|
| 476 |
|
|
|
|
| 477 |
def build_and_load_model(weights_bytes, pkd_bounds):
|
| 478 |
import torch.nn as nn
|
| 479 |
import torch.nn.functional as F
|
|
@@ -528,6 +459,7 @@ def build_and_load_model(weights_bytes, pkd_bounds):
|
|
| 528 |
model.to(device).eval()
|
| 529 |
return model, device
|
| 530 |
|
|
|
|
| 531 |
def _ensure_model_loaded(pkd_lo=1.0, pkd_hi=16.0) -> bool:
|
| 532 |
if st.session_state.model is not None: return True
|
| 533 |
try:
|
|
@@ -544,46 +476,77 @@ def _ensure_model_loaded(pkd_lo=1.0, pkd_hi=16.0) -> bool:
|
|
| 544 |
|
| 545 |
# ══════════════════════════════════════════════════════════════════════════════
|
| 546 |
# INTEGRATED GRADIENTS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 547 |
# ══════════════════════════════════════════════════════════════════════════════
|
| 548 |
-
def compute_ig(model, seq_a, seq_b, steps=
|
| 549 |
device = next(model.parameters()).device
|
| 550 |
tok = model.esm_tokenizer
|
| 551 |
cls_tok = tok.cls_token
|
| 552 |
esm = model.esm_model
|
|
|
|
| 553 |
def tokenise(seq):
|
| 554 |
-
proc = seq.replace("|",f"{cls_tok}{cls_tok}")
|
| 555 |
-
return tok(proc,return_tensors="pt",
|
|
|
|
|
|
|
| 556 |
word_embed = esm.base_model.model.embeddings.word_embeddings
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 569 |
grads = []
|
| 570 |
-
for
|
| 571 |
-
|
| 572 |
-
out = fwd(fixed,
|
| 573 |
out.sum().backward()
|
| 574 |
-
grads.append(
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 579 |
def norm(a):
|
| 580 |
-
lo,hi = a.min(),a.max()
|
| 581 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 582 |
return norm(attr_a[1:-1]), norm(attr_b[1:-1])
|
| 583 |
|
| 584 |
|
| 585 |
# ══════════════════════════════════════════════════════════════════════════════
|
| 586 |
-
# NGL VIEWER — theme-reactive
|
| 587 |
# ══════════════════════════════════════════════════════════════════════════════
|
| 588 |
def ngl_viewer_html(pdb_content, ig_chain_map=None, height=420) -> str:
|
| 589 |
ig_json = json.dumps(ig_chain_map) if ig_chain_map else "null"
|
|
@@ -595,8 +558,7 @@ def ngl_viewer_html(pdb_content, ig_chain_map=None, height=420) -> str:
|
|
| 595 |
*{{box-sizing:border-box;margin:0;padding:0}}
|
| 596 |
body{{overflow:hidden;font-family:'JetBrains Mono',monospace;transition:background .35s}}
|
| 597 |
#vp{{width:100%;height:{height}px}}
|
| 598 |
-
#hint{{position:absolute;top:10px;left:12px;font-size:10px;pointer-events:none;
|
| 599 |
-
letter-spacing:.04em;transition:color .3s}}
|
| 600 |
#ctrl{{position:absolute;bottom:12px;right:12px;display:flex;flex-direction:column;gap:4px}}
|
| 601 |
.cb{{padding:4px 11px;font-family:'JetBrains Mono',monospace;font-size:9px;font-weight:500;
|
| 602 |
border-radius:5px;cursor:pointer;letter-spacing:.1em;text-transform:uppercase;transition:all .15s}}
|
|
@@ -613,57 +575,39 @@ body{{overflow:hidden;font-family:'JetBrains Mono',monospace;transition:backgrou
|
|
| 613 |
</div>
|
| 614 |
<div id="leg"></div>
|
| 615 |
<script>
|
| 616 |
-
var T
|
| 617 |
-
light:{{
|
| 618 |
-
|
| 619 |
-
hint:'rgba(51,65,85,.5)',
|
| 620 |
-
btn:'rgba(240,244,250,.93)','btn_b':'rgba(37,99,235,.16)',
|
| 621 |
-
btn_c:'rgba(71,85,105,.6)',
|
| 622 |
btn_hc:'#2563eb',btn_hb:'rgba(37,99,235,.55)',btn_hbg:'rgba(37,99,235,.06)',
|
| 623 |
leg:'rgba(240,244,250,.93)',leg_b:'rgba(37,99,235,.13)',leg_c:'rgba(71,85,105,.65)',
|
| 624 |
-
ig_hi:'#1d4ed8',ig_lo:'#c7d7f0',
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
dark:{{
|
| 628 |
-
bg:'#06090f',
|
| 629 |
-
hint:'rgba(168,184,216,.4)',
|
| 630 |
-
btn:'rgba(6,9,15,.88)',btn_b:'rgba(96,165,250,.16)',
|
| 631 |
-
btn_c:'rgba(168,184,216,.55)',
|
| 632 |
btn_hc:'#60a5fa',btn_hb:'rgba(96,165,250,.55)',btn_hbg:'rgba(96,165,250,.06)',
|
| 633 |
leg:'rgba(6,9,15,.88)',leg_b:'rgba(96,165,250,.13)',leg_c:'rgba(168,184,216,.55)',
|
| 634 |
-
ig_hi:'#60a5fa',ig_lo:'#172135',
|
| 635 |
-
null_col:0x172135,
|
| 636 |
-
}}
|
| 637 |
}};
|
| 638 |
-
|
| 639 |
function isDark(){{return window.matchMedia('(prefers-color-scheme:dark)').matches;}}
|
| 640 |
-
var theme
|
| 641 |
-
var stage = null, comp = null, curRep = 'cartoon';
|
| 642 |
-
var igData = {ig_json};
|
| 643 |
|
| 644 |
-
function
|
| 645 |
-
// light: #c7d7f0(199,215,240) → #1d4ed8(29,78,216)
|
| 646 |
-
// dark: #172135(23,33,53) → #60a5fa(96,165,250)
|
| 647 |
s=Math.max(0,Math.min(1,s||0));
|
| 648 |
-
if(dark) return
|
| 649 |
-
return
|
| 650 |
}}
|
| 651 |
-
|
| 652 |
function styleUI(dark){{
|
| 653 |
-
var Th
|
| 654 |
-
|
| 655 |
-
document.
|
| 656 |
-
document.getElementById('hint').style.color = Th.hint;
|
| 657 |
document.querySelectorAll('.cb').forEach(function(b){{
|
| 658 |
b.style.background=Th.btn;b.style.border='1px solid '+Th.btn_b;b.style.color=Th.btn_c;
|
| 659 |
b.onmouseenter=function(){{b.style.borderColor=Th.btn_hb;b.style.color=Th.btn_hc;b.style.background=Th.btn_hbg;}};
|
| 660 |
-
b.onmouseleave=function(){{b.style.borderColor=Th.btn_b;b.style.color=Th.btn_c;b.style.background=Th.btn;
|
| 661 |
}});
|
| 662 |
var leg=document.getElementById('leg');
|
| 663 |
leg.style.background=Th.leg;leg.style.border='1px solid '+Th.leg_b;leg.style.color=Th.leg_c;
|
| 664 |
-
if(stage)
|
| 665 |
}}
|
| 666 |
-
|
| 667 |
function addRepr(c){{
|
| 668 |
comp=c;comp.removeAllRepresentations();
|
| 669 |
var dark=isDark(),Th=dark?T.dark:T.light,leg=document.getElementById('leg');
|
|
@@ -673,43 +617,31 @@ function addRepr(c){{
|
|
| 673 |
var cd=igData[atom.chainname];
|
| 674 |
if(!cd||!cd.scores||!cd.scores.length) return Th.null_col;
|
| 675 |
var i=atom.resno-cd.start;
|
| 676 |
-
return(i<0||i>=cd.scores.length)?Th.null_col:
|
| 677 |
}};
|
| 678 |
}});
|
| 679 |
comp.addRepresentation(curRep,{{color:sid}});
|
| 680 |
leg.innerHTML='<span style="color:'+Th.ig_hi+'">■</span> High IG <span style="color:'+Th.ig_lo+'">■</span> Low IG';
|
| 681 |
-
}}
|
| 682 |
comp.addRepresentation(curRep,{{colorScheme:'chainname'}});
|
| 683 |
leg.innerHTML='Coloured by chain';
|
| 684 |
}}
|
| 685 |
stage.autoView();
|
| 686 |
}}
|
| 687 |
-
|
| 688 |
function setRep(r){{curRep=r;if(comp){{comp.removeAllRepresentations();addRepr(comp);}}}}
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
if(comp){{comp.removeAllRepresentations();addRepr(comp);}}
|
| 693 |
-
}}
|
| 694 |
-
|
| 695 |
-
window.matchMedia('(prefers-color-scheme:dark)').addEventListener('change',function(e){{
|
| 696 |
-
themeSwitch(e.matches);
|
| 697 |
-
}});
|
| 698 |
-
window.addEventListener('message',function(e){{
|
| 699 |
-
if(e.data&&e.data.balmTheme) themeSwitch(e.data.balmTheme==='dark');
|
| 700 |
-
}});
|
| 701 |
window.addEventListener('resize',function(){{if(stage)stage.handleResize();}});
|
| 702 |
-
|
| 703 |
-
stage = new NGL.Stage('vp',{{backgroundColor:theme.bg,quality:'medium',tooltip:true}});
|
| 704 |
styleUI(isDark());
|
| 705 |
-
|
| 706 |
var blob=new Blob([`{escaped}`],{{type:'text/plain'}});
|
| 707 |
stage.loadFile(URL.createObjectURL(blob),{{ext:'pdb',name:'s'}}).then(addRepr);
|
| 708 |
</script></body></html>"""
|
| 709 |
|
| 710 |
|
| 711 |
# ══════════════════════════════════════════════════════════════════════════════
|
| 712 |
-
# PLOTLY
|
| 713 |
# ══════════════════════════════════════════════════════════════════════════════
|
| 714 |
_PL = dict(paper_bgcolor="rgba(255,255,255,.97)",plot_bgcolor="rgba(241,244,249,1)")
|
| 715 |
_FONT = dict(family="JetBrains Mono, monospace",color="#475569",size=9)
|
|
@@ -725,8 +657,7 @@ def make_heatmap(seq, attr, title):
|
|
| 725 |
fig = go.Figure(go.Heatmap(z=rows,text=hover,hoverinfo="text",
|
| 726 |
colorscale=[[0,"#e8f0fe"],[.4,"#93c5fd"],[.75,"#3b82f6"],[1,"#1d4ed8"]],
|
| 727 |
zmin=0,zmax=1,xgap=1,ygap=1,
|
| 728 |
-
colorbar=dict(thickness=10,len=.9,tickfont=dict(**_FONT),
|
| 729 |
-
title=dict(text="IG",font=dict(**_FONT)))))
|
| 730 |
fig.update_layout(
|
| 731 |
title=dict(text=title,font=dict(family="JetBrains Mono, monospace",size=10,color="#1d4ed8")),
|
| 732 |
**_PL,height=200,margin=dict(t=30,b=24,l=55,r=45),
|
|
@@ -768,9 +699,6 @@ def residue_strip_html(seq, attr) -> str:
|
|
| 768 |
'<div style="line-height:2.2;word-break:break-all">'+"".join(cells)+"</div>")
|
| 769 |
|
| 770 |
|
| 771 |
-
# ══════════════════════════════════════════════════════════════════════════════
|
| 772 |
-
# BATCH TEMPLATES
|
| 773 |
-
# ══════════════════════════════════════════════════════════════════════════════
|
| 774 |
PPI_TEMPLATE = "seq_a,seq_b\nACDEFGHIKLMNPQRSTVWY,QWERTYIPASDFGKLCVBNM\n"
|
| 775 |
ABAG_TEMPLATE = "heavy_chain,light_chain,antigen\nEVQLVESGGG...,DIQMTQ...,IYSPT...\n"
|
| 776 |
|
|
@@ -786,14 +714,10 @@ def main():
|
|
| 786 |
if _src in st.session_state:
|
| 787 |
st.session_state[_dst] = st.session_state.pop(_src)
|
| 788 |
|
| 789 |
-
# Auto-load model from example click
|
| 790 |
if st.session_state.get("_auto_load_model") and st.session_state.model is None:
|
| 791 |
st.session_state["_auto_load_model"] = False
|
| 792 |
-
_ensure_model_loaded(
|
| 793 |
-
st.session_state.get("pkd_lo",1.0),
|
| 794 |
-
st.session_state.get("pkd_hi",16.0))
|
| 795 |
|
| 796 |
-
# Header
|
| 797 |
st.markdown("""
|
| 798 |
<div class="app-header">
|
| 799 |
<div class="app-logo">🧬</div>
|
|
@@ -801,15 +725,14 @@ def main():
|
|
| 801 |
<div class="app-title">BALM-PPI Pro</div>
|
| 802 |
<div class="app-subtitle">ESM-2 · LoRA · Integrated Gradients · Protein Binding Affinity Prediction</div>
|
| 803 |
</div>
|
| 804 |
-
</div>
|
| 805 |
-
""", unsafe_allow_html=True)
|
| 806 |
st.divider()
|
| 807 |
|
| 808 |
-
#
|
| 809 |
with st.sidebar:
|
| 810 |
st.markdown("### ⚙️ Model")
|
| 811 |
-
st.markdown('<div class="model-section">',
|
| 812 |
-
st.markdown('<div class="model-section-title">Weights Source</div>',
|
| 813 |
model_src = st.radio("src",["HuggingFace (auto-cached)","Upload .pth"],
|
| 814 |
key="model_src",label_visibility="collapsed")
|
| 815 |
custom_w = None
|
|
@@ -826,12 +749,11 @@ def main():
|
|
| 826 |
|
| 827 |
if st.button("⚡ Load Model",use_container_width=True,type="primary"):
|
| 828 |
try:
|
| 829 |
-
if model_src
|
| 830 |
with st.spinner("📥 Fetching weights (cached after first run)…"):
|
| 831 |
wb = _download_weights_hf()
|
| 832 |
else:
|
| 833 |
-
if custom_w is None:
|
| 834 |
-
st.error("Upload a .pth file first."); st.stop()
|
| 835 |
wb = custom_w.read()
|
| 836 |
with st.spinner("🔧 Building ESM-2 + LoRA…"):
|
| 837 |
m,dev = build_and_load_model(wb,(pkd_lo,pkd_hi))
|
|
@@ -842,10 +764,7 @@ def main():
|
|
| 842 |
st.error(f"❌ {e}"); st.code(traceback.format_exc())
|
| 843 |
|
| 844 |
if st.session_state.model:
|
| 845 |
-
st.markdown(
|
| 846 |
-
f'<div class="ready-badge"><span class="ready-dot"></span>'
|
| 847 |
-
f'READY · {st.session_state.device}</div>',
|
| 848 |
-
unsafe_allow_html=True)
|
| 849 |
else:
|
| 850 |
st.markdown('<div class="idle-badge">○ NOT LOADED</div>',unsafe_allow_html=True)
|
| 851 |
|
|
@@ -854,12 +773,10 @@ def main():
|
|
| 854 |
res = st.session_state.result
|
| 855 |
st.markdown("**Latest Result**")
|
| 856 |
st.markdown(
|
| 857 |
-
f'<div class="pkd-card">'
|
| 858 |
-
f'<div class="pkd-lbl">Predicted pKd</div>'
|
| 859 |
f'<div class="pkd-val">{res["pkd"]:.3f}</div>'
|
| 860 |
f'<div class="pkd-lbl" style="margin-top:10px">Cosine Similarity</div>'
|
| 861 |
-
f'<div style="font-family:var(--mono);font-size:1rem;font-weight:700;color:var(--text0)">'
|
| 862 |
-
f'{res["cosine"]:.4f}</div></div>',
|
| 863 |
unsafe_allow_html=True)
|
| 864 |
|
| 865 |
st.divider()
|
|
@@ -869,17 +786,15 @@ def main():
|
|
| 869 |
'💻 <a href="https://github.com/rgorantla04/BALM-PPI" style="color:var(--accent);text-decoration:none">rgorantla04/BALM-PPI</a>'
|
| 870 |
'</div>',unsafe_allow_html=True)
|
| 871 |
|
| 872 |
-
#
|
| 873 |
-
tab_single,
|
| 874 |
-
["🎯 Single Prediction","📂 Batch Prediction","📚 Model Info"])
|
| 875 |
|
| 876 |
with tab_single:
|
| 877 |
mode_opts = ["Protein-Protein","Antibody-Antigen"]
|
| 878 |
-
mode = st.radio("im_r",mode_opts,horizontal=True,
|
| 879 |
-
label_visibility="collapsed",key="interaction_mode")
|
| 880 |
st.markdown("---")
|
| 881 |
|
| 882 |
-
if mode
|
| 883 |
sc1,sc2 = st.columns(2)
|
| 884 |
with sc1:
|
| 885 |
st.markdown("**Target Protein — Seq A**")
|
|
@@ -889,8 +804,7 @@ def main():
|
|
| 889 |
st.markdown("**Binder Protein — Seq B** `|` = chain separator")
|
| 890 |
seq_b_raw = st.text_area("Binder",label_visibility="collapsed",key="ta_ppi_b",
|
| 891 |
height=140,placeholder="Paste FASTA or raw sequence…")
|
| 892 |
-
seq_a = clean_multi(seq_a_raw)
|
| 893 |
-
seq_b = clean_multi(seq_b_raw)
|
| 894 |
if not st.session_state.pdb_content:
|
| 895 |
if seq_a:
|
| 896 |
st.session_state.chain_info_a = [
|
|
@@ -904,26 +818,20 @@ def main():
|
|
| 904 |
sc1,sc2 = st.columns([2,3])
|
| 905 |
with sc1:
|
| 906 |
st.markdown("**Antigen / Target — Seq A**")
|
| 907 |
-
ag_raw = st.text_area("Antigen",label_visibility="collapsed",
|
| 908 |
-
key="ta_ab_ag",height=140,placeholder="Antigen sequence…")
|
| 909 |
with sc2:
|
| 910 |
ah,al = st.columns(2)
|
| 911 |
with ah:
|
| 912 |
st.markdown("**Heavy Chain (H)**")
|
| 913 |
-
h_raw = st.text_area("Heavy",label_visibility="collapsed",
|
| 914 |
-
key="ta_ab_h",height=140,placeholder="VH sequence…")
|
| 915 |
with al:
|
| 916 |
st.markdown("**Light Chain (L)**")
|
| 917 |
-
l_raw = st.text_area("Light",label_visibility="collapsed",
|
| 918 |
-
|
| 919 |
-
seq_a = clean_multi(ag_raw)
|
| 920 |
-
h_seq = clean_multi(h_raw)
|
| 921 |
-
l_seq = clean_multi(l_raw)
|
| 922 |
seq_b = f"{h_seq}|{l_seq}" if (h_seq or l_seq) else ""
|
| 923 |
if not st.session_state.pdb_content:
|
| 924 |
if seq_a:
|
| 925 |
-
st.session_state.chain_info_a = [
|
| 926 |
-
{"id":"Ag","seq":seq_a,"resnos":list(range(1,len(seq_a)+1))}]
|
| 927 |
if h_seq or l_seq:
|
| 928 |
st.session_state.chain_info_b = [
|
| 929 |
{"id":"H","seq":h_seq,"resnos":list(range(1,len(h_seq)+1))},
|
|
@@ -936,12 +844,10 @@ def main():
|
|
| 936 |
if st.session_state.model is None:
|
| 937 |
st.caption("⬅ Load model in sidebar first")
|
| 938 |
with rc2:
|
| 939 |
-
run_ig = st.checkbox("Run Integrated Gradients",value=True,
|
| 940 |
-
help="~1–2 min CPU · <10 s GPU")
|
| 941 |
with rc3:
|
| 942 |
if st.session_state.model:
|
| 943 |
-
st.markdown('<div class="ready-badge" style="margin-top:6px;font-size:.7rem">'
|
| 944 |
-
'<span class="ready-dot"></span>READY</div>',unsafe_allow_html=True)
|
| 945 |
|
| 946 |
if run_btn:
|
| 947 |
if not seq_a or not seq_b:
|
|
@@ -958,37 +864,37 @@ def main():
|
|
| 958 |
except Exception as e:
|
| 959 |
st.error(f"Prediction failed: {e}"); st.code(traceback.format_exc())
|
| 960 |
if run_ig and st.session_state.result:
|
| 961 |
-
with st.spinner("Computing Integrated Gradients…"):
|
| 962 |
try:
|
| 963 |
-
ig_a,ig_b = compute_ig(st.session_state.model,seq_a,seq_b)
|
| 964 |
st.session_state.ig_a = ig_a
|
| 965 |
st.session_state.ig_b = ig_b
|
| 966 |
if st.session_state.chain_info_a or st.session_state.chain_info_b:
|
| 967 |
st.session_state.ig_chain_map = build_ig_chain_map(
|
| 968 |
st.session_state.chain_info_a,st.session_state.chain_info_b,ig_a,ig_b)
|
| 969 |
except Exception as e:
|
| 970 |
-
st.warning(f"IG failed: {e}")
|
| 971 |
st.rerun()
|
| 972 |
|
| 973 |
if st.session_state.result:
|
| 974 |
res = st.session_state.result
|
| 975 |
pkd,cos = res["pkd"],res["cosine"]
|
| 976 |
pct = max(0.0,min(100.0,((pkd-pkd_lo)/max(pkd_hi-pkd_lo,1e-9))*100))
|
| 977 |
-
strength =
|
| 978 |
-
badge_cls =
|
| 979 |
st.markdown("---")
|
| 980 |
m1,m2,m3 = st.columns([1,1,2])
|
| 981 |
m1.metric("Predicted pKd",f"{pkd:.3f}")
|
| 982 |
m2.metric("Cosine Similarity",f"{cos:.4f}")
|
| 983 |
with m3:
|
| 984 |
st.markdown(
|
| 985 |
-
f'<div class="pkd-card">'
|
| 986 |
-
f'<
|
| 987 |
f'<span style="color:var(--text0);font-weight:600">{strength}</span>'
|
| 988 |
f'<span>Strong ({pkd_hi:.0f})</span></div>'
|
| 989 |
f'<div class="str-bar"><div class="str-fill" style="width:{pct:.1f}%"></div></div>'
|
| 990 |
-
f'<span class="pkd-badge {badge_cls}">{strength}</span>'
|
| 991 |
-
|
| 992 |
|
| 993 |
st.markdown("---")
|
| 994 |
mode_now = st.session_state.get("interaction_mode","Protein-Protein")
|
|
@@ -1013,13 +919,11 @@ def main():
|
|
| 1013 |
st.markdown(
|
| 1014 |
f'<div class="ex-card"><div class="ex-pdb">{ex["label"]}</div>'
|
| 1015 |
f'<div class="ex-sub">{ex["subtitle"]}</div>'
|
| 1016 |
-
f'<div class="ex-desc">{ex["desc"]}</div></div>',
|
| 1017 |
-
unsafe_allow_html=True)
|
| 1018 |
if st.button(f"⬇ Load {ex['pdb']}",key=f"ex_{ex['pdb']}",use_container_width=True):
|
| 1019 |
ok = populate_from_pdb(ex["pdb"],ex["mode"],
|
| 1020 |
chain_a=ex.get("chain_a"),chain_b=ex.get("chain_b"),
|
| 1021 |
-
chain_h=ex.get("chain_h"),chain_l=ex.get("chain_l"),
|
| 1022 |
-
chain_ag=ex.get("chain_ag"))
|
| 1023 |
if ok:
|
| 1024 |
if st.session_state.model is None:
|
| 1025 |
st.session_state["_auto_load_model"] = True
|
|
@@ -1029,7 +933,7 @@ def main():
|
|
| 1029 |
st.markdown('<div class="sec-hdr">Custom PDB Fetch</div>',unsafe_allow_html=True)
|
| 1030 |
pdb_in = st.text_input("PDB ID",placeholder="6M0J, 1BRS, 2VXQ…",key="pdb_id_input")
|
| 1031 |
|
| 1032 |
-
if mode
|
| 1033 |
st.caption("Comma-separate chains for multi-chain, e.g. `B,C`")
|
| 1034 |
fc1,fc2 = st.columns(2)
|
| 1035 |
ca_in = fc1.text_input("Side A chain(s)",placeholder="A or A,B",key="ca_in")
|
|
@@ -1055,34 +959,29 @@ def main():
|
|
| 1055 |
if ok: st.rerun()
|
| 1056 |
|
| 1057 |
with right_col:
|
| 1058 |
-
vt = st.tabs([
|
| 1059 |
-
|
| 1060 |
-
|
| 1061 |
-
|
| 1062 |
-
"📥 Download",
|
| 1063 |
-
])
|
| 1064 |
|
| 1065 |
with vt[0]:
|
| 1066 |
if st.session_state.pdb_content:
|
| 1067 |
-
st.caption("🎨 IG attribution" if st.session_state.ig_chain_map
|
| 1068 |
-
else "🔗
|
| 1069 |
st.components.v1.html(
|
| 1070 |
-
ngl_viewer_html(st.session_state.pdb_content,
|
| 1071 |
-
st.session_state.ig_chain_map,height=420),
|
| 1072 |
height=425,scrolling=False)
|
| 1073 |
else:
|
| 1074 |
st.markdown(
|
| 1075 |
-
'<div class="ngl-placeholder">'
|
| 1076 |
-
'<div style="font-size:2.2rem">🧬</div>'
|
| 1077 |
'<div style="margin-top:12px;font-weight:600">Load an example or fetch a PDB</div>'
|
| 1078 |
-
'<div style="font-size:.72rem;margin-top:4px">
|
| 1079 |
'</div>',unsafe_allow_html=True)
|
| 1080 |
|
| 1081 |
with vt[1]:
|
| 1082 |
if ig_a_d and seq_a_d:
|
| 1083 |
st.markdown(residue_strip_html(seq_a_d,ig_a_d),unsafe_allow_html=True)
|
| 1084 |
-
st.plotly_chart(top10_bar(seq_a_d,ig_a_d,"#2563eb",f"Top 10 · {lbl_a}"),
|
| 1085 |
-
use_container_width=True)
|
| 1086 |
else:
|
| 1087 |
st.info("Run prediction with **Integrated Gradients** to see attribution.")
|
| 1088 |
|
|
@@ -1096,8 +995,7 @@ def main():
|
|
| 1096 |
with vt[3]:
|
| 1097 |
if ig_b_d and seq_b_d:
|
| 1098 |
st.markdown(residue_strip_html(seq_b_d,ig_b_d),unsafe_allow_html=True)
|
| 1099 |
-
st.plotly_chart(top10_bar(seq_b_d,ig_b_d,"#7c3aed",f"Top 10 · {lbl_b}"),
|
| 1100 |
-
use_container_width=True)
|
| 1101 |
else:
|
| 1102 |
st.info("Run prediction with **Integrated Gradients** to see attribution.")
|
| 1103 |
|
|
@@ -1112,37 +1010,27 @@ def main():
|
|
| 1112 |
if st.session_state.result:
|
| 1113 |
if ig_a_d or ig_b_d:
|
| 1114 |
dl1,dl2 = st.columns(2)
|
| 1115 |
-
buf = sio.StringIO()
|
| 1116 |
-
wc = csv.writer(buf)
|
| 1117 |
wc.writerow(["chain","position","residue","ig_score"])
|
| 1118 |
-
for i,(aa,sc) in enumerate(zip(seq_a_d or "",ig_a_d or [])):
|
| 1119 |
-
|
| 1120 |
-
|
| 1121 |
-
|
| 1122 |
-
dl1.download_button("📥 IG Scores CSV",buf.getvalue().encode(),
|
| 1123 |
-
"ig_scores.csv","text/csv",use_container_width=True)
|
| 1124 |
-
res_dl = st.session_state.result
|
| 1125 |
summary = {
|
| 1126 |
"pkd":res_dl["pkd"],"cosine":res_dl["cosine"],
|
| 1127 |
-
"Target_length":len(seq_a_d or ""),
|
| 1128 |
-
"proteina_length":len(seq_b_d or ""),
|
| 1129 |
"top10_Target":sorted([{"res":f"{(seq_a_d or '')[i]}{i+1}","ig":(ig_a_d or [])[i]}
|
| 1130 |
-
for i in range(min(len(seq_a_d or ""),len(ig_a_d or [])))],
|
| 1131 |
-
key=lambda x:-x["ig"])[:10],
|
| 1132 |
"top10_proteina":sorted([{"res":f"{(seq_b_d or '')[i]}{i+1}","ig":(ig_b_d or [])[i]}
|
| 1133 |
-
for i in range(min(len(seq_b_d or ""),len(ig_b_d or [])))],
|
| 1134 |
-
key=lambda x:-x["ig"])[:10],
|
| 1135 |
}
|
| 1136 |
-
dl2.download_button("📥 Summary JSON",
|
| 1137 |
-
json.dumps(summary,indent=2).encode(),
|
| 1138 |
-
"balm_result.json","application/json",
|
| 1139 |
-
use_container_width=True)
|
| 1140 |
else:
|
| 1141 |
res_dl = st.session_state.result
|
| 1142 |
st.info("Enable IG and re-run for per-residue CSV.")
|
| 1143 |
st.download_button("📥 Basic Result JSON",
|
| 1144 |
-
|
| 1145 |
-
|
| 1146 |
else:
|
| 1147 |
st.info("Run a prediction to enable download.")
|
| 1148 |
|
|
@@ -1161,8 +1049,7 @@ def main():
|
|
| 1161 |
df = pd.read_csv(bf)
|
| 1162 |
st.dataframe(df.head(),use_container_width=True)
|
| 1163 |
if st.button("🏃 Run Batch",type="primary",use_container_width=True):
|
| 1164 |
-
if st.session_state.model is None:
|
| 1165 |
-
st.error("Load the model first.")
|
| 1166 |
else:
|
| 1167 |
model,results = st.session_state.model,[]
|
| 1168 |
prog,stat,n = st.progress(0.0),st.empty(),len(df)
|
|
@@ -1180,7 +1067,7 @@ def main():
|
|
| 1180 |
pkd_t,cos_t = model(sa,sb)
|
| 1181 |
rr={"pKd":round(float(pkd_t.item()),4),"Cosine":round(float(cos_t.item()),6)}
|
| 1182 |
if run_ig_b:
|
| 1183 |
-
ia,ib=compute_ig(model,sa,sb)
|
| 1184 |
sap,sbp=sa.replace("|",""),sb.replace("|","")
|
| 1185 |
rr["top5_Target"]="|".join(f"{sap[j]}{j+1}:{ia[j]:.3f}"
|
| 1186 |
for j in sorted(range(min(len(sap),len(ia))),key=lambda x:-ia[x])[:5])
|
|
@@ -1193,14 +1080,13 @@ def main():
|
|
| 1193 |
df_res = pd.concat([df,pd.DataFrame(results)],axis=1)
|
| 1194 |
st.success(f"✅ {n} predictions complete")
|
| 1195 |
st.dataframe(df_res,use_container_width=True)
|
| 1196 |
-
st.download_button("📥 Download Results CSV",df_res.to_csv(index=False).encode(),
|
| 1197 |
-
"balm_batch_results.csv","text/csv",use_container_width=True)
|
| 1198 |
|
| 1199 |
with tab_info:
|
| 1200 |
st.markdown("### 📚 About BALM-PPI")
|
| 1201 |
st.markdown("""
|
| 1202 |
-
**BALM-PPI** predicts protein–protein and antibody–antigen binding affinity
|
| 1203 |
-
|
| 1204 |
|
| 1205 |
| Component | Detail |
|
| 1206 |
|-----------|--------|
|
|
@@ -1209,8 +1095,9 @@ using ESM-2 fine-tuned with LoRA, with Integrated Gradients explainability.
|
|
| 1209 |
| Column mapping | `seq_a` → **Target** · `seq_b` → **proteina** |
|
| 1210 |
| Multi-chain | `\\|` separator → `<cls><cls>` double CLS token |
|
| 1211 |
| Affinity output | pKd = ((cos+1)/2) × (pKd_max − pKd_min) + pKd_min |
|
| 1212 |
-
|
|
| 1213 |
-
|
|
|
|
|
| 1214 |
|
| 1215 |
**Examples** · **6M0J** — ACE2 ↔ SARS-CoV-2 RBD · **1BRS** — Barnase ↔ Barstar
|
| 1216 |
|
|
|
|
| 1 |
"""
|
| 2 |
BALM-PPI Pro · ESM-2 + LoRA + Integrated Gradients
|
| 3 |
+
Fixes:
|
| 4 |
+
- IG forced to float32 (torch_dtype="auto" → fp16 kills grads on HF Spaces)
|
| 5 |
+
- IG steps reduced 30→15 for speed with negligible accuracy loss
|
| 6 |
+
- NGL viewer filters PDB to only the selected chains
|
| 7 |
+
- PDB fetches cached for 1 h
|
| 8 |
"""
|
| 9 |
|
|
|
|
| 10 |
import os
|
| 11 |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
| 12 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
|
| 21 |
from Bio import PDB
|
| 22 |
from Bio.Data.PDBData import protein_letters_3to1 as THREE_TO_ONE
|
| 23 |
|
| 24 |
+
st.set_page_config(page_title="BALM-PPI Pro", page_icon="🧬", layout="wide",
|
| 25 |
initial_sidebar_state="expanded")
|
| 26 |
|
| 27 |
HF_REPO_ID = "Harshit494/BALM-PPI"
|
|
|
|
| 41 |
},
|
| 42 |
]
|
| 43 |
|
|
|
|
| 44 |
_DEFAULTS: dict = {
|
| 45 |
"model": None, "device": None,
|
| 46 |
"result": None, "ig_a": None, "ig_b": None, "ig_chain_map": None,
|
|
|
|
| 55 |
st.session_state[_k] = _v
|
| 56 |
|
| 57 |
# ══════════════════════════════════════════════════════════════════════════════
|
| 58 |
+
# CSS — light-first, dark via @media
|
| 59 |
# ══════════════════════════════════════════════════════════════════════════════
|
| 60 |
st.markdown("""
|
| 61 |
<style>
|
| 62 |
@import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400;500;600;700&family=Inter:wght@300;400;500;600;700&display=swap');
|
|
|
|
|
|
|
| 63 |
:root {
|
| 64 |
--bg0:#ffffff; --bg1:#f8f9fc; --bg2:#f1f4f9; --bg3:#e4e9f2;
|
| 65 |
--border:rgba(37,99,235,0.13); --shadow:rgba(37,99,235,0.07);
|
|
|
|
| 68 |
--text0:#0f172a; --text1:#334155; --text2:#64748b; --text3:#94a3b8;
|
| 69 |
--mono:'JetBrains Mono',monospace; --sans:'Inter',sans-serif;
|
| 70 |
}
|
|
|
|
| 71 |
@media (prefers-color-scheme:dark){
|
| 72 |
:root {
|
| 73 |
--bg0:#06090f; --bg1:#0b1120; --bg2:#101828; --bg3:#172135;
|
|
|
|
| 77 |
--text0:#f1f5f9; --text1:#cbd5e1; --text2:#64748b; --text3:#334155;
|
| 78 |
}
|
| 79 |
}
|
|
|
|
|
|
|
| 80 |
html,body,[data-testid="stAppViewContainer"],[data-testid="stMain"]{
|
| 81 |
+
background:var(--bg0)!important;font-family:var(--sans)!important;
|
| 82 |
}
|
| 83 |
[data-testid="stHeader"]{
|
| 84 |
+
background:var(--bg0)!important;border-bottom:1px solid var(--border)!important;
|
|
|
|
| 85 |
box-shadow:0 1px 8px var(--shadow)!important;
|
| 86 |
}
|
| 87 |
+
[data-testid="stSidebar"]{background:var(--bg1)!important;border-right:1px solid var(--border)!important;}
|
|
|
|
|
|
|
|
|
|
| 88 |
[data-testid="stSidebar"] *{font-family:var(--sans)!important;}
|
|
|
|
|
|
|
| 89 |
h1,h2,h3,h4,h5,h6{font-family:var(--sans)!important;color:var(--text0)!important;font-weight:700!important;}
|
| 90 |
p,label,div,span,li{color:var(--text1);}
|
| 91 |
code,pre{font-family:var(--mono)!important;font-size:.82rem!important;
|
| 92 |
background:var(--bg2)!important;border-radius:4px!important;color:var(--accent)!important;}
|
| 93 |
+
.stButton>button{font-family:var(--sans)!important;font-weight:600!important;
|
| 94 |
+
letter-spacing:.01em!important;border-radius:8px!important;transition:all .15s ease!important;height:38px!important;}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
.stButton>button[kind="primary"]{
|
| 96 |
background:linear-gradient(135deg,var(--accent),var(--accent2))!important;
|
| 97 |
+
border:none!important;color:#fff!important;box-shadow:0 2px 12px var(--shadow)!important;}
|
| 98 |
+
.stButton>button[kind="primary"]:hover{box-shadow:0 4px 20px rgba(37,99,235,.28)!important;transform:translateY(-1px)!important;}
|
| 99 |
+
.stButton>button[kind="secondary"]{background:var(--bg1)!important;border:1px solid var(--border)!important;color:var(--text1)!important;}
|
| 100 |
+
.stButton>button[kind="secondary"]:hover{border-color:var(--accent)!important;color:var(--accent)!important;background:rgba(37,99,235,.05)!important;}
|
| 101 |
+
.stTextArea textarea{background:var(--bg0)!important;border:1.5px solid var(--border)!important;
|
| 102 |
+
border-radius:8px!important;color:var(--text0)!important;font-family:var(--mono)!important;
|
| 103 |
+
font-size:.82rem!important;line-height:1.65!important;resize:vertical!important;}
|
| 104 |
+
.stTextArea textarea:focus{border-color:var(--accent)!important;box-shadow:0 0 0 3px rgba(37,99,235,.1)!important;}
|
| 105 |
+
.stTextArea label{font-weight:600!important;font-size:.76rem!important;text-transform:uppercase!important;
|
| 106 |
+
letter-spacing:.09em!important;color:var(--text2)!important;}
|
| 107 |
+
.stTextInput input,.stNumberInput input{background:var(--bg0)!important;border:1.5px solid var(--border)!important;
|
| 108 |
+
border-radius:8px!important;color:var(--text0)!important;}
|
| 109 |
+
.stTextInput input:focus,.stNumberInput input:focus{border-color:var(--accent)!important;box-shadow:0 0 0 3px rgba(37,99,235,.1)!important;}
|
| 110 |
+
.stTextInput label,.stNumberInput label{font-size:.76rem!important;font-weight:600!important;
|
| 111 |
+
color:var(--text2)!important;text-transform:uppercase!important;letter-spacing:.08em!important;}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
.stRadio label,.stCheckbox label{color:var(--text1)!important;font-size:.9rem!important;}
|
| 113 |
+
.stTabs [data-baseweb="tab-list"]{background:var(--bg2)!important;border-radius:10px!important;
|
| 114 |
+
padding:4px!important;gap:3px!important;border:1px solid var(--border)!important;}
|
| 115 |
+
.stTabs [data-baseweb="tab"]{background:transparent!important;border-radius:7px!important;
|
| 116 |
+
color:var(--text2)!important;font-family:var(--sans)!important;font-size:.82rem!important;
|
| 117 |
+
font-weight:500!important;transition:all .14s ease!important;padding:6px 13px!important;}
|
| 118 |
+
.stTabs [aria-selected="true"]{background:var(--bg0)!important;color:var(--accent)!important;
|
| 119 |
+
font-weight:600!important;box-shadow:0 1px 6px var(--shadow)!important;}
|
| 120 |
+
[data-testid="stMetric"]{background:var(--bg1)!important;border:1px solid var(--border)!important;
|
| 121 |
+
border-radius:10px!important;padding:14px 18px!important;box-shadow:0 1px 4px var(--shadow)!important;}
|
| 122 |
+
[data-testid="stMetricLabel"]{color:var(--text2)!important;font-size:.72rem!important;
|
| 123 |
+
text-transform:uppercase!important;letter-spacing:.1em!important;}
|
| 124 |
+
[data-testid="stMetricValue"]{color:var(--accent)!important;font-family:var(--mono)!important;
|
| 125 |
+
font-size:1.55rem!important;font-weight:700!important;}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
hr{border-color:var(--border)!important;margin:12px 0!important;}
|
| 127 |
[data-testid="stAlert"]{background:var(--bg1)!important;border-radius:8px!important;border-left-width:3px!important;}
|
| 128 |
[data-testid="stDataFrame"]{border:1px solid var(--border)!important;border-radius:8px!important;overflow:hidden!important;}
|
| 129 |
+
.stProgress>div>div{background:linear-gradient(90deg,var(--accent),var(--accent2))!important;border-radius:4px!important;}
|
|
|
|
|
|
|
|
|
|
| 130 |
|
|
|
|
| 131 |
.app-header{display:flex;align-items:center;gap:14px;padding:4px 0 14px;}
|
| 132 |
+
.app-logo{width:40px;height:40px;border-radius:10px;
|
|
|
|
| 133 |
background:linear-gradient(135deg,var(--accent),var(--accent2));
|
| 134 |
display:flex;align-items:center;justify-content:center;
|
| 135 |
+
font-size:20px;flex-shrink:0;box-shadow:0 2px 12px rgba(37,99,235,.22);}
|
|
|
|
| 136 |
.app-title{font-family:var(--mono)!important;font-size:1.22rem!important;
|
| 137 |
font-weight:700!important;color:var(--text0)!important;letter-spacing:-.02em;}
|
| 138 |
+
.app-subtitle{font-size:.77rem!important;color:var(--text2)!important;margin-top:1px;letter-spacing:.03em;}
|
| 139 |
+
.sec-hdr{font-family:var(--mono);font-size:.65rem;font-weight:700;letter-spacing:.2em;
|
| 140 |
+
text-transform:uppercase;color:var(--accent);margin:14px 0 8px;
|
| 141 |
+
display:flex;align-items:center;gap:8px;}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
.sec-hdr::after{content:'';flex:1;height:1px;background:var(--border);}
|
| 143 |
+
.ex-card{background:var(--bg1);border:1px solid var(--border);border-radius:10px;
|
| 144 |
+
padding:12px 15px;margin-bottom:8px;position:relative;overflow:hidden;
|
| 145 |
+
transition:box-shadow .15s,border-color .15s;}
|
| 146 |
+
.ex-card::before{content:'';position:absolute;left:0;top:0;bottom:0;width:3px;
|
| 147 |
+
background:linear-gradient(180deg,var(--accent),var(--accent2));border-radius:3px 0 0 3px;}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
.ex-card:hover{border-color:rgba(37,99,235,.3);box-shadow:0 2px 12px var(--shadow);}
|
| 149 |
.ex-pdb{font-family:var(--mono);font-size:.95rem;font-weight:700;color:var(--accent);}
|
| 150 |
.ex-sub{font-size:.8rem;font-weight:600;color:var(--text1);margin:2px 0;}
|
| 151 |
.ex-desc{font-size:.72rem;line-height:1.5;color:var(--text2);margin-top:3px;}
|
| 152 |
+
.pkd-card{background:linear-gradient(135deg,rgba(37,99,235,.05),rgba(124,58,237,.04));
|
| 153 |
+
border:1px solid var(--border);border-radius:12px;padding:14px 18px;box-shadow:0 1px 6px var(--shadow);}
|
| 154 |
+
.pkd-lbl{font-family:var(--mono);font-size:.67rem;color:var(--text2);text-transform:uppercase;letter-spacing:.12em;margin-bottom:2px;}
|
| 155 |
+
.pkd-val{font-family:var(--mono);font-size:2rem;font-weight:700;color:var(--accent);line-height:1.1;}
|
| 156 |
+
.pkd-badge{display:inline-block;padding:2px 9px;border-radius:20px;font-size:.7rem;font-weight:600;font-family:var(--mono);margin-top:5px;}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
.badge-weak{background:rgba(220,38,38,.08);color:var(--red);border:1px solid rgba(220,38,38,.22);}
|
| 158 |
.badge-moderate{background:rgba(217,119,6,.08);color:var(--amber);border:1px solid rgba(217,119,6,.22);}
|
| 159 |
.badge-strong{background:rgba(22,163,74,.08);color:var(--green);border:1px solid rgba(22,163,74,.22);}
|
|
|
|
| 160 |
.str-bar{height:5px;border-radius:3px;background:var(--bg3);overflow:hidden;margin:8px 0 3px;}
|
| 161 |
+
.str-fill{height:100%;border-radius:3px;background:linear-gradient(90deg,var(--accent),var(--accent2));transition:width .7s cubic-bezier(.4,0,.2,1);}
|
| 162 |
+
.str-labels{display:flex;justify-content:space-between;font-size:.65rem;color:var(--text2);font-family:var(--mono);}
|
| 163 |
+
.ready-badge{display:inline-flex;align-items:center;gap:6px;padding:3px 10px;border-radius:20px;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
background:rgba(22,163,74,.07);border:1px solid rgba(22,163,74,.22);
|
| 165 |
+
font-family:var(--mono);font-size:.72rem;color:var(--green);font-weight:600;}
|
| 166 |
+
.ready-dot{display:inline-block;width:6px;height:6px;border-radius:50%;background:var(--green);animation:pulse 2s infinite;}
|
| 167 |
+
.idle-badge{display:inline-flex;align-items:center;gap:6px;padding:3px 10px;border-radius:20px;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
background:rgba(100,116,139,.07);border:1px solid rgba(100,116,139,.18);
|
| 169 |
+
font-family:var(--mono);font-size:.72rem;color:var(--text2);font-weight:600;}
|
|
|
|
| 170 |
@keyframes pulse{0%,100%{opacity:1;transform:scale(1)}50%{opacity:.5;transform:scale(.85)}}
|
| 171 |
+
.model-section{background:var(--bg2);border:1px solid var(--border);border-radius:9px;padding:12px 14px;margin-bottom:10px;}
|
| 172 |
+
.model-section-title{font-family:var(--mono);font-size:.63rem;font-weight:700;letter-spacing:.18em;text-transform:uppercase;color:var(--text2);margin-bottom:8px;}
|
| 173 |
+
.ngl-placeholder{display:flex;flex-direction:column;align-items:center;justify-content:center;
|
| 174 |
+
height:420px;background:var(--bg2);border:1px solid var(--border);border-radius:12px;
|
| 175 |
+
color:var(--text2);font-family:var(--mono);font-size:.8rem;text-align:center;line-height:2;}
|
| 176 |
+
|
| 177 |
+
/* IG progress info box */
|
| 178 |
+
.ig-info{background:linear-gradient(135deg,rgba(37,99,235,.04),rgba(124,58,237,.03));
|
| 179 |
+
border:1px solid var(--border);border-radius:8px;padding:10px 14px;
|
| 180 |
+
font-family:var(--mono);font-size:.75rem;color:var(--text2);margin-bottom:8px;}
|
|
|
|
|
|
|
| 181 |
</style>
|
| 182 |
""", unsafe_allow_html=True)
|
| 183 |
|
| 184 |
+
# postMessage bridge for NGL iframe theme sync
|
| 185 |
st.markdown("""
|
| 186 |
<script>
|
| 187 |
(function(){
|
|
|
|
| 190 |
try{f.contentWindow.postMessage({balmTheme:dark?'dark':'light'},'*');}catch(e){}
|
| 191 |
});
|
| 192 |
}
|
| 193 |
+
window.matchMedia('(prefers-color-scheme:dark)').addEventListener('change',function(e){notify(e.matches);});
|
|
|
|
|
|
|
| 194 |
})();
|
| 195 |
</script>
|
| 196 |
""", unsafe_allow_html=True)
|
| 197 |
|
| 198 |
|
| 199 |
# ══════════════════════════════════════════════════════════════════════════════
|
| 200 |
+
# PDB UTILITIES
|
| 201 |
# ══════════════════════════════════════════════════════════════════════════════
|
| 202 |
|
| 203 |
@st.cache_data(show_spinner=False, ttl=3600)
|
|
|
|
| 208 |
return r.text
|
| 209 |
|
| 210 |
|
| 211 |
+
def filter_pdb_to_chains(pdb_text: str, keep_chains: set) -> str:
|
| 212 |
+
"""
|
| 213 |
+
Strip PDB text to only ATOM/HETATM records for chain IDs in keep_chains.
|
| 214 |
+
Preserves HEADER, TITLE, REMARK lines so NGL can still parse it.
|
| 215 |
+
This keeps the viewer uncluttered — only selected chains are shown.
|
| 216 |
+
"""
|
| 217 |
+
out = []
|
| 218 |
+
for line in pdb_text.splitlines():
|
| 219 |
+
rec = line[:6].strip()
|
| 220 |
+
if rec in ("ATOM", "HETATM"):
|
| 221 |
+
chain_col = line[21:22] # column 22 (0-indexed: 21) = chain ID
|
| 222 |
+
if chain_col not in keep_chains:
|
| 223 |
+
continue
|
| 224 |
+
out.append(line)
|
| 225 |
+
return "\n".join(out)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
def get_chains_from_pdb(pdb_id: str):
|
| 229 |
try:
|
| 230 |
pdb_text = _fetch_pdb_cached(pdb_id.strip().upper())
|
|
|
|
| 278 |
if not pdb_content or not chains: return False
|
| 279 |
available = sorted(chains.keys())
|
| 280 |
st.info(f"✅ **{pdb_id.upper()}** — chains: **{', '.join(available)}**")
|
| 281 |
+
|
| 282 |
+
st.session_state.ig_a = st.session_state.ig_b = st.session_state.ig_chain_map = None
|
| 283 |
+
st.session_state.result = None
|
|
|
|
|
|
|
| 284 |
st.session_state.chain_info_a = []
|
| 285 |
st.session_state.chain_info_b = []
|
| 286 |
st.session_state["_pm"] = mode
|
| 287 |
+
|
| 288 |
+
selected_chain_ids: set = set() # ← collect which chains to keep in viewer
|
| 289 |
+
|
| 290 |
if mode == "Protein-Protein":
|
| 291 |
id_a = chain_a or available[0]
|
| 292 |
+
id_b = chain_b or (available[1] if len(available) > 1 else available[0])
|
| 293 |
seq_a, info_a, miss_a = resolve_chains(chains, id_a, available)
|
| 294 |
seq_b, info_b, miss_b = resolve_chains(chains, id_b, available)
|
| 295 |
for m in miss_a: st.warning(f"Chain **{m}** not found (Side A). Available: {', '.join(available)}")
|
|
|
|
| 298 |
st.session_state["_pb"] = seq_b
|
| 299 |
st.session_state.chain_info_a = info_a
|
| 300 |
st.session_state.chain_info_b = info_b
|
| 301 |
+
selected_chain_ids = {c["id"] for c in info_a} | {c["id"] for c in info_b}
|
| 302 |
else:
|
| 303 |
+
id_h = chain_h or "H"
|
| 304 |
+
id_l = chain_l or "L"
|
| 305 |
+
id_ag_str = chain_ag or next((c for c in available if c not in (id_h, id_l)), available[0])
|
| 306 |
def _single(cid):
|
| 307 |
if cid in chains and chains[cid]["seq"]: return cid, chains[cid]
|
| 308 |
f = next((k for k in chains if k.upper()==cid.upper() and chains[k]["seq"]), None)
|
|
|
|
| 318 |
st.session_state["_pag"] = seq_ag
|
| 319 |
st.session_state.chain_info_a = info_ag if info_ag else []
|
| 320 |
st.session_state.chain_info_b = [{"id":real_h,**ch},{"id":real_l,**cl}]
|
| 321 |
+
selected_chain_ids = ({c["id"] for c in info_ag} if info_ag else set())
|
| 322 |
+
if ch["seq"]: selected_chain_ids.add(real_h)
|
| 323 |
+
if cl["seq"]: selected_chain_ids.add(real_l)
|
| 324 |
+
|
| 325 |
+
# ── Filter PDB to selected chains only ───────────────────────────────────
|
| 326 |
+
if selected_chain_ids:
|
| 327 |
+
filtered = filter_pdb_to_chains(pdb_content, selected_chain_ids)
|
| 328 |
+
st.session_state.pdb_content = filtered
|
| 329 |
+
else:
|
| 330 |
+
st.session_state.pdb_content = pdb_content
|
| 331 |
+
|
| 332 |
return True
|
| 333 |
|
| 334 |
|
|
|
|
| 368 |
seq_a = "".join(c["seq"] for c in chain_info_a) if chain_info_a else ""
|
| 369 |
seq_b = "".join(c["seq"] for c in chain_info_b) if chain_info_b else ""
|
| 370 |
if ig_a and chain_info_a:
|
| 371 |
+
parts = split_ig_by_chains(ig_a,[c["seq"] for c in chain_info_a])
|
| 372 |
ig_a_out = [s for sub in parts for s in sub]
|
| 373 |
else:
|
| 374 |
ig_a_out = list(ig_a) if ig_a else []
|
| 375 |
if ig_b and chain_info_b:
|
| 376 |
+
parts = split_ig_by_chains(ig_b,[c["seq"] for c in chain_info_b])
|
| 377 |
ig_b_out = [s for sub in parts for s in sub]
|
| 378 |
else:
|
| 379 |
ig_b_out = list(ig_b) if ig_b else []
|
|
|
|
| 382 |
|
| 383 |
# ══════════════════════════════════════════════════════════════════════════════
|
| 384 |
# MODEL
|
| 385 |
+
# KEY FIX: do NOT use torch_dtype="auto" — on HF Spaces this loads fp16
|
| 386 |
+
# which silently produces zero gradients and breaks IG entirely.
|
| 387 |
+
# Always load in float32; the extra memory is worth correct attributions.
|
| 388 |
# ══════════════════════════════════════════════════════════════════════════════
|
| 389 |
@st.cache_resource(show_spinner="Loading ESM-2 backbone…")
|
| 390 |
def _load_esm_base():
|
| 391 |
from transformers import EsmModel, EsmTokenizer
|
| 392 |
base = EsmModel.from_pretrained(
|
| 393 |
"facebook/esm2_t33_650M_UR50D",
|
| 394 |
+
# torch_dtype="auto" removed: fp16 kills gradients on HF Spaces / CPU
|
| 395 |
+
# low_cpu_mem_usage removed: creates meta tensors → deepcopy+.to(device) crashes
|
| 396 |
+
use_safetensors=True,
|
| 397 |
)
|
| 398 |
tok = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
| 399 |
return base, tok
|
| 400 |
|
| 401 |
+
|
| 402 |
def _download_weights_hf() -> bytes:
|
| 403 |
from huggingface_hub import hf_hub_download
|
| 404 |
path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_FILENAME, resume_download=True)
|
| 405 |
with open(path,"rb") as f: return f.read()
|
| 406 |
|
| 407 |
+
|
| 408 |
def build_and_load_model(weights_bytes, pkd_bounds):
|
| 409 |
import torch.nn as nn
|
| 410 |
import torch.nn.functional as F
|
|
|
|
| 459 |
model.to(device).eval()
|
| 460 |
return model, device
|
| 461 |
|
| 462 |
+
|
| 463 |
def _ensure_model_loaded(pkd_lo=1.0, pkd_hi=16.0) -> bool:
|
| 464 |
if st.session_state.model is not None: return True
|
| 465 |
try:
|
|
|
|
| 476 |
|
| 477 |
# ══════════════════════════════════════════════════════════════════════════════
|
| 478 |
# INTEGRATED GRADIENTS
|
| 479 |
+
# KEY FIXES:
|
| 480 |
+
# 1. Explicit .float() casts — model may internally have mixed dtypes after LoRA
|
| 481 |
+
# 2. Steps reduced 30 → 15 (Riemann IG converges fast; halves compute time)
|
| 482 |
+
# 3. word_embed detach + clone to avoid stale graph issues
|
| 483 |
+
# 4. Proper gradient zeroing between steps
|
| 484 |
# ══════════════════════════════════════════════════════════════════════════════
|
| 485 |
+
def compute_ig(model, seq_a: str, seq_b: str, steps: int = 15):
|
| 486 |
device = next(model.parameters()).device
|
| 487 |
tok = model.esm_tokenizer
|
| 488 |
cls_tok = tok.cls_token
|
| 489 |
esm = model.esm_model
|
| 490 |
+
|
| 491 |
def tokenise(seq):
|
| 492 |
+
proc = seq.replace("|", f"{cls_tok}{cls_tok}")
|
| 493 |
+
return tok(proc, return_tensors="pt",
|
| 494 |
+
padding=False, truncation=True, max_length=1024).to(device)
|
| 495 |
+
|
| 496 |
word_embed = esm.base_model.model.embeddings.word_embeddings
|
| 497 |
+
|
| 498 |
+
enc_a = tokenise(seq_a); mask_a = enc_a["attention_mask"]
|
| 499 |
+
enc_b = tokenise(seq_b); mask_b = enc_b["attention_mask"]
|
| 500 |
+
|
| 501 |
+
# ── Force float32 — critical on HF Spaces where fp16 zeros gradients ────
|
| 502 |
+
emb_a = word_embed(enc_a["input_ids"]).detach().float().clone()
|
| 503 |
+
emb_b = word_embed(enc_b["input_ids"]).detach().float().clone()
|
| 504 |
+
mask_a = mask_a.float()
|
| 505 |
+
mask_b = mask_b.float()
|
| 506 |
+
|
| 507 |
+
def encode(embs, mask):
|
| 508 |
+
# Cast mask for attention
|
| 509 |
+
int_mask = mask.long()
|
| 510 |
+
ext = esm.base_model.model.get_extended_attention_mask(int_mask, embs.shape[:2])
|
| 511 |
+
# Run encoder — force float32 throughout
|
| 512 |
+
h = esm.base_model.model.encoder(
|
| 513 |
+
embs.float(), attention_mask=ext
|
| 514 |
+
).last_hidden_state.float()
|
| 515 |
+
m = mask.unsqueeze(-1).expand(h.size()).float()
|
| 516 |
+
return (torch.sum(h * m, 1) / torch.clamp(m.sum(1), min=1e-9))
|
| 517 |
+
|
| 518 |
+
def fwd(e_a, e_b):
|
| 519 |
+
return model.projection_head(encode(e_a, mask_a), encode(e_b, mask_b))
|
| 520 |
+
|
| 521 |
+
def riemann(tgt, baseline, fixed, tgt_is_b):
|
| 522 |
grads = []
|
| 523 |
+
for alpha in torch.linspace(0, 1, steps, device=device):
|
| 524 |
+
interp = (baseline + alpha * (tgt - baseline)).detach().requires_grad_(True)
|
| 525 |
+
out = fwd(fixed, interp) if tgt_is_b else fwd(interp, fixed)
|
| 526 |
out.sum().backward()
|
| 527 |
+
grads.append(interp.grad.detach().float().clone())
|
| 528 |
+
avg_grad = torch.stack(grads).mean(0)
|
| 529 |
+
ig = (avg_grad * (tgt - baseline)).abs().sum(-1).squeeze(0)
|
| 530 |
+
return ig.cpu().numpy()
|
| 531 |
+
|
| 532 |
+
baseline_a = torch.zeros_like(emb_a)
|
| 533 |
+
baseline_b = torch.zeros_like(emb_b)
|
| 534 |
+
|
| 535 |
+
attr_a = riemann(emb_a, baseline_a, emb_b, False)
|
| 536 |
+
attr_b = riemann(emb_b, baseline_b, emb_a, True)
|
| 537 |
+
|
| 538 |
def norm(a):
|
| 539 |
+
lo, hi = a.min(), a.max()
|
| 540 |
+
if hi - lo < 1e-9:
|
| 541 |
+
return [0.0] * len(a) # flat → return zeros, not NaN
|
| 542 |
+
return ((a - lo) / (hi - lo)).tolist()
|
| 543 |
+
|
| 544 |
+
# Slice off [CLS] and [EOS] tokens (indices 0 and -1)
|
| 545 |
return norm(attr_a[1:-1]), norm(attr_b[1:-1])
|
| 546 |
|
| 547 |
|
| 548 |
# ══════════════════════════════════════════════════════════════════════════════
|
| 549 |
+
# NGL VIEWER — theme-reactive, shows only filtered PDB chains
|
| 550 |
# ══════════════════════════════════════════════════════════════════════════════
|
| 551 |
def ngl_viewer_html(pdb_content, ig_chain_map=None, height=420) -> str:
|
| 552 |
ig_json = json.dumps(ig_chain_map) if ig_chain_map else "null"
|
|
|
|
| 558 |
*{{box-sizing:border-box;margin:0;padding:0}}
|
| 559 |
body{{overflow:hidden;font-family:'JetBrains Mono',monospace;transition:background .35s}}
|
| 560 |
#vp{{width:100%;height:{height}px}}
|
| 561 |
+
#hint{{position:absolute;top:10px;left:12px;font-size:10px;pointer-events:none;letter-spacing:.04em;transition:color .3s}}
|
|
|
|
| 562 |
#ctrl{{position:absolute;bottom:12px;right:12px;display:flex;flex-direction:column;gap:4px}}
|
| 563 |
.cb{{padding:4px 11px;font-family:'JetBrains Mono',monospace;font-size:9px;font-weight:500;
|
| 564 |
border-radius:5px;cursor:pointer;letter-spacing:.1em;text-transform:uppercase;transition:all .15s}}
|
|
|
|
| 575 |
</div>
|
| 576 |
<div id="leg"></div>
|
| 577 |
<script>
|
| 578 |
+
var T={{
|
| 579 |
+
light:{{bg:'#f0f4fa',hint:'rgba(51,65,85,.5)',
|
| 580 |
+
btn:'rgba(240,244,250,.93)',btn_b:'rgba(37,99,235,.16)',btn_c:'rgba(71,85,105,.6)',
|
|
|
|
|
|
|
|
|
|
| 581 |
btn_hc:'#2563eb',btn_hb:'rgba(37,99,235,.55)',btn_hbg:'rgba(37,99,235,.06)',
|
| 582 |
leg:'rgba(240,244,250,.93)',leg_b:'rgba(37,99,235,.13)',leg_c:'rgba(71,85,105,.65)',
|
| 583 |
+
ig_hi:'#1d4ed8',ig_lo:'#c7d7f0',null_col:0xe8ecf4}},
|
| 584 |
+
dark:{{bg:'#06090f',hint:'rgba(168,184,216,.4)',
|
| 585 |
+
btn:'rgba(6,9,15,.88)',btn_b:'rgba(96,165,250,.16)',btn_c:'rgba(168,184,216,.55)',
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 586 |
btn_hc:'#60a5fa',btn_hb:'rgba(96,165,250,.55)',btn_hbg:'rgba(96,165,250,.06)',
|
| 587 |
leg:'rgba(6,9,15,.88)',leg_b:'rgba(96,165,250,.13)',leg_c:'rgba(168,184,216,.55)',
|
| 588 |
+
ig_hi:'#60a5fa',ig_lo:'#172135',null_col:0x172135}}
|
|
|
|
|
|
|
| 589 |
}};
|
|
|
|
| 590 |
function isDark(){{return window.matchMedia('(prefers-color-scheme:dark)').matches;}}
|
| 591 |
+
var theme=isDark()?T.dark:T.light,stage=null,comp=null,curRep='cartoon',igData={ig_json};
|
|
|
|
|
|
|
| 592 |
|
| 593 |
+
function igColor(s,dark){{
|
|
|
|
|
|
|
| 594 |
s=Math.max(0,Math.min(1,s||0));
|
| 595 |
+
if(dark) return(Math.round(23+s*73)<<16)|(Math.round(33+s*132)<<8)|Math.round(53+s*197);
|
| 596 |
+
return(Math.round(199-s*170)<<16)|(Math.round(215-s*137)<<8)|Math.round(240-s*24);
|
| 597 |
}}
|
|
|
|
| 598 |
function styleUI(dark){{
|
| 599 |
+
var Th=dark?T.dark:T.light; theme=Th;
|
| 600 |
+
document.body.style.background=Th.bg;
|
| 601 |
+
document.getElementById('hint').style.color=Th.hint;
|
|
|
|
| 602 |
document.querySelectorAll('.cb').forEach(function(b){{
|
| 603 |
b.style.background=Th.btn;b.style.border='1px solid '+Th.btn_b;b.style.color=Th.btn_c;
|
| 604 |
b.onmouseenter=function(){{b.style.borderColor=Th.btn_hb;b.style.color=Th.btn_hc;b.style.background=Th.btn_hbg;}};
|
| 605 |
+
b.onmouseleave=function(){{b.style.borderColor=Th.btn_b;b.style.color=Th.btn_c;b.style.background=Th.btn;}};
|
| 606 |
}});
|
| 607 |
var leg=document.getElementById('leg');
|
| 608 |
leg.style.background=Th.leg;leg.style.border='1px solid '+Th.leg_b;leg.style.color=Th.leg_c;
|
| 609 |
+
if(stage)stage.setParameters({{backgroundColor:Th.bg}});
|
| 610 |
}}
|
|
|
|
| 611 |
function addRepr(c){{
|
| 612 |
comp=c;comp.removeAllRepresentations();
|
| 613 |
var dark=isDark(),Th=dark?T.dark:T.light,leg=document.getElementById('leg');
|
|
|
|
| 617 |
var cd=igData[atom.chainname];
|
| 618 |
if(!cd||!cd.scores||!cd.scores.length) return Th.null_col;
|
| 619 |
var i=atom.resno-cd.start;
|
| 620 |
+
return(i<0||i>=cd.scores.length)?Th.null_col:igColor(cd.scores[i],dark);
|
| 621 |
}};
|
| 622 |
}});
|
| 623 |
comp.addRepresentation(curRep,{{color:sid}});
|
| 624 |
leg.innerHTML='<span style="color:'+Th.ig_hi+'">■</span> High IG <span style="color:'+Th.ig_lo+'">■</span> Low IG';
|
| 625 |
+
}}else{{
|
| 626 |
comp.addRepresentation(curRep,{{colorScheme:'chainname'}});
|
| 627 |
leg.innerHTML='Coloured by chain';
|
| 628 |
}}
|
| 629 |
stage.autoView();
|
| 630 |
}}
|
|
|
|
| 631 |
function setRep(r){{curRep=r;if(comp){{comp.removeAllRepresentations();addRepr(comp);}}}}
|
| 632 |
+
function themeSwitch(dark){{styleUI(dark);if(comp){{comp.removeAllRepresentations();addRepr(comp);}}}}
|
| 633 |
+
window.matchMedia('(prefers-color-scheme:dark)').addEventListener('change',function(e){{themeSwitch(e.matches);}});
|
| 634 |
+
window.addEventListener('message',function(e){{if(e.data&&e.data.balmTheme)themeSwitch(e.data.balmTheme==='dark');}});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 635 |
window.addEventListener('resize',function(){{if(stage)stage.handleResize();}});
|
| 636 |
+
stage=new NGL.Stage('vp',{{backgroundColor:theme.bg,quality:'medium',tooltip:true}});
|
|
|
|
| 637 |
styleUI(isDark());
|
|
|
|
| 638 |
var blob=new Blob([`{escaped}`],{{type:'text/plain'}});
|
| 639 |
stage.loadFile(URL.createObjectURL(blob),{{ext:'pdb',name:'s'}}).then(addRepr);
|
| 640 |
</script></body></html>"""
|
| 641 |
|
| 642 |
|
| 643 |
# ══════════════════════════════════════════════════════════════════════════════
|
| 644 |
+
# PLOTLY
|
| 645 |
# ══════════════════════════════════════════════════════════════════════════════
|
| 646 |
_PL = dict(paper_bgcolor="rgba(255,255,255,.97)",plot_bgcolor="rgba(241,244,249,1)")
|
| 647 |
_FONT = dict(family="JetBrains Mono, monospace",color="#475569",size=9)
|
|
|
|
| 657 |
fig = go.Figure(go.Heatmap(z=rows,text=hover,hoverinfo="text",
|
| 658 |
colorscale=[[0,"#e8f0fe"],[.4,"#93c5fd"],[.75,"#3b82f6"],[1,"#1d4ed8"]],
|
| 659 |
zmin=0,zmax=1,xgap=1,ygap=1,
|
| 660 |
+
colorbar=dict(thickness=10,len=.9,tickfont=dict(**_FONT),title=dict(text="IG",font=dict(**_FONT)))))
|
|
|
|
| 661 |
fig.update_layout(
|
| 662 |
title=dict(text=title,font=dict(family="JetBrains Mono, monospace",size=10,color="#1d4ed8")),
|
| 663 |
**_PL,height=200,margin=dict(t=30,b=24,l=55,r=45),
|
|
|
|
| 699 |
'<div style="line-height:2.2;word-break:break-all">'+"".join(cells)+"</div>")
|
| 700 |
|
| 701 |
|
|
|
|
|
|
|
|
|
|
| 702 |
PPI_TEMPLATE = "seq_a,seq_b\nACDEFGHIKLMNPQRSTVWY,QWERTYIPASDFGKLCVBNM\n"
|
| 703 |
ABAG_TEMPLATE = "heavy_chain,light_chain,antigen\nEVQLVESGGG...,DIQMTQ...,IYSPT...\n"
|
| 704 |
|
|
|
|
| 714 |
if _src in st.session_state:
|
| 715 |
st.session_state[_dst] = st.session_state.pop(_src)
|
| 716 |
|
|
|
|
| 717 |
if st.session_state.get("_auto_load_model") and st.session_state.model is None:
|
| 718 |
st.session_state["_auto_load_model"] = False
|
| 719 |
+
_ensure_model_loaded(st.session_state.get("pkd_lo",1.0),st.session_state.get("pkd_hi",16.0))
|
|
|
|
|
|
|
| 720 |
|
|
|
|
| 721 |
st.markdown("""
|
| 722 |
<div class="app-header">
|
| 723 |
<div class="app-logo">🧬</div>
|
|
|
|
| 725 |
<div class="app-title">BALM-PPI Pro</div>
|
| 726 |
<div class="app-subtitle">ESM-2 · LoRA · Integrated Gradients · Protein Binding Affinity Prediction</div>
|
| 727 |
</div>
|
| 728 |
+
</div>""", unsafe_allow_html=True)
|
|
|
|
| 729 |
st.divider()
|
| 730 |
|
| 731 |
+
# SIDEBAR
|
| 732 |
with st.sidebar:
|
| 733 |
st.markdown("### ⚙️ Model")
|
| 734 |
+
st.markdown('<div class="model-section">',unsafe_allow_html=True)
|
| 735 |
+
st.markdown('<div class="model-section-title">Weights Source</div>',unsafe_allow_html=True)
|
| 736 |
model_src = st.radio("src",["HuggingFace (auto-cached)","Upload .pth"],
|
| 737 |
key="model_src",label_visibility="collapsed")
|
| 738 |
custom_w = None
|
|
|
|
| 749 |
|
| 750 |
if st.button("⚡ Load Model",use_container_width=True,type="primary"):
|
| 751 |
try:
|
| 752 |
+
if model_src=="HuggingFace (auto-cached)":
|
| 753 |
with st.spinner("📥 Fetching weights (cached after first run)…"):
|
| 754 |
wb = _download_weights_hf()
|
| 755 |
else:
|
| 756 |
+
if custom_w is None: st.error("Upload a .pth file first."); st.stop()
|
|
|
|
| 757 |
wb = custom_w.read()
|
| 758 |
with st.spinner("🔧 Building ESM-2 + LoRA…"):
|
| 759 |
m,dev = build_and_load_model(wb,(pkd_lo,pkd_hi))
|
|
|
|
| 764 |
st.error(f"❌ {e}"); st.code(traceback.format_exc())
|
| 765 |
|
| 766 |
if st.session_state.model:
|
| 767 |
+
st.markdown(f'<div class="ready-badge"><span class="ready-dot"></span>READY · {st.session_state.device}</div>',unsafe_allow_html=True)
|
|
|
|
|
|
|
|
|
|
| 768 |
else:
|
| 769 |
st.markdown('<div class="idle-badge">○ NOT LOADED</div>',unsafe_allow_html=True)
|
| 770 |
|
|
|
|
| 773 |
res = st.session_state.result
|
| 774 |
st.markdown("**Latest Result**")
|
| 775 |
st.markdown(
|
| 776 |
+
f'<div class="pkd-card"><div class="pkd-lbl">Predicted pKd</div>'
|
|
|
|
| 777 |
f'<div class="pkd-val">{res["pkd"]:.3f}</div>'
|
| 778 |
f'<div class="pkd-lbl" style="margin-top:10px">Cosine Similarity</div>'
|
| 779 |
+
f'<div style="font-family:var(--mono);font-size:1rem;font-weight:700;color:var(--text0)">{res["cosine"]:.4f}</div></div>',
|
|
|
|
| 780 |
unsafe_allow_html=True)
|
| 781 |
|
| 782 |
st.divider()
|
|
|
|
| 786 |
'💻 <a href="https://github.com/rgorantla04/BALM-PPI" style="color:var(--accent);text-decoration:none">rgorantla04/BALM-PPI</a>'
|
| 787 |
'</div>',unsafe_allow_html=True)
|
| 788 |
|
| 789 |
+
# TABS
|
| 790 |
+
tab_single,tab_batch,tab_info = st.tabs(["🎯 Single Prediction","📂 Batch Prediction","📚 Model Info"])
|
|
|
|
| 791 |
|
| 792 |
with tab_single:
|
| 793 |
mode_opts = ["Protein-Protein","Antibody-Antigen"]
|
| 794 |
+
mode = st.radio("im_r",mode_opts,horizontal=True,label_visibility="collapsed",key="interaction_mode")
|
|
|
|
| 795 |
st.markdown("---")
|
| 796 |
|
| 797 |
+
if mode=="Protein-Protein":
|
| 798 |
sc1,sc2 = st.columns(2)
|
| 799 |
with sc1:
|
| 800 |
st.markdown("**Target Protein — Seq A**")
|
|
|
|
| 804 |
st.markdown("**Binder Protein — Seq B** `|` = chain separator")
|
| 805 |
seq_b_raw = st.text_area("Binder",label_visibility="collapsed",key="ta_ppi_b",
|
| 806 |
height=140,placeholder="Paste FASTA or raw sequence…")
|
| 807 |
+
seq_a = clean_multi(seq_a_raw); seq_b = clean_multi(seq_b_raw)
|
|
|
|
| 808 |
if not st.session_state.pdb_content:
|
| 809 |
if seq_a:
|
| 810 |
st.session_state.chain_info_a = [
|
|
|
|
| 818 |
sc1,sc2 = st.columns([2,3])
|
| 819 |
with sc1:
|
| 820 |
st.markdown("**Antigen / Target — Seq A**")
|
| 821 |
+
ag_raw = st.text_area("Antigen",label_visibility="collapsed",key="ta_ab_ag",height=140,placeholder="Antigen sequence…")
|
|
|
|
| 822 |
with sc2:
|
| 823 |
ah,al = st.columns(2)
|
| 824 |
with ah:
|
| 825 |
st.markdown("**Heavy Chain (H)**")
|
| 826 |
+
h_raw = st.text_area("Heavy",label_visibility="collapsed",key="ta_ab_h",height=140,placeholder="VH sequence…")
|
|
|
|
| 827 |
with al:
|
| 828 |
st.markdown("**Light Chain (L)**")
|
| 829 |
+
l_raw = st.text_area("Light",label_visibility="collapsed",key="ta_ab_l",height=140,placeholder="VL sequence…")
|
| 830 |
+
seq_a = clean_multi(ag_raw); h_seq = clean_multi(h_raw); l_seq = clean_multi(l_raw)
|
|
|
|
|
|
|
|
|
|
| 831 |
seq_b = f"{h_seq}|{l_seq}" if (h_seq or l_seq) else ""
|
| 832 |
if not st.session_state.pdb_content:
|
| 833 |
if seq_a:
|
| 834 |
+
st.session_state.chain_info_a = [{"id":"Ag","seq":seq_a,"resnos":list(range(1,len(seq_a)+1))}]
|
|
|
|
| 835 |
if h_seq or l_seq:
|
| 836 |
st.session_state.chain_info_b = [
|
| 837 |
{"id":"H","seq":h_seq,"resnos":list(range(1,len(h_seq)+1))},
|
|
|
|
| 844 |
if st.session_state.model is None:
|
| 845 |
st.caption("⬅ Load model in sidebar first")
|
| 846 |
with rc2:
|
| 847 |
+
run_ig = st.checkbox("Run Integrated Gradients",value=True,help="~45-90s CPU · <5s GPU")
|
|
|
|
| 848 |
with rc3:
|
| 849 |
if st.session_state.model:
|
| 850 |
+
st.markdown('<div class="ready-badge" style="margin-top:6px;font-size:.7rem"><span class="ready-dot"></span>READY</div>',unsafe_allow_html=True)
|
|
|
|
| 851 |
|
| 852 |
if run_btn:
|
| 853 |
if not seq_a or not seq_b:
|
|
|
|
| 864 |
except Exception as e:
|
| 865 |
st.error(f"Prediction failed: {e}"); st.code(traceback.format_exc())
|
| 866 |
if run_ig and st.session_state.result:
|
| 867 |
+
with st.spinner("Computing Integrated Gradients (15 steps)…"):
|
| 868 |
try:
|
| 869 |
+
ig_a,ig_b = compute_ig(st.session_state.model,seq_a,seq_b,steps=15)
|
| 870 |
st.session_state.ig_a = ig_a
|
| 871 |
st.session_state.ig_b = ig_b
|
| 872 |
if st.session_state.chain_info_a or st.session_state.chain_info_b:
|
| 873 |
st.session_state.ig_chain_map = build_ig_chain_map(
|
| 874 |
st.session_state.chain_info_a,st.session_state.chain_info_b,ig_a,ig_b)
|
| 875 |
except Exception as e:
|
| 876 |
+
st.warning(f"IG failed: {e}\n{traceback.format_exc()}")
|
| 877 |
st.rerun()
|
| 878 |
|
| 879 |
if st.session_state.result:
|
| 880 |
res = st.session_state.result
|
| 881 |
pkd,cos = res["pkd"],res["cosine"]
|
| 882 |
pct = max(0.0,min(100.0,((pkd-pkd_lo)/max(pkd_hi-pkd_lo,1e-9))*100))
|
| 883 |
+
strength = "Weak" if pct<30 else "Moderate" if pct<55 else "Strong" if pct<75 else "Very Strong"
|
| 884 |
+
badge_cls = "badge-weak" if pct<30 else "badge-moderate" if pct<55 else "badge-strong"
|
| 885 |
st.markdown("---")
|
| 886 |
m1,m2,m3 = st.columns([1,1,2])
|
| 887 |
m1.metric("Predicted pKd",f"{pkd:.3f}")
|
| 888 |
m2.metric("Cosine Similarity",f"{cos:.4f}")
|
| 889 |
with m3:
|
| 890 |
st.markdown(
|
| 891 |
+
f'<div class="pkd-card"><div class="str-labels">'
|
| 892 |
+
f'<span>Weak ({pkd_lo:.0f})</span>'
|
| 893 |
f'<span style="color:var(--text0);font-weight:600">{strength}</span>'
|
| 894 |
f'<span>Strong ({pkd_hi:.0f})</span></div>'
|
| 895 |
f'<div class="str-bar"><div class="str-fill" style="width:{pct:.1f}%"></div></div>'
|
| 896 |
+
f'<span class="pkd-badge {badge_cls}">{strength}</span></div>',
|
| 897 |
+
unsafe_allow_html=True)
|
| 898 |
|
| 899 |
st.markdown("---")
|
| 900 |
mode_now = st.session_state.get("interaction_mode","Protein-Protein")
|
|
|
|
| 919 |
st.markdown(
|
| 920 |
f'<div class="ex-card"><div class="ex-pdb">{ex["label"]}</div>'
|
| 921 |
f'<div class="ex-sub">{ex["subtitle"]}</div>'
|
| 922 |
+
f'<div class="ex-desc">{ex["desc"]}</div></div>',unsafe_allow_html=True)
|
|
|
|
| 923 |
if st.button(f"⬇ Load {ex['pdb']}",key=f"ex_{ex['pdb']}",use_container_width=True):
|
| 924 |
ok = populate_from_pdb(ex["pdb"],ex["mode"],
|
| 925 |
chain_a=ex.get("chain_a"),chain_b=ex.get("chain_b"),
|
| 926 |
+
chain_h=ex.get("chain_h"),chain_l=ex.get("chain_l"),chain_ag=ex.get("chain_ag"))
|
|
|
|
| 927 |
if ok:
|
| 928 |
if st.session_state.model is None:
|
| 929 |
st.session_state["_auto_load_model"] = True
|
|
|
|
| 933 |
st.markdown('<div class="sec-hdr">Custom PDB Fetch</div>',unsafe_allow_html=True)
|
| 934 |
pdb_in = st.text_input("PDB ID",placeholder="6M0J, 1BRS, 2VXQ…",key="pdb_id_input")
|
| 935 |
|
| 936 |
+
if mode=="Protein-Protein":
|
| 937 |
st.caption("Comma-separate chains for multi-chain, e.g. `B,C`")
|
| 938 |
fc1,fc2 = st.columns(2)
|
| 939 |
ca_in = fc1.text_input("Side A chain(s)",placeholder="A or A,B",key="ca_in")
|
|
|
|
| 959 |
if ok: st.rerun()
|
| 960 |
|
| 961 |
with right_col:
|
| 962 |
+
vt = st.tabs(["🧬 Structure",
|
| 963 |
+
f"📊 {lbl_a} Strip",f"🗺 {lbl_a} Heatmap",
|
| 964 |
+
f"📊 {lbl_b} Strip",f"🗺 {lbl_b} Heatmap",
|
| 965 |
+
"📥 Download"])
|
|
|
|
|
|
|
| 966 |
|
| 967 |
with vt[0]:
|
| 968 |
if st.session_state.pdb_content:
|
| 969 |
+
st.caption("🎨 IG attribution (selected chains only)" if st.session_state.ig_chain_map
|
| 970 |
+
else "🔗 Selected chains — run prediction with IG for attribution colouring")
|
| 971 |
st.components.v1.html(
|
| 972 |
+
ngl_viewer_html(st.session_state.pdb_content,st.session_state.ig_chain_map,height=420),
|
|
|
|
| 973 |
height=425,scrolling=False)
|
| 974 |
else:
|
| 975 |
st.markdown(
|
| 976 |
+
'<div class="ngl-placeholder"><div style="font-size:2.2rem">🧬</div>'
|
|
|
|
| 977 |
'<div style="margin-top:12px;font-weight:600">Load an example or fetch a PDB</div>'
|
| 978 |
+
'<div style="font-size:.72rem;margin-top:4px">Only selected chains will be shown</div>'
|
| 979 |
'</div>',unsafe_allow_html=True)
|
| 980 |
|
| 981 |
with vt[1]:
|
| 982 |
if ig_a_d and seq_a_d:
|
| 983 |
st.markdown(residue_strip_html(seq_a_d,ig_a_d),unsafe_allow_html=True)
|
| 984 |
+
st.plotly_chart(top10_bar(seq_a_d,ig_a_d,"#2563eb",f"Top 10 · {lbl_a}"),use_container_width=True)
|
|
|
|
| 985 |
else:
|
| 986 |
st.info("Run prediction with **Integrated Gradients** to see attribution.")
|
| 987 |
|
|
|
|
| 995 |
with vt[3]:
|
| 996 |
if ig_b_d and seq_b_d:
|
| 997 |
st.markdown(residue_strip_html(seq_b_d,ig_b_d),unsafe_allow_html=True)
|
| 998 |
+
st.plotly_chart(top10_bar(seq_b_d,ig_b_d,"#7c3aed",f"Top 10 · {lbl_b}"),use_container_width=True)
|
|
|
|
| 999 |
else:
|
| 1000 |
st.info("Run prediction with **Integrated Gradients** to see attribution.")
|
| 1001 |
|
|
|
|
| 1010 |
if st.session_state.result:
|
| 1011 |
if ig_a_d or ig_b_d:
|
| 1012 |
dl1,dl2 = st.columns(2)
|
| 1013 |
+
buf = sio.StringIO(); wc = csv.writer(buf)
|
|
|
|
| 1014 |
wc.writerow(["chain","position","residue","ig_score"])
|
| 1015 |
+
for i,(aa,sc) in enumerate(zip(seq_a_d or "",ig_a_d or [])): wc.writerow(["Target",i+1,aa,f"{sc:.6f}"])
|
| 1016 |
+
for i,(aa,sc) in enumerate(zip(seq_b_d or "",ig_b_d or [])): wc.writerow(["proteina",i+1,aa,f"{sc:.6f}"])
|
| 1017 |
+
dl1.download_button("📥 IG Scores CSV",buf.getvalue().encode(),"ig_scores.csv","text/csv",use_container_width=True)
|
| 1018 |
+
res_dl = st.session_state.result
|
|
|
|
|
|
|
|
|
|
| 1019 |
summary = {
|
| 1020 |
"pkd":res_dl["pkd"],"cosine":res_dl["cosine"],
|
| 1021 |
+
"Target_length":len(seq_a_d or ""),"proteina_length":len(seq_b_d or ""),
|
|
|
|
| 1022 |
"top10_Target":sorted([{"res":f"{(seq_a_d or '')[i]}{i+1}","ig":(ig_a_d or [])[i]}
|
| 1023 |
+
for i in range(min(len(seq_a_d or ""),len(ig_a_d or [])))],key=lambda x:-x["ig"])[:10],
|
|
|
|
| 1024 |
"top10_proteina":sorted([{"res":f"{(seq_b_d or '')[i]}{i+1}","ig":(ig_b_d or [])[i]}
|
| 1025 |
+
for i in range(min(len(seq_b_d or ""),len(ig_b_d or [])))],key=lambda x:-x["ig"])[:10],
|
|
|
|
| 1026 |
}
|
| 1027 |
+
dl2.download_button("📥 Summary JSON",json.dumps(summary,indent=2).encode(),"balm_result.json","application/json",use_container_width=True)
|
|
|
|
|
|
|
|
|
|
| 1028 |
else:
|
| 1029 |
res_dl = st.session_state.result
|
| 1030 |
st.info("Enable IG and re-run for per-residue CSV.")
|
| 1031 |
st.download_button("📥 Basic Result JSON",
|
| 1032 |
+
json.dumps({"pkd":res_dl["pkd"],"cosine":res_dl["cosine"]},indent=2).encode(),
|
| 1033 |
+
"balm_result.json","application/json")
|
| 1034 |
else:
|
| 1035 |
st.info("Run a prediction to enable download.")
|
| 1036 |
|
|
|
|
| 1049 |
df = pd.read_csv(bf)
|
| 1050 |
st.dataframe(df.head(),use_container_width=True)
|
| 1051 |
if st.button("🏃 Run Batch",type="primary",use_container_width=True):
|
| 1052 |
+
if st.session_state.model is None: st.error("Load the model first.")
|
|
|
|
| 1053 |
else:
|
| 1054 |
model,results = st.session_state.model,[]
|
| 1055 |
prog,stat,n = st.progress(0.0),st.empty(),len(df)
|
|
|
|
| 1067 |
pkd_t,cos_t = model(sa,sb)
|
| 1068 |
rr={"pKd":round(float(pkd_t.item()),4),"Cosine":round(float(cos_t.item()),6)}
|
| 1069 |
if run_ig_b:
|
| 1070 |
+
ia,ib=compute_ig(model,sa,sb,steps=15)
|
| 1071 |
sap,sbp=sa.replace("|",""),sb.replace("|","")
|
| 1072 |
rr["top5_Target"]="|".join(f"{sap[j]}{j+1}:{ia[j]:.3f}"
|
| 1073 |
for j in sorted(range(min(len(sap),len(ia))),key=lambda x:-ia[x])[:5])
|
|
|
|
| 1080 |
df_res = pd.concat([df,pd.DataFrame(results)],axis=1)
|
| 1081 |
st.success(f"✅ {n} predictions complete")
|
| 1082 |
st.dataframe(df_res,use_container_width=True)
|
| 1083 |
+
st.download_button("📥 Download Results CSV",df_res.to_csv(index=False).encode(),"balm_batch_results.csv","text/csv",use_container_width=True)
|
|
|
|
| 1084 |
|
| 1085 |
with tab_info:
|
| 1086 |
st.markdown("### 📚 About BALM-PPI")
|
| 1087 |
st.markdown("""
|
| 1088 |
+
**BALM-PPI** predicts protein–protein and antibody–antigen binding affinity using ESM-2 + LoRA
|
| 1089 |
+
with Integrated Gradients explainability.
|
| 1090 |
|
| 1091 |
| Component | Detail |
|
| 1092 |
|-----------|--------|
|
|
|
|
| 1095 |
| Column mapping | `seq_a` → **Target** · `seq_b` → **proteina** |
|
| 1096 |
| Multi-chain | `\\|` separator → `<cls><cls>` double CLS token |
|
| 1097 |
| Affinity output | pKd = ((cos+1)/2) × (pKd_max − pKd_min) + pKd_min |
|
| 1098 |
+
| IG | Integrated Gradients, **15-step** Riemann (float32-safe, HF Spaces compatible) |
|
| 1099 |
+
| Viewer | PDB filtered to **selected chains only** |
|
| 1100 |
+
| Speed | hf_transfer · use_safetensors · HF disk cache · PDB cache (1 h TTL) |
|
| 1101 |
|
| 1102 |
**Examples** · **6M0J** — ACE2 ↔ SARS-CoV-2 RBD · **1BRS** — Barnase ↔ Barstar
|
| 1103 |
|