{ "cells": [ { "cell_type": "markdown", "id": "458aed0f", "metadata": {}, "source": [ "Note: This notebook is adapted from the [AbLang2](https://github.com/TobiasHeOl/AbLang2) model's GitHub repository. It is used to verify that the Hugging Face implementation functions correctly and produces the same output as the original model." ] }, { "cell_type": "code", "execution_count": 1, "id": "a51e7ed2", "metadata": {}, "outputs": [], "source": [ "!rm -rf ~/.cache/huggingface/hub/models--hemantn--ablang2" ] }, { "cell_type": "code", "execution_count": 2, "id": "7ae54cd0-6253-46dd-a316-4f20b12041e0", "metadata": {}, "outputs": [], "source": [ "import sys\n", "import os\n", "import numpy as np\n", "from transformers import AutoModel, AutoTokenizer\n", "from huggingface_hub import hf_hub_download" ] }, { "cell_type": "markdown", "id": "10801511-770d-46ac-a15d-a02d4ef9ec87", "metadata": {}, "source": [ "# **0. Sequence input and its format**\n", "\n", "AbLang2 takes as input either the individual heavy variable domain (VH), light variable domain (VL), or the full variable domain (Fv).\n", "\n", "Each record (antibody) needs to be a list with the VH as the first element and the VL as the second. If either the VH or VL is not known, leave an empty string.\n", "\n", "An asterisk (\\*) is used for masking. It is recommended to mask residues which you are interested in mutating.\n", "\n", "**NB:** It is important that the VH and VL sequence is ordered correctly." ] }, { "cell_type": "code", "execution_count": 3, "id": "99192978-a008-4a32-a80e-bba238e0ec7c", "metadata": {}, "outputs": [], "source": [ "seq1 = [\n", " 'EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTTVTVSS', # VH sequence\n", " 'DIQLTQSPLSLPVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK' # VL sequence\n", "]\n", "seq2 = [\n", " 'EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTT',\n", " 'PVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK'\n", "]\n", "seq3 = [\n", " 'EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTTVTVSS',\n", " '' # The VL sequence is not known, so an empty string is left instead. \n", "]\n", "seq4 = [\n", " '',\n", " 'DIQLTQSPLSLPVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK'\n", "]\n", "seq5 = [\n", " 'EVQ***SGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCAR**PGHGAAFMDVWGTGTTVTVSS', # (*) is used to mask certain residues\n", " 'DIQLTQSPLSLPVTLGQPASISCRSS*SLEASDTNIYLSWFQQRPGQSPRRLIYKI*NRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK'\n", "]\n", "\n", "all_seqs = [seq1, seq2, seq3, seq4, seq5]\n", "only_both_chains_seqs = [seq1, seq2, seq5]" ] }, { "cell_type": "markdown", "id": "dffbacfa-8642-4d94-9572-2205a05c18f9", "metadata": {}, "source": [ "# **1. How to use AbLang2**\n", "\n", "AbLang2 can be downloaded and used in its raw form as seen below. For convenience, we have also developed different \"modes\" which can be used for specific use cases (see Section 2) " ] }, { "cell_type": "code", "execution_count": 4, "id": "6d66ad84", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ed2d5574bd21463c9244070ab762c31e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "config.json: 0%| | 0.00/763 [00:00|\". The length of the output is therefore 5 longer than the VH and VL.\n", "\n", "**NB:** By default the representations are derived using a single forward pass. To prevent the predicted likelihood and probability to be affected by the input residue at each position, setting the \"stepwise_masking\" argument to True can be used. This will run a forward pass for each position with the residue at that position masked. This is much slower than running a single forward pass." ] }, { "cell_type": "code", "execution_count": 7, "id": "6227f661-575f-4b1e-9646-cfba7b10c3b4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[array([[-0.40741208, -0.5118987 , 0.06096708, ..., 0.3268144 ,\n", " 0.03920235, -0.36715826],\n", " [-0.5768883 , 0.38245413, -0.21791998, ..., 0.01250262,\n", " -0.08844463, -0.32367525],\n", " [-0.1475935 , 0.39639047, -0.38226923, ..., -0.10119921,\n", " -0.41469565, -0.00319315],\n", " ...,\n", " [-0.14358369, 0.3124389 , -0.30157998, ..., -0.13289244,\n", " -0.45353398, -0.07878865],\n", " [ 0.17538925, 0.24394299, 0.20141171, ..., 0.14587352,\n", " -0.38479003, 0.07409196],\n", " [-0.23031706, -0.35487285, 0.1960684 , ..., -0.1283362 ,\n", " 0.31107333, -0.3265108 ]], shape=(238, 480), dtype=float32),\n", " array([[-0.41981837, -0.3666375 , 0.10595217, ..., 0.3903574 ,\n", " 0.0382378 , -0.36337993],\n", " [-0.5054137 , 0.38347068, -0.10992069, ..., -0.05231472,\n", " -0.13636623, -0.34830108],\n", " [-0.06784609, 0.69349885, -0.4212398 , ..., -0.24805346,\n", " -0.39583805, -0.10972726],\n", " ...,\n", " [-0.02212614, 0.26338235, -0.5558968 , ..., -0.24067189,\n", " -0.11965694, 0.07879876],\n", " [-0.20650092, 0.43451664, -0.09650223, ..., -0.05296766,\n", " -0.04297376, 0.41854134],\n", " [-0.02653179, 0.03729444, 0.13194172, ..., -0.4554279 ,\n", " 0.03723941, 0.17769177]], shape=(238, 480), dtype=float32),\n", " array([[-0.40043733, -0.48596814, 0.0886725 , ..., 0.38941646,\n", " 0.06195956, -0.40999672],\n", " [-0.54576075, 0.4312959 , -0.3451486 , ..., -0.09285564,\n", " 0.03116508, -0.45269737],\n", " [ 0.0221165 , 0.53196615, -0.30137214, ..., -0.1889072 ,\n", " -0.32587305, 0.05078396],\n", " ...,\n", " [-0.03700298, 0.7739084 , 0.3454928 , ..., -0.03060072,\n", " 0.02420983, -0.48005292],\n", " [-0.03366657, 0.74771184, -0.35423476, ..., -0.08759108,\n", " -0.17898935, -0.4540483 ],\n", " [-0.16625853, 0.2701079 , -0.19761363, ..., 0.10313392,\n", " 0.44890267, -0.64840287]], shape=(238, 480), dtype=float32),\n", " array([[-0.26863217, 0.32259187, 0.10813517, ..., 0.03953876,\n", " 0.18312076, -0.00498045],\n", " [-0.2165424 , -0.38562432, -0.02696264, ..., 0.20541488,\n", " 0.18698391, -0.22639504],\n", " [-0.41950518, 0.04743317, 0.0048816 , ..., 0.11408642,\n", " -0.05384652, 0.1025871 ],\n", " ...,\n", " [-0.14095458, 0.5860325 , -0.44657114, ..., -0.39150292,\n", " -0.22395667, -0.42516366],\n", " [ 0.29816052, 0.40440455, -0.52062094, ..., 0.08969188,\n", " -0.20792632, -0.2045222 ],\n", " [-0.21370608, 0.23035707, -0.355185 , ..., -0.36726946,\n", " -0.05693531, -0.37847823]], shape=(238, 480), dtype=float32),\n", " array([[-0.42062947, -0.44009134, 0.00152371, ..., 0.27141467,\n", " 0.03798106, -0.397461 ],\n", " [-0.57318133, 0.5258899 , -0.17001636, ..., -0.23864633,\n", " 0.2088059 , -0.57877594],\n", " [-0.38988614, 0.46168196, -0.3429413 , ..., -0.14872643,\n", " -0.46576905, -0.21224979],\n", " ...,\n", " [-0.21528634, 0.30046722, -0.25216463, ..., -0.11576828,\n", " -0.4704907 , -0.0740136 ],\n", " [ 0.0633081 , 0.22700705, 0.28184187, ..., 0.15967266,\n", " -0.377182 , 0.06188517],\n", " [-0.27826303, -0.37297496, 0.21229912, ..., -0.14886017,\n", " 0.24998347, -0.35954213]], shape=(238, 480), dtype=float32)]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ablang(all_seqs, mode='rescoding', stepwise_masking = False)" ] }, { "cell_type": "markdown", "id": "6da2183b-4306-49bd-a7fc-23e78a23f305", "metadata": {}, "source": [ "## **Align rescoding/likelihood/probability output**\n", "\n", "For the 'rescoding', 'likelihood', and 'probability' modes, the output can also be aligned using the argument \"align=True\".\n", "\n", "This is done using the antibody numbering tool ANARCI, and requires manually installing **Pandas** and **[ANARCI](https://github.com/oxpig/ANARCI)**.\n", "\n", "**NB**: Align can only be used on input with the same format, i.e. either all heavy, all light, or all both heavy and light." ] }, { "cell_type": "code", "execution_count": 8, "id": "e4bc0cb1-f5b0-4255-9e93-d643ae1396df", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['<' '1 ' '2 ' '3 ' '4 ' '5 ' '6 ' '7 ' '8 ' '9 ' '11 ' '12 ' '13 ' '14 '\n", " '15 ' '16 ' '17 ' '18 ' '19 ' '20 ' '21 ' '22 ' '23 ' '24 ' '25 ' '26 '\n", " '27 ' '28 ' '29 ' '30 ' '35 ' '36 ' '37 ' '38 ' '39 ' '40 ' '41 ' '42 '\n", " '43 ' '44 ' '45 ' '46 ' '47 ' '48 ' '49 ' '50 ' '51 ' '52 ' '53 ' '54 '\n", " '55 ' '56 ' '57 ' '58 ' '59 ' '62 ' '63 ' '64 ' '65 ' '66 ' '67 ' '68 '\n", " '69 ' '70 ' '71 ' '72 ' '74 ' '75 ' '76 ' '77 ' '78 ' '79 ' '80 ' '81 '\n", " '82 ' '83 ' '84 ' '85 ' '86 ' '87 ' '88 ' '89 ' '90 ' '91 ' '92 ' '93 '\n", " '94 ' '95 ' '96 ' '97 ' '98 ' '99 ' '100 ' '101 ' '102 ' '103 ' '104 '\n", " '105 ' '106 ' '107 ' '108 ' '109 ' '110 ' '111 ' '112A' '112 ' '113 '\n", " '114 ' '115 ' '116 ' '117 ' '118 ' '119 ' '120 ' '121 ' '122 ' '123 '\n", " '124 ' '125 ' '126 ' '127 ' '128 ' '>' '|' '<' '1 ' '2 ' '3 ' '4 ' '5 '\n", " '6 ' '7 ' '8 ' '9 ' '10 ' '11 ' '12 ' '13 ' '14 ' '15 ' '16 ' '17 ' '18 '\n", " '19 ' '20 ' '21 ' '22 ' '23 ' '24 ' '25 ' '26 ' '27 ' '28 ' '29 ' '30 '\n", " '31 ' '32 ' '34 ' '35 ' '36 ' '37 ' '38 ' '39 ' '40 ' '41 ' '42 ' '43 '\n", " '44 ' '45 ' '46 ' '47 ' '48 ' '49 ' '50 ' '51 ' '52 ' '53 ' '54 ' '55 '\n", " '56 ' '57 ' '64 ' '65 ' '66 ' '67 ' '68 ' '69 ' '70 ' '71 ' '72 ' '74 '\n", " '75 ' '76 ' '77 ' '78 ' '79 ' '80 ' '83 ' '84 ' '85 ' '86 ' '87 ' '88 '\n", " '89 ' '90 ' '91 ' '92 ' '93 ' '94 ' '95 ' '96 ' '97 ' '98 ' '99 ' '100 '\n", " '101 ' '102 ' '103 ' '104 ' '105 ' '106 ' '107 ' '108 ' '109 ' '114 '\n", " '115 ' '116 ' '117 ' '118 ' '119 ' '120 ' '121 ' '122 ' '123 ' '124 '\n", " '125 ' '126 ' '127 ' '>']\n", "['|', '|<-----------PVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKI-SNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK>', '<------SGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCAR**PGHGAAFMDVWGTGTTVTVSS>|']\n", "[[[ 9.31621838 -3.42184329 -3.59397745 ... -14.73707485 -6.8935833\n", " -0.23662776]\n", " [ -3.54718232 -5.84866619 -4.02423859 ... -12.93966579 -9.5614481\n", " -4.48473835]\n", " [-11.94997597 -2.245543 -5.69481373 ... -15.19639015 -17.97454071\n", " -12.56952095]\n", " ...\n", " [ -8.94504833 -0.42261261 -4.95588207 ... -16.66817474 -15.2224741\n", " -10.37267494]\n", " [-11.65150356 -5.44477606 -2.95585775 ... -16.25555801 -9.75158596\n", " -11.75897026]\n", " [ 1.79469728 -1.95846701 -3.59784532 ... -14.95585823 -7.47080708\n", " -0.95226753]]\n", "\n", " [[ 8.55518723 -3.83663297 -2.33595967 ... -13.87456799 -8.14840603\n", " -0.42472434]\n", " [ -4.40701294 -5.53201008 -3.69397402 ... -12.97877789 -9.86258411\n", " -4.95414352]\n", " [-11.95642853 -3.86210871 -5.80935192 ... -14.89213085 -16.94556236\n", " -11.36959839]\n", " ...\n", " [ -7.75924015 -0.66524202 -4.08643246 ... -16.16580772 -14.76507473\n", " -8.3507061 ]\n", " [-11.91039753 -4.86995983 -2.74777436 ... -16.07694817 -8.44974899\n", " -10.45223904]\n", " [ 0.86006832 -2.37964034 -3.58130741 ... -15.35423565 -7.73035526\n", " -1.11989737]]\n", "\n", " [[ -4.37902737 -7.55587149 1.21958363 ... -15.48622513 -6.021842\n", " -3.79647374]\n", " [ 0. 0. 0. ... 0. 0.\n", " 0. ]\n", " [ 0. 0. 0. ... 0. 0.\n", " 0. ]\n", " ...\n", " [ -8.94207573 -0.51090252 -5.09760332 ... -16.69521713 -15.45450687\n", " -10.50823212]\n", " [-11.92354965 -5.55152607 -2.87666893 ... -16.40607834 -10.19431686\n", " -12.1328764 ]\n", " [ 2.42200375 -2.01573253 -3.61701298 ... -14.9590435 -7.19029331\n", " -0.89830256]]]\n" ] } ], "source": [ "results = ablang(only_both_chains_seqs, mode='likelihood', align=True)\n", "\n", "print(results.number_alignment)\n", "print(results.aligned_seqs)\n", "print(results.aligned_embeds)" ] }, { "cell_type": "code", "execution_count": 9, "id": "56be8cad", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[array([[9.9955505e-01, 2.9358694e-06, 2.4716087e-06, ..., 3.5776201e-11,\n", " 9.1196831e-08, 7.0967326e-05],\n", " [4.1573694e-06, 4.1619489e-07, 2.5800944e-06, ..., 3.4650952e-10,\n", " 1.0159109e-08, 1.6279575e-06],\n", " [7.8059600e-08, 1.2794037e-03, 4.0645118e-05, ..., 3.0375720e-09,\n", " 1.8879491e-10, 4.2010839e-08],\n", " ...,\n", " [3.4210879e-07, 1.7195340e-03, 1.8477240e-05, ..., 1.5137445e-10,\n", " 6.4255873e-10, 8.2064140e-08],\n", " [9.1038084e-09, 4.5161755e-06, 5.4411950e-05, ..., 9.1139631e-11,\n", " 6.0862085e-08, 8.1761966e-09],\n", " [8.5759175e-04, 2.0104915e-05, 3.9023766e-06, ..., 4.5562460e-11,\n", " 8.1156479e-08, 5.4990651e-05]], shape=(238, 26), dtype=float32),\n", " array([[9.9939799e-01, 4.1499175e-06, 1.8611167e-05, ..., 1.8139243e-10,\n", " 5.5649299e-08, 1.2583815e-04],\n", " [1.6735513e-06, 5.4332406e-07, 3.4143472e-06, ..., 3.1693398e-10,\n", " 7.1501400e-09, 9.6832969e-07],\n", " [3.7784993e-08, 1.2377645e-04, 1.7658784e-05, ..., 2.0061326e-09,\n", " 2.5737484e-10, 6.7947965e-08],\n", " ...,\n", " [1.1050455e-06, 1.3312638e-03, 4.3497097e-05, ..., 2.4686178e-10,\n", " 1.0018089e-09, 6.1165900e-07],\n", " [5.7270397e-09, 6.5396339e-06, 5.4601755e-05, ..., 8.8801404e-11,\n", " 1.8233513e-07, 2.4615032e-08],\n", " [7.3952030e-04, 2.8970928e-05, 8.7113440e-06, ..., 6.7168833e-11,\n", " 1.3746008e-07, 1.0210846e-04]], shape=(222, 26), dtype=float32),\n", " array([[9.99685407e-01, 3.35662639e-06, 1.14241482e-06, ...,\n", " 2.32460891e-11, 6.88188067e-08, 5.69467156e-05],\n", " [6.38133372e-07, 1.01300586e-07, 5.64459742e-06, ...,\n", " 4.09234556e-11, 2.53804799e-09, 4.31722100e-07],\n", " [1.49096788e-08, 2.04515047e-04, 9.23794141e-06, ...,\n", " 7.46306961e-10, 2.92107380e-11, 2.21786500e-08],\n", " ...,\n", " [2.15093763e-07, 1.06453872e-03, 1.62486140e-05, ...,\n", " 1.12102910e-10, 1.47300866e-10, 4.73037538e-08],\n", " [4.30136682e-09, 3.09317988e-06, 3.96632568e-05, ...,\n", " 5.24226877e-11, 2.39579450e-08, 3.86403221e-09],\n", " [9.77773685e-04, 1.29533228e-05, 2.78623725e-06, ...,\n", " 2.73364300e-11, 3.96418649e-08, 4.04014427e-05]],\n", " shape=(238, 26), dtype=float32)]" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ablang(only_both_chains_seqs, mode='probability')" ] }, { "cell_type": "markdown", "id": "8f0a71ec-e916-4330-90d0-13a4b1121a89", "metadata": {}, "source": [ "## **Pseudo log likelihood and Confidence scores**\n", "\n", "The pseudo log likelihood and confidence represents two methods for calculating the uncertainty for the input sequence.\n", "\n", "- pseudo_log_likelihood: For each position, the pseudo log likelihood is calculated when predicting the masked residue. The final score is an average across the whole input. This is similar to the approach taken in the ESM-2 paper for calculating pseudo perplexity [(Lin et al., 2023)](https://doi.org/10.1126/science.ade2574).\n", "\n", "- confidence: For each position, the log likelihood is calculated without masking the residue. The final score is an average across the whole input. \n", "\n", "**NB:** The **confidence is fast** to compute, requiring only a single forward pass per input. **Pseudo log likelihood is slow** to calculate, requiring L forward passes per input, where L is the length of the input.\n", "\n", "**NB:** It is recommended to use **pseudo log likelihood for final results** and **confidence for exploratory work**." ] }, { "cell_type": "code", "execution_count": 10, "id": "83f3064b-48a7-42fb-ba82-ec153ea946da", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([1.96673731, 2.04801253, 2.09881898, 1.82533665, 1.97255249])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "results = ablang(all_seqs, mode='pseudo_log_likelihood')\n", "np.exp(-results) # convert to pseudo perplexity" ] }, { "cell_type": "code", "execution_count": 11, "id": "42cc8b34-5ae9-4857-93fe-a438a0f2a868", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([1.2636038, 1.126463 , 1.3123759, 1.2140924, 1.1805094],\n", " dtype=float32)" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "results = ablang(all_seqs, mode='confidence')\n", "np.exp(-results)" ] }, { "cell_type": "markdown", "id": "e0b63e48-b2a1-4a8e-8ecb-449748a2cb25", "metadata": {}, "source": [ "## **restore**\n", "\n", "This mode can be used to restore masked residues, and fragmented regions with \"align=True\". " ] }, { "cell_type": "code", "execution_count": 12, "id": "2d5b725c-4eac-4a4b-9331-357c3ac140f7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array(['|',\n", " '|',\n", " '|'],\n", " dtype='|',\n", " '|',\n", " '|'],\n", " dtype='