fisherman611 commited on
Commit
c6f9ba7
·
verified ·
1 Parent(s): 9144bfc

Create models/mbart50.py

Browse files
Files changed (1) hide show
  1. models/mbart50.py +125 -0
models/mbart50.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
5
+
6
+ import torch
7
+ from transformers import MBart50Tokenizer, MBartForConditionalGeneration # type: ignore
8
+ from datasets import load_dataset
9
+ from peft import LoraConfig, get_peft_model, TaskType
10
+ from dotenv import load_dotenv
11
+ import wandb
12
+ import json
13
+ from utils.helper import TextPreprocessor
14
+ from utils.trainer import train_model
15
+
16
+ load_dotenv()
17
+
18
+
19
+ class MBart50Finetuner:
20
+ """Class to handle fine-tuning of mBART50 model for translation tasks."""
21
+
22
+ def __init__(self, config_path="config.json"):
23
+ """Initialize with configuration file."""
24
+ with open(config_path, "r") as json_file:
25
+ cfg = json.load(json_file)
26
+
27
+ self.args = cfg["mbart50"]["args"]
28
+ self.lora_config = cfg["mbart50"]["lora_config"]
29
+
30
+ # Constants
31
+ self.max_len = self.args["max_len"]
32
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+ self.id = self.args["id"]
34
+ self.initial_learning_rate = self.args["initial_learning_rate"]
35
+ self.model_name = self.args["model_name"]
36
+ self.src_lang = self.args["src_lang"]
37
+ self.tgt_lang = self.args["tgt_lang"]
38
+ self.wandb_project = self.args["wandb_project"]
39
+ self.output_dir = self.args["output_dir"]
40
+ self.name = "mbart50"
41
+
42
+ self.model = None
43
+ self.tokenizer = None
44
+ self.train_dataset = None
45
+ self.val_dataset = None
46
+ self.test_dataset = None
47
+
48
+ def setup_wandb(self):
49
+ """Initialize Weights & Biases for experiment tracking."""
50
+ wandb.login(key=os.environ.get("WANDB_API"), relogin=True)
51
+ wandb.init(project=self.wandb_project, name="mbart50-finetune-lora")
52
+
53
+ def load_model_and_tokenizer(self):
54
+ """Load the mBART model and tokenizer."""
55
+ self.tokenizer = MBart50Tokenizer.from_pretrained(self.model_name)
56
+ self.model = MBartForConditionalGeneration.from_pretrained(self.model_name)
57
+ self.tokenizer.src_lang = self.src_lang
58
+ self.tokenizer.tgt_lang = self.tgt_lang
59
+
60
+ def load_datasets(self):
61
+ """Load training, validation, and test datasets."""
62
+ data_files = {
63
+ "train": "data/train_cleaned_dataset.csv",
64
+ "test": "data/test_cleaned_dataset.csv",
65
+ "val": "data/val_cleaned_dataset.csv",
66
+ }
67
+
68
+ if self.id is not None:
69
+ training_parts = [
70
+ f"[{(i * 200000) + 1 if i > 0 else ''}:{(i + 1) * 200000 if i < 10 else ''}]"
71
+ for i in range(11)
72
+ ]
73
+ self.train_dataset = load_dataset(
74
+ "csv", data_files=data_files, split=f"train{training_parts[self.id]}"
75
+ )
76
+ self.test_dataset = load_dataset("csv", data_files=data_files, split="test")
77
+ self.val_dataset = load_dataset(
78
+ "csv", data_files=data_files, split="val[:20000]"
79
+ )
80
+ else:
81
+ self.train_dataset = load_dataset(
82
+ "csv", data_files=data_files, split="train[:1000000]"
83
+ )
84
+ self.test_dataset = load_dataset("csv", data_files=data_files, split="test[:100000]")
85
+ self.val_dataset = load_dataset("csv", data_files=data_files, split="val[:100000]")
86
+
87
+ def configure_lora(self):
88
+ """Apply LoRA configuration to the model."""
89
+ lora_config = LoraConfig(
90
+ task_type=TaskType.SEQ_2_SEQ_LM,
91
+ r=self.lora_config["r"],
92
+ lora_alpha=self.lora_config["lora_alpha"],
93
+ target_modules=self.lora_config["target_modules"],
94
+ lora_dropout=self.lora_config["lora_dropout"],
95
+ )
96
+ self.model = get_peft_model(self.model, lora_config) # type: ignore
97
+
98
+ def finetune(self):
99
+ """Orchestrate the fine-tuning process."""
100
+ self.setup_wandb()
101
+ self.load_model_and_tokenizer()
102
+ self.load_datasets()
103
+
104
+ preprocessor = TextPreprocessor(self.tokenizer, self.max_len, name="mbart50")
105
+ tokenized_train_dataset = preprocessor.preprocess_dataset(self.train_dataset)
106
+ tokenized_eval_dataset = preprocessor.preprocess_dataset(self.val_dataset)
107
+
108
+ self.configure_lora()
109
+ self.model.print_trainable_parameters() # type: ignore
110
+
111
+ train_model(
112
+ model=self.model,
113
+ tokenizer=self.tokenizer,
114
+ train_dataset=tokenized_train_dataset,
115
+ eval_dataset=tokenized_eval_dataset,
116
+ output_dir=self.output_dir,
117
+ initial_learning_rate=self.initial_learning_rate,
118
+ name=self.name,
119
+ val_dataset=self.val_dataset,
120
+ )
121
+
122
+
123
+ if __name__ == "__main__":
124
+ finetuner = MBart50Finetuner()
125
+ finetuner.finetune()