File size: 3,270 Bytes
0499fd5 9e537b8 0499fd5 9e537b8 0499fd5 9e537b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import functools
import seqio
import tensorflow as tf
import t5.data
from datasets import load_dataset
from t5.data import postprocessors
from t5.data import preprocessors
from t5.evaluation import metrics
from seqio import FunctionDataSource, utils
TaskRegistry = seqio.TaskRegistry
DEFAULT_OUTPUT_FEATURES = {
"inputs": seqio.Feature(
vocabulary=t5.data.get_default_vocabulary(), add_eos=True,
required=False),
"targets": seqio.Feature(
vocabulary=t5.data.get_default_vocabulary(), add_eos=True)
}
def gen_dataset(split, shuffle=False, seed=None, column="text", dataset_params=None):
dataset = load_dataset(**dataset_params)
if shuffle:
if seed:
dataset = dataset.shuffle(seed=seed)
else:
dataset = dataset.shuffle()
while True:
for item in dataset[str(split)]:
yield item[column]
def dataset_fn(split, shuffle_files, seed=None, dataset_params=None):
return tf.data.Dataset.from_generator(
functools.partial(gen_dataset, split, shuffle_files, seed, dataset_params=dataset_params),
output_signature=tf.TensorSpec(shape=(), dtype=tf.string, name=dataset_name)
)
@utils.map_over_dataset
def target_to_key(x, key_map, target_key):
"""Assign the value from the dataset to target_key in key_map"""
return {**key_map, target_key: x}
# Final pretraining task used in Raffel et al., 2019 adaptated to NCC
dataset_name = 'NbAiLab/NCC'
dataset_params = {"path": dataset_name}
dataset_shapes = {'train': 20830348, 'validation': 473079}
TaskRegistry.add(
'span_corruption',
source=seqio.FunctionDataSource(
dataset_fn=functools.partial(dataset_fn, dataset_params=dataset_params),
splits=("train", "validation"),
caching_permitted=True,
num_input_examples=dataset_shapes,
),
preprocessors=[
functools.partial(
target_to_key, key_map={
"inputs": None,
"targets": None,
}, target_key="targets"),
seqio.preprocessors.tokenize,
seqio.CacheDatasetPlaceholder(),
preprocessors.span_corruption,
seqio.preprocessors.append_eos_after_trim,
],
output_features={"targets": DEFAULT_OUTPUT_FEATURES["targets"]},
metric_fns=[]
)
# Final pretraining task used in Raffel et al., 2019 adaptated to nbailab_extended
dataset_name = 'NbAiLab/nbailab_extended'
dataset_params = {"path": dataset_name, "use_auth_token": True, "streaming": True}
dataset_shapes = None
TaskRegistry.add(
'span_corrpution',
source=seqio.FunctionDataSource(
dataset_fn=functools.partial(dataset_fn, dataset_params=dataset_params),
splits=("train", "validation"),
caching_permitted=True,
num_input_examples=dataset_shapes,
),
preprocessors=[
functools.partial(
target_to_key, key_map={
"inputs": None,
"targets": None,
}, target_key="targets"),
seqio.preprocessors.tokenize,
seqio.CacheDatasetPlaceholder(),
preprocessors.span_corruption,
seqio.preprocessors.append_eos_after_trim,
],
output_features={"targets": DEFAULT_OUTPUT_FEATURES["targets"]},
metric_fns=[]
)
|