hemantn commited on
Commit
0c337fe
·
1 Parent(s): 3d3c39b

updated notebook with correct usage

Browse files
Files changed (1) hide show
  1. test_ablang2_HF_implementation.ipynb +32 -63
test_ablang2_HF_implementation.ipynb CHANGED
@@ -10,14 +10,16 @@
10
  },
11
  {
12
  "cell_type": "code",
13
- "execution_count": 1,
14
  "id": "7ae54cd0-6253-46dd-a316-4f20b12041e0",
15
  "metadata": {},
16
  "outputs": [],
17
  "source": [
18
- "import numpy as np \n",
19
- "from transformers import AutoTokenizer, AutoModel\n",
20
- "from ablang2.adapter import AbLang2PairedHuggingFaceAdapter"
 
 
21
  ]
22
  },
23
  {
@@ -38,7 +40,7 @@
38
  },
39
  {
40
  "cell_type": "code",
41
- "execution_count": 2,
42
  "id": "99192978-a008-4a32-a80e-bba238e0ec7c",
43
  "metadata": {},
44
  "outputs": [],
@@ -85,8 +87,17 @@
85
  "metadata": {},
86
  "outputs": [],
87
  "source": [
88
- "model = AutoModel.from_pretrained(\"/hemantn/ablang2/\", trust_remote_code=True)\n",
89
- "tokenizer = AutoTokenizer.from_pretrained(\"/hemantn/ablang2/\", trust_remote_code=True)\n",
 
 
 
 
 
 
 
 
 
90
  "ablang = AbLang2PairedHuggingFaceAdapter(model=model, tokenizer=tokenizer)"
91
  ]
92
  },
@@ -120,7 +131,7 @@
120
  },
121
  {
122
  "cell_type": "code",
123
- "execution_count": 5,
124
  "id": "ceae4a88-0679-4704-8bad-c06a4569c497",
125
  "metadata": {},
126
  "outputs": [],
@@ -145,7 +156,7 @@
145
  },
146
  {
147
  "cell_type": "code",
148
- "execution_count": 6,
149
  "id": "d22f4302-1262-4cc1-8a1c-a36daa8c710c",
150
  "metadata": {},
151
  "outputs": [
@@ -164,7 +175,7 @@
164
  " -0.16615383, -0.15569784]], shape=(5, 480))"
165
  ]
166
  },
167
- "execution_count": 6,
168
  "metadata": {},
169
  "output_type": "execute_result"
170
  }
@@ -189,7 +200,7 @@
189
  },
190
  {
191
  "cell_type": "code",
192
- "execution_count": 7,
193
  "id": "6227f661-575f-4b1e-9646-cfba7b10c3b4",
194
  "metadata": {},
195
  "outputs": [
@@ -263,7 +274,7 @@
263
  " 0.24998347, -0.35954213]], shape=(238, 480), dtype=float32)]"
264
  ]
265
  },
266
- "execution_count": 7,
267
  "metadata": {},
268
  "output_type": "execute_result"
269
  }
@@ -288,7 +299,7 @@
288
  },
289
  {
290
  "cell_type": "code",
291
- "execution_count": 8,
292
  "id": "e4bc0cb1-f5b0-4255-9e93-d643ae1396df",
293
  "metadata": {},
294
  "outputs": [
@@ -450,7 +461,7 @@
450
  },
451
  {
452
  "cell_type": "code",
453
- "execution_count": 10,
454
  "id": "83f3064b-48a7-42fb-ba82-ec153ea946da",
455
  "metadata": {},
456
  "outputs": [
@@ -460,7 +471,7 @@
460
  "array([1.96673731, 2.04801253, 2.09881898, 1.82533665, 1.97255249])"
461
  ]
462
  },
463
- "execution_count": 10,
464
  "metadata": {},
465
  "output_type": "execute_result"
466
  }
@@ -472,7 +483,7 @@
472
  },
473
  {
474
  "cell_type": "code",
475
- "execution_count": 11,
476
  "id": "42cc8b34-5ae9-4857-93fe-a438a0f2a868",
477
  "metadata": {},
478
  "outputs": [
@@ -483,7 +494,7 @@
483
  " dtype=float32)"
484
  ]
485
  },
486
- "execution_count": 11,
487
  "metadata": {},
488
  "output_type": "execute_result"
489
  }
@@ -505,7 +516,7 @@
505
  },
