aifixcode-model / aifixcode_trainer.py
khulnasoft's picture
Update aifixcode_trainer.py
bf2f259 verified
"""
This script sets up a HuggingFace-based training and inference pipeline
for bug-fixing AI using a CodeT5 model. It is designed to be more
robust and flexible than the original.
Key improvements:
- Uses argparse for configuration, making it easy to change settings
via the command line.
- Adds checks to ensure data files exist.
- Implements a compute_metrics function for better model evaluation.
- Optimizes data preprocessing with dynamic padding.
- Saves the best-performing model based on evaluation metrics.
- Checks for GPU availability.
"""
import os
import argparse
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Trainer, TrainingArguments, DataCollatorForSeq2Seq
from datasets import load_dataset, DatasetDict
from typing import Dict
from evaluate import load
# ========== ARGUMENT PARSING ==========
def parse_args():
"""Parses command-line arguments for the training script."""
parser = argparse.ArgumentParser(description="Fine-tune a Seq2Seq model for code repair.")
parser.add_argument("--model_name", type=str, default="Salesforce/codet5p-220m",
help="Pre-trained model name from HuggingFace.")
parser.add_argument("--output_dir", type=str, default="./aifixcode-model",
help="Directory to save the trained model.")
parser.add_argument("--train_path", type=str, default="./data/train.json",
help="Path to the training data JSON file.")
parser.add_argument("--val_path", type=str, default="./data/val.json",
help="Path to the validation data JSON file.")
parser.add_argument("--epochs", type=int, default=3,
help="Number of training epochs.")
parser.add_argument("--learning_rate", type=float, default=5e-5,
help="Learning rate for the optimizer.")
parser.add_argument("--per_device_train_batch_size", type=int, default=4,
help="Batch size per device for training.")
parser.add_argument("--per_device_eval_batch_size", type=int, default=4,
help="Batch size per device for evaluation.")
parser.add_argument("--push_to_hub", action="store_true",
help="Whether to push the model to the Hugging Face Hub.")
parser.add_argument("--hub_model_id", type=str, default="khulnasoft/aifixcode-model",
help="Hugging Face Hub model ID to push to.")
return parser.parse_args()
# ========== DATA LOADING ==========
def load_json_dataset(train_path: str, val_path: str) -> DatasetDict:
"""Loads and returns a dataset dictionary from JSON files."""
if not os.path.exists(train_path) or not os.path.exists(val_path):
raise FileNotFoundError(f"One or both data files not found: {train_path}, {val_path}")
print("Loading dataset...")
dataset = DatasetDict({
"train": load_dataset("json", data_files=train_path, split="train"),
"validation": load_dataset("json", data_files=val_path, split="train")
})
return dataset
# ========== DATA PREPROCESSING ==========
def preprocess_function(examples: Dict[str, list], tokenizer) -> Dict[str, list]:
"""Tokenizes a batch of input and target code.
This function uses dynamic padding by default, which is more
memory-efficient than padding all sequences to a fixed max length.
"""
inputs = [ex for ex in examples["input"]]
targets = [ex for ex in examples["output"]]
model_inputs = tokenizer(inputs, text_target=targets, max_length=512, truncation=True)
return model_inputs
# ========== METRIC CALCULATION ==========
def compute_metrics(eval_pred):
"""Computes BLEU and Rouge metrics for model evaluation."""
bleu_metric = load("bleu")
rouge_metric = load("rouge")
predictions, labels = eval_pred
# Replace -100 in labels as we can't decode them
labels = [[item if item != -100 else tokenizer.pad_token_id for item in row] for row in labels]
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
# Compute BLEU score
bleu_result = bleu_metric.compute(predictions=decoded_preds, references=decoded_labels)
# Compute ROUGE score
rouge_result = rouge_metric.compute(predictions=decoded_preds, references=decoded_labels)
return {
"bleu": bleu_result["bleu"],
"rouge1": rouge_result["rouge1"],
"rouge2": rouge_result["rouge2"],
"rougeL": rouge_result["rougeL"],
}
# ========== MAIN EXECUTION BLOCK ==========
def main():
"""Main function to set up and run the training pipeline."""
args = parse_args()
# Check for GPU availability
if not torch.cuda.is_available():
print("Warning: A GPU is not available. Training will be very slow on CPU.")
# Load model and tokenizer
print(f"Loading model '{args.model_name}' and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)
# Load and preprocess dataset
try:
dataset = load_json_dataset(args.train_path, args.val_path)
except FileNotFoundError as e:
print(e)
return
print("Tokenizing dataset...")
tokenized_dataset = dataset.map(
lambda examples: preprocess_function(examples, tokenizer),
batched=True,
remove_columns=dataset["train"].column_names
)
# Training arguments setup
print("Setting up trainer...")
training_args = TrainingArguments(
output_dir=os.path.join(args.output_dir, "checkpoints"),
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=args.learning_rate,
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_eval_batch_size,
num_train_epochs=args.epochs,
weight_decay=0.01,
logging_dir=os.path.join(args.output_dir, "logs"),
logging_strategy="epoch",
push_to_hub=args.push_to_hub,
hub_model_id=args.hub_model_id if args.push_to_hub else None,
hub_strategy="every_save",
load_best_model_at_end=True, # Saves the best model
metric_for_best_model="rougeL", # Specify the metric to use for saving the best model
greater_is_better=True,
report_to="tensorboard"
)
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
# Initialize and train the trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["validation"],
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics
)
print("Starting training...")
trainer.train()
# Save final model
print("Saving final model...")
final_model_dir = os.path.join(args.output_dir, "final")
trainer.save_model(final_model_dir)
tokenizer.save_pretrained(final_model_dir)
print("Training complete and model saved!")
if __name__ == "__main__":
main()