# DPO Authors: Rafael Rafailov, Archit Sharma, Eric Mitchell, Stefano Ermon, Christopher D. Manning, and Chelsea Finn 2023 # noqa # Copyright 2023 The HuggingFace Team. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved. from copy import deepcopy import torch import torch.distributed as dist import torch.nn.functional as F from mmengine import MessageHub from transformers.integrations import is_deepspeed_zero3_enabled from xtuner.parallel.sequence import (gather_forward_split_backward, get_sequence_parallel_group, get_sequence_parallel_world_size, split_for_sequence_parallel) from .sft import SupervisedFinetune def create_reference_model(model): if is_deepspeed_zero3_enabled(): raise ValueError('DeepSpeed ZeRO-3 is enabled and is not compatible ' 'with `create_reference_model()`. Please instantiate ' 'your reference model directly with ' '`AutoCausalLM.from_pretrained()`.') parameter_names = [n for n, _ in model.named_parameters()] ref_model = deepcopy(model) # if no layers are shared, return copy of model for param_name in parameter_names: param = ref_model.get_parameter(param_name) param.requires_grad = False return ref_model.eval() class DPO(SupervisedFinetune): """A general class of DPO and its variants.""" def __init__(self, llm, ref_llm=None, beta=0.1, loss_type='sigmoid', label_smoothing=0.0, **kwargs): super().__init__(llm, **kwargs) self.ref_llm = ref_llm self.loss_type = loss_type self.label_smoothing = label_smoothing self.beta = beta if not self.use_lora: self.ref_llm = create_reference_model(self.llm) def _gather_masked_logits(self, logits, labels, mask): logits = torch.gather( logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) return logits * mask def get_logps( self, all_logits, # bs, seqlen,vocab_size all_ref_logits, # bs, seqlen,vocab_size labels, # bs, seqlen ): labels = labels[:, 1:].clone() all_logits = all_logits[:, :-1, :] all_ref_logits = all_ref_logits[:, :-1, :] labels[labels == -100] = 0 loss_mask = labels != 0 all_logps = self._gather_masked_logits(all_logits, labels, loss_mask).sum(-1) all_ref_logps = self._gather_masked_logits(all_ref_logits, labels, loss_mask).sum(-1) if self.loss_type == 'ipo': # average_log_prob all_logps = all_logps / loss_mask.sum(-1) all_ref_logps = all_ref_logps / loss_mask.sum(-1) policy_chosen_logps = all_logps[::2] policy_rejected_logps = all_logps[1::2] reference_chosen_logps = all_ref_logps[::2] reference_rejected_logps = all_ref_logps[1::2] return (policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps) def get_var_len_atten_logps(self, all_logits, all_ref_logits, labels, cu_seqlens, attention_mask): seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() # unpack sequence unpacked_logits = torch.split(all_logits, seqlens, dim=1) unpacked_ref_logits = torch.split(all_ref_logits, seqlens, dim=1) unpacked_labels = torch.split(labels, seqlens, dim=1) if attention_mask is not None: # It indicate that we pad the original sequence, labels, # position_ids and cumulative_len for sequence parallel if the # attention_mask is not None. # We then need to remove the padded segments. assert False in attention_mask unpacked_logits = unpacked_logits[:-1] unpacked_ref_logits = unpacked_ref_logits[:-1] unpacked_labels = unpacked_labels[:-1] assert len(unpacked_logits) % 2 == 0 def compute_logps(_logits, _labels): _labels = _labels[:, 1:].clone() _logits = _logits[:, :-1, :] _labels[_labels == -100] = 0 loss_mask = _labels != 0 logps = self._gather_masked_logits(_logits, _labels, loss_mask) logps = logps.sum(-1) if self.loss_type == 'ipo': logps /= loss_mask.sum(-1) return logps (policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps) = [], [], [], [] for i in range(len(unpacked_logits) // 2): chosen = unpacked_logits[2 * i] rejected = unpacked_logits[2 * i + 1] chosen_ref = unpacked_ref_logits[2 * i] rejected_ref = unpacked_ref_logits[2 * i + 1] chosen_label = unpacked_labels[2 * i] rejected_label = unpacked_labels[2 * i + 1] policy_chosen_logps.append(compute_logps(chosen, chosen_label)) policy_rejected_logps.append( compute_logps(rejected, rejected_label)) reference_chosen_logps.append( compute_logps(chosen_ref, chosen_label)) reference_rejected_logps.append( compute_logps(rejected_ref, rejected_label)) return (torch.stack(policy_chosen_logps), torch.stack(policy_rejected_logps), torch.stack(reference_chosen_logps), torch.stack(reference_rejected_logps)) @staticmethod def _split_for_sequence_parallel(data): # attention mask should not be split ARGS_NEED_TO_SPLIT = ('input_ids', 'position_ids') sp_group = get_sequence_parallel_group() for key in ARGS_NEED_TO_SPLIT: val = data.get(key, None) if val is not None: # `dim` is 1 as the shape of tensor is (bs, seq_len, ...) data[key] = split_for_sequence_parallel( val, dim=1, sp_group=sp_group) return data def compute_loss(self, data, data_samples=None): # modified from https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py # noqa labels = data.pop('labels') if get_sequence_parallel_world_size() > 1: data = self._split_for_sequence_parallel(data) all_logits = self.llm(**data).logits with torch.no_grad(): if self.ref_llm is None: with self.llm.disable_adapter(): all_ref_logits = self.llm(**data).logits else: all_ref_logits = self.ref_llm(**data).logits if get_sequence_parallel_world_size() > 1: all_logits = gather_forward_split_backward( all_logits, dim=1, sp_group=get_sequence_parallel_group(), grad_scale='up') all_ref_logits = gather_forward_split_backward( all_ref_logits, dim=1, sp_group=get_sequence_parallel_group(), grad_scale='up') if not self.use_varlen_attn: (policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps) = self.get_logps( all_logits, all_ref_logits, labels) else: message_hub = MessageHub.get_instance('varlen_attn_args') rank = dist.get_rank() cu_seqlens = message_hub.get_info(f'cumulative_len_rank_{rank}') (policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps) = self.get_var_len_atten_logps( all_logits, all_ref_logits, labels, cu_seqlens, data['attention_mask']) pi_logratios = policy_chosen_logps - policy_rejected_logps ref_logratios = reference_chosen_logps - reference_rejected_logps logits = pi_logratios - ref_logratios if self.loss_type == 'sigmoid': loss = (-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) - F.logsigmoid(-self.beta * logits) * self.label_smoothing) elif self.loss_type == 'robust': loss = (-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + F.logsigmoid(-self.beta * logits) * self.label_smoothing) / (1 - 2 * self.label_smoothing) elif self.loss_type == 'hinge': loss = torch.relu(1 - self.beta * logits) elif self.loss_type == 'ipo': # eqn (17) of the paper where beta is the regularization # parameter for the IPO loss, denoted by tau in the paper. # noqa loss = (logits - 1 / (2 * self.beta))**2 elif self.loss_type == 'kto_pair': # eqn (7) of the HALOs paper chosen_KL = (policy_chosen_logps - reference_chosen_logps).mean().clamp(min=0) rejected_KL = (policy_rejected_logps - reference_rejected_logps).mean().clamp(min=0) chosen_logratios = policy_chosen_logps - reference_chosen_logps rejected_logratios = \ policy_rejected_logps - reference_rejected_logps # As described in the KTO report, the KL term for chosen (rejected) # is estimated using the rejected (chosen) half. # noqa loss = torch.cat( ( 1 - F.sigmoid(self.beta * (chosen_logratios - rejected_KL)), 1 - F.sigmoid(self.beta * (chosen_KL - rejected_logratios)), ), 0, ) elif self.loss_type == 'sppo_hard': # In the paper (https://arxiv.org/pdf/2405.00675), # SPPO employs a soft probability approach, # estimated using the PairRM score. The probability calculation # is conducted outside of the trainer class. # The version described here is the hard probability version, # where P in Equation (4.7) of Algorithm 1 is set to 1 for # the winner and 0 for the loser. a = policy_chosen_logps - reference_chosen_logps b = policy_rejected_logps - reference_rejected_logps loss = (a - 0.5 / self.beta)**2 + (b + 0.5 / self.beta)**2 elif self.loss_type == 'nca_pair': chosen_rewards = (policy_chosen_logps - reference_chosen_logps) * self.beta rejected_rewards = (policy_rejected_logps - reference_rejected_logps) * self.beta loss = (-F.logsigmoid(chosen_rewards) - 0.5 * F.logsigmoid(-chosen_rewards) - 0.5 * F.logsigmoid(-rejected_rewards)) else: raise ValueError( f'Unknown loss type: {self.loss_type}. Should be one of ' "['sigmoid', 'hinge', 'ipo', 'kto_pair', " "'sppo_hard', 'nca_pair', 'robust']") # for logging chosen_rewards = self.beta * ( policy_chosen_logps - reference_chosen_logps) rejected_rewards = self.beta * ( policy_rejected_logps - reference_rejected_logps) reward_acc = (chosen_rewards > rejected_rewards).float().mean() loss_dict = { 'loss': loss, 'chosen_rewards': chosen_rewards.mean(), 'rejected_rewards': rejected_rewards.mean(), 'reward_acc': reward_acc, 'reward_margin': (chosen_rewards - rejected_rewards).mean(), } return loss_dict