flash_attention_utils_backward_compat

#2
by itlevy - opened
NOTICE DELETED
@@ -1,5 +0,0 @@
1
- Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
-
3
- NVIDIA CORPORATION, its affiliates and licensors retain all intellectual property and proprietary rights in and to this material, related documentation and any modifications thereto. Any use, reproduction, disclosure or distribution of this material and related documentation without an express license agreement from NVIDIA CORPORATION or its affiliates is strictly prohibited.
4
-
5
- Llama 3.1 is licensed under the Llama 3.1 Community License, Copyright © Meta Platforms, Inc. All Rights Reserved.
 
 
 
 
 
 
README.md CHANGED
@@ -8,9 +8,9 @@ tags:
8
  - llama-3
9
  - pytorch
10
  license: other
11
- license_name: nvidia-open-model-license
12
  license_link: >-
13
- https://developer.download.nvidia.com/licenses/nvidia-open-model-license-agreement-june-2024.pdf
14
  ---
15
 
16
  # Llama-3_1-Nemotron-51B-instruct
@@ -22,8 +22,7 @@ Llama-3_1-Nemotron-51B-instruct is a model which offers a great tradeoff between
22
 
23
 
24
  ## License
25
- Your use of this model is governed by the [NVIDIA Open Model License](https://developer.download.nvidia.com/licenses/nvidia-open-model-license-agreement-june-2024.pdf).
26
- Additional Information: [Llama 3.1 Community License Agreement](https://www.llama.com/llama3_1/license/). Built with Llama.
27
 
28
  ## How was the model developed
29
 
@@ -33,7 +32,6 @@ The KD step included 40 billion tokens consisting of a mixture of 3 datasets - F
33
  Links to [NIM](https://build.nvidia.com/nvidia/llama-3_1-nemotron-51b-instruct), [blog](https://developer.nvidia.com/blog/advancing-the-accuracy-efficiency-frontier-with-llama-3-1-nemotron-51b/) and [huggingface](https://huggingface.co/nvidia/Llama-3_1-Nemotron-51B-Instruct)
34
 
35
 
36
-
37
  This results in a final model that is aligned for human chat preferences.
38
 
39
  **Model Developers:** NVIDIA
 
8
  - llama-3
9
  - pytorch
10
  license: other
11
+ license_name: nvidia-ai-foundation-models-community-license
12
  license_link: >-
13
+ https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-ai-foundation-models-community-license-agreement/
14
  ---
15
 
16
  # Llama-3_1-Nemotron-51B-instruct
 
22
 
23
 
24
  ## License
25
+ [NVIDIA AI Foundation Models Community License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-ai-foundation-models-community-license-agreement/). Additional Information: [Llama 3.1 Community License Agreement](https://www.llama.com/llama3_1/license/). Built with Llama.
 
26
 
27
  ## How was the model developed
28
 
 
32
  Links to [NIM](https://build.nvidia.com/nvidia/llama-3_1-nemotron-51b-instruct), [blog](https://developer.nvidia.com/blog/advancing-the-accuracy-efficiency-frontier-with-llama-3-1-nemotron-51b/) and [huggingface](https://huggingface.co/nvidia/Llama-3_1-Nemotron-51B-Instruct)
33
 
34
 
 
35
  This results in a final model that is aligned for human chat preferences.
36
 
37
  **Model Developers:** NVIDIA
modeling_decilm.py CHANGED
@@ -25,7 +25,7 @@ import torch.utils.checkpoint
25
  from torch import nn
26
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
  from transformers import GenerationConfig
28
- from transformers.generation.utils import NEED_SETUP_CACHE_CLASSES_MAPPING, GenerationMixin, GenerateOutput
29
  from transformers.modeling_utils import PreTrainedModel
30
  from transformers.utils import (
31
  add_start_docstrings,
@@ -385,6 +385,7 @@ class DeciLMAttention(nn.Module):
385
  **kwargs,
386
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
387
  bsz, q_len, _ = hidden_states.size()
 
388
  if self.config.pretraining_tp > 1:
389
  key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
390
  query_slices = self.q_proj.weight.split(
@@ -496,6 +497,7 @@ class DeciLMFlashAttention2(DeciLMAttention):
496
  "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
497
  "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
498
  )
 
499
  output_attentions = False
500
 
501
  bsz, q_len, _ = hidden_states.size()
@@ -833,13 +835,10 @@ class DeciLMPreTrainedModel(PreTrainedModel):
833
  module.weight.data[module.padding_idx].zero_()
834
 
835
  def _prepare_generation_config(
836
- self,
837
- generation_config: Optional[GenerationConfig],
838
- *args,
839
- **kwargs,
840
  ) -> tuple[GenerationConfig, dict]:
841
  # DeciLM-specific code
842
- generation_config, model_kwargs = super()._prepare_generation_config(generation_config, *args, **kwargs)
843
  generation_config.cache_implementation = "variable"
844
  NEED_SETUP_CACHE_CLASSES_MAPPING["variable"] = VariableCache
845
  return generation_config, model_kwargs
@@ -1134,7 +1133,7 @@ class DeciLMModel(DeciLMPreTrainedModel):
1134
  return causal_mask
1135
 
1136
 
1137
- class DeciLMForCausalLM(DeciLMPreTrainedModel, GenerationMixin):
1138
  _tied_weights_keys = ["lm_head.weight"]
1139
 
1140
  def __init__(self, config):
@@ -1314,50 +1313,6 @@ class DeciLMForCausalLM(DeciLMPreTrainedModel, GenerationMixin):
1314
  )
1315
  return model_inputs
1316
 
1317
- def _maybe_initialize_input_ids_for_generation(
1318
- self,
1319
- inputs: Optional[torch.Tensor] = None,
1320
- bos_token_id: Optional[torch.Tensor] = None,
1321
- model_kwargs: Optional[dict[str, torch.Tensor]] = None,
1322
- ) -> torch.LongTensor:
1323
- """
1324
- Patching hf bug that creates wrong cache length if only inputs_embeds are passed to the model
1325
- """
1326
- input_ids = super()._maybe_initialize_input_ids_for_generation(
1327
- inputs=inputs, bos_token_id=bos_token_id, model_kwargs=model_kwargs)
1328
- if (
1329
- "inputs_embeds" in model_kwargs
1330
- and input_ids is not None
1331
- and input_ids.shape[1] == 0
1332
- ):
1333
- batch_size, input_sequence_length = model_kwargs["inputs_embeds"].shape[:2]
1334
- input_ids = torch.zeros((batch_size, input_sequence_length), dtype=torch.long, device=self.device)
1335
- return input_ids
1336
-
1337
- def generate(
1338
- self,
1339
- inputs: Optional[torch.Tensor] = None,
1340
- *args,
1341
- **kwargs,
1342
- ) -> Union[GenerateOutput, torch.LongTensor]:
1343
- """
1344
- Patching hf bug that creates wrong cache length if only inputs_embeds are passed to the model
1345
- """
1346
- only_passed_inputs_embeds = (
1347
- "inputs_embeds" in kwargs and
1348
- "input_ids" not in kwargs and
1349
- inputs is None
1350
- )
1351
- if only_passed_inputs_embeds:
1352
- input_sequence_length = kwargs["inputs_embeds"].shape[1]
1353
-
1354
- generation_output = super().generate(inputs=inputs, *args, **kwargs)
1355
-
1356
- if only_passed_inputs_embeds and isinstance(generation_output, torch.Tensor):
1357
- generation_output = generation_output[:, input_sequence_length:]
1358
-
1359
- return generation_output
1360
-
1361
 
1362
  @add_start_docstrings(
1363
  """
 
25
  from torch import nn
26
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
  from transformers import GenerationConfig
28
+ from transformers.generation.utils import NEED_SETUP_CACHE_CLASSES_MAPPING
29
  from transformers.modeling_utils import PreTrainedModel
30
  from transformers.utils import (
31
  add_start_docstrings,
 
385
  **kwargs,
386
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
387
  bsz, q_len, _ = hidden_states.size()
388
+
389
  if self.config.pretraining_tp > 1:
390
  key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
391
  query_slices = self.q_proj.weight.split(
 
497
  "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
498
  "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
499
  )
500
+
501
  output_attentions = False
502
 
503
  bsz, q_len, _ = hidden_states.size()
 
835
  module.weight.data[module.padding_idx].zero_()
836
 
837
  def _prepare_generation_config(
838
+ self, generation_config: Optional[GenerationConfig], **kwargs: dict
 
 
 
839
  ) -> tuple[GenerationConfig, dict]:
840
  # DeciLM-specific code
841
+ generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs)
842
  generation_config.cache_implementation = "variable"
843
  NEED_SETUP_CACHE_CLASSES_MAPPING["variable"] = VariableCache
844
  return generation_config, model_kwargs
 
1133
  return causal_mask
1134
 
1135
 
1136
+ class DeciLMForCausalLM(DeciLMPreTrainedModel):
1137
  _tied_weights_keys = ["lm_head.weight"]
1138
 
1139
  def __init__(self, config):
 
1313
  )
1314
  return model_inputs
1315
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1316
 
1317
  @add_start_docstrings(
1318
  """
transformers_4_44_2__modeling_flash_attention_utils_backward_compat.py CHANGED
@@ -15,18 +15,12 @@
15
 
16
  import inspect
17
  import os
18
- from typing import Optional, Tuple, Union
19
-
20
 
21
  import torch
22
  import torch.nn.functional as F
23
 
24
- from functools import lru_cache
25
- import importlib.metadata
26
- import importlib.util
27
- from packaging import version
28
-
29
- from transformers.utils import is_flash_attn_2_available
30
 
31
 
32
  if is_flash_attn_2_available():
@@ -38,46 +32,6 @@ if is_flash_attn_2_available():
38
  raise "Unable to import flash_attn"
39
 
40
 
41
- def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]:
42
- # Check if the package spec exists and grab its version to avoid importing a local directory
43
- package_exists = importlib.util.find_spec(pkg_name) is not None
44
- package_version = "N/A"
45
- if package_exists:
46
- try:
47
- # Primary method to get the package version
48
- package_version = importlib.metadata.version(pkg_name)
49
- except importlib.metadata.PackageNotFoundError:
50
- # Fallback method: Only for "torch" and versions containing "dev"
51
- if pkg_name == "torch":
52
- try:
53
- package = importlib.import_module(pkg_name)
54
- temp_version = getattr(package, "__version__", "N/A")
55
- # Check if the version contains "dev"
56
- if "dev" in temp_version:
57
- package_version = temp_version
58
- package_exists = True
59
- else:
60
- package_exists = False
61
- except ImportError:
62
- # If the package can't be imported, it's not available
63
- package_exists = False
64
- else:
65
- # For packages other than "torch", don't attempt the fallback and set as not available
66
- package_exists = False
67
- if return_version:
68
- return package_exists, package_version
69
- else:
70
- return package_exists
71
-
72
-
73
- @lru_cache()
74
- def is_flash_attn_greater_or_equal(library_version: str):
75
- if not _is_package_available("flash_attn"):
76
- return False
77
-
78
- return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version)
79
-
80
-
81
  def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
82
  """
83
  Retrieves indexing data required to repad unpadded (ragged) tensors.
 
15
 
16
  import inspect
17
  import os
18
+ from typing import Optional, Tuple
 
19
 
20
  import torch
21
  import torch.nn.functional as F
22
 
23
+ from transformers.utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal
 
 
 
 
 
24
 
25
 
26
  if is_flash_attn_2_available():
 
32
  raise "Unable to import flash_attn"
33
 
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
36
  """
37
  Retrieves indexing data required to repad unpadded (ragged) tensors.
variable_cache.py CHANGED
@@ -32,21 +32,17 @@ class VariableCache(Cache_4_44_2, Cache):
32
  The cache of each layer is allocated to the same gpu as the layer itself.
33
  """
34
 
35
- def __init__(
36
- self,
37
- *, # key-word only, no positional args allowed to avoid mix-ups with newer transformers versions
38
- config: DeciLMConfig,
39
- batch_size: int = None,
40
- max_cache_len: int = None,
41
- dtype: torch.dtype = torch.float32,
42
- max_batch_size: Optional[int] = None,
43
- **kwargs: Any,
44
- ) -> None:
45
  Cache_4_44_2.__init__(self)
46
 
47
- self.config = deepcopy(config)
48
- self.max_batch_size = batch_size or max_batch_size
49
- self.batch_size = self.max_batch_size
50
  self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
51
  self.dtype = dtype
52
 
@@ -83,7 +79,6 @@ class VariableCache(Cache_4_44_2, Cache):
83
  if attention_config.no_op or attention_config.replace_with_linear:
84
  return None
85
  config = deepcopy(self.config)
86
- config.num_hidden_layers = 1
87
  config.num_key_value_heads = self.config.num_attention_heads // attention_config.n_heads_in_group
88
  return StaticCache(config, self.max_batch_size, self.max_cache_len, device, self.dtype)
89
 
 
32
  The cache of each layer is allocated to the same gpu as the layer itself.
33
  """
34
 
35
+ def __init__(self,
36
+ config: DeciLMConfig,
37
+ max_batch_size: int,
38
+ max_cache_len: int | None,
39
+ device: torch.device | str | None = None,
40
+ dtype: torch.dtype | None = None,
41
+ ):
 
 
 
42
  Cache_4_44_2.__init__(self)
43
 
44
+ self.config = config
45
+ self.max_batch_size = max_batch_size
 
46
  self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
47
  self.dtype = dtype
48
 
 
79
  if attention_config.no_op or attention_config.replace_with_linear:
80
  return None
81
  config = deepcopy(self.config)
 
82
  config.num_key_value_heads = self.config.num_attention_heads // attention_config.n_heads_in_group
83
  return StaticCache(config, self.max_batch_size, self.max_cache_len, device, self.dtype)
84