Add pipeline tag and library name

#1
by nielsr HF Staff - opened
Files changed (1) hide show
  1. README.md +60 -2
README.md CHANGED
@@ -1,9 +1,11 @@
1
  ---
2
- license: mit
3
  datasets:
4
  - Skylion007/openwebtext
 
5
  tags:
6
  - diffusion
 
 
7
  ---
8
 
9
  # Generalized Interpolating Discrete Diffusion
@@ -36,7 +38,7 @@ Our trained checkpoints are available under the following links. All of them hav
36
  |-------|-------|------|
37
  | GIDD+ (p_u = 0.0) | [dvruette/gidd-small-p_unif-0.0](https://huggingface.co/dvruette/gidd-small-p_unif-0.0) | [dvruette/gidd-base-p_unif-0.0](https://huggingface.co/dvruette/gidd-base-p_unif-0.0) |
38
  | GIDD+ (p_u = 0.1) | [dvruette/gidd-small-p_unif-0.1](https://huggingface.co/dvruette/gidd-small-p_unif-0.1) | [dvruette/gidd-base-p_unif-0.1](https://huggingface.co/dvruette/gidd-base-p_unif-0.1) |
39
- | GIDD+ (p_u = 0.2) | dvruette/gidd-small-p_unif-0.2 | [dvruette/gidd-base-p_unif-0.2](https://huggingface.co/dvruette/gidd-base-p_unif-0.2) |
40
 
41
 
42
  ## Use the Model
@@ -63,3 +65,59 @@ corrected_texts = pipe.self_correction(texts, num_inference_steps=128, early_sto
63
  print(corrected_texts)
64
  ```
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
 
2
  datasets:
3
  - Skylion007/openwebtext
4
+ license: mit
5
  tags:
6
  - diffusion
7
+ pipeline_tag: text-generation
8
+ library_name: transformers
9
  ---
10
 
11
  # Generalized Interpolating Discrete Diffusion
 
38
  |-------|-------|------|
39
  | GIDD+ (p_u = 0.0) | [dvruette/gidd-small-p_unif-0.0](https://huggingface.co/dvruette/gidd-small-p_unif-0.0) | [dvruette/gidd-base-p_unif-0.0](https://huggingface.co/dvruette/gidd-base-p_unif-0.0) |
40
  | GIDD+ (p_u = 0.1) | [dvruette/gidd-small-p_unif-0.1](https://huggingface.co/dvruette/gidd-small-p_unif-0.1) | [dvruette/gidd-base-p_unif-0.1](https://huggingface.co/dvruette/gidd-base-p_unif-0.1) |
41
+ | GIDD+ (p_u = 0.2) | [dvruette/gidd-small-p_unif-0.2](https://huggingface.co/dvruette/gidd-small-p_unif-0.2) | [dvruette/gidd-base-p_unif-0.2](https://huggingface.co/dvruette/gidd-base-p_unif-0.2) |
42
 
43
 
44
  ## Use the Model
 
65
  print(corrected_texts)
66
  ```
67
 
68
+ ## Reproducing Experiments
69
+
70
+ ### Training
71
+
72
+ To reproduce the training runs from the paper, you can use the following commands.
73
+ In this example, we are training on a single node with 8 GPUs, feel free to adjust the `--nnodes` and `--nproc_per_node` arguments to match your setup.
74
+ The checkpoints will be saved under `./outputs/{YYYY-MM-DD}/{HH-MM-SS}/checkpoints/` by default.
75
+
76
+ (optional) Log into W&B with `wandb login` for experiment tracking or disable via `wandb disabled` if you don't need/want it.
77
+
78
+ ```bash
79
+ # GIDD+ (p_u = 0.0)
80
+ torchrun --nnodes 1 --nproc_per_node 8 gidd/train.py --config-name gidd logging.run_name="'small-gidd+-owt-pu=0.0'"
81
+
82
+ # GIDD+ (p_0 > 0.0)
83
+ torchrun --nnodes 1 --nproc_per_node 8 gidd/train.py --config-name gidd model.p_uniform=0.1 logging.run_name="'small-gidd+-owt-pu=0.1'"
84
+
85
+ # MDLM baseline
86
+ torchrun --nnodes 1 --nproc_per_node 8 gidd/train.py --config-name mdlm logging.run_name="'small-mdlm-owt'"
87
+
88
+ # AR baseline
89
+ torchrun --nnodes 1 --nproc_per_node 8 gidd/train.py --config-name ar logging.run_name="'small-ar-owt'"
90
+ ```
91
+
92
+ ### Evaluation
93
+
94
+ There are also a couple of scripts to run inference and evaluate the trained models.
95
+ Note that these scripts expect the checkpoint format that is saved by the training script, so the checkpoints from HuggingFace are not directly compatible.
96
+ You can download our original training checkpoints from here: https://polybox.ethz.ch/index.php/s/BbxZcYDSoXf8aL4
97
+
98
+ #### Generate samples
99
+ The following command will generate `num_samples=16` samples in `num_denoising_steps=128` iterations from the model checkpoint located at `path` and save them to `samples_dir=samples.pt`.
100
+ ```bash
101
+ python gidd/eval/generate_samples.py path=./outputs/path/to/checkpoint/ samples_path=samples.pt num_samples=16 num_denoising_steps=128 batch_size=16
102
+ ```
103
+
104
+ #### Generative PPL
105
+ Given a file containing samples generated with the `generate_samples.py` script, the following command will compute the generative PPL.
106
+ Here we assume that the diffusion model used to generate samples located at `samples.pt` uses the `gpt2` tokenizer, and we compute generative PPL using `google/gemma-2-9b` as a reference model (note that `gemma-2-9b` requires you to log into your HF account using `huggingface-cli login`).
107
+ The results will be saved to `metrics_path=metrics.json`.
108
+ ```bash
109
+ python gidd/eval/generative_ppl.py samples_path=samples.pt model_tokenizer=gpt2 pretrained_model=google/gemma-2-9b batch_size=4 metrics_path=metrics.json
110
+ ```
111
+
112
+ #### Validation loss
113
+ A simple helper script to compute the loss of a trained model on the entire validation split.
114
+ ```bash
115
+ python gidd/eval/loss.py path=./outputs/path/to/checkpoint/ batch_size=32
116
+ ```
117
+
118
+ #### Self-correction
119
+ This script will run the self-correction step on the samples contained in `samples.pt` (e.g. generated with the `generate_samples.py` script) and save the corrected samples to `corrected_samples.pt`.
120
+ The `temp` argument controls the temperature used when resampling tokens from the model (see paper for more details).
121
+ ```bash
122
+ python gidd/eval/self_correction.py path=./outputs/path/to/checkpoint/ samples_path=samples.pt corrected_samples_path=corrected_samples.pt batch_size=16 num_denoising_steps=128 temp=0.1
123
+ ```