506
  {
507
  "cell_type": "code",
508
- "execution_count": 12,
509
  "id": "2d5b725c-4eac-4a4b-9331-357c3ac140f7",
510
  "metadata": {},
511
  "outputs": [
@@ -518,7 +529,7 @@
518
  " dtype='<U238')"
519
  ]
520
  },
521
- "execution_count": 12,
522
  "metadata": {},
523
  "output_type": "execute_result"
524
  }
@@ -530,7 +541,7 @@
530
  },
531
  {
532
  "cell_type": "code",
533
- "execution_count": 13,
534
  "id": "0e9615f7-c490-4947-96f4-7617266c686e",
535
  "metadata": {},
536
  "outputs": [
@@ -543,7 +554,7 @@
543
  " dtype='<U238')"
544
  ]
545
  },
546
- "execution_count": 13,
547
  "metadata": {},
548
  "output_type": "execute_result"
549
  }
@@ -560,48 +571,6 @@
560
  "metadata": {},
561
  "outputs": [],
562
  "source": []
563
- },
564
- {
565
- "cell_type": "markdown",
566
- "id": "98956ca9",
567
- "metadata": {},
568
- "source": [
569
- "## **rescoding / likelihood / probability**\n",
570
- "\n",
571
- "The rescodings represents each residue as a 480 sized embedding. The likelihoods represents each residue as the predicted logits for each character in the vocabulary. The probabilities represents the normalised likelihoods.\n",
572
- "\n",
573
- "**NB:** The output includes extra tokens (start, stop and separation tokens) in the format \"<VH_seq>|<VL_seq>\". The length of the output is therefore 5 longer than the VH and VL.\n",
574
- "\n",
575
- "**NB:** By default the representations are derived using a single forward pass. To prevent the predicted likelihood and probability to be affected by the input residue at each position, setting the \"stepwise_masking\" argument to True can be used. This will run a forward pass for each position with the residue at that position masked. This is much slower than running a single forward pass."
576
- ]
577
- },
578
- {
579
- "cell_type": "markdown",
580
- "id": "b046ae57",
581
- "metadata": {},
582
- "source": [
583
- "## **rescoding / likelihood / probability**\n",
584
- "\n",
585
- "The rescodings represents each residue as a 480 sized embedding. The likelihoods represents each residue as the predicted logits for each character in the vocabulary. The probabilities represents the normalised likelihoods.\n",
586
- "\n",
587
- "**NB:** The output includes extra tokens (start, stop and separation tokens) in the format \"<VH_seq>|<VL_seq>\". The length of the output is therefore 5 longer than the VH and VL.\n",
588
- "\n",
589
- "**NB:** By default the representations are derived using a single forward pass. To prevent the predicted likelihood and probability to be affected by the input residue at each position, setting the \"stepwise_masking\" argument to True can be used. This will run a forward pass for each position with the residue at that position masked. This is much slower than running a single forward pass."
590
- ]
591
- },
592
- {
593
- "cell_type": "markdown",
594
- "id": "78ccf7d8",
595
- "metadata": {},
596
- "source": [
597
- "## **rescoding / likelihood / probability**\n",
598
- "\n",
599
- "The rescodings represents each residue as a 480 sized embedding. The likelihoods represents each residue as the predicted logits for each character in the vocabulary. The probabilities represents the normalised likelihoods.\n",
600
- "\n",
601
- "**NB:** The output includes extra tokens (start, stop and separation tokens) in the format \"<VH_seq>|<VL_seq>\". The length of the output is therefore 5 longer than the VH and VL.\n",
602
- "\n",
603
- "**NB:** By default the representations are derived using a single forward pass. To prevent the predicted likelihood and probability to be affected by the input residue at each position, setting the \"stepwise_masking\" argument to True can be used. This will run a forward pass for each position with the residue at that position masked. This is much slower than running a single forward pass."
604
- ]
605
  }
606
  ],
607
  "metadata": {
 
10
  },
11
  {
12
  "cell_type": "code",
13
+ "execution_count": 11,
14
  "id": "7ae54cd0-6253-46dd-a316-4f20b12041e0",
15
  "metadata": {},
16
  "outputs": [],
17
  "source": [
18
+ "import sys\n",
19
+ "import os\n",
20
+ "import numpy as np\n",
21
+ "from transformers import AutoModel, AutoTokenizer\n",
22
+ "from transformers.utils import cached_file"
23
  ]
24
  },
25
  {
 
40
  },
