|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
LoRA Fine-tuning: Add Tool Calling to Synthia-S1-27b |
|
|
Using pre-tokenized data from Codyfederer/synthia-tool-calling-tokenized |
|
|
Optimized for H100 80GB |
|
|
""" |
|
|
|
|
|
import os |
|
|
from dataclasses import dataclass |
|
|
from typing import Any, Dict, List |
|
|
from datasets import load_dataset |
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
AutoModelForCausalLM, |
|
|
Trainer, |
|
|
TrainingArguments, |
|
|
) |
|
|
from peft import LoraConfig, get_peft_model |
|
|
import torch |
|
|
import trackio |
|
|
from huggingface_hub import whoami |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DataCollatorForPreTokenized: |
|
|
"""Data collator for pre-tokenized datasets with padding.""" |
|
|
pad_token_id: int |
|
|
|
|
|
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: |
|
|
|
|
|
max_length = max(len(f["input_ids"]) for f in features) |
|
|
|
|
|
batch = { |
|
|
"input_ids": [], |
|
|
"attention_mask": [], |
|
|
"labels": [], |
|
|
} |
|
|
|
|
|
for feature in features: |
|
|
input_ids = feature["input_ids"] |
|
|
attention_mask = feature["attention_mask"] |
|
|
labels = feature.get("labels", input_ids.copy()) |
|
|
|
|
|
|
|
|
padding_length = max_length - len(input_ids) |
|
|
|
|
|
|
|
|
batch["input_ids"].append(input_ids + [self.pad_token_id] * padding_length) |
|
|
batch["attention_mask"].append(attention_mask + [0] * padding_length) |
|
|
batch["labels"].append(labels + [-100] * padding_length) |
|
|
|
|
|
|
|
|
return {k: torch.tensor(v, dtype=torch.long) for k, v in batch.items()} |
|
|
|
|
|
|
|
|
BASE_MODEL = "Tesslate/Synthia-S1-27b" |
|
|
OUTPUT_MODEL = "Synthia-S1-27b-tool-calling" |
|
|
TOKENIZED_DATASET = "Codyfederer/synthia-tool-calling-tokenized" |
|
|
MAX_SEQ_LENGTH = 4096 |
|
|
|
|
|
|
|
|
BATCH_SIZE = 4 |
|
|
GRADIENT_ACCUMULATION = 8 |
|
|
LEARNING_RATE = 2e-4 |
|
|
NUM_EPOCHS = 1 |
|
|
LORA_R = 64 |
|
|
LORA_ALPHA = 128 |
|
|
|
|
|
print("=" * 60) |
|
|
print("Tool Calling Fine-tuning for Synthia-S1-27b (H100)") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
trackio.init(project="synthia-tool-calling") |
|
|
|
|
|
|
|
|
try: |
|
|
username = whoami()["name"] |
|
|
hub_model_id = f"{username}/{OUTPUT_MODEL}" |
|
|
print(f"Will push to: {hub_model_id}") |
|
|
except Exception as e: |
|
|
print(f"Error getting username: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
print(f"\nLoading tokenizer from {BASE_MODEL}...") |
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
BASE_MODEL, |
|
|
trust_remote_code=True, |
|
|
padding_side="right", |
|
|
) |
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
|
print(f"Vocab size: {len(tokenizer):,}") |
|
|
|
|
|
|
|
|
print(f"\nLoading pre-tokenized dataset: {TOKENIZED_DATASET}") |
|
|
tokenized_ds = load_dataset(TOKENIZED_DATASET) |
|
|
|
|
|
train_dataset = tokenized_ds["train"] |
|
|
eval_dataset = tokenized_ds.get("test", tokenized_ds.get("validation")) |
|
|
|
|
|
print(f"Train samples: {len(train_dataset):,}") |
|
|
if eval_dataset: |
|
|
print(f"Eval samples: {len(eval_dataset):,}") |
|
|
|
|
|
|
|
|
def truncate_example(example): |
|
|
return { |
|
|
"input_ids": example["input_ids"][:MAX_SEQ_LENGTH], |
|
|
"attention_mask": example["attention_mask"][:MAX_SEQ_LENGTH], |
|
|
"labels": example["labels"][:MAX_SEQ_LENGTH] if "labels" in example else example["input_ids"][:MAX_SEQ_LENGTH], |
|
|
} |
|
|
|
|
|
print(f"Truncating to max_length={MAX_SEQ_LENGTH}...") |
|
|
train_dataset = train_dataset.map(truncate_example, desc="Truncating train") |
|
|
if eval_dataset: |
|
|
eval_dataset = eval_dataset.map(truncate_example, desc="Truncating eval") |
|
|
|
|
|
|
|
|
print(f"\nLoading model: {BASE_MODEL}...") |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
BASE_MODEL, |
|
|
device_map="auto", |
|
|
trust_remote_code=True, |
|
|
torch_dtype=torch.bfloat16, |
|
|
attn_implementation="sdpa", |
|
|
) |
|
|
print(f"Model loaded. Parameters: {model.num_parameters():,}") |
|
|
|
|
|
|
|
|
print(f"\nConfiguring LoRA (r={LORA_R}, alpha={LORA_ALPHA})...") |
|
|
lora_config = LoraConfig( |
|
|
r=LORA_R, |
|
|
lora_alpha=LORA_ALPHA, |
|
|
lora_dropout=0.05, |
|
|
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], |
|
|
bias="none", |
|
|
task_type="CAUSAL_LM", |
|
|
) |
|
|
model = get_peft_model(model, lora_config) |
|
|
model.print_trainable_parameters() |
|
|
|
|
|
|
|
|
print("\nConfiguring training...") |
|
|
training_args = TrainingArguments( |
|
|
output_dir=f"./{OUTPUT_MODEL}", |
|
|
num_train_epochs=NUM_EPOCHS, |
|
|
per_device_train_batch_size=BATCH_SIZE, |
|
|
per_device_eval_batch_size=BATCH_SIZE, |
|
|
gradient_accumulation_steps=GRADIENT_ACCUMULATION, |
|
|
learning_rate=LEARNING_RATE, |
|
|
lr_scheduler_type="cosine", |
|
|
warmup_ratio=0.03, |
|
|
weight_decay=0.01, |
|
|
optim="adamw_torch", |
|
|
gradient_checkpointing=True, |
|
|
gradient_checkpointing_kwargs={"use_reentrant": False}, |
|
|
max_grad_norm=1.0, |
|
|
eval_strategy="steps", |
|
|
eval_steps=500, |
|
|
save_strategy="steps", |
|
|
save_steps=500, |
|
|
save_total_limit=3, |
|
|
push_to_hub=True, |
|
|
hub_model_id=hub_model_id, |
|
|
hub_strategy="checkpoint", |
|
|
logging_steps=10, |
|
|
report_to="trackio", |
|
|
run_name=f"synthia-tool-calling-lora-r{LORA_R}", |
|
|
bf16=True, |
|
|
dataloader_num_workers=0, |
|
|
dataloader_pin_memory=True, |
|
|
seed=42, |
|
|
remove_unused_columns=False, |
|
|
) |
|
|
|
|
|
|
|
|
print("\nInitializing trainer...") |
|
|
data_collator = DataCollatorForPreTokenized(pad_token_id=tokenizer.pad_token_id) |
|
|
|
|
|
trainer = Trainer( |
|
|
model=model, |
|
|
args=training_args, |
|
|
train_dataset=train_dataset, |
|
|
eval_dataset=eval_dataset, |
|
|
tokenizer=tokenizer, |
|
|
data_collator=data_collator, |
|
|
) |
|
|
|
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("Starting training...") |
|
|
print("=" * 60 + "\n") |
|
|
trainer.train() |
|
|
|
|
|
|
|
|
print("\nSaving final model...") |
|
|
trainer.save_model() |
|
|
print(f"Pushing to Hub: {hub_model_id}") |
|
|
trainer.push_to_hub() |
|
|
|
|
|
print(f"\n" + "=" * 60) |
|
|
print(f"Training complete!") |
|
|
print(f"Model available at: https://huggingface.co/{hub_model_id}") |
|
|
print("=" * 60) |
|
|
|