Question Answering
sanjudebnath commited on
Commit
bc565b0
verified
1 Parent(s): 22b42b9

Upload load_data.ipynb

Browse files
Files changed (1) hide show
  1. load_data.ipynb +755 -0
load_data.ipynb ADDED
@@ -0,0 +1,755 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "12d87b30",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Load Data\n",
9
+ "This notebook loads and preproceses all necessary data, namely the following.\n",
10
+ "* OpenWebTextCorpus: for base DistilBERT model\n",
11
+ "* SQuAD datasrt: for Q&A\n",
12
+ "* Natural Questions (needs to be downloaded externally but is preprocessed here): for Q&A\n",
13
+ "* HotPotQA: for Q&A"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": 4,
19
+ "id": "7c82d7fa",
20
+ "metadata": {},
21
+ "outputs": [],
22
+ "source": [
23
+ "from tqdm.auto import tqdm\n",
24
+ "from datasets import load_dataset\n",
25
+ "import os\n",
26
+ "import pandas as pd\n",
27
+ "import random"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "markdown",
32
+ "id": "1737f219",
33
+ "metadata": {},
34
+ "source": [
35
+ "## Distilbert Data\n",
36
+ "In the following, we download the english openwebtext dataset from huggingface (https://huggingface.co/datasets/openwebtext). The dataset is provided by Aaron Gokaslan and Vanya Cohen from Brown University (https://skylion007.github.io/OpenWebTextCorpus/).\n",
37
+ "\n",
38
+ "We first load the data, investigate the structure and write the dataset into files of each 10 000 texts."
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "id": "cce7623c",
45
+ "metadata": {},
46
+ "outputs": [],
47
+ "source": [
48
+ "ds = load_dataset(\"openwebtext\")"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": 4,
54
+ "id": "678a5e86",
55
+ "metadata": {},
56
+ "outputs": [
57
+ {
58
+ "data": {
59
+ "text/plain": [
60
+ "DatasetDict({\n",
61
+ " train: Dataset({\n",
62
+ " features: ['text'],\n",
63
+ " num_rows: 8013769\n",
64
+ " })\n",
65
+ "})"
66
+ ]
67
+ },
68
+ "execution_count": 4,
69
+ "metadata": {},
70
+ "output_type": "execute_result"
71
+ }
72
+ ],
73
+ "source": [
74
+ "# we have a text-only training dataset with 8 million entries\n",
75
+ "ds"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "code",
80
+ "execution_count": 5,
81
+ "id": "b141bce7",
82
+ "metadata": {},
83
+ "outputs": [],
84
+ "source": [
85
+ "# create necessary folders\n",
86
+ "os.mkdir('data')\n",
87
+ "os.mkdir('data/original')"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": null,
93
+ "id": "ca94f995",
94
+ "metadata": {},
95
+ "outputs": [],
96
+ "source": [
97
+ "# save text in chunks of 10000 samples\n",
98
+ "text = []\n",
99
+ "i = 0\n",
100
+ "\n",
101
+ "for sample in tqdm(ds['train']):\n",
102
+ " # replace all newlines\n",
103
+ " sample = sample['text'].replace('\\n','')\n",
104
+ " \n",
105
+ " # append cleaned sample to all texts\n",
106
+ " text.append(sample)\n",
107
+ " \n",
108
+ " # if we processed 10000 samples, write them to a file and start over\n",
109
+ " if len(text) == 10000:\n",
110
+ " with open(f\"data/original/text_{i}.txt\", 'w', encoding='utf-8') as f:\n",
111
+ " f.write('\\n'.join(text))\n",
112
+ " text = []\n",
113
+ " i += 1 \n",
114
+ "\n",
115
+ "# write remaining samples to a file\n",
116
+ "with open(f\"data/original/text_{i}.txt\", 'w', encoding='utf-8') as f:\n",
117
+ " f.write('\\n'.join(text))"
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "markdown",
122
+ "id": "f131dcfc",
123
+ "metadata": {},
124
+ "source": [
125
+ "### Testing\n",
126
+ "If we load the first file, we should get a file that is 10000 lines long and has one column\n",
127
+ "\n",
128
+ "As we do not preprocess the data in any way, but just write the read text into the file, this is all testing necessary"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": 13,
134
+ "id": "df50af74",
135
+ "metadata": {},
136
+ "outputs": [],
137
+ "source": [
138
+ "with open(\"data/original/text_0.txt\", 'r', encoding='utf-8') as f:\n",
139
+ " lines = f.read().split('\\n')\n",
140
+ "lines = pd.DataFrame(lines)"
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "code",
145
+ "execution_count": 14,
146
+ "id": "8ddb0085",
147
+ "metadata": {},
148
+ "outputs": [
149
+ {
150
+ "name": "stdout",
151
+ "output_type": "stream",
152
+ "text": [
153
+ "Passed\n"
154
+ ]
155
+ }
156
+ ],
157
+ "source": [
158
+ "assert lines.shape==(10000,1)\n",
159
+ "print(\"Passed\")"
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "markdown",
164
+ "id": "1a65b268",
165
+ "metadata": {},
166
+ "source": [
167
+ "## SQuAD Data\n",
168
+ "In the following, we download the SQuAD dataset from huggingface (https://huggingface.co/datasets/squad). It was initially provided by Rajpurkar et al. from Stanford University.\n",
169
+ "\n",
170
+ "We again load the dataset and store it in chunks of 1000 into files."
171
+ ]
172
+ },
173
+ {
174
+ "cell_type": "code",
175
+ "execution_count": null,
176
+ "id": "6750ce6e",
177
+ "metadata": {},
178
+ "outputs": [],
179
+ "source": [
180
+ "dataset = load_dataset(\"squad\")"
181
+ ]
182
+ },
183
+ {
184
+ "cell_type": "code",
185
+ "execution_count": null,
186
+ "id": "65a7ee23",
187
+ "metadata": {},
188
+ "outputs": [],
189
+ "source": [
190
+ "os.mkdir(\"data/training_squad\")\n",
191
+ "os.mkdir(\"data/test_squad\")"
192
+ ]
193
+ },
194
+ {
195
+ "cell_type": "code",
196
+ "execution_count": null,
197
+ "id": "f6ebf63e",
198
+ "metadata": {},
199
+ "outputs": [],
200
+ "source": [
201
+ "# we already have a training and test split. Each sample has an id, title, context, question and answers.\n",
202
+ "dataset"
203
+ ]
204
+ },
205
+ {
206
+ "cell_type": "code",
207
+ "execution_count": null,
208
+ "id": "f67ae448",
209
+ "metadata": {},
210
+ "outputs": [],
211
+ "source": [
212
+ "# answers are provided like that - we need to extract answer_end for the model\n",
213
+ "dataset['train']['answers'][0]"
214
+ ]
215
+ },
216
+ {
217
+ "cell_type": "code",
218
+ "execution_count": null,
219
+ "id": "101cd650",
220
+ "metadata": {},
221
+ "outputs": [],
222
+ "source": [
223
+ "# column contains the split (either train or validation), save_dir is the directory\n",
224
+ "def save_samples(column, save_dir):\n",
225
+ " text = []\n",
226
+ " i = 0\n",
227
+ "\n",
228
+ " for sample in tqdm(dataset[column]):\n",
229
+ " \n",
230
+ " # preprocess the context and question by removing the newlines\n",
231
+ " context = sample['context'].replace('\\n','')\n",
232
+ " question = sample['question'].replace('\\n','')\n",
233
+ "\n",
234
+ " # get the answer as text and start character index\n",
235
+ " answer_text = sample['answers']['text'][0]\n",
236
+ " answer_start = str(sample['answers']['answer_start'][0])\n",
237
+ " \n",
238
+ " text.append([context, question, answer_text, answer_start])\n",
239
+ "\n",
240
+ " # we choose chunks of 1000\n",
241
+ " if len(text) == 1000:\n",
242
+ " with open(f\"data/{save_dir}/text_{i}.txt\", 'w', encoding='utf-8') as f:\n",
243
+ " f.write(\"\\n\".join([\"\\t\".join(t) for t in text]))\n",
244
+ " text = []\n",
245
+ " i += 1\n",
246
+ "\n",
247
+ " # save remaining\n",
248
+ " with open(f\"data/{save_dir}/text_{i}.txt\", 'w', encoding='utf-8') as f:\n",
249
+ " f.write(\"\\n\".join([\"\\t\".join(t) for t in text]))\n",
250
+ "\n",
251
+ "save_samples(\"train\", \"training_squad\")\n",
252
+ "save_samples(\"validation\", \"test_squad\")\n",
253
+ " "
254
+ ]
255
+ },
256
+ {
257
+ "cell_type": "markdown",
258
+ "id": "67044d13",
259
+ "metadata": {
260
+ "collapsed": false,
261
+ "jupyter": {
262
+ "outputs_hidden": false
263
+ }
264
+ },
265
+ "source": [
266
+ "### Testing\n",
267
+ "If we load a file, we should get a file with 10000 lines and 4 columns\n",
268
+ "\n",
269
+ "Also, we want to assure the correct interval. Hence, the second test."
270
+ ]
271
+ },
272
+ {
273
+ "cell_type": "code",
274
+ "execution_count": null,
275
+ "id": "446281cf",
276
+ "metadata": {},
277
+ "outputs": [],
278
+ "source": [
279
+ "with open(\"data/training_squad/text_0.txt\", 'r', encoding='utf-8') as f:\n",
280
+ " lines = f.read().split('\\n')\n",
281
+ " \n",
282
+ "lines = pd.DataFrame([line.split(\"\\t\") for line in lines], columns=[\"context\", \"question\", \"answer\", \"answer_start\"])"
283
+ ]
284
+ },
285
+ {
286
+ "cell_type": "code",
287
+ "execution_count": null,
288
+ "id": "ccd5c650",
289
+ "metadata": {},
290
+ "outputs": [],
291
+ "source": [
292
+ "assert lines.shape==(1000,4)\n",
293
+ "print(\"Passed\")"
294
+ ]
295
+ },
296
+ {
297
+ "cell_type": "code",
298
+ "execution_count": null,
299
+ "id": "2c9e4b70",
300
+ "metadata": {},
301
+ "outputs": [],
302
+ "source": [
303
+ "# we assert that we have the right interval\n",
304
+ "for ind, line in lines.iterrows():\n",
305
+ " sample = line\n",
306
+ " answer_start = int(sample['answer_start'])\n",
307
+ " assert sample['context'][answer_start:answer_start+len(sample['answer'])] == sample['answer']\n",
308
+ "print(\"Passed\")"
309
+ ]
310
+ },
311
+ {
312
+ "cell_type": "markdown",
313
+ "id": "02265ace",
314
+ "metadata": {},
315
+ "source": [
316
+ "## Natural Questions Dataset\n",
317
+ "* Download from https://ai.google.com/research/NaturalQuestions via gsutil (the one from huggingface has 134.92GB, the one from google cloud is in archives)\n",
318
+ "* Use gunzip to get some samples - we then get `.jsonl`files\n",
319
+ "* The dataset is a lot more messy, as it is just wikipedia articles with all web artifacts\n",
320
+ " * I cleaned the html tags\n",
321
+ " * Also I chose a random interval (containing the answer) from the dataset\n",
322
+ " * We can't send the whole text into the model anyways"
323
+ ]
324
+ },
325
+ {
326
+ "cell_type": "code",
327
+ "execution_count": null,
328
+ "id": "f3bce0c1",
329
+ "metadata": {},
330
+ "outputs": [],
331
+ "source": [
332
+ "from pathlib import Path\n",
333
+ "paths = [str(x) for x in Path('data/natural_questions/v1.0/train/').glob('**/*.jsonl')]"
334
+ ]
335
+ },
336
+ {
337
+ "cell_type": "code",
338
+ "execution_count": null,
339
+ "id": "e9c58c00",
340
+ "metadata": {},
341
+ "outputs": [],
342
+ "source": [
343
+ "os.mkdir(\"data/natural_questions_train\")"
344
+ ]
345
+ },
346
+ {
347
+ "cell_type": "code",
348
+ "execution_count": null,
349
+ "id": "0ed7ba6c",
350
+ "metadata": {},
351
+ "outputs": [],
352
+ "source": [
353
+ "import re\n",
354
+ "\n",
355
+ "# clean html tags\n",
356
+ "CLEANR = re.compile('<.+?>')\n",
357
+ "# clean multiple spaces\n",
358
+ "CLEANMULTSPACE = re.compile('(\\s)+')\n",
359
+ "\n",
360
+ "# the function takes an html documents and removes artifacts\n",
361
+ "def cleanhtml(raw_html):\n",
362
+ " # tags\n",
363
+ " cleantext = re.sub(CLEANR, '', raw_html)\n",
364
+ " # newlines\n",
365
+ " cleantext = cleantext.replace(\"\\n\", '')\n",
366
+ " # tabs\n",
367
+ " cleantext = cleantext.replace(\"\\t\", '')\n",
368
+ " # character encodings\n",
369
+ " cleantext = cleantext.replace(\"&#39;\", \"'\")\n",
370
+ " cleantext = cleantext.replace(\"&amp;\", \"'\")\n",
371
+ " cleantext = cleantext.replace(\"&quot;\", '\"')\n",
372
+ " # multiple spaces\n",
373
+ " cleantext = re.sub(CLEANMULTSPACE, ' ', cleantext)\n",
374
+ " # documents end with this tags, if it is present in the string, cut it off\n",
375
+ " idx = cleantext.find(\"<!-- NewPP limit\")\n",
376
+ " if idx > -1:\n",
377
+ " cleantext = cleantext[:idx]\n",
378
+ " return cleantext.strip()"
379
+ ]
380
+ },
381
+ {
382
+ "cell_type": "code",
383
+ "execution_count": null,
384
+ "id": "66ca19ac",
385
+ "metadata": {},
386
+ "outputs": [],
387
+ "source": [
388
+ "import json\n",
389
+ "\n",
390
+ "# file count\n",
391
+ "i = 0\n",
392
+ "data = []\n",
393
+ "\n",
394
+ "# iterate over all json files\n",
395
+ "for path in paths:\n",
396
+ " print(path)\n",
397
+ " # read file and store as list (this requires much memory, as the files are huge)\n",
398
+ " with open(path, 'r') as json_file:\n",
399
+ " json_list = list(json_file)\n",
400
+ " \n",
401
+ " # process every context, question, answer pair\n",
402
+ " for json_str in json_list:\n",
403
+ " result = json.loads(json_str)\n",
404
+ "\n",
405
+ " # append a question mark - SQuAD questions end with a qm too\n",
406
+ " question = result['question_text'] + \"?\"\n",
407
+ " \n",
408
+ " # some question do not contain an answer - we do not need them\n",
409
+ " if(len(result['annotations'][0]['short_answers'])==0):\n",
410
+ " continue\n",
411
+ "\n",
412
+ " # get true start/end byte\n",
413
+ " true_start = result['annotations'][0]['short_answers'][0]['start_byte']\n",
414
+ " true_end = result['annotations'][0]['short_answers'][0]['end_byte']\n",
415
+ "\n",
416
+ " # convert to bytes\n",
417
+ " byte_encoding = bytes(result['document_html'], encoding='utf-8')\n",
418
+ " \n",
419
+ " # the document is the whole wikipedia article, we randomly choose an appropriate part (containing the\n",
420
+ " # answer): we have 512 tokens as the input for the model - 4000 bytes lead to a good length\n",
421
+ " max_back = 3500 if true_start >= 3500 else true_start\n",
422
+ " first = random.randint(int(true_start)-max_back, int(true_start))\n",
423
+ " end = first + 3500 + true_end - true_start\n",
424
+ " \n",
425
+ " # get chosen context\n",
426
+ " cleanbytes = byte_encoding[first:end]\n",
427
+ " # decode back to text - if our end byte is the middle of a word, we ignore it and cut it off\n",
428
+ " cleantext = bytes.decode(cleanbytes, errors='ignore')\n",
429
+ " # clean html tags\n",
430
+ " cleantext = cleanhtml(cleantext)\n",
431
+ "\n",
432
+ " # find the true answer\n",
433
+ " answer_start = cleanbytes.find(byte_encoding[true_start:true_end])\n",
434
+ " true_answer = bytes.decode(cleanbytes[answer_start:answer_start+(true_end-true_start)])\n",
435
+ " \n",
436
+ " # clean html tags\n",
437
+ " true_answer = cleanhtml(true_answer)\n",
438
+ " \n",
439
+ " start_ind = cleantext.find(true_answer)\n",
440
+ " \n",
441
+ " # If cleaning the string makes the answer not findable skip it\n",
442
+ " # this hardly ever happens, except if there is an emense amount of web artifacts\n",
443
+ " if start_ind == -1:\n",
444
+ " continue\n",
445
+ " \n",
446
+ " data.append([cleantext, question, true_answer, str(start_ind)])\n",
447
+ "\n",
448
+ " if len(data) == 1000:\n",
449
+ " with open(f\"data/natural_questions_train/text_{i}.txt\", 'w', encoding='utf-8') as f:\n",
450
+ " f.write(\"\\n\".join([\"\\t\".join(t) for t in data]))\n",
451
+ " i += 1\n",
452
+ " data = []\n",
453
+ "with open(f\"data/natural_questions_train/text_{i}.txt\", 'w', encoding='utf-8') as f:\n",
454
+ " f.write(\"\\n\".join([\"\\t\".join(t) for t in data]))"
455
+ ]
456
+ },
457
+ {
458
+ "cell_type": "markdown",
459
+ "id": "30f26b4e",
460
+ "metadata": {},
461
+ "source": [
462
+ "### Testing\n",
463
+ "In the following, we first check if the shape of the file is correct.\n",
464
+ "\n",
465
+ "Then we iterate over the file and check if the answers according to the file are the same as in the original file."
466
+ ]
467
+ },
468
+ {
469
+ "cell_type": "code",
470
+ "execution_count": null,
471
+ "id": "490ac0db",
472
+ "metadata": {},
473
+ "outputs": [],
474
+ "source": [
475
+ "with open(\"data/natural_questions_train/text_0.txt\", 'r', encoding='utf-8') as f:\n",
476
+ " lines = f.read().split('\\n')\n",
477
+ " \n",
478
+ "lines = pd.DataFrame([line.split(\"\\t\") for line in lines], columns=[\"context\", \"question\", \"answer\", \"answer_start\"])"
479
+ ]
480
+ },
481
+ {
482
+ "cell_type": "code",
483
+ "execution_count": null,
484
+ "id": "0d7cc3ee",
485
+ "metadata": {},
486
+ "outputs": [],
487
+ "source": [
488
+ "assert lines.shape == (1000, 4)\n",
489
+ "print(\"Passed\")"
490
+ ]
491
+ },
492
+ {
493
+ "cell_type": "code",
494
+ "execution_count": null,
495
+ "id": "0fd8a854",
496
+ "metadata": {},
497
+ "outputs": [],
498
+ "source": [
499
+ "with open(\"data/natural_questions/v1.0/train/nq-train-00.jsonl\", 'r') as json_file:\n",
500
+ " json_list = list(json_file)[:500]\n",
501
+ "del json_file"
502
+ ]
503
+ },
504
+ {
505
+ "cell_type": "code",
506
+ "execution_count": null,
507
+ "id": "170bff30",
508
+ "metadata": {},
509
+ "outputs": [],
510
+ "source": [
511
+ "lines_index = 0\n",
512
+ "for i in range(len(json_list)):\n",
513
+ " result = json.loads(json_list[i])\n",
514
+ " \n",
515
+ " if(len(result['annotations'][0]['short_answers'])==0):\n",
516
+ " pass\n",
517
+ " else: \n",
518
+ " # assert that the question text is the same\n",
519
+ " assert result['question_text'] + \"?\" == lines.loc[lines_index, 'question']\n",
520
+ " true_start = result['annotations'][0]['short_answers'][0]['start_byte']\n",
521
+ " true_end = result['annotations'][0]['short_answers'][0]['end_byte']\n",
522
+ " true_answer = bytes.decode(bytes(result['document_html'], encoding='utf-8')[true_start:true_end])\n",
523
+ " \n",
524
+ " processed_answer = lines.loc[lines_index, 'answer']\n",
525
+ " # assert that the answer is the same\n",
526
+ " assert cleanhtml(true_answer) == processed_answer\n",
527
+ " \n",
528
+ " start_ind = int(lines.loc[lines_index, 'answer_start'])\n",
529
+ " # assert that the answer (according to the index) is the same\n",
530
+ " assert cleanhtml(true_answer) == lines.loc[lines_index, 'context'][start_ind:start_ind+len(processed_answer)]\n",
531
+ " \n",
532
+ " lines_index += 1\n",
533
+ " \n",
534
+ " if lines_index == len(lines):\n",
535
+ " break\n",
536
+ "print(\"Passed\")"
537
+ ]
538
+ },
539
+ {
540
+ "cell_type": "markdown",
541
+ "id": "78e6e737",
542
+ "metadata": {},
543
+ "source": [
544
+ "## Hotpot QA"
545
+ ]
546
+ },
547
+ {
548
+ "cell_type": "code",
549
+ "execution_count": null,
550
+ "id": "27efcc8c",
551
+ "metadata": {},
552
+ "outputs": [],
553
+ "source": [
554
+ "ds = load_dataset(\"hotpot_qa\", 'fullwiki')"
555
+ ]
556
+ },
557
+ {
558
+ "cell_type": "code",
559
+ "execution_count": null,
560
+ "id": "1493f21f",
561
+ "metadata": {},
562
+ "outputs": [],
563
+ "source": [
564
+ "ds"
565
+ ]
566
+ },
567
+ {
568
+ "cell_type": "code",
569
+ "execution_count": null,
570
+ "id": "2a047946",
571
+ "metadata": {},
572
+ "outputs": [],
573
+ "source": [
574
+ "os.mkdir('data/hotpotqa_training')\n",
575
+ "os.mkdir('data/hotpotqa_test')"
576
+ ]
577
+ },
578
+ {
579
+ "cell_type": "code",
580
+ "execution_count": null,
581
+ "id": "e65b6485",
582
+ "metadata": {},
583
+ "outputs": [],
584
+ "source": [
585
+ "# column contains the split (either train or validation), save_dir is the directory\n",
586
+ "def save_samples(column, save_dir):\n",
587
+ " text = []\n",
588
+ " i = 0\n",
589
+ "\n",
590
+ " for sample in tqdm(ds[column]):\n",
591
+ " \n",
592
+ " # preprocess the context and question by removing the newlines\n",
593
+ " context = sample['context']['sentences']\n",
594
+ " context = \" \".join([\"\".join(sentence) for sentence in context])\n",
595
+ " question = sample['question'].replace('\\n','')\n",
596
+ " \n",
597
+ " # get the answer as text and start character index\n",
598
+ " answer_text = sample['answer']\n",
599
+ " answer_start = context.find(answer_text)\n",
600
+ " if answer_start == -1:\n",
601
+ " continue\n",
602
+ " \n",
603
+ " \n",
604
+ " \n",
605
+ " if answer_start > 1500:\n",
606
+ " first = random.randint(answer_start-1500, answer_start)\n",
607
+ " end = first + 1500 + len(answer_text)\n",
608
+ " \n",
609
+ " context = context[first:end+1]\n",
610
+ " answer_start = context.find(answer_text)\n",
611
+ " \n",
612
+ " if answer_start == -1:continue\n",
613
+ " \n",
614
+ " text.append([context, question, answer_text, str(answer_start)])\n",
615
+ "\n",
616
+ " # we choose chunks of 1000\n",
617
+ " if len(text) == 1000:\n",
618
+ " with open(f\"data/{save_dir}/text_{i}.txt\", 'w', encoding='utf-8') as f:\n",
619
+ " f.write(\"\\n\".join([\"\\t\".join(t) for t in text]))\n",
620
+ " text = []\n",
621
+ " i += 1\n",
622
+ "\n",
623
+ " # save remaining\n",
624
+ " with open(f\"data/{save_dir}/text_{i}.txt\", 'w', encoding='utf-8') as f:\n",
625
+ " f.write(\"\\n\".join([\"\\t\".join(t) for t in text]))\n",
626
+ "\n",
627
+ "save_samples(\"train\", \"hotpotqa_training\")\n",
628
+ "save_samples(\"validation\", \"hotpotqa_test\")"
629
+ ]
630
+ },
631
+ {
632
+ "cell_type": "markdown",
633
+ "id": "97cc358f",
634
+ "metadata": {},
635
+ "source": [
636
+ "## Testing"
637
+ ]
638
+ },
639
+ {
640
+ "cell_type": "code",
641
+ "execution_count": null,
642
+ "id": "f321483c",
643
+ "metadata": {},
644
+ "outputs": [],
645
+ "source": [
646
+ "with open(\"data/hotpotqa_training/text_0.txt\", 'r', encoding='utf-8') as f:\n",
647
+ " lines = f.read().split('\\n')\n",
648
+ " \n",
649
+ "lines = pd.DataFrame([line.split(\"\\t\") for line in lines], columns=[\"context\", \"question\", \"answer\", \"answer_start\"])"
650
+ ]
651
+ },
652
+ {
653
+ "cell_type": "code",
654
+ "execution_count": null,
655
+ "id": "72a96e78",
656
+ "metadata": {},
657
+ "outputs": [],
658
+ "source": [
659
+ "assert lines.shape == (1000, 4)\n",
660
+ "print(\"Passed\")"
661
+ ]
662
+ },
663
+ {
664
+ "cell_type": "code",
665
+ "execution_count": null,
666
+ "id": "c32c2f16",
667
+ "metadata": {},
668
+ "outputs": [],
669
+ "source": [
670
+ "# we assert that we have the right interval\n",
671
+ "for ind, line in lines.iterrows():\n",
672
+ " sample = line\n",
673
+ " answer_start = int(sample['answer_start'])\n",
674
+ " assert sample['context'][answer_start:answer_start+len(sample['answer'])] == sample['answer']\n",
675
+ "print(\"Passed\")"
676
+ ]
677
+ },
678
+ {
679
+ "cell_type": "code",
680
+ "execution_count": null,
681
+ "id": "bc36fe7d",
682
+ "metadata": {},
683
+ "outputs": [],
684
+ "source": []
685
+ }
686
+ ],
687
+ "metadata": {
688
+ "kernelspec": {
689
+ "display_name": "Python 3 (ipykernel)",
690
+ "language": "python",
691
+ "name": "python3"
692
+ },
693
+ "language_info": {
694
+ "codemirror_mode": {
695
+ "name": "ipython",
696
+ "version": 3
697
+ },
698
+ "file_extension": ".py",
699
+ "mimetype": "text/x-python",
700
+ "name": "python",
701
+ "nbconvert_exporter": "python",
702
+ "pygments_lexer": "ipython3",
703
+ "version": "3.10.16"
704
+ },
705
+ "toc": {
706
+ "base_numbering": 1,
707
+ "nav_menu": {},
708
+ "number_sections": true,
709
+ "sideBar": true,
710
+ "skip_h1_title": false,
711
+ "title_cell": "Table of Contents",
712
+ "title_sidebar": "Contents",
713
+ "toc_cell": false,
714
+ "toc_position": {},
715
+ "toc_section_display": true,
716
+ "toc_window_display": false
717
+ },
718
+ "varInspector": {
719
+ "cols": {
720
+ "lenName": 16,
721
+ "lenType": 16,
722
+ "lenVar": 40
723
+ },
724
+ "kernels_config": {
725
+ "python": {
726
+ "delete_cmd_postfix": "",
727
+ "delete_cmd_prefix": "del ",
728
+ "library": "var_list.py",
729
+ "varRefreshCmd": "print(var_dic_list())"
730
+ },
731
+ "r": {
732
+ "delete_cmd_postfix": ") ",
733
+ "delete_cmd_prefix": "rm(",
734
+ "library": "var_list.r",
735
+ "varRefreshCmd": "cat(var_dic_list()) "
736
+ }
737
+ },
738
+ "types_to_exclude": [
739
+ "module",
740
+ "function",
741
+ "builtin_function_or_method",
742
+ "instance",
743
+ "_Feature"
744
+ ],
745
+ "window_display": false
746
+ },
747
+ "vscode": {
748
+ "interpreter": {
749
+ "hash": "85bf9c14e9ba73b783ed1274d522bec79eb0b2b739090180d8ce17bb11aff4aa"
750
+ }
751
+ }
752
+ },
753
+ "nbformat": 4,
754
+ "nbformat_minor": 5
755
+ }