|
|
""" |
|
|
This script sets up a HuggingFace-based training and inference pipeline |
|
|
for bug-fixing AI using a CodeT5 model. It is designed to be more |
|
|
robust and flexible than the original. |
|
|
|
|
|
Key improvements: |
|
|
- Uses argparse for configuration, making it easy to change settings |
|
|
via the command line. |
|
|
- Adds checks to ensure data files exist. |
|
|
- Implements a compute_metrics function for better model evaluation. |
|
|
- Optimizes data preprocessing with dynamic padding. |
|
|
- Saves the best-performing model based on evaluation metrics. |
|
|
- Checks for GPU availability. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import argparse |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Trainer, TrainingArguments, DataCollatorForSeq2Seq |
|
|
from datasets import load_dataset, DatasetDict |
|
|
from typing import Dict |
|
|
from evaluate import load |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
"""Parses command-line arguments for the training script.""" |
|
|
parser = argparse.ArgumentParser(description="Fine-tune a Seq2Seq model for code repair.") |
|
|
parser.add_argument("--model_name", type=str, default="Salesforce/codet5p-220m", |
|
|
help="Pre-trained model name from HuggingFace.") |
|
|
parser.add_argument("--output_dir", type=str, default="./aifixcode-model", |
|
|
help="Directory to save the trained model.") |
|
|
parser.add_argument("--train_path", type=str, default="./data/train.json", |
|
|
help="Path to the training data JSON file.") |
|
|
parser.add_argument("--val_path", type=str, default="./data/val.json", |
|
|
help="Path to the validation data JSON file.") |
|
|
parser.add_argument("--epochs", type=int, default=3, |
|
|
help="Number of training epochs.") |
|
|
parser.add_argument("--learning_rate", type=float, default=5e-5, |
|
|
help="Learning rate for the optimizer.") |
|
|
parser.add_argument("--per_device_train_batch_size", type=int, default=4, |
|
|
help="Batch size per device for training.") |
|
|
parser.add_argument("--per_device_eval_batch_size", type=int, default=4, |
|
|
help="Batch size per device for evaluation.") |
|
|
parser.add_argument("--push_to_hub", action="store_true", |
|
|
help="Whether to push the model to the Hugging Face Hub.") |
|
|
parser.add_argument("--hub_model_id", type=str, default="khulnasoft/aifixcode-model", |
|
|
help="Hugging Face Hub model ID to push to.") |
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def load_json_dataset(train_path: str, val_path: str) -> DatasetDict: |
|
|
"""Loads and returns a dataset dictionary from JSON files.""" |
|
|
if not os.path.exists(train_path) or not os.path.exists(val_path): |
|
|
raise FileNotFoundError(f"One or both data files not found: {train_path}, {val_path}") |
|
|
|
|
|
print("Loading dataset...") |
|
|
dataset = DatasetDict({ |
|
|
"train": load_dataset("json", data_files=train_path, split="train"), |
|
|
"validation": load_dataset("json", data_files=val_path, split="train") |
|
|
}) |
|
|
return dataset |
|
|
|
|
|
|
|
|
def preprocess_function(examples: Dict[str, list], tokenizer) -> Dict[str, list]: |
|
|
"""Tokenizes a batch of input and target code. |
|
|
|
|
|
This function uses dynamic padding by default, which is more |
|
|
memory-efficient than padding all sequences to a fixed max length. |
|
|
""" |
|
|
inputs = [ex for ex in examples["input"]] |
|
|
targets = [ex for ex in examples["output"]] |
|
|
|
|
|
model_inputs = tokenizer(inputs, text_target=targets, max_length=512, truncation=True) |
|
|
return model_inputs |
|
|
|
|
|
|
|
|
def compute_metrics(eval_pred): |
|
|
"""Computes BLEU and Rouge metrics for model evaluation.""" |
|
|
bleu_metric = load("bleu") |
|
|
rouge_metric = load("rouge") |
|
|
|
|
|
predictions, labels = eval_pred |
|
|
|
|
|
|
|
|
labels = [[item if item != -100 else tokenizer.pad_token_id for item in row] for row in labels] |
|
|
|
|
|
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True) |
|
|
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) |
|
|
|
|
|
|
|
|
bleu_result = bleu_metric.compute(predictions=decoded_preds, references=decoded_labels) |
|
|
|
|
|
|
|
|
rouge_result = rouge_metric.compute(predictions=decoded_preds, references=decoded_labels) |
|
|
|
|
|
return { |
|
|
"bleu": bleu_result["bleu"], |
|
|
"rouge1": rouge_result["rouge1"], |
|
|
"rouge2": rouge_result["rouge2"], |
|
|
"rougeL": rouge_result["rougeL"], |
|
|
} |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main function to set up and run the training pipeline.""" |
|
|
args = parse_args() |
|
|
|
|
|
|
|
|
if not torch.cuda.is_available(): |
|
|
print("Warning: A GPU is not available. Training will be very slow on CPU.") |
|
|
|
|
|
|
|
|
print(f"Loading model '{args.model_name}' and tokenizer...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name) |
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name) |
|
|
|
|
|
|
|
|
try: |
|
|
dataset = load_json_dataset(args.train_path, args.val_path) |
|
|
except FileNotFoundError as e: |
|
|
print(e) |
|
|
return |
|
|
|
|
|
print("Tokenizing dataset...") |
|
|
tokenized_dataset = dataset.map( |
|
|
lambda examples: preprocess_function(examples, tokenizer), |
|
|
batched=True, |
|
|
remove_columns=dataset["train"].column_names |
|
|
) |
|
|
|
|
|
|
|
|
print("Setting up trainer...") |
|
|
training_args = TrainingArguments( |
|
|
output_dir=os.path.join(args.output_dir, "checkpoints"), |
|
|
evaluation_strategy="epoch", |
|
|
save_strategy="epoch", |
|
|
learning_rate=args.learning_rate, |
|
|
per_device_train_batch_size=args.per_device_train_batch_size, |
|
|
per_device_eval_batch_size=args.per_device_eval_batch_size, |
|
|
num_train_epochs=args.epochs, |
|
|
weight_decay=0.01, |
|
|
logging_dir=os.path.join(args.output_dir, "logs"), |
|
|
logging_strategy="epoch", |
|
|
push_to_hub=args.push_to_hub, |
|
|
hub_model_id=args.hub_model_id if args.push_to_hub else None, |
|
|
hub_strategy="every_save", |
|
|
load_best_model_at_end=True, |
|
|
metric_for_best_model="rougeL", |
|
|
greater_is_better=True, |
|
|
report_to="tensorboard" |
|
|
) |
|
|
|
|
|
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model) |
|
|
|
|
|
|
|
|
trainer = Trainer( |
|
|
model=model, |
|
|
args=training_args, |
|
|
train_dataset=tokenized_dataset["train"], |
|
|
eval_dataset=tokenized_dataset["validation"], |
|
|
tokenizer=tokenizer, |
|
|
data_collator=data_collator, |
|
|
compute_metrics=compute_metrics |
|
|
) |
|
|
|
|
|
print("Starting training...") |
|
|
trainer.train() |
|
|
|
|
|
|
|
|
print("Saving final model...") |
|
|
final_model_dir = os.path.join(args.output_dir, "final") |
|
|
trainer.save_model(final_model_dir) |
|
|
tokenizer.save_pretrained(final_model_dir) |
|
|
print("Training complete and model saved!") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|