pere commited on
Commit
ecfe1a8
·
1 Parent(s): edfac96

pretrain_cont

Browse files
__pycache__/tasks.cpython-38.pyc CHANGED
Binary files a/__pycache__/tasks.cpython-38.pyc and b/__pycache__/tasks.cpython-38.pyc differ
 
continue_pretraining_base.gin CHANGED
@@ -1,15 +1,15 @@
1
  include 't5x/examples/t5/mt5/base.gin'
2
  include 't5x/configs/runs/pretrain.gin'
 
3
 
4
  # Register necessary SeqIO Tasks/Mixtures.
5
  import t5.data.mixtures
6
  import tasks
7
 
8
- MIXTURE_OR_TASK_NAME = "span_corruption"
9
- TASK_FEATURE_LENGTHS = {"inputs": 256, "targets": 256}
10
- TRAIN_STEPS = 1_001_000
11
- DROPOUT_RATE = 0.0
12
  INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/mt5_xxl/checkpoint_1000000"
 
13
 
14
- #Batch size should be the default for pretraining
15
- #BATCH_SIZE = 256
 
1
  include 't5x/examples/t5/mt5/base.gin'
2
  include 't5x/configs/runs/pretrain.gin'
3
+ include 't5x/configs/runs/finetune.gin'
4
 
5
  # Register necessary SeqIO Tasks/Mixtures.
6
  import t5.data.mixtures
7
  import tasks
8
 
9
+ MIXTURE_OR_TASK_NAME = "ncc_span_corruption_stream"
10
+ TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 512}
11
+ TRAIN_STEPS = 1_100_000
12
+ DROPOUT_RATE = 0.0 # Changed from the default since T5-1.1 recomments this.
13
  INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/mt5_xxl/checkpoint_1000000"
14
+ PjitPartitioner.num_partitions = 2
15
 
 
 
