updated notebook with correct usage
Browse files
test_ablang2_HF_implementation.ipynb
CHANGED
|
@@ -10,14 +10,16 @@
|
|
| 10 |
},
|
| 11 |
{
|
| 12 |
"cell_type": "code",
|
| 13 |
-
"execution_count":
|
| 14 |
"id": "7ae54cd0-6253-46dd-a316-4f20b12041e0",
|
| 15 |
"metadata": {},
|
| 16 |
"outputs": [],
|
| 17 |
"source": [
|
| 18 |
-
"import
|
| 19 |
-
"
|
| 20 |
-
"
|
|
|
|
|
|
|
| 21 |
]
|
| 22 |
},
|
| 23 |
{
|
|
@@ -38,7 +40,7 @@
|
|
| 38 |
},
|
| 39 |
{
|
| 40 |
"cell_type": "code",
|
| 41 |
-
"execution_count":
|
| 42 |
"id": "99192978-a008-4a32-a80e-bba238e0ec7c",
|
| 43 |
"metadata": {},
|
| 44 |
"outputs": [],
|
|
@@ -85,8 +87,17 @@
|
|
| 85 |
"metadata": {},
|
| 86 |
"outputs": [],
|
| 87 |
"source": [
|
| 88 |
-
"model
|
| 89 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
"ablang = AbLang2PairedHuggingFaceAdapter(model=model, tokenizer=tokenizer)"
|
| 91 |
]
|
| 92 |
},
|
|
@@ -120,7 +131,7 @@
|
|
| 120 |
},
|
| 121 |
{
|
| 122 |
"cell_type": "code",
|
| 123 |
-
"execution_count":
|
| 124 |
"id": "ceae4a88-0679-4704-8bad-c06a4569c497",
|
| 125 |
"metadata": {},
|
| 126 |
"outputs": [],
|
|
@@ -145,7 +156,7 @@
|
|
| 145 |
},
|
| 146 |
{
|
| 147 |
"cell_type": "code",
|
| 148 |
-
"execution_count":
|
| 149 |
"id": "d22f4302-1262-4cc1-8a1c-a36daa8c710c",
|
| 150 |
"metadata": {},
|
| 151 |
"outputs": [
|
|
@@ -164,7 +175,7 @@
|
|
| 164 |
" -0.16615383, -0.15569784]], shape=(5, 480))"
|
| 165 |
]
|
| 166 |
},
|
| 167 |
-
"execution_count":
|
| 168 |
"metadata": {},
|
| 169 |
"output_type": "execute_result"
|
| 170 |
}
|
|
@@ -189,7 +200,7 @@
|
|
| 189 |
},
|
| 190 |
{
|
| 191 |
"cell_type": "code",
|
| 192 |
-
"execution_count":
|
| 193 |
"id": "6227f661-575f-4b1e-9646-cfba7b10c3b4",
|
| 194 |
"metadata": {},
|
| 195 |
"outputs": [
|
|
@@ -263,7 +274,7 @@
|
|
| 263 |
" 0.24998347, -0.35954213]], shape=(238, 480), dtype=float32)]"
|
| 264 |
]
|
| 265 |
},
|
| 266 |
-
"execution_count":
|
| 267 |
"metadata": {},
|
| 268 |
"output_type": "execute_result"
|
| 269 |
}
|
|
@@ -288,7 +299,7 @@
|
|
| 288 |
},
|
| 289 |
{
|
| 290 |
"cell_type": "code",
|
| 291 |
-
"execution_count":
|
| 292 |
"id": "e4bc0cb1-f5b0-4255-9e93-d643ae1396df",
|
| 293 |
"metadata": {},
|
| 294 |
"outputs": [
|
|
@@ -450,7 +461,7 @@
|
|
| 450 |
},
|
| 451 |
{
|
| 452 |
"cell_type": "code",
|
| 453 |
-
"execution_count":
|
| 454 |
"id": "83f3064b-48a7-42fb-ba82-ec153ea946da",
|
| 455 |
"metadata": {},
|
| 456 |
"outputs": [
|
|
@@ -460,7 +471,7 @@
|
|
| 460 |
"array([1.96673731, 2.04801253, 2.09881898, 1.82533665, 1.97255249])"
|
| 461 |
]
|
| 462 |
},
|
| 463 |
-
"execution_count":
|
| 464 |
"metadata": {},
|
| 465 |
"output_type": "execute_result"
|
| 466 |
}
|
|
@@ -472,7 +483,7 @@
|
|
| 472 |
},
|
| 473 |
{
|
| 474 |
"cell_type": "code",
|
| 475 |
-
"execution_count":
|
| 476 |
"id": "42cc8b34-5ae9-4857-93fe-a438a0f2a868",
|
| 477 |
"metadata": {},
|
| 478 |
"outputs": [
|
|
@@ -483,7 +494,7 @@
|
|
| 483 |
" dtype=float32)"
|
| 484 |
]
|
| 485 |
},
|
| 486 |
-
"execution_count":
|
| 487 |
"metadata": {},
|
| 488 |
"output_type": "execute_result"
|
| 489 |
}
|
|
@@ -505,7 +516,7 @@
|
|
| 505 |
},
|
| 506 |
{
|
| 507 |
"cell_type": "code",
|
| 508 |
-
"execution_count":
|
| 509 |
"id": "2d5b725c-4eac-4a4b-9331-357c3ac140f7",
|
| 510 |
"metadata": {},
|
| 511 |
"outputs": [
|
|
@@ -518,7 +529,7 @@
|
|
| 518 |
" dtype='<U238')"
|
| 519 |
]
|
| 520 |
},
|
| 521 |
-
"execution_count":
|
| 522 |
"metadata": {},
|
| 523 |
"output_type": "execute_result"
|
| 524 |
}
|
|
@@ -530,7 +541,7 @@
|
|
| 530 |
},
|
| 531 |
{
|
| 532 |
"cell_type": "code",
|
| 533 |
-
"execution_count":
|
| 534 |
"id": "0e9615f7-c490-4947-96f4-7617266c686e",
|
| 535 |
"metadata": {},
|
| 536 |
"outputs": [
|
|
@@ -543,7 +554,7 @@
|
|
| 543 |
" dtype='<U238')"
|
| 544 |
]
|
| 545 |
},
|
| 546 |
-
"execution_count":
|
| 547 |
"metadata": {},
|
| 548 |
"output_type": "execute_result"
|
| 549 |
}
|
|
@@ -560,48 +571,6 @@
|
|
| 560 |
"metadata": {},
|
| 561 |
"outputs": [],
|
| 562 |
"source": []
|
| 563 |
-
},
|
| 564 |
-
{
|
| 565 |
-
"cell_type": "markdown",
|
| 566 |
-
"id": "98956ca9",
|
| 567 |
-
"metadata": {},
|
| 568 |
-
"source": [
|
| 569 |
-
"## **rescoding / likelihood / probability**\n",
|
| 570 |
-
"\n",
|
| 571 |
-
"The rescodings represents each residue as a 480 sized embedding. The likelihoods represents each residue as the predicted logits for each character in the vocabulary. The probabilities represents the normalised likelihoods.\n",
|
| 572 |
-
"\n",
|
| 573 |
-
"**NB:** The output includes extra tokens (start, stop and separation tokens) in the format \"<VH_seq>|<VL_seq>\". The length of the output is therefore 5 longer than the VH and VL.\n",
|
| 574 |
-
"\n",
|
| 575 |
-
"**NB:** By default the representations are derived using a single forward pass. To prevent the predicted likelihood and probability to be affected by the input residue at each position, setting the \"stepwise_masking\" argument to True can be used. This will run a forward pass for each position with the residue at that position masked. This is much slower than running a single forward pass."
|
| 576 |
-
]
|
| 577 |
-
},
|
| 578 |
-
{
|
| 579 |
-
"cell_type": "markdown",
|
| 580 |
-
"id": "b046ae57",
|
| 581 |
-
"metadata": {},
|
| 582 |
-
"source": [
|
| 583 |
-
"## **rescoding / likelihood / probability**\n",
|
| 584 |
-
"\n",
|
| 585 |
-
"The rescodings represents each residue as a 480 sized embedding. The likelihoods represents each residue as the predicted logits for each character in the vocabulary. The probabilities represents the normalised likelihoods.\n",
|
| 586 |
-
"\n",
|
| 587 |
-
"**NB:** The output includes extra tokens (start, stop and separation tokens) in the format \"<VH_seq>|<VL_seq>\". The length of the output is therefore 5 longer than the VH and VL.\n",
|
| 588 |
-
"\n",
|
| 589 |
-
"**NB:** By default the representations are derived using a single forward pass. To prevent the predicted likelihood and probability to be affected by the input residue at each position, setting the \"stepwise_masking\" argument to True can be used. This will run a forward pass for each position with the residue at that position masked. This is much slower than running a single forward pass."
|
| 590 |
-
]
|
| 591 |
-
},
|
| 592 |
-
{
|
| 593 |
-
"cell_type": "markdown",
|
| 594 |
-
"id": "78ccf7d8",
|
| 595 |
-
"metadata": {},
|
| 596 |
-
"source": [
|
| 597 |
-
"## **rescoding / likelihood / probability**\n",
|
| 598 |
-
"\n",
|
| 599 |
-
"The rescodings represents each residue as a 480 sized embedding. The likelihoods represents each residue as the predicted logits for each character in the vocabulary. The probabilities represents the normalised likelihoods.\n",
|
| 600 |
-
"\n",
|
| 601 |
-
"**NB:** The output includes extra tokens (start, stop and separation tokens) in the format \"<VH_seq>|<VL_seq>\". The length of the output is therefore 5 longer than the VH and VL.\n",
|
| 602 |
-
"\n",
|
| 603 |
-
"**NB:** By default the representations are derived using a single forward pass. To prevent the predicted likelihood and probability to be affected by the input residue at each position, setting the \"stepwise_masking\" argument to True can be used. This will run a forward pass for each position with the residue at that position masked. This is much slower than running a single forward pass."
|
| 604 |
-
]
|
| 605 |
}
|
| 606 |
],
|
| 607 |
"metadata": {
|
|
|
|
| 10 |
},
|
| 11 |
{
|
| 12 |
"cell_type": "code",
|
| 13 |
+
"execution_count": 11,
|
| 14 |
"id": "7ae54cd0-6253-46dd-a316-4f20b12041e0",
|
| 15 |
"metadata": {},
|
| 16 |
"outputs": [],
|
| 17 |
"source": [
|
| 18 |
+
"import sys\n",
|
| 19 |
+
"import os\n",
|
| 20 |
+
"import numpy as np\n",
|
| 21 |
+
"from transformers import AutoModel, AutoTokenizer\n",
|
| 22 |
+
"from transformers.utils import cached_file"
|
| 23 |
]
|
| 24 |
},
|
| 25 |
{
|
|
|
|
| 40 |
},
|
| 41 |
{
|
| 42 |
"cell_type": "code",
|
| 43 |
+
"execution_count": 6,
|
| 44 |
"id": "99192978-a008-4a32-a80e-bba238e0ec7c",
|
| 45 |
"metadata": {},
|
| 46 |
"outputs": [],
|
|
|
|
| 87 |
"metadata": {},
|
| 88 |
"outputs": [],
|
| 89 |
"source": [
|
| 90 |
+
"# Load model and tokenizer from Hugging Face Hub\n",
|
| 91 |
+
"model = AutoModel.from_pretrained(\"hemantn/ablang2\", trust_remote_code=True)\n",
|
| 92 |
+
"tokenizer = AutoTokenizer.from_pretrained(\"hemantn/ablang2\", trust_remote_code=True)\n",
|
| 93 |
+
"\n",
|
| 94 |
+
"# Find the cached model directory and import adapter\n",
|
| 95 |
+
"adapter_path = cached_file(\"hemantn/ablang2\", \"adapter.py\")\n",
|
| 96 |
+
"cached_model_dir = os.path.dirname(adapter_path)\n",
|
| 97 |
+
"sys.path.insert(0, cached_model_dir)\n",
|
| 98 |
+
"\n",
|
| 99 |
+
"# Import and create the adapter\n",
|
| 100 |
+
"from adapter import AbLang2PairedHuggingFaceAdapter\n",
|
| 101 |
"ablang = AbLang2PairedHuggingFaceAdapter(model=model, tokenizer=tokenizer)"
|
| 102 |
]
|
| 103 |
},
|
|
|
|
| 131 |
},
|
| 132 |
{
|
| 133 |
"cell_type": "code",
|
| 134 |
+
"execution_count": 7,
|
| 135 |
"id": "ceae4a88-0679-4704-8bad-c06a4569c497",
|
| 136 |
"metadata": {},
|
| 137 |
"outputs": [],
|
|
|
|
| 156 |
},
|
| 157 |
{
|
| 158 |
"cell_type": "code",
|
| 159 |
+
"execution_count": 8,
|
| 160 |
"id": "d22f4302-1262-4cc1-8a1c-a36daa8c710c",
|
| 161 |
"metadata": {},
|
| 162 |
"outputs": [
|
|
|
|
| 175 |
" -0.16615383, -0.15569784]], shape=(5, 480))"
|
| 176 |
]
|
| 177 |
},
|
| 178 |
+
"execution_count": 8,
|
| 179 |
"metadata": {},
|
| 180 |
"output_type": "execute_result"
|
| 181 |
}
|
|
|
|
| 200 |
},
|
| 201 |
{
|
| 202 |
"cell_type": "code",
|
| 203 |
+
"execution_count": 9,
|
| 204 |
"id": "6227f661-575f-4b1e-9646-cfba7b10c3b4",
|
| 205 |
"metadata": {},
|
| 206 |
"outputs": [
|
|
|
|
| 274 |
" 0.24998347, -0.35954213]], shape=(238, 480), dtype=float32)]"
|
| 275 |
]
|
| 276 |
},
|
| 277 |
+
"execution_count": 9,
|
| 278 |
"metadata": {},
|
| 279 |
"output_type": "execute_result"
|
| 280 |
}
|
|
|
|
| 299 |
},
|
| 300 |
{
|
| 301 |
"cell_type": "code",
|
| 302 |
+
"execution_count": 10,
|
| 303 |
"id": "e4bc0cb1-f5b0-4255-9e93-d643ae1396df",
|
| 304 |
"metadata": {},
|
| 305 |
"outputs": [
|
|
|
|
| 461 |
},
|
| 462 |
{
|
| 463 |
"cell_type": "code",
|
| 464 |
+
"execution_count": 12,
|
| 465 |
"id": "83f3064b-48a7-42fb-ba82-ec153ea946da",
|
| 466 |
"metadata": {},
|
| 467 |
"outputs": [
|
|
|
|
| 471 |
"array([1.96673731, 2.04801253, 2.09881898, 1.82533665, 1.97255249])"
|
| 472 |
]
|
| 473 |
},
|
| 474 |
+
"execution_count": 12,
|
| 475 |
"metadata": {},
|
| 476 |
"output_type": "execute_result"
|
| 477 |
}
|
|
|
|
| 483 |
},
|
| 484 |
{
|
| 485 |
"cell_type": "code",
|
| 486 |
+
"execution_count": 13,
|
| 487 |
"id": "42cc8b34-5ae9-4857-93fe-a438a0f2a868",
|
| 488 |
"metadata": {},
|
| 489 |
"outputs": [
|
|
|
|
| 494 |
" dtype=float32)"
|
| 495 |
]
|
| 496 |
},
|
| 497 |
+
"execution_count": 13,
|
| 498 |
"metadata": {},
|
| 499 |
"output_type": "execute_result"
|
| 500 |
}
|
|
|
|
| 516 |
},
|
| 517 |
{
|
| 518 |
"cell_type": "code",
|
| 519 |
+
"execution_count": 14,
|
| 520 |
"id": "2d5b725c-4eac-4a4b-9331-357c3ac140f7",
|
| 521 |
"metadata": {},
|
| 522 |
"outputs": [
|
|
|
|
| 529 |
" dtype='<U238')"
|
| 530 |
]
|
| 531 |
},
|
| 532 |
+
"execution_count": 14,
|
| 533 |
"metadata": {},
|
| 534 |
"output_type": "execute_result"
|
| 535 |
}
|
|
|
|
| 541 |
},
|
| 542 |
{
|
| 543 |
"cell_type": "code",
|
| 544 |
+
"execution_count": 15,
|
| 545 |
"id": "0e9615f7-c490-4947-96f4-7617266c686e",
|
| 546 |
"metadata": {},
|
| 547 |
"outputs": [
|
|
|
|
| 554 |
" dtype='<U238')"
|
| 555 |
]
|
| 556 |
},
|
| 557 |
+
"execution_count": 15,
|
| 558 |
"metadata": {},
|
| 559 |
"output_type": "execute_result"
|
| 560 |
}
|
|
|
|
| 571 |
"metadata": {},
|
| 572 |
"outputs": [],
|
| 573 |
"source": []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 574 |
}
|
| 575 |
],
|
| 576 |
"metadata": {
|