Why fine tuning FLAN-T5-LARGE with 4bit Quantization and LORA is so slow?

#23
by scigeek - opened

The training for fine tuning this model on a task (English -> SQL query) is taking way too long. Even with a training set with 7500 datapoints, for 2 epochs takes almost an hour to finish. The params for quantization and LORA are provided below.

bnb_config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_use_double_quant=False,#True,
   bnb_4bit_compute_dtype=torch.bfloat16
)

model_id = "google/flan-t5-large"

tokenizer = AutoTokenizer.from_pretrained(model_id)
orig_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, 
                                                   quantization_config=bnb_config, 
                                                   torch_dtype=torch.float16,
                                                   device_map="auto")

# for gradient checkpointing and 4 bit training
from peft import prepare_model_for_kbit_training

finetune_model.gradient_checkpointing_enable()
finetune_model = prepare_model_for_kbit_training(finetune_model)

#Lora Config
lora_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=8,  # change
    target_modules=["q", "k", "v", "o", "wi", "wo"],
    bias="none",
    task_type="CAUSAL_LM"
)

TrainingArguments(
    output_dir=output_dir,
    learning_rate=5e-3,
    num_train_epochs=2, #3
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    weight_decay=0.01,
    logging_steps=50,
    evaluation_strategy="steps",
    eval_steps=500,
    fp16=True,
    #tf32=True,
    #bf16=False,
    gradient_accumulation_steps=16,
    #gradient_checkpointing=True,
    lr_scheduler_type = "linear",
    optim="paged_adamw_8bit",
)

To fine tune FLAT-T5-LARGE under these parameter setting, it is taking a long time (~ 2 hrs) for 2 epochs. For a 75000 training set, it took almost 5 hrs.

I am using a A100 (40 GB GPU) machine on Colab.
Any idea or suggestion to improve this slow progress ?

scigeek changed discussion title from Why fine tuning FLAT-T5-LARGE with 4bit Quantization and LORA is so slow? to Why fine tuning FLAN-T5-LARGE with 4bit Quantization and LORA is so slow?

@scigeek AFAIK you should do bf16=True (T5 in general, not sure if this is still the thing) the weights might be cast back and forth and might cause lag (likely f16 -> bf16 -> 4bit from what I guess) The model was trained on TPUs with bf16 that's why it will cause weird errors as well
4bit itself normally already requires casting from a nvidia hardware native precision so it's a lot of casting.

@merve Thanks for the information. Setting bf16=True did improve the speed (with 7500 training data points); it took about an hour for 2 epochs on A100 in Colab. The eval scores are still low. I am hoping that it will improve once I increase the training set size. With current progress, scaling up the training set to 75000 may take multiple hours to complete the fine tuning.

Google org

Closing this issue then!

merve changed discussion status to closed

Sign up or log in to comment