41
  {
42
  "cell_type": "code",
43
+ "execution_count": 6,
44
  "id": "99192978-a008-4a32-a80e-bba238e0ec7c",
45
  "metadata": {},
46
  "outputs": [],
 
87
  "metadata": {},
88
  "outputs": [],
89
  "source": [
90
+ "# Load model and tokenizer from Hugging Face Hub\n",
91
+ "model = AutoModel.from_pretrained(\"hemantn/ablang2\", trust_remote_code=True)\n",
92
+ "tokenizer = AutoTokenizer.from_pretrained(\"hemantn/ablang2\", trust_remote_code=True)\n",
93
+ "\n",
94
+ "# Find the cached model directory and import adapter\n",
95
+ "adapter_path = cached_file(\"hemantn/ablang2\", \"adapter.py\")\n",
96
+ "cached_model_dir = os.path.dirname(adapter_path)\n",
97
+ "sys.path.insert(0, cached_model_dir)\n",
98
+ "\n",
99
+ "# Import and create the adapter\n",
100
+ "from adapter import AbLang2PairedHuggingFaceAdapter\n",
101
  "ablang = AbLang2PairedHuggingFaceAdapter(model=model, tokenizer=tokenizer)"
102
  ]
103
  },
 
131
  },
132
  {
133
  "cell_type": "code",
134
+ "execution_count": 7,
135
  "id": "ceae4a88-0679-4704-8bad-c06a4569c497",
136
  "metadata": {},
137
  "outputs": [],
 
156
  },
157
  {
158
  "cell_type": "code",
159
+ "execution_count": 8,
160
  "id": "d22f4302-1262-4cc1-8a1c-a36daa8c710c",
161
  "metadata": {},
162
  "outputs": [
 
175
  " -0.16615383, -0.15569784]], shape=(5, 480))"
176
  ]
177
  },
178
+ "execution_count": 8,
179
  "metadata": {},
180
  "output_type": "execute_result"
181
  }
 
200
  },
201
  {
202
  "cell_type": "code",
203
+ "execution_count": 9,
204
  "id": "6227f661-575f-4b1e-9646-cfba7b10c3b4",
205
  "metadata": {},
206
  "outputs": [
 
274
  " 0.24998347, -0.35954213]], shape=(238, 480), dtype=float32)]"
275
  ]
276
  },
277
+ "execution_count": 9,
278
  "metadata": {},
279
  "output_type": "execute_result"
280
  }
 
299
  },
300
  {
301
  "cell_type": "code",
302
+ "execution_count": 10,
303
  "id": "e4bc0cb1-f5b0-4255-9e93-d643ae1396df",
304
  "metadata": {},
305
  "outputs": [
 
461
  },
462
  {
463
  "cell_type": "code",
464
+ "execution_count": 12,
465
  "id": "83f3064b-48a7-42fb-ba82-ec153ea946da",
466
  "metadata": {},
467
  "outputs": [
 
471
  "array([1.96673731, 2.04801253, 2.09881898, 1.82533665, 1.97255249])"
472
  ]
473
  },
474
+ "execution_count": 12,
475
  "metadata": {},
476
  "output_type": "execute_result"
477
  }
 
483
  },
484
  {
485
  "cell_type": "code",
486
+ "execution_count": 13,
487
  "id": "42cc8b34-5ae9-4857-93fe-a438a0f2a868",
488
  "metadata": {},
489
  "outputs": [
 
494
  " dtype=float32)"
495
  ]
496
  },
497
+ "execution_count": 13,
498
  "metadata": {},
499
  "output_type": "execute_result"
500
  }
 
516
  },
517
  {
518
  "cell_type": "code",
519
+ "execution_count": 14,
520
  "id": "2d5b725c-4eac-4a4b-9331-357c3ac140f7",
521
  "metadata": {},
522
  "outputs": [
 
529
  " dtype='<U238')"
530
  ]
531
  },
532
+ "execution_count": 14,
533
  "metadata": {},
534
  "output_type": "execute_result"
535
  }
 
541
  },
542
  {
543
  "cell_type": "code",
544
+ "execution_count": 15,
545
  "id": "0e9615f7-c490-4947-96f4-7617266c686e",
546
  "metadata": {},
547
  "outputs": [
 
554
  " dtype='<U238')"
555
  ]
556
  },
557
+ "execution_count": 15,
558
  "metadata": {},
559
  "output_type": "execute_result"
560
  }
 
571
  "metadata": {},
572
  "outputs": [],
573
  "source": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
574
  }
575
  ],
576
  "metadata": {