cadalano commited on
Commit
c2b7da4
·
verified ·
1 Parent(s): d6ba08c

Upload 21 files

Browse files
Files changed (21) hide show
  1. .gitignore +116 -0
  2. CITATION.cff +30 -0
  3. CODEOWNERS +1 -0
  4. Dockerfile +15 -0
  5. GPTNeo_example_notebook.ipynb +0 -0
  6. LICENSE +21 -0
  7. README.md +383 -0
  8. configs.py +47 -0
  9. docker-compose.yml +67 -0
  10. encoders.py +28 -0
  11. export.py +14 -0
  12. inputs.py +384 -0
  13. main.py +257 -0
  14. model_fns.py +305 -0
  15. optimizers.py +176 -0
  16. requirements.txt +18 -0
  17. run_experiment.py +265 -0
  18. sample.py +218 -0
  19. tasks.py +116 -0
  20. test_models.py +180 -0
  21. utils.py +292 -0
.gitignore ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # testing
2
+ .test/
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ *.egg-info/
27
+ .installed.cfg
28
+ *.egg
29
+ MANIFEST
30
+
31
+ # PyInstaller
32
+ # Usually these files are written by a python script from a template
33
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
34
+ *.manifest
35
+ *.spec
36
+
37
+ # Installer logs
38
+ pip-log.txt
39
+ pip-delete-this-directory.txt
40
+
41
+ # Unit test / coverage reports
42
+ htmlcov/
43
+ .tox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+
53
+ # Translations
54
+ *.mo
55
+ *.pot
56
+
57
+ # Django stuff:
58
+ *.log
59
+ local_settings.py
60
+ db.sqlite3
61
+
62
+ # Flask stuff:
63
+ instance/
64
+ .webassets-cache
65
+
66
+ # Scrapy stuff:
67
+ .scrapy
68
+
69
+ # Sphinx documentation
70
+ docs/_build/
71
+
72
+ # PyBuilder
73
+ target/
74
+
75
+ # Jupyter Notebook
76
+ .ipynb_checkpoints
77
+
78
+ # pyenv
79
+ .python-version
80
+
81
+ # celery beat schedule file
82
+ celerybeat-schedule
83
+
84
+ # SageMath parsed files
85
+ *.sage.py
86
+
87
+ # Environments
88
+ .env
89
+ .venv
90
+ env/
91
+ venv/
92
+ ENV/
93
+ env.bak/
94
+ venv.bak/
95
+
96
+ # Spyder project settings
97
+ .spyderproject
98
+ .spyproject
99
+
100
+ # Rope project settings
101
+ .ropeproject
102
+
103
+ # mkdocs documentation
104
+ /site
105
+
106
+ # mypy
107
+ .mypy_cache/
108
+
109
+ logs/
110
+ *.log
111
+ test_*
112
+ test/
113
+ .vscode
114
+
115
+
116
+ run_configs/
CITATION.cff ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YAML 1.2
2
+ ---
3
+ authors:
4
+ - affiliation: EleutherAI
5
+ family-names: Black
6
+ given-names: Sid
7
+ - affiliation: EleutherAI
8
+ family-names: Leo
9
+ given-names: Gao
10
+ - affiliation: EleutherAI
11
+ family-names: Wang
12
+ given-names: Phil
13
+ - affiliation: EleutherAI
14
+ family-names: Leahy
15
+ given-names: Connor
16
+ - affiliation: EleutherAI
17
+ family-names: Biderman
18
+ given-names: Stella
19
+ cff-version: "1.1.0"
20
+ keywords:
21
+ - Transformers
22
+ - "Massive language model"
23
+ - "Autoregressive language model"
24
+ license: "Apache-2.0"
25
+ message: "If you use this software, please cite it using these metadata."
26
+ repository-code: "https://www.github.com/eleutherai/gpt-neo"
27
+ title: "GPT-Neo: Large Scale Autoregressive Language Modeling with Mesh-Tensorflow"
28
+ version: "1.0"
29
+ date-released: 2021-03-21
30
+ ...
CODEOWNERS ADDED
@@ -0,0 +1 @@
 
 
1
+ * EleutherAI/pm-gptneo
Dockerfile ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM gcr.io/deeplearning-platform-release/tf-cpu.1-15
2
+
3
+ WORKDIR /neogpt
4
+
5
+ # Make RUN commands use `bash --login`:
6
+ SHELL ["/bin/bash", "--login", "-c"]
7
+ ENV DEBIAN_FRONTEND=noninteractive
8
+ RUN apt-get update -y && apt-get install tmux -y
9
+ RUN conda install gcc_linux-64 gxx_linux-64 -y
10
+ ADD requirements.txt .
11
+ RUN pip install -r requirements.txt
12
+ RUN apt-get install screen htop -y
13
+ RUN python -m pip install tensorboard==1.15 cloud_tpu_profiler==1.15
14
+
15
+ CMD tmux
GPTNeo_example_notebook.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2020 EleutherAI
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GPT Neo
2
+
3
+ 🎉 1T or bust my dudes 🎉
4
+
5
+ An implementation of model & data parallel [GPT3](https://arxiv.org/abs/2005.14165)-like models using the [mesh-tensorflow](https://github.com/tensorflow/mesh) library.
6
+
7
+ **If you're just here to play with our pre-trained models, we strongly recommend you try out the [HuggingFace Transformer integration](https://huggingface.co/EleutherAI).**
8
+
9
+ Training and inference is officially supported on TPU and should work on GPU as well. This repository will be (mostly) archived as we move focus to our GPU-specific repo, [GPT-NeoX](https://github.com/EleutherAI/gpt-neox/).
10
+
11
+ In addition to the functionality offered by GPT-3, we also offer the following:
12
+ * [Local attention](https://arxiv.org/abs/2004.05150)
13
+ * [Linear attention](https://arxiv.org/abs/1812.01243)
14
+ * [Mixture of Experts](https://arxiv.org/abs/1701.06538)
15
+ * [Axial Positional embedding](https://arxiv.org/abs/1912.12180)
16
+
17
+ NB, while neo can *technically* run a training step at 200B+ parameters, it is very inefficient at those scales. This, as well as the fact that many GPUs became available to us, among other things, prompted us to move development over to [GPT-NeoX](https://github.com/EleutherAI/gpt-neox/).
18
+
19
+ # Pretrained Models
20
+
21
+ **Update 21/03/2021:**
22
+
23
+ We're proud to release two pretrained GPT-Neo models trained on The Pile, the weights and configs can be freely downloaded from [the-eye.eu](https://the-eye.eu/public/AI/gptneo-release/).
24
+
25
+ 1.3B: https://the-eye.eu/public/AI/gptneo-release/GPT3_XL/
26
+
27
+ 2.7B: https://the-eye.eu/public/AI/gptneo-release/GPT3_2-7B/
28
+
29
+ For more information on how to get these set up, see the colab notebook, or read through the rest of the readme.
30
+
31
+ ## Model Evaluations
32
+
33
+ #### Linguistic Reasoning
34
+
35
+ | Model and Size | Pile BPB | Pile PPL | Wikitext PPL | Lambada PPL | Lambada Acc | Winogrande | Hellaswag |
36
+ | ---------------- | ---------- | ---------- | ------------- | ----------- | ----------- | ---------- | ----------- |
37
+ | **GPT-Neo 1.3B** | **0.7527** | **6.159** | **13.10** | **7.498** | **57.23%** | **55.01%** | **38.66%** |
38
+ | GPT-2 1.5B | 1.0468 | ----- | 17.48 | 10.634 | 51.21% | 59.40% | 40.03% |
39
+ | **GPT-Neo 2.7B** | **0.7165** | **5.646** | **11.39** | **5.626** | **62.22%** | **56.50%** | **42.73%** |
40
+ | GPT-3 Ada | 0.9631 | ----- | ----- | 9.954 | 51.60% | 52.90% | 35.93% |
41
+
42
+ #### Physical and Scientific Reasoning
43
+
44
+ | Model and Size | MathQA | PubMedQA | Piqa |
45
+ | ---------------- | ---------- | ---------- | ----------- |
46
+ | **GPT-Neo 1.3B** | **24.05%** | **54.40%** | **71.11%** |
47
+ | GPT-2 1.5B | 23.64% | 58.33% | 70.78% |
48
+ | **GPT-Neo 2.7B** | **24.72%** | **57.54%** | **72.14%** |
49
+ | GPT-3 Ada | 24.29% | 52.80% | 68.88% |
50
+
51
+ **Note:** All evaluations were done using our [evaluation harness](https://github.com/EleutherAI/lm-evaluation-harness). Some results for GPT-2 and GPT-3 are inconsistent with the values reported in the respective papers. We are currently looking into why, and would greatly appreciate feedback and further testing of our eval harness.
52
+
53
+ # Setup
54
+
55
+ ```bash
56
+ git clone https://github.com/EleutherAI/GPTNeo
57
+ cd GPTNeo
58
+ pip3 install -r requirements.txt
59
+ ```
60
+ # Training Setup
61
+
62
+ ## TPUs:
63
+
64
+ Sign up for [Google Cloud Platform](https://cloud.google.com/), and create a [storage bucket](https://cloud.google.com/storage).
65
+
66
+ Create your VM through a google shell (`https://ssh.cloud.google.com/`) with `ctpu up --vm-only` so that it can connect to your Google bucket and TPUs and install the requirements with pip (see above).
67
+
68
+ Google colab provides tpu-v8s for free, which should be enough to finetune our models up to GPT3XL (1.5B parameter) sizes.
69
+ Click [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/EleutherAI/GPTNeo/blob/master/GPTNeo_example_notebook.ipynb) to run through our example colab notebook.
70
+
71
+ For more detailed instructions, run through our [Training Guide](https://github.com/EleutherAI/GPTNeo#training-guide) below.
72
+
73
+ ## GPUs:
74
+
75
+ You can also choose to train GPTNeo locally on your GPUs. To do so, you can omit the Google cloud setup steps above, and git clone the repo locally. Run through the [Training Guide](https://github.com/EleutherAI/GPTNeo#training-guide) below, then when running main.py, you simply have to omit the `tpu` flag, and pass in GPU ids instead.
76
+
77
+ Note: Some users have reported having difficulty getting MTF to recognize their GPUs. See [here](https://github.com/EleutherAI/gpt-neo/issues/150) for details and instructions on how to fix it.
78
+
79
+ # Generating Text
80
+
81
+ Once you have a trained model, or you've downloaded one of our pre-trained models, generating text is as simple as running the main.py script with the `--predict` flag on. You can pass a path to your prompt txt file with the `--prompt` flag, like so:
82
+
83
+ ```bash
84
+ python3 main.py --predict --prompt <example_prompt.txt> --tpu <tpu_name> --model <config_name>
85
+ ```
86
+
87
+ or, if using GPUs:
88
+
89
+ ```bash
90
+ python3 main.py --predict --prompt <example_prompt.txt> --gpu_ids <device:GPU:0 device:GPU:1> --model <config_name>
91
+ ```
92
+
93
+ # Training Guide
94
+
95
+ ## 1. Create your Tokenizer (OPTIONAL)
96
+
97
+ We recommend you use [Huggingface's pretrained GPT2 tokenizer](https://huggingface.co/transformers/model_doc/gpt2.html#transformers.GPT2Tokenizer) with our repo (instructions provided below), but if you want to train a model with a different vocabulary size, we provide facilities to train your own tokenizer like so:
98
+
99
+ ```bash
100
+ python data/train_tokenizer.py \
101
+ --base_dir ./path/to/your/txt/files \
102
+ --output_dir ./output/path \
103
+ --file_type txt \
104
+ --vocab_size 50257
105
+
106
+ # if it succeeded, you should see the message
107
+ # 'tokenizer saved at ./output/path/byte-level-bpe.tokenizer.json'
108
+ ```
109
+
110
+ ## 2. Tokenizing your Dataset
111
+
112
+ If you just want to test training, you can skip this step and download some dummy data like so:
113
+
114
+ ```
115
+ wget https://storage.googleapis.com/connors-datasets/bundestag/bundestag_0.tfrecords
116
+ ```
117
+
118
+ Then copy the data to your bucket, or if using GPUs, a local directory:
119
+
120
+ ```
121
+ gsutil cp bundestag_0.tfrecords gs://<your bucket>/
122
+ ```
123
+
124
+ If using your own data to train, you can use the `data/create_tfrecords.py` script to encode your text data into tfrecords.
125
+
126
+ Your data must either be in the form of lots of normal .txt files (one document per file), or in any format supported by [lm_dataformat](https://github.com/leogao2/lm_dataformat).
127
+
128
+ You can run the script without parameters to see help for all options.
129
+
130
+ In **document mode** Each example in the tfrecords is one (variably sized) document. This is to be used with the `documents_fixed` and `documents_random` sampling modes (For more details see the parameters reference section).
131
+ Document mode is the default mode.
132
+
133
+ The below command will tokenize all files in acceptable formats in *base_dir* using gpt2 tokenizer and save them to *output_dir*
134
+ ```
135
+ python3 create_tfrecords.py --mode documents --input_dir <base> --name <name> --output_dir <output> --use_gpt2_tokenizer --minimum_size <min>
136
+ ```
137
+
138
+ - `input_dir`: Defines the folder where your data is located. The script will encode all files present in this folder.
139
+ - `name`: Name of output files will be `name_i.tfrecords` where i is the number of the file.
140
+ - `output_dir`: Where to save the tfrecords to
141
+ - `use_gpt2_tokenizer`: Whether to use the pretrained HuggingFace GPT2 tokenizer, in which case the separator will be set to [50256].
142
+ - `encoder_path`: if not using the pretrained gpt2 tokenizer, use this flag to provide a path to your generated tokenizer json.
143
+ - `separator`: Written in list format, the separator token(s) to insert between documents (e.g. "[0]"). Will depend on your encoder.
144
+ - `minimum_size`: The minimum size (in tokens) a document must have, otherwise it is discarded. This is what will later determine your `stitch` parameter: `stitch * minimum_size` must always be greater or equal `n_ctx` (For more details see the parameters reference section).
145
+
146
+ ## 4. Using a Dataset in a Model
147
+
148
+ To use a dataset in a model, you must first register that dataset under `./configs/dataset_configs` folder. First choose a filename with a `.json` extension. That filename will serve as the dataset identification. The config should be filled out the following manner.
149
+
150
+ If you have a dataset encoded using the pretrained gpt2 tokenizer, you can specify that like so:
151
+
152
+ ```json
153
+ {
154
+ "n_vocab": 50257,
155
+ "path": "gs://neo-datasets/openwebtext-documents/openwebtext_*.tfrecords",
156
+ "eval_path": "gs://neo-datasets/openwebtext-documents/openwebtext_*.tfrecords",
157
+ "tokenizer_is_pretrained": true,
158
+ "tokenizer_path": "gpt2"
159
+ }
160
+ ```
161
+
162
+ or if you've trained a custom tokenizer, like so:
163
+
164
+ ```json
165
+ {
166
+ "n_vocab": 32768,
167
+ "path": "./path/to/your/*.tfrecords",
168
+ "eval_path": "./path/to/your/eval/*.tfrecords",
169
+ "tokenizer_path": "./path/to/your/byte-level-bpe.tokenizer.json"
170
+ }
171
+ ```
172
+
173
+ Finally, in your model config, add the filename that you created above to the `datasets` array.
174
+
175
+ The `<dataset id>` will be the filename, excluding the `.json`, that you created above
176
+
177
+ ```
178
+ "datasets": [[<dataset id>, <stitch>, <datatype>, <weight>]] # datasets key defines at run time how each dataset is processed for training
179
+ ```
180
+
181
+ ## 5. Choose a model configuration
182
+
183
+ Once you have your datasets set up, find a suitable config in `/configs`.
184
+
185
+ Here we use a GPT3-XL sized model as an example, but there are many more in `./configs`, all of which have short summaries in the Available Configs section.
186
+
187
+ All you need to do is edit the dataset id as described above, and edit `model_path` (where logs and checkpoints will be saved) to point to a cloud bucket you have write access to (or local path, if using GPUs).
188
+
189
+ ```json
190
+ {
191
+ "n_head": 32,
192
+ "n_vocab": 50257,
193
+ "embed_dropout": 0.1,
194
+ "lr": 0.0002,
195
+ "lr_decay": "cosine",
196
+ "warmup_steps": 3000,
197
+ "beta1": 0.9,
198
+ "beta2": 0.95,
199
+ "epsilon": 1e-8,
200
+ "opt_name": "adam",
201
+ "weight_decay": 0.1,
202
+ "train_batch_size": 512,
203
+ "attn_dropout": 0.1,
204
+ "train_steps": 286150,
205
+ "eval_steps": 0,
206
+ "predict_steps": 1,
207
+ "res_dropout": 0.1,
208
+ "eval_batch_size": 128,
209
+ "predict_batch_size": 1,
210
+ "iterations": 2500,
211
+ "n_embd": 2048,
212
+ "datasets": [["your_dataset_name", 25, "documents_random", 1.0]],
213
+ "model_path": "gs://neo-models/GPT3_XL",
214
+ "n_ctx": 2048,
215
+ "n_layer": 24,
216
+ "scale_by_depth": true,
217
+ "scale_by_in": false,
218
+ "attention_types" : [[["global"],24]],
219
+ "mesh_shape": "x:128,y:2",
220
+ "layout": "batch:x,memory_length:y,embd:y",
221
+ "activation_function": "gelu",
222
+ "recompute_grad": true,
223
+ "gradient_clipping": 1.0,
224
+ "tokens_per_mb_per_replica": 2048
225
+ }
226
+ ```
227
+
228
+
229
+ ## 6. Run Training
230
+
231
+ ```
232
+ python3 main.py --model <your_config_name> --steps_per_checkpoint <n> --tpu <tpu-name>
233
+ ```
234
+
235
+ - `tpu`: Name of the TPU to use.
236
+ - `steps_per_checkpoint`: The frequency in steps at which to save checkpoints.
237
+ - `--auto_layout` and `--auto_layout_and_mesh_shape` (Optional): Disable training and instead auto generate a memory efficient `layout` (and `mesh_shape`)
238
+ - `gpu_ids`: if training using GPUs, omit the `tpu` flag and pass in the ids of your gpus. In the example below, we train on 3 GPUs, specifying their device ids delimited by spaces:
239
+
240
+ ```
241
+ python3 main.py --model <your_config_name> --steps_per_checkpoint <n> --gpu_ids <device:GPU:0 device:GPU:1>
242
+ ```
243
+
244
+ # Available Configs
245
+
246
+ We have several model sizes available, but some of our configs require large TPUs and will need tweaking to run on smaller machines, or GPUs. Below is a short guide to each model in the configs directory:
247
+
248
+ TODO
249
+
250
+ # Extra Features:
251
+
252
+ ## Training (with Sacred)
253
+
254
+ [Sacred](https://github.com/IDSIA/sacred) helps track experiments and is much nicer to work with than tensorboard.
255
+
256
+ To setup:
257
+
258
+ 1. Install Docker and Docker-compose
259
+
260
+ 2. Run `docker-compose up`
261
+
262
+ To use:
263
+
264
+ 1. Ensure model_dir doesn't have any metric logs in it (it trips up the metric stuff for tensorboard, which assumes that it's a continuation of the existing run). You can use `gsutil rm -r ...` to delete model dir
265
+
266
+ 2. Run `python3 run_experiment.py --tpu sometpuhere --model someconfig.json` Options are the same as `main.py`.
267
+
268
+ 3. You can go to http://server_ip_goes_here:8081/ to see the Omniboard overview. If you prefer to see a tensorboard, the script also spins one up and automatically assigns it a port. The script should print out the tensorboard port near the top of the log.
269
+
270
+ ## Peeking at a Dataset
271
+
272
+ If you are ever confused by the dataset of a particular config file, you can easily check the minimum and maximum token ids with a single command. This is useful for making sure that the vocabulary size of the model is at least as large as the maximum token id. Tensorflow will not error if you try to gather on a matrix with out of bounds indices, so you need to make sure your vocabulary size is sufficiently large.
273
+
274
+ ```bash
275
+ python main --model {config_name} --check_dataset
276
+ ```
277
+
278
+ ## Masked Language Modeling
279
+
280
+ In addition to being able to train large GPT's, this repository also allows you to easily do masked language modeling (BERT, RoBERTa). In order to do so, you must follow two additional steps.
281
+
282
+ 1. When tokenizing your dataset, you must reserve a special id for the `[mask]` token.
283
+
284
+ 2. In the configs, you will have to define two additional fields
285
+
286
+ ```python
287
+ "mlm_training": true, # must be set to true
288
+ "mlm_mask_id": <mask id> # the mask id that you reserved from above
289
+ ```
290
+
291
+ That's all you need to train a model with the MLM objective, good for any type of data that you have encoded properly. If you would like to tweak the other related hyperparameters, please continue reading.
292
+
293
+ ```python
294
+ "mlm_cls_token_id": <cls token id>, # auto append specified CLS token id on the left
295
+ "mlm_mask_prob": 0.15, # the probability of masking a token, defaults to 15%
296
+ "mlm_same_token_prob": 0.10, # probability of keeping the token the same, defaults to 10%
297
+ "mlm_random_token_prob": 0.10, # probability of tokens that are replaced with random tokens, 10% was recommended by the BERT paper
298
+ "mlm_mask_ignore_ids": [<cls token>, <sep token>] # ignore masking other special tokens, if any
299
+ ```
300
+
301
+ ## Parameter Reference
302
+
303
+ Pick a valid config from `/configs` and tweak the parameters as needed:
304
+
305
+ - `n_heads`: The number of attention heads.
306
+ - `n_embd`: Size of the hidden layers, must be divisible by `n_heads`.
307
+ - `n_vocab`: Vocabulary size.
308
+ - `embed_dropout`, `res_dropout`, `attn_dropout`: Dropout probability for word embedding/residuals/attention
309
+ - `lr`: Learning rate
310
+ - `warmup_steps`: Number of steps before full learning rate is reached (linear ramp from `0` to `lr`).
311
+ - `lr_decay`: `cosine` or `linear`.
312
+ - `opt_name`: `adam` or `adafactor`.
313
+ - `beta1`, `beta2` and `epsilon`: `adam` optimizer params.
314
+ - `beta1`, `ada_epsilon1` and `ada_epsilon2`: `adafactor` optimizer params.
315
+ - `weight_decay`: Weight decay parameter, if not present no weight decay is used (the weight decay fix for Adam is used) (default: 0.01) (optional).
316
+ - `train_batch_size`: Batch size during training.
317
+ - `train_steps`: Number of training steps (batches), set to roughly ~1 epoch for now (total number of tokens in your dataset / number of tokens per batch (= `train_batch_size` / `n_ctx`)).
318
+ - `eval_steps`: Number of steps to run for each evaluation. Set to `0` for no eval. i.e After every checkpoint, the model is tested for `eval_steps`
319
+ - `iterations`: Number of steps queued to the TPU, must be smaller than `steps_per_checkpoint`. (default: 500)
320
+ - `datasets`: List of tfrecords datasets to use. Each dataset is a list with the following parameters: `[train glob , eval glob, stitch, sampling_mode, weight]`. So for example for a single dataset (note the double list): `[["bundestag_*.tfrecords", "", 10, "random_sample", 1.0]]`
321
+ + `dataset_id`: The name of a dataset configuration file in `./configs/dataset_configs`
322
+ + `stitch`: If `sampling_mode` `random_sample` is used, the input pipeline samples this amount of texts into one to sample from. You must select stitch so that `stitch * minimum_document_length >= n_ctx`
323
+ + `sampling_mode`: `chunks` (tfrecords are preprocessed into the correct length and are read sequentially) or `documents_random` (`stitch` amount of documents are concatenated and then a `n_ctx` chunk is randomly subsampled)
324
+ + `weights`: How much relative weight this dataset should have compared to others
325
+ - `model`: Which model to train. Currently only `GPT` is supported, and it defaults to this if not present.
326
+ - `model_path`: Google storage bucket location (or local path, if using GPUs) to save model checkpoints and logs.
327
+ - `n_ctx`: Size of context window. Default is 2048
328
+ - `n_layer`: Number of layers (blocks) in the model.
329
+ - `scale_by_depth`: If true, the weight initialization of layers are scaled by their depth as in the GPT2 paper.
330
+ - `scale_by_in`: If true, the weight initialization of layers are scaled by their number of inputs as in the GPT2 paper.
331
+ - `mesh_shape`: A Mesh is an n-dimensional array of processors with named dimensions used for parallelism in the mesh-tensorflow library. Each Tensor is split evenly across mesh dimensions according to the layout (see below). The 'mesh_shape' is the shape of this array, and must be equal to the number of processors. e.g., for a v3-128 TPU "mesh_shape": “x:16,y:8”.
332
+ - `layout`: A Tensor is laid out on its mesh with one slice on each processor. A Tensor "layout", is an injective partial map specifying which dimensions of the tensor are (evenly) split across which dimensions of the mesh. No dimension of a tensor may be split across two dimensions of its mesh and no two dimensions of a tensor may be split across the same dimension of its mesh. The user defines a global set of layout rules in the form of (tensor-dimension-name, mesh-dimension-name) pairs. A dimension of a tensor is split across a dimension of its mesh if there is a matching rule, e.g. (for the above example mesh_shape: "layout":"batch:x,heads:y"
333
+ - `activation_function`: `selu` (self normalizing) or `gelu` (used by OA), activation function used in feed-forward passes. (default: gelu)
334
+ - `attention_types`: the type of attention for each layer in a list of the following format [[["attention_type"], n_layers]]. e.g. for a 12 layer net [[["global"], 12]] or [[["local"], 10], [["global"], 2]].
335
+ + Choose from: `linear`, `global`, `local` or `none`. We have found a 50/50 mix of `global` and `linear` to work well. `none` allows you to create feed-forward only layers for more efficient [PAR Transformer](https://arxiv.org/abs/2009.04534) models.
336
+ - `precision`: `float32` or `bfloat16`.
337
+ - `tokens_per_mb_per_replica`: If not None, will split the batch up into smaller microbatches containing `tokens_per_mb_per_replica` tokens to avoid OOMs. Gradients are accumulated locally and reduced once. IMPORTANT: mb refers to *minibatch* not megabyte here.
338
+
339
+ **Mixture of Experts**
340
+
341
+ - `moe_layers`: A list of layer numbers to append a [mixture of experts](https://arxiv.org/abs/1701.06538) layer onto. E.G: `[2,4,6,8,10,12]`.
342
+ We have experimentally found a moe layer for every two self-attention layers to work well.
343
+ - `moe_params`: a dictionary of additional kwargs to pass in to the moe layer. E.G
344
+ `{"moe_dropout_rate": 0.0 }`
345
+
346
+ **Experimental features**
347
+
348
+ - `axial_pos_emb_`: If true, uses [axial positional embedding](https://arxiv.org/abs/1912.12180.
349
+ - `mlp_glu`: If true, uses a gated linear unit variant of feed forward layers.
350
+ - `scalenorm`: If true, uses scalenorm instead of layernorm.
351
+ - `rezero`: If true, uses [rezero](https://www.groundai.com/project/rezero-is-all-you-need-fast-convergence-at-large-depth/1) instead of layernorm.
352
+ - `num_mem_kv`: adds memory / key values from the [all-attention paper](https://arxiv.org/pdf/1907.01470.pdf). Param is an int with the number of desired mem/key values.
353
+ - `macaron`: if true - uses a [macaron transformer](https://arxiv.org/pdf/1906.02762.pdf) for each layer block.
354
+
355
+ ## TODO:
356
+
357
+ - [x] finalize documentation
358
+ - [ ] update configs
359
+
360
+ ## Citing GPT-Neo
361
+
362
+ If you have found GPT-Neo helpful in your work, you can cite this repository as
363
+
364
+ ```
365
+ @software{gpt-neo,
366
+ author = {Black, Sid and Gao, Leo and Wang, Phil and Leahy, Connor and Biderman, Stella},
367
+ title = {{GPT-Neo}: Large Scale Autoregressive Language Modeling with Mesh-Tensorflow},
368
+ url = {http://github.com/eleutherai/gpt-neo},
369
+ version = {1.0},
370
+ year = {2021},
371
+ }
372
+ ```
373
+ The version number should be replaced with the version number you are using, and the year corresponds to the project's open-source release.
374
+
375
+ If you are specifically interested in citing the GPT-Neo models trained on [the Pile](https://arxiv.org/abs/2101.00027), we would appreciate also citing
376
+ ```
377
+ @article{gao2020pile,
378
+ title={The Pile: An 800GB Dataset of Diverse Text for Language Modeling},
379
+ author={Gao, Leo and Biderman, Stella and Black, Sid and Golding, Laurence and Hoppe, Travis and Foster, Charles and Phang, Jason and He, Horace and Thite, Anish and Nabeshima, Noa and others},
380
+ journal={arXiv preprint arXiv:2101.00027},
381
+ year={2020}
382
+ }
383
+ ```
configs.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from collections import defaultdict
4
+
5
+ DATASETS = {}
6
+
7
+ for path in Path("configs/dataset_configs").glob("*.json"):
8
+ dataset_id = path.stem
9
+ DATASETS[dataset_id] = json.loads(path.read_text())
10
+
11
+
12
+ def fetch_model_params(model):
13
+ model_path = model if model.endswith(".json") else f"configs/{model}.json"
14
+ with open(model_path) as f:
15
+ params = json.load(f)
16
+
17
+ dataset_ids = []
18
+ for d in params.get("datasets"):
19
+ if isinstance(d, list):
20
+ dataset_ids.append(d[0])
21
+ else:
22
+ dataset_ids.append(d)
23
+ no_datasets = params.get("no_dataset", False)
24
+ assert no_datasets or len(dataset_ids) > 0, "You must specify at least one dataset id in the model config"
25
+
26
+ datasets = {}
27
+ last_dataset = None
28
+ for dataset_id in dataset_ids:
29
+ assert dataset_id in DATASETS, f"Dataset '{dataset_id}' was not found under dataset_configs/ folder. Please follow the example.json in that folder."
30
+ dataset = DATASETS[dataset_id]
31
+ assert params["n_vocab"] >= dataset["n_vocab"], f"The embedding table size '{params['n_vocab']}' must be greater or equal to the vocab size used to encode the dataset '{dataset_id}' ({dataset['n_vocab']})"
32
+ datasets[dataset_id] = dataset
33
+ last_dataset = dataset
34
+
35
+ if last_dataset is not None:
36
+ params["padding_id"] = last_dataset.get("padding_id", 0)
37
+ params["eos_id"] = last_dataset.get("eos_id", 1)
38
+
39
+ params["dataset_configs"] = datasets
40
+
41
+ # Set some other parameter defaults
42
+ params["mlm_training"] = params.get("mlm_training") == True
43
+ params["causal"] = not params["mlm_training"]
44
+
45
+ # Set all other parameter values to default to None
46
+ params = defaultdict(lambda: None, params)
47
+ return params
docker-compose.yml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: '3'
2
+ services:
3
+
4
+ mongo:
5
+ image: mongo
6
+ ports:
7
+ - 127.0.0.1:27017:27017
8
+ environment:
9
+ MONGO_INITDB_ROOT_USERNAME: user
10
+ MONGO_INITDB_ROOT_PASSWORD: password
11
+ MONGO_INITDB_DATABASE: db
12
+ expose:
13
+ - 27017
14
+ networks:
15
+ - omniboard
16
+ volumes:
17
+ - ./data:/data/db
18
+
19
+ mongoClientTemp:
20
+ image: mongo:latest
21
+ container_name: mongoClientTemp
22
+ links:
23
+ - mongo:mongo
24
+ command: mongo --host mongo -u user -p password --eval "db.getSiblingDB('db').createUser({user:'readonly', pwd:'password', roles:[{role:'read',db:'db'}]});"
25
+ depends_on:
26
+ - mongo
27
+ networks:
28
+ - omniboard
29
+
30
+ omniboard_readonly:
31
+ #image: vivekratnavel/omniboard:latest
32
+ build: https://github.com/lucidrains/omniboard.git
33
+ command: ["--mu", "mongodb://readonly:password@mongo:27017/db"]
34
+ ports:
35
+ - 0.0.0.0:8081:9000
36
+ networks:
37
+ - omniboard
38
+ depends_on:
39
+ - mongo
40
+
41
+ omniboard:
42
+ #image: vivekratnavel/omniboard:latest
43
+ build: https://github.com/lucidrains/omniboard.git
44
+ command: ["--mu", "mongodb://user:password@mongo:27017/db?authSource=admin"]
45
+ expose:
46
+ - 9000
47
+ networks:
48
+ - omniboard
49
+ depends_on:
50
+ - mongo
51
+
52
+ nginx:
53
+ image: dhswt/nginx-basic-auth:1.3
54
+ environment:
55
+ - HTPASSWD=isaac: #put passwd here
56
+ - FORWARD_HOST=omniboard
57
+ - FORWARD_PORT=9000
58
+ networks:
59
+ - omniboard
60
+ depends_on:
61
+ - omniboard
62
+ ports:
63
+ - 0.0.0.0:8080:80
64
+ expose:
65
+ - 8080
66
+ networks:
67
+ omniboard:
encoders.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tokenizers import Tokenizer
2
+ from transformers import GPT2Tokenizer, GPT2TokenizerFast
3
+
4
+ def fetch_encoder(params):
5
+ no_dataset = params.get('no_dataset', False)
6
+ if no_dataset:
7
+ return None
8
+
9
+ dataset = next(iter(params['dataset_configs'].values())) # Get the first value from the dict
10
+ path = dataset["tokenizer_path"]
11
+ is_pretrained = dataset.get("tokenizer_is_pretrained", False)
12
+
13
+ if is_pretrained:
14
+ tok = GPT2TokenizerFast.from_pretrained(path)
15
+
16
+ # Will add a padding token id of 50257 at run-time
17
+ tok.add_special_tokens({'pad_token': '<|padding|>'})
18
+ return tok
19
+
20
+ return Tokenizer.from_file(path)
21
+
22
+
23
+ # GPT2Tokenizer and Tokenizer have different ways of fetching token ids
24
+ def encode(encoder, text, gpt=True):
25
+ result = encoder.encode(text)
26
+ if isinstance(result, list):
27
+ return result
28
+ return result.ids
export.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow.compat.v1 as tf
2
+
3
+ def export_model(estimator, export_dir, params,
4
+ checkpoint_path=None):
5
+
6
+
7
+ def serving_input_receiver_fn():
8
+ t = tf.placeholder(dtype=tf.int64,
9
+ shape=[1, params["n_ctx"]],
10
+ name='input_example_tensor')
11
+ return tf.estimator.export.ServingInputReceiver(t, t)
12
+
13
+ return estimator.export_saved_model(
14
+ export_dir, serving_input_receiver_fn, checkpoint_path=checkpoint_path)
inputs.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow.compat.v1 as tf
3
+ from functools import partial
4
+ from data.encoders import encode
5
+ import random
6
+ import re
7
+ import logging
8
+ from itertools import cycle
9
+ from utils import natural_sort
10
+
11
+
12
+ ### IN USE ###
13
+
14
+ def _get_number_of_documents(filename):
15
+ # extracts number of files from a filename formatted "<name>_<num_documents>.tfrecords."
16
+ # if no pattern is matched, returns None
17
+ match = re.search("_(\d{1,}).tfrecords$", filename)
18
+ return int(match.group(1)) if match is not None else match
19
+
20
+
21
+ def _get_number_of_documents_by_iteration(filename):
22
+ # extracts number of files from a tfrecord document in the event it doesn't have metadata in the filename
23
+ # this could be very slow.
24
+ logging.warning(
25
+ "inputs/sequential_input() found no metadata found in filename - iterating through first tfrecord to find global length")
26
+ count = 0
27
+ for item in tf.io.tf_record_iterator(filename):
28
+ count += 1
29
+ return count
30
+
31
+
32
+ def _get_skip_index(all_files, n_batches):
33
+ prev_cumsum = 0
34
+ cumsum = 0
35
+ global_n_documents = None
36
+ for count, f in cycle(enumerate(all_files)):
37
+ prev_cumsum = cumsum
38
+ if _get_number_of_documents(f) is not None:
39
+ cumsum += _get_number_of_documents(f)
40
+ elif global_n_documents is None:
41
+ global_n_documents = _get_number_of_documents_by_iteration(f)
42
+ cumsum += global_n_documents
43
+ else:
44
+ cumsum += global_n_documents
45
+ if cumsum == n_batches:
46
+ remainder = 0
47
+ skip_idx = count + 1
48
+ elif cumsum > n_batches:
49
+ remainder = n_batches - prev_cumsum
50
+ skip_idx = count
51
+ break
52
+ return skip_idx, remainder
53
+
54
+
55
+ def _parse_function(example_proto):
56
+ features = {
57
+ "text": tf.VarLenFeature(tf.int64)
58
+ }
59
+ parsed_features = tf.parse_single_example(example_proto, features)
60
+ return tf.sparse.to_dense(parsed_features["text"], parsed_features["text"].dense_shape[0])
61
+
62
+
63
+ def autoregressive_sample_text(params, x):
64
+ vals1 = x[:params["n_ctx"]]
65
+ vals2 = x[1:params["n_ctx"] + 1]
66
+
67
+ vals1 = tf.reshape(vals1, [params["n_ctx"]])
68
+ vals2 = tf.reshape(vals2, [params["n_ctx"]])
69
+ vals1 = tf.cast(vals1, dtype=tf.int32)
70
+ vals2 = tf.cast(vals2, dtype=tf.int32)
71
+ return vals1, vals2
72
+
73
+
74
+ def sequential_input(params, global_step=None, eval=False):
75
+ """
76
+ Input fn that reads tfrecords encoded with a fixed chunk size (== n_ctx + 1), and that either:
77
+
78
+ - has the number of documents for each tfrecord file encoded in the title in the format
79
+ <name>_<n_documents>.tfrecords.
80
+
81
+ OR
82
+
83
+ - has a fixed number of documents per tfrecord file.
84
+
85
+ If the glob pattern above isn't matched, we assume that each document has the same number of samples as the first tfrecord read.
86
+ If this isn't the case, it may result in errors, or some samples being missed.
87
+
88
+ This means we can calculate the number of samples we've seen so far using the global step,
89
+ and can use dataset.skip() to iterate through the list of filenames, as opposed to the whole dataset, which is incredibly inefficient.
90
+
91
+ If training is starting and stopping often, as with TPU pre-emption, reading the whole dataset sequentially appears to improve model
92
+ performance, as it results in less repeated data.
93
+ """
94
+ if not eval:
95
+ assert global_step is not None
96
+ logging.warning(
97
+ "Changing batch size with sequential_input() will result in some data being skipped or repeated. Please ensure your batch size stays constant throughout training.")
98
+ batch_size = params['eval_batch_size' if eval else 'train_batch_size']
99
+
100
+ filenames = []
101
+ for dataset_config in params['dataset_configs'].values(): # iterate through each dataset and read params
102
+ path_key = 'path' if not eval else 'eval_path'
103
+ path = dataset_config[path_key]
104
+ filenames.extend(
105
+ tf.io.gfile.glob(path)) # then glob all files that fit the pattern specified in dataset_configs
106
+
107
+ filenames = natural_sort(filenames)
108
+ shuffle_filenames = params.get("shuffle_input_filenames", True)
109
+ if shuffle_filenames:
110
+ seed = params.get('seed', 1) # shuffle deterministically
111
+ random.seed(seed)
112
+ random.shuffle(filenames)
113
+
114
+ dataset = tf.data.Dataset.from_tensor_slices(filenames).repeat() # repeat filenames to infinity
115
+
116
+ if not eval:
117
+ # skip forward first in the filenames list, then skip the remaining amount in the parsed tfrecords files
118
+ skip_idx, remainder = _get_skip_index(filenames, n_batches=global_step * params[
119
+ "train_batch_size"]) # TODO: fix for > 1 epoch
120
+ dataset = dataset.skip(skip_idx) # skip to skip idx
121
+
122
+ # read tfrecord examples and skip remainder
123
+ dataset = dataset.apply(tf.data.TFRecordDataset)
124
+ dataset = dataset.skip(remainder)
125
+ else:
126
+ # shuffle filenames if in eval mode
127
+ dataset = dataset.shuffle(len(filenames))
128
+ dataset = dataset.apply(tf.data.TFRecordDataset)
129
+
130
+ # parse the tokenized data from the tfrecord files and shuffle
131
+ dataset = dataset.map(_parse_function, num_parallel_calls=1)
132
+ dataset = dataset.map(partial(autoregressive_sample_text, params), num_parallel_calls=1)
133
+
134
+ # batch data and repeat to infinity
135
+ dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(params["iterations"] * 2)
136
+ return dataset.repeat()
137
+
138
+
139
+ def pred_input(params, logger, enc=None,
140
+ path_to_prompt=""):
141
+ unicorns = "In a shocking finding, scientists discovered a herd of unicorns living in a remote, " \
142
+ "previously unexplored valley, in the Andes Mountains. Even more surprising to the " \
143
+ "researchers was the fact that the unicorns spoke perfect English."
144
+
145
+ text = unicorns if path_to_prompt == "" else open(path_to_prompt, "r").read()
146
+ tokens = encode(enc, text)
147
+
148
+ if len(tokens) > params["n_ctx"]:
149
+ logger.info("The length of your input prompt is longer than the model's context length - truncating input.")
150
+ tokens = tokens[len(tokens) - params["n_ctx"]:]
151
+ if len(tokens) < params["n_ctx"]:
152
+ tokens = tf.pad(tokens, [[0, params["n_ctx"] - len(tokens)]], constant_values=params["padding_id"])
153
+ t = tf.broadcast_to(tokens, [params["batch_size"], params["n_ctx"]])
154
+ dataset = tf.data.Dataset.from_tensors(t)
155
+
156
+ def _dummy_labels(x):
157
+ return x, x
158
+
159
+ dataset = dataset.map(_dummy_labels)
160
+ return dataset
161
+
162
+
163
+ def handle_pred_output(predictions, logger, enc, params, out_name="test"):
164
+ with tf.gfile.Open(f"{out_name}.txt", "w") as f:
165
+ for i, p in enumerate(predictions):
166
+ p = p["outputs"]
167
+
168
+ # remove eos + padding ids from output
169
+ idx = np.argmax(p == params['eos_id'])
170
+ if idx > 0:
171
+ p = p[:idx]
172
+ idx = np.argmax(p == params['padding_id'])
173
+ if idx > 0:
174
+ p = p[:idx]
175
+
176
+ text = enc.decode(p)
177
+ f.write("=" * 40 + " SAMPLE " + str(i) + " " + "=" * 40 + "\n")
178
+ f.write(text)
179
+ f.write("\n" + "=" * 80 + "\n")
180
+
181
+ logger.info("=" * 40 + " SAMPLE " + str(i) + " " + "=" * 40 + "\n")
182
+ logger.info(text)
183
+ logger.info("\n" + "=" * 80 + "\n")
184
+
185
+
186
+ ### DEPRECATED ###
187
+
188
+ def generic_text(params, eval=False, sample_text_fn=None, **kwargs):
189
+ logging.warning("DEPRECATION WARNING: generic_text will be phased out in future versions.")
190
+ i = 0 if not eval else 1
191
+
192
+ weights = []
193
+ datasets = []
194
+
195
+ for dataset in params["datasets"]:
196
+ dataset_id, stitch, datatype, weight = dataset
197
+
198
+ assert dataset_id in params[
199
+ 'dataset_configs'], f'Unknown dataset id {dataset_id} given. Please make sure your dataset ids contain that configuration'
200
+ dataset_config = params['dataset_configs'][dataset_id]
201
+
202
+ path_key = 'path' if not eval else 'eval_path'
203
+ path = dataset_config[path_key]
204
+
205
+ datasets.append(text_dataset(
206
+ tf.io.gfile.glob(path),
207
+ params,
208
+ stitch=stitch,
209
+ datatype=datatype,
210
+ batch=False,
211
+ sample_text_fn=sample_text_fn
212
+ ))
213
+
214
+ weights.append(weight)
215
+
216
+ batch_size = params['eval_batch_size' if eval else 'train_batch_size']
217
+
218
+ seed = params.get('seed', None)
219
+ dataset = tf.data.experimental.sample_from_datasets(datasets, weights=weights, seed=seed)
220
+ dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(params["iterations"] * 2)
221
+ return dataset
222
+
223
+
224
+ def text_dataset(files, params, stitch, datatype, batch=True, sample_text_fn=None):
225
+ seed = params.get('seed', None)
226
+ deterministic = seed is not None
227
+ num_parallel_calls = 1 if deterministic else tf.data.experimental.AUTOTUNE
228
+
229
+ dataset = tf.data.Dataset.from_tensor_slices(files)
230
+
231
+ if deterministic:
232
+ dataset = dataset.interleave(tf.data.TFRecordDataset, cycle_length=4)
233
+ else:
234
+ dataset = dataset.apply(
235
+ tf.data.experimental.parallel_interleave(tf.data.TFRecordDataset, cycle_length=4, sloppy=False))
236
+
237
+ if "documents" in datatype:
238
+ def _parse_function(example_proto):
239
+ features = {
240
+ # "hash": tf.VarLenFeature(tf.string),
241
+ "text": tf.VarLenFeature(tf.int64)
242
+ }
243
+ parsed_features = tf.parse_single_example(example_proto, features)
244
+ return parsed_features["text"], parsed_features["text"].dense_shape[0]
245
+ else:
246
+ def _parse_function(example_proto):
247
+ features = {
248
+ "text": tf.VarLenFeature(tf.int64)
249
+ }
250
+ parsed_features = tf.parse_single_example(example_proto, features)
251
+ return parsed_features["text"] # Assuming the text is not sparse
252
+
253
+ dataset = dataset.map(_parse_function, num_parallel_calls=1)
254
+
255
+ # Subsample method
256
+ if "documents" in datatype:
257
+ # Since samples can be less than the correct length, and TPUs don't like variable lengths, this function stitches together enough samples
258
+ # to have a text at least 1024 tokens long. For this to work the stitch parameter must be correctly tuned so that
259
+ # stitch * min(characters_in_text) >= amount
260
+ def _stitch_text(x, y):
261
+ x = tf.sparse.to_dense(x)
262
+
263
+ def _get_x(i):
264
+ return tf.gather(x[i], tf.range(y[i]))
265
+
266
+ out = _get_x(0)
267
+ eos_id = params['eos_id']
268
+
269
+ for i in range(1, stitch):
270
+ out = tf.concat([out, [eos_id], _get_x(i)], axis=0) # text1<|endoftext|>text2
271
+
272
+ return out
273
+
274
+ # Hack-y way to stitch together multiple texts
275
+
276
+ dataset = dataset.shuffle(1000 * stitch, seed=seed).batch(stitch, drop_remainder=True).map(_stitch_text,
277
+ num_parallel_calls=num_parallel_calls)
278
+
279
+ # Sample 1024(+1) tokens from the stitched together text
280
+ is_random_documents = datatype == "documents_random"
281
+ if sample_text_fn is not None:
282
+ _sample_text = partial(sample_text_fn, random_documents=is_random_documents)
283
+ else:
284
+ _sample_text = autoregressive_sample_text_random_documents if is_random_documents else autoregressive_sample_text
285
+ _sample_text = partial(_sample_text, params)
286
+
287
+ dataset = dataset.map(_sample_text, num_parallel_calls=num_parallel_calls)
288
+
289
+ if batch:
290
+ dataset = dataset.batch(params["train_batch_size"], drop_remainder=True).prefetch(params["iterations"] * 2)
291
+
292
+ dataset = dataset.repeat()
293
+
294
+ return dataset
295
+
296
+
297
+ def autoregressive_sample_text_random_documents(params, x):
298
+ seed = params.get('seed', None)
299
+ s = tf.size(x)
300
+ r = tf.random.uniform([], maxval=s - (params["n_ctx"] + 1), dtype=tf.dtypes.int32, seed=seed)
301
+ r1 = tf.range(r, r + params["n_ctx"])
302
+ r2 = tf.range(r + 1, (r + 1) + params["n_ctx"])
303
+ r1 = tf.reshape(r1, [params["n_ctx"]]) # Somehow, this makes the compiler happy
304
+ r2 = tf.reshape(r2, [params[
305
+ "n_ctx"]]) # TPUs want constant sized input, and these reshapes makes it recognize the shape of the input
306
+ vals1 = tf.gather(x, r1)
307
+ vals2 = tf.gather(x, r2)
308
+
309
+ vals1 = tf.reshape(vals1, [params["n_ctx"]])
310
+ vals2 = tf.reshape(vals2, [params["n_ctx"]])
311
+ vals1 = tf.cast(vals1, dtype=tf.int32)
312
+ vals2 = tf.cast(vals2, dtype=tf.int32)
313
+ return vals1, vals2
314
+
315
+
316
+ def mlm_sample_text(params, x, random_documents=False):
317
+ seed = params.get('seed', None)
318
+ ctx_len = params["n_ctx"]
319
+ assert 'mlm_mask_id' in params, 'the key `mlm_mask_id` must be set on your config to do masked language model training, specifying the id of the reserved mask token'
320
+
321
+ mask_id = params['mlm_mask_id']
322
+ cls_token_id = params.get('mlm_cls_token_id', None)
323
+ num_tokens = params.get('n_vocab', None)
324
+
325
+ mask_ignore_ids = set(params.get('mlm_mask_ignore_ids', []))
326
+ mask_ignore_ids.add(cls_token_id)
327
+
328
+ mask_prob = params.get('mlm_mask_prob', 0.15)
329
+ same_token_prob = params.get('mlm_same_token_prob', 0.10)
330
+ random_token_prob = params.get('mlm_random_token_prob', 0.)
331
+
332
+ seq_len = ctx_len if cls_token_id is None else (ctx_len - 1)
333
+
334
+ if random_documents:
335
+ s = tf.size(x)
336
+ r = tf.random.uniform([], maxval=(s - seq_len), dtype=tf.dtypes.int32, seed=seed)
337
+ r1 = tf.range(r, r + seq_len)
338
+ r1 = tf.reshape(r1, [seq_len])
339
+ features = tf.gather(x, r1)
340
+ else:
341
+ features = x[:seq_len]
342
+
343
+ # add cls token id if specified by `mlm_cls_token_id`
344
+ if cls_token_id is not None:
345
+ features = tf.pad(features, [[1, 0]], constant_values=cls_token_id)
346
+
347
+ features = tf.cast(features, dtype=tf.int32)
348
+ shape = features.shape
349
+
350
+ # determine which tokens are mask-able
351
+ can_mask = tf.not_equal(features, 0)
352
+ for ignore_id in mask_ignore_ids:
353
+ can_mask &= tf.not_equal(features, ignore_id)
354
+
355
+ # generate boolean mask for masking ids
356
+ mask_mask = tf.less(tf.random.uniform(shape, minval=0., maxval=1., dtype=tf.float32, seed=seed), mask_prob)
357
+ mask_mask &= can_mask
358
+
359
+ # generate mask for actually replacing the tokens, for allowing a small number of tokens to stay the same
360
+ replace_mask = tf.less(tf.random.uniform(shape, minval=0., maxval=1., dtype=tf.float32, seed=seed),
361
+ 1 - same_token_prob)
362
+
363
+ # randomly replace some tokens with random tokens before masking
364
+ if random_token_prob > 0:
365
+ random_token_mask = tf.less(tf.random.uniform(shape, minval=0., maxval=1., dtype=tf.float32, seed=seed),
366
+ random_token_prob)
367
+ random_tokens = tf.random.uniform(shape, minval=1, maxval=num_tokens, dtype=tf.dtypes.int32, seed=seed)
368
+
369
+ # make sure random tokens do not include illegal token ids specified by `mlm_mask_ignore_ids`
370
+ random_can_mask = tf.not_equal(random_tokens, 0)
371
+ for ignore_id in mask_ignore_ids:
372
+ random_can_mask &= tf.not_equal(random_tokens, ignore_id)
373
+
374
+ features = tf.where(random_token_mask & random_can_mask, random_tokens, features)
375
+
376
+ # mask the tokens
377
+ mask_tokens = tf.ones(shape, dtype=tf.int32) * mask_id
378
+ masked_features = tf.where(mask_mask & replace_mask, mask_tokens, features)
379
+
380
+ # labels will be set to 0 for all non-masked tokens
381
+ labels = tf.where(mask_mask, tf.zeros(shape, dtype=tf.int32), features)
382
+
383
+ masked_features, labels = map(lambda t: tf.reshape(t, [ctx_len]), (masked_features, labels))
384
+ return masked_features, labels
main.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GPT-like model in Mesh-Tensorflow"""
2
+
3
+ from functools import partial
4
+ import mesh_tensorflow as mtf
5
+ import tensorflow.compat.v1 as tf
6
+ from tensorflow.python.tpu import tpu_config, tpu_estimator
7
+ from tensorflow_estimator.python.estimator import estimator as estimator_lib
8
+ from utils import save_config, expand_attention_types_params, yes_or_no, remove_gs_or_filepath, setup_logging, \
9
+ check_dataset
10
+ from inputs import sequential_input, pred_input, handle_pred_output, mlm_sample_text, generic_text
11
+ from export import export_model
12
+ from model_fns import model_fn
13
+ from data.encoders import fetch_encoder
14
+ from configs import fetch_model_params
15
+ from tasks import task_descriptors
16
+ import argparse
17
+ import json
18
+ import numpy
19
+
20
+
21
+ def parse_args():
22
+ # Parse command line arguments
23
+ parser = argparse.ArgumentParser()
24
+ parser.add_argument("--tpu", type=str, help="Name of TPU to train on, if any.")
25
+ parser.add_argument("--gpu_ids", nargs="+", type=str, default=["device:GPU:0"],
26
+ help="If training on GPU, can specify your GPU names in a list - i.e 'device:GPU:0 device:GPU:1'")
27
+ parser.add_argument("--model", type=str, default=None, help="JSON file that contains model parameters.")
28
+ parser.add_argument("--steps_per_checkpoint", type=int, default=5000, help="Save a model checkpoint every X steps.")
29
+ parser.add_argument("--auto_layout", action="store_true", help="If set, generates and prints the most memory "
30
+ "efficient layout according to MTF auto layout.")
31
+ parser.add_argument("--auto_layout_and_mesh_shape", action="store_true",
32
+ help="If set, generates and prints the most memory efficient layout and mesh shape according to"
33
+ " MTF auto layout.")
34
+ parser.add_argument("--new", action="store_true", help="If set, deletes previous checkpoint, if it exists, and "
35
+ "starts a new training run")
36
+ parser.add_argument("--predict", action="store_true", help="If set, uses the model to predict rather than train.")
37
+ parser.add_argument("--eval", action="store_true", help="If set, run model in evaluation mode.")
38
+ parser.add_argument("--prompt", type=str, help="path to .txt file containing a prompt for prediction. If empty, "
39
+ "defaults to unicorns.",
40
+ default="")
41
+ parser.add_argument("--check_dataset", action="store_true",
42
+ help="If set, outputs sample from the dataset and quits.")
43
+ parser.add_argument("--sacred_id", type=str, default="nosacred", help="Sacred run id.")
44
+ parser.add_argument("--entmax_sampling", action="store_true", help="(experimental) use entmax sampling")
45
+ parser.add_argument("--export", action="store_true", help="If set, will export the model.")
46
+ args = parser.parse_args()
47
+ assert args.model is not None, "Model must be set"
48
+ return args
49
+
50
+
51
+ def main(args):
52
+ # Setup logging
53
+ logger = setup_logging(args)
54
+
55
+ # Read params of model
56
+ params = fetch_model_params(args.model)
57
+
58
+ # Fetch appropriate input functions
59
+ input_fn = params.get("input_fn", "sequential_input")
60
+ if input_fn == "sequential_input":
61
+ input_fn = sequential_input
62
+ elif input_fn == "generic_text":
63
+ input_fn = generic_text
64
+ pred_input_fn = pred_input
65
+ handle_pred_output_fn = handle_pred_output
66
+
67
+ # get current step
68
+ current_step = int(estimator_lib._load_global_step_from_checkpoint_dir(params["model_path"]))
69
+ logger.info(f"Current step {current_step}")
70
+
71
+ if params["mlm_training"]:
72
+ mlm_sample_text_fn = partial(mlm_sample_text, params)
73
+ input_fn = partial(generic_text, sample_text_fn=mlm_sample_text_fn)
74
+ if args.check_dataset:
75
+ check_dataset(input_fn, params)
76
+
77
+
78
+ # Fetch encoder per params
79
+ encoder = fetch_encoder(params)
80
+
81
+ pred_input_fn = partial(pred_input_fn, path_to_prompt=args.prompt, logger=logger, enc=encoder)
82
+
83
+ # Sample from Dataset if check dataset flag is on
84
+ if args.check_dataset:
85
+ check_dataset(input_fn, params, global_step=current_step)
86
+
87
+ # Confirm deletion of checkpoint files if --new flag is set
88
+ if args.new:
89
+ if yes_or_no(f"Are you sure you want to remove '{params['model_path']}' to start afresh?"):
90
+ remove_gs_or_filepath(params["model_path"])
91
+ else:
92
+ exit()
93
+
94
+ # Save config to logdir for experiment management
95
+ save_config(params, params["model_path"])
96
+
97
+ # Add to params: auto_layout, auto_layout_and_mesh_shape, use_tpu, num_cores
98
+ mesh_shape = mtf.convert_to_shape(params["mesh_shape"])
99
+ params["num_cores"] = mesh_shape.size
100
+ params["auto_layout"] = args.auto_layout
101
+ params["auto_layout_and_mesh_shape"] = args.auto_layout_and_mesh_shape
102
+ params["use_tpu"] = True if not args.tpu is None else False
103
+ params["gpu_ids"] = args.gpu_ids
104
+ params["steps_per_checkpoint"] = args.steps_per_checkpoint
105
+ # Expand attention types param
106
+ params["attention_types"] = expand_attention_types_params(params["attention_types"])
107
+ assert len(params["attention_types"]) == params["n_layer"] # Assert that the length of expanded list = num layers
108
+ params["predict_batch_size"] = params.get("predict_batch_size", 1) # Default to 1
109
+ params["predict"] = args.predict
110
+ params['model'] = params.get("model", "GPT") # Default model selection to GPT since it's the only option for now
111
+ params["export"] = args.export
112
+ # Set sampling parameters
113
+ params["sampling_use_entmax"] = args.entmax_sampling
114
+
115
+ # Sample quality of MoE models suffers when using the faster sampling method, so default to slow_sampling if
116
+ # moe layers are present
117
+ params["slow_sampling"] = True if params["moe_layers"] is not None else False
118
+
119
+ logger.info(f"params = {params}")
120
+
121
+ # Get eval tasks from params
122
+ eval_tasks = params.get("eval_tasks", [])
123
+ has_predict_or_eval_steps_or_eval_tasks = params["predict_steps"] > 0 or params["eval_steps"] > 0 or len(
124
+ eval_tasks) > 0
125
+
126
+ for t in eval_tasks:
127
+ assert t in task_descriptors, f"Eval task '{t}' is not known"
128
+ task_descriptors[t]["init_fn"](params)
129
+
130
+ # Set up TPUs and Estimator
131
+ if args.tpu == "colab":
132
+ tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver() if params["use_tpu"] else None
133
+ else:
134
+ tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(args.tpu) if params["use_tpu"] else None
135
+
136
+ config = tpu_config.RunConfig(
137
+ cluster=tpu_cluster_resolver,
138
+ model_dir=params["model_path"],
139
+ save_checkpoints_steps=None, # Disable the default saver
140
+ save_checkpoints_secs=None, # Disable the default saver
141
+ log_step_count_steps=params["iterations"],
142
+ save_summary_steps=params["iterations"],
143
+ tpu_config=tpu_config.TPUConfig(
144
+ num_shards=mesh_shape.size,
145
+ iterations_per_loop=params["iterations"],
146
+ num_cores_per_replica=1,
147
+ per_host_input_for_training=tpu_config.InputPipelineConfig.BROADCAST))
148
+
149
+ estimator = tpu_estimator.TPUEstimator(
150
+ use_tpu=params["use_tpu"],
151
+ model_fn=model_fn,
152
+ config=config,
153
+ train_batch_size=params["train_batch_size"],
154
+ eval_batch_size=params["train_batch_size"],
155
+ predict_batch_size=params["predict_batch_size"],
156
+ params=params)
157
+
158
+ def _make_task_estimator(task):
159
+ task_params = params.copy()
160
+ task_params["eval_task"] = task
161
+ return tpu_estimator.TPUEstimator(
162
+ use_tpu=params["use_tpu"],
163
+ model_fn=model_fn,
164
+ config=config,
165
+ train_batch_size=params["train_batch_size"],
166
+ eval_batch_size=params["eval_batch_size"],
167
+ predict_batch_size=params["predict_batch_size"],
168
+ params=task_params)
169
+
170
+ eval_task_estimators = {
171
+ task: _make_task_estimator(task)
172
+ for task in eval_tasks
173
+ }
174
+
175
+ if args.export:
176
+ export_model(estimator, "export", params)
177
+ return
178
+
179
+ if args.predict:
180
+ # Predict
181
+ predictions = estimator.predict(input_fn=pred_input_fn)
182
+ logger.info("Predictions generated")
183
+ enc = fetch_encoder(params)
184
+ handle_pred_output_fn(predictions, logger, enc, params, out_name=f"predictions_{args.sacred_id}_{current_step}")
185
+ return
186
+
187
+ def save_eval_results(task, eval_results):
188
+ def as_python(x):
189
+ if isinstance(x, numpy.generic):
190
+ return x.item()
191
+ return x
192
+ eval_results = {k: as_python(v) for k, v in eval_results.items()}
193
+ with open(f'eval_{args.sacred_id}.jsonl', 'a') as fh:
194
+ json.dump({'task': task, 'current_step': current_step, **eval_results}, fh)
195
+ fh.write('\n')
196
+
197
+ def run_eval():
198
+ logger.info("Running evaluation...")
199
+ eval_results = estimator.evaluate(
200
+ input_fn=partial(input_fn, eval=True),
201
+ steps=params["eval_steps"])
202
+ logger.info(f"Eval results: {eval_results}")
203
+ save_eval_results('validation', eval_results)
204
+
205
+ def run_eval_tasks():
206
+ for task in eval_tasks:
207
+ logger.info(f"Starting evaluation task '{task}'")
208
+ task_info = task_descriptors[task]["get_task_info_fn"](params)
209
+ task_estimator = eval_task_estimators[task]
210
+ task_input_fn = task_descriptors[task]["input_fn"]
211
+ eval_results = task_estimator.evaluate(
212
+ input_fn=task_input_fn,
213
+ steps=task_info["n_steps"],
214
+ name=task)
215
+ logger.info(f"Eval task '{task}' results: {eval_results}")
216
+ save_eval_results(task, eval_results)
217
+
218
+ if args.eval:
219
+ run_eval_tasks()
220
+ if params["eval_steps"] > 0:
221
+ run_eval()
222
+ return
223
+
224
+
225
+ elif has_predict_or_eval_steps_or_eval_tasks:
226
+ # Eval and train - stop and predict and/or eval every checkpoint
227
+ while current_step < params["train_steps"]:
228
+ next_checkpoint = min(current_step + args.steps_per_checkpoint,
229
+ params["train_steps"])
230
+
231
+ estimator.train(input_fn=partial(input_fn, global_step=current_step, eval=False), max_steps=next_checkpoint)
232
+ current_step = next_checkpoint
233
+
234
+ if params["predict_steps"] > 0:
235
+ logger.info("Running prediction...")
236
+ predictions = estimator.predict(input_fn=pred_input_fn)
237
+ enc = fetch_encoder(params)
238
+ handle_pred_output_fn(predictions, logger, enc, params, out_name=f"predictions_{args.sacred_id}_{current_step}")
239
+
240
+ if params["eval_steps"] > 0:
241
+ run_eval()
242
+
243
+ if eval_tasks:
244
+ run_eval_tasks()
245
+
246
+ return
247
+ else:
248
+ # Else, just train
249
+ while current_step < params["train_steps"]:
250
+ # Else, don't stop and restart
251
+ estimator.train(input_fn=partial(input_fn, global_step=current_step, eval=False), max_steps=params["train_steps"])
252
+
253
+
254
+ if __name__ == "__main__":
255
+ tf.disable_v2_behavior()
256
+ args = parse_args()
257
+ main(args)
model_fns.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mesh_tensorflow as mtf
2
+ import tensorflow.compat.v1 as tf
3
+ from tensorflow.python.tpu import tpu_estimator
4
+ import mesh_tensorflow.transformer as mtf_transformer
5
+ from optimizers import get_optimizer
6
+ from utils import (create_host_call, get_graph_info, remove_batch_from_layout, simd_mesh_setup, add_mode_to_params,
7
+ get_batch_size, auto_layout, auto_layout_and_mesh_shape)
8
+ from models.utils import biasmask_attn_weights
9
+ from tensorflow.python.ops import resources
10
+ from sample import sample_autoregressive
11
+ from models.gpt2 import gpt2
12
+ import math
13
+
14
+
15
+ def model_fn(features, labels, mode, params):
16
+ # Get global step
17
+ global_step = tf.train.get_global_step()
18
+
19
+ # Construct mtf graph + mesh from params
20
+ graph = mtf.Graph()
21
+ mesh_shape = mtf.convert_to_shape(params["mesh_shape"])
22
+ layout_rules = mtf.convert_to_layout_rules(params["layout"])
23
+
24
+ # Mesh setup
25
+ if params["use_tpu"]:
26
+ var_placer, mesh_impl = simd_mesh_setup(params, mesh_shape, layout_rules)
27
+ else:
28
+ var_placer = None
29
+ gpu_ids = params["gpu_ids"]
30
+ mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
31
+ mesh_shape, layout_rules, gpu_ids)
32
+
33
+ # Trainable variable precision
34
+ # Store to checkpoints in master type, train in slice type, compute in activation type
35
+ if params["precision"] == "bfloat16":
36
+ variable_dtype = mtf.VariableDType(master_dtype=tf.bfloat16, slice_dtype=tf.float32,
37
+ activation_dtype=tf.bfloat16)
38
+ else:
39
+ variable_dtype = mtf.VariableDType(master_dtype=tf.float32, slice_dtype=tf.float32, activation_dtype=tf.float32)
40
+
41
+ # Build mtf mesh object
42
+ mesh = mtf.Mesh(graph, "my_mesh", var_placer)
43
+
44
+ # Build mtf_features & seq length dict for getting number of microbatches
45
+ # We need to pack inputs into a dict to pass into serialize_training_step
46
+ features_dict = {"inputs": features, "labels": labels}
47
+ sequence_length_dict = {"inputs": params["n_ctx"], "labels": params["n_ctx"]}
48
+
49
+ params = add_mode_to_params(params, mode)
50
+ batch_size = get_batch_size(params)
51
+
52
+ batch_dim = mtf.Dimension("batch", batch_size)
53
+ batch_dims = [batch_dim]
54
+ feature_length = sequence_length_dict["inputs"]
55
+ length_dim = mtf.Dimension("sequence", feature_length)
56
+
57
+ mtf_features = {}
58
+ for key, x in features_dict.items():
59
+ if x is not None:
60
+ feature_shape = mtf.Shape(batch_dims + [length_dim])
61
+ if type(features_dict[key]) == dict:
62
+ features_dict[key] = features_dict[key]["feature"]
63
+ x = tf.cast(features_dict[key], tf.int32)
64
+ x = tf.reshape(x, feature_shape.to_integer_list)
65
+ mtf_features[key] = mtf.import_fully_replicated(
66
+ mesh, x, feature_shape, name=key)
67
+
68
+ # Instantiate dict for dimensions, bias, etc that can be calculated here once then passed into model
69
+ other_features = {}
70
+ memory_length_dim = mtf.Dimension("memory_length", length_dim.size)
71
+
72
+ attn_bias = biasmask_attn_weights(mesh, length_dim, memory_length_dim, variable_dtype) if params["causal"] else None
73
+
74
+ # Add attn_bias into mtf_features
75
+ other_features["attn_bias"] = attn_bias
76
+
77
+ # Define other Dimensions that we'll need inside the model
78
+ embd_dim = mtf.Dimension("embd", params["n_embd"])
79
+ vocab_dim = mtf.Dimension("vocab", params["n_vocab"])
80
+ # We need this because gathering when both the args have the same dimension in them breaks things
81
+ # This dim is specifically for the weights
82
+ # This prevents the "Einsum has lhs dimension without corresponding rhs or output dimension." error
83
+ embed_sequence_dim = mtf.Dimension("embed_sequence", params["n_ctx"])
84
+
85
+ other_features["embd_dim"] = embd_dim
86
+ other_features["vocab_dim"] = vocab_dim
87
+ other_features["embed_sequence_dim"] = embed_sequence_dim
88
+ other_features["memory_length_dim"] = memory_length_dim
89
+
90
+ if mode == tf.estimator.ModeKeys.PREDICT:
91
+ # Set up the model for prediction
92
+ inputs = mtf_features["inputs"]
93
+ if params["remove_partial_sequences"] is None:
94
+ params["remove_partial_sequences"] = False
95
+
96
+ export = params.get("export", False)
97
+
98
+ if not export:
99
+ mtf_samples = sample_autoregressive(
100
+ inputs, other_features=other_features, params=params, variable_dtype=variable_dtype,
101
+ remove_partial_sequences=params["remove_partial_sequences"], stop_at_token=params["eos_id"],
102
+ sampling_use_entmax=params['sampling_use_entmax'], max_steps=params["predict_max_steps"])
103
+
104
+ else:
105
+ with mtf.utils.outside_all_rewrites():
106
+ with tf.variable_scope('gpt2'):
107
+ mtf_samples, loss, loss_batch = gpt2.model(mtf_features, other_features, params, mesh,
108
+ variable_dtype=variable_dtype, context=None)
109
+
110
+ mtf_samples = mtf.anonymize(mtf_samples)
111
+ inputs = mtf.anonymize(inputs)
112
+ lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True)
113
+ inputs = lowering.export_to_tf_tensor(inputs)
114
+ outputs = lowering.export_to_tf_tensor(mtf_samples)
115
+ predictions = {
116
+ "inputs": inputs,
117
+ "outputs": outputs}
118
+
119
+ def scaffold_fn():
120
+ return tf.train.Scaffold(
121
+ local_init_op=tf.group(
122
+ tf.train.Scaffold.default_local_init_op(),
123
+ lowering.copy_masters_to_slices(),
124
+ name="mtf_local_init_op"),
125
+ ready_op=tf.concat(
126
+ [tf.report_uninitialized_variables(),
127
+ resources.report_uninitialized_resources()],
128
+ axis=0,
129
+ name="mtf_ready_op"))
130
+
131
+ return tpu_estimator.TPUEstimatorSpec(
132
+ mode=tf.estimator.ModeKeys.PREDICT,
133
+ predictions=predictions,
134
+ scaffold_fn=scaffold_fn,
135
+ prediction_hooks=[mtf.MtfRestoreHook(lowering)])
136
+
137
+ # We're not predicting, so we better be training or evaluating
138
+ assert (mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL)
139
+
140
+ if mode == tf.estimator.ModeKeys.TRAIN:
141
+ # Gets number of microbatches per batch for serialized training
142
+ # if param tokens_per_mb_per_replica = None, this defaults to 1 and no microbatching is performed
143
+ num_microbatches = int(mtf_transformer.utils.serialize_num_microbatches(batch_dim=batch_dim,
144
+ sequence_length=sequence_length_dict,
145
+ mesh_shape=mesh_shape,
146
+ layout_rules=layout_rules,
147
+ tokens_per_microbatch_per_replica=
148
+ params["tokens_per_mb_per_replica"]))
149
+ else:
150
+ num_microbatches = 1
151
+
152
+ params["num_microbatches"] = num_microbatches # Add num microbatches to params
153
+
154
+ if num_microbatches > 1:
155
+
156
+ # For serialize_training_step we need to modify the model to output results in a dict
157
+ def serialized_fn(mtf_features):
158
+ if params["model"] == "GPT":
159
+ with tf.variable_scope('gpt2'):
160
+ logits, loss, loss_batch = gpt2.model(mtf_features, other_features, params, mesh,
161
+ variable_dtype=variable_dtype)
162
+ return {"logits": logits, "loss": loss, "loss_batch": loss_batch}
163
+ else:
164
+ raise Exception(f"'{params['model']}' is not a valid model - please select from [GPT]")
165
+
166
+ # Serialize the training step - Gradients are accumulated locally and reduced once.
167
+ var_grads, output_dict = mtf.serialize_training_step(mtf_features, serialized_fn, batch_dim, num_microbatches)
168
+ loss = output_dict["loss"]
169
+ loss_batch = output_dict["loss_batch"]
170
+ logits = output_dict["logits"]
171
+ else:
172
+ # If we're not splitting into microbatches, return logits & loss as is
173
+ if params["model"] == "GPT":
174
+ with mtf.utils.outside_all_rewrites():
175
+ with tf.variable_scope('gpt2'):
176
+ logits, loss, loss_batch = gpt2.model(mtf_features, other_features, params, mesh,
177
+ variable_dtype=variable_dtype, context=None)
178
+ else:
179
+ raise Exception(f"'{params['model']}' is not a valid model - please select from [GPT]")
180
+
181
+ # Auto layout generation
182
+ if params["auto_layout"]:
183
+ auto_layout(graph, mesh_shape, logits, loss)
184
+ if params["auto_layout_and_mesh_shape"]:
185
+ auto_layout_and_mesh_shape(graph, params["num_cores"], logits, loss)
186
+
187
+ if mode == tf.estimator.ModeKeys.TRAIN:
188
+ # In TRAIN mode, get optimizer
189
+ if params["num_microbatches"] > 1:
190
+ # If we are splitting the batch into microbatches, var grads are created in the serialize_training_step fn
191
+ # So we pass them in here
192
+ _, update_ops, var_grads = get_optimizer(mesh, loss, params, variable_dtype=variable_dtype,
193
+ inp_var_grads=var_grads)
194
+ else:
195
+ # Otherwise, they are created in the get_optimizer fn, so we leave inp_var_grads blank
196
+ _, update_ops, var_grads = get_optimizer(mesh, loss, params, variable_dtype=variable_dtype)
197
+ # Log summaries to tensorboard
198
+ mtf.scalar_summary("loss", loss)
199
+ # Log gradients if in params
200
+ if params["log_grads"] not in [None, False]:
201
+ for g in var_grads:
202
+ grad_norm = mtf.sqrt(mtf.reduce_sum(mtf.square(g)))
203
+ mtf.scalar_summary("grads/norm" + g.name[:-2], grad_norm)
204
+ else:
205
+ # For now, we can only export fully-replicated tensors.
206
+ # This has to be done before lowering or they will not be included in the graph
207
+ mean_logits = mtf.reduce_mean(logits, reduced_dim=vocab_dim)
208
+ max_logits = mtf.argmax(logits, vocab_dim)
209
+ del logits
210
+ fully_replicated_mean_logits = mtf.anonymize(mean_logits)
211
+ fully_replicated_max_logits = mtf.anonymize(max_logits)
212
+ fully_replicated_loss_batch = mtf.anonymize(loss_batch)
213
+
214
+ # Gets & prints info about no. trainable vars in the model & dimension names
215
+ get_graph_info(graph)
216
+
217
+ # 'lowers' mtf tensors into a tf graph - this enables us to export results as tf tensors
218
+ lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True)
219
+ tf_loss = lowering.export_to_tf_tensor(loss)
220
+ tf_loss = tf.cast(tf_loss, tf.float32)
221
+
222
+ if mode == tf.estimator.ModeKeys.TRAIN:
223
+ # Use our patched version until mtf updates theirs
224
+ host_call = create_host_call(params['model_path'])
225
+ mtf.utils.remove_summaries()
226
+
227
+ # Creates train_op
228
+ tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
229
+ tf_update_ops.append(tf.assign_add(global_step, 1)) # Need to manually increment global_step
230
+ tf.logging.info(f"tf_update_ops: {tf_update_ops}")
231
+ train_op = tf.group(tf_update_ops)
232
+ else:
233
+ tf_mean_logits = lowering.export_to_tf_tensor(fully_replicated_mean_logits)
234
+ tf_max_logits = lowering.export_to_tf_tensor(fully_replicated_max_logits)
235
+ tf_loss_batch = tf.to_float(lowering.export_to_tf_tensor(fully_replicated_loss_batch))
236
+
237
+ with mtf.utils.outside_all_rewrites():
238
+ # Copy master variables to slices. Must be called first.
239
+ restore_hook = mtf.MtfRestoreHook(lowering)
240
+ if mode == tf.estimator.ModeKeys.TRAIN:
241
+ # Set up the checkpoint server and return the TPUEstimatorSpec
242
+ saver = tf.train.Saver(
243
+ tf.global_variables(),
244
+ sharded=True,
245
+ max_to_keep=10,
246
+ keep_checkpoint_every_n_hours=2,
247
+ defer_build=False,
248
+ save_relative_paths=True)
249
+ tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
250
+ saver_listener = mtf.MtfCheckpointSaverListener(lowering)
251
+ saver_hook = tf.train.CheckpointSaverHook(
252
+ params["model_path"],
253
+ save_steps=params["steps_per_checkpoint"],
254
+ saver=saver,
255
+ listeners=[saver_listener])
256
+
257
+ return tpu_estimator.TPUEstimatorSpec(
258
+ tf.estimator.ModeKeys.TRAIN,
259
+ loss=tf_loss,
260
+ host_call=host_call,
261
+ train_op=train_op,
262
+ training_hooks=[restore_hook, saver_hook])
263
+
264
+ elif mode == tf.estimator.ModeKeys.EVAL:
265
+ # Evaluation metrics
266
+ def _perplexity(loss):
267
+ perplexity = tf.exp(loss)
268
+ return tf.metrics.mean(perplexity)
269
+
270
+ def _bits_per_byte(loss):
271
+ bpb = loss * (0.29335 / math.log(2))
272
+ return tf.metrics.mean(bpb)
273
+
274
+ def _metric_fn(tf_mean_logits, tf_loss_batch):
275
+ mean_logits = tf.metrics.mean(tf_mean_logits)
276
+ loss = tf.reduce_mean(tf_loss_batch)
277
+ perp = _perplexity(loss)
278
+ bpb = _bits_per_byte(loss)
279
+ return {"mean_logits": mean_logits, "perplexity": perp, "bits per byte": bpb}
280
+
281
+ def _lambada_metric_fn(labels, tf_max_logits, tf_loss_batch):
282
+ eos_token = params["eos_id"]
283
+ answer_positions = tf.where(tf.math.not_equal(labels, eos_token))
284
+
285
+ correct_answers = tf.gather_nd(tf.math.equal(tf_max_logits, labels), answer_positions)
286
+ accuracy = tf.metrics.mean(tf.cast(correct_answers, tf.float32))
287
+
288
+ # I guess tf_loss_batch has z_loss and maybe other stuff added to it
289
+ # so maybe this should be calculated separately in the future
290
+ answer_loss = tf.gather_nd(tf_loss_batch, answer_positions)
291
+ log_perplexity = tf.metrics.mean(answer_loss)
292
+
293
+ return {"lambada_acc": accuracy, "lambada_log_ppl": log_perplexity}
294
+
295
+ eval_task = params["eval_task"]
296
+ if eval_task == "lambada":
297
+ eval_metrics = (_lambada_metric_fn, [labels, tf_max_logits, tf_loss_batch])
298
+ else:
299
+ eval_metrics = (_metric_fn, [tf_mean_logits, tf_loss_batch])
300
+
301
+ return tpu_estimator.TPUEstimatorSpec(
302
+ tf.estimator.ModeKeys.EVAL,
303
+ evaluation_hooks=[restore_hook],
304
+ loss=tf_loss,
305
+ eval_metrics=eval_metrics)
optimizers.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import re
6
+ import mesh_tensorflow as mtf
7
+ import tensorflow.compat.v1 as tf
8
+
9
+ def clip_by_global_norm(grads, clip_norm):
10
+ """Clip the grads by global norm."""
11
+ global_norm = mtf.sqrt(mtf.add_n([mtf.reduce_sum(mtf.square(t)) for t in grads if t is not None]))
12
+ multiplier = clip_norm / mtf.maximum(global_norm, clip_norm)
13
+ clipped_grads = [None if t is None else t * multiplier for t in grads]
14
+ return clipped_grads, global_norm
15
+
16
+ def get_optimizer(mesh, loss, params, variable_dtype, inp_var_grads=None):
17
+ """Creates and returns an optimizer training op."""
18
+ global_step = tf.train.get_or_create_global_step()
19
+
20
+ learning_rate = tf.constant(value=params["lr"], shape=[], dtype=variable_dtype.slice_dtype)
21
+ clip_value = mtf.constant(mesh, params["gradient_clipping"], dtype=variable_dtype.slice_dtype)
22
+
23
+ if inp_var_grads is None:
24
+ var_grads = mtf.gradients([loss], [v.outputs[0] for v in mesh.graph.trainable_variables])
25
+ else:
26
+ var_grads = inp_var_grads
27
+
28
+ # Cast to full precision
29
+ var_grads_fp = [mtf.cast(v, variable_dtype.slice_dtype) for v in var_grads]
30
+
31
+ # decrease LR to final lr (lr*0.1) by this step - defaults to train_steps
32
+ end_step = params.get("lr_decay_end", params["train_steps"])
33
+
34
+ if params["lr_decay"] == "linear":
35
+ learning_rate = tf.train.polynomial_decay(
36
+ learning_rate,
37
+ global_step,
38
+ end_step,
39
+ end_learning_rate=params["lr"]*0.1, # Decrease to 10% of initial LR according to GPT-3 paper
40
+ power=1.0,
41
+ cycle=False)
42
+ elif params["lr_decay"] == "cosine":
43
+ learning_rate = tf.train.cosine_decay(
44
+ learning_rate,
45
+ global_step,
46
+ end_step,
47
+ alpha=0.1 # Alpha is min lr value as a fraction of init lr.
48
+ )
49
+
50
+ if params["warmup_steps"] > 0:
51
+ global_steps_int = tf.cast(global_step, tf.int32)
52
+ warmup_steps_int = tf.constant(params["warmup_steps"], dtype=tf.int32)
53
+
54
+ dtype = variable_dtype.slice_dtype
55
+
56
+ global_steps_float = tf.cast(global_steps_int, dtype)
57
+ warmup_steps_float = tf.cast(warmup_steps_int, dtype)
58
+
59
+ warmup_percent_done = global_steps_float / warmup_steps_float
60
+ warmup_learning_rate = learning_rate * warmup_percent_done
61
+
62
+ is_warmup = tf.cast(global_steps_int < warmup_steps_int, dtype)
63
+ learning_rate = ((1.0 - is_warmup) * learning_rate +
64
+ is_warmup * warmup_learning_rate)
65
+
66
+ learning_rate = mtf.import_fully_replicated(mesh, learning_rate, mtf.Shape([]), name="learning_rate")
67
+ mtf.scalar_summary("lr", learning_rate)
68
+
69
+ if params["opt_name"].lower() == "adam":
70
+ optimizer = AdamWeightDecayOptimizer(
71
+ learning_rate=learning_rate,
72
+ weight_decay_rate=params["weight_decay"],
73
+ beta_1=params["beta1"],
74
+ beta_2=params["beta2"],
75
+ epsilon=params["epsilon"],
76
+ exclude_from_weight_decay=["norm", "bias"],
77
+ variable_dtype=variable_dtype
78
+ )
79
+ else:
80
+ optimizer = mtf.optimize.AdafactorOptimizer(
81
+ learning_rate=params["lr"],
82
+ decay_rate=params["weight_decay"],
83
+ beta1=params["beta1"],
84
+ epsilon1=params["ada_epsilon1"],
85
+ epsilon2=params["ada_epsilon2"]
86
+ )
87
+
88
+ if params["gradient_clipping"] is not None:
89
+ (var_grads_fp, _) = clip_by_global_norm(var_grads_fp, clip_norm=clip_value)
90
+
91
+ update_ops = optimizer.apply_grads(var_grads_fp, mesh.graph.trainable_variables)
92
+ return learning_rate, update_ops, var_grads_fp
93
+
94
+
95
+ class AdamWeightDecayOptimizer(mtf.optimize.Optimizer):
96
+ """A basic Adam optimizer that includes "correct" L2 weight decay."""
97
+
98
+ def __init__(self,
99
+ learning_rate,
100
+ weight_decay_rate=0.0,
101
+ beta_1=0.9,
102
+ beta_2=0.999,
103
+ epsilon=1e-6,
104
+ exclude_from_weight_decay=None,
105
+ variable_dtype=None):
106
+ """Constructs a AdamWeightDecayOptimizer."""
107
+
108
+ self.learning_rate = learning_rate
109
+ self.weight_decay_rate = weight_decay_rate
110
+ self.beta_1 = beta_1
111
+ self.beta_2 = beta_2
112
+ self.epsilon = epsilon
113
+ self.exclude_from_weight_decay = exclude_from_weight_decay
114
+ self.variable_dtype = variable_dtype
115
+
116
+ def apply_grad(self, grad, var):
117
+ """See base class."""
118
+ if grad is None:
119
+ tf.logging.warning("Gradient is None for variable %s" % var.name)
120
+ return []
121
+
122
+ grad = mtf.to_float(grad)
123
+
124
+ assignments = []
125
+
126
+ m = mtf.get_variable(
127
+ var.mesh, var.name + "/adam_m", var.shape,
128
+ initializer=tf.zeros_initializer(),
129
+ # master_dtype=self.variable_dtype.master_dtype,
130
+ # slice_dtype=self.variable_dtype.slice_dtype,
131
+ # activation_dtype=self.variable_dtype.activation_dtype,
132
+ trainable=False)
133
+
134
+ v = mtf.get_variable(
135
+ var.mesh, var.name + "/adam_v", var.shape,
136
+ initializer=tf.zeros_initializer(),
137
+ # master_dtype=self.variable_dtype.master_dtype,
138
+ # slice_dtype=self.variable_dtype.slice_dtype,
139
+ # activation_dtype=self.variable_dtype.activation_dtype,
140
+ trainable=False)
141
+
142
+ # Standard Adam update.
143
+ next_m = self.beta_1 * m + (1.0 - self.beta_1) * grad
144
+ next_v = self.beta_2 * v + (1.0 - self.beta_2) * mtf.square(grad)
145
+
146
+ update = next_m / (mtf.sqrt(next_v) + self.epsilon)
147
+
148
+ # Just adding the square of the weights to the loss function is *not*
149
+ # the correct way of using L2 regularization/weight decay with Adam,
150
+ # since that will interact with the m and v parameters in strange ways.
151
+ #
152
+ # Instead we want to decay the weights in a manner that doesn't interact
153
+ # with the m/v parameters. This is equivalent to adding the square
154
+ # of the weights to the loss with plain (non-momentum) SGD.
155
+ if self._do_use_weight_decay(var.name):
156
+ update += mtf.to_float(var.value) * self.weight_decay_rate
157
+
158
+ update_with_lr = self.learning_rate * update
159
+
160
+ var_update = mtf.assign_sub(var, update_with_lr)
161
+
162
+ assignments.extend(
163
+ [var_update,
164
+ mtf.assign(m, next_m),
165
+ mtf.assign(v, next_v)])
166
+ return assignments
167
+
168
+ def _do_use_weight_decay(self, param_name):
169
+ """Whether to use L2 weight decay for `param_name`."""
170
+ if not self.weight_decay_rate:
171
+ return False
172
+ if self.exclude_from_weight_decay:
173
+ for r in self.exclude_from_weight_decay:
174
+ if re.search(r, param_name) is not None:
175
+ return False
176
+ return True
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ google-api-python-client
2
+ jsonlines
3
+ lm_dataformat
4
+ mesh-tensorflow==0.1.18
5
+ numpy
6
+ oauth2client
7
+ ortools
8
+ pytest
9
+ sacred
10
+ tensorflow==2.5.1
11
+ tensorflow-datasets==3.2.1
12
+ tokenizers==0.9.4
13
+ transformers==4.1.1
14
+ tpunicorn
15
+ absl-py
16
+ ftfy
17
+ sacred
18
+ pymongo
run_experiment.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import atexit
2
+ import sacred
3
+ import argparse
4
+ import time
5
+ import math
6
+ import subprocess
7
+ import shutil
8
+ import os
9
+ import json
10
+ import threading
11
+ import requests
12
+ import glob
13
+ from configs import fetch_model_params
14
+ import socket
15
+ import subprocess
16
+ import queue
17
+ import sys
18
+ import signal
19
+
20
+
21
+ parser = argparse.ArgumentParser()
22
+ parser.add_argument('--tpu', type=str, required=True) # Name of TPU to train on, if any
23
+ parser.add_argument('--model', type=str, required=True) # JSON file that contains model parameters
24
+ parser.add_argument('--experiment_name', type=str, required=True) # name of experiment (will show up in omniboard)
25
+ parser.add_argument('--steps_per_checkpoint', type=int, default=5000)
26
+ parser.add_argument('--autostack', action="store_false")
27
+ parser.add_argument('--auto_layout', action="store_true")
28
+ parser.add_argument('--auto_layout_and_mesh_shape', action="store_true")
29
+ parser.add_argument('--new', action='store_true')
30
+ parser.add_argument('--test', action='store_true')
31
+ parser.add_argument('--eval', action='store_true')
32
+ parser.add_argument('--predict', action='store_true')
33
+ parser.add_argument('--no_delete_tpu', action='store_true')
34
+ parser.add_argument('--initial_heartbeat_timeout', type=int, default=7200)
35
+ parser.add_argument('--heartbeat_timeout', type=int, default=1800) # kill and restart if nothing logged to tensorboard in this many seconds
36
+ args = parser.parse_args()
37
+
38
+ params = fetch_model_params(args.model)
39
+
40
+ ex = sacred.Experiment(args.experiment_name)
41
+ ex.observers.append(sacred.observers.QueuedMongoObserver(url='127.0.0.1:27017', db_name='db', username='user', password='password'))
42
+
43
+
44
+ def get_open_port(lo=8000, hi=8100):
45
+ for i in range(lo, hi):
46
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
47
+ if s.connect_ex(('localhost', i)) != 0:
48
+ return i
49
+
50
+
51
+ def train_thread(args, tpu, id, q):
52
+ print('starting training on', tpu)
53
+
54
+ # pass binary flags through
55
+ opts = ''
56
+ for flag in ['auto_layout', 'auto_layout_and_mesh_shape', 'new', 'test', 'predict', 'eval', ]:
57
+ if args.__getattribute__(flag):
58
+ opts += ' --' + flag
59
+
60
+ for flag in ['autostack', ]:
61
+ if not args.__getattribute__(flag):
62
+ opts += ' --' + flag
63
+
64
+ cmd = "python3 main.py --tpu {tpu} --model run_configs/config_{id}.json --steps_per_checkpoint {steps_per_checkpoint} {opts} --sacred_id {run_id}".format(tpu=tpu, id=id, steps_per_checkpoint=args.steps_per_checkpoint, opts=opts, run_id=id)
65
+ print('Running:', cmd)
66
+ proc = subprocess.Popen(cmd, shell=True)
67
+
68
+ # poll until it's exited
69
+ while proc.poll() is None:
70
+ time.sleep(60)
71
+ try:
72
+ nq, *nargs = q.get_nowait()
73
+ if nq == 'kill':
74
+ print('train thread recieved kill signal from logging thread')
75
+ # first send SIGTERM
76
+ proc.terminate()
77
+
78
+ time.sleep(60)
79
+
80
+ # if it still hasn't exited, we send SIGKILL
81
+ if proc.poll() is None:
82
+ print('SIGTERM not successful, sending SIGKILL')
83
+ proc.kill()
84
+
85
+ except queue.Empty:
86
+ pass
87
+
88
+ print('exited training!')
89
+ if proc.returncode == 0:
90
+ print('exited gracefully')
91
+ os.kill(os.getpid(), signal.SIGINT)
92
+ return
93
+
94
+ if args.no_delete_tpu:
95
+ print('recreate done, exiting train_thread - not killing tpu!')
96
+ return
97
+ print("Recreating {} in 60sec...".format(tpu))
98
+ time.sleep(60)
99
+ os.system("pu recreate {} --yes --retry 3600 --retry-randomness 1.5".format(tpu))
100
+ print('recreate done, exiting train_thread')
101
+
102
+ # clear out queue
103
+ while True:
104
+ try:
105
+ q.get_nowait()
106
+ print('dropped request in queue after pu recreate')
107
+ except queue.Empty:
108
+ break
109
+
110
+
111
+ def get_json(uri, params=None, timeout=15):
112
+ resp = requests.get(uri, params=params, timeout=timeout)
113
+ resp.raise_for_status()
114
+ return resp.json()
115
+
116
+
117
+ def get_tag_sets(base_uri):
118
+ j = get_json(f'{base_uri}/data/plugin/scalars/tags', {'experiment': ''})
119
+ assert isinstance(j, dict)
120
+ return {
121
+ run: j[run].keys()
122
+ for run in j.keys()
123
+ }
124
+
125
+
126
+ def get_scalar_data(base_uri, run, tag):
127
+ j = get_json(f'{base_uri}/data/plugin/scalars/scalars', {'experiment': '', 'run': run, 'tag': tag})
128
+ assert isinstance(j, list)
129
+ return j
130
+
131
+
132
+ def get_run_data(port):
133
+ base_uri = f'http://localhost:{port}/'
134
+ r = {}
135
+ try:
136
+ tag_sets = get_tag_sets(base_uri)
137
+ runs = tag_sets.keys()
138
+ if '.' in runs:
139
+ if 'loss' in tag_sets['.']:
140
+ r['loss'] = get_scalar_data(base_uri, '.', 'loss')
141
+ if 'eval' in runs:
142
+ if 'loss' in tag_sets['eval']:
143
+ r['val_loss'] = get_scalar_data(base_uri, 'eval', 'loss')
144
+ if 'eval_lambada' in runs:
145
+ if 'lambada_acc' in tag_sets['eval_lambada']:
146
+ r['lambada_acc'] = get_scalar_data(base_uri, 'eval_lambada', 'lambada_acc')
147
+ if 'lambada_log_ppl' in tag_sets['eval_lambada']:
148
+ r['lambada_ppl'] = [
149
+ [t, s, math.exp(lp)]
150
+ for [t, s, lp] in get_scalar_data(base_uri, 'eval_lambada', 'lambada_log_ppl')
151
+ ]
152
+ except:
153
+ import traceback
154
+ traceback.print_exc()
155
+ return r
156
+
157
+
158
+ @ex.main
159
+ def main(_run):
160
+ print('Starting run', _run._id)
161
+ print('experiment main invoked with argv:', " ".join(sys.argv))
162
+ print('WARNING: please remember to remove old metric log files from the model directory.')
163
+
164
+ os.makedirs('run_configs', exist_ok=True)
165
+ shutil.copy(args.model if args.model.endswith('.json') else 'configs/{}.json'.format(args.model), 'run_configs/config_{}.json'.format(_run._id))
166
+
167
+ tensorboard_port = get_open_port()
168
+ print('Tensorboard at port:', tensorboard_port)
169
+ print('Tensorboard url: ', 'http://eleutherai.bmk.sh:'+ str(tensorboard_port))
170
+ os.system("screen -S tensorboard_{} -d -m bash -c 'tensorboard --logdir {} --port {} --bind_all --reload_multifile=true || tensorboard --logdir {} --port {} --reload_multifile=true'".format(_run._id, params["model_path"], tensorboard_port,params["model_path"], tensorboard_port,))
171
+ atexit.register(goodbye, _run._id)
172
+
173
+ curr_step = {}
174
+ seen_predictions = set()
175
+
176
+ heartbeat_timeout = args.initial_heartbeat_timeout * 2
177
+ while True:
178
+ last_tb_log_time = time.time()
179
+ start_time = time.time()
180
+ q = queue.Queue()
181
+ trainthd = threading.Thread(target=train_thread, args=(args, args.tpu, _run._id, q))
182
+ trainthd.start()
183
+
184
+ while trainthd.is_alive():
185
+ time.sleep(60)
186
+
187
+ if start_time + args.initial_heartbeat_timeout < time.time():
188
+ # after initial args.initial_heartbeat_timeout grace period, now we want to set the timeout threshold much lower
189
+ heartbeat_timeout = args.heartbeat_timeout
190
+
191
+ print('Polling tensorboard for metrics...')
192
+ data = get_run_data(tensorboard_port)
193
+ for k in data.keys():
194
+ for ts, step, val in data[k]:
195
+ if step <= curr_step.get(k, -1):
196
+ continue
197
+ _run.log_scalar(k, val, step)
198
+ if k == 'loss':
199
+ _run.log_scalar('tb_ts', ts, step)
200
+ print('Logged to sacred: step={},loss={},tb_ts={}'.format(step, val, ts))
201
+
202
+ # found something new, so logging!
203
+ last_tb_log_time = time.time()
204
+
205
+ curr_step[k] = step
206
+
207
+ for f in glob.glob('predictions_{}_*'.format(_run._id)):
208
+ if f in seen_predictions:
209
+ continue
210
+ print('collecting prediction file', f)
211
+ ex.add_artifact(f)
212
+
213
+ seen_predictions.add(f)
214
+
215
+ # collect eval metrics from jsonl
216
+ if os.path.exists(f'eval_{_run._id}.jsonl'):
217
+ with open(f'eval_{_run._id}.jsonl') as fh:
218
+ for line in fh:
219
+ ob = json.loads(line)
220
+ val_step = ob['global_step']
221
+ val_task = ob['task']
222
+ for metr in ob.keys():
223
+ k = 'fs.' + val_task + '.' + metr
224
+ if metr in ['task', 'global_step']: continue
225
+ if val_step <= curr_step.get(k, -1): continue
226
+ _run.log_scalar(k, ob[metr], val_step)
227
+ curr_step[k] = val_step
228
+
229
+ if time.time() - last_tb_log_time > heartbeat_timeout:
230
+ # the run hasn't logged in a while, so we restart it
231
+ q.put(('kill',))
232
+
233
+ # give training thread some time to do its thing and recreate tpu
234
+ while trainthd.is_alive():
235
+ print('logging thread waiting for killing stalled run and for tpu recreate to finish')
236
+ time.sleep(60)
237
+
238
+ # reset heartbeat timeout to initial
239
+ heartbeat_timeout = args.initial_heartbeat_timeout
240
+ last_tb_log_time = time.time()
241
+
242
+
243
+ if args.no_delete_tpu:
244
+ break
245
+
246
+
247
+ def goodbye(id):
248
+ print("You are now leaving the Python sector.")
249
+ print("Sie verlassen den pythonischen Sektor.")
250
+
251
+ os.system("screen -S tensorboard_{} -X quit".format(id))
252
+
253
+
254
+ if __name__ == '__main__':
255
+ for file in glob.glob("**/*", recursive=True):
256
+ if file.split('.')[-1] in ['py']:
257
+ print('Adding', file, 'to sacred')
258
+ ex.add_source_file(file)
259
+
260
+ ex.add_config({
261
+ 'tpu_name': args.tpu,
262
+ **params
263
+ })
264
+
265
+ ex.run()
sample.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mesh_tensorflow as mtf
2
+ import tensorflow.compat.v1 as tf
3
+ import mesh_tensorflow.transformer as mtf_transformer
4
+
5
+ from models.utils import entmax, sample_categorical
6
+ from models.gpt2 import gpt2
7
+
8
+ def sample_autoregressive(partial_sequences,
9
+ other_features,
10
+ params,
11
+ stop_at_token=50256,
12
+ max_steps=None,
13
+ temperature=0.9,
14
+ variable_dtype=mtf.VariableDType(tf.float32),
15
+ encoder_output=None,
16
+ encoder_sequence_id=None,
17
+ encoder_inputs=None,
18
+ shared_params=None,
19
+ has_partial_sequences=True,
20
+ encoder_layer_outputs=None,
21
+ never_end=False,
22
+ remove_partial_sequences=False,
23
+ sampling_keep_top_k=-1,
24
+ sampling_use_entmax = False,
25
+ bos_id=50256,
26
+ ):
27
+ """Sample randomly one token at a time.
28
+
29
+ The partial_sequences represent partial sequences to be continued. The
30
+ first tokens of each sequence are nonzero representing the given partial
31
+ sequences and the last tokens of each sequence are zeros, representing what
32
+ needs to be filled in.
33
+
34
+ If there are no partial sequences (you want to sample from the beginning),
35
+ then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and
36
+ has_partial_sequences=False (so we can skip computation).
37
+
38
+ Args:
39
+ partial_sequences: an int32 Tensor with shape [<batch_dims>, length_dim]
40
+ stop_at_token: an optional integer eos id. Stop when we produce it.
41
+ max_steps: an optional integer, the max number of steps to decode.
42
+ temperature: an optional floating point value between 0.0 and 1.0 0.0
43
+ means argmax, 1.0 means sample according to predicted distribution.
44
+ variable_dtype: a mtf.VariableDType
45
+ encoder_output: an optional Tensor
46
+ encoder_sequence_id: an optional Tensor
47
+ encoder_inputs: an optional Tensor
48
+ shared_params: an optional dictionary
49
+ has_partial_sequences: a boolean
50
+ encoder_layer_outputs: optional - readonly list of tensor activations when
51
+ decoding, one per each input layer + the embedding layer
52
+ never_end: a boolean - if set, then avoid generating stop_at_token
53
+ remove_partial_sequences: a boolean - whether to remove the partial
54
+ sequences from the output
55
+ sampling_keep_top_k: an integer - if not -1, only sample from the top k
56
+ logits.
57
+ bos_id: beginning of sequence id
58
+
59
+ Returns:
60
+ a Tensor with shape [<batch_dims>, length_dim]
61
+ """
62
+
63
+ inputs = partial_sequences # Partial sequences to fill in
64
+ batch_dims = inputs.shape.dims[:-1]
65
+ length_dim = inputs.shape.dims[-1]
66
+ padding_id = params.get("padding_id", 0)
67
+ slow_sampling = params.get("slow_sampling", False)
68
+
69
+
70
+ initial_position = mtf.reduce_sum(
71
+ mtf.to_int32(mtf.not_equal(inputs, padding_id)), reduced_dim=length_dim) # Gets position where zero padding starts
72
+
73
+ length_range = mtf.range(inputs.mesh, length_dim, tf.int32)
74
+ input_full_attention = True # for now hardcode this to true bc lazy
75
+ if input_full_attention:
76
+ # Vanilla autoregressive model - each position can see previous positions.
77
+ # Think this feeds in to the loop fn and tells each position where it can attend to?
78
+ read_priority = write_priority = length_range * mtf.to_int32(
79
+ mtf.greater(length_range, initial_position))
80
+ else:
81
+ read_priority = write_priority = length_range
82
+
83
+ # Builds context to pass around internally
84
+ # The 'first part' context records initial states of k / v / x
85
+
86
+ if not slow_sampling:
87
+ context_first_part = mtf_transformer.transformer.Context(
88
+ model=None,
89
+ mesh=inputs.mesh,
90
+ batch_dims=batch_dims,
91
+ length_dim=length_dim,
92
+ variable_dtype=variable_dtype,
93
+ mode="first_part",
94
+ position=length_range,
95
+ position_is_default=True,
96
+ new_states=[],
97
+ initial_position=initial_position,
98
+ sequence_id=None,
99
+ encoder_output=encoder_output,
100
+ encoder_sequence_id=encoder_sequence_id,
101
+ constant_states=[],
102
+ shared_params=shared_params,
103
+ encoder_layer_outputs=encoder_layer_outputs,
104
+ write_priority=write_priority,
105
+ read_priority=read_priority,
106
+ inputs=inputs,
107
+ encoder_inputs=encoder_inputs)
108
+
109
+ with tf.variable_scope("gpt2"):
110
+ logits, _, _ = gpt2.model({"inputs": inputs}, other_features, params, inputs.mesh, variable_dtype=variable_dtype, context=context_first_part)
111
+
112
+ if not has_partial_sequences:
113
+ initial_states = [mtf.zeros_like(t) for t in context_first_part.new_states]
114
+ else:
115
+ initial_states = context_first_part.new_states
116
+ else:
117
+ initial_states = []
118
+
119
+ if not has_partial_sequences:
120
+ partial_sequences_eos_count = 0
121
+
122
+ if stop_at_token is not None:
123
+ partial_sequences_eos_count = mtf.reduce_sum(
124
+ mtf.to_int32(mtf.equal(partial_sequences, stop_at_token)),
125
+ reduced_dim=length_dim)
126
+
127
+ def cond_fn(position, ids, *unused_states):
128
+ """Should we run another loop iteration?"""
129
+ past_end = mtf.greater_equal(position, length_dim.size)
130
+ if max_steps:
131
+ past_end = mtf.logical_or(
132
+ past_end, mtf.greater_equal(position - initial_position, max_steps))
133
+
134
+ is_done = past_end
135
+ if stop_at_token is not None:
136
+ eos_count = mtf.reduce_sum(
137
+ mtf.to_int32(mtf.equal(ids, stop_at_token)),
138
+ reduced_dim=length_dim)
139
+ has_additional_eos = mtf.greater(eos_count, partial_sequences_eos_count)
140
+ is_done = mtf.logical_or(is_done, has_additional_eos)
141
+ all_done = mtf.reduce_all(is_done)
142
+ return mtf.logical_not(all_done)
143
+
144
+ def body_fn(position, ids, *states):
145
+ """One step in the decode loop."""
146
+ nonlocal sampling_keep_top_k
147
+
148
+ context = mtf_transformer.transformer.Context(
149
+ model=None,
150
+ mesh=inputs.mesh,
151
+ batch_dims=batch_dims,
152
+ length_dim=length_dim,
153
+ variable_dtype=variable_dtype,
154
+ mode="incremental",
155
+ position=position,
156
+ position_is_default=True,
157
+ states=states,
158
+ new_states=[],
159
+ initial_position=position,
160
+ sequence_id=None,
161
+ encoder_output=encoder_output,
162
+ encoder_sequence_id=encoder_sequence_id,
163
+ shared_params=shared_params,
164
+ encoder_layer_outputs=encoder_layer_outputs,
165
+ write_priority=write_priority,
166
+ read_priority=read_priority,
167
+ inputs=ids,
168
+ encoder_inputs=encoder_inputs) if not slow_sampling else None
169
+
170
+ with tf.variable_scope("gpt2", reuse=tf.AUTO_REUSE):
171
+ logits, _, _ = gpt2.model({"inputs": ids}, other_features, params, inputs.mesh, variable_dtype=variable_dtype, context = context)
172
+
173
+ if not sampling_use_entmax:
174
+ # By default, do top_k sampling of 0.9
175
+ if sampling_keep_top_k == -2:
176
+ sampling_keep_top_k = int(logits.shape[-1].size * 0.1)
177
+
178
+ if sampling_keep_top_k != -1:
179
+ if sampling_keep_top_k <= 0:
180
+ raise ValueError("sampling_keep_top_k must either be -1 or positive.")
181
+ k_largest = mtf.nth_largest_element(
182
+ logits, n=sampling_keep_top_k,
183
+ reduced_dim=other_features["vocab_dim"])
184
+ logits = mtf.where(mtf.less_equal(logits, k_largest),
185
+ mtf.ones_like(logits) * -1e6, logits)
186
+
187
+ ids_this_step = mtf.sample_with_temperature(
188
+ logits, other_features["vocab_dim"], temperature)
189
+ else:
190
+ ids_this_step = sample_categorical(entmax(logits))
191
+
192
+ if slow_sampling:
193
+ ids_this_step = mtf.shift(ids_this_step, offset=1, dim=length_dim, wrap=False)
194
+ else:
195
+ ids_this_step = mtf.reshape(ids_this_step, (batch_dims))
196
+
197
+ one_hot = mtf.one_hot(position, length_dim, dtype=tf.int32)
198
+ one_new_id = ids_this_step * one_hot
199
+ new_ids = (1 - one_hot) * ids + one_new_id
200
+ new_position = position + 1
201
+
202
+ ret = [new_position, new_ids]
203
+ if context is not None:
204
+ ret += context.new_states
205
+ return ret
206
+
207
+ while_loop_inputs = [initial_position, inputs] + initial_states
208
+ final_position, outputs = mtf.while_loop(
209
+ cond_fn, body_fn, while_loop_inputs)[:2]
210
+ del final_position
211
+ if has_partial_sequences and remove_partial_sequences:
212
+ # Remove partial sequences from outputs
213
+ partial_length = mtf.reduce_sum(
214
+ mtf.to_int32(mtf.not_equal(partial_sequences, padding_id)),
215
+ reduced_dim=length_dim)
216
+ outputs = mtf.dynamic_shift(
217
+ outputs, -partial_length, length_dim, wrap=False)
218
+ return outputs
tasks.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import json
3
+ import requests
4
+ import numpy as np
5
+ import ftfy
6
+ from data.encoders import fetch_encoder, encode
7
+ import tensorflow as tf
8
+ import re
9
+ from functools import partial
10
+
11
+ lambada_src_uri = 'http://eaidata.bmk.sh/data/lambada_test.jsonl'
12
+ normalization = 'NFKC'
13
+
14
+
15
+ # Note: this task is called "lambada" but it really refers to OpenAI's version
16
+ # of the task, which actually differs in some ways from the task described in
17
+ # the original paper. So, strictly speaking, accuracy values from this task
18
+ # should not be compared to accuracy values from the original lambada task.
19
+ # For more information, see
20
+ # https://github.com/openai/gpt-2/issues/131
21
+
22
+ def lambada_create_tokens_data(params, path):
23
+ with open(path, 'w') as f:
24
+ req = requests.get(lambada_src_uri)
25
+ req.raise_for_status()
26
+ jsons = [json.loads(l) for l in req.iter_lines()]
27
+ texts = [ftfy.fix_text(j['text'], normalization=normalization) for j in jsons]
28
+ enc = fetch_encoder(params)
29
+ arrays = [encode(enc, t) for t in texts]
30
+ json.dump(arrays, f)
31
+ return arrays
32
+
33
+
34
+ def lambada_read_or_create_tokens_data(params, path):
35
+ # if you tell me where the file should go, i will helpfully create it for you
36
+ if not os.path.exists(path):
37
+ return lambada_create_tokens_data(params, path)
38
+ with open(path) as f:
39
+ return json.load(f)
40
+
41
+
42
+ def bin_pack(params, tokens_data):
43
+ eos_token = params['eos_id']
44
+ n_ctx = params['n_ctx']
45
+ dummy_token = 1
46
+ pad_batch_size = params['eval_batch_size']
47
+ bins = []
48
+ for a in tokens_data:
49
+ if len(bins) == 0 or len(bins[-1]) + len(a) + 1 > n_ctx:
50
+ bins.append([])
51
+ bins[-1] += a
52
+ bins[-1].append(eos_token)
53
+ while len(bins) % pad_batch_size != 0:
54
+ bins.append([])
55
+ bins_array = np.full((len(bins), n_ctx), dummy_token, dtype=np.uint16)
56
+ for i, b in enumerate(bins):
57
+ bins_array[i, 0:len(b)] = b
58
+ return bins_array
59
+
60
+
61
+ def lambada_init(params):
62
+ ds_configs = params['dataset_configs']
63
+ l = [
64
+ ds_configs[ds_id].get('lambada_tokens_path', "./lambada.json")
65
+ for ds_id, _, _, _ in params['datasets']
66
+ ]
67
+ assert len(l) > 0, 'lambada_tokens_path not found in the dataset config'
68
+ lt_path = l[0]
69
+ assert lt_path.endswith('.json'), 'lambada_tokens_path must have extension json'
70
+
71
+ tokens_data = lambada_read_or_create_tokens_data(params, lt_path)
72
+ bins_array = bin_pack(params, tokens_data)
73
+ params['lambada_tokens_path'] = lt_path
74
+ params['lambada_n_steps'] = len(bins_array) // params['eval_batch_size']
75
+
76
+
77
+ def lambada_get_task_info(params):
78
+ return {
79
+ 'n_steps': params['lambada_n_steps'],
80
+ }
81
+
82
+
83
+ # The LAMBADA evaluation code looks at the logits of each position just before an eos_token
84
+ def lambada_input(params):
85
+ eos_token = 50256 if params['n_vocab'] >= 50257 else 0
86
+ n_ctx = params['n_ctx']
87
+ lt_path = params['lambada_tokens_path']
88
+ tokens_data = lambada_read_or_create_tokens_data(params, lt_path)
89
+ bins_array = bin_pack(params, tokens_data)
90
+ dataset = tf.data.Dataset.from_tensor_slices(bins_array)
91
+
92
+ def _get_output(bin):
93
+ bin = tf.cast(bin, dtype=tf.int32)
94
+ indexes = tf.range(n_ctx)
95
+ results = tf.gather(bin, (indexes + 1) % n_ctx)
96
+ eos_next_positions = tf.math.equal(tf.gather(bin, (indexes + 2) % n_ctx), eos_token)
97
+ output = tf.where(eos_next_positions, results, tf.constant(eos_token, shape=[n_ctx]))
98
+ bin = tf.reshape(bin, [n_ctx])
99
+ bin = tf.cast(bin, dtype=tf.int32)
100
+ output = tf.reshape(output, [n_ctx])
101
+ output = tf.cast(output, dtype=tf.int32)
102
+ return bin, output
103
+
104
+ dataset = dataset.map(_get_output)
105
+ dataset = dataset.batch(params['eval_batch_size'], drop_remainder=True)
106
+ dataset = dataset.repeat()
107
+ return dataset
108
+
109
+
110
+ task_descriptors = {
111
+ 'lambada': {
112
+ 'init_fn': lambada_init,
113
+ 'get_task_info_fn': lambada_get_task_info,
114
+ 'input_fn': lambada_input,
115
+ }
116
+ }
test_models.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import traceback
3
+ import logging
4
+ from collections import defaultdict
5
+ from contextlib import contextmanager
6
+
7
+ import tensorflow as tf
8
+ tf.compat.v1.enable_eager_execution()
9
+ import mesh_tensorflow as mtf
10
+ from mesh_tensorflow import placement_mesh_impl
11
+
12
+ from inputs import mlm_sample_text
13
+ from models.gpt2 import gpt2
14
+ from models.utils import biasmask_attn_weights, entmax, sample_categorical
15
+
16
+ from sample import sample_autoregressive
17
+
18
+ # helper functions
19
+
20
+ @contextmanager
21
+ def not_raises(exception):
22
+ try:
23
+ yield
24
+ except exception:
25
+ logging.error(traceback.format_exc())
26
+ raise pytest.fail("DID RAISE {0}".format(exception))
27
+
28
+ # fixtures
29
+
30
+ params = defaultdict(lambda: None, {
31
+ "n_head": 1,
32
+ "n_ctx": 4,
33
+ "n_embd": 2,
34
+ "n_vocab": 256,
35
+ "embed_dropout": 0.,
36
+ "n_layer": 2,
37
+ "num_microbatches": 1,
38
+ "train_batch_size": 1,
39
+ "causal": True,
40
+ "attention_types": ['global', 'local'],
41
+ "res_dropout": 0.1,
42
+ "rotary_emb": True,
43
+ "activation_function": "gelu",
44
+ "moe_layers": (1,),
45
+ "num_mem_kv": 16,
46
+ "no_weight_tie": True,
47
+ "moe_params": {
48
+ 'moe_dropout_rate': 0.0
49
+ },
50
+ "mesh_shape": [],
51
+ "layout": {},
52
+ "local_attention_radius": 128,
53
+ "share_parameters": True,
54
+ "rezero": True
55
+ })
56
+
57
+ # tests
58
+
59
+ def test_model():
60
+ graph = mtf.Graph()
61
+ mesh = mtf.Mesh(graph, "my_mesh")
62
+
63
+ seq_len = params["n_ctx"]
64
+
65
+ batch_dim = mtf.Dimension("batch", 1)
66
+ sequence_dim = mtf.Dimension("sequence", seq_len)
67
+
68
+ features = {
69
+ 'inputs': mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32),
70
+ 'labels': mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32)
71
+ }
72
+
73
+ # create mask
74
+
75
+ num_mem_kv = params.get('num_mem_kv', 0)
76
+ length_dim = mtf.Dimension('sequence', seq_len)
77
+ memory_length_dim = mtf.Dimension('memory_length', seq_len + num_mem_kv)
78
+ embed_sequence_dim = mtf.Dimension('embed_sequence', seq_len)
79
+ embd_dim = mtf.Dimension("embd", params["n_embd"])
80
+ vocab_dim = mtf.Dimension("vocab", params["n_vocab"])
81
+
82
+ other_features = {}
83
+ variable_dtype = mtf.VariableDType(tf.float32, tf.float32, tf.float32)
84
+
85
+ other_features["attn_bias"] = biasmask_attn_weights(mesh, length_dim, memory_length_dim, variable_dtype)
86
+ other_features["embd_dim"] = embd_dim
87
+ other_features["vocab_dim"] = vocab_dim
88
+ other_features["embed_sequence_dim"] = embed_sequence_dim
89
+ other_features["memory_length_dim"] = memory_length_dim
90
+
91
+ with not_raises(Exception):
92
+ logits, _, _ = gpt2.model(features, other_features, params, mesh, variable_dtype=variable_dtype)
93
+
94
+ mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""])
95
+ lowering = mtf.Lowering(graph, {mesh: mesh_impl})
96
+ logits = lowering.export_to_tf_tensor(logits)
97
+
98
+
99
+ def test_sampling():
100
+ graph = mtf.Graph()
101
+ mesh = mtf.Mesh(graph, "my_mesh")
102
+
103
+ batch_dim = mtf.Dimension("batch", 1)
104
+ sequence_dim = mtf.Dimension("sequence", 1)
105
+
106
+ inputs = mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32)
107
+ inputs = mtf.pad(inputs, [0, 3], sequence_dim.name)
108
+
109
+ # create mask
110
+
111
+ seq_len = params["n_ctx"]
112
+ num_mem_kv = params.get('num_mem_kv', 0)
113
+ length_dim = mtf.Dimension('sequence', seq_len)
114
+ memory_length_dim = mtf.Dimension('memory_length', seq_len + num_mem_kv)
115
+ embed_sequence_dim = mtf.Dimension('embed_sequence', seq_len)
116
+ embd_dim = mtf.Dimension("embd", params["n_embd"])
117
+ vocab_dim = mtf.Dimension("vocab", params["n_vocab"])
118
+
119
+ other_features = {}
120
+
121
+ other_features["attn_bias"] = biasmask_attn_weights(mesh, length_dim, memory_length_dim, mtf.VariableDType(tf.float32))
122
+ other_features["embd_dim"] = embd_dim
123
+ other_features["vocab_dim"] = vocab_dim
124
+ other_features["embed_sequence_dim"] = embed_sequence_dim
125
+ other_features["memory_length_dim"] = memory_length_dim
126
+
127
+ params["mode"] = "predict"
128
+
129
+ with not_raises(Exception):
130
+ samples = sample_autoregressive(
131
+ inputs, other_features=other_features, params=params, variable_dtype=mtf.VariableDType(),
132
+ remove_partial_sequences=params["remove_partial_sequences"], stop_at_token=params["eos_id"], sampling_use_entmax=True)
133
+
134
+ mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""])
135
+ lowering = mtf.Lowering(graph, {mesh: mesh_impl})
136
+ samples = lowering.export_to_tf_tensor(samples)
137
+
138
+ # mlm
139
+
140
+ mlm_params = defaultdict(lambda: None, {
141
+ "n_head": 1,
142
+ "n_ctx": 4,
143
+ "n_embd": 1,
144
+ "n_vocab": 256,
145
+ "embed_dropout": 0.,
146
+ "n_layer": 2,
147
+ "num_microbatches": 1,
148
+ "train_batch_size": 1,
149
+ "attention_types": ['global', 'local'],
150
+ "res_dropout": 0.1,
151
+ "mesh_shape": [],
152
+ "layout": {},
153
+ "share_parameters": True,
154
+ "mlm_training": True,
155
+ "mlm_mask_id": 3,
156
+ "mlm_cls_token_id": 4,
157
+ "mlm_random_token_prob": 0.1
158
+ })
159
+
160
+ def test_mlm_sample_text():
161
+ document = tf.random.normal((16,))
162
+ with not_raises(Exception):
163
+ features, labels = mlm_sample_text(mlm_params, document, random_documents = True)
164
+ assert features.shape == (mlm_params['n_ctx'],)
165
+
166
+ # entmax
167
+
168
+ def test_entmax():
169
+ graph = mtf.Graph()
170
+ mesh = mtf.Mesh(graph, "my_mesh")
171
+ length = mtf.Dimension("tensor_length", 8)
172
+ tensor = mtf.range(mesh, length, tf.float32)
173
+ output = entmax(tensor)
174
+ grad = mtf.gradients([output], [tensor])[0]
175
+ sample = sample_categorical(output, length)
176
+
177
+ mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""])
178
+ lowering = mtf.Lowering(graph, {mesh: mesh_impl})
179
+ sample = lowering.export_to_tf_tensor(sample)
180
+ grad = lowering.export_to_tf_tensor(grad)
utils.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from urllib.parse import urlparse
3
+ from shutil import rmtree
4
+ import logging
5
+ import os
6
+ from pathlib import Path
7
+ import sys
8
+ import tensorflow.compat.v1 as tf
9
+ import tensorflow.compat.v2 as tf2
10
+ import mesh_tensorflow as mtf
11
+ import mesh_tensorflow.auto_mtf
12
+ from data.encoders import fetch_encoder
13
+ import re
14
+
15
+ def setup_logging(args):
16
+ Path("logs").mkdir(exist_ok=True)
17
+ tf.logging.set_verbosity(logging.INFO)
18
+ tf.get_logger().propagate = False # Remove double log on console
19
+ name = os.path.splitext(os.path.basename(args.model))[0]
20
+ handlers = [
21
+ logging.FileHandler(f"logs/{name}.log"),
22
+ logging.StreamHandler(sys.stdout)
23
+ ]
24
+ logger = logging.getLogger("tensorflow")
25
+ logger.handlers = handlers
26
+ return logger
27
+
28
+
29
+ def get_batch_size(params):
30
+ return params[f"{params['mode']}_batch_size"]
31
+
32
+
33
+ def add_mode_to_params(params, mode):
34
+ if mode == tf.estimator.ModeKeys.PREDICT:
35
+ params["mode"] = "predict"
36
+ elif mode == tf.estimator.ModeKeys.EVAL:
37
+ params["mode"] = "eval"
38
+ elif mode == tf.estimator.ModeKeys.TRAIN:
39
+ params["mode"] = "train"
40
+ else:
41
+ raise ValueError(f"Invalid mode {mode}")
42
+ return params
43
+
44
+
45
+ def simd_mesh_setup(params, mesh_shape, layout_rules):
46
+ """Constructs SimdMesh function - instructions on how to evenly split tensors across all TPU cores"""
47
+
48
+ num_hosts = params["context"].num_hosts
49
+ host_placement_fn = params["context"].tpu_host_placement_function
50
+ device_list = [host_placement_fn(host_id=i) for i in range(num_hosts)]
51
+ tf.logging.info(f"device_list = {device_list}")
52
+
53
+ # TODO: Better estimation of replica cache size?
54
+ replica_cache_size = 300 * 1000000 # 300M per replica
55
+
56
+ # Worker 0 caches all the TPU binaries
57
+ worker0_mem = replica_cache_size * params["context"].num_replicas
58
+ devices_memory_usage = [worker0_mem] + [0] * (num_hosts - 1)
59
+ var_placer = mtf.utils.BalancedVariablePlacer(device_list, devices_memory_usage)
60
+ mesh_devices = [""] * mesh_shape.size
61
+ mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
62
+ mesh_shape, layout_rules, mesh_devices, params["context"].device_assignment)
63
+
64
+ return var_placer, mesh_impl
65
+
66
+
67
+ def remove_batch_from_layout(layout):
68
+ """
69
+ The tf-mesh layout splits across batch size, remove it.
70
+ Useful for prediction steps, when you no longer want large batches.
71
+
72
+ :param layout: string describing tf-mesh layout
73
+ :return: layout minus batch dimension
74
+ """
75
+ layout = layout.split(',')
76
+ ret_layout = ""
77
+ for i in layout:
78
+ if "batch" in i:
79
+ pass
80
+ else:
81
+ ret_layout += f"{i},"
82
+ return ret_layout[:-1]
83
+
84
+
85
+ def yes_or_no(question):
86
+ while True:
87
+ reply = str(input(question+' (y/n): ')).lower().strip()
88
+ if reply[:1] == 'y':
89
+ return True
90
+ if reply[:1] == 'n':
91
+ return False
92
+
93
+
94
+ def remove_gs_or_filepath(path):
95
+ parsed_url = urlparse(path)
96
+ if parsed_url.scheme == "gs":
97
+ os.system(f"gsutil rm -rf {path}")
98
+ return
99
+ rmtree(path)
100
+
101
+
102
+ def save_config(params_dict, logdir):
103
+ print(f"Saving config to {logdir}")
104
+ text = "{\n\n"
105
+ total_params = len(params_dict)
106
+ for count, key in enumerate(params_dict):
107
+ config_value = str(params_dict[key])
108
+ if re.search('[a-zA-Z]', config_value):
109
+ if config_value.lower() != 'true':
110
+ if config_value.lower() != 'false':
111
+ if config_value[0] != '[':
112
+ # TODO: Making a manual exception for parsing epsilon right now since it's the only number in
113
+ # scientific notation. Should fix this.
114
+ if key != "epsilon":
115
+ config_value = f'"{config_value}"'
116
+ if count == total_params - 1:
117
+ text += f'"{str(key)}"' + ' : ' + config_value + '\n\n'
118
+ else:
119
+ text += f'"{str(key)}"' + ' : ' + config_value + ',\n\n'
120
+ text += '\n\n}'
121
+ sess = tf.InteractiveSession()
122
+ summary_op = tf.summary.text("run_config", tf.convert_to_tensor(text))
123
+ summary_writer = tf.summary.FileWriter(f"{logdir}/config", sess.graph)
124
+ text = sess.run(summary_op)
125
+ summary_writer.add_summary(text, 0)
126
+ summary_writer.flush()
127
+ summary_writer.close()
128
+ tf.reset_default_graph()
129
+ print('Done!')
130
+
131
+
132
+ def expand_attention_types_params(params_list):
133
+ newlist = []
134
+ for item in params_list:
135
+ for _ in range(item[1]):
136
+ newlist.extend(item[0])
137
+ return newlist
138
+
139
+
140
+ def get_n_trainable_vars(graph):
141
+ """
142
+ Gets number of trainable vars in a MTF model.
143
+
144
+ :param graph: Mesh-Tensorflow graph
145
+ :return: None
146
+ """
147
+ total_parameters = 0
148
+ for variable in graph.trainable_variables:
149
+ shape = variable.shape.dims
150
+ variable_parameters = 1
151
+ for dim in shape:
152
+ variable_parameters *= dim.size
153
+ total_parameters += variable_parameters
154
+ print(f"\n\nN TRAINABLE VARS:\n{total_parameters:,}\n\n")
155
+
156
+
157
+ def print_dim_names(graph):
158
+ """
159
+ Print names of all Dimensions
160
+ :param graph: Mesh-Tensorflow graph
161
+ :return: None
162
+ """
163
+ all_dim_names = []
164
+ for variable in graph.all_variables:
165
+ names = variable.shape.dimension_names
166
+ all_dim_names.append(names)
167
+
168
+ # Print all dim names in graph & write to file
169
+ all_dim_names = [item for sublist in all_dim_names for item in sublist] # Flatten all dims
170
+ unique_dims = list(set(all_dim_names))
171
+ print("ALL DIM NAMES:")
172
+ for dim_name in unique_dims:
173
+ print(dim_name)
174
+ print('\n')
175
+
176
+
177
+ def get_graph_info(graph):
178
+ """
179
+ Wrapper fn that calculates number of trainable vars in an MTF graph & prints all dim_names to file
180
+ TODO: how to get un-trainable dim-names too, batch etc.
181
+
182
+ :param graph: Mesh-Tensorflow graph
183
+ :return: None
184
+ """
185
+ get_n_trainable_vars(graph)
186
+ print_dim_names(graph)
187
+
188
+
189
+ def loss_denominator(targets, num_microbatches):
190
+ """Denominator applied to losses.
191
+
192
+ This is usually the size of the targets tensor (omitting ensemble
193
+ dimensions). Alternatively, it is an override value passed to the
194
+ class constructor.
195
+
196
+ Args:
197
+ targets: a mtf.Tensor
198
+ num_microbatches: an integer - greater than one if the step has been
199
+ serialized into multiple microbatches to save memory.
200
+ Returns:
201
+ a float
202
+ """
203
+ ret = float(targets.shape.size) * num_microbatches
204
+ return float(ret)
205
+
206
+ def check_dataset(input_fn, params, global_step=None):
207
+ tf.enable_eager_execution()
208
+ if global_step is not None:
209
+ dataset = input_fn(params, global_step=global_step)
210
+ else:
211
+ dataset = input_fn(params)
212
+ dataset_iter = dataset.make_one_shot_iterator()
213
+ tensor, _ = next(dataset_iter)
214
+ enc = fetch_encoder(params)
215
+
216
+ for p in tensor[:1]:
217
+ txt = enc.decode(p)
218
+
219
+ print('-' * 50)
220
+ print(txt[:500], '\n\n...\n\n', txt[-500:])
221
+ print('-' * 50)
222
+ exit()
223
+
224
+ def auto_layout(graph, mesh_shape, logits, loss):
225
+ layout_rules = mtf.auto_mtf.layout(graph, mesh_shape, [logits, loss])
226
+ print(f"Auto-selected layout:\n{layout_rules}\nRe-initialize graph with selected layout")
227
+ quit()
228
+
229
+ def auto_layout_and_mesh_shape(graph, num_cores, logits, loss):
230
+ layout_rules, mesh_shape = mtf.auto_mtf.layout_and_mesh_shape(graph, num_cores,
231
+ [logits, loss], max_mesh_shape_dimensions=4)
232
+ print(f"Num cores:\n{num_cores}\nAuto-selected layout:\n{layout_rules}\nAuto-selected mesh shape:\n{mesh_shape}" \
233
+ f"\nRe-initialize graph with selected layout & mesh shape")
234
+ quit()
235
+
236
+ def create_host_call(model_dir):
237
+ """Construct a host_call writing scalar summaries.
238
+
239
+ Borrowed from t2t.
240
+
241
+ Args:
242
+ model_dir: String containing path to train
243
+ Returns:
244
+ (fn, args) Pair to be called by TPUEstimator as the host_call.
245
+ """
246
+
247
+ graph = tf.get_default_graph()
248
+ # A list of (name, lowered tensor) tuples
249
+ summaries = graph.get_collection(mtf.utils.SCALAR_SUMMARIES_COLLECTION_KEY)
250
+
251
+ def maybe_cast(tensor):
252
+ assert tensor.shape.is_compatible_with([]), tensor.name
253
+ if tensor.dtype == tf.int64:
254
+ return tf.to_int32(tensor)
255
+ if tensor.dtype == tf.bfloat16:
256
+ return tf.cast(tensor, tf.float32)
257
+ return tensor
258
+
259
+ reshaped_tensors = [tf.reshape(maybe_cast(t), [1]) for _, t in summaries]
260
+
261
+ # When no supported summaries are found, don't create host_call. Otherwise,
262
+ # TPU outfeed queue would enqueue global_step while host_call doesn't dequeue
263
+ # it, eventually causing hang.
264
+ if not reshaped_tensors:
265
+ return None
266
+
267
+ def host_call_fn(global_step, *args):
268
+ """Training host call. Creates scalar summaries for training metrics."""
269
+ # This function is executed on the CPU and should not directly reference
270
+ # any Tensors in the rest of the `model_fn`. To pass Tensors from the
271
+ # model to the `model_fn`, provide as part of the `host_call`.
272
+ global_step = tf.cast(global_step[0], tf.int64)
273
+ with tf2.summary.create_file_writer(model_dir).as_default():
274
+ # We cannot directly use any tensor from summaries, because each
275
+ # tensor here must be a concat of multiple tensors from all shards.
276
+ # Therefore, we rely on the assumption that args wil have the same
277
+ # length as summaries, and all tensors in args will have the same
278
+ # order of self._tup_summaries.
279
+ assert len(args) == len(summaries)
280
+ for i, tensor in enumerate(args):
281
+ name = summaries[i][0]
282
+ tf2.summary.scalar(name, tf.reduce_mean(tensor), step=global_step)
283
+ return tf.summary.all_v2_summary_ops()
284
+
285
+ global_step_t = tf.reshape(tf.to_int32(tf.train.get_global_step()), [1])
286
+ return host_call_fn, [global_step_t] + reshaped_tensors
287
+
288
+
289
+ def natural_sort(l):
290
+ convert = lambda text: int(text) if text.isdigit() else text.lower()
291
+ alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ]
292
+ return sorted(l, key = alphanum_key)