✨ gradient checkpointing
Browse files- modeling_mpt.py +4 -1
modeling_mpt.py
CHANGED
|
@@ -33,7 +33,10 @@ class MPTPreTrainedModel(PreTrainedModel):
|
|
| 33 |
base_model_prefix = "model"
|
| 34 |
supports_gradient_checkpointing = True
|
| 35 |
_no_split_modules = ["MPTBlock"]
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
class MPTModel(MPTPreTrainedModel):
|
| 39 |
def __init__(self, config: MPTConfig):
|
|
|
|
| 33 |
base_model_prefix = "model"
|
| 34 |
supports_gradient_checkpointing = True
|
| 35 |
_no_split_modules = ["MPTBlock"]
|
| 36 |
+
|
| 37 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 38 |
+
if isinstance(module, MPTModel):
|
| 39 |
+
module.gradient_checkpointing = value
|
| 40 |
|
| 41 |
class MPTModel(MPTPreTrainedModel):
|
| 42 |
def __init__(self, config: MPTConfig):
|