training-scripts / tokenize_dataset.py
Codyfederer's picture
Upload tokenize_dataset.py with huggingface_hub
8bca88b verified
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "transformers>=4.50.0",
# "datasets>=2.14.0",
# "huggingface_hub",
# ]
# ///
"""
Tokenize Dataset Script: Prepare Tool Calling Dataset for Training
This script tokenizes the nvidia/Nemotron-Agentic-v1 tool_calling dataset
and uploads it to HuggingFace Hub for reuse.
Usage:
uv run tokenize_dataset.py
Can run on CPU - no GPU required!
"""
import os
import json
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download, HfApi, create_repo
# ============================================================================
# CONFIGURATION
# ============================================================================
# Model to get tokenizer from
BASE_MODEL = "Tesslate/Synthia-S1-27b"
# Source dataset
DATASET_NAME = "nvidia/Nemotron-Agentic-v1"
DATASET_SPLIT = "tool_calling"
# Output tokenized dataset
TOKENIZED_DATASET_REPO = "Codyfederer/synthia-tool-calling-tokenized"
TOKENIZED_DATASET_PRIVATE = True
# Tokenization settings
MAX_SEQ_LENGTH = 4096
# ============================================================================
# TOKENIZATION FUNCTIONS
# ============================================================================
def tokenize_conversation(example, tokenizer, max_length):
"""
Tokenize a conversation using the model's chat template.
Returns input_ids, attention_mask, and labels for causal LM training.
"""
messages = example["messages"]
# Apply chat template to get the full text
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False
)
# Tokenize the text
tokenized = tokenizer(
text,
truncation=True,
max_length=max_length,
padding=False,
return_tensors=None,
)
# For causal LM, labels are the same as input_ids
tokenized["labels"] = tokenized["input_ids"].copy()
return tokenized
def main():
print("=" * 60)
print("Tokenize Dataset for Tool Calling Training")
print("=" * 60)
# Get HF username
from huggingface_hub import whoami
try:
username = whoami()["name"]
print(f"Logged in as: {username}")
except Exception as e:
print(f"ERROR: Not logged in to HF Hub ({e})")
print("Run 'huggingface-cli login' first")
return
# -------------------------------------------------------------------------
# Load Tokenizer
# -------------------------------------------------------------------------
print(f"\nLoading tokenizer from {BASE_MODEL}...")
tokenizer = AutoTokenizer.from_pretrained(
BASE_MODEL,
trust_remote_code=True,
padding_side="right",
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
print(f"Vocab size: {len(tokenizer):,}")
# -------------------------------------------------------------------------
# Load Source Dataset
# -------------------------------------------------------------------------
print(f"\nLoading dataset: {DATASET_NAME} ({DATASET_SPLIT} split)...")
# Download the JSONL file
jsonl_file = f"data/{DATASET_SPLIT}.jsonl"
print(f"Downloading {jsonl_file}...")
local_path = hf_hub_download(
repo_id=DATASET_NAME,
filename=jsonl_file,
repo_type="dataset"
)
print(f"Downloaded to: {local_path}")
# Load and process JSONL
print("Loading and processing JSONL file...")
processed_examples = []
skipped = 0
with open(local_path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f):
if line_num % 50000 == 0:
print(f" Processed {line_num:,} lines...")
try:
example = json.loads(line.strip())
messages = example.get("messages", [])
# Convert messages to consistent format
formatted_messages = []
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
# Handle content that might be a list or complex object
if isinstance(content, list):
parts = []
for item in content:
if isinstance(item, dict):
if "text" in item:
parts.append(item["text"])
else:
parts.append(json.dumps(item))
else:
parts.append(str(item))
content = "\n".join(parts) if parts else ""
elif isinstance(content, dict):
content = json.dumps(content)
elif content is None:
content = ""
else:
content = str(content)
formatted_messages.append({
"role": role,
"content": content
})
# Merge consecutive messages with same role
if formatted_messages:
merged_messages = []
for msg in formatted_messages:
role = msg["role"]
content = msg["content"]
# Map tool role to user
if role == "tool":
role = "user"
content = f"[Tool Result]\n{content}"
if merged_messages and merged_messages[-1]["role"] == role:
merged_messages[-1]["content"] += f"\n\n{content}"
else:
merged_messages.append({"role": role, "content": content})
# Ensure conversation starts with user
if merged_messages and merged_messages[0]["role"] != "user":
merged_messages.insert(0, {"role": "user", "content": "[Start]"})
processed_examples.append({"messages": merged_messages})
except Exception as e:
skipped += 1
if skipped < 5:
print(f" Warning: Skipped line {line_num}: {e}")
print(f"Loaded {len(processed_examples):,} examples (skipped {skipped})")
# Create dataset
dataset = Dataset.from_list(processed_examples)
print(f"Dataset size: {len(dataset):,} examples")
# Create train/eval split
split_dataset = dataset.train_test_split(test_size=0.02, seed=42)
train_dataset = split_dataset["train"]
eval_dataset = split_dataset["test"]
print(f"Train samples: {len(train_dataset):,}")
print(f"Eval samples: {len(eval_dataset):,}")
# -------------------------------------------------------------------------
# Tokenize Dataset
# -------------------------------------------------------------------------
print(f"\nTokenizing dataset with max_length={MAX_SEQ_LENGTH}...")
print("This may take a while for large datasets...")
train_dataset = train_dataset.map(
lambda x: tokenize_conversation(x, tokenizer, MAX_SEQ_LENGTH),
remove_columns=["messages"],
num_proc=1, # Use single process to reduce memory
desc="Tokenizing train",
)
eval_dataset = eval_dataset.map(
lambda x: tokenize_conversation(x, tokenizer, MAX_SEQ_LENGTH),
remove_columns=["messages"],
num_proc=1, # Use single process to reduce memory
desc="Tokenizing eval",
)
print(f"Tokenization complete!")
print(f"Train dataset columns: {train_dataset.column_names}")
print(f"Sample input_ids length: {len(train_dataset[0]['input_ids'])}")
# -------------------------------------------------------------------------
# Upload to Hub
# -------------------------------------------------------------------------
print(f"\nUploading TOKENIZED dataset to Hub: {TOKENIZED_DATASET_REPO}")
# Create repo
api = HfApi()
try:
create_repo(
TOKENIZED_DATASET_REPO,
repo_type="dataset",
private=TOKENIZED_DATASET_PRIVATE,
exist_ok=True
)
print(f" Created/verified repo (private={TOKENIZED_DATASET_PRIVATE})")
if TOKENIZED_DATASET_PRIVATE:
try:
api.update_repo_visibility(
TOKENIZED_DATASET_REPO,
repo_type="dataset",
private=True
)
except Exception:
pass
except Exception as e:
print(f" Repo note: {e}")
# Reset format for serialization
train_dataset.reset_format()
eval_dataset.reset_format()
# Verify data
print(f" Verifying tokenized data...")
print(f" Train columns: {train_dataset.column_names}")
print(f" Sample input_ids type: {type(train_dataset[0]['input_ids'])}")
print(f" Sample input_ids length: {len(train_dataset[0]['input_ids'])}")
print(f" First 10 tokens: {train_dataset[0]['input_ids'][:10]}")
# Push to Hub
print(f" Pushing train split ({len(train_dataset):,} examples)...")
train_dataset.push_to_hub(
TOKENIZED_DATASET_REPO,
split="train",
)
print(f" Pushing test split ({len(eval_dataset):,} examples)...")
eval_dataset.push_to_hub(
TOKENIZED_DATASET_REPO,
split="test",
)
print(f"\n" + "=" * 60)
print(f"SUCCESS! Tokenized dataset saved to:")
print(f" https://huggingface.co/datasets/{TOKENIZED_DATASET_REPO}")
print(f"=" * 60)
# Verify upload
print("\nVerifying upload...")
try:
from datasets import load_dataset as verify_load
verify_ds = verify_load(TOKENIZED_DATASET_REPO, split="train", streaming=True)
sample = next(iter(verify_ds))
if "input_ids" in sample:
print(f" VERIFIED: Dataset contains input_ids with {len(sample['input_ids'])} tokens")
else:
print(f" WARNING: input_ids not found in columns: {list(sample.keys())}")
except Exception as ve:
print(f" Could not verify: {ve}")
if __name__ == "__main__":
main()