pretrain_cont.gin ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Defaults for pretraining with train.py.
2
+ #
3
+ #
4
+ # You must also include a binding for MODEL.
5
+ #
6
+ # Required to be set
7
+ #
8
+ # - MIXTURE_OR_TASK_NAME
9
+ # - TASK_FEATURE_LENGTHS
10
+ # - TRAIN_STEPS - include pretrain steps
11
+ # - MODEL_DIR: # automatically set when using xm_launch
12
+ #
13
+ # Commonly overridden options:
14
+ #
15
+ # - train/DatasetConfig.batch_size
16
+ # - train_eval/DatasetConfig.batch_size
17
+ # - PjitPartitioner.num_partitions
18
+ # - Trainer.num_microbatches
19
+ # - DROPOUT_RATE
20
+ from __gin__ import dynamic_registration
21
+
22
+ import __main__ as train_script
23
+ from t5x import gin_utils
24
+ from t5x import partitioning
25
+ from t5x import utils
26
+ from t5x import trainer
27
+
28
+ MIXTURE_OR_TASK_NAME = %gin.REQUIRED
29
+ TASK_FEATURE_LENGTHS = %gin.REQUIRED
30
+ TRAIN_STEPS = %gin.REQUIRED
31
+ MODEL_DIR = %gin.REQUIRED
32
+ BATCH_SIZE = 128
33
+ USE_CACHED_TASKS = True
34
+ INITIAL_CHECKPOINT_PATH = %gin.REQUIRED
35
+
36
+ # DEPRECATED: Import the this module in your gin file.
37
+ MIXTURE_OR_TASK_MODULE = None
38
+ SHUFFLE_TRAIN_EXAMPLES = True
39
+
40
+ # HW RNG is faster than SW, but has limited determinism.
41
+ # Most notably it is not deterministic across different
42
+ # submeshes.
43
+ USE_HARDWARE_RNG = False
44
+ # None always uses faster, hardware RNG
45
+ RANDOM_SEED = None
46
+
47
+ # Can be overridden with `train.*`.`
48
+ train_script.train:
49
+ model = %MODEL # imported from separate gin file
50
+ model_dir = %MODEL_DIR
51
+ train_dataset_cfg = @train/utils.DatasetConfig()
52
+ train_eval_dataset_cfg = @train_eval/utils.DatasetConfig()
53
+ infer_eval_dataset_cfg = None
54
+ checkpoint_cfg = @utils.CheckpointConfig()
55
+ partitioner = @partitioning.PjitPartitioner()
56
+ trainer_cls = @trainer.Trainer
57
+ total_steps = %TRAIN_STEPS
58
+ eval_steps = 20
59
+ eval_period = 1000
60
+ random_seed = %RANDOM_SEED
61
+ use_hardware_rng = %USE_HARDWARE_RNG
62
+ summarize_config_fn = @gin_utils.summarize_gin_config
63
+
64
+ partitioning.PjitPartitioner:
65
+ num_partitions = 1
66
+ model_parallel_submesh = None
67
+ logical_axis_rules = @partitioning.standard_logical_axis_rules()
68
+
69
+ train/utils.DatasetConfig:
70
+ mixture_or_task_name = %MIXTURE_OR_TASK_NAME
71
+ task_feature_lengths = %TASK_FEATURE_LENGTHS
72
+ split = 'train'
73
+ batch_size = %BATCH_SIZE
74
+ shuffle = %SHUFFLE_TRAIN_EXAMPLES
75
+ seed = None # use a new seed each run/restart
76
+ use_cached = %USE_CACHED_TASKS
77
+ pack = True
78
+ module = %MIXTURE_OR_TASK_MODULE
79
+
80
+ train_eval/utils.DatasetConfig:
81
+ mixture_or_task_name = %MIXTURE_OR_TASK_NAME
82
+ task_feature_lengths = %TASK_FEATURE_LENGTHS
83
+ split = 'validation'
84
+ batch_size = %BATCH_SIZE
85
+ shuffle = False
86
+ seed = 42
87
+ use_cached = %USE_CACHED_TASKS
88
+ pack = True
89
+ module = %MIXTURE_OR_TASK_MODULE
90
+
91
+ utils.CheckpointConfig:
92
+ restore = @utils.RestoreCheckpointConfig()
93
+ save = @utils.SaveCheckpointConfig()
94
+ utils.RestoreCheckpointConfig:
95
+ path = %INITIAL_CHECKPOINT_PATH
96
+ mode = 'specific'
97
+ dtype = 'float32'
98
+ utils.SaveCheckpointConfig:
99
+ period = 1000
100
+ dtype = 'float32'
101
+ keep = None # keep all checkpoints
102
+ save_dataset = False # don't checkpoint dataset state
103
+
104
+ trainer.Trainer:
105
+ num_microbatches = None
106
+ learning_rate_fn = @utils.create_learning_rate_scheduler()
107
+
108
+ utils.create_learning_rate_scheduler:
109
+ factors = 'constant * rsqrt_decay'
110
+ base_learning_rate = 0.5 #This is set to half of the original since it is continued training
111
+ warmup_steps = 10000 # 10k to keep consistent with T5/MTF defaults.
tasks.py CHANGED
@@ -11,15 +11,18 @@ from seqio import FunctionDataSource, utils
11
 
12
  TaskRegistry = seqio.TaskRegistry
13
 
14
- vocabulary = seqio.SentencePieceVocabulary('gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model', extra_ids=0)
 
15
 
16
  DEFAULT_OUTPUT_FEATURES = {
17
  "inputs": seqio.Feature(
18
- vocabulary=vocabulary, add_eos=True,
19
  required=False),
20
  "targets": seqio.Feature(
21
- vocabulary=vocabulary, add_eos=True)
22
  }
 
 
23
 
24
 
25
  def gen_dataset(split, shuffle=False, seed=None, column="text", dataset_params=None):
@@ -48,15 +51,15 @@ def target_to_key(x, key_map, target_key):
48
 
