Update modeling_nemotron_h.py
Browse files- modeling_nemotron_h.py +1 -0
modeling_nemotron_h.py
CHANGED
|
@@ -1112,6 +1112,7 @@ class NemotronHPreTrainedModel(PreTrainedModel):
|
|
| 1112 |
_no_split_modules = ["NemotronHBlock"]
|
| 1113 |
supports_gradient_checkpointing = True
|
| 1114 |
_is_stateful = True
|
|
|
|
| 1115 |
|
| 1116 |
def _init_weights(self, module):
|
| 1117 |
"""Initialize the weights."""
|
|
|
|
| 1112 |
_no_split_modules = ["NemotronHBlock"]
|
| 1113 |
supports_gradient_checkpointing = True
|
| 1114 |
_is_stateful = True
|
| 1115 |
+
_supports_flash_attn_2 = True
|
| 1116 |
|
| 1117 |
def _init_weights(self, module):
|
| 1118 |
"""Initialize the weights."""
|