Update README.md
#4
by
Jiqing
- opened
README.md
CHANGED
|
@@ -78,8 +78,8 @@ def create_optimizer(opt_model, lr_ratio=0.1):
|
|
| 78 |
"lr": training_args.learning_rate * lr_ratio
|
| 79 |
},
|
| 80 |
]
|
| 81 |
-
|
| 82 |
-
|
| 83 |
optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
| 84 |
|
| 85 |
return optimizer
|
|
@@ -104,8 +104,8 @@ def preprocess_logits_for_metrics(logits, labels):
|
|
| 104 |
|
| 105 |
|
| 106 |
if __name__ == "__main__":
|
| 107 |
-
|
| 108 |
-
|
| 109 |
raw_dataset = load_dataset("Jiqing/ProtST-BinaryLocalization")
|
| 110 |
model = AutoModel.from_pretrained("Jiqing/protst-esm1b-for-sequential-classification", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
|
| 111 |
tokenizer = AutoTokenizer.from_pretrained("facebook/esm1b_t33_650M_UR50S")
|
|
@@ -117,8 +117,8 @@ if __name__ == "__main__":
|
|
| 117 |
'dataloader_num_workers': 0, 'run_name': 'downstream_esm1b_localization_fix', 'optim': 'adamw_torch', 'resume_from_checkpoint': False, \
|
| 118 |
- 'label_names': ['labels'], 'load_best_model_at_end': True, 'metric_for_best_model': 'accuracy', 'bf16': True, "save_total_limit": 3}
|
| 119 |
+ 'label_names': ['labels'], 'load_best_model_at_end': True, 'metric_for_best_model': 'accuracy', 'bf16': True, "save_total_limit": 3, "use_habana":True, "use_lazy_mode": True, "use_hpu_graphs_for_inference": True}
|
| 120 |
-
|
| 121 |
-
|
| 122 |
|
| 123 |
def tokenize_protein(example, tokenizer=None):
|
| 124 |
protein_seq = example["prot_seq"]
|
|
@@ -134,8 +134,8 @@ if __name__ == "__main__":
|
|
| 134 |
for split in ["train", "validation", "test"]:
|
| 135 |
raw_dataset[split] = raw_dataset[split].map(func_tokenize_protein, batched=False, remove_columns=["Unnamed: 0", "prot_seq", "localization"])
|
| 136 |
|
| 137 |
-
|
| 138 |
-
|
| 139 |
|
| 140 |
transformers.utils.logging.set_verbosity_info()
|
| 141 |
log_level = training_args.get_process_log_level()
|
|
@@ -144,16 +144,16 @@ if __name__ == "__main__":
|
|
| 144 |
optimizer = create_optimizer(model)
|
| 145 |
scheduler = create_scheduler(training_args, optimizer)
|
| 146 |
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
|
| 151 |
|
| 152 |
# build trainer
|
| 153 |
-
|
| 154 |
-
|
| 155 |
model=model,
|
| 156 |
-
|
| 157 |
args=training_args,
|
| 158 |
train_dataset=raw_dataset["train"],
|
| 159 |
eval_dataset=raw_dataset["validation"],
|
|
|
|
| 78 |
"lr": training_args.learning_rate * lr_ratio
|
| 79 |
},
|
| 80 |
]
|
| 81 |
+
- optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
|
| 82 |
+
+ optimizer_cls, optimizer_kwargs = GaudiTrainer.get_optimizer_cls_and_kwargs(training_args)
|
| 83 |
optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
| 84 |
|
| 85 |
return optimizer
|
|
|
|
| 104 |
|
| 105 |
|
| 106 |
if __name__ == "__main__":
|
| 107 |
+
- device = torch.device("cpu")
|
| 108 |
+
+ device = torch.device("hpu")
|
| 109 |
raw_dataset = load_dataset("Jiqing/ProtST-BinaryLocalization")
|
| 110 |
model = AutoModel.from_pretrained("Jiqing/protst-esm1b-for-sequential-classification", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
|
| 111 |
tokenizer = AutoTokenizer.from_pretrained("facebook/esm1b_t33_650M_UR50S")
|
|
|
|
| 117 |
'dataloader_num_workers': 0, 'run_name': 'downstream_esm1b_localization_fix', 'optim': 'adamw_torch', 'resume_from_checkpoint': False, \
|
| 118 |
- 'label_names': ['labels'], 'load_best_model_at_end': True, 'metric_for_best_model': 'accuracy', 'bf16': True, "save_total_limit": 3}
|
| 119 |
+ 'label_names': ['labels'], 'load_best_model_at_end': True, 'metric_for_best_model': 'accuracy', 'bf16': True, "save_total_limit": 3, "use_habana":True, "use_lazy_mode": True, "use_hpu_graphs_for_inference": True}
|
| 120 |
+
- training_args = HfArgumentParser(TrainingArguments).parse_dict(training_args, allow_extra_keys=False)[0]
|
| 121 |
+
+ training_args = HfArgumentParser(GaudiTrainingArguments).parse_dict(training_args, allow_extra_keys=False)[0]
|
| 122 |
|
| 123 |
def tokenize_protein(example, tokenizer=None):
|
| 124 |
protein_seq = example["prot_seq"]
|
|
|
|
| 134 |
for split in ["train", "validation", "test"]:
|
| 135 |
raw_dataset[split] = raw_dataset[split].map(func_tokenize_protein, batched=False, remove_columns=["Unnamed: 0", "prot_seq", "localization"])
|
| 136 |
|
| 137 |
+
- data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
| 138 |
+
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding="max_length", max_length=1024)
|
| 139 |
|
| 140 |
transformers.utils.logging.set_verbosity_info()
|
| 141 |
log_level = training_args.get_process_log_level()
|
|
|
|
| 144 |
optimizer = create_optimizer(model)
|
| 145 |
scheduler = create_scheduler(training_args, optimizer)
|
| 146 |
|
| 147 |
+
+ gaudi_config = GaudiConfig()
|
| 148 |
+
+ gaudi_config.use_fused_adam = True
|
| 149 |
+
+ gaudi_config.use_fused_clip_norm =True
|
| 150 |
|
| 151 |
|
| 152 |
# build trainer
|
| 153 |
+
- trainer = Trainer(
|
| 154 |
+
+ trainer = GaudiTrainer(
|
| 155 |
model=model,
|
| 156 |
+
+ gaudi_config=gaudi_config,
|
| 157 |
args=training_args,
|
| 158 |
train_dataset=raw_dataset["train"],
|
| 159 |
eval_dataset=raw_dataset["validation"],
|