49
 
50
  # Final pretraining task used in Raffel et al., 2019 adaptated to NCC
51
- dataset_name = 'NbAiLab/NCC_small'
52
- dataset_params = {"path": dataset_name, "use_auth_token": True}
53
  dataset_shapes = {'train': 20830348, 'validation': 473079}
54
  TaskRegistry.add(
55
- 'span_corruption',
56
  source=seqio.FunctionDataSource(
57
  dataset_fn=functools.partial(dataset_fn, dataset_params=dataset_params),
58
  splits=("train", "validation"),
59
- #caching_permitted=True,
60
  num_input_examples=dataset_shapes,
61
  ),
62
  preprocessors=[
@@ -66,24 +69,52 @@ TaskRegistry.add(
66
  "targets": None,
67
  }, target_key="targets"),
68
  seqio.preprocessors.tokenize,
69
- #seqio.CacheDatasetPlaceholder(),
70
- preprocessors.span_corruption,
71
  seqio.preprocessors.append_eos_after_trim,
72
  ],
73
- output_features={"targets": DEFAULT_OUTPUT_FEATURES["targets"]},
74
  metric_fns=[]
75
  )
76
 
77
- # Final pretraining task used in Raffel et al., 2019 adaptated to nbailab_extended
78
- dataset_name = 'NbAiLab/nbailab_extended'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  dataset_params = {"path": dataset_name, "use_auth_token": True, "streaming": True}
80
  dataset_shapes = None
81
  TaskRegistry.add(
82
- 'span_corrpution_stream',
83
  source=seqio.FunctionDataSource(
84
  dataset_fn=functools.partial(dataset_fn, dataset_params=dataset_params),
85
  splits=("train", "validation"),
86
- caching_permitted=True,
87
  num_input_examples=dataset_shapes,
88
  ),
89
  preprocessors=[
@@ -93,10 +124,10 @@ TaskRegistry.add(
93
  "targets": None,
94
  }, target_key="targets"),
95
  seqio.preprocessors.tokenize,
96
- seqio.CacheDatasetPlaceholder(),
97
  preprocessors.span_corruption,
98
  seqio.preprocessors.append_eos_after_trim,
99
  ],
100
- output_features={"targets": DEFAULT_OUTPUT_FEATURES["targets"]},
101
  metric_fns=[]
102
  )
 
11
 
12
  TaskRegistry = seqio.TaskRegistry
13
 
14
+
15
+
16
 
17
  DEFAULT_OUTPUT_FEATURES = {
18
  "inputs": seqio.Feature(
19
+ vocabulary=t5.data.get_default_vocabulary(), add_eos=True,
20
  required=False),
21
  "targets": seqio.Feature(
22
+ vocabulary=t5.data.get_default_vocabulary(), add_eos=True)
23
  }
24
+ # Custom vocabs can also be defined and loaded
25
+ # vocabulary = seqio.SentencePieceVocabulary("gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model")
26
 
27
 
28
  def gen_dataset(split, shuffle=False, seed=None, column="text", dataset_params=None):
 
51
 
52
 
53
  # Final pretraining task used in Raffel et al., 2019 adaptated to NCC
54
+ dataset_name = 'NbAiLab/NCC'
55
+ dataset_params = {"path": dataset_name}
56
  dataset_shapes = {'train': 20830348, 'validation': 473079}
57
  TaskRegistry.add(
58
+ "ncc_span_corruption",
59
  source=seqio.FunctionDataSource(
60
  dataset_fn=functools.partial(dataset_fn, dataset_params=dataset_params),
61
  splits=("train", "validation"),
62
+ caching_permitted=False,
63
  num_input_examples=dataset_shapes,
64
  ),
65
  preprocessors=[
 
69
  "targets": None,
70
  }, target_key="targets"),
71
  seqio.preprocessors.tokenize,
72
+ # seqio.CacheDatasetPlaceholder(),
73
+ preprocessors.span_corruption,
74
  seqio.preprocessors.append_eos_after_trim,
