Why fine tuning FLAN-T5-LARGE with 4bit Quantization and LORA is so slow?
The training for fine tuning this model on a task (English -> SQL query) is taking way too long. Even with a training set with 7500 datapoints, for 2 epochs takes almost an hour to finish. The params for quantization and LORA are provided below.
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=False,#True,
bnb_4bit_compute_dtype=torch.bfloat16
)
model_id = "google/flan-t5-large"
tokenizer = AutoTokenizer.from_pretrained(model_id)
orig_model = AutoModelForSeq2SeqLM.from_pretrained(model_id,
quantization_config=bnb_config,
torch_dtype=torch.float16,
device_map="auto")
# for gradient checkpointing and 4 bit training
from peft import prepare_model_for_kbit_training
finetune_model.gradient_checkpointing_enable()
finetune_model = prepare_model_for_kbit_training(finetune_model)
#Lora Config
lora_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=8, # change
target_modules=["q", "k", "v", "o", "wi", "wo"],
bias="none",
task_type="CAUSAL_LM"
)
TrainingArguments(
output_dir=output_dir,
learning_rate=5e-3,
num_train_epochs=2, #3
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
weight_decay=0.01,
logging_steps=50,
evaluation_strategy="steps",
eval_steps=500,
fp16=True,
#tf32=True,
#bf16=False,
gradient_accumulation_steps=16,
#gradient_checkpointing=True,
lr_scheduler_type = "linear",
optim="paged_adamw_8bit",
)
To fine tune FLAT-T5-LARGE under these parameter setting, it is taking a long time (~ 2 hrs) for 2 epochs. For a 75000 training set, it took almost 5 hrs.
I am using a A100 (40 GB GPU) machine on Colab.
Any idea or suggestion to improve this slow progress ?
@scigeek
AFAIK you should do bf16=True (T5 in general, not sure if this is still the thing) the weights might be cast back and forth and might cause lag (likely f16 -> bf16 -> 4bit from what I guess) The model was trained on TPUs with bf16 that's why it will cause weird errors as well
4bit itself normally already requires casting from a nvidia hardware native precision so it's a lot of casting.
@merve
Thanks for the information. Setting bf16=True did improve the speed (with 7500 training data points); it took about an hour for 2 epochs on A100 in Colab. The eval scores are still low. I am hoping that it will improve once I increase the training set size. With current progress, scaling up the training set to 75000 may take multiple hours to complete the fine tuning.
Closing this issue then!