duzx16
commited on
Commit
·
aea6cef
1
Parent(s):
0564795
Implement gradient checkpointing
Browse files- modeling_chatglm.py +41 -20
modeling_chatglm.py
CHANGED
|
@@ -244,7 +244,7 @@ def attention_fn(
|
|
| 244 |
use_cache=False,
|
| 245 |
):
|
| 246 |
if layer_past is not None:
|
| 247 |
-
past_key, past_value = layer_past
|
| 248 |
key_layer = torch.cat((past_key, key_layer), dim=0)
|
| 249 |
value_layer = torch.cat((past_value, value_layer), dim=0)
|
| 250 |
|
|
@@ -644,7 +644,7 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
|
| 644 |
"""
|
| 645 |
|
| 646 |
is_parallelizable = False
|
| 647 |
-
supports_gradient_checkpointing =
|
| 648 |
config_class = ChatGLMConfig
|
| 649 |
base_model_prefix = "transformer"
|
| 650 |
_no_split_modules = ["GLM6BBlock"]
|
|
@@ -656,6 +656,10 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
|
| 656 |
"""Initialize the weights."""
|
| 657 |
return
|
| 658 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 659 |
|
| 660 |
CHATGLM_6B_START_DOCSTRING = r"""
|
| 661 |
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class.
|
|
@@ -760,6 +764,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 760 |
num_embeddings=self.vocab_size, embedding_dim=self.hidden_size,
|
| 761 |
dtype=self.params_dtype
|
| 762 |
)
|
|
|
|
| 763 |
|
| 764 |
def get_layer(layer_id):
|
| 765 |
return GLMBlock(
|
|
@@ -812,9 +817,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 812 |
#seq_len, b, nh, hidden_size
|
| 813 |
past_key_values = self.dropout(past_key_values)
|
| 814 |
past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
|
| 815 |
-
past_key_values = [(v[0], v[1]) for v in past_key_values]
|
| 816 |
-
# past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(self.num_layers)
|
| 817 |
-
# past_key_values = [(v1,v2) for v1, v2 in zip(past_key_values[0], past_key_values[1])]
|
| 818 |
return past_key_values
|
| 819 |
|
| 820 |
def get_masks(self, input_ids, device):
|
|
@@ -877,6 +880,13 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 877 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 878 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 879 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 880 |
if input_ids is not None and inputs_embeds is not None:
|
| 881 |
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 882 |
elif input_ids is not None:
|
|
@@ -926,31 +936,42 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 926 |
all_self_attentions = () if output_attentions else None
|
| 927 |
all_hidden_states = () if output_hidden_states else None
|
| 928 |
|
| 929 |
-
seq_length_with_past = seq_length
|
| 930 |
-
past_key_values_length = 0
|
| 931 |
-
if past_key_values[0] is not None:
|
| 932 |
-
past_key_values_length = past_key_values[0][0].shape[0]
|
| 933 |
-
seq_length_with_past = seq_length_with_past + past_key_values_length
|
| 934 |
if attention_mask is None:
|
| 935 |
attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
|
| 936 |
|
| 937 |
else:
|
| 938 |
attention_mask = attention_mask.to(input_ids.device)
|
| 939 |
|
|
|
|
|
|
|
|
|
|
| 940 |
for i, layer in enumerate(self.layers):
|
| 941 |
|
| 942 |
if output_hidden_states:
|
| 943 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 944 |
-
|
| 945 |
-
|
| 946 |
-
|
| 947 |
-
|
| 948 |
-
|
| 949 |
-
|
| 950 |
-
|
| 951 |
-
|
| 952 |
-
|
| 953 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 954 |
|
| 955 |
hidden_states = layer_ret[0]
|
| 956 |
|
|
|
|
| 244 |
use_cache=False,
|
| 245 |
):
|
| 246 |
if layer_past is not None:
|
| 247 |
+
past_key, past_value = layer_past[0], layer_past[1]
|
| 248 |
key_layer = torch.cat((past_key, key_layer), dim=0)
|
| 249 |
value_layer = torch.cat((past_value, value_layer), dim=0)
|
| 250 |
|
|
|
|
| 644 |
"""
|
| 645 |
|
| 646 |
is_parallelizable = False
|
| 647 |
+
supports_gradient_checkpointing = True
|
| 648 |
config_class = ChatGLMConfig
|
| 649 |
base_model_prefix = "transformer"
|
| 650 |
_no_split_modules = ["GLM6BBlock"]
|
|
|
|
| 656 |
"""Initialize the weights."""
|
| 657 |
return
|
| 658 |
|
| 659 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 660 |
+
if isinstance(module, ChatGLMModel):
|
| 661 |
+
module.gradient_checkpointing = value
|
| 662 |
+
|
| 663 |
|
| 664 |
CHATGLM_6B_START_DOCSTRING = r"""
|
| 665 |
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class.
|
|
|
|
| 764 |
num_embeddings=self.vocab_size, embedding_dim=self.hidden_size,
|
| 765 |
dtype=self.params_dtype
|
| 766 |
)
|
| 767 |
+
self.gradient_checkpointing = False
|
| 768 |
|
| 769 |
def get_layer(layer_id):
|
| 770 |
return GLMBlock(
|
|
|
|
| 817 |
#seq_len, b, nh, hidden_size
|
| 818 |
past_key_values = self.dropout(past_key_values)
|
| 819 |
past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
|
| 820 |
+
# past_key_values = [(v[0], v[1]) for v in past_key_values]
|
|
|
|
|
|
|
| 821 |
return past_key_values
|
| 822 |
|
| 823 |
def get_masks(self, input_ids, device):
|
|
|
|
| 880 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 881 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 882 |
|
| 883 |
+
if self.gradient_checkpointing and self.training:
|
| 884 |
+
if use_cache:
|
| 885 |
+
logger.warning_once(
|
| 886 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 887 |
+
)
|
| 888 |
+
use_cache = False
|
| 889 |
+
|
| 890 |
if input_ids is not None and inputs_embeds is not None:
|
| 891 |
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 892 |
elif input_ids is not None:
|
|
|
|
| 936 |
all_self_attentions = () if output_attentions else None
|
| 937 |
all_hidden_states = () if output_hidden_states else None
|
| 938 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 939 |
if attention_mask is None:
|
| 940 |
attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
|
| 941 |
|
| 942 |
else:
|
| 943 |
attention_mask = attention_mask.to(input_ids.device)
|
| 944 |
|
| 945 |
+
if self.training:
|
| 946 |
+
hidden_states = hidden_states.requires_grad_(True)
|
| 947 |
+
|
| 948 |
for i, layer in enumerate(self.layers):
|
| 949 |
|
| 950 |
if output_hidden_states:
|
| 951 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 952 |
+
layer_past = past_key_values[i]
|
| 953 |
+
|
| 954 |
+
if self.gradient_checkpointing and self.training:
|
| 955 |
+
layer_ret = torch.utils.checkpoint.checkpoint(
|
| 956 |
+
layer,
|
| 957 |
+
hidden_states,
|
| 958 |
+
position_ids,
|
| 959 |
+
attention_mask,
|
| 960 |
+
torch.tensor(i),
|
| 961 |
+
layer_past,
|
| 962 |
+
use_cache,
|
| 963 |
+
output_attentions
|
| 964 |
+
)
|
| 965 |
+
else:
|
| 966 |
+
layer_ret = layer(
|
| 967 |
+
hidden_states,
|
| 968 |
+
position_ids=position_ids,
|
| 969 |
+
attention_mask=attention_mask,
|
| 970 |
+
layer_id=torch.tensor(i),
|
| 971 |
+
layer_past=layer_past,
|
| 972 |
+
use_cache=use_cache,
|
| 973 |
+
output_attentions=output_attentions
|
| 974 |
+
)
|
| 975 |
|
| 976 |
hidden_states = layer_ret[0]
|
| 977 |
|