75
  ],
76
+ output_features={"targets": seqio.Feature(vocabulary=vocabulary, add_eos=True)},
77
  metric_fns=[]
78
  )
79
 
80
+ # Final pretraining task used in Raffel et al., 2019 adaptated to NCC
81
+ dataset_name = 'NbAiLab/NCC_amall'
82
+ dataset_params = {"path": dataset_name}
83
+ dataset_shapes = {'train': 20830348, 'validation': 473079}
84
+ TaskRegistry.add(
85
+ "ncc_small_span_corruption",
86
+ source=seqio.FunctionDataSource(
87
+ dataset_fn=functools.partial(dataset_fn, dataset_params=dataset_params),
88
+ splits=("train", "validation"),
89
+ caching_permitted=False,
90
+ num_input_examples=dataset_shapes,
91
+ ),
92
+ preprocessors=[
93
+ functools.partial(
94
+ target_to_key, key_map={
95
+ "inputs": None,
96
+ "targets": None,
97
+ }, target_key="targets"),
98
+ seqio.preprocessors.tokenize,
99
+ # seqio.CacheDatasetPlaceholder(),
100
+ preprocessors.span_corruption,
101
+ seqio.preprocessors.append_eos_after_trim,
102
+ ],
103
+ output_features={"targets": seqio.Feature(vocabulary=vocabulary, add_eos=True)},
104
+ metric_fns=[]
105
+ )
106
+
107
+
108
+ # Final pretraining task used in Raffel et al., 2019 adaptated to NCC
109
+ dataset_name = 'NbAiLab/NCC'
110
  dataset_params = {"path": dataset_name, "use_auth_token": True, "streaming": True}
111
  dataset_shapes = None
112
  TaskRegistry.add(
113
+ "NCC_span_corruption_stream",
114
  source=seqio.FunctionDataSource(
115
  dataset_fn=functools.partial(dataset_fn, dataset_params=dataset_params),
116
  splits=("train", "validation"),
117
+ caching_permitted=False,
118
  num_input_examples=dataset_shapes,
119
  ),
120
  preprocessors=[
 
124
  "targets": None,
125
  }, target_key="targets"),
126
  seqio.preprocessors.tokenize,
127
+ # seqio.CacheDatasetPlaceholder(),
128
  preprocessors.span_corruption,
129
  seqio.preprocessors.append_eos_after_trim,
130
  ],
131
+ output_features={"targets": seqio.Feature(vocabulary=vocabulary, add_eos=True)},
132
  metric_fns=[]
133
  )
