Spaces:
Running
Running
File size: 1,730 Bytes
adcccfb 8a20cad c6a9cc3 8a20cad c6a9cc3 8a20cad adcccfb c6a9cc3 adcccfb c6a9cc3 0bbb8d4 c6a9cc3 adcccfb 3f67405 c6a9cc3 adcccfb 8a20cad adcccfb 8a20cad adcccfb 8a20cad dbbaa64 c6a9cc3 adcccfb c6a9cc3 adcccfb 8a20cad 9df7d7d dbbaa64 c6a9cc3 8a20cad 0bbb8d4 c6a9cc3 3f67405 adcccfb c6a9cc3 8a20cad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
from datasets import load_dataset
from transformers import (
T5ForConditionalGeneration, # Using specific model class
AutoTokenizer,
TrainingArguments,
DataCollatorForSeq2Seq
)
from trl import SFTTrainer
import torch
# 2. Load and prepare dataset
dataset = load_dataset("json", data_files="data/med_q_n_a_converted.jsonl", split="train")
# Create properly formatted text field
def format_example(example):
return {
"text": f"Instruction: {example['input']}\nResponse: {example['output']}",
"input": example["input"],
"output": example["output"]
}
dataset = dataset.map(format_example)
# 3. Load model and tokenizer
model_name = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
# 4. Configure training
training_args = TrainingArguments(
output_dir="./flan-t5-medical-finetuned",
per_device_train_batch_size=4,
gradient_accumulation_steps=2,
num_train_epochs=3,
learning_rate=5e-5,
logging_dir="./logs",
save_strategy="epoch",
evaluation_strategy="no",
fp16=torch.cuda.is_available(),
report_to="none",
remove_unused_columns=False,
# Add these to prevent version conflicts
dataloader_pin_memory=False,
dataloader_num_workers=0
)
# 5. Initialize trainer with proper config
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
args=training_args,
dataset_text_field="text",
max_seq_length=512, # Explicitly set to avoid warning
data_collator=DataCollatorForSeq2Seq(
tokenizer,
model=model,
padding="longest"
)
)
# 6. Start training
trainer.train() |