rivapereira123 commited on
Commit
adcccfb
·
verified ·
1 Parent(s): 64d4c08

Create finetune_flan_t5.py

Browse files
Files changed (1) hide show
  1. finetune_flan_t5.py +60 -0
finetune_flan_t5.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TrainingArguments
3
+ from trl import SFTTrainer, DataCollatorForSeq2Seq
4
+ import torch
5
+
6
+ # Load your dataset (from the converted JSONL file)
7
+ dataset = load_dataset("json", data_files="data/med_q_n_a_converted.jsonl", split="train")
8
+
9
+ # Load tokenizer and model
10
+ model_name = "google/flan-t5-base"
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
13
+
14
+ # Preprocess dataset
15
+ def preprocess(example):
16
+ input_text = example["instruction"]
17
+ target_text = example["output"]
18
+ tokenized = tokenizer(
19
+ input_text,
20
+ max_length=512,
21
+ truncation=True,
22
+ padding="max_length"
23
+ )
24
+ with tokenizer.as_target_tokenizer():
25
+ tokenized["labels"] = tokenizer(
26
+ target_text,
27
+ max_length=128,
28
+ truncation=True,
29
+ padding="max_length"
30
+ )["input_ids"]
31
+ return tokenized
32
+
33
+ tokenized_dataset = dataset.map(preprocess, remove_columns=dataset.column_names)
34
+
35
+ # Define training arguments
36
+ training_args = TrainingArguments(
37
+ output_dir="./flan-t5-medical",
38
+ per_device_train_batch_size=4,
39
+ gradient_accumulation_steps=2,
40
+ num_train_epochs=3,
41
+ logging_dir="./logs",
42
+ save_strategy="epoch",
43
+ evaluation_strategy="no",
44
+ fp16=torch.cuda.is_available()
45
+ )
46
+
47
+ # Define data collator
48
+ data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
49
+
50
+ # Initialize trainer
51
+ trainer = SFTTrainer(
52
+ model=model,
53
+ args=training_args,
54
+ train_dataset=tokenized_dataset,
55
+ tokenizer=tokenizer,
56
+ data_collator=data_collator,
57
+ )
58
+
59
+ # Start training
60
+ trainer.train()