tasks_old.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ import seqio
4
+ import tensorflow as tf
5
+ import t5.data
6
+ from datasets import load_dataset
7
+ from t5.data import postprocessors
8
+ from t5.data import preprocessors
9
+ from t5.evaluation import metrics
10
+ from seqio import FunctionDataSource, utils
11
+
12
+ TaskRegistry = seqio.TaskRegistry
13
+
14
+ vocabulary = seqio.SentencePieceVocabulary('gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model', extra_ids=0)
15
+
16
+ DEFAULT_OUTPUT_FEATURES = {
17
+ "inputs": seqio.Feature(
18
+ vocabulary=vocabulary, add_eos=True,
19
+ required=False),
20
+ "targets": seqio.Feature(
21
+ vocabulary=vocabulary, add_eos=True)
22
+ }
23
+
24
+
25
+ def gen_dataset(split, shuffle=False, seed=None, column="text", dataset_params=None):
26
+ dataset = load_dataset(**dataset_params)
27
+ if shuffle:
28
+ if seed:
29
+ dataset = dataset.shuffle(seed=seed)
30
+ else:
31
+ dataset = dataset.shuffle()
32
+ while True:
33
+ for item in dataset[str(split)]:
34
+ yield item[column]
35
+
36
+
37
+ def dataset_fn(split, shuffle_files, seed=None, dataset_params=None):
38
+ return tf.data.Dataset.from_generator(
39
+ functools.partial(gen_dataset, split, shuffle_files, seed, dataset_params=dataset_params),
40
+ output_signature=tf.TensorSpec(shape=(), dtype=tf.string, name=dataset_name)
41
+ )
42
+
43
+
44
+ @utils.map_over_dataset
45
+ def target_to_key(x, key_map, target_key):
46
+ """Assign the value from the dataset to target_key in key_map"""
47
+ return {**key_map, target_key: x}
48
+
49
+
50
+ # Final pretraining task used in Raffel et al., 2019 adaptated to NCC
51
+ dataset_name = 'NbAiLab/NCC_small'
52
+ dataset_params = {"path": dataset_name, "use_auth_token": True}
53
+ dataset_shapes = {'train': 20830348, 'validation': 473079}
54
+ TaskRegistry.add(
55
+ 'span_corruption',
56
+ source=seqio.FunctionDataSource(
57
+ dataset_fn=functools.partial(dataset_fn, dataset_params=dataset_params),
58
+ splits=("train", "validation"),
59
+ #caching_permitted=True,
60
+ num_input_examples=dataset_shapes,
61
+ ),
62
+ preprocessors=[
63
+ functools.partial(
64
+ target_to_key, key_map={
65
+ "inputs": None,
66
+ "targets": None,
67
+ }, target_key="targets"),
68
+ seqio.preprocessors.tokenize,
69
+ #seqio.CacheDatasetPlaceholder(),
70
+ preprocessors.span_corruption,
71
+ seqio.preprocessors.append_eos_after_trim,
72
+ ],
73
+ output_features={"targets": DEFAULT_OUTPUT_FEATURES["targets"]},
74
+ metric_fns=[]
75
+ )
76
+
77
+ # Final pretraining task used in Raffel et al., 2019 adaptated to nbailab_extended
78
+ dataset_name = 'NbAiLab/nbailab_extended'
79
+ dataset_params = {"path": dataset_name, "use_auth_token": True, "streaming": True}
80
+ dataset_shapes = None
81
+ TaskRegistry.add(
82
+ 'span_corrpution_stream',
83
+ source=seqio.FunctionDataSource(
84
+ dataset_fn=functools.partial(dataset_fn, dataset_params=dataset_params),
85
+ splits=("train", "validation"),
86
+ caching_permitted=True,
87
+ num_input_examples=dataset_shapes,
88
+ ),
89
+ preprocessors=[
90
+ functools.partial(
91
+ target_to_key, key_map={
92
+ "inputs": None,
93
+ "targets": None,
94
+ }, target_key="targets"),
95
+ seqio.preprocessors.tokenize,
96
+ seqio.CacheDatasetPlaceholder(),
97
+ preprocessors.span_corruption,
98
+ seqio.preprocessors.append_eos_after_trim,
99
+ ],
100
+ output_features={"targets": DEFAULT_OUTPUT_FEATURES["targets"]},
101
+ metric_fns=[]
102
+ )
train_base.sh CHANGED
@@ -1,6 +1,6 @@
1
  PROJECT_DIR=${HOME}"/models/pk-nb-t5x"
2
  T5X_DIR="../../t5x" # directory where the t5x is cloned.
3
- MODEL_DIR="gs://nb-t5x/pk_nb_t5x_base"
4
  export PYTHONPATH=${PROJECT_DIR}
5
 
6
  python3 ${T5X_DIR}/t5x/train.py \
 
1
  PROJECT_DIR=${HOME}"/models/pk-nb-t5x"
2
  T5X_DIR="../../t5x" # directory where the t5x is cloned.
3
+ MODEL_DIR="gs://nb-t5x-us-central2/pk_nb_t5x_base_test1"
4
  export PYTHONPATH=${PROJECT_DIR}
5
 
6
  python3 ${T5X_DIR}/t5x/train.py \