Upload 21 files
Browse files- .gitignore +116 -0
- CITATION.cff +30 -0
- CODEOWNERS +1 -0
- Dockerfile +15 -0
- GPTNeo_example_notebook.ipynb +0 -0
- LICENSE +21 -0
- README.md +383 -0
- configs.py +47 -0
- docker-compose.yml +67 -0
- encoders.py +28 -0
- export.py +14 -0
- inputs.py +384 -0
- main.py +257 -0
- model_fns.py +305 -0
- optimizers.py +176 -0
- requirements.txt +18 -0
- run_experiment.py +265 -0
- sample.py +218 -0
- tasks.py +116 -0
- test_models.py +180 -0
- utils.py +292 -0
.gitignore
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# testing
|
2 |
+
.test/
|
3 |
+
|
4 |
+
# Byte-compiled / optimized / DLL files
|
5 |
+
__pycache__/
|
6 |
+
*.py[cod]
|
7 |
+
*$py.class
|
8 |
+
|
9 |
+
# C extensions
|
10 |
+
*.so
|
11 |
+
|
12 |
+
# Distribution / packaging
|
13 |
+
.Python
|
14 |
+
build/
|
15 |
+
develop-eggs/
|
16 |
+
dist/
|
17 |
+
downloads/
|
18 |
+
eggs/
|
19 |
+
.eggs/
|
20 |
+
lib/
|
21 |
+
lib64/
|
22 |
+
parts/
|
23 |
+
sdist/
|
24 |
+
var/
|
25 |
+
wheels/
|
26 |
+
*.egg-info/
|
27 |
+
.installed.cfg
|
28 |
+
*.egg
|
29 |
+
MANIFEST
|
30 |
+
|
31 |
+
# PyInstaller
|
32 |
+
# Usually these files are written by a python script from a template
|
33 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
34 |
+
*.manifest
|
35 |
+
*.spec
|
36 |
+
|
37 |
+
# Installer logs
|
38 |
+
pip-log.txt
|
39 |
+
pip-delete-this-directory.txt
|
40 |
+
|
41 |
+
# Unit test / coverage reports
|
42 |
+
htmlcov/
|
43 |
+
.tox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
|
53 |
+
# Translations
|
54 |
+
*.mo
|
55 |
+
*.pot
|
56 |
+
|
57 |
+
# Django stuff:
|
58 |
+
*.log
|
59 |
+
local_settings.py
|
60 |
+
db.sqlite3
|
61 |
+
|
62 |
+
# Flask stuff:
|
63 |
+
instance/
|
64 |
+
.webassets-cache
|
65 |
+
|
66 |
+
# Scrapy stuff:
|
67 |
+
.scrapy
|
68 |
+
|
69 |
+
# Sphinx documentation
|
70 |
+
docs/_build/
|
71 |
+
|
72 |
+
# PyBuilder
|
73 |
+
target/
|
74 |
+
|
75 |
+
# Jupyter Notebook
|
76 |
+
.ipynb_checkpoints
|
77 |
+
|
78 |
+
# pyenv
|
79 |
+
.python-version
|
80 |
+
|
81 |
+
# celery beat schedule file
|
82 |
+
celerybeat-schedule
|
83 |
+
|
84 |
+
# SageMath parsed files
|
85 |
+
*.sage.py
|
86 |
+
|
87 |
+
# Environments
|
88 |
+
.env
|
89 |
+
.venv
|
90 |
+
env/
|
91 |
+
venv/
|
92 |
+
ENV/
|
93 |
+
env.bak/
|
94 |
+
venv.bak/
|
95 |
+
|
96 |
+
# Spyder project settings
|
97 |
+
.spyderproject
|
98 |
+
.spyproject
|
99 |
+
|
100 |
+
# Rope project settings
|
101 |
+
.ropeproject
|
102 |
+
|
103 |
+
# mkdocs documentation
|
104 |
+
/site
|
105 |
+
|
106 |
+
# mypy
|
107 |
+
.mypy_cache/
|
108 |
+
|
109 |
+
logs/
|
110 |
+
*.log
|
111 |
+
test_*
|
112 |
+
test/
|
113 |
+
.vscode
|
114 |
+
|
115 |
+
|
116 |
+
run_configs/
|
CITATION.cff
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# YAML 1.2
|
2 |
+
---
|
3 |
+
authors:
|
4 |
+
- affiliation: EleutherAI
|
5 |
+
family-names: Black
|
6 |
+
given-names: Sid
|
7 |
+
- affiliation: EleutherAI
|
8 |
+
family-names: Leo
|
9 |
+
given-names: Gao
|
10 |
+
- affiliation: EleutherAI
|
11 |
+
family-names: Wang
|
12 |
+
given-names: Phil
|
13 |
+
- affiliation: EleutherAI
|
14 |
+
family-names: Leahy
|
15 |
+
given-names: Connor
|
16 |
+
- affiliation: EleutherAI
|
17 |
+
family-names: Biderman
|
18 |
+
given-names: Stella
|
19 |
+
cff-version: "1.1.0"
|
20 |
+
keywords:
|
21 |
+
- Transformers
|
22 |
+
- "Massive language model"
|
23 |
+
- "Autoregressive language model"
|
24 |
+
license: "Apache-2.0"
|
25 |
+
message: "If you use this software, please cite it using these metadata."
|
26 |
+
repository-code: "https://www.github.com/eleutherai/gpt-neo"
|
27 |
+
title: "GPT-Neo: Large Scale Autoregressive Language Modeling with Mesh-Tensorflow"
|
28 |
+
version: "1.0"
|
29 |
+
date-released: 2021-03-21
|
30 |
+
...
|
CODEOWNERS
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
* EleutherAI/pm-gptneo
|
Dockerfile
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM gcr.io/deeplearning-platform-release/tf-cpu.1-15
|
2 |
+
|
3 |
+
WORKDIR /neogpt
|
4 |
+
|
5 |
+
# Make RUN commands use `bash --login`:
|
6 |
+
SHELL ["/bin/bash", "--login", "-c"]
|
7 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
8 |
+
RUN apt-get update -y && apt-get install tmux -y
|
9 |
+
RUN conda install gcc_linux-64 gxx_linux-64 -y
|
10 |
+
ADD requirements.txt .
|
11 |
+
RUN pip install -r requirements.txt
|
12 |
+
RUN apt-get install screen htop -y
|
13 |
+
RUN python -m pip install tensorboard==1.15 cloud_tpu_profiler==1.15
|
14 |
+
|
15 |
+
CMD tmux
|
GPTNeo_example_notebook.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2020 EleutherAI
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
ADDED
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GPT Neo
|
2 |
+
|
3 |
+
🎉 1T or bust my dudes 🎉
|
4 |
+
|
5 |
+
An implementation of model & data parallel [GPT3](https://arxiv.org/abs/2005.14165)-like models using the [mesh-tensorflow](https://github.com/tensorflow/mesh) library.
|
6 |
+
|
7 |
+
**If you're just here to play with our pre-trained models, we strongly recommend you try out the [HuggingFace Transformer integration](https://huggingface.co/EleutherAI).**
|
8 |
+
|
9 |
+
Training and inference is officially supported on TPU and should work on GPU as well. This repository will be (mostly) archived as we move focus to our GPU-specific repo, [GPT-NeoX](https://github.com/EleutherAI/gpt-neox/).
|
10 |
+
|
11 |
+
In addition to the functionality offered by GPT-3, we also offer the following:
|
12 |
+
* [Local attention](https://arxiv.org/abs/2004.05150)
|
13 |
+
* [Linear attention](https://arxiv.org/abs/1812.01243)
|
14 |
+
* [Mixture of Experts](https://arxiv.org/abs/1701.06538)
|
15 |
+
* [Axial Positional embedding](https://arxiv.org/abs/1912.12180)
|
16 |
+
|
17 |
+
NB, while neo can *technically* run a training step at 200B+ parameters, it is very inefficient at those scales. This, as well as the fact that many GPUs became available to us, among other things, prompted us to move development over to [GPT-NeoX](https://github.com/EleutherAI/gpt-neox/).
|
18 |
+
|
19 |
+
# Pretrained Models
|
20 |
+
|
21 |
+
**Update 21/03/2021:**
|
22 |
+
|
23 |
+
We're proud to release two pretrained GPT-Neo models trained on The Pile, the weights and configs can be freely downloaded from [the-eye.eu](https://the-eye.eu/public/AI/gptneo-release/).
|
24 |
+
|
25 |
+
1.3B: https://the-eye.eu/public/AI/gptneo-release/GPT3_XL/
|
26 |
+
|
27 |
+
2.7B: https://the-eye.eu/public/AI/gptneo-release/GPT3_2-7B/
|
28 |
+
|
29 |
+
For more information on how to get these set up, see the colab notebook, or read through the rest of the readme.
|
30 |
+
|
31 |
+
## Model Evaluations
|
32 |
+
|
33 |
+
#### Linguistic Reasoning
|
34 |
+
|
35 |
+
| Model and Size | Pile BPB | Pile PPL | Wikitext PPL | Lambada PPL | Lambada Acc | Winogrande | Hellaswag |
|
36 |
+
| ---------------- | ---------- | ---------- | ------------- | ----------- | ----------- | ---------- | ----------- |
|
37 |
+
| **GPT-Neo 1.3B** | **0.7527** | **6.159** | **13.10** | **7.498** | **57.23%** | **55.01%** | **38.66%** |
|
38 |
+
| GPT-2 1.5B | 1.0468 | ----- | 17.48 | 10.634 | 51.21% | 59.40% | 40.03% |
|
39 |
+
| **GPT-Neo 2.7B** | **0.7165** | **5.646** | **11.39** | **5.626** | **62.22%** | **56.50%** | **42.73%** |
|
40 |
+
| GPT-3 Ada | 0.9631 | ----- | ----- | 9.954 | 51.60% | 52.90% | 35.93% |
|
41 |
+
|
42 |
+
#### Physical and Scientific Reasoning
|
43 |
+
|
44 |
+
| Model and Size | MathQA | PubMedQA | Piqa |
|
45 |
+
| ---------------- | ---------- | ---------- | ----------- |
|
46 |
+
| **GPT-Neo 1.3B** | **24.05%** | **54.40%** | **71.11%** |
|
47 |
+
| GPT-2 1.5B | 23.64% | 58.33% | 70.78% |
|
48 |
+
| **GPT-Neo 2.7B** | **24.72%** | **57.54%** | **72.14%** |
|
49 |
+
| GPT-3 Ada | 24.29% | 52.80% | 68.88% |
|
50 |
+
|
51 |
+
**Note:** All evaluations were done using our [evaluation harness](https://github.com/EleutherAI/lm-evaluation-harness). Some results for GPT-2 and GPT-3 are inconsistent with the values reported in the respective papers. We are currently looking into why, and would greatly appreciate feedback and further testing of our eval harness.
|
52 |
+
|
53 |
+
# Setup
|
54 |
+
|
55 |
+
```bash
|
56 |
+
git clone https://github.com/EleutherAI/GPTNeo
|
57 |
+
cd GPTNeo
|
58 |
+
pip3 install -r requirements.txt
|
59 |
+
```
|
60 |
+
# Training Setup
|
61 |
+
|
62 |
+
## TPUs:
|
63 |
+
|
64 |
+
Sign up for [Google Cloud Platform](https://cloud.google.com/), and create a [storage bucket](https://cloud.google.com/storage).
|
65 |
+
|
66 |
+
Create your VM through a google shell (`https://ssh.cloud.google.com/`) with `ctpu up --vm-only` so that it can connect to your Google bucket and TPUs and install the requirements with pip (see above).
|
67 |
+
|
68 |
+
Google colab provides tpu-v8s for free, which should be enough to finetune our models up to GPT3XL (1.5B parameter) sizes.
|
69 |
+
Click [](https://colab.research.google.com/github/EleutherAI/GPTNeo/blob/master/GPTNeo_example_notebook.ipynb) to run through our example colab notebook.
|
70 |
+
|
71 |
+
For more detailed instructions, run through our [Training Guide](https://github.com/EleutherAI/GPTNeo#training-guide) below.
|
72 |
+
|
73 |
+
## GPUs:
|
74 |
+
|
75 |
+
You can also choose to train GPTNeo locally on your GPUs. To do so, you can omit the Google cloud setup steps above, and git clone the repo locally. Run through the [Training Guide](https://github.com/EleutherAI/GPTNeo#training-guide) below, then when running main.py, you simply have to omit the `tpu` flag, and pass in GPU ids instead.
|
76 |
+
|
77 |
+
Note: Some users have reported having difficulty getting MTF to recognize their GPUs. See [here](https://github.com/EleutherAI/gpt-neo/issues/150) for details and instructions on how to fix it.
|
78 |
+
|
79 |
+
# Generating Text
|
80 |
+
|
81 |
+
Once you have a trained model, or you've downloaded one of our pre-trained models, generating text is as simple as running the main.py script with the `--predict` flag on. You can pass a path to your prompt txt file with the `--prompt` flag, like so:
|
82 |
+
|
83 |
+
```bash
|
84 |
+
python3 main.py --predict --prompt <example_prompt.txt> --tpu <tpu_name> --model <config_name>
|
85 |
+
```
|
86 |
+
|
87 |
+
or, if using GPUs:
|
88 |
+
|
89 |
+
```bash
|
90 |
+
python3 main.py --predict --prompt <example_prompt.txt> --gpu_ids <device:GPU:0 device:GPU:1> --model <config_name>
|
91 |
+
```
|
92 |
+
|
93 |
+
# Training Guide
|
94 |
+
|
95 |
+
## 1. Create your Tokenizer (OPTIONAL)
|
96 |
+
|
97 |
+
We recommend you use [Huggingface's pretrained GPT2 tokenizer](https://huggingface.co/transformers/model_doc/gpt2.html#transformers.GPT2Tokenizer) with our repo (instructions provided below), but if you want to train a model with a different vocabulary size, we provide facilities to train your own tokenizer like so:
|
98 |
+
|
99 |
+
```bash
|
100 |
+
python data/train_tokenizer.py \
|
101 |
+
--base_dir ./path/to/your/txt/files \
|
102 |
+
--output_dir ./output/path \
|
103 |
+
--file_type txt \
|
104 |
+
--vocab_size 50257
|
105 |
+
|
106 |
+
# if it succeeded, you should see the message
|
107 |
+
# 'tokenizer saved at ./output/path/byte-level-bpe.tokenizer.json'
|
108 |
+
```
|
109 |
+
|
110 |
+
## 2. Tokenizing your Dataset
|
111 |
+
|
112 |
+
If you just want to test training, you can skip this step and download some dummy data like so:
|
113 |
+
|
114 |
+
```
|
115 |
+
wget https://storage.googleapis.com/connors-datasets/bundestag/bundestag_0.tfrecords
|
116 |
+
```
|
117 |
+
|
118 |
+
Then copy the data to your bucket, or if using GPUs, a local directory:
|
119 |
+
|
120 |
+
```
|
121 |
+
gsutil cp bundestag_0.tfrecords gs://<your bucket>/
|
122 |
+
```
|
123 |
+
|
124 |
+
If using your own data to train, you can use the `data/create_tfrecords.py` script to encode your text data into tfrecords.
|
125 |
+
|
126 |
+
Your data must either be in the form of lots of normal .txt files (one document per file), or in any format supported by [lm_dataformat](https://github.com/leogao2/lm_dataformat).
|
127 |
+
|
128 |
+
You can run the script without parameters to see help for all options.
|
129 |
+
|
130 |
+
In **document mode** Each example in the tfrecords is one (variably sized) document. This is to be used with the `documents_fixed` and `documents_random` sampling modes (For more details see the parameters reference section).
|
131 |
+
Document mode is the default mode.
|
132 |
+
|
133 |
+
The below command will tokenize all files in acceptable formats in *base_dir* using gpt2 tokenizer and save them to *output_dir*
|
134 |
+
```
|
135 |
+
python3 create_tfrecords.py --mode documents --input_dir <base> --name <name> --output_dir <output> --use_gpt2_tokenizer --minimum_size <min>
|
136 |
+
```
|
137 |
+
|
138 |
+
- `input_dir`: Defines the folder where your data is located. The script will encode all files present in this folder.
|
139 |
+
- `name`: Name of output files will be `name_i.tfrecords` where i is the number of the file.
|
140 |
+
- `output_dir`: Where to save the tfrecords to
|
141 |
+
- `use_gpt2_tokenizer`: Whether to use the pretrained HuggingFace GPT2 tokenizer, in which case the separator will be set to [50256].
|
142 |
+
- `encoder_path`: if not using the pretrained gpt2 tokenizer, use this flag to provide a path to your generated tokenizer json.
|
143 |
+
- `separator`: Written in list format, the separator token(s) to insert between documents (e.g. "[0]"). Will depend on your encoder.
|
144 |
+
- `minimum_size`: The minimum size (in tokens) a document must have, otherwise it is discarded. This is what will later determine your `stitch` parameter: `stitch * minimum_size` must always be greater or equal `n_ctx` (For more details see the parameters reference section).
|
145 |
+
|
146 |
+
## 4. Using a Dataset in a Model
|
147 |
+
|
148 |
+
To use a dataset in a model, you must first register that dataset under `./configs/dataset_configs` folder. First choose a filename with a `.json` extension. That filename will serve as the dataset identification. The config should be filled out the following manner.
|
149 |
+
|
150 |
+
If you have a dataset encoded using the pretrained gpt2 tokenizer, you can specify that like so:
|
151 |
+
|
152 |
+
```json
|
153 |
+
{
|
154 |
+
"n_vocab": 50257,
|
155 |
+
"path": "gs://neo-datasets/openwebtext-documents/openwebtext_*.tfrecords",
|
156 |
+
"eval_path": "gs://neo-datasets/openwebtext-documents/openwebtext_*.tfrecords",
|
157 |
+
"tokenizer_is_pretrained": true,
|
158 |
+
"tokenizer_path": "gpt2"
|
159 |
+
}
|
160 |
+
```
|
161 |
+
|
162 |
+
or if you've trained a custom tokenizer, like so:
|
163 |
+
|
164 |
+
```json
|
165 |
+
{
|
166 |
+
"n_vocab": 32768,
|
167 |
+
"path": "./path/to/your/*.tfrecords",
|
168 |
+
"eval_path": "./path/to/your/eval/*.tfrecords",
|
169 |
+
"tokenizer_path": "./path/to/your/byte-level-bpe.tokenizer.json"
|
170 |
+
}
|
171 |
+
```
|
172 |
+
|
173 |
+
Finally, in your model config, add the filename that you created above to the `datasets` array.
|
174 |
+
|
175 |
+
The `<dataset id>` will be the filename, excluding the `.json`, that you created above
|
176 |
+
|
177 |
+
```
|
178 |
+
"datasets": [[<dataset id>, <stitch>, <datatype>, <weight>]] # datasets key defines at run time how each dataset is processed for training
|
179 |
+
```
|
180 |
+
|
181 |
+
## 5. Choose a model configuration
|
182 |
+
|
183 |
+
Once you have your datasets set up, find a suitable config in `/configs`.
|
184 |
+
|
185 |
+
Here we use a GPT3-XL sized model as an example, but there are many more in `./configs`, all of which have short summaries in the Available Configs section.
|
186 |
+
|
187 |
+
All you need to do is edit the dataset id as described above, and edit `model_path` (where logs and checkpoints will be saved) to point to a cloud bucket you have write access to (or local path, if using GPUs).
|
188 |
+
|
189 |
+
```json
|
190 |
+
{
|
191 |
+
"n_head": 32,
|
192 |
+
"n_vocab": 50257,
|
193 |
+
"embed_dropout": 0.1,
|
194 |
+
"lr": 0.0002,
|
195 |
+
"lr_decay": "cosine",
|
196 |
+
"warmup_steps": 3000,
|
197 |
+
"beta1": 0.9,
|
198 |
+
"beta2": 0.95,
|
199 |
+
"epsilon": 1e-8,
|
200 |
+
"opt_name": "adam",
|
201 |
+
"weight_decay": 0.1,
|
202 |
+
"train_batch_size": 512,
|
203 |
+
"attn_dropout": 0.1,
|
204 |
+
"train_steps": 286150,
|
205 |
+
"eval_steps": 0,
|
206 |
+
"predict_steps": 1,
|
207 |
+
"res_dropout": 0.1,
|
208 |
+
"eval_batch_size": 128,
|
209 |
+
"predict_batch_size": 1,
|
210 |
+
"iterations": 2500,
|
211 |
+
"n_embd": 2048,
|
212 |
+
"datasets": [["your_dataset_name", 25, "documents_random", 1.0]],
|
213 |
+
"model_path": "gs://neo-models/GPT3_XL",
|
214 |
+
"n_ctx": 2048,
|
215 |
+
"n_layer": 24,
|
216 |
+
"scale_by_depth": true,
|
217 |
+
"scale_by_in": false,
|
218 |
+
"attention_types" : [[["global"],24]],
|
219 |
+
"mesh_shape": "x:128,y:2",
|
220 |
+
"layout": "batch:x,memory_length:y,embd:y",
|
221 |
+
"activation_function": "gelu",
|
222 |
+
"recompute_grad": true,
|
223 |
+
"gradient_clipping": 1.0,
|
224 |
+
"tokens_per_mb_per_replica": 2048
|
225 |
+
}
|
226 |
+
```
|
227 |
+
|
228 |
+
|
229 |
+
## 6. Run Training
|
230 |
+
|
231 |
+
```
|
232 |
+
python3 main.py --model <your_config_name> --steps_per_checkpoint <n> --tpu <tpu-name>
|
233 |
+
```
|
234 |
+
|
235 |
+
- `tpu`: Name of the TPU to use.
|
236 |
+
- `steps_per_checkpoint`: The frequency in steps at which to save checkpoints.
|
237 |
+
- `--auto_layout` and `--auto_layout_and_mesh_shape` (Optional): Disable training and instead auto generate a memory efficient `layout` (and `mesh_shape`)
|
238 |
+
- `gpu_ids`: if training using GPUs, omit the `tpu` flag and pass in the ids of your gpus. In the example below, we train on 3 GPUs, specifying their device ids delimited by spaces:
|
239 |
+
|
240 |
+
```
|
241 |
+
python3 main.py --model <your_config_name> --steps_per_checkpoint <n> --gpu_ids <device:GPU:0 device:GPU:1>
|
242 |
+
```
|
243 |
+
|
244 |
+
# Available Configs
|
245 |
+
|
246 |
+
We have several model sizes available, but some of our configs require large TPUs and will need tweaking to run on smaller machines, or GPUs. Below is a short guide to each model in the configs directory:
|
247 |
+
|
248 |
+
TODO
|
249 |
+
|
250 |
+
# Extra Features:
|
251 |
+
|
252 |
+
## Training (with Sacred)
|
253 |
+
|
254 |
+
[Sacred](https://github.com/IDSIA/sacred) helps track experiments and is much nicer to work with than tensorboard.
|
255 |
+
|
256 |
+
To setup:
|
257 |
+
|
258 |
+
1. Install Docker and Docker-compose
|
259 |
+
|
260 |
+
2. Run `docker-compose up`
|
261 |
+
|
262 |
+
To use:
|
263 |
+
|
264 |
+
1. Ensure model_dir doesn't have any metric logs in it (it trips up the metric stuff for tensorboard, which assumes that it's a continuation of the existing run). You can use `gsutil rm -r ...` to delete model dir
|
265 |
+
|
266 |
+
2. Run `python3 run_experiment.py --tpu sometpuhere --model someconfig.json` Options are the same as `main.py`.
|
267 |
+
|
268 |
+
3. You can go to http://server_ip_goes_here:8081/ to see the Omniboard overview. If you prefer to see a tensorboard, the script also spins one up and automatically assigns it a port. The script should print out the tensorboard port near the top of the log.
|
269 |
+
|
270 |
+
## Peeking at a Dataset
|
271 |
+
|
272 |
+
If you are ever confused by the dataset of a particular config file, you can easily check the minimum and maximum token ids with a single command. This is useful for making sure that the vocabulary size of the model is at least as large as the maximum token id. Tensorflow will not error if you try to gather on a matrix with out of bounds indices, so you need to make sure your vocabulary size is sufficiently large.
|
273 |
+
|
274 |
+
```bash
|
275 |
+
python main --model {config_name} --check_dataset
|
276 |
+
```
|
277 |
+
|
278 |
+
## Masked Language Modeling
|
279 |
+
|
280 |
+
In addition to being able to train large GPT's, this repository also allows you to easily do masked language modeling (BERT, RoBERTa). In order to do so, you must follow two additional steps.
|
281 |
+
|
282 |
+
1. When tokenizing your dataset, you must reserve a special id for the `[mask]` token.
|
283 |
+
|
284 |
+
2. In the configs, you will have to define two additional fields
|
285 |
+
|
286 |
+
```python
|
287 |
+
"mlm_training": true, # must be set to true
|
288 |
+
"mlm_mask_id": <mask id> # the mask id that you reserved from above
|
289 |
+
```
|
290 |
+
|
291 |
+
That's all you need to train a model with the MLM objective, good for any type of data that you have encoded properly. If you would like to tweak the other related hyperparameters, please continue reading.
|
292 |
+
|
293 |
+
```python
|
294 |
+
"mlm_cls_token_id": <cls token id>, # auto append specified CLS token id on the left
|
295 |
+
"mlm_mask_prob": 0.15, # the probability of masking a token, defaults to 15%
|
296 |
+
"mlm_same_token_prob": 0.10, # probability of keeping the token the same, defaults to 10%
|
297 |
+
"mlm_random_token_prob": 0.10, # probability of tokens that are replaced with random tokens, 10% was recommended by the BERT paper
|
298 |
+
"mlm_mask_ignore_ids": [<cls token>, <sep token>] # ignore masking other special tokens, if any
|
299 |
+
```
|
300 |
+
|
301 |
+
## Parameter Reference
|
302 |
+
|
303 |
+
Pick a valid config from `/configs` and tweak the parameters as needed:
|
304 |
+
|
305 |
+
- `n_heads`: The number of attention heads.
|
306 |
+
- `n_embd`: Size of the hidden layers, must be divisible by `n_heads`.
|
307 |
+
- `n_vocab`: Vocabulary size.
|
308 |
+
- `embed_dropout`, `res_dropout`, `attn_dropout`: Dropout probability for word embedding/residuals/attention
|
309 |
+
- `lr`: Learning rate
|
310 |
+
- `warmup_steps`: Number of steps before full learning rate is reached (linear ramp from `0` to `lr`).
|
311 |
+
- `lr_decay`: `cosine` or `linear`.
|
312 |
+
- `opt_name`: `adam` or `adafactor`.
|
313 |
+
- `beta1`, `beta2` and `epsilon`: `adam` optimizer params.
|
314 |
+
- `beta1`, `ada_epsilon1` and `ada_epsilon2`: `adafactor` optimizer params.
|
315 |
+
- `weight_decay`: Weight decay parameter, if not present no weight decay is used (the weight decay fix for Adam is used) (default: 0.01) (optional).
|
316 |
+
- `train_batch_size`: Batch size during training.
|
317 |
+
- `train_steps`: Number of training steps (batches), set to roughly ~1 epoch for now (total number of tokens in your dataset / number of tokens per batch (= `train_batch_size` / `n_ctx`)).
|
318 |
+
- `eval_steps`: Number of steps to run for each evaluation. Set to `0` for no eval. i.e After every checkpoint, the model is tested for `eval_steps`
|
319 |
+
- `iterations`: Number of steps queued to the TPU, must be smaller than `steps_per_checkpoint`. (default: 500)
|
320 |
+
- `datasets`: List of tfrecords datasets to use. Each dataset is a list with the following parameters: `[train glob , eval glob, stitch, sampling_mode, weight]`. So for example for a single dataset (note the double list): `[["bundestag_*.tfrecords", "", 10, "random_sample", 1.0]]`
|
321 |
+
+ `dataset_id`: The name of a dataset configuration file in `./configs/dataset_configs`
|
322 |
+
+ `stitch`: If `sampling_mode` `random_sample` is used, the input pipeline samples this amount of texts into one to sample from. You must select stitch so that `stitch * minimum_document_length >= n_ctx`
|
323 |
+
+ `sampling_mode`: `chunks` (tfrecords are preprocessed into the correct length and are read sequentially) or `documents_random` (`stitch` amount of documents are concatenated and then a `n_ctx` chunk is randomly subsampled)
|
324 |
+
+ `weights`: How much relative weight this dataset should have compared to others
|
325 |
+
- `model`: Which model to train. Currently only `GPT` is supported, and it defaults to this if not present.
|
326 |
+
- `model_path`: Google storage bucket location (or local path, if using GPUs) to save model checkpoints and logs.
|
327 |
+
- `n_ctx`: Size of context window. Default is 2048
|
328 |
+
- `n_layer`: Number of layers (blocks) in the model.
|
329 |
+
- `scale_by_depth`: If true, the weight initialization of layers are scaled by their depth as in the GPT2 paper.
|
330 |
+
- `scale_by_in`: If true, the weight initialization of layers are scaled by their number of inputs as in the GPT2 paper.
|
331 |
+
- `mesh_shape`: A Mesh is an n-dimensional array of processors with named dimensions used for parallelism in the mesh-tensorflow library. Each Tensor is split evenly across mesh dimensions according to the layout (see below). The 'mesh_shape' is the shape of this array, and must be equal to the number of processors. e.g., for a v3-128 TPU "mesh_shape": “x:16,y:8”.
|
332 |
+
- `layout`: A Tensor is laid out on its mesh with one slice on each processor. A Tensor "layout", is an injective partial map specifying which dimensions of the tensor are (evenly) split across which dimensions of the mesh. No dimension of a tensor may be split across two dimensions of its mesh and no two dimensions of a tensor may be split across the same dimension of its mesh. The user defines a global set of layout rules in the form of (tensor-dimension-name, mesh-dimension-name) pairs. A dimension of a tensor is split across a dimension of its mesh if there is a matching rule, e.g. (for the above example mesh_shape: "layout":"batch:x,heads:y"
|
333 |
+
- `activation_function`: `selu` (self normalizing) or `gelu` (used by OA), activation function used in feed-forward passes. (default: gelu)
|
334 |
+
- `attention_types`: the type of attention for each layer in a list of the following format [[["attention_type"], n_layers]]. e.g. for a 12 layer net [[["global"], 12]] or [[["local"], 10], [["global"], 2]].
|
335 |
+
+ Choose from: `linear`, `global`, `local` or `none`. We have found a 50/50 mix of `global` and `linear` to work well. `none` allows you to create feed-forward only layers for more efficient [PAR Transformer](https://arxiv.org/abs/2009.04534) models.
|
336 |
+
- `precision`: `float32` or `bfloat16`.
|
337 |
+
- `tokens_per_mb_per_replica`: If not None, will split the batch up into smaller microbatches containing `tokens_per_mb_per_replica` tokens to avoid OOMs. Gradients are accumulated locally and reduced once. IMPORTANT: mb refers to *minibatch* not megabyte here.
|
338 |
+
|
339 |
+
**Mixture of Experts**
|
340 |
+
|
341 |
+
- `moe_layers`: A list of layer numbers to append a [mixture of experts](https://arxiv.org/abs/1701.06538) layer onto. E.G: `[2,4,6,8,10,12]`.
|
342 |
+
We have experimentally found a moe layer for every two self-attention layers to work well.
|
343 |
+
- `moe_params`: a dictionary of additional kwargs to pass in to the moe layer. E.G
|
344 |
+
`{"moe_dropout_rate": 0.0 }`
|
345 |
+
|
346 |
+
**Experimental features**
|
347 |
+
|
348 |
+
- `axial_pos_emb_`: If true, uses [axial positional embedding](https://arxiv.org/abs/1912.12180.
|
349 |
+
- `mlp_glu`: If true, uses a gated linear unit variant of feed forward layers.
|
350 |
+
- `scalenorm`: If true, uses scalenorm instead of layernorm.
|
351 |
+
- `rezero`: If true, uses [rezero](https://www.groundai.com/project/rezero-is-all-you-need-fast-convergence-at-large-depth/1) instead of layernorm.
|
352 |
+
- `num_mem_kv`: adds memory / key values from the [all-attention paper](https://arxiv.org/pdf/1907.01470.pdf). Param is an int with the number of desired mem/key values.
|
353 |
+
- `macaron`: if true - uses a [macaron transformer](https://arxiv.org/pdf/1906.02762.pdf) for each layer block.
|
354 |
+
|
355 |
+
## TODO:
|
356 |
+
|
357 |
+
- [x] finalize documentation
|
358 |
+
- [ ] update configs
|
359 |
+
|
360 |
+
## Citing GPT-Neo
|
361 |
+
|
362 |
+
If you have found GPT-Neo helpful in your work, you can cite this repository as
|
363 |
+
|
364 |
+
```
|
365 |
+
@software{gpt-neo,
|
366 |
+
author = {Black, Sid and Gao, Leo and Wang, Phil and Leahy, Connor and Biderman, Stella},
|
367 |
+
title = {{GPT-Neo}: Large Scale Autoregressive Language Modeling with Mesh-Tensorflow},
|
368 |
+
url = {http://github.com/eleutherai/gpt-neo},
|
369 |
+
version = {1.0},
|
370 |
+
year = {2021},
|
371 |
+
}
|
372 |
+
```
|
373 |
+
The version number should be replaced with the version number you are using, and the year corresponds to the project's open-source release.
|
374 |
+
|
375 |
+
If you are specifically interested in citing the GPT-Neo models trained on [the Pile](https://arxiv.org/abs/2101.00027), we would appreciate also citing
|
376 |
+
```
|
377 |
+
@article{gao2020pile,
|
378 |
+
title={The Pile: An 800GB Dataset of Diverse Text for Language Modeling},
|
379 |
+
author={Gao, Leo and Biderman, Stella and Black, Sid and Golding, Laurence and Hoppe, Travis and Foster, Charles and Phang, Jason and He, Horace and Thite, Anish and Nabeshima, Noa and others},
|
380 |
+
journal={arXiv preprint arXiv:2101.00027},
|
381 |
+
year={2020}
|
382 |
+
}
|
383 |
+
```
|
configs.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from pathlib import Path
|
3 |
+
from collections import defaultdict
|
4 |
+
|
5 |
+
DATASETS = {}
|
6 |
+
|
7 |
+
for path in Path("configs/dataset_configs").glob("*.json"):
|
8 |
+
dataset_id = path.stem
|
9 |
+
DATASETS[dataset_id] = json.loads(path.read_text())
|
10 |
+
|
11 |
+
|
12 |
+
def fetch_model_params(model):
|
13 |
+
model_path = model if model.endswith(".json") else f"configs/{model}.json"
|
14 |
+
with open(model_path) as f:
|
15 |
+
params = json.load(f)
|
16 |
+
|
17 |
+
dataset_ids = []
|
18 |
+
for d in params.get("datasets"):
|
19 |
+
if isinstance(d, list):
|
20 |
+
dataset_ids.append(d[0])
|
21 |
+
else:
|
22 |
+
dataset_ids.append(d)
|
23 |
+
no_datasets = params.get("no_dataset", False)
|
24 |
+
assert no_datasets or len(dataset_ids) > 0, "You must specify at least one dataset id in the model config"
|
25 |
+
|
26 |
+
datasets = {}
|
27 |
+
last_dataset = None
|
28 |
+
for dataset_id in dataset_ids:
|
29 |
+
assert dataset_id in DATASETS, f"Dataset '{dataset_id}' was not found under dataset_configs/ folder. Please follow the example.json in that folder."
|
30 |
+
dataset = DATASETS[dataset_id]
|
31 |
+
assert params["n_vocab"] >= dataset["n_vocab"], f"The embedding table size '{params['n_vocab']}' must be greater or equal to the vocab size used to encode the dataset '{dataset_id}' ({dataset['n_vocab']})"
|
32 |
+
datasets[dataset_id] = dataset
|
33 |
+
last_dataset = dataset
|
34 |
+
|
35 |
+
if last_dataset is not None:
|
36 |
+
params["padding_id"] = last_dataset.get("padding_id", 0)
|
37 |
+
params["eos_id"] = last_dataset.get("eos_id", 1)
|
38 |
+
|
39 |
+
params["dataset_configs"] = datasets
|
40 |
+
|
41 |
+
# Set some other parameter defaults
|
42 |
+
params["mlm_training"] = params.get("mlm_training") == True
|
43 |
+
params["causal"] = not params["mlm_training"]
|
44 |
+
|
45 |
+
# Set all other parameter values to default to None
|
46 |
+
params = defaultdict(lambda: None, params)
|
47 |
+
return params
|
docker-compose.yml
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
version: '3'
|
2 |
+
services:
|
3 |
+
|
4 |
+
mongo:
|
5 |
+
image: mongo
|
6 |
+
ports:
|
7 |
+
- 127.0.0.1:27017:27017
|
8 |
+
environment:
|
9 |
+
MONGO_INITDB_ROOT_USERNAME: user
|
10 |
+
MONGO_INITDB_ROOT_PASSWORD: password
|
11 |
+
MONGO_INITDB_DATABASE: db
|
12 |
+
expose:
|
13 |
+
- 27017
|
14 |
+
networks:
|
15 |
+
- omniboard
|
16 |
+
volumes:
|
17 |
+
- ./data:/data/db
|
18 |
+
|
19 |
+
mongoClientTemp:
|
20 |
+
image: mongo:latest
|
21 |
+
container_name: mongoClientTemp
|
22 |
+
links:
|
23 |
+
- mongo:mongo
|
24 |
+
command: mongo --host mongo -u user -p password --eval "db.getSiblingDB('db').createUser({user:'readonly', pwd:'password', roles:[{role:'read',db:'db'}]});"
|
25 |
+
depends_on:
|
26 |
+
- mongo
|
27 |
+
networks:
|
28 |
+
- omniboard
|
29 |
+
|
30 |
+
omniboard_readonly:
|
31 |
+
#image: vivekratnavel/omniboard:latest
|
32 |
+
build: https://github.com/lucidrains/omniboard.git
|
33 |
+
command: ["--mu", "mongodb://readonly:password@mongo:27017/db"]
|
34 |
+
ports:
|
35 |
+
- 0.0.0.0:8081:9000
|
36 |
+
networks:
|
37 |
+
- omniboard
|
38 |
+
depends_on:
|
39 |
+
- mongo
|
40 |
+
|
41 |
+
omniboard:
|
42 |
+
#image: vivekratnavel/omniboard:latest
|
43 |
+
build: https://github.com/lucidrains/omniboard.git
|
44 |
+
command: ["--mu", "mongodb://user:password@mongo:27017/db?authSource=admin"]
|
45 |
+
expose:
|
46 |
+
- 9000
|
47 |
+
networks:
|
48 |
+
- omniboard
|
49 |
+
depends_on:
|
50 |
+
- mongo
|
51 |
+
|
52 |
+
nginx:
|
53 |
+
image: dhswt/nginx-basic-auth:1.3
|
54 |
+
environment:
|
55 |
+
- HTPASSWD=isaac: #put passwd here
|
56 |
+
- FORWARD_HOST=omniboard
|
57 |
+
- FORWARD_PORT=9000
|
58 |
+
networks:
|
59 |
+
- omniboard
|
60 |
+
depends_on:
|
61 |
+
- omniboard
|
62 |
+
ports:
|
63 |
+
- 0.0.0.0:8080:80
|
64 |
+
expose:
|
65 |
+
- 8080
|
66 |
+
networks:
|
67 |
+
omniboard:
|
encoders.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tokenizers import Tokenizer
|
2 |
+
from transformers import GPT2Tokenizer, GPT2TokenizerFast
|
3 |
+
|
4 |
+
def fetch_encoder(params):
|
5 |
+
no_dataset = params.get('no_dataset', False)
|
6 |
+
if no_dataset:
|
7 |
+
return None
|
8 |
+
|
9 |
+
dataset = next(iter(params['dataset_configs'].values())) # Get the first value from the dict
|
10 |
+
path = dataset["tokenizer_path"]
|
11 |
+
is_pretrained = dataset.get("tokenizer_is_pretrained", False)
|
12 |
+
|
13 |
+
if is_pretrained:
|
14 |
+
tok = GPT2TokenizerFast.from_pretrained(path)
|
15 |
+
|
16 |
+
# Will add a padding token id of 50257 at run-time
|
17 |
+
tok.add_special_tokens({'pad_token': '<|padding|>'})
|
18 |
+
return tok
|
19 |
+
|
20 |
+
return Tokenizer.from_file(path)
|
21 |
+
|
22 |
+
|
23 |
+
# GPT2Tokenizer and Tokenizer have different ways of fetching token ids
|
24 |
+
def encode(encoder, text, gpt=True):
|
25 |
+
result = encoder.encode(text)
|
26 |
+
if isinstance(result, list):
|
27 |
+
return result
|
28 |
+
return result.ids
|
export.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow.compat.v1 as tf
|
2 |
+
|
3 |
+
def export_model(estimator, export_dir, params,
|
4 |
+
checkpoint_path=None):
|
5 |
+
|
6 |
+
|
7 |
+
def serving_input_receiver_fn():
|
8 |
+
t = tf.placeholder(dtype=tf.int64,
|
9 |
+
shape=[1, params["n_ctx"]],
|
10 |
+
name='input_example_tensor')
|
11 |
+
return tf.estimator.export.ServingInputReceiver(t, t)
|
12 |
+
|
13 |
+
return estimator.export_saved_model(
|
14 |
+
export_dir, serving_input_receiver_fn, checkpoint_path=checkpoint_path)
|
inputs.py
ADDED
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import tensorflow.compat.v1 as tf
|
3 |
+
from functools import partial
|
4 |
+
from data.encoders import encode
|
5 |
+
import random
|
6 |
+
import re
|
7 |
+
import logging
|
8 |
+
from itertools import cycle
|
9 |
+
from utils import natural_sort
|
10 |
+
|
11 |
+
|
12 |
+
### IN USE ###
|
13 |
+
|
14 |
+
def _get_number_of_documents(filename):
|
15 |
+
# extracts number of files from a filename formatted "<name>_<num_documents>.tfrecords."
|
16 |
+
# if no pattern is matched, returns None
|
17 |
+
match = re.search("_(\d{1,}).tfrecords$", filename)
|
18 |
+
return int(match.group(1)) if match is not None else match
|
19 |
+
|
20 |
+
|
21 |
+
def _get_number_of_documents_by_iteration(filename):
|
22 |
+
# extracts number of files from a tfrecord document in the event it doesn't have metadata in the filename
|
23 |
+
# this could be very slow.
|
24 |
+
logging.warning(
|
25 |
+
"inputs/sequential_input() found no metadata found in filename - iterating through first tfrecord to find global length")
|
26 |
+
count = 0
|
27 |
+
for item in tf.io.tf_record_iterator(filename):
|
28 |
+
count += 1
|
29 |
+
return count
|
30 |
+
|
31 |
+
|
32 |
+
def _get_skip_index(all_files, n_batches):
|
33 |
+
prev_cumsum = 0
|
34 |
+
cumsum = 0
|
35 |
+
global_n_documents = None
|
36 |
+
for count, f in cycle(enumerate(all_files)):
|
37 |
+
prev_cumsum = cumsum
|
38 |
+
if _get_number_of_documents(f) is not None:
|
39 |
+
cumsum += _get_number_of_documents(f)
|
40 |
+
elif global_n_documents is None:
|
41 |
+
global_n_documents = _get_number_of_documents_by_iteration(f)
|
42 |
+
cumsum += global_n_documents
|
43 |
+
else:
|
44 |
+
cumsum += global_n_documents
|
45 |
+
if cumsum == n_batches:
|
46 |
+
remainder = 0
|
47 |
+
skip_idx = count + 1
|
48 |
+
elif cumsum > n_batches:
|
49 |
+
remainder = n_batches - prev_cumsum
|
50 |
+
skip_idx = count
|
51 |
+
break
|
52 |
+
return skip_idx, remainder
|
53 |
+
|
54 |
+
|
55 |
+
def _parse_function(example_proto):
|
56 |
+
features = {
|
57 |
+
"text": tf.VarLenFeature(tf.int64)
|
58 |
+
}
|
59 |
+
parsed_features = tf.parse_single_example(example_proto, features)
|
60 |
+
return tf.sparse.to_dense(parsed_features["text"], parsed_features["text"].dense_shape[0])
|
61 |
+
|
62 |
+
|
63 |
+
def autoregressive_sample_text(params, x):
|
64 |
+
vals1 = x[:params["n_ctx"]]
|
65 |
+
vals2 = x[1:params["n_ctx"] + 1]
|
66 |
+
|
67 |
+
vals1 = tf.reshape(vals1, [params["n_ctx"]])
|
68 |
+
vals2 = tf.reshape(vals2, [params["n_ctx"]])
|
69 |
+
vals1 = tf.cast(vals1, dtype=tf.int32)
|
70 |
+
vals2 = tf.cast(vals2, dtype=tf.int32)
|
71 |
+
return vals1, vals2
|
72 |
+
|
73 |
+
|
74 |
+
def sequential_input(params, global_step=None, eval=False):
|
75 |
+
"""
|
76 |
+
Input fn that reads tfrecords encoded with a fixed chunk size (== n_ctx + 1), and that either:
|
77 |
+
|
78 |
+
- has the number of documents for each tfrecord file encoded in the title in the format
|
79 |
+
<name>_<n_documents>.tfrecords.
|
80 |
+
|
81 |
+
OR
|
82 |
+
|
83 |
+
- has a fixed number of documents per tfrecord file.
|
84 |
+
|
85 |
+
If the glob pattern above isn't matched, we assume that each document has the same number of samples as the first tfrecord read.
|
86 |
+
If this isn't the case, it may result in errors, or some samples being missed.
|
87 |
+
|
88 |
+
This means we can calculate the number of samples we've seen so far using the global step,
|
89 |
+
and can use dataset.skip() to iterate through the list of filenames, as opposed to the whole dataset, which is incredibly inefficient.
|
90 |
+
|
91 |
+
If training is starting and stopping often, as with TPU pre-emption, reading the whole dataset sequentially appears to improve model
|
92 |
+
performance, as it results in less repeated data.
|
93 |
+
"""
|
94 |
+
if not eval:
|
95 |
+
assert global_step is not None
|
96 |
+
logging.warning(
|
97 |
+
"Changing batch size with sequential_input() will result in some data being skipped or repeated. Please ensure your batch size stays constant throughout training.")
|
98 |
+
batch_size = params['eval_batch_size' if eval else 'train_batch_size']
|
99 |
+
|
100 |
+
filenames = []
|
101 |
+
for dataset_config in params['dataset_configs'].values(): # iterate through each dataset and read params
|
102 |
+
path_key = 'path' if not eval else 'eval_path'
|
103 |
+
path = dataset_config[path_key]
|
104 |
+
filenames.extend(
|
105 |
+
tf.io.gfile.glob(path)) # then glob all files that fit the pattern specified in dataset_configs
|
106 |
+
|
107 |
+
filenames = natural_sort(filenames)
|
108 |
+
shuffle_filenames = params.get("shuffle_input_filenames", True)
|
109 |
+
if shuffle_filenames:
|
110 |
+
seed = params.get('seed', 1) # shuffle deterministically
|
111 |
+
random.seed(seed)
|
112 |
+
random.shuffle(filenames)
|
113 |
+
|
114 |
+
dataset = tf.data.Dataset.from_tensor_slices(filenames).repeat() # repeat filenames to infinity
|
115 |
+
|
116 |
+
if not eval:
|
117 |
+
# skip forward first in the filenames list, then skip the remaining amount in the parsed tfrecords files
|
118 |
+
skip_idx, remainder = _get_skip_index(filenames, n_batches=global_step * params[
|
119 |
+
"train_batch_size"]) # TODO: fix for > 1 epoch
|
120 |
+
dataset = dataset.skip(skip_idx) # skip to skip idx
|
121 |
+
|
122 |
+
# read tfrecord examples and skip remainder
|
123 |
+
dataset = dataset.apply(tf.data.TFRecordDataset)
|
124 |
+
dataset = dataset.skip(remainder)
|
125 |
+
else:
|
126 |
+
# shuffle filenames if in eval mode
|
127 |
+
dataset = dataset.shuffle(len(filenames))
|
128 |
+
dataset = dataset.apply(tf.data.TFRecordDataset)
|
129 |
+
|
130 |
+
# parse the tokenized data from the tfrecord files and shuffle
|
131 |
+
dataset = dataset.map(_parse_function, num_parallel_calls=1)
|
132 |
+
dataset = dataset.map(partial(autoregressive_sample_text, params), num_parallel_calls=1)
|
133 |
+
|
134 |
+
# batch data and repeat to infinity
|
135 |
+
dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(params["iterations"] * 2)
|
136 |
+
return dataset.repeat()
|
137 |
+
|
138 |
+
|
139 |
+
def pred_input(params, logger, enc=None,
|
140 |
+
path_to_prompt=""):
|
141 |
+
unicorns = "In a shocking finding, scientists discovered a herd of unicorns living in a remote, " \
|
142 |
+
"previously unexplored valley, in the Andes Mountains. Even more surprising to the " \
|
143 |
+
"researchers was the fact that the unicorns spoke perfect English."
|
144 |
+
|
145 |
+
text = unicorns if path_to_prompt == "" else open(path_to_prompt, "r").read()
|
146 |
+
tokens = encode(enc, text)
|
147 |
+
|
148 |
+
if len(tokens) > params["n_ctx"]:
|
149 |
+
logger.info("The length of your input prompt is longer than the model's context length - truncating input.")
|
150 |
+
tokens = tokens[len(tokens) - params["n_ctx"]:]
|
151 |
+
if len(tokens) < params["n_ctx"]:
|
152 |
+
tokens = tf.pad(tokens, [[0, params["n_ctx"] - len(tokens)]], constant_values=params["padding_id"])
|
153 |
+
t = tf.broadcast_to(tokens, [params["batch_size"], params["n_ctx"]])
|
154 |
+
dataset = tf.data.Dataset.from_tensors(t)
|
155 |
+
|
156 |
+
def _dummy_labels(x):
|
157 |
+
return x, x
|
158 |
+
|
159 |
+
dataset = dataset.map(_dummy_labels)
|
160 |
+
return dataset
|
161 |
+
|
162 |
+
|
163 |
+
def handle_pred_output(predictions, logger, enc, params, out_name="test"):
|
164 |
+
with tf.gfile.Open(f"{out_name}.txt", "w") as f:
|
165 |
+
for i, p in enumerate(predictions):
|
166 |
+
p = p["outputs"]
|
167 |
+
|
168 |
+
# remove eos + padding ids from output
|
169 |
+
idx = np.argmax(p == params['eos_id'])
|
170 |
+
if idx > 0:
|
171 |
+
p = p[:idx]
|
172 |
+
idx = np.argmax(p == params['padding_id'])
|
173 |
+
if idx > 0:
|
174 |
+
p = p[:idx]
|
175 |
+
|
176 |
+
text = enc.decode(p)
|
177 |
+
f.write("=" * 40 + " SAMPLE " + str(i) + " " + "=" * 40 + "\n")
|
178 |
+
f.write(text)
|
179 |
+
f.write("\n" + "=" * 80 + "\n")
|
180 |
+
|
181 |
+
logger.info("=" * 40 + " SAMPLE " + str(i) + " " + "=" * 40 + "\n")
|
182 |
+
logger.info(text)
|
183 |
+
logger.info("\n" + "=" * 80 + "\n")
|
184 |
+
|
185 |
+
|
186 |
+
### DEPRECATED ###
|
187 |
+
|
188 |
+
def generic_text(params, eval=False, sample_text_fn=None, **kwargs):
|
189 |
+
logging.warning("DEPRECATION WARNING: generic_text will be phased out in future versions.")
|
190 |
+
i = 0 if not eval else 1
|
191 |
+
|
192 |
+
weights = []
|
193 |
+
datasets = []
|
194 |
+
|
195 |
+
for dataset in params["datasets"]:
|
196 |
+
dataset_id, stitch, datatype, weight = dataset
|
197 |
+
|
198 |
+
assert dataset_id in params[
|
199 |
+
'dataset_configs'], f'Unknown dataset id {dataset_id} given. Please make sure your dataset ids contain that configuration'
|
200 |
+
dataset_config = params['dataset_configs'][dataset_id]
|
201 |
+
|
202 |
+
path_key = 'path' if not eval else 'eval_path'
|
203 |
+
path = dataset_config[path_key]
|
204 |
+
|
205 |
+
datasets.append(text_dataset(
|
206 |
+
tf.io.gfile.glob(path),
|
207 |
+
params,
|
208 |
+
stitch=stitch,
|
209 |
+
datatype=datatype,
|
210 |
+
batch=False,
|
211 |
+
sample_text_fn=sample_text_fn
|
212 |
+
))
|
213 |
+
|
214 |
+
weights.append(weight)
|
215 |
+
|
216 |
+
batch_size = params['eval_batch_size' if eval else 'train_batch_size']
|
217 |
+
|
218 |
+
seed = params.get('seed', None)
|
219 |
+
dataset = tf.data.experimental.sample_from_datasets(datasets, weights=weights, seed=seed)
|
220 |
+
dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(params["iterations"] * 2)
|
221 |
+
return dataset
|
222 |
+
|
223 |
+
|
224 |
+
def text_dataset(files, params, stitch, datatype, batch=True, sample_text_fn=None):
|
225 |
+
seed = params.get('seed', None)
|
226 |
+
deterministic = seed is not None
|
227 |
+
num_parallel_calls = 1 if deterministic else tf.data.experimental.AUTOTUNE
|
228 |
+
|
229 |
+
dataset = tf.data.Dataset.from_tensor_slices(files)
|
230 |
+
|
231 |
+
if deterministic:
|
232 |
+
dataset = dataset.interleave(tf.data.TFRecordDataset, cycle_length=4)
|
233 |
+
else:
|
234 |
+
dataset = dataset.apply(
|
235 |
+
tf.data.experimental.parallel_interleave(tf.data.TFRecordDataset, cycle_length=4, sloppy=False))
|
236 |
+
|
237 |
+
if "documents" in datatype:
|
238 |
+
def _parse_function(example_proto):
|
239 |
+
features = {
|
240 |
+
# "hash": tf.VarLenFeature(tf.string),
|
241 |
+
"text": tf.VarLenFeature(tf.int64)
|
242 |
+
}
|
243 |
+
parsed_features = tf.parse_single_example(example_proto, features)
|
244 |
+
return parsed_features["text"], parsed_features["text"].dense_shape[0]
|
245 |
+
else:
|
246 |
+
def _parse_function(example_proto):
|
247 |
+
features = {
|
248 |
+
"text": tf.VarLenFeature(tf.int64)
|
249 |
+
}
|
250 |
+
parsed_features = tf.parse_single_example(example_proto, features)
|
251 |
+
return parsed_features["text"] # Assuming the text is not sparse
|
252 |
+
|
253 |
+
dataset = dataset.map(_parse_function, num_parallel_calls=1)
|
254 |
+
|
255 |
+
# Subsample method
|
256 |
+
if "documents" in datatype:
|
257 |
+
# Since samples can be less than the correct length, and TPUs don't like variable lengths, this function stitches together enough samples
|
258 |
+
# to have a text at least 1024 tokens long. For this to work the stitch parameter must be correctly tuned so that
|
259 |
+
# stitch * min(characters_in_text) >= amount
|
260 |
+
def _stitch_text(x, y):
|
261 |
+
x = tf.sparse.to_dense(x)
|
262 |
+
|
263 |
+
def _get_x(i):
|
264 |
+
return tf.gather(x[i], tf.range(y[i]))
|
265 |
+
|
266 |
+
out = _get_x(0)
|
267 |
+
eos_id = params['eos_id']
|
268 |
+
|
269 |
+
for i in range(1, stitch):
|
270 |
+
out = tf.concat([out, [eos_id], _get_x(i)], axis=0) # text1<|endoftext|>text2
|
271 |
+
|
272 |
+
return out
|
273 |
+
|
274 |
+
# Hack-y way to stitch together multiple texts
|
275 |
+
|
276 |
+
dataset = dataset.shuffle(1000 * stitch, seed=seed).batch(stitch, drop_remainder=True).map(_stitch_text,
|
277 |
+
num_parallel_calls=num_parallel_calls)
|
278 |
+
|
279 |
+
# Sample 1024(+1) tokens from the stitched together text
|
280 |
+
is_random_documents = datatype == "documents_random"
|
281 |
+
if sample_text_fn is not None:
|
282 |
+
_sample_text = partial(sample_text_fn, random_documents=is_random_documents)
|
283 |
+
else:
|
284 |
+
_sample_text = autoregressive_sample_text_random_documents if is_random_documents else autoregressive_sample_text
|
285 |
+
_sample_text = partial(_sample_text, params)
|
286 |
+
|
287 |
+
dataset = dataset.map(_sample_text, num_parallel_calls=num_parallel_calls)
|
288 |
+
|
289 |
+
if batch:
|
290 |
+
dataset = dataset.batch(params["train_batch_size"], drop_remainder=True).prefetch(params["iterations"] * 2)
|
291 |
+
|
292 |
+
dataset = dataset.repeat()
|
293 |
+
|
294 |
+
return dataset
|
295 |
+
|
296 |
+
|
297 |
+
def autoregressive_sample_text_random_documents(params, x):
|
298 |
+
seed = params.get('seed', None)
|
299 |
+
s = tf.size(x)
|
300 |
+
r = tf.random.uniform([], maxval=s - (params["n_ctx"] + 1), dtype=tf.dtypes.int32, seed=seed)
|
301 |
+
r1 = tf.range(r, r + params["n_ctx"])
|
302 |
+
r2 = tf.range(r + 1, (r + 1) + params["n_ctx"])
|
303 |
+
r1 = tf.reshape(r1, [params["n_ctx"]]) # Somehow, this makes the compiler happy
|
304 |
+
r2 = tf.reshape(r2, [params[
|
305 |
+
"n_ctx"]]) # TPUs want constant sized input, and these reshapes makes it recognize the shape of the input
|
306 |
+
vals1 = tf.gather(x, r1)
|
307 |
+
vals2 = tf.gather(x, r2)
|
308 |
+
|
309 |
+
vals1 = tf.reshape(vals1, [params["n_ctx"]])
|
310 |
+
vals2 = tf.reshape(vals2, [params["n_ctx"]])
|
311 |
+
vals1 = tf.cast(vals1, dtype=tf.int32)
|
312 |
+
vals2 = tf.cast(vals2, dtype=tf.int32)
|
313 |
+
return vals1, vals2
|
314 |
+
|
315 |
+
|
316 |
+
def mlm_sample_text(params, x, random_documents=False):
|
317 |
+
seed = params.get('seed', None)
|
318 |
+
ctx_len = params["n_ctx"]
|
319 |
+
assert 'mlm_mask_id' in params, 'the key `mlm_mask_id` must be set on your config to do masked language model training, specifying the id of the reserved mask token'
|
320 |
+
|
321 |
+
mask_id = params['mlm_mask_id']
|
322 |
+
cls_token_id = params.get('mlm_cls_token_id', None)
|
323 |
+
num_tokens = params.get('n_vocab', None)
|
324 |
+
|
325 |
+
mask_ignore_ids = set(params.get('mlm_mask_ignore_ids', []))
|
326 |
+
mask_ignore_ids.add(cls_token_id)
|
327 |
+
|
328 |
+
mask_prob = params.get('mlm_mask_prob', 0.15)
|
329 |
+
same_token_prob = params.get('mlm_same_token_prob', 0.10)
|
330 |
+
random_token_prob = params.get('mlm_random_token_prob', 0.)
|
331 |
+
|
332 |
+
seq_len = ctx_len if cls_token_id is None else (ctx_len - 1)
|
333 |
+
|
334 |
+
if random_documents:
|
335 |
+
s = tf.size(x)
|
336 |
+
r = tf.random.uniform([], maxval=(s - seq_len), dtype=tf.dtypes.int32, seed=seed)
|
337 |
+
r1 = tf.range(r, r + seq_len)
|
338 |
+
r1 = tf.reshape(r1, [seq_len])
|
339 |
+
features = tf.gather(x, r1)
|
340 |
+
else:
|
341 |
+
features = x[:seq_len]
|
342 |
+
|
343 |
+
# add cls token id if specified by `mlm_cls_token_id`
|
344 |
+
if cls_token_id is not None:
|
345 |
+
features = tf.pad(features, [[1, 0]], constant_values=cls_token_id)
|
346 |
+
|
347 |
+
features = tf.cast(features, dtype=tf.int32)
|
348 |
+
shape = features.shape
|
349 |
+
|
350 |
+
# determine which tokens are mask-able
|
351 |
+
can_mask = tf.not_equal(features, 0)
|
352 |
+
for ignore_id in mask_ignore_ids:
|
353 |
+
can_mask &= tf.not_equal(features, ignore_id)
|
354 |
+
|
355 |
+
# generate boolean mask for masking ids
|
356 |
+
mask_mask = tf.less(tf.random.uniform(shape, minval=0., maxval=1., dtype=tf.float32, seed=seed), mask_prob)
|
357 |
+
mask_mask &= can_mask
|
358 |
+
|
359 |
+
# generate mask for actually replacing the tokens, for allowing a small number of tokens to stay the same
|
360 |
+
replace_mask = tf.less(tf.random.uniform(shape, minval=0., maxval=1., dtype=tf.float32, seed=seed),
|
361 |
+
1 - same_token_prob)
|
362 |
+
|
363 |
+
# randomly replace some tokens with random tokens before masking
|
364 |
+
if random_token_prob > 0:
|
365 |
+
random_token_mask = tf.less(tf.random.uniform(shape, minval=0., maxval=1., dtype=tf.float32, seed=seed),
|
366 |
+
random_token_prob)
|
367 |
+
random_tokens = tf.random.uniform(shape, minval=1, maxval=num_tokens, dtype=tf.dtypes.int32, seed=seed)
|
368 |
+
|
369 |
+
# make sure random tokens do not include illegal token ids specified by `mlm_mask_ignore_ids`
|
370 |
+
random_can_mask = tf.not_equal(random_tokens, 0)
|
371 |
+
for ignore_id in mask_ignore_ids:
|
372 |
+
random_can_mask &= tf.not_equal(random_tokens, ignore_id)
|
373 |
+
|
374 |
+
features = tf.where(random_token_mask & random_can_mask, random_tokens, features)
|
375 |
+
|
376 |
+
# mask the tokens
|
377 |
+
mask_tokens = tf.ones(shape, dtype=tf.int32) * mask_id
|
378 |
+
masked_features = tf.where(mask_mask & replace_mask, mask_tokens, features)
|
379 |
+
|
380 |
+
# labels will be set to 0 for all non-masked tokens
|
381 |
+
labels = tf.where(mask_mask, tf.zeros(shape, dtype=tf.int32), features)
|
382 |
+
|
383 |
+
masked_features, labels = map(lambda t: tf.reshape(t, [ctx_len]), (masked_features, labels))
|
384 |
+
return masked_features, labels
|
main.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""GPT-like model in Mesh-Tensorflow"""
|
2 |
+
|
3 |
+
from functools import partial
|
4 |
+
import mesh_tensorflow as mtf
|
5 |
+
import tensorflow.compat.v1 as tf
|
6 |
+
from tensorflow.python.tpu import tpu_config, tpu_estimator
|
7 |
+
from tensorflow_estimator.python.estimator import estimator as estimator_lib
|
8 |
+
from utils import save_config, expand_attention_types_params, yes_or_no, remove_gs_or_filepath, setup_logging, \
|
9 |
+
check_dataset
|
10 |
+
from inputs import sequential_input, pred_input, handle_pred_output, mlm_sample_text, generic_text
|
11 |
+
from export import export_model
|
12 |
+
from model_fns import model_fn
|
13 |
+
from data.encoders import fetch_encoder
|
14 |
+
from configs import fetch_model_params
|
15 |
+
from tasks import task_descriptors
|
16 |
+
import argparse
|
17 |
+
import json
|
18 |
+
import numpy
|
19 |
+
|
20 |
+
|
21 |
+
def parse_args():
|
22 |
+
# Parse command line arguments
|
23 |
+
parser = argparse.ArgumentParser()
|
24 |
+
parser.add_argument("--tpu", type=str, help="Name of TPU to train on, if any.")
|
25 |
+
parser.add_argument("--gpu_ids", nargs="+", type=str, default=["device:GPU:0"],
|
26 |
+
help="If training on GPU, can specify your GPU names in a list - i.e 'device:GPU:0 device:GPU:1'")
|
27 |
+
parser.add_argument("--model", type=str, default=None, help="JSON file that contains model parameters.")
|
28 |
+
parser.add_argument("--steps_per_checkpoint", type=int, default=5000, help="Save a model checkpoint every X steps.")
|
29 |
+
parser.add_argument("--auto_layout", action="store_true", help="If set, generates and prints the most memory "
|
30 |
+
"efficient layout according to MTF auto layout.")
|
31 |
+
parser.add_argument("--auto_layout_and_mesh_shape", action="store_true",
|
32 |
+
help="If set, generates and prints the most memory efficient layout and mesh shape according to"
|
33 |
+
" MTF auto layout.")
|
34 |
+
parser.add_argument("--new", action="store_true", help="If set, deletes previous checkpoint, if it exists, and "
|
35 |
+
"starts a new training run")
|
36 |
+
parser.add_argument("--predict", action="store_true", help="If set, uses the model to predict rather than train.")
|
37 |
+
parser.add_argument("--eval", action="store_true", help="If set, run model in evaluation mode.")
|
38 |
+
parser.add_argument("--prompt", type=str, help="path to .txt file containing a prompt for prediction. If empty, "
|
39 |
+
"defaults to unicorns.",
|
40 |
+
default="")
|
41 |
+
parser.add_argument("--check_dataset", action="store_true",
|
42 |
+
help="If set, outputs sample from the dataset and quits.")
|
43 |
+
parser.add_argument("--sacred_id", type=str, default="nosacred", help="Sacred run id.")
|
44 |
+
parser.add_argument("--entmax_sampling", action="store_true", help="(experimental) use entmax sampling")
|
45 |
+
parser.add_argument("--export", action="store_true", help="If set, will export the model.")
|
46 |
+
args = parser.parse_args()
|
47 |
+
assert args.model is not None, "Model must be set"
|
48 |
+
return args
|
49 |
+
|
50 |
+
|
51 |
+
def main(args):
|
52 |
+
# Setup logging
|
53 |
+
logger = setup_logging(args)
|
54 |
+
|
55 |
+
# Read params of model
|
56 |
+
params = fetch_model_params(args.model)
|
57 |
+
|
58 |
+
# Fetch appropriate input functions
|
59 |
+
input_fn = params.get("input_fn", "sequential_input")
|
60 |
+
if input_fn == "sequential_input":
|
61 |
+
input_fn = sequential_input
|
62 |
+
elif input_fn == "generic_text":
|
63 |
+
input_fn = generic_text
|
64 |
+
pred_input_fn = pred_input
|
65 |
+
handle_pred_output_fn = handle_pred_output
|
66 |
+
|
67 |
+
# get current step
|
68 |
+
current_step = int(estimator_lib._load_global_step_from_checkpoint_dir(params["model_path"]))
|
69 |
+
logger.info(f"Current step {current_step}")
|
70 |
+
|
71 |
+
if params["mlm_training"]:
|
72 |
+
mlm_sample_text_fn = partial(mlm_sample_text, params)
|
73 |
+
input_fn = partial(generic_text, sample_text_fn=mlm_sample_text_fn)
|
74 |
+
if args.check_dataset:
|
75 |
+
check_dataset(input_fn, params)
|
76 |
+
|
77 |
+
|
78 |
+
# Fetch encoder per params
|
79 |
+
encoder = fetch_encoder(params)
|
80 |
+
|
81 |
+
pred_input_fn = partial(pred_input_fn, path_to_prompt=args.prompt, logger=logger, enc=encoder)
|
82 |
+
|
83 |
+
# Sample from Dataset if check dataset flag is on
|
84 |
+
if args.check_dataset:
|
85 |
+
check_dataset(input_fn, params, global_step=current_step)
|
86 |
+
|
87 |
+
# Confirm deletion of checkpoint files if --new flag is set
|
88 |
+
if args.new:
|
89 |
+
if yes_or_no(f"Are you sure you want to remove '{params['model_path']}' to start afresh?"):
|
90 |
+
remove_gs_or_filepath(params["model_path"])
|
91 |
+
else:
|
92 |
+
exit()
|
93 |
+
|
94 |
+
# Save config to logdir for experiment management
|
95 |
+
save_config(params, params["model_path"])
|
96 |
+
|
97 |
+
# Add to params: auto_layout, auto_layout_and_mesh_shape, use_tpu, num_cores
|
98 |
+
mesh_shape = mtf.convert_to_shape(params["mesh_shape"])
|
99 |
+
params["num_cores"] = mesh_shape.size
|
100 |
+
params["auto_layout"] = args.auto_layout
|
101 |
+
params["auto_layout_and_mesh_shape"] = args.auto_layout_and_mesh_shape
|
102 |
+
params["use_tpu"] = True if not args.tpu is None else False
|
103 |
+
params["gpu_ids"] = args.gpu_ids
|
104 |
+
params["steps_per_checkpoint"] = args.steps_per_checkpoint
|
105 |
+
# Expand attention types param
|
106 |
+
params["attention_types"] = expand_attention_types_params(params["attention_types"])
|
107 |
+
assert len(params["attention_types"]) == params["n_layer"] # Assert that the length of expanded list = num layers
|
108 |
+
params["predict_batch_size"] = params.get("predict_batch_size", 1) # Default to 1
|
109 |
+
params["predict"] = args.predict
|
110 |
+
params['model'] = params.get("model", "GPT") # Default model selection to GPT since it's the only option for now
|
111 |
+
params["export"] = args.export
|
112 |
+
# Set sampling parameters
|
113 |
+
params["sampling_use_entmax"] = args.entmax_sampling
|
114 |
+
|
115 |
+
# Sample quality of MoE models suffers when using the faster sampling method, so default to slow_sampling if
|
116 |
+
# moe layers are present
|
117 |
+
params["slow_sampling"] = True if params["moe_layers"] is not None else False
|
118 |
+
|
119 |
+
logger.info(f"params = {params}")
|
120 |
+
|
121 |
+
# Get eval tasks from params
|
122 |
+
eval_tasks = params.get("eval_tasks", [])
|
123 |
+
has_predict_or_eval_steps_or_eval_tasks = params["predict_steps"] > 0 or params["eval_steps"] > 0 or len(
|
124 |
+
eval_tasks) > 0
|
125 |
+
|
126 |
+
for t in eval_tasks:
|
127 |
+
assert t in task_descriptors, f"Eval task '{t}' is not known"
|
128 |
+
task_descriptors[t]["init_fn"](params)
|
129 |
+
|
130 |
+
# Set up TPUs and Estimator
|
131 |
+
if args.tpu == "colab":
|
132 |
+
tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver() if params["use_tpu"] else None
|
133 |
+
else:
|
134 |
+
tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(args.tpu) if params["use_tpu"] else None
|
135 |
+
|
136 |
+
config = tpu_config.RunConfig(
|
137 |
+
cluster=tpu_cluster_resolver,
|
138 |
+
model_dir=params["model_path"],
|
139 |
+
save_checkpoints_steps=None, # Disable the default saver
|
140 |
+
save_checkpoints_secs=None, # Disable the default saver
|
141 |
+
log_step_count_steps=params["iterations"],
|
142 |
+
save_summary_steps=params["iterations"],
|
143 |
+
tpu_config=tpu_config.TPUConfig(
|
144 |
+
num_shards=mesh_shape.size,
|
145 |
+
iterations_per_loop=params["iterations"],
|
146 |
+
num_cores_per_replica=1,
|
147 |
+
per_host_input_for_training=tpu_config.InputPipelineConfig.BROADCAST))
|
148 |
+
|
149 |
+
estimator = tpu_estimator.TPUEstimator(
|
150 |
+
use_tpu=params["use_tpu"],
|
151 |
+
model_fn=model_fn,
|
152 |
+
config=config,
|
153 |
+
train_batch_size=params["train_batch_size"],
|
154 |
+
eval_batch_size=params["train_batch_size"],
|
155 |
+
predict_batch_size=params["predict_batch_size"],
|
156 |
+
params=params)
|
157 |
+
|
158 |
+
def _make_task_estimator(task):
|
159 |
+
task_params = params.copy()
|
160 |
+
task_params["eval_task"] = task
|
161 |
+
return tpu_estimator.TPUEstimator(
|
162 |
+
use_tpu=params["use_tpu"],
|
163 |
+
model_fn=model_fn,
|
164 |
+
config=config,
|
165 |
+
train_batch_size=params["train_batch_size"],
|
166 |
+
eval_batch_size=params["eval_batch_size"],
|
167 |
+
predict_batch_size=params["predict_batch_size"],
|
168 |
+
params=task_params)
|
169 |
+
|
170 |
+
eval_task_estimators = {
|
171 |
+
task: _make_task_estimator(task)
|
172 |
+
for task in eval_tasks
|
173 |
+
}
|
174 |
+
|
175 |
+
if args.export:
|
176 |
+
export_model(estimator, "export", params)
|
177 |
+
return
|
178 |
+
|
179 |
+
if args.predict:
|
180 |
+
# Predict
|
181 |
+
predictions = estimator.predict(input_fn=pred_input_fn)
|
182 |
+
logger.info("Predictions generated")
|
183 |
+
enc = fetch_encoder(params)
|
184 |
+
handle_pred_output_fn(predictions, logger, enc, params, out_name=f"predictions_{args.sacred_id}_{current_step}")
|
185 |
+
return
|
186 |
+
|
187 |
+
def save_eval_results(task, eval_results):
|
188 |
+
def as_python(x):
|
189 |
+
if isinstance(x, numpy.generic):
|
190 |
+
return x.item()
|
191 |
+
return x
|
192 |
+
eval_results = {k: as_python(v) for k, v in eval_results.items()}
|
193 |
+
with open(f'eval_{args.sacred_id}.jsonl', 'a') as fh:
|
194 |
+
json.dump({'task': task, 'current_step': current_step, **eval_results}, fh)
|
195 |
+
fh.write('\n')
|
196 |
+
|
197 |
+
def run_eval():
|
198 |
+
logger.info("Running evaluation...")
|
199 |
+
eval_results = estimator.evaluate(
|
200 |
+
input_fn=partial(input_fn, eval=True),
|
201 |
+
steps=params["eval_steps"])
|
202 |
+
logger.info(f"Eval results: {eval_results}")
|
203 |
+
save_eval_results('validation', eval_results)
|
204 |
+
|
205 |
+
def run_eval_tasks():
|
206 |
+
for task in eval_tasks:
|
207 |
+
logger.info(f"Starting evaluation task '{task}'")
|
208 |
+
task_info = task_descriptors[task]["get_task_info_fn"](params)
|
209 |
+
task_estimator = eval_task_estimators[task]
|
210 |
+
task_input_fn = task_descriptors[task]["input_fn"]
|
211 |
+
eval_results = task_estimator.evaluate(
|
212 |
+
input_fn=task_input_fn,
|
213 |
+
steps=task_info["n_steps"],
|
214 |
+
name=task)
|
215 |
+
logger.info(f"Eval task '{task}' results: {eval_results}")
|
216 |
+
save_eval_results(task, eval_results)
|
217 |
+
|
218 |
+
if args.eval:
|
219 |
+
run_eval_tasks()
|
220 |
+
if params["eval_steps"] > 0:
|
221 |
+
run_eval()
|
222 |
+
return
|
223 |
+
|
224 |
+
|
225 |
+
elif has_predict_or_eval_steps_or_eval_tasks:
|
226 |
+
# Eval and train - stop and predict and/or eval every checkpoint
|
227 |
+
while current_step < params["train_steps"]:
|
228 |
+
next_checkpoint = min(current_step + args.steps_per_checkpoint,
|
229 |
+
params["train_steps"])
|
230 |
+
|
231 |
+
estimator.train(input_fn=partial(input_fn, global_step=current_step, eval=False), max_steps=next_checkpoint)
|
232 |
+
current_step = next_checkpoint
|
233 |
+
|
234 |
+
if params["predict_steps"] > 0:
|
235 |
+
logger.info("Running prediction...")
|
236 |
+
predictions = estimator.predict(input_fn=pred_input_fn)
|
237 |
+
enc = fetch_encoder(params)
|
238 |
+
handle_pred_output_fn(predictions, logger, enc, params, out_name=f"predictions_{args.sacred_id}_{current_step}")
|
239 |
+
|
240 |
+
if params["eval_steps"] > 0:
|
241 |
+
run_eval()
|
242 |
+
|
243 |
+
if eval_tasks:
|
244 |
+
run_eval_tasks()
|
245 |
+
|
246 |
+
return
|
247 |
+
else:
|
248 |
+
# Else, just train
|
249 |
+
while current_step < params["train_steps"]:
|
250 |
+
# Else, don't stop and restart
|
251 |
+
estimator.train(input_fn=partial(input_fn, global_step=current_step, eval=False), max_steps=params["train_steps"])
|
252 |
+
|
253 |
+
|
254 |
+
if __name__ == "__main__":
|
255 |
+
tf.disable_v2_behavior()
|
256 |
+
args = parse_args()
|
257 |
+
main(args)
|
model_fns.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import mesh_tensorflow as mtf
|
2 |
+
import tensorflow.compat.v1 as tf
|
3 |
+
from tensorflow.python.tpu import tpu_estimator
|
4 |
+
import mesh_tensorflow.transformer as mtf_transformer
|
5 |
+
from optimizers import get_optimizer
|
6 |
+
from utils import (create_host_call, get_graph_info, remove_batch_from_layout, simd_mesh_setup, add_mode_to_params,
|
7 |
+
get_batch_size, auto_layout, auto_layout_and_mesh_shape)
|
8 |
+
from models.utils import biasmask_attn_weights
|
9 |
+
from tensorflow.python.ops import resources
|
10 |
+
from sample import sample_autoregressive
|
11 |
+
from models.gpt2 import gpt2
|
12 |
+
import math
|
13 |
+
|
14 |
+
|
15 |
+
def model_fn(features, labels, mode, params):
|
16 |
+
# Get global step
|
17 |
+
global_step = tf.train.get_global_step()
|
18 |
+
|
19 |
+
# Construct mtf graph + mesh from params
|
20 |
+
graph = mtf.Graph()
|
21 |
+
mesh_shape = mtf.convert_to_shape(params["mesh_shape"])
|
22 |
+
layout_rules = mtf.convert_to_layout_rules(params["layout"])
|
23 |
+
|
24 |
+
# Mesh setup
|
25 |
+
if params["use_tpu"]:
|
26 |
+
var_placer, mesh_impl = simd_mesh_setup(params, mesh_shape, layout_rules)
|
27 |
+
else:
|
28 |
+
var_placer = None
|
29 |
+
gpu_ids = params["gpu_ids"]
|
30 |
+
mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
|
31 |
+
mesh_shape, layout_rules, gpu_ids)
|
32 |
+
|
33 |
+
# Trainable variable precision
|
34 |
+
# Store to checkpoints in master type, train in slice type, compute in activation type
|
35 |
+
if params["precision"] == "bfloat16":
|
36 |
+
variable_dtype = mtf.VariableDType(master_dtype=tf.bfloat16, slice_dtype=tf.float32,
|
37 |
+
activation_dtype=tf.bfloat16)
|
38 |
+
else:
|
39 |
+
variable_dtype = mtf.VariableDType(master_dtype=tf.float32, slice_dtype=tf.float32, activation_dtype=tf.float32)
|
40 |
+
|
41 |
+
# Build mtf mesh object
|
42 |
+
mesh = mtf.Mesh(graph, "my_mesh", var_placer)
|
43 |
+
|
44 |
+
# Build mtf_features & seq length dict for getting number of microbatches
|
45 |
+
# We need to pack inputs into a dict to pass into serialize_training_step
|
46 |
+
features_dict = {"inputs": features, "labels": labels}
|
47 |
+
sequence_length_dict = {"inputs": params["n_ctx"], "labels": params["n_ctx"]}
|
48 |
+
|
49 |
+
params = add_mode_to_params(params, mode)
|
50 |
+
batch_size = get_batch_size(params)
|
51 |
+
|
52 |
+
batch_dim = mtf.Dimension("batch", batch_size)
|
53 |
+
batch_dims = [batch_dim]
|
54 |
+
feature_length = sequence_length_dict["inputs"]
|
55 |
+
length_dim = mtf.Dimension("sequence", feature_length)
|
56 |
+
|
57 |
+
mtf_features = {}
|
58 |
+
for key, x in features_dict.items():
|
59 |
+
if x is not None:
|
60 |
+
feature_shape = mtf.Shape(batch_dims + [length_dim])
|
61 |
+
if type(features_dict[key]) == dict:
|
62 |
+
features_dict[key] = features_dict[key]["feature"]
|
63 |
+
x = tf.cast(features_dict[key], tf.int32)
|
64 |
+
x = tf.reshape(x, feature_shape.to_integer_list)
|
65 |
+
mtf_features[key] = mtf.import_fully_replicated(
|
66 |
+
mesh, x, feature_shape, name=key)
|
67 |
+
|
68 |
+
# Instantiate dict for dimensions, bias, etc that can be calculated here once then passed into model
|
69 |
+
other_features = {}
|
70 |
+
memory_length_dim = mtf.Dimension("memory_length", length_dim.size)
|
71 |
+
|
72 |
+
attn_bias = biasmask_attn_weights(mesh, length_dim, memory_length_dim, variable_dtype) if params["causal"] else None
|
73 |
+
|
74 |
+
# Add attn_bias into mtf_features
|
75 |
+
other_features["attn_bias"] = attn_bias
|
76 |
+
|
77 |
+
# Define other Dimensions that we'll need inside the model
|
78 |
+
embd_dim = mtf.Dimension("embd", params["n_embd"])
|
79 |
+
vocab_dim = mtf.Dimension("vocab", params["n_vocab"])
|
80 |
+
# We need this because gathering when both the args have the same dimension in them breaks things
|
81 |
+
# This dim is specifically for the weights
|
82 |
+
# This prevents the "Einsum has lhs dimension without corresponding rhs or output dimension." error
|
83 |
+
embed_sequence_dim = mtf.Dimension("embed_sequence", params["n_ctx"])
|
84 |
+
|
85 |
+
other_features["embd_dim"] = embd_dim
|
86 |
+
other_features["vocab_dim"] = vocab_dim
|
87 |
+
other_features["embed_sequence_dim"] = embed_sequence_dim
|
88 |
+
other_features["memory_length_dim"] = memory_length_dim
|
89 |
+
|
90 |
+
if mode == tf.estimator.ModeKeys.PREDICT:
|
91 |
+
# Set up the model for prediction
|
92 |
+
inputs = mtf_features["inputs"]
|
93 |
+
if params["remove_partial_sequences"] is None:
|
94 |
+
params["remove_partial_sequences"] = False
|
95 |
+
|
96 |
+
export = params.get("export", False)
|
97 |
+
|
98 |
+
if not export:
|
99 |
+
mtf_samples = sample_autoregressive(
|
100 |
+
inputs, other_features=other_features, params=params, variable_dtype=variable_dtype,
|
101 |
+
remove_partial_sequences=params["remove_partial_sequences"], stop_at_token=params["eos_id"],
|
102 |
+
sampling_use_entmax=params['sampling_use_entmax'], max_steps=params["predict_max_steps"])
|
103 |
+
|
104 |
+
else:
|
105 |
+
with mtf.utils.outside_all_rewrites():
|
106 |
+
with tf.variable_scope('gpt2'):
|
107 |
+
mtf_samples, loss, loss_batch = gpt2.model(mtf_features, other_features, params, mesh,
|
108 |
+
variable_dtype=variable_dtype, context=None)
|
109 |
+
|
110 |
+
mtf_samples = mtf.anonymize(mtf_samples)
|
111 |
+
inputs = mtf.anonymize(inputs)
|
112 |
+
lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True)
|
113 |
+
inputs = lowering.export_to_tf_tensor(inputs)
|
114 |
+
outputs = lowering.export_to_tf_tensor(mtf_samples)
|
115 |
+
predictions = {
|
116 |
+
"inputs": inputs,
|
117 |
+
"outputs": outputs}
|
118 |
+
|
119 |
+
def scaffold_fn():
|
120 |
+
return tf.train.Scaffold(
|
121 |
+
local_init_op=tf.group(
|
122 |
+
tf.train.Scaffold.default_local_init_op(),
|
123 |
+
lowering.copy_masters_to_slices(),
|
124 |
+
name="mtf_local_init_op"),
|
125 |
+
ready_op=tf.concat(
|
126 |
+
[tf.report_uninitialized_variables(),
|
127 |
+
resources.report_uninitialized_resources()],
|
128 |
+
axis=0,
|
129 |
+
name="mtf_ready_op"))
|
130 |
+
|
131 |
+
return tpu_estimator.TPUEstimatorSpec(
|
132 |
+
mode=tf.estimator.ModeKeys.PREDICT,
|
133 |
+
predictions=predictions,
|
134 |
+
scaffold_fn=scaffold_fn,
|
135 |
+
prediction_hooks=[mtf.MtfRestoreHook(lowering)])
|
136 |
+
|
137 |
+
# We're not predicting, so we better be training or evaluating
|
138 |
+
assert (mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL)
|
139 |
+
|
140 |
+
if mode == tf.estimator.ModeKeys.TRAIN:
|
141 |
+
# Gets number of microbatches per batch for serialized training
|
142 |
+
# if param tokens_per_mb_per_replica = None, this defaults to 1 and no microbatching is performed
|
143 |
+
num_microbatches = int(mtf_transformer.utils.serialize_num_microbatches(batch_dim=batch_dim,
|
144 |
+
sequence_length=sequence_length_dict,
|
145 |
+
mesh_shape=mesh_shape,
|
146 |
+
layout_rules=layout_rules,
|
147 |
+
tokens_per_microbatch_per_replica=
|
148 |
+
params["tokens_per_mb_per_replica"]))
|
149 |
+
else:
|
150 |
+
num_microbatches = 1
|
151 |
+
|
152 |
+
params["num_microbatches"] = num_microbatches # Add num microbatches to params
|
153 |
+
|
154 |
+
if num_microbatches > 1:
|
155 |
+
|
156 |
+
# For serialize_training_step we need to modify the model to output results in a dict
|
157 |
+
def serialized_fn(mtf_features):
|
158 |
+
if params["model"] == "GPT":
|
159 |
+
with tf.variable_scope('gpt2'):
|
160 |
+
logits, loss, loss_batch = gpt2.model(mtf_features, other_features, params, mesh,
|
161 |
+
variable_dtype=variable_dtype)
|
162 |
+
return {"logits": logits, "loss": loss, "loss_batch": loss_batch}
|
163 |
+
else:
|
164 |
+
raise Exception(f"'{params['model']}' is not a valid model - please select from [GPT]")
|
165 |
+
|
166 |
+
# Serialize the training step - Gradients are accumulated locally and reduced once.
|
167 |
+
var_grads, output_dict = mtf.serialize_training_step(mtf_features, serialized_fn, batch_dim, num_microbatches)
|
168 |
+
loss = output_dict["loss"]
|
169 |
+
loss_batch = output_dict["loss_batch"]
|
170 |
+
logits = output_dict["logits"]
|
171 |
+
else:
|
172 |
+
# If we're not splitting into microbatches, return logits & loss as is
|
173 |
+
if params["model"] == "GPT":
|
174 |
+
with mtf.utils.outside_all_rewrites():
|
175 |
+
with tf.variable_scope('gpt2'):
|
176 |
+
logits, loss, loss_batch = gpt2.model(mtf_features, other_features, params, mesh,
|
177 |
+
variable_dtype=variable_dtype, context=None)
|
178 |
+
else:
|
179 |
+
raise Exception(f"'{params['model']}' is not a valid model - please select from [GPT]")
|
180 |
+
|
181 |
+
# Auto layout generation
|
182 |
+
if params["auto_layout"]:
|
183 |
+
auto_layout(graph, mesh_shape, logits, loss)
|
184 |
+
if params["auto_layout_and_mesh_shape"]:
|
185 |
+
auto_layout_and_mesh_shape(graph, params["num_cores"], logits, loss)
|
186 |
+
|
187 |
+
if mode == tf.estimator.ModeKeys.TRAIN:
|
188 |
+
# In TRAIN mode, get optimizer
|
189 |
+
if params["num_microbatches"] > 1:
|
190 |
+
# If we are splitting the batch into microbatches, var grads are created in the serialize_training_step fn
|
191 |
+
# So we pass them in here
|
192 |
+
_, update_ops, var_grads = get_optimizer(mesh, loss, params, variable_dtype=variable_dtype,
|
193 |
+
inp_var_grads=var_grads)
|
194 |
+
else:
|
195 |
+
# Otherwise, they are created in the get_optimizer fn, so we leave inp_var_grads blank
|
196 |
+
_, update_ops, var_grads = get_optimizer(mesh, loss, params, variable_dtype=variable_dtype)
|
197 |
+
# Log summaries to tensorboard
|
198 |
+
mtf.scalar_summary("loss", loss)
|
199 |
+
# Log gradients if in params
|
200 |
+
if params["log_grads"] not in [None, False]:
|
201 |
+
for g in var_grads:
|
202 |
+
grad_norm = mtf.sqrt(mtf.reduce_sum(mtf.square(g)))
|
203 |
+
mtf.scalar_summary("grads/norm" + g.name[:-2], grad_norm)
|
204 |
+
else:
|
205 |
+
# For now, we can only export fully-replicated tensors.
|
206 |
+
# This has to be done before lowering or they will not be included in the graph
|
207 |
+
mean_logits = mtf.reduce_mean(logits, reduced_dim=vocab_dim)
|
208 |
+
max_logits = mtf.argmax(logits, vocab_dim)
|
209 |
+
del logits
|
210 |
+
fully_replicated_mean_logits = mtf.anonymize(mean_logits)
|
211 |
+
fully_replicated_max_logits = mtf.anonymize(max_logits)
|
212 |
+
fully_replicated_loss_batch = mtf.anonymize(loss_batch)
|
213 |
+
|
214 |
+
# Gets & prints info about no. trainable vars in the model & dimension names
|
215 |
+
get_graph_info(graph)
|
216 |
+
|
217 |
+
# 'lowers' mtf tensors into a tf graph - this enables us to export results as tf tensors
|
218 |
+
lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True)
|
219 |
+
tf_loss = lowering.export_to_tf_tensor(loss)
|
220 |
+
tf_loss = tf.cast(tf_loss, tf.float32)
|
221 |
+
|
222 |
+
if mode == tf.estimator.ModeKeys.TRAIN:
|
223 |
+
# Use our patched version until mtf updates theirs
|
224 |
+
host_call = create_host_call(params['model_path'])
|
225 |
+
mtf.utils.remove_summaries()
|
226 |
+
|
227 |
+
# Creates train_op
|
228 |
+
tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
|
229 |
+
tf_update_ops.append(tf.assign_add(global_step, 1)) # Need to manually increment global_step
|
230 |
+
tf.logging.info(f"tf_update_ops: {tf_update_ops}")
|
231 |
+
train_op = tf.group(tf_update_ops)
|
232 |
+
else:
|
233 |
+
tf_mean_logits = lowering.export_to_tf_tensor(fully_replicated_mean_logits)
|
234 |
+
tf_max_logits = lowering.export_to_tf_tensor(fully_replicated_max_logits)
|
235 |
+
tf_loss_batch = tf.to_float(lowering.export_to_tf_tensor(fully_replicated_loss_batch))
|
236 |
+
|
237 |
+
with mtf.utils.outside_all_rewrites():
|
238 |
+
# Copy master variables to slices. Must be called first.
|
239 |
+
restore_hook = mtf.MtfRestoreHook(lowering)
|
240 |
+
if mode == tf.estimator.ModeKeys.TRAIN:
|
241 |
+
# Set up the checkpoint server and return the TPUEstimatorSpec
|
242 |
+
saver = tf.train.Saver(
|
243 |
+
tf.global_variables(),
|
244 |
+
sharded=True,
|
245 |
+
max_to_keep=10,
|
246 |
+
keep_checkpoint_every_n_hours=2,
|
247 |
+
defer_build=False,
|
248 |
+
save_relative_paths=True)
|
249 |
+
tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
|
250 |
+
saver_listener = mtf.MtfCheckpointSaverListener(lowering)
|
251 |
+
saver_hook = tf.train.CheckpointSaverHook(
|
252 |
+
params["model_path"],
|
253 |
+
save_steps=params["steps_per_checkpoint"],
|
254 |
+
saver=saver,
|
255 |
+
listeners=[saver_listener])
|
256 |
+
|
257 |
+
return tpu_estimator.TPUEstimatorSpec(
|
258 |
+
tf.estimator.ModeKeys.TRAIN,
|
259 |
+
loss=tf_loss,
|
260 |
+
host_call=host_call,
|
261 |
+
train_op=train_op,
|
262 |
+
training_hooks=[restore_hook, saver_hook])
|
263 |
+
|
264 |
+
elif mode == tf.estimator.ModeKeys.EVAL:
|
265 |
+
# Evaluation metrics
|
266 |
+
def _perplexity(loss):
|
267 |
+
perplexity = tf.exp(loss)
|
268 |
+
return tf.metrics.mean(perplexity)
|
269 |
+
|
270 |
+
def _bits_per_byte(loss):
|
271 |
+
bpb = loss * (0.29335 / math.log(2))
|
272 |
+
return tf.metrics.mean(bpb)
|
273 |
+
|
274 |
+
def _metric_fn(tf_mean_logits, tf_loss_batch):
|
275 |
+
mean_logits = tf.metrics.mean(tf_mean_logits)
|
276 |
+
loss = tf.reduce_mean(tf_loss_batch)
|
277 |
+
perp = _perplexity(loss)
|
278 |
+
bpb = _bits_per_byte(loss)
|
279 |
+
return {"mean_logits": mean_logits, "perplexity": perp, "bits per byte": bpb}
|
280 |
+
|
281 |
+
def _lambada_metric_fn(labels, tf_max_logits, tf_loss_batch):
|
282 |
+
eos_token = params["eos_id"]
|
283 |
+
answer_positions = tf.where(tf.math.not_equal(labels, eos_token))
|
284 |
+
|
285 |
+
correct_answers = tf.gather_nd(tf.math.equal(tf_max_logits, labels), answer_positions)
|
286 |
+
accuracy = tf.metrics.mean(tf.cast(correct_answers, tf.float32))
|
287 |
+
|
288 |
+
# I guess tf_loss_batch has z_loss and maybe other stuff added to it
|
289 |
+
# so maybe this should be calculated separately in the future
|
290 |
+
answer_loss = tf.gather_nd(tf_loss_batch, answer_positions)
|
291 |
+
log_perplexity = tf.metrics.mean(answer_loss)
|
292 |
+
|
293 |
+
return {"lambada_acc": accuracy, "lambada_log_ppl": log_perplexity}
|
294 |
+
|
295 |
+
eval_task = params["eval_task"]
|
296 |
+
if eval_task == "lambada":
|
297 |
+
eval_metrics = (_lambada_metric_fn, [labels, tf_max_logits, tf_loss_batch])
|
298 |
+
else:
|
299 |
+
eval_metrics = (_metric_fn, [tf_mean_logits, tf_loss_batch])
|
300 |
+
|
301 |
+
return tpu_estimator.TPUEstimatorSpec(
|
302 |
+
tf.estimator.ModeKeys.EVAL,
|
303 |
+
evaluation_hooks=[restore_hook],
|
304 |
+
loss=tf_loss,
|
305 |
+
eval_metrics=eval_metrics)
|
optimizers.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import re
|
6 |
+
import mesh_tensorflow as mtf
|
7 |
+
import tensorflow.compat.v1 as tf
|
8 |
+
|
9 |
+
def clip_by_global_norm(grads, clip_norm):
|
10 |
+
"""Clip the grads by global norm."""
|
11 |
+
global_norm = mtf.sqrt(mtf.add_n([mtf.reduce_sum(mtf.square(t)) for t in grads if t is not None]))
|
12 |
+
multiplier = clip_norm / mtf.maximum(global_norm, clip_norm)
|
13 |
+
clipped_grads = [None if t is None else t * multiplier for t in grads]
|
14 |
+
return clipped_grads, global_norm
|
15 |
+
|
16 |
+
def get_optimizer(mesh, loss, params, variable_dtype, inp_var_grads=None):
|
17 |
+
"""Creates and returns an optimizer training op."""
|
18 |
+
global_step = tf.train.get_or_create_global_step()
|
19 |
+
|
20 |
+
learning_rate = tf.constant(value=params["lr"], shape=[], dtype=variable_dtype.slice_dtype)
|
21 |
+
clip_value = mtf.constant(mesh, params["gradient_clipping"], dtype=variable_dtype.slice_dtype)
|
22 |
+
|
23 |
+
if inp_var_grads is None:
|
24 |
+
var_grads = mtf.gradients([loss], [v.outputs[0] for v in mesh.graph.trainable_variables])
|
25 |
+
else:
|
26 |
+
var_grads = inp_var_grads
|
27 |
+
|
28 |
+
# Cast to full precision
|
29 |
+
var_grads_fp = [mtf.cast(v, variable_dtype.slice_dtype) for v in var_grads]
|
30 |
+
|
31 |
+
# decrease LR to final lr (lr*0.1) by this step - defaults to train_steps
|
32 |
+
end_step = params.get("lr_decay_end", params["train_steps"])
|
33 |
+
|
34 |
+
if params["lr_decay"] == "linear":
|
35 |
+
learning_rate = tf.train.polynomial_decay(
|
36 |
+
learning_rate,
|
37 |
+
global_step,
|
38 |
+
end_step,
|
39 |
+
end_learning_rate=params["lr"]*0.1, # Decrease to 10% of initial LR according to GPT-3 paper
|
40 |
+
power=1.0,
|
41 |
+
cycle=False)
|
42 |
+
elif params["lr_decay"] == "cosine":
|
43 |
+
learning_rate = tf.train.cosine_decay(
|
44 |
+
learning_rate,
|
45 |
+
global_step,
|
46 |
+
end_step,
|
47 |
+
alpha=0.1 # Alpha is min lr value as a fraction of init lr.
|
48 |
+
)
|
49 |
+
|
50 |
+
if params["warmup_steps"] > 0:
|
51 |
+
global_steps_int = tf.cast(global_step, tf.int32)
|
52 |
+
warmup_steps_int = tf.constant(params["warmup_steps"], dtype=tf.int32)
|
53 |
+
|
54 |
+
dtype = variable_dtype.slice_dtype
|
55 |
+
|
56 |
+
global_steps_float = tf.cast(global_steps_int, dtype)
|
57 |
+
warmup_steps_float = tf.cast(warmup_steps_int, dtype)
|
58 |
+
|
59 |
+
warmup_percent_done = global_steps_float / warmup_steps_float
|
60 |
+
warmup_learning_rate = learning_rate * warmup_percent_done
|
61 |
+
|
62 |
+
is_warmup = tf.cast(global_steps_int < warmup_steps_int, dtype)
|
63 |
+
learning_rate = ((1.0 - is_warmup) * learning_rate +
|
64 |
+
is_warmup * warmup_learning_rate)
|
65 |
+
|
66 |
+
learning_rate = mtf.import_fully_replicated(mesh, learning_rate, mtf.Shape([]), name="learning_rate")
|
67 |
+
mtf.scalar_summary("lr", learning_rate)
|
68 |
+
|
69 |
+
if params["opt_name"].lower() == "adam":
|
70 |
+
optimizer = AdamWeightDecayOptimizer(
|
71 |
+
learning_rate=learning_rate,
|
72 |
+
weight_decay_rate=params["weight_decay"],
|
73 |
+
beta_1=params["beta1"],
|
74 |
+
beta_2=params["beta2"],
|
75 |
+
epsilon=params["epsilon"],
|
76 |
+
exclude_from_weight_decay=["norm", "bias"],
|
77 |
+
variable_dtype=variable_dtype
|
78 |
+
)
|
79 |
+
else:
|
80 |
+
optimizer = mtf.optimize.AdafactorOptimizer(
|
81 |
+
learning_rate=params["lr"],
|
82 |
+
decay_rate=params["weight_decay"],
|
83 |
+
beta1=params["beta1"],
|
84 |
+
epsilon1=params["ada_epsilon1"],
|
85 |
+
epsilon2=params["ada_epsilon2"]
|
86 |
+
)
|
87 |
+
|
88 |
+
if params["gradient_clipping"] is not None:
|
89 |
+
(var_grads_fp, _) = clip_by_global_norm(var_grads_fp, clip_norm=clip_value)
|
90 |
+
|
91 |
+
update_ops = optimizer.apply_grads(var_grads_fp, mesh.graph.trainable_variables)
|
92 |
+
return learning_rate, update_ops, var_grads_fp
|
93 |
+
|
94 |
+
|
95 |
+
class AdamWeightDecayOptimizer(mtf.optimize.Optimizer):
|
96 |
+
"""A basic Adam optimizer that includes "correct" L2 weight decay."""
|
97 |
+
|
98 |
+
def __init__(self,
|
99 |
+
learning_rate,
|
100 |
+
weight_decay_rate=0.0,
|
101 |
+
beta_1=0.9,
|
102 |
+
beta_2=0.999,
|
103 |
+
epsilon=1e-6,
|
104 |
+
exclude_from_weight_decay=None,
|
105 |
+
variable_dtype=None):
|
106 |
+
"""Constructs a AdamWeightDecayOptimizer."""
|
107 |
+
|
108 |
+
self.learning_rate = learning_rate
|
109 |
+
self.weight_decay_rate = weight_decay_rate
|
110 |
+
self.beta_1 = beta_1
|
111 |
+
self.beta_2 = beta_2
|
112 |
+
self.epsilon = epsilon
|
113 |
+
self.exclude_from_weight_decay = exclude_from_weight_decay
|
114 |
+
self.variable_dtype = variable_dtype
|
115 |
+
|
116 |
+
def apply_grad(self, grad, var):
|
117 |
+
"""See base class."""
|
118 |
+
if grad is None:
|
119 |
+
tf.logging.warning("Gradient is None for variable %s" % var.name)
|
120 |
+
return []
|
121 |
+
|
122 |
+
grad = mtf.to_float(grad)
|
123 |
+
|
124 |
+
assignments = []
|
125 |
+
|
126 |
+
m = mtf.get_variable(
|
127 |
+
var.mesh, var.name + "/adam_m", var.shape,
|
128 |
+
initializer=tf.zeros_initializer(),
|
129 |
+
# master_dtype=self.variable_dtype.master_dtype,
|
130 |
+
# slice_dtype=self.variable_dtype.slice_dtype,
|
131 |
+
# activation_dtype=self.variable_dtype.activation_dtype,
|
132 |
+
trainable=False)
|
133 |
+
|
134 |
+
v = mtf.get_variable(
|
135 |
+
var.mesh, var.name + "/adam_v", var.shape,
|
136 |
+
initializer=tf.zeros_initializer(),
|
137 |
+
# master_dtype=self.variable_dtype.master_dtype,
|
138 |
+
# slice_dtype=self.variable_dtype.slice_dtype,
|
139 |
+
# activation_dtype=self.variable_dtype.activation_dtype,
|
140 |
+
trainable=False)
|
141 |
+
|
142 |
+
# Standard Adam update.
|
143 |
+
next_m = self.beta_1 * m + (1.0 - self.beta_1) * grad
|
144 |
+
next_v = self.beta_2 * v + (1.0 - self.beta_2) * mtf.square(grad)
|
145 |
+
|
146 |
+
update = next_m / (mtf.sqrt(next_v) + self.epsilon)
|
147 |
+
|
148 |
+
# Just adding the square of the weights to the loss function is *not*
|
149 |
+
# the correct way of using L2 regularization/weight decay with Adam,
|
150 |
+
# since that will interact with the m and v parameters in strange ways.
|
151 |
+
#
|
152 |
+
# Instead we want to decay the weights in a manner that doesn't interact
|
153 |
+
# with the m/v parameters. This is equivalent to adding the square
|
154 |
+
# of the weights to the loss with plain (non-momentum) SGD.
|
155 |
+
if self._do_use_weight_decay(var.name):
|
156 |
+
update += mtf.to_float(var.value) * self.weight_decay_rate
|
157 |
+
|
158 |
+
update_with_lr = self.learning_rate * update
|
159 |
+
|
160 |
+
var_update = mtf.assign_sub(var, update_with_lr)
|
161 |
+
|
162 |
+
assignments.extend(
|
163 |
+
[var_update,
|
164 |
+
mtf.assign(m, next_m),
|
165 |
+
mtf.assign(v, next_v)])
|
166 |
+
return assignments
|
167 |
+
|
168 |
+
def _do_use_weight_decay(self, param_name):
|
169 |
+
"""Whether to use L2 weight decay for `param_name`."""
|
170 |
+
if not self.weight_decay_rate:
|
171 |
+
return False
|
172 |
+
if self.exclude_from_weight_decay:
|
173 |
+
for r in self.exclude_from_weight_decay:
|
174 |
+
if re.search(r, param_name) is not None:
|
175 |
+
return False
|
176 |
+
return True
|
requirements.txt
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
google-api-python-client
|
2 |
+
jsonlines
|
3 |
+
lm_dataformat
|
4 |
+
mesh-tensorflow==0.1.18
|
5 |
+
numpy
|
6 |
+
oauth2client
|
7 |
+
ortools
|
8 |
+
pytest
|
9 |
+
sacred
|
10 |
+
tensorflow==2.5.1
|
11 |
+
tensorflow-datasets==3.2.1
|
12 |
+
tokenizers==0.9.4
|
13 |
+
transformers==4.1.1
|
14 |
+
tpunicorn
|
15 |
+
absl-py
|
16 |
+
ftfy
|
17 |
+
sacred
|
18 |
+
pymongo
|
run_experiment.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import atexit
|
2 |
+
import sacred
|
3 |
+
import argparse
|
4 |
+
import time
|
5 |
+
import math
|
6 |
+
import subprocess
|
7 |
+
import shutil
|
8 |
+
import os
|
9 |
+
import json
|
10 |
+
import threading
|
11 |
+
import requests
|
12 |
+
import glob
|
13 |
+
from configs import fetch_model_params
|
14 |
+
import socket
|
15 |
+
import subprocess
|
16 |
+
import queue
|
17 |
+
import sys
|
18 |
+
import signal
|
19 |
+
|
20 |
+
|
21 |
+
parser = argparse.ArgumentParser()
|
22 |
+
parser.add_argument('--tpu', type=str, required=True) # Name of TPU to train on, if any
|
23 |
+
parser.add_argument('--model', type=str, required=True) # JSON file that contains model parameters
|
24 |
+
parser.add_argument('--experiment_name', type=str, required=True) # name of experiment (will show up in omniboard)
|
25 |
+
parser.add_argument('--steps_per_checkpoint', type=int, default=5000)
|
26 |
+
parser.add_argument('--autostack', action="store_false")
|
27 |
+
parser.add_argument('--auto_layout', action="store_true")
|
28 |
+
parser.add_argument('--auto_layout_and_mesh_shape', action="store_true")
|
29 |
+
parser.add_argument('--new', action='store_true')
|
30 |
+
parser.add_argument('--test', action='store_true')
|
31 |
+
parser.add_argument('--eval', action='store_true')
|
32 |
+
parser.add_argument('--predict', action='store_true')
|
33 |
+
parser.add_argument('--no_delete_tpu', action='store_true')
|
34 |
+
parser.add_argument('--initial_heartbeat_timeout', type=int, default=7200)
|
35 |
+
parser.add_argument('--heartbeat_timeout', type=int, default=1800) # kill and restart if nothing logged to tensorboard in this many seconds
|
36 |
+
args = parser.parse_args()
|
37 |
+
|
38 |
+
params = fetch_model_params(args.model)
|
39 |
+
|
40 |
+
ex = sacred.Experiment(args.experiment_name)
|
41 |
+
ex.observers.append(sacred.observers.QueuedMongoObserver(url='127.0.0.1:27017', db_name='db', username='user', password='password'))
|
42 |
+
|
43 |
+
|
44 |
+
def get_open_port(lo=8000, hi=8100):
|
45 |
+
for i in range(lo, hi):
|
46 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
47 |
+
if s.connect_ex(('localhost', i)) != 0:
|
48 |
+
return i
|
49 |
+
|
50 |
+
|
51 |
+
def train_thread(args, tpu, id, q):
|
52 |
+
print('starting training on', tpu)
|
53 |
+
|
54 |
+
# pass binary flags through
|
55 |
+
opts = ''
|
56 |
+
for flag in ['auto_layout', 'auto_layout_and_mesh_shape', 'new', 'test', 'predict', 'eval', ]:
|
57 |
+
if args.__getattribute__(flag):
|
58 |
+
opts += ' --' + flag
|
59 |
+
|
60 |
+
for flag in ['autostack', ]:
|
61 |
+
if not args.__getattribute__(flag):
|
62 |
+
opts += ' --' + flag
|
63 |
+
|
64 |
+
cmd = "python3 main.py --tpu {tpu} --model run_configs/config_{id}.json --steps_per_checkpoint {steps_per_checkpoint} {opts} --sacred_id {run_id}".format(tpu=tpu, id=id, steps_per_checkpoint=args.steps_per_checkpoint, opts=opts, run_id=id)
|
65 |
+
print('Running:', cmd)
|
66 |
+
proc = subprocess.Popen(cmd, shell=True)
|
67 |
+
|
68 |
+
# poll until it's exited
|
69 |
+
while proc.poll() is None:
|
70 |
+
time.sleep(60)
|
71 |
+
try:
|
72 |
+
nq, *nargs = q.get_nowait()
|
73 |
+
if nq == 'kill':
|
74 |
+
print('train thread recieved kill signal from logging thread')
|
75 |
+
# first send SIGTERM
|
76 |
+
proc.terminate()
|
77 |
+
|
78 |
+
time.sleep(60)
|
79 |
+
|
80 |
+
# if it still hasn't exited, we send SIGKILL
|
81 |
+
if proc.poll() is None:
|
82 |
+
print('SIGTERM not successful, sending SIGKILL')
|
83 |
+
proc.kill()
|
84 |
+
|
85 |
+
except queue.Empty:
|
86 |
+
pass
|
87 |
+
|
88 |
+
print('exited training!')
|
89 |
+
if proc.returncode == 0:
|
90 |
+
print('exited gracefully')
|
91 |
+
os.kill(os.getpid(), signal.SIGINT)
|
92 |
+
return
|
93 |
+
|
94 |
+
if args.no_delete_tpu:
|
95 |
+
print('recreate done, exiting train_thread - not killing tpu!')
|
96 |
+
return
|
97 |
+
print("Recreating {} in 60sec...".format(tpu))
|
98 |
+
time.sleep(60)
|
99 |
+
os.system("pu recreate {} --yes --retry 3600 --retry-randomness 1.5".format(tpu))
|
100 |
+
print('recreate done, exiting train_thread')
|
101 |
+
|
102 |
+
# clear out queue
|
103 |
+
while True:
|
104 |
+
try:
|
105 |
+
q.get_nowait()
|
106 |
+
print('dropped request in queue after pu recreate')
|
107 |
+
except queue.Empty:
|
108 |
+
break
|
109 |
+
|
110 |
+
|
111 |
+
def get_json(uri, params=None, timeout=15):
|
112 |
+
resp = requests.get(uri, params=params, timeout=timeout)
|
113 |
+
resp.raise_for_status()
|
114 |
+
return resp.json()
|
115 |
+
|
116 |
+
|
117 |
+
def get_tag_sets(base_uri):
|
118 |
+
j = get_json(f'{base_uri}/data/plugin/scalars/tags', {'experiment': ''})
|
119 |
+
assert isinstance(j, dict)
|
120 |
+
return {
|
121 |
+
run: j[run].keys()
|
122 |
+
for run in j.keys()
|
123 |
+
}
|
124 |
+
|
125 |
+
|
126 |
+
def get_scalar_data(base_uri, run, tag):
|
127 |
+
j = get_json(f'{base_uri}/data/plugin/scalars/scalars', {'experiment': '', 'run': run, 'tag': tag})
|
128 |
+
assert isinstance(j, list)
|
129 |
+
return j
|
130 |
+
|
131 |
+
|
132 |
+
def get_run_data(port):
|
133 |
+
base_uri = f'http://localhost:{port}/'
|
134 |
+
r = {}
|
135 |
+
try:
|
136 |
+
tag_sets = get_tag_sets(base_uri)
|
137 |
+
runs = tag_sets.keys()
|
138 |
+
if '.' in runs:
|
139 |
+
if 'loss' in tag_sets['.']:
|
140 |
+
r['loss'] = get_scalar_data(base_uri, '.', 'loss')
|
141 |
+
if 'eval' in runs:
|
142 |
+
if 'loss' in tag_sets['eval']:
|
143 |
+
r['val_loss'] = get_scalar_data(base_uri, 'eval', 'loss')
|
144 |
+
if 'eval_lambada' in runs:
|
145 |
+
if 'lambada_acc' in tag_sets['eval_lambada']:
|
146 |
+
r['lambada_acc'] = get_scalar_data(base_uri, 'eval_lambada', 'lambada_acc')
|
147 |
+
if 'lambada_log_ppl' in tag_sets['eval_lambada']:
|
148 |
+
r['lambada_ppl'] = [
|
149 |
+
[t, s, math.exp(lp)]
|
150 |
+
for [t, s, lp] in get_scalar_data(base_uri, 'eval_lambada', 'lambada_log_ppl')
|
151 |
+
]
|
152 |
+
except:
|
153 |
+
import traceback
|
154 |
+
traceback.print_exc()
|
155 |
+
return r
|
156 |
+
|
157 |
+
|
158 |
+
@ex.main
|
159 |
+
def main(_run):
|
160 |
+
print('Starting run', _run._id)
|
161 |
+
print('experiment main invoked with argv:', " ".join(sys.argv))
|
162 |
+
print('WARNING: please remember to remove old metric log files from the model directory.')
|
163 |
+
|
164 |
+
os.makedirs('run_configs', exist_ok=True)
|
165 |
+
shutil.copy(args.model if args.model.endswith('.json') else 'configs/{}.json'.format(args.model), 'run_configs/config_{}.json'.format(_run._id))
|
166 |
+
|
167 |
+
tensorboard_port = get_open_port()
|
168 |
+
print('Tensorboard at port:', tensorboard_port)
|
169 |
+
print('Tensorboard url: ', 'http://eleutherai.bmk.sh:'+ str(tensorboard_port))
|
170 |
+
os.system("screen -S tensorboard_{} -d -m bash -c 'tensorboard --logdir {} --port {} --bind_all --reload_multifile=true || tensorboard --logdir {} --port {} --reload_multifile=true'".format(_run._id, params["model_path"], tensorboard_port,params["model_path"], tensorboard_port,))
|
171 |
+
atexit.register(goodbye, _run._id)
|
172 |
+
|
173 |
+
curr_step = {}
|
174 |
+
seen_predictions = set()
|
175 |
+
|
176 |
+
heartbeat_timeout = args.initial_heartbeat_timeout * 2
|
177 |
+
while True:
|
178 |
+
last_tb_log_time = time.time()
|
179 |
+
start_time = time.time()
|
180 |
+
q = queue.Queue()
|
181 |
+
trainthd = threading.Thread(target=train_thread, args=(args, args.tpu, _run._id, q))
|
182 |
+
trainthd.start()
|
183 |
+
|
184 |
+
while trainthd.is_alive():
|
185 |
+
time.sleep(60)
|
186 |
+
|
187 |
+
if start_time + args.initial_heartbeat_timeout < time.time():
|
188 |
+
# after initial args.initial_heartbeat_timeout grace period, now we want to set the timeout threshold much lower
|
189 |
+
heartbeat_timeout = args.heartbeat_timeout
|
190 |
+
|
191 |
+
print('Polling tensorboard for metrics...')
|
192 |
+
data = get_run_data(tensorboard_port)
|
193 |
+
for k in data.keys():
|
194 |
+
for ts, step, val in data[k]:
|
195 |
+
if step <= curr_step.get(k, -1):
|
196 |
+
continue
|
197 |
+
_run.log_scalar(k, val, step)
|
198 |
+
if k == 'loss':
|
199 |
+
_run.log_scalar('tb_ts', ts, step)
|
200 |
+
print('Logged to sacred: step={},loss={},tb_ts={}'.format(step, val, ts))
|
201 |
+
|
202 |
+
# found something new, so logging!
|
203 |
+
last_tb_log_time = time.time()
|
204 |
+
|
205 |
+
curr_step[k] = step
|
206 |
+
|
207 |
+
for f in glob.glob('predictions_{}_*'.format(_run._id)):
|
208 |
+
if f in seen_predictions:
|
209 |
+
continue
|
210 |
+
print('collecting prediction file', f)
|
211 |
+
ex.add_artifact(f)
|
212 |
+
|
213 |
+
seen_predictions.add(f)
|
214 |
+
|
215 |
+
# collect eval metrics from jsonl
|
216 |
+
if os.path.exists(f'eval_{_run._id}.jsonl'):
|
217 |
+
with open(f'eval_{_run._id}.jsonl') as fh:
|
218 |
+
for line in fh:
|
219 |
+
ob = json.loads(line)
|
220 |
+
val_step = ob['global_step']
|
221 |
+
val_task = ob['task']
|
222 |
+
for metr in ob.keys():
|
223 |
+
k = 'fs.' + val_task + '.' + metr
|
224 |
+
if metr in ['task', 'global_step']: continue
|
225 |
+
if val_step <= curr_step.get(k, -1): continue
|
226 |
+
_run.log_scalar(k, ob[metr], val_step)
|
227 |
+
curr_step[k] = val_step
|
228 |
+
|
229 |
+
if time.time() - last_tb_log_time > heartbeat_timeout:
|
230 |
+
# the run hasn't logged in a while, so we restart it
|
231 |
+
q.put(('kill',))
|
232 |
+
|
233 |
+
# give training thread some time to do its thing and recreate tpu
|
234 |
+
while trainthd.is_alive():
|
235 |
+
print('logging thread waiting for killing stalled run and for tpu recreate to finish')
|
236 |
+
time.sleep(60)
|
237 |
+
|
238 |
+
# reset heartbeat timeout to initial
|
239 |
+
heartbeat_timeout = args.initial_heartbeat_timeout
|
240 |
+
last_tb_log_time = time.time()
|
241 |
+
|
242 |
+
|
243 |
+
if args.no_delete_tpu:
|
244 |
+
break
|
245 |
+
|
246 |
+
|
247 |
+
def goodbye(id):
|
248 |
+
print("You are now leaving the Python sector.")
|
249 |
+
print("Sie verlassen den pythonischen Sektor.")
|
250 |
+
|
251 |
+
os.system("screen -S tensorboard_{} -X quit".format(id))
|
252 |
+
|
253 |
+
|
254 |
+
if __name__ == '__main__':
|
255 |
+
for file in glob.glob("**/*", recursive=True):
|
256 |
+
if file.split('.')[-1] in ['py']:
|
257 |
+
print('Adding', file, 'to sacred')
|
258 |
+
ex.add_source_file(file)
|
259 |
+
|
260 |
+
ex.add_config({
|
261 |
+
'tpu_name': args.tpu,
|
262 |
+
**params
|
263 |
+
})
|
264 |
+
|
265 |
+
ex.run()
|
sample.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import mesh_tensorflow as mtf
|
2 |
+
import tensorflow.compat.v1 as tf
|
3 |
+
import mesh_tensorflow.transformer as mtf_transformer
|
4 |
+
|
5 |
+
from models.utils import entmax, sample_categorical
|
6 |
+
from models.gpt2 import gpt2
|
7 |
+
|
8 |
+
def sample_autoregressive(partial_sequences,
|
9 |
+
other_features,
|
10 |
+
params,
|
11 |
+
stop_at_token=50256,
|
12 |
+
max_steps=None,
|
13 |
+
temperature=0.9,
|
14 |
+
variable_dtype=mtf.VariableDType(tf.float32),
|
15 |
+
encoder_output=None,
|
16 |
+
encoder_sequence_id=None,
|
17 |
+
encoder_inputs=None,
|
18 |
+
shared_params=None,
|
19 |
+
has_partial_sequences=True,
|
20 |
+
encoder_layer_outputs=None,
|
21 |
+
never_end=False,
|
22 |
+
remove_partial_sequences=False,
|
23 |
+
sampling_keep_top_k=-1,
|
24 |
+
sampling_use_entmax = False,
|
25 |
+
bos_id=50256,
|
26 |
+
):
|
27 |
+
"""Sample randomly one token at a time.
|
28 |
+
|
29 |
+
The partial_sequences represent partial sequences to be continued. The
|
30 |
+
first tokens of each sequence are nonzero representing the given partial
|
31 |
+
sequences and the last tokens of each sequence are zeros, representing what
|
32 |
+
needs to be filled in.
|
33 |
+
|
34 |
+
If there are no partial sequences (you want to sample from the beginning),
|
35 |
+
then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and
|
36 |
+
has_partial_sequences=False (so we can skip computation).
|
37 |
+
|
38 |
+
Args:
|
39 |
+
partial_sequences: an int32 Tensor with shape [<batch_dims>, length_dim]
|
40 |
+
stop_at_token: an optional integer eos id. Stop when we produce it.
|
41 |
+
max_steps: an optional integer, the max number of steps to decode.
|
42 |
+
temperature: an optional floating point value between 0.0 and 1.0 0.0
|
43 |
+
means argmax, 1.0 means sample according to predicted distribution.
|
44 |
+
variable_dtype: a mtf.VariableDType
|
45 |
+
encoder_output: an optional Tensor
|
46 |
+
encoder_sequence_id: an optional Tensor
|
47 |
+
encoder_inputs: an optional Tensor
|
48 |
+
shared_params: an optional dictionary
|
49 |
+
has_partial_sequences: a boolean
|
50 |
+
encoder_layer_outputs: optional - readonly list of tensor activations when
|
51 |
+
decoding, one per each input layer + the embedding layer
|
52 |
+
never_end: a boolean - if set, then avoid generating stop_at_token
|
53 |
+
remove_partial_sequences: a boolean - whether to remove the partial
|
54 |
+
sequences from the output
|
55 |
+
sampling_keep_top_k: an integer - if not -1, only sample from the top k
|
56 |
+
logits.
|
57 |
+
bos_id: beginning of sequence id
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
a Tensor with shape [<batch_dims>, length_dim]
|
61 |
+
"""
|
62 |
+
|
63 |
+
inputs = partial_sequences # Partial sequences to fill in
|
64 |
+
batch_dims = inputs.shape.dims[:-1]
|
65 |
+
length_dim = inputs.shape.dims[-1]
|
66 |
+
padding_id = params.get("padding_id", 0)
|
67 |
+
slow_sampling = params.get("slow_sampling", False)
|
68 |
+
|
69 |
+
|
70 |
+
initial_position = mtf.reduce_sum(
|
71 |
+
mtf.to_int32(mtf.not_equal(inputs, padding_id)), reduced_dim=length_dim) # Gets position where zero padding starts
|
72 |
+
|
73 |
+
length_range = mtf.range(inputs.mesh, length_dim, tf.int32)
|
74 |
+
input_full_attention = True # for now hardcode this to true bc lazy
|
75 |
+
if input_full_attention:
|
76 |
+
# Vanilla autoregressive model - each position can see previous positions.
|
77 |
+
# Think this feeds in to the loop fn and tells each position where it can attend to?
|
78 |
+
read_priority = write_priority = length_range * mtf.to_int32(
|
79 |
+
mtf.greater(length_range, initial_position))
|
80 |
+
else:
|
81 |
+
read_priority = write_priority = length_range
|
82 |
+
|
83 |
+
# Builds context to pass around internally
|
84 |
+
# The 'first part' context records initial states of k / v / x
|
85 |
+
|
86 |
+
if not slow_sampling:
|
87 |
+
context_first_part = mtf_transformer.transformer.Context(
|
88 |
+
model=None,
|
89 |
+
mesh=inputs.mesh,
|
90 |
+
batch_dims=batch_dims,
|
91 |
+
length_dim=length_dim,
|
92 |
+
variable_dtype=variable_dtype,
|
93 |
+
mode="first_part",
|
94 |
+
position=length_range,
|
95 |
+
position_is_default=True,
|
96 |
+
new_states=[],
|
97 |
+
initial_position=initial_position,
|
98 |
+
sequence_id=None,
|
99 |
+
encoder_output=encoder_output,
|
100 |
+
encoder_sequence_id=encoder_sequence_id,
|
101 |
+
constant_states=[],
|
102 |
+
shared_params=shared_params,
|
103 |
+
encoder_layer_outputs=encoder_layer_outputs,
|
104 |
+
write_priority=write_priority,
|
105 |
+
read_priority=read_priority,
|
106 |
+
inputs=inputs,
|
107 |
+
encoder_inputs=encoder_inputs)
|
108 |
+
|
109 |
+
with tf.variable_scope("gpt2"):
|
110 |
+
logits, _, _ = gpt2.model({"inputs": inputs}, other_features, params, inputs.mesh, variable_dtype=variable_dtype, context=context_first_part)
|
111 |
+
|
112 |
+
if not has_partial_sequences:
|
113 |
+
initial_states = [mtf.zeros_like(t) for t in context_first_part.new_states]
|
114 |
+
else:
|
115 |
+
initial_states = context_first_part.new_states
|
116 |
+
else:
|
117 |
+
initial_states = []
|
118 |
+
|
119 |
+
if not has_partial_sequences:
|
120 |
+
partial_sequences_eos_count = 0
|
121 |
+
|
122 |
+
if stop_at_token is not None:
|
123 |
+
partial_sequences_eos_count = mtf.reduce_sum(
|
124 |
+
mtf.to_int32(mtf.equal(partial_sequences, stop_at_token)),
|
125 |
+
reduced_dim=length_dim)
|
126 |
+
|
127 |
+
def cond_fn(position, ids, *unused_states):
|
128 |
+
"""Should we run another loop iteration?"""
|
129 |
+
past_end = mtf.greater_equal(position, length_dim.size)
|
130 |
+
if max_steps:
|
131 |
+
past_end = mtf.logical_or(
|
132 |
+
past_end, mtf.greater_equal(position - initial_position, max_steps))
|
133 |
+
|
134 |
+
is_done = past_end
|
135 |
+
if stop_at_token is not None:
|
136 |
+
eos_count = mtf.reduce_sum(
|
137 |
+
mtf.to_int32(mtf.equal(ids, stop_at_token)),
|
138 |
+
reduced_dim=length_dim)
|
139 |
+
has_additional_eos = mtf.greater(eos_count, partial_sequences_eos_count)
|
140 |
+
is_done = mtf.logical_or(is_done, has_additional_eos)
|
141 |
+
all_done = mtf.reduce_all(is_done)
|
142 |
+
return mtf.logical_not(all_done)
|
143 |
+
|
144 |
+
def body_fn(position, ids, *states):
|
145 |
+
"""One step in the decode loop."""
|
146 |
+
nonlocal sampling_keep_top_k
|
147 |
+
|
148 |
+
context = mtf_transformer.transformer.Context(
|
149 |
+
model=None,
|
150 |
+
mesh=inputs.mesh,
|
151 |
+
batch_dims=batch_dims,
|
152 |
+
length_dim=length_dim,
|
153 |
+
variable_dtype=variable_dtype,
|
154 |
+
mode="incremental",
|
155 |
+
position=position,
|
156 |
+
position_is_default=True,
|
157 |
+
states=states,
|
158 |
+
new_states=[],
|
159 |
+
initial_position=position,
|
160 |
+
sequence_id=None,
|
161 |
+
encoder_output=encoder_output,
|
162 |
+
encoder_sequence_id=encoder_sequence_id,
|
163 |
+
shared_params=shared_params,
|
164 |
+
encoder_layer_outputs=encoder_layer_outputs,
|
165 |
+
write_priority=write_priority,
|
166 |
+
read_priority=read_priority,
|
167 |
+
inputs=ids,
|
168 |
+
encoder_inputs=encoder_inputs) if not slow_sampling else None
|
169 |
+
|
170 |
+
with tf.variable_scope("gpt2", reuse=tf.AUTO_REUSE):
|
171 |
+
logits, _, _ = gpt2.model({"inputs": ids}, other_features, params, inputs.mesh, variable_dtype=variable_dtype, context = context)
|
172 |
+
|
173 |
+
if not sampling_use_entmax:
|
174 |
+
# By default, do top_k sampling of 0.9
|
175 |
+
if sampling_keep_top_k == -2:
|
176 |
+
sampling_keep_top_k = int(logits.shape[-1].size * 0.1)
|
177 |
+
|
178 |
+
if sampling_keep_top_k != -1:
|
179 |
+
if sampling_keep_top_k <= 0:
|
180 |
+
raise ValueError("sampling_keep_top_k must either be -1 or positive.")
|
181 |
+
k_largest = mtf.nth_largest_element(
|
182 |
+
logits, n=sampling_keep_top_k,
|
183 |
+
reduced_dim=other_features["vocab_dim"])
|
184 |
+
logits = mtf.where(mtf.less_equal(logits, k_largest),
|
185 |
+
mtf.ones_like(logits) * -1e6, logits)
|
186 |
+
|
187 |
+
ids_this_step = mtf.sample_with_temperature(
|
188 |
+
logits, other_features["vocab_dim"], temperature)
|
189 |
+
else:
|
190 |
+
ids_this_step = sample_categorical(entmax(logits))
|
191 |
+
|
192 |
+
if slow_sampling:
|
193 |
+
ids_this_step = mtf.shift(ids_this_step, offset=1, dim=length_dim, wrap=False)
|
194 |
+
else:
|
195 |
+
ids_this_step = mtf.reshape(ids_this_step, (batch_dims))
|
196 |
+
|
197 |
+
one_hot = mtf.one_hot(position, length_dim, dtype=tf.int32)
|
198 |
+
one_new_id = ids_this_step * one_hot
|
199 |
+
new_ids = (1 - one_hot) * ids + one_new_id
|
200 |
+
new_position = position + 1
|
201 |
+
|
202 |
+
ret = [new_position, new_ids]
|
203 |
+
if context is not None:
|
204 |
+
ret += context.new_states
|
205 |
+
return ret
|
206 |
+
|
207 |
+
while_loop_inputs = [initial_position, inputs] + initial_states
|
208 |
+
final_position, outputs = mtf.while_loop(
|
209 |
+
cond_fn, body_fn, while_loop_inputs)[:2]
|
210 |
+
del final_position
|
211 |
+
if has_partial_sequences and remove_partial_sequences:
|
212 |
+
# Remove partial sequences from outputs
|
213 |
+
partial_length = mtf.reduce_sum(
|
214 |
+
mtf.to_int32(mtf.not_equal(partial_sequences, padding_id)),
|
215 |
+
reduced_dim=length_dim)
|
216 |
+
outputs = mtf.dynamic_shift(
|
217 |
+
outputs, -partial_length, length_dim, wrap=False)
|
218 |
+
return outputs
|
tasks.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
import json
|
3 |
+
import requests
|
4 |
+
import numpy as np
|
5 |
+
import ftfy
|
6 |
+
from data.encoders import fetch_encoder, encode
|
7 |
+
import tensorflow as tf
|
8 |
+
import re
|
9 |
+
from functools import partial
|
10 |
+
|
11 |
+
lambada_src_uri = 'http://eaidata.bmk.sh/data/lambada_test.jsonl'
|
12 |
+
normalization = 'NFKC'
|
13 |
+
|
14 |
+
|
15 |
+
# Note: this task is called "lambada" but it really refers to OpenAI's version
|
16 |
+
# of the task, which actually differs in some ways from the task described in
|
17 |
+
# the original paper. So, strictly speaking, accuracy values from this task
|
18 |
+
# should not be compared to accuracy values from the original lambada task.
|
19 |
+
# For more information, see
|
20 |
+
# https://github.com/openai/gpt-2/issues/131
|
21 |
+
|
22 |
+
def lambada_create_tokens_data(params, path):
|
23 |
+
with open(path, 'w') as f:
|
24 |
+
req = requests.get(lambada_src_uri)
|
25 |
+
req.raise_for_status()
|
26 |
+
jsons = [json.loads(l) for l in req.iter_lines()]
|
27 |
+
texts = [ftfy.fix_text(j['text'], normalization=normalization) for j in jsons]
|
28 |
+
enc = fetch_encoder(params)
|
29 |
+
arrays = [encode(enc, t) for t in texts]
|
30 |
+
json.dump(arrays, f)
|
31 |
+
return arrays
|
32 |
+
|
33 |
+
|
34 |
+
def lambada_read_or_create_tokens_data(params, path):
|
35 |
+
# if you tell me where the file should go, i will helpfully create it for you
|
36 |
+
if not os.path.exists(path):
|
37 |
+
return lambada_create_tokens_data(params, path)
|
38 |
+
with open(path) as f:
|
39 |
+
return json.load(f)
|
40 |
+
|
41 |
+
|
42 |
+
def bin_pack(params, tokens_data):
|
43 |
+
eos_token = params['eos_id']
|
44 |
+
n_ctx = params['n_ctx']
|
45 |
+
dummy_token = 1
|
46 |
+
pad_batch_size = params['eval_batch_size']
|
47 |
+
bins = []
|
48 |
+
for a in tokens_data:
|
49 |
+
if len(bins) == 0 or len(bins[-1]) + len(a) + 1 > n_ctx:
|
50 |
+
bins.append([])
|
51 |
+
bins[-1] += a
|
52 |
+
bins[-1].append(eos_token)
|
53 |
+
while len(bins) % pad_batch_size != 0:
|
54 |
+
bins.append([])
|
55 |
+
bins_array = np.full((len(bins), n_ctx), dummy_token, dtype=np.uint16)
|
56 |
+
for i, b in enumerate(bins):
|
57 |
+
bins_array[i, 0:len(b)] = b
|
58 |
+
return bins_array
|
59 |
+
|
60 |
+
|
61 |
+
def lambada_init(params):
|
62 |
+
ds_configs = params['dataset_configs']
|
63 |
+
l = [
|
64 |
+
ds_configs[ds_id].get('lambada_tokens_path', "./lambada.json")
|
65 |
+
for ds_id, _, _, _ in params['datasets']
|
66 |
+
]
|
67 |
+
assert len(l) > 0, 'lambada_tokens_path not found in the dataset config'
|
68 |
+
lt_path = l[0]
|
69 |
+
assert lt_path.endswith('.json'), 'lambada_tokens_path must have extension json'
|
70 |
+
|
71 |
+
tokens_data = lambada_read_or_create_tokens_data(params, lt_path)
|
72 |
+
bins_array = bin_pack(params, tokens_data)
|
73 |
+
params['lambada_tokens_path'] = lt_path
|
74 |
+
params['lambada_n_steps'] = len(bins_array) // params['eval_batch_size']
|
75 |
+
|
76 |
+
|
77 |
+
def lambada_get_task_info(params):
|
78 |
+
return {
|
79 |
+
'n_steps': params['lambada_n_steps'],
|
80 |
+
}
|
81 |
+
|
82 |
+
|
83 |
+
# The LAMBADA evaluation code looks at the logits of each position just before an eos_token
|
84 |
+
def lambada_input(params):
|
85 |
+
eos_token = 50256 if params['n_vocab'] >= 50257 else 0
|
86 |
+
n_ctx = params['n_ctx']
|
87 |
+
lt_path = params['lambada_tokens_path']
|
88 |
+
tokens_data = lambada_read_or_create_tokens_data(params, lt_path)
|
89 |
+
bins_array = bin_pack(params, tokens_data)
|
90 |
+
dataset = tf.data.Dataset.from_tensor_slices(bins_array)
|
91 |
+
|
92 |
+
def _get_output(bin):
|
93 |
+
bin = tf.cast(bin, dtype=tf.int32)
|
94 |
+
indexes = tf.range(n_ctx)
|
95 |
+
results = tf.gather(bin, (indexes + 1) % n_ctx)
|
96 |
+
eos_next_positions = tf.math.equal(tf.gather(bin, (indexes + 2) % n_ctx), eos_token)
|
97 |
+
output = tf.where(eos_next_positions, results, tf.constant(eos_token, shape=[n_ctx]))
|
98 |
+
bin = tf.reshape(bin, [n_ctx])
|
99 |
+
bin = tf.cast(bin, dtype=tf.int32)
|
100 |
+
output = tf.reshape(output, [n_ctx])
|
101 |
+
output = tf.cast(output, dtype=tf.int32)
|
102 |
+
return bin, output
|
103 |
+
|
104 |
+
dataset = dataset.map(_get_output)
|
105 |
+
dataset = dataset.batch(params['eval_batch_size'], drop_remainder=True)
|
106 |
+
dataset = dataset.repeat()
|
107 |
+
return dataset
|
108 |
+
|
109 |
+
|
110 |
+
task_descriptors = {
|
111 |
+
'lambada': {
|
112 |
+
'init_fn': lambada_init,
|
113 |
+
'get_task_info_fn': lambada_get_task_info,
|
114 |
+
'input_fn': lambada_input,
|
115 |
+
}
|
116 |
+
}
|
test_models.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
import traceback
|
3 |
+
import logging
|
4 |
+
from collections import defaultdict
|
5 |
+
from contextlib import contextmanager
|
6 |
+
|
7 |
+
import tensorflow as tf
|
8 |
+
tf.compat.v1.enable_eager_execution()
|
9 |
+
import mesh_tensorflow as mtf
|
10 |
+
from mesh_tensorflow import placement_mesh_impl
|
11 |
+
|
12 |
+
from inputs import mlm_sample_text
|
13 |
+
from models.gpt2 import gpt2
|
14 |
+
from models.utils import biasmask_attn_weights, entmax, sample_categorical
|
15 |
+
|
16 |
+
from sample import sample_autoregressive
|
17 |
+
|
18 |
+
# helper functions
|
19 |
+
|
20 |
+
@contextmanager
|
21 |
+
def not_raises(exception):
|
22 |
+
try:
|
23 |
+
yield
|
24 |
+
except exception:
|
25 |
+
logging.error(traceback.format_exc())
|
26 |
+
raise pytest.fail("DID RAISE {0}".format(exception))
|
27 |
+
|
28 |
+
# fixtures
|
29 |
+
|
30 |
+
params = defaultdict(lambda: None, {
|
31 |
+
"n_head": 1,
|
32 |
+
"n_ctx": 4,
|
33 |
+
"n_embd": 2,
|
34 |
+
"n_vocab": 256,
|
35 |
+
"embed_dropout": 0.,
|
36 |
+
"n_layer": 2,
|
37 |
+
"num_microbatches": 1,
|
38 |
+
"train_batch_size": 1,
|
39 |
+
"causal": True,
|
40 |
+
"attention_types": ['global', 'local'],
|
41 |
+
"res_dropout": 0.1,
|
42 |
+
"rotary_emb": True,
|
43 |
+
"activation_function": "gelu",
|
44 |
+
"moe_layers": (1,),
|
45 |
+
"num_mem_kv": 16,
|
46 |
+
"no_weight_tie": True,
|
47 |
+
"moe_params": {
|
48 |
+
'moe_dropout_rate': 0.0
|
49 |
+
},
|
50 |
+
"mesh_shape": [],
|
51 |
+
"layout": {},
|
52 |
+
"local_attention_radius": 128,
|
53 |
+
"share_parameters": True,
|
54 |
+
"rezero": True
|
55 |
+
})
|
56 |
+
|
57 |
+
# tests
|
58 |
+
|
59 |
+
def test_model():
|
60 |
+
graph = mtf.Graph()
|
61 |
+
mesh = mtf.Mesh(graph, "my_mesh")
|
62 |
+
|
63 |
+
seq_len = params["n_ctx"]
|
64 |
+
|
65 |
+
batch_dim = mtf.Dimension("batch", 1)
|
66 |
+
sequence_dim = mtf.Dimension("sequence", seq_len)
|
67 |
+
|
68 |
+
features = {
|
69 |
+
'inputs': mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32),
|
70 |
+
'labels': mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32)
|
71 |
+
}
|
72 |
+
|
73 |
+
# create mask
|
74 |
+
|
75 |
+
num_mem_kv = params.get('num_mem_kv', 0)
|
76 |
+
length_dim = mtf.Dimension('sequence', seq_len)
|
77 |
+
memory_length_dim = mtf.Dimension('memory_length', seq_len + num_mem_kv)
|
78 |
+
embed_sequence_dim = mtf.Dimension('embed_sequence', seq_len)
|
79 |
+
embd_dim = mtf.Dimension("embd", params["n_embd"])
|
80 |
+
vocab_dim = mtf.Dimension("vocab", params["n_vocab"])
|
81 |
+
|
82 |
+
other_features = {}
|
83 |
+
variable_dtype = mtf.VariableDType(tf.float32, tf.float32, tf.float32)
|
84 |
+
|
85 |
+
other_features["attn_bias"] = biasmask_attn_weights(mesh, length_dim, memory_length_dim, variable_dtype)
|
86 |
+
other_features["embd_dim"] = embd_dim
|
87 |
+
other_features["vocab_dim"] = vocab_dim
|
88 |
+
other_features["embed_sequence_dim"] = embed_sequence_dim
|
89 |
+
other_features["memory_length_dim"] = memory_length_dim
|
90 |
+
|
91 |
+
with not_raises(Exception):
|
92 |
+
logits, _, _ = gpt2.model(features, other_features, params, mesh, variable_dtype=variable_dtype)
|
93 |
+
|
94 |
+
mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""])
|
95 |
+
lowering = mtf.Lowering(graph, {mesh: mesh_impl})
|
96 |
+
logits = lowering.export_to_tf_tensor(logits)
|
97 |
+
|
98 |
+
|
99 |
+
def test_sampling():
|
100 |
+
graph = mtf.Graph()
|
101 |
+
mesh = mtf.Mesh(graph, "my_mesh")
|
102 |
+
|
103 |
+
batch_dim = mtf.Dimension("batch", 1)
|
104 |
+
sequence_dim = mtf.Dimension("sequence", 1)
|
105 |
+
|
106 |
+
inputs = mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32)
|
107 |
+
inputs = mtf.pad(inputs, [0, 3], sequence_dim.name)
|
108 |
+
|
109 |
+
# create mask
|
110 |
+
|
111 |
+
seq_len = params["n_ctx"]
|
112 |
+
num_mem_kv = params.get('num_mem_kv', 0)
|
113 |
+
length_dim = mtf.Dimension('sequence', seq_len)
|
114 |
+
memory_length_dim = mtf.Dimension('memory_length', seq_len + num_mem_kv)
|
115 |
+
embed_sequence_dim = mtf.Dimension('embed_sequence', seq_len)
|
116 |
+
embd_dim = mtf.Dimension("embd", params["n_embd"])
|
117 |
+
vocab_dim = mtf.Dimension("vocab", params["n_vocab"])
|
118 |
+
|
119 |
+
other_features = {}
|
120 |
+
|
121 |
+
other_features["attn_bias"] = biasmask_attn_weights(mesh, length_dim, memory_length_dim, mtf.VariableDType(tf.float32))
|
122 |
+
other_features["embd_dim"] = embd_dim
|
123 |
+
other_features["vocab_dim"] = vocab_dim
|
124 |
+
other_features["embed_sequence_dim"] = embed_sequence_dim
|
125 |
+
other_features["memory_length_dim"] = memory_length_dim
|
126 |
+
|
127 |
+
params["mode"] = "predict"
|
128 |
+
|
129 |
+
with not_raises(Exception):
|
130 |
+
samples = sample_autoregressive(
|
131 |
+
inputs, other_features=other_features, params=params, variable_dtype=mtf.VariableDType(),
|
132 |
+
remove_partial_sequences=params["remove_partial_sequences"], stop_at_token=params["eos_id"], sampling_use_entmax=True)
|
133 |
+
|
134 |
+
mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""])
|
135 |
+
lowering = mtf.Lowering(graph, {mesh: mesh_impl})
|
136 |
+
samples = lowering.export_to_tf_tensor(samples)
|
137 |
+
|
138 |
+
# mlm
|
139 |
+
|
140 |
+
mlm_params = defaultdict(lambda: None, {
|
141 |
+
"n_head": 1,
|
142 |
+
"n_ctx": 4,
|
143 |
+
"n_embd": 1,
|
144 |
+
"n_vocab": 256,
|
145 |
+
"embed_dropout": 0.,
|
146 |
+
"n_layer": 2,
|
147 |
+
"num_microbatches": 1,
|
148 |
+
"train_batch_size": 1,
|
149 |
+
"attention_types": ['global', 'local'],
|
150 |
+
"res_dropout": 0.1,
|
151 |
+
"mesh_shape": [],
|
152 |
+
"layout": {},
|
153 |
+
"share_parameters": True,
|
154 |
+
"mlm_training": True,
|
155 |
+
"mlm_mask_id": 3,
|
156 |
+
"mlm_cls_token_id": 4,
|
157 |
+
"mlm_random_token_prob": 0.1
|
158 |
+
})
|
159 |
+
|
160 |
+
def test_mlm_sample_text():
|
161 |
+
document = tf.random.normal((16,))
|
162 |
+
with not_raises(Exception):
|
163 |
+
features, labels = mlm_sample_text(mlm_params, document, random_documents = True)
|
164 |
+
assert features.shape == (mlm_params['n_ctx'],)
|
165 |
+
|
166 |
+
# entmax
|
167 |
+
|
168 |
+
def test_entmax():
|
169 |
+
graph = mtf.Graph()
|
170 |
+
mesh = mtf.Mesh(graph, "my_mesh")
|
171 |
+
length = mtf.Dimension("tensor_length", 8)
|
172 |
+
tensor = mtf.range(mesh, length, tf.float32)
|
173 |
+
output = entmax(tensor)
|
174 |
+
grad = mtf.gradients([output], [tensor])[0]
|
175 |
+
sample = sample_categorical(output, length)
|
176 |
+
|
177 |
+
mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""])
|
178 |
+
lowering = mtf.Lowering(graph, {mesh: mesh_impl})
|
179 |
+
sample = lowering.export_to_tf_tensor(sample)
|
180 |
+
grad = lowering.export_to_tf_tensor(grad)
|
utils.py
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from urllib.parse import urlparse
|
3 |
+
from shutil import rmtree
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
from pathlib import Path
|
7 |
+
import sys
|
8 |
+
import tensorflow.compat.v1 as tf
|
9 |
+
import tensorflow.compat.v2 as tf2
|
10 |
+
import mesh_tensorflow as mtf
|
11 |
+
import mesh_tensorflow.auto_mtf
|
12 |
+
from data.encoders import fetch_encoder
|
13 |
+
import re
|
14 |
+
|
15 |
+
def setup_logging(args):
|
16 |
+
Path("logs").mkdir(exist_ok=True)
|
17 |
+
tf.logging.set_verbosity(logging.INFO)
|
18 |
+
tf.get_logger().propagate = False # Remove double log on console
|
19 |
+
name = os.path.splitext(os.path.basename(args.model))[0]
|
20 |
+
handlers = [
|
21 |
+
logging.FileHandler(f"logs/{name}.log"),
|
22 |
+
logging.StreamHandler(sys.stdout)
|
23 |
+
]
|
24 |
+
logger = logging.getLogger("tensorflow")
|
25 |
+
logger.handlers = handlers
|
26 |
+
return logger
|
27 |
+
|
28 |
+
|
29 |
+
def get_batch_size(params):
|
30 |
+
return params[f"{params['mode']}_batch_size"]
|
31 |
+
|
32 |
+
|
33 |
+
def add_mode_to_params(params, mode):
|
34 |
+
if mode == tf.estimator.ModeKeys.PREDICT:
|
35 |
+
params["mode"] = "predict"
|
36 |
+
elif mode == tf.estimator.ModeKeys.EVAL:
|
37 |
+
params["mode"] = "eval"
|
38 |
+
elif mode == tf.estimator.ModeKeys.TRAIN:
|
39 |
+
params["mode"] = "train"
|
40 |
+
else:
|
41 |
+
raise ValueError(f"Invalid mode {mode}")
|
42 |
+
return params
|
43 |
+
|
44 |
+
|
45 |
+
def simd_mesh_setup(params, mesh_shape, layout_rules):
|
46 |
+
"""Constructs SimdMesh function - instructions on how to evenly split tensors across all TPU cores"""
|
47 |
+
|
48 |
+
num_hosts = params["context"].num_hosts
|
49 |
+
host_placement_fn = params["context"].tpu_host_placement_function
|
50 |
+
device_list = [host_placement_fn(host_id=i) for i in range(num_hosts)]
|
51 |
+
tf.logging.info(f"device_list = {device_list}")
|
52 |
+
|
53 |
+
# TODO: Better estimation of replica cache size?
|
54 |
+
replica_cache_size = 300 * 1000000 # 300M per replica
|
55 |
+
|
56 |
+
# Worker 0 caches all the TPU binaries
|
57 |
+
worker0_mem = replica_cache_size * params["context"].num_replicas
|
58 |
+
devices_memory_usage = [worker0_mem] + [0] * (num_hosts - 1)
|
59 |
+
var_placer = mtf.utils.BalancedVariablePlacer(device_list, devices_memory_usage)
|
60 |
+
mesh_devices = [""] * mesh_shape.size
|
61 |
+
mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
|
62 |
+
mesh_shape, layout_rules, mesh_devices, params["context"].device_assignment)
|
63 |
+
|
64 |
+
return var_placer, mesh_impl
|
65 |
+
|
66 |
+
|
67 |
+
def remove_batch_from_layout(layout):
|
68 |
+
"""
|
69 |
+
The tf-mesh layout splits across batch size, remove it.
|
70 |
+
Useful for prediction steps, when you no longer want large batches.
|
71 |
+
|
72 |
+
:param layout: string describing tf-mesh layout
|
73 |
+
:return: layout minus batch dimension
|
74 |
+
"""
|
75 |
+
layout = layout.split(',')
|
76 |
+
ret_layout = ""
|
77 |
+
for i in layout:
|
78 |
+
if "batch" in i:
|
79 |
+
pass
|
80 |
+
else:
|
81 |
+
ret_layout += f"{i},"
|
82 |
+
return ret_layout[:-1]
|
83 |
+
|
84 |
+
|
85 |
+
def yes_or_no(question):
|
86 |
+
while True:
|
87 |
+
reply = str(input(question+' (y/n): ')).lower().strip()
|
88 |
+
if reply[:1] == 'y':
|
89 |
+
return True
|
90 |
+
if reply[:1] == 'n':
|
91 |
+
return False
|
92 |
+
|
93 |
+
|
94 |
+
def remove_gs_or_filepath(path):
|
95 |
+
parsed_url = urlparse(path)
|
96 |
+
if parsed_url.scheme == "gs":
|
97 |
+
os.system(f"gsutil rm -rf {path}")
|
98 |
+
return
|
99 |
+
rmtree(path)
|
100 |
+
|
101 |
+
|
102 |
+
def save_config(params_dict, logdir):
|
103 |
+
print(f"Saving config to {logdir}")
|
104 |
+
text = "{\n\n"
|
105 |
+
total_params = len(params_dict)
|
106 |
+
for count, key in enumerate(params_dict):
|
107 |
+
config_value = str(params_dict[key])
|
108 |
+
if re.search('[a-zA-Z]', config_value):
|
109 |
+
if config_value.lower() != 'true':
|
110 |
+
if config_value.lower() != 'false':
|
111 |
+
if config_value[0] != '[':
|
112 |
+
# TODO: Making a manual exception for parsing epsilon right now since it's the only number in
|
113 |
+
# scientific notation. Should fix this.
|
114 |
+
if key != "epsilon":
|
115 |
+
config_value = f'"{config_value}"'
|
116 |
+
if count == total_params - 1:
|
117 |
+
text += f'"{str(key)}"' + ' : ' + config_value + '\n\n'
|
118 |
+
else:
|
119 |
+
text += f'"{str(key)}"' + ' : ' + config_value + ',\n\n'
|
120 |
+
text += '\n\n}'
|
121 |
+
sess = tf.InteractiveSession()
|
122 |
+
summary_op = tf.summary.text("run_config", tf.convert_to_tensor(text))
|
123 |
+
summary_writer = tf.summary.FileWriter(f"{logdir}/config", sess.graph)
|
124 |
+
text = sess.run(summary_op)
|
125 |
+
summary_writer.add_summary(text, 0)
|
126 |
+
summary_writer.flush()
|
127 |
+
summary_writer.close()
|
128 |
+
tf.reset_default_graph()
|
129 |
+
print('Done!')
|
130 |
+
|
131 |
+
|
132 |
+
def expand_attention_types_params(params_list):
|
133 |
+
newlist = []
|
134 |
+
for item in params_list:
|
135 |
+
for _ in range(item[1]):
|
136 |
+
newlist.extend(item[0])
|
137 |
+
return newlist
|
138 |
+
|
139 |
+
|
140 |
+
def get_n_trainable_vars(graph):
|
141 |
+
"""
|
142 |
+
Gets number of trainable vars in a MTF model.
|
143 |
+
|
144 |
+
:param graph: Mesh-Tensorflow graph
|
145 |
+
:return: None
|
146 |
+
"""
|
147 |
+
total_parameters = 0
|
148 |
+
for variable in graph.trainable_variables:
|
149 |
+
shape = variable.shape.dims
|
150 |
+
variable_parameters = 1
|
151 |
+
for dim in shape:
|
152 |
+
variable_parameters *= dim.size
|
153 |
+
total_parameters += variable_parameters
|
154 |
+
print(f"\n\nN TRAINABLE VARS:\n{total_parameters:,}\n\n")
|
155 |
+
|
156 |
+
|
157 |
+
def print_dim_names(graph):
|
158 |
+
"""
|
159 |
+
Print names of all Dimensions
|
160 |
+
:param graph: Mesh-Tensorflow graph
|
161 |
+
:return: None
|
162 |
+
"""
|
163 |
+
all_dim_names = []
|
164 |
+
for variable in graph.all_variables:
|
165 |
+
names = variable.shape.dimension_names
|
166 |
+
all_dim_names.append(names)
|
167 |
+
|
168 |
+
# Print all dim names in graph & write to file
|
169 |
+
all_dim_names = [item for sublist in all_dim_names for item in sublist] # Flatten all dims
|
170 |
+
unique_dims = list(set(all_dim_names))
|
171 |
+
print("ALL DIM NAMES:")
|
172 |
+
for dim_name in unique_dims:
|
173 |
+
print(dim_name)
|
174 |
+
print('\n')
|
175 |
+
|
176 |
+
|
177 |
+
def get_graph_info(graph):
|
178 |
+
"""
|
179 |
+
Wrapper fn that calculates number of trainable vars in an MTF graph & prints all dim_names to file
|
180 |
+
TODO: how to get un-trainable dim-names too, batch etc.
|
181 |
+
|
182 |
+
:param graph: Mesh-Tensorflow graph
|
183 |
+
:return: None
|
184 |
+
"""
|
185 |
+
get_n_trainable_vars(graph)
|
186 |
+
print_dim_names(graph)
|
187 |
+
|
188 |
+
|
189 |
+
def loss_denominator(targets, num_microbatches):
|
190 |
+
"""Denominator applied to losses.
|
191 |
+
|
192 |
+
This is usually the size of the targets tensor (omitting ensemble
|
193 |
+
dimensions). Alternatively, it is an override value passed to the
|
194 |
+
class constructor.
|
195 |
+
|
196 |
+
Args:
|
197 |
+
targets: a mtf.Tensor
|
198 |
+
num_microbatches: an integer - greater than one if the step has been
|
199 |
+
serialized into multiple microbatches to save memory.
|
200 |
+
Returns:
|
201 |
+
a float
|
202 |
+
"""
|
203 |
+
ret = float(targets.shape.size) * num_microbatches
|
204 |
+
return float(ret)
|
205 |
+
|
206 |
+
def check_dataset(input_fn, params, global_step=None):
|
207 |
+
tf.enable_eager_execution()
|
208 |
+
if global_step is not None:
|
209 |
+
dataset = input_fn(params, global_step=global_step)
|
210 |
+
else:
|
211 |
+
dataset = input_fn(params)
|
212 |
+
dataset_iter = dataset.make_one_shot_iterator()
|
213 |
+
tensor, _ = next(dataset_iter)
|
214 |
+
enc = fetch_encoder(params)
|
215 |
+
|
216 |
+
for p in tensor[:1]:
|
217 |
+
txt = enc.decode(p)
|
218 |
+
|
219 |
+
print('-' * 50)
|
220 |
+
print(txt[:500], '\n\n...\n\n', txt[-500:])
|
221 |
+
print('-' * 50)
|
222 |
+
exit()
|
223 |
+
|
224 |
+
def auto_layout(graph, mesh_shape, logits, loss):
|
225 |
+
layout_rules = mtf.auto_mtf.layout(graph, mesh_shape, [logits, loss])
|
226 |
+
print(f"Auto-selected layout:\n{layout_rules}\nRe-initialize graph with selected layout")
|
227 |
+
quit()
|
228 |
+
|
229 |
+
def auto_layout_and_mesh_shape(graph, num_cores, logits, loss):
|
230 |
+
layout_rules, mesh_shape = mtf.auto_mtf.layout_and_mesh_shape(graph, num_cores,
|
231 |
+
[logits, loss], max_mesh_shape_dimensions=4)
|
232 |
+
print(f"Num cores:\n{num_cores}\nAuto-selected layout:\n{layout_rules}\nAuto-selected mesh shape:\n{mesh_shape}" \
|
233 |
+
f"\nRe-initialize graph with selected layout & mesh shape")
|
234 |
+
quit()
|
235 |
+
|
236 |
+
def create_host_call(model_dir):
|
237 |
+
"""Construct a host_call writing scalar summaries.
|
238 |
+
|
239 |
+
Borrowed from t2t.
|
240 |
+
|
241 |
+
Args:
|
242 |
+
model_dir: String containing path to train
|
243 |
+
Returns:
|
244 |
+
(fn, args) Pair to be called by TPUEstimator as the host_call.
|
245 |
+
"""
|
246 |
+
|
247 |
+
graph = tf.get_default_graph()
|
248 |
+
# A list of (name, lowered tensor) tuples
|
249 |
+
summaries = graph.get_collection(mtf.utils.SCALAR_SUMMARIES_COLLECTION_KEY)
|
250 |
+
|
251 |
+
def maybe_cast(tensor):
|
252 |
+
assert tensor.shape.is_compatible_with([]), tensor.name
|
253 |
+
if tensor.dtype == tf.int64:
|
254 |
+
return tf.to_int32(tensor)
|
255 |
+
if tensor.dtype == tf.bfloat16:
|
256 |
+
return tf.cast(tensor, tf.float32)
|
257 |
+
return tensor
|
258 |
+
|
259 |
+
reshaped_tensors = [tf.reshape(maybe_cast(t), [1]) for _, t in summaries]
|
260 |
+
|
261 |
+
# When no supported summaries are found, don't create host_call. Otherwise,
|
262 |
+
# TPU outfeed queue would enqueue global_step while host_call doesn't dequeue
|
263 |
+
# it, eventually causing hang.
|
264 |
+
if not reshaped_tensors:
|
265 |
+
return None
|
266 |
+
|
267 |
+
def host_call_fn(global_step, *args):
|
268 |
+
"""Training host call. Creates scalar summaries for training metrics."""
|
269 |
+
# This function is executed on the CPU and should not directly reference
|
270 |
+
# any Tensors in the rest of the `model_fn`. To pass Tensors from the
|
271 |
+
# model to the `model_fn`, provide as part of the `host_call`.
|
272 |
+
global_step = tf.cast(global_step[0], tf.int64)
|
273 |
+
with tf2.summary.create_file_writer(model_dir).as_default():
|
274 |
+
# We cannot directly use any tensor from summaries, because each
|
275 |
+
# tensor here must be a concat of multiple tensors from all shards.
|
276 |
+
# Therefore, we rely on the assumption that args wil have the same
|
277 |
+
# length as summaries, and all tensors in args will have the same
|
278 |
+
# order of self._tup_summaries.
|
279 |
+
assert len(args) == len(summaries)
|
280 |
+
for i, tensor in enumerate(args):
|
281 |
+
name = summaries[i][0]
|
282 |
+
tf2.summary.scalar(name, tf.reduce_mean(tensor), step=global_step)
|
283 |
+
return tf.summary.all_v2_summary_ops()
|
284 |
+
|
285 |
+
global_step_t = tf.reshape(tf.to_int32(tf.train.get_global_step()), [1])
|
286 |
+
return host_call_fn, [global_step_t] + reshaped_tensors
|
287 |
+
|
288 |
+
|
289 |
+
def natural_sort(l):
|
290 |
+
convert = lambda text: int(text) if text.isdigit() else text.lower()
|
291 |
+
alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ]
|
292 |
+
return sorted(l, key = alphanum_key)
|