| |
| |
| |
| from __future__ import annotations |
|
|
| import argparse |
| import logging |
| import os |
| import time |
| from pathlib import Path |
| from typing import cast |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Single-file ST loss training example (no src imports).") |
| parser.add_argument( |
| "--model_name", |
| default="answerdotai/ModernBERT-base", |
| help="Sentence-Transformers model name or path.", |
| ) |
| parser.add_argument("--max_seq_length", type=int, default=512) |
| parser.add_argument( |
| "--max_train_examples", |
| type=int, |
| default=-1, |
| help="Limit training examples (use -1 for full dataset).", |
| ) |
| parser.add_argument("--seed", type=int, default=12) |
| parser.add_argument("--num_train_epochs", type=int, default=1) |
| parser.add_argument("--per_device_train_batch_size", type=int, default=8192) |
| parser.add_argument("--per_device_eval_batch_size", type=int, default=512) |
| parser.add_argument( |
| "--learning_rate", |
| type=float, |
| default=1e-4, |
| ) |
| parser.add_argument("--warmup_ratio", type=float, default=0.1) |
| parser.add_argument("--weight_decay", type=float, default=0.01) |
| parser.add_argument("--gradient_accumulation_steps", type=int, default=1) |
| parser.add_argument("--logging_steps", type=int, default=10) |
| parser.add_argument("--save_steps", type=int, default=100) |
| parser.add_argument("--save_total_limit", type=int, default=2) |
| parser.add_argument("--lr_scheduler_type", default="cosine") |
| parser.add_argument("--optim", default="adamw_torch") |
| parser.add_argument("--loss_mini_batch_size", type=int, default=128) |
| parser.add_argument("--temperature", type=float, default=None) |
| parser.add_argument("--gather_across_devices", action="store_true") |
| parser.add_argument("--bf16", action="store_true", default=True) |
| parser.add_argument("--fp16", action="store_true", default=False) |
| parser.add_argument("--dataloader_num_workers", type=int, default=12) |
| parser.add_argument("--dataloader_prefetch_factor", type=int, default=2) |
| parser.add_argument("--dataloader_persistent_workers", action="store_true", default=False) |
| parser.add_argument("--no_drop_last", action="store_true", help="Disable drop_last (default: True)") |
| parser.add_argument( |
| "--batch_sampler", |
| choices=["batch_sampler", "no_duplicates"], |
| default="no_duplicates", |
| help="Batch sampler type for SentenceTransformers.", |
| ) |
| parser.add_argument( |
| "--loss_type", |
| choices=["CMNRL", "CMNBRL"], |
| default="CMNBRL", |
| help="Loss type: CMNRL (CachedMultipleNegativesRankingLoss) or " |
| "CMNBRL (aka GTE with GradCache).", |
| ) |
| parser.add_argument( |
| "--output_root", |
| default="output/models/examples", |
| help="Root directory for outputs.", |
| ) |
| parser.add_argument("--run_name", default=None) |
| parser.add_argument("--no_shuffle", action="store_true") |
| parser.add_argument("--max_steps", type=int, default=-1, help="Max training steps (debug).") |
| parser.add_argument("--resume_from_checkpoint", default=None, help="Resume training from checkpoint.") |
| return parser.parse_args() |
|
|
|
|
| def build_output_dir(output_root: Path, run_name: str) -> Path: |
| timestamp = time.strftime("%Y%m%d_%H%M%S") |
| return output_root / run_name / timestamp |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
|
|
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") |
|
|
| import torch |
| from datasets import Dataset, DatasetDict, load_dataset |
| from sentence_transformers import ( |
| SentenceTransformer, |
| SentenceTransformerTrainer, |
| SentenceTransformerTrainingArguments, |
| losses, |
| ) |
| from sentence_transformers.evaluation import NanoBEIREvaluator |
|
|
| logging.basicConfig( |
| format="%(asctime)s %(levelname)s %(name)s: %(message)s", |
| datefmt="%Y-%m-%d %H:%M:%S", |
| level=logging.INFO, |
| ) |
| logger = logging.getLogger("train_st_loss_example") |
|
|
| if args.bf16 and (not torch.cuda.is_available() or not torch.cuda.is_bf16_supported()): |
| logger.warning("bf16 requested but not supported on this device; falling back to fp16=false.") |
| args.bf16 = False |
|
|
| output_root = Path(args.output_root) |
| output_root.mkdir(parents=True, exist_ok=True) |
|
|
| max_train_tag = "full" if args.max_train_examples < 0 else f"{args.max_train_examples}" |
| data_tag = "pair" |
| if args.run_name is None: |
| model_tag = args.model_name.rstrip("/").split("/")[-1] |
| temp_tag = "tdefault" if args.temperature is None else f"t{args.temperature}".replace(".", "p") |
| args.run_name = ( |
| f"{model_tag}_{args.loss_type}_{args.batch_sampler}_{temp_tag}_{data_tag}" |
| f"_bs{args.per_device_train_batch_size}_{max_train_tag}" |
| ) |
| output_dir = build_output_dir(output_root, args.run_name) |
| output_dir.mkdir(parents=True, exist_ok=True) |
| final_dir = output_dir / "final" |
|
|
| logger.info("Loading model: %s", args.model_name) |
| model = SentenceTransformer(args.model_name) |
| model.max_seq_length = args.max_seq_length |
|
|
| def _load_pair_dataset(dataset_id: str, config: str | None, rename_map: dict[str, str]) -> Dataset: |
| ds = load_dataset(dataset_id, config, split="train") if config else load_dataset(dataset_id, split="train") |
| ds = cast(Dataset, ds) |
| if rename_map: |
| column_names = ds.column_names or [] |
| existing = {k: v for k, v in rename_map.items() if k in column_names} |
| if existing: |
| ds = ds.rename_columns(existing) |
| ds = ds.select_columns(["query", "positive"]) |
| return ds |
|
|
| logger.info("Loading datasets (pair only)...") |
| train_datasets = DatasetDict( |
| { |
| "msmarco": _load_pair_dataset( |
| "sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1", |
| "triplet", |
| {"query": "query", "positive": "positive"}, |
| ), |
| "natural_questions": _load_pair_dataset( |
| "sentence-transformers/natural-questions", |
| "pair", |
| {"answer": "positive"}, |
| ), |
| "gooaq": _load_pair_dataset( |
| "sentence-transformers/gooaq", |
| "pair", |
| {"question": "query", "answer": "positive"}, |
| ), |
| "ccnews": _load_pair_dataset( |
| "sentence-transformers/ccnews", |
| "pair", |
| {"title": "query", "article": "positive"}, |
| ), |
| "hotpotqa": _load_pair_dataset( |
| "sentence-transformers/hotpotqa", |
| "triplet", |
| {"anchor": "query", "positive": "positive"}, |
| ), |
| } |
| ) |
|
|
| for name, ds in train_datasets.items(): |
| if not args.no_shuffle: |
| ds = ds.shuffle(seed=args.seed) |
| if args.max_train_examples > 0: |
| ds = ds.select(range(min(args.max_train_examples, len(ds)))) |
| train_datasets[name] = ds |
| logger.info("Train examples [%s]: %d", name, len(ds)) |
|
|
| loss_kwargs = {} |
| if args.temperature is not None: |
| if args.loss_type == "CMNBRL": |
| loss_kwargs["temperature"] = args.temperature |
| else: |
| loss_kwargs["scale"] = 1.0 / args.temperature |
| if args.loss_mini_batch_size is not None: |
| loss_kwargs["mini_batch_size"] = args.loss_mini_batch_size |
| if args.gather_across_devices: |
| loss_kwargs["gather_across_devices"] = True |
|
|
| if args.loss_type == "CMNBRL": |
| loss = losses.CachedMultipleNegativesBidirectionalRankingLoss(model=model, **loss_kwargs) |
| else: |
| loss = losses.CachedMultipleNegativesRankingLoss(model=model, **loss_kwargs) |
|
|
| training_args = SentenceTransformerTrainingArguments( |
| output_dir=str(output_dir), |
| num_train_epochs=args.num_train_epochs, |
| per_device_train_batch_size=args.per_device_train_batch_size, |
| per_device_eval_batch_size=args.per_device_eval_batch_size, |
| learning_rate=args.learning_rate, |
| warmup_ratio=args.warmup_ratio, |
| weight_decay=args.weight_decay, |
| gradient_accumulation_steps=args.gradient_accumulation_steps, |
| logging_steps=args.logging_steps, |
| save_steps=args.save_steps, |
| save_strategy="steps", |
| save_total_limit=args.save_total_limit, |
| lr_scheduler_type=args.lr_scheduler_type, |
| optim=args.optim, |
| bf16=args.bf16, |
| fp16=args.fp16, |
| dataloader_num_workers=args.dataloader_num_workers, |
| dataloader_prefetch_factor=args.dataloader_prefetch_factor, |
| dataloader_persistent_workers=args.dataloader_persistent_workers, |
| dataloader_drop_last=not args.no_drop_last, |
| seed=args.seed, |
| max_steps=args.max_steps, |
| eval_strategy="no", |
| report_to=["wandb"], |
| remove_unused_columns=False, |
| batch_sampler=args.batch_sampler, |
| disable_tqdm=False, |
| ) |
|
|
| trainer = SentenceTransformerTrainer( |
| model=model, |
| args=training_args, |
| train_dataset=train_datasets, |
| loss=loss, |
| ) |
|
|
| logger.info("Training start. Output: %s", output_dir) |
| trainer.train(resume_from_checkpoint=args.resume_from_checkpoint) |
|
|
| evaluator = NanoBEIREvaluator( |
| ndcg_at_k=[10], |
| mrr_at_k=[10], |
| accuracy_at_k=[10], |
| precision_recall_at_k=[10], |
| map_at_k=[10], |
| batch_size=args.per_device_eval_batch_size, |
| show_progress_bar=False, |
| write_csv=False, |
| ) |
| results = evaluator( |
| model, |
| output_path=str(output_dir / "eval"), |
| epoch=0, |
| steps=trainer.state.global_step, |
| ) |
| ndcg_key = evaluator.primary_metric |
| print(f"NDCG@10: {results[ndcg_key]:.6f} ({ndcg_key})") |
|
|
| final_dir.mkdir(parents=True, exist_ok=True) |
| trainer.save_model(str(final_dir)) |
| model.save(str(final_dir), create_model_card=True) |
| logger.info("Saved model to: %s", final_dir) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|