File size: 1,730 Bytes
adcccfb
8a20cad
c6a9cc3
 
8a20cad
c6a9cc3
8a20cad
 
adcccfb
 
c6a9cc3
 
 
adcccfb
 
c6a9cc3
 
 
 
 
 
 
 
 
0bbb8d4
c6a9cc3
adcccfb
 
3f67405
 
c6a9cc3
adcccfb
8a20cad
adcccfb
 
 
8a20cad
adcccfb
 
 
8a20cad
dbbaa64
c6a9cc3
 
 
 
adcccfb
 
c6a9cc3
adcccfb
 
 
8a20cad
9df7d7d
dbbaa64
c6a9cc3
8a20cad
 
0bbb8d4
c6a9cc3
3f67405
adcccfb
 
c6a9cc3
8a20cad
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from datasets import load_dataset
from transformers import (
    T5ForConditionalGeneration,  # Using specific model class
    AutoTokenizer,
    TrainingArguments,
    DataCollatorForSeq2Seq
)
from trl import SFTTrainer
import torch



# 2. Load and prepare dataset
dataset = load_dataset("json", data_files="data/med_q_n_a_converted.jsonl", split="train")

# Create properly formatted text field
def format_example(example):
    return {
        "text": f"Instruction: {example['input']}\nResponse: {example['output']}",
        "input": example["input"],
        "output": example["output"]
    }

dataset = dataset.map(format_example)

# 3. Load model and tokenizer
model_name = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

# 4. Configure training
training_args = TrainingArguments(
    output_dir="./flan-t5-medical-finetuned",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,
    num_train_epochs=3,
    learning_rate=5e-5,
    logging_dir="./logs",
    save_strategy="epoch",
    evaluation_strategy="no",
    fp16=torch.cuda.is_available(),
    report_to="none",
    remove_unused_columns=False,
    # Add these to prevent version conflicts
    dataloader_pin_memory=False,
    dataloader_num_workers=0
)

# 5. Initialize trainer with proper config
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
    args=training_args,
    dataset_text_field="text",
    max_seq_length=512,  # Explicitly set to avoid warning
    data_collator=DataCollatorForSeq2Seq(
        tokenizer,
        model=model,
        padding="longest"
    )
)

# 6. Start training
trainer.train()