pretrain_cont
Browse files- __pycache__/tasks.cpython-38.pyc +0 -0
- continue_pretraining_base.gin +6 -6
- pretrain_cont.gin +111 -0
- tasks.py +47 -16
- tasks_old.py +102 -0
- train_base.sh +1 -1
__pycache__/tasks.cpython-38.pyc
CHANGED
|
Binary files a/__pycache__/tasks.cpython-38.pyc and b/__pycache__/tasks.cpython-38.pyc differ
|
|
|
continue_pretraining_base.gin
CHANGED
|
@@ -1,15 +1,15 @@
|
|
| 1 |
include 't5x/examples/t5/mt5/base.gin'
|
| 2 |
include 't5x/configs/runs/pretrain.gin'
|
|
|
|
| 3 |
|
| 4 |
# Register necessary SeqIO Tasks/Mixtures.
|
| 5 |
import t5.data.mixtures
|
| 6 |
import tasks
|
| 7 |
|
| 8 |
-
MIXTURE_OR_TASK_NAME = "
|
| 9 |
-
TASK_FEATURE_LENGTHS = {"inputs":
|
| 10 |
-
TRAIN_STEPS =
|
| 11 |
-
DROPOUT_RATE = 0.0
|
| 12 |
INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/mt5_xxl/checkpoint_1000000"
|
|
|
|
| 13 |
|
| 14 |
-
#Batch size should be the default for pretraining
|
| 15 |
-
#BATCH_SIZE = 256
|
|
|
|
| 1 |
include 't5x/examples/t5/mt5/base.gin'
|
| 2 |
include 't5x/configs/runs/pretrain.gin'
|
| 3 |
+
include 't5x/configs/runs/finetune.gin'
|
| 4 |
|
| 5 |
# Register necessary SeqIO Tasks/Mixtures.
|
| 6 |
import t5.data.mixtures
|
| 7 |
import tasks
|
| 8 |
|
| 9 |
+
MIXTURE_OR_TASK_NAME = "ncc_span_corruption_stream"
|
| 10 |
+
TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 512}
|
| 11 |
+
TRAIN_STEPS = 1_100_000
|
| 12 |
+
DROPOUT_RATE = 0.0 # Changed from the default since T5-1.1 recomments this.
|
| 13 |
INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/mt5_xxl/checkpoint_1000000"
|
| 14 |
+
PjitPartitioner.num_partitions = 2
|
| 15 |
|
|
|
|
|
|
pretrain_cont.gin
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Defaults for pretraining with train.py.
|
| 2 |
+
#
|
| 3 |
+
#
|
| 4 |
+
# You must also include a binding for MODEL.
|
| 5 |
+
#
|
| 6 |
+
# Required to be set
|
| 7 |
+
#
|
| 8 |
+
# - MIXTURE_OR_TASK_NAME
|
| 9 |
+
# - TASK_FEATURE_LENGTHS
|
| 10 |
+
# - TRAIN_STEPS - include pretrain steps
|
| 11 |
+
# - MODEL_DIR: # automatically set when using xm_launch
|
| 12 |
+
#
|
| 13 |
+
# Commonly overridden options:
|
| 14 |
+
#
|
| 15 |
+
# - train/DatasetConfig.batch_size
|
| 16 |
+
# - train_eval/DatasetConfig.batch_size
|
| 17 |
+
# - PjitPartitioner.num_partitions
|
| 18 |
+
# - Trainer.num_microbatches
|
| 19 |
+
# - DROPOUT_RATE
|
| 20 |
+
from __gin__ import dynamic_registration
|
| 21 |
+
|
| 22 |
+
import __main__ as train_script
|
| 23 |
+
from t5x import gin_utils
|
| 24 |
+
from t5x import partitioning
|
| 25 |
+
from t5x import utils
|
| 26 |
+
from t5x import trainer
|
| 27 |
+
|
| 28 |
+
MIXTURE_OR_TASK_NAME = %gin.REQUIRED
|
| 29 |
+
TASK_FEATURE_LENGTHS = %gin.REQUIRED
|
| 30 |
+
TRAIN_STEPS = %gin.REQUIRED
|
| 31 |
+
MODEL_DIR = %gin.REQUIRED
|
| 32 |
+
BATCH_SIZE = 128
|
| 33 |
+
USE_CACHED_TASKS = True
|
| 34 |
+
INITIAL_CHECKPOINT_PATH = %gin.REQUIRED
|
| 35 |
+
|
| 36 |
+
# DEPRECATED: Import the this module in your gin file.
|
| 37 |
+
MIXTURE_OR_TASK_MODULE = None
|
| 38 |
+
SHUFFLE_TRAIN_EXAMPLES = True
|
| 39 |
+
|
| 40 |
+
# HW RNG is faster than SW, but has limited determinism.
|
| 41 |
+
# Most notably it is not deterministic across different
|
| 42 |
+
# submeshes.
|
| 43 |
+
USE_HARDWARE_RNG = False
|
| 44 |
+
# None always uses faster, hardware RNG
|
| 45 |
+
RANDOM_SEED = None
|
| 46 |
+
|
| 47 |
+
# Can be overridden with `train.*`.`
|
| 48 |
+
train_script.train:
|
| 49 |
+
model = %MODEL # imported from separate gin file
|
| 50 |
+
model_dir = %MODEL_DIR
|
| 51 |
+
train_dataset_cfg = @train/utils.DatasetConfig()
|
| 52 |
+
train_eval_dataset_cfg = @train_eval/utils.DatasetConfig()
|
| 53 |
+
infer_eval_dataset_cfg = None
|
| 54 |
+
checkpoint_cfg = @utils.CheckpointConfig()
|
| 55 |
+
partitioner = @partitioning.PjitPartitioner()
|
| 56 |
+
trainer_cls = @trainer.Trainer
|
| 57 |
+
total_steps = %TRAIN_STEPS
|
| 58 |
+
eval_steps = 20
|
| 59 |
+
eval_period = 1000
|
| 60 |
+
random_seed = %RANDOM_SEED
|
| 61 |
+
use_hardware_rng = %USE_HARDWARE_RNG
|
| 62 |
+
summarize_config_fn = @gin_utils.summarize_gin_config
|
| 63 |
+
|
| 64 |
+
partitioning.PjitPartitioner:
|
| 65 |
+
num_partitions = 1
|
| 66 |
+
model_parallel_submesh = None
|
| 67 |
+
logical_axis_rules = @partitioning.standard_logical_axis_rules()
|
| 68 |
+
|
| 69 |
+
train/utils.DatasetConfig:
|
| 70 |
+
mixture_or_task_name = %MIXTURE_OR_TASK_NAME
|
| 71 |
+
task_feature_lengths = %TASK_FEATURE_LENGTHS
|
| 72 |
+
split = 'train'
|
| 73 |
+
batch_size = %BATCH_SIZE
|
| 74 |
+
shuffle = %SHUFFLE_TRAIN_EXAMPLES
|
| 75 |
+
seed = None # use a new seed each run/restart
|
| 76 |
+
use_cached = %USE_CACHED_TASKS
|
| 77 |
+
pack = True
|
| 78 |
+
module = %MIXTURE_OR_TASK_MODULE
|
| 79 |
+
|
| 80 |
+
train_eval/utils.DatasetConfig:
|
| 81 |
+
mixture_or_task_name = %MIXTURE_OR_TASK_NAME
|
| 82 |
+
task_feature_lengths = %TASK_FEATURE_LENGTHS
|
| 83 |
+
split = 'validation'
|
| 84 |
+
batch_size = %BATCH_SIZE
|
| 85 |
+
shuffle = False
|
| 86 |
+
seed = 42
|
| 87 |
+
use_cached = %USE_CACHED_TASKS
|
| 88 |
+
pack = True
|
| 89 |
+
module = %MIXTURE_OR_TASK_MODULE
|
| 90 |
+
|
| 91 |
+
utils.CheckpointConfig:
|
| 92 |
+
restore = @utils.RestoreCheckpointConfig()
|
| 93 |
+
save = @utils.SaveCheckpointConfig()
|
| 94 |
+
utils.RestoreCheckpointConfig:
|
| 95 |
+
path = %INITIAL_CHECKPOINT_PATH
|
| 96 |
+
mode = 'specific'
|
| 97 |
+
dtype = 'float32'
|
| 98 |
+
utils.SaveCheckpointConfig:
|
| 99 |
+
period = 1000
|
| 100 |
+
dtype = 'float32'
|
| 101 |
+
keep = None # keep all checkpoints
|
| 102 |
+
save_dataset = False # don't checkpoint dataset state
|
| 103 |
+
|
| 104 |
+
trainer.Trainer:
|
| 105 |
+
num_microbatches = None
|
| 106 |
+
learning_rate_fn = @utils.create_learning_rate_scheduler()
|
| 107 |
+
|
| 108 |
+
utils.create_learning_rate_scheduler:
|
| 109 |
+
factors = 'constant * rsqrt_decay'
|
| 110 |
+
base_learning_rate = 0.5 #This is set to half of the original since it is continued training
|
| 111 |
+
warmup_steps = 10000 # 10k to keep consistent with T5/MTF defaults.
|
tasks.py
CHANGED
|
@@ -11,15 +11,18 @@ from seqio import FunctionDataSource, utils
|
|
| 11 |
|
| 12 |
TaskRegistry = seqio.TaskRegistry
|
| 13 |
|
| 14 |
-
|
|
|
|
| 15 |
|
| 16 |
DEFAULT_OUTPUT_FEATURES = {
|
| 17 |
"inputs": seqio.Feature(
|
| 18 |
-
vocabulary=
|
| 19 |
required=False),
|
| 20 |
"targets": seqio.Feature(
|
| 21 |
-
vocabulary=
|
| 22 |
}
|
|
|
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
def gen_dataset(split, shuffle=False, seed=None, column="text", dataset_params=None):
|
|
@@ -48,15 +51,15 @@ def target_to_key(x, key_map, target_key):
|
|
| 48 |
|
| 49 |
|
| 50 |
# Final pretraining task used in Raffel et al., 2019 adaptated to NCC
|
| 51 |
-
dataset_name = 'NbAiLab/
|
| 52 |
-
dataset_params = {"path": dataset_name
|
| 53 |
dataset_shapes = {'train': 20830348, 'validation': 473079}
|
| 54 |
TaskRegistry.add(
|
| 55 |
-
|
| 56 |
source=seqio.FunctionDataSource(
|
| 57 |
dataset_fn=functools.partial(dataset_fn, dataset_params=dataset_params),
|
| 58 |
splits=("train", "validation"),
|
| 59 |
-
|
| 60 |
num_input_examples=dataset_shapes,
|
| 61 |
),
|
| 62 |
preprocessors=[
|
|
@@ -66,24 +69,52 @@ TaskRegistry.add(
|
|
| 66 |
"targets": None,
|
| 67 |
}, target_key="targets"),
|
| 68 |
seqio.preprocessors.tokenize,
|
| 69 |
-
#seqio.CacheDatasetPlaceholder(),
|
| 70 |
-
preprocessors.span_corruption,
|
| 71 |
seqio.preprocessors.append_eos_after_trim,
|
| 72 |
],
|
| 73 |
-
output_features={"targets":
|
| 74 |
metric_fns=[]
|
| 75 |
)
|
| 76 |
|
| 77 |
-
# Final pretraining task used in Raffel et al., 2019 adaptated to
|
| 78 |
-
dataset_name = 'NbAiLab/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
dataset_params = {"path": dataset_name, "use_auth_token": True, "streaming": True}
|
| 80 |
dataset_shapes = None
|
| 81 |
TaskRegistry.add(
|
| 82 |
-
|
| 83 |
source=seqio.FunctionDataSource(
|
| 84 |
dataset_fn=functools.partial(dataset_fn, dataset_params=dataset_params),
|
| 85 |
splits=("train", "validation"),
|
| 86 |
-
caching_permitted=
|
| 87 |
num_input_examples=dataset_shapes,
|
| 88 |
),
|
| 89 |
preprocessors=[
|
|
@@ -93,10 +124,10 @@ TaskRegistry.add(
|
|
| 93 |
"targets": None,
|
| 94 |
}, target_key="targets"),
|
| 95 |
seqio.preprocessors.tokenize,
|
| 96 |
-
seqio.CacheDatasetPlaceholder(),
|
| 97 |
preprocessors.span_corruption,
|
| 98 |
seqio.preprocessors.append_eos_after_trim,
|
| 99 |
],
|
| 100 |
-
output_features={"targets":
|
| 101 |
metric_fns=[]
|
| 102 |
)
|
|
|
|
| 11 |
|
| 12 |
TaskRegistry = seqio.TaskRegistry
|
| 13 |
|
| 14 |
+
|
| 15 |
+
|
| 16 |
|
| 17 |
DEFAULT_OUTPUT_FEATURES = {
|
| 18 |
"inputs": seqio.Feature(
|
| 19 |
+
vocabulary=t5.data.get_default_vocabulary(), add_eos=True,
|
| 20 |
required=False),
|
| 21 |
"targets": seqio.Feature(
|
| 22 |
+
vocabulary=t5.data.get_default_vocabulary(), add_eos=True)
|
| 23 |
}
|
| 24 |
+
# Custom vocabs can also be defined and loaded
|
| 25 |
+
# vocabulary = seqio.SentencePieceVocabulary("gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model")
|
| 26 |
|
| 27 |
|
| 28 |
def gen_dataset(split, shuffle=False, seed=None, column="text", dataset_params=None):
|
|
|
|
| 51 |
|
| 52 |
|
| 53 |
# Final pretraining task used in Raffel et al., 2019 adaptated to NCC
|
| 54 |
+
dataset_name = 'NbAiLab/NCC'
|
| 55 |
+
dataset_params = {"path": dataset_name}
|
| 56 |
dataset_shapes = {'train': 20830348, 'validation': 473079}
|
| 57 |
TaskRegistry.add(
|
| 58 |
+
"ncc_span_corruption",
|
| 59 |
source=seqio.FunctionDataSource(
|
| 60 |
dataset_fn=functools.partial(dataset_fn, dataset_params=dataset_params),
|
| 61 |
splits=("train", "validation"),
|
| 62 |
+
caching_permitted=False,
|
| 63 |
num_input_examples=dataset_shapes,
|
| 64 |
),
|
| 65 |
preprocessors=[
|
|
|
|
| 69 |
"targets": None,
|
| 70 |
}, target_key="targets"),
|
| 71 |
seqio.preprocessors.tokenize,
|
| 72 |
+
# seqio.CacheDatasetPlaceholder(),
|
| 73 |
+
preprocessors.span_corruption,
|
| 74 |
seqio.preprocessors.append_eos_after_trim,
|
| 75 |
],
|
| 76 |
+
output_features={"targets": seqio.Feature(vocabulary=vocabulary, add_eos=True)},
|
| 77 |
metric_fns=[]
|
| 78 |
)
|
| 79 |
|
| 80 |
+
# Final pretraining task used in Raffel et al., 2019 adaptated to NCC
|
| 81 |
+
dataset_name = 'NbAiLab/NCC_amall'
|
| 82 |
+
dataset_params = {"path": dataset_name}
|
| 83 |
+
dataset_shapes = {'train': 20830348, 'validation': 473079}
|
| 84 |
+
TaskRegistry.add(
|
| 85 |
+
"ncc_small_span_corruption",
|
| 86 |
+
source=seqio.FunctionDataSource(
|
| 87 |
+
dataset_fn=functools.partial(dataset_fn, dataset_params=dataset_params),
|
| 88 |
+
splits=("train", "validation"),
|
| 89 |
+
caching_permitted=False,
|
| 90 |
+
num_input_examples=dataset_shapes,
|
| 91 |
+
),
|
| 92 |
+
preprocessors=[
|
| 93 |
+
functools.partial(
|
| 94 |
+
target_to_key, key_map={
|
| 95 |
+
"inputs": None,
|
| 96 |
+
"targets": None,
|
| 97 |
+
}, target_key="targets"),
|
| 98 |
+
seqio.preprocessors.tokenize,
|
| 99 |
+
# seqio.CacheDatasetPlaceholder(),
|
| 100 |
+
preprocessors.span_corruption,
|
| 101 |
+
seqio.preprocessors.append_eos_after_trim,
|
| 102 |
+
],
|
| 103 |
+
output_features={"targets": seqio.Feature(vocabulary=vocabulary, add_eos=True)},
|
| 104 |
+
metric_fns=[]
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# Final pretraining task used in Raffel et al., 2019 adaptated to NCC
|
| 109 |
+
dataset_name = 'NbAiLab/NCC'
|
| 110 |
dataset_params = {"path": dataset_name, "use_auth_token": True, "streaming": True}
|
| 111 |
dataset_shapes = None
|
| 112 |
TaskRegistry.add(
|
| 113 |
+
"NCC_span_corruption_stream",
|
| 114 |
source=seqio.FunctionDataSource(
|
| 115 |
dataset_fn=functools.partial(dataset_fn, dataset_params=dataset_params),
|
| 116 |
splits=("train", "validation"),
|
| 117 |
+
caching_permitted=False,
|
| 118 |
num_input_examples=dataset_shapes,
|
| 119 |
),
|
| 120 |
preprocessors=[
|
|
|
|
| 124 |
"targets": None,
|
| 125 |
}, target_key="targets"),
|
| 126 |
seqio.preprocessors.tokenize,
|
| 127 |
+
# seqio.CacheDatasetPlaceholder(),
|
| 128 |
preprocessors.span_corruption,
|
| 129 |
seqio.preprocessors.append_eos_after_trim,
|
| 130 |
],
|
| 131 |
+
output_features={"targets": seqio.Feature(vocabulary=vocabulary, add_eos=True)},
|
| 132 |
metric_fns=[]
|
| 133 |
)
|
tasks_old.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
|
| 3 |
+
import seqio
|
| 4 |
+
import tensorflow as tf
|
| 5 |
+
import t5.data
|
| 6 |
+
from datasets import load_dataset
|
| 7 |
+
from t5.data import postprocessors
|
| 8 |
+
from t5.data import preprocessors
|
| 9 |
+
from t5.evaluation import metrics
|
| 10 |
+
from seqio import FunctionDataSource, utils
|
| 11 |
+
|
| 12 |
+
TaskRegistry = seqio.TaskRegistry
|
| 13 |
+
|
| 14 |
+
vocabulary = seqio.SentencePieceVocabulary('gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model', extra_ids=0)
|
| 15 |
+
|
| 16 |
+
DEFAULT_OUTPUT_FEATURES = {
|
| 17 |
+
"inputs": seqio.Feature(
|
| 18 |
+
vocabulary=vocabulary, add_eos=True,
|
| 19 |
+
required=False),
|
| 20 |
+
"targets": seqio.Feature(
|
| 21 |
+
vocabulary=vocabulary, add_eos=True)
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def gen_dataset(split, shuffle=False, seed=None, column="text", dataset_params=None):
|
| 26 |
+
dataset = load_dataset(**dataset_params)
|
| 27 |
+
if shuffle:
|
| 28 |
+
if seed:
|
| 29 |
+
dataset = dataset.shuffle(seed=seed)
|
| 30 |
+
else:
|
| 31 |
+
dataset = dataset.shuffle()
|
| 32 |
+
while True:
|
| 33 |
+
for item in dataset[str(split)]:
|
| 34 |
+
yield item[column]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def dataset_fn(split, shuffle_files, seed=None, dataset_params=None):
|
| 38 |
+
return tf.data.Dataset.from_generator(
|
| 39 |
+
functools.partial(gen_dataset, split, shuffle_files, seed, dataset_params=dataset_params),
|
| 40 |
+
output_signature=tf.TensorSpec(shape=(), dtype=tf.string, name=dataset_name)
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@utils.map_over_dataset
|
| 45 |
+
def target_to_key(x, key_map, target_key):
|
| 46 |
+
"""Assign the value from the dataset to target_key in key_map"""
|
| 47 |
+
return {**key_map, target_key: x}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# Final pretraining task used in Raffel et al., 2019 adaptated to NCC
|
| 51 |
+
dataset_name = 'NbAiLab/NCC_small'
|
| 52 |
+
dataset_params = {"path": dataset_name, "use_auth_token": True}
|
| 53 |
+
dataset_shapes = {'train': 20830348, 'validation': 473079}
|
| 54 |
+
TaskRegistry.add(
|
| 55 |
+
'span_corruption',
|
| 56 |
+
source=seqio.FunctionDataSource(
|
| 57 |
+
dataset_fn=functools.partial(dataset_fn, dataset_params=dataset_params),
|
| 58 |
+
splits=("train", "validation"),
|
| 59 |
+
#caching_permitted=True,
|
| 60 |
+
num_input_examples=dataset_shapes,
|
| 61 |
+
),
|
| 62 |
+
preprocessors=[
|
| 63 |
+
functools.partial(
|
| 64 |
+
target_to_key, key_map={
|
| 65 |
+
"inputs": None,
|
| 66 |
+
"targets": None,
|
| 67 |
+
}, target_key="targets"),
|
| 68 |
+
seqio.preprocessors.tokenize,
|
| 69 |
+
#seqio.CacheDatasetPlaceholder(),
|
| 70 |
+
preprocessors.span_corruption,
|
| 71 |
+
seqio.preprocessors.append_eos_after_trim,
|
| 72 |
+
],
|
| 73 |
+
output_features={"targets": DEFAULT_OUTPUT_FEATURES["targets"]},
|
| 74 |
+
metric_fns=[]
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Final pretraining task used in Raffel et al., 2019 adaptated to nbailab_extended
|
| 78 |
+
dataset_name = 'NbAiLab/nbailab_extended'
|
| 79 |
+
dataset_params = {"path": dataset_name, "use_auth_token": True, "streaming": True}
|
| 80 |
+
dataset_shapes = None
|
| 81 |
+
TaskRegistry.add(
|
| 82 |
+
'span_corrpution_stream',
|
| 83 |
+
source=seqio.FunctionDataSource(
|
| 84 |
+
dataset_fn=functools.partial(dataset_fn, dataset_params=dataset_params),
|
| 85 |
+
splits=("train", "validation"),
|
| 86 |
+
caching_permitted=True,
|
| 87 |
+
num_input_examples=dataset_shapes,
|
| 88 |
+
),
|
| 89 |
+
preprocessors=[
|
| 90 |
+
functools.partial(
|
| 91 |
+
target_to_key, key_map={
|
| 92 |
+
"inputs": None,
|
| 93 |
+
"targets": None,
|
| 94 |
+
}, target_key="targets"),
|
| 95 |
+
seqio.preprocessors.tokenize,
|
| 96 |
+
seqio.CacheDatasetPlaceholder(),
|
| 97 |
+
preprocessors.span_corruption,
|
| 98 |
+
seqio.preprocessors.append_eos_after_trim,
|
| 99 |
+
],
|
| 100 |
+
output_features={"targets": DEFAULT_OUTPUT_FEATURES["targets"]},
|
| 101 |
+
metric_fns=[]
|
| 102 |
+
)
|
train_base.sh
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
PROJECT_DIR=${HOME}"/models/pk-nb-t5x"
|
| 2 |
T5X_DIR="../../t5x" # directory where the t5x is cloned.
|
| 3 |
-
MODEL_DIR="gs://nb-t5x/
|
| 4 |
export PYTHONPATH=${PROJECT_DIR}
|
| 5 |
|
| 6 |
python3 ${T5X_DIR}/t5x/train.py \
|
|
|
|
| 1 |
PROJECT_DIR=${HOME}"/models/pk-nb-t5x"
|
| 2 |
T5X_DIR="../../t5x" # directory where the t5x is cloned.
|
| 3 |
+
MODEL_DIR="gs://nb-t5x-us-central2/pk_nb_t5x_base_test1"
|
| 4 |
export PYTHONPATH=${PROJECT_DIR}
|
| 5 |
|
| 6 |
python3 ${T5X_DIR}/t5x/train.py \
|