Add pipeline tag and library name
#1
by
nielsr
HF Staff
- opened
README.md
CHANGED
|
@@ -1,9 +1,11 @@
|
|
| 1 |
---
|
| 2 |
-
license: mit
|
| 3 |
datasets:
|
| 4 |
- Skylion007/openwebtext
|
|
|
|
| 5 |
tags:
|
| 6 |
- diffusion
|
|
|
|
|
|
|
| 7 |
---
|
| 8 |
|
| 9 |
# Generalized Interpolating Discrete Diffusion
|
|
@@ -36,7 +38,7 @@ Our trained checkpoints are available under the following links. All of them hav
|
|
| 36 |
|-------|-------|------|
|
| 37 |
| GIDD+ (p_u = 0.0) | [dvruette/gidd-small-p_unif-0.0](https://huggingface.co/dvruette/gidd-small-p_unif-0.0) | [dvruette/gidd-base-p_unif-0.0](https://huggingface.co/dvruette/gidd-base-p_unif-0.0) |
|
| 38 |
| GIDD+ (p_u = 0.1) | [dvruette/gidd-small-p_unif-0.1](https://huggingface.co/dvruette/gidd-small-p_unif-0.1) | [dvruette/gidd-base-p_unif-0.1](https://huggingface.co/dvruette/gidd-base-p_unif-0.1) |
|
| 39 |
-
| GIDD+ (p_u = 0.2) | dvruette/gidd-small-p_unif-0.2 | [dvruette/gidd-base-p_unif-0.2](https://huggingface.co/dvruette/gidd-base-p_unif-0.2) |
|
| 40 |
|
| 41 |
|
| 42 |
## Use the Model
|
|
@@ -63,3 +65,59 @@ corrected_texts = pipe.self_correction(texts, num_inference_steps=128, early_sto
|
|
| 63 |
print(corrected_texts)
|
| 64 |
```
|
| 65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
|
|
|
| 2 |
datasets:
|
| 3 |
- Skylion007/openwebtext
|
| 4 |
+
license: mit
|
| 5 |
tags:
|
| 6 |
- diffusion
|
| 7 |
+
pipeline_tag: text-generation
|
| 8 |
+
library_name: transformers
|
| 9 |
---
|
| 10 |
|
| 11 |
# Generalized Interpolating Discrete Diffusion
|
|
|
|
| 38 |
|-------|-------|------|
|
| 39 |
| GIDD+ (p_u = 0.0) | [dvruette/gidd-small-p_unif-0.0](https://huggingface.co/dvruette/gidd-small-p_unif-0.0) | [dvruette/gidd-base-p_unif-0.0](https://huggingface.co/dvruette/gidd-base-p_unif-0.0) |
|
| 40 |
| GIDD+ (p_u = 0.1) | [dvruette/gidd-small-p_unif-0.1](https://huggingface.co/dvruette/gidd-small-p_unif-0.1) | [dvruette/gidd-base-p_unif-0.1](https://huggingface.co/dvruette/gidd-base-p_unif-0.1) |
|
| 41 |
+
| GIDD+ (p_u = 0.2) | [dvruette/gidd-small-p_unif-0.2](https://huggingface.co/dvruette/gidd-small-p_unif-0.2) | [dvruette/gidd-base-p_unif-0.2](https://huggingface.co/dvruette/gidd-base-p_unif-0.2) |
|
| 42 |
|
| 43 |
|
| 44 |
## Use the Model
|
|
|
|
| 65 |
print(corrected_texts)
|
| 66 |
```
|
| 67 |
|
| 68 |
+
## Reproducing Experiments
|
| 69 |
+
|
| 70 |
+
### Training
|
| 71 |
+
|
| 72 |
+
To reproduce the training runs from the paper, you can use the following commands.
|
| 73 |
+
In this example, we are training on a single node with 8 GPUs, feel free to adjust the `--nnodes` and `--nproc_per_node` arguments to match your setup.
|
| 74 |
+
The checkpoints will be saved under `./outputs/{YYYY-MM-DD}/{HH-MM-SS}/checkpoints/` by default.
|
| 75 |
+
|
| 76 |
+
(optional) Log into W&B with `wandb login` for experiment tracking or disable via `wandb disabled` if you don't need/want it.
|
| 77 |
+
|
| 78 |
+
```bash
|
| 79 |
+
# GIDD+ (p_u = 0.0)
|
| 80 |
+
torchrun --nnodes 1 --nproc_per_node 8 gidd/train.py --config-name gidd logging.run_name="'small-gidd+-owt-pu=0.0'"
|
| 81 |
+
|
| 82 |
+
# GIDD+ (p_0 > 0.0)
|
| 83 |
+
torchrun --nnodes 1 --nproc_per_node 8 gidd/train.py --config-name gidd model.p_uniform=0.1 logging.run_name="'small-gidd+-owt-pu=0.1'"
|
| 84 |
+
|
| 85 |
+
# MDLM baseline
|
| 86 |
+
torchrun --nnodes 1 --nproc_per_node 8 gidd/train.py --config-name mdlm logging.run_name="'small-mdlm-owt'"
|
| 87 |
+
|
| 88 |
+
# AR baseline
|
| 89 |
+
torchrun --nnodes 1 --nproc_per_node 8 gidd/train.py --config-name ar logging.run_name="'small-ar-owt'"
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
### Evaluation
|
| 93 |
+
|
| 94 |
+
There are also a couple of scripts to run inference and evaluate the trained models.
|
| 95 |
+
Note that these scripts expect the checkpoint format that is saved by the training script, so the checkpoints from HuggingFace are not directly compatible.
|
| 96 |
+
You can download our original training checkpoints from here: https://polybox.ethz.ch/index.php/s/BbxZcYDSoXf8aL4
|
| 97 |
+
|
| 98 |
+
#### Generate samples
|
| 99 |
+
The following command will generate `num_samples=16` samples in `num_denoising_steps=128` iterations from the model checkpoint located at `path` and save them to `samples_dir=samples.pt`.
|
| 100 |
+
```bash
|
| 101 |
+
python gidd/eval/generate_samples.py path=./outputs/path/to/checkpoint/ samples_path=samples.pt num_samples=16 num_denoising_steps=128 batch_size=16
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
#### Generative PPL
|
| 105 |
+
Given a file containing samples generated with the `generate_samples.py` script, the following command will compute the generative PPL.
|
| 106 |
+
Here we assume that the diffusion model used to generate samples located at `samples.pt` uses the `gpt2` tokenizer, and we compute generative PPL using `google/gemma-2-9b` as a reference model (note that `gemma-2-9b` requires you to log into your HF account using `huggingface-cli login`).
|
| 107 |
+
The results will be saved to `metrics_path=metrics.json`.
|
| 108 |
+
```bash
|
| 109 |
+
python gidd/eval/generative_ppl.py samples_path=samples.pt model_tokenizer=gpt2 pretrained_model=google/gemma-2-9b batch_size=4 metrics_path=metrics.json
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
#### Validation loss
|
| 113 |
+
A simple helper script to compute the loss of a trained model on the entire validation split.
|
| 114 |
+
```bash
|
| 115 |
+
python gidd/eval/loss.py path=./outputs/path/to/checkpoint/ batch_size=32
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
#### Self-correction
|
| 119 |
+
This script will run the self-correction step on the samples contained in `samples.pt` (e.g. generated with the `generate_samples.py` script) and save the corrected samples to `corrected_samples.pt`.
|
| 120 |
+
The `temp` argument controls the temperature used when resampling tokens from the model (see paper for more details).
|
| 121 |
+
```bash
|
| 122 |
+
python gidd/eval/self_correction.py path=./outputs/path/to/checkpoint/ samples_path=samples.pt corrected_samples_path=corrected_samples.pt batch_size=16 num_denoising_steps=128 temp=0.1
|
| 123 |
+
```
|