nospace
Browse files- norwegian_byt5_ns_base.gin +29 -0
- tasks.py +47 -0
- train_byt5_ns_base.sh +9 -0
norwegian_byt5_ns_base.gin
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
include 't5x/examples/t5/byt5/base.gin'
|
| 2 |
+
include 'pretrain_cont.gin'
|
| 3 |
+
#include 't5x/configs/runs/pretrain.gin'
|
| 4 |
+
#iinclude 't5x/configs/runs/finetune.gin'
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# Register necessary SeqIO Tasks/Mixtures.
|
| 8 |
+
import t5.data.mixtures
|
| 9 |
+
import tasks
|
| 10 |
+
|
| 11 |
+
MIXTURE_OR_TASK_NAME = "byt5_ns_ncc_english_span_corruption_stream"
|
| 12 |
+
TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 512}
|
| 13 |
+
TRAIN_STEPS = 1_500_000
|
| 14 |
+
DROPOUT_RATE = 0.0 # Changed from the default since T5-1.1 recomments this.
|
| 15 |
+
INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/byt5/base/model.ckpt-1000000"
|
| 16 |
+
PjitPartitioner.num_partitions = 1
|
| 17 |
+
|
| 18 |
+
# `LOSS_NORMALIZING_FACTOR`: When fine-tuning a model that was pre-trained
|
| 19 |
+
# # using Mesh Tensorflow (e.g. the public T5 / mT5 / ByT5 models), this should be
|
| 20 |
+
# # set to `pretraining batch_size` * `target_token_length`. For T5 and T5.1.1:
|
| 21 |
+
# # `2048 * 114`. For mT5: `1024 * 229`. For ByT5: `1024 * 189`.
|
| 22 |
+
|
| 23 |
+
# The instructions above is from T5X. We here have to convert the Mesh Tensorflow byt5-model, so this needs to be set
|
| 24 |
+
LOSS_NORMALIZING_FACTOR = 193536
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
|
tasks.py
CHANGED
|
@@ -50,6 +50,25 @@ def dataset_fn(split, shuffle_files, seed=None, dataset_params=None):
|
|
| 50 |
)
|
| 51 |
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
@utils.map_over_dataset
|
| 54 |
def target_to_key(x, key_map, target_key):
|
| 55 |
"""Assign the value from the dataset to target_key in key_map"""
|
|
@@ -192,6 +211,34 @@ TaskRegistry.add(
|
|
| 192 |
metric_fns=[]
|
| 193 |
)
|
| 194 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
|
| 196 |
# Final pretraining task used in Raffel et al., 2019 adaptated to NCC
|
| 197 |
dataset_name = 'NbAiLab/NCC_plus_english'
|
|
|
|
| 50 |
)
|
| 51 |
|
| 52 |
|
| 53 |
+
def gen_dataset_ns(split, shuffle=False, seed=None, column="text", dataset_params=None):
|
| 54 |
+
dataset = load_dataset(**dataset_params)
|
| 55 |
+
if shuffle:
|
| 56 |
+
if seed:
|
| 57 |
+
dataset = dataset.shuffle(seed=seed)
|
| 58 |
+
else:
|
| 59 |
+
dataset = dataset.shuffle()
|
| 60 |
+
while True:
|
| 61 |
+
for item in dataset[str(split)]:
|
| 62 |
+
yield item[column].replace(" ","")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def dataset_fni_ns(split, shuffle_files, seed=None, dataset_params=None):
|
| 66 |
+
return tf.data.Dataset.from_generator(
|
| 67 |
+
functools.partial(gen_dataset_ns, split, shuffle_files, seed, dataset_params=dataset_params),
|
| 68 |
+
output_signature=tf.TensorSpec(shape=(), dtype=tf.string, name=dataset_name)
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
@utils.map_over_dataset
|
| 73 |
def target_to_key(x, key_map, target_key):
|
| 74 |
"""Assign the value from the dataset to target_key in key_map"""
|
|
|
|
| 211 |
metric_fns=[]
|
| 212 |
)
|
| 213 |
|
| 214 |
+
# Final pretraining task used in Raffel et al., 2019 adaptated to NCC
|
| 215 |
+
# No space training
|
| 216 |
+
dataset_name = 'NbAiLab/NCC_plus_english'
|
| 217 |
+
dataset_params = {"path": dataset_name, "use_auth_token": True, "streaming": True}
|
| 218 |
+
dataset_shapes = None
|
| 219 |
+
TaskRegistry.add(
|
| 220 |
+
"byt5_ns_ncc_english_span_corruption_stream",
|
| 221 |
+
source=seqio.FunctionDataSource(
|
| 222 |
+
dataset_ns_fn=functools.partial(dataset_fn, dataset_params=dataset_params),
|
| 223 |
+
splits=("train", "validation"),
|
| 224 |
+
caching_permitted=False,
|
| 225 |
+
num_input_examples=dataset_shapes,
|
| 226 |
+
),
|
| 227 |
+
preprocessors=[
|
| 228 |
+
functools.partial(
|
| 229 |
+
target_to_key, key_map={
|
| 230 |
+
"inputs": None,
|
| 231 |
+
"targets": None,
|
| 232 |
+
}, target_key="targets"),
|
| 233 |
+
seqio.preprocessors.tokenize,
|
| 234 |
+
# seqio.CacheDatasetPlaceholder(),
|
| 235 |
+
preprocessors.span_corruption,
|
| 236 |
+
seqio.preprocessors.append_eos_after_trim,
|
| 237 |
+
],
|
| 238 |
+
output_features={"targets": BYT5_DEFAULT_OUTPUT_FEATURES["targets"]},
|
| 239 |
+
metric_fns=[]
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
|
| 243 |
# Final pretraining task used in Raffel et al., 2019 adaptated to NCC
|
| 244 |
dataset_name = 'NbAiLab/NCC_plus_english'
|
train_byt5_ns_base.sh
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
PROJECT_DIR=${HOME}"/models/pk-nb-t5x"
|
| 2 |
+
T5X_DIR="../../t5x" # directory where the t5x is cloned.
|
| 3 |
+
MODEL_DIR="gs://t5x-training/pretrained_models/norwegian_NCC_plus_English_byt5x_ns_base"
|
| 4 |
+
export PYTHONPATH=${PROJECT_DIR}
|
| 5 |
+
|
| 6 |
+
python3 ${T5X_DIR}/t5x/train.py \
|
| 7 |
+
--gin_search_paths=${PROJECT_DIR} \
|
| 8 |
+
--gin_file="norwegian_byt5_ns_base.gin" \
|
| 9 |
+
--gin.MODEL_DIR="'${MODEL_DIR}'" \
|