TRL documentation
Iterative Trainer
Iterative Trainer
Iterative fine-tuning is a training method that enables to perform custom actions (generation and filtering for example) between optimization steps. In TRL we provide an easy-to-use API to fine-tune your models in an iterative way in just a few lines of code.
Quickstart
To get started quickly, you can either pass a model identifier or a pre-instantiated model to the trainer:
from trl import IterativeSFTConfig, IterativeSFTTrainer
# Using a model identifier
trainer = IterativeSFTTrainer(
"facebook/opt-350m",
args=IterativeSFTConfig(
max_length=512,
output_dir="./output",
),
)
# Or using a pre-instantiated model
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
trainer = IterativeSFTTrainer(
model,
args=IterativeSFTConfig(
max_length=512,
output_dir="./output",
),
processing_class=tokenizer,
)
Usage
The IterativeSFTTrainer supports two ways of providing input data to the step
function:
Using a list of tensors as input:
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
trainer.step(**inputs)
Using a list of strings as input:
inputs = {
"texts": texts,
"texts_labels": texts_labels, # Optional, defaults to texts
}
trainer.step(**inputs)
For causal language models, labels will automatically be created from input_ids
or from texts
. When using sequence to sequence models you will have to provide your own labels or text_labels
.
Configuration
The IterativeSFTConfig class provides several parameters to customize the training:
from trl import IterativeSFTConfig
config = IterativeSFTConfig(
# Model initialization parameters
model_init_kwargs={"torch_dtype": "bfloat16"},
# Data preprocessing parameters
max_length=512,
truncation_mode="keep_end",
# Training parameters
output_dir="./output",
learning_rate=2e-5,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
max_steps=1000,
save_steps=100,
optim="adamw_torch",
report_to="wandb",
)
Model Initialization
You can control how the model is initialized by passing keyword arguments to model_init_kwargs
:
config = IterativeSFTConfig(
model_init_kwargs={
"torch_dtype": "bfloat16",
"device_map": "auto",
"trust_remote_code": True,
}
)
Data Preprocessing
The trainer supports two truncation modes:
keep_end
: Truncates from the start of the sequencekeep_start
: Truncates from the end of the sequence
config = IterativeSFTConfig(
max_length=512,
truncation_mode="keep_end", # or "keep_start"
)
Training Optimization
You can optimize CUDA cache usage for more memory-efficient training:
config = IterativeSFTConfig(
optimize_device_cache=True,
)
IterativeSFTTrainer
class trl.IterativeSFTTrainer
< source >( model: typing.Union[str, transformers.modeling_utils.PreTrainedModel] args: typing.Union[trl.trainer.iterative_sft_config.IterativeSFTConfig, transformers.training_args.TrainingArguments, NoneType] = None data_collator: typing.Optional[transformers.data.data_collator.DataCollator] = None eval_dataset: typing.Union[datasets.arrow_dataset.Dataset, dict[str, datasets.arrow_dataset.Dataset], NoneType] = None processing_class: typing.Union[transformers.tokenization_utils_base.PreTrainedTokenizerBase, transformers.image_processing_utils.BaseImageProcessor, transformers.feature_extraction_utils.FeatureExtractionMixin, transformers.processing_utils.ProcessorMixin, NoneType] = None optimizers: tuple = (None, None) preprocess_logits_for_metrics: typing.Optional[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None compute_metrics: typing.Optional[typing.Callable[[transformers.trainer_utils.EvalLoopOutput], dict]] = None )
Parameters
- model (
Union[str, PreTrainedModel]
) — Model to be trained. Can be either:- A string, being the model id of a pretrained model hosted inside a model repo on huggingface.co, or a
path to a directory containing model weights saved using
save_pretrained
, e.g.,'./my_model_directory/'
. The model is loaded usingfrom_pretrained
with the keyword arguments inargs.model_init_kwargs
. - A
PreTrainedModel
object. Only causal language models are supported.
- A string, being the model id of a pretrained model hosted inside a model repo on huggingface.co, or a
path to a directory containing model weights saved using
- args (IterativeSFTConfig, optional, defaults to
None
) — Configuration for this trainer. IfNone
, a default configuration is used. - data_collator (
DataCollator
, optional) — Function to use to form a batch from a list of elements of the processedtrain_dataset
oreval_dataset
. Will default todefault_data_collator
if noprocessing_class
is provided, an instance ofDataCollatorWithPadding
otherwise if the processing_class is a feature extractor or tokenizer. - eval_dataset (
datasets.Dataset
) — The dataset to use for evaluation. - processing_class (
PreTrainedTokenizerBase
,BaseImageProcessor
,FeatureExtractionMixin
orProcessorMixin
, optional, defaults toNone
) — Processing class used to process the data. IfNone
, the processing class is loaded from the model’s name withfrom_pretrained
. - optimizers (
tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]
) — The optimizer and scheduler to use for training. - preprocess_logits_for_metrics (
Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
) — The function to use to preprocess the logits before computing the metrics. - compute_metrics (
Callable[[EvalPrediction], dict]
, optional) — The function to use to compute the metrics. Must take aEvalPrediction
and return a dictionary string to metric values. - max_length (
int
, optional, deprecated) — Maximum length of the tokenized sequence. Useargs.max_length
instead. - truncation_mode (
str
, optional, deprecated) — The truncation mode to use. Useargs.truncation_mode
instead. - optimize_device_cache (
bool
, optional, deprecated) — Whether to optimize accelerator cache. Useargs.optimize_device_cache
instead.
The IterativeSFTTrainer can be used to finetune models with methods that requires some steps between optimization.
create_model_card
< source >( model_name: typing.Optional[str] = None dataset_name: typing.Optional[str] = None tags: typing.Union[str, list[str], NoneType] = None )
Creates a draft of a model card using the information available to the Trainer
.
step
< source >( input_ids: typing.Optional[list[torch.LongTensor]] = None attention_mask: typing.Optional[list[torch.LongTensor]] = None labels: typing.Optional[list[torch.LongTensor]] = None texts: typing.Optional[list[str]] = None texts_labels: typing.Optional[list[str]] = None ) → dict[str, Any]
Parameters
- input_ids (list
torch.LongTensor
) — List of tensors containing the input_ids (if not provided, text will be used) - attention_mask (list
torch.LongTensor
, , optional) — List of tensors containing the attention_mask - labels (list
torch.FloatTensor
, optional) — List of tensors containing the labels (if set to None, will default to input_ids) - texts (list
str
, optional) — List of strings containing the text input (if not provided, input_ids will directly be used) - texts_labels (list
str
, optional) — List of strings containing the text labels (if set to None, will default to text)
Returns
dict[str, Any]
A summary of the training statistics
Run an optimisation step given a list of input_ids, attention_mask, and labels or a list of text and text_labels.
IterativeSFTConfig
class trl.IterativeSFTConfig
< source >( output_dir: typing.Optional[str] = None overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: typing.Optional[float] = 0 torch_empty_cache_steps: typing.Optional[int] = None learning_rate: float = 5e-05 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' lr_scheduler_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: str = 'passive' log_level_replica: str = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 10 logging_nan_inf_filter: bool = True save_strategy: typing.Union[transformers.trainer_utils.SaveStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: typing.Optional[bool] = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False use_ipex: bool = False bf16: typing.Optional[bool] = None fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = None tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, list[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: typing.Optional[int] = None past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: typing.Optional[bool] = True label_names: typing.Optional[list[str]] = None load_best_model_at_end: typing.Optional[bool] = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False fsdp: typing.Union[list[transformers.trainer_utils.FSDPOption], str, NoneType] = '' fsdp_min_num_params: int = 0 fsdp_config: typing.Union[dict[str, typing.Any], str, NoneType] = None fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None accelerator_config: typing.Union[dict, str, NoneType] = None deepspeed: typing.Union[dict, str, NoneType] = None label_smoothing_factor: float = 0.0 optim: typing.Union[transformers.training_args.OptimizerNames, str] = 'adamw_torch' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: typing.Optional[str] = 'length' report_to: typing.Union[NoneType, str, list[str]] = None ddp_find_unused_parameters: typing.Optional[bool] = None ddp_bucket_cap_mb: typing.Optional[int] = None ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: typing.Optional[bool] = None hub_always_push: bool = False hub_revision: typing.Optional[str] = None gradient_checkpointing: bool = False gradient_checkpointing_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = None include_inputs_for_metrics: bool = False include_for_metrics: list = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: int = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None include_tokens_per_second: typing.Optional[bool] = False include_num_input_tokens_seen: typing.Optional[bool] = False neftune_noise_alpha: typing.Optional[float] = None optim_target_modules: typing.Union[NoneType, str, list[str]] = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: typing.Optional[bool] = False liger_kernel_config: typing.Optional[dict[str, bool]] = None eval_use_gather_object: typing.Optional[bool] = False average_tokens_across_devices: typing.Optional[bool] = True model_init_kwargs: typing.Optional[dict[str, typing.Any]] = None max_length: typing.Optional[int] = None truncation_mode: str = 'keep_end' optimize_device_cache: bool = False )
Parameters that control the model
- model_init_kwargs (
dict[str, Any]
orNone
, optional, defaults toNone
) — Keyword arguments forfrom_pretrained
, used when themodel
argument of the IterativeSFTTrainer is provided as a string.
Parameters that control the data preprocessing
- max_length (
int
orNone
, optional, defaults toNone
) — Maximum length of the tokenized sequence. Sequences longer thanmax_length
are truncated. - truncation_mode (
str
, optional, defaults to"keep_end"
) — The truncation mode to use, either"keep_end"
or"keep_start"
. - optimize_device_cache (
bool
, optional, defaults toFalse
) — Whether to optimize accelerator cache for slightly more memory-efficient training.
Configuration class for the IterativeSFTTrainer.
This class includes only the parameters that are specific to Iterative SFT training. For a full list of training
arguments, please refer to the TrainingArguments
documentation. Note that default values in this
class may differ from those in TrainingArguments
.
Using HfArgumentParser
we can turn this class into
argparse arguments that can be specified on the
command line.