Enable mambav2 compat
Browse files- modeling_caduceus.py +29 -16
modeling_caduceus.py
CHANGED
|
@@ -2,21 +2,29 @@
|
|
| 2 |
|
| 3 |
"""
|
| 4 |
|
|
|
|
| 5 |
import math
|
| 6 |
from functools import partial
|
| 7 |
from typing import Optional, Tuple, Union
|
| 8 |
|
| 9 |
import torch
|
| 10 |
-
from mamba_ssm.modules.mamba_simple import Mamba
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
from torch import nn
|
| 12 |
from torch.nn import functional as F
|
| 13 |
from transformers import PreTrainedModel
|
| 14 |
from transformers.modeling_outputs import BaseModelOutputWithNoAttention, MaskedLMOutput, SequenceClassifierOutput
|
| 15 |
|
| 16 |
try:
|
| 17 |
-
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
|
| 18 |
except ImportError:
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
from .configuration_caduceus import CaduceusConfig
|
| 22 |
from .modeling_rcps import RCPSAddNormWrapper, RCPSEmbedding, RCPSLMHead, RCPSMambaBlock
|
|
@@ -54,13 +62,24 @@ def create_block(
|
|
| 54 |
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
|
| 55 |
)
|
| 56 |
block_cls = RCPSMambaBlock if rcps else Block
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
block.layer_idx = layer_idx
|
| 65 |
return block
|
| 66 |
|
|
@@ -497,12 +516,6 @@ class CaduceusForSequenceClassification(CaduceusPreTrainedModel):
|
|
| 497 |
|
| 498 |
# Initialize weights and apply final processing
|
| 499 |
self.post_init()
|
| 500 |
-
self.init_scorer()
|
| 501 |
-
|
| 502 |
-
def init_scorer(self, initializer_range=0.02):
|
| 503 |
-
initializer_range = self.config.initializer_cfg.get("initializer_range", initializer_range) \
|
| 504 |
-
if self.config.initializer_cfg is not None else initializer_range
|
| 505 |
-
self.score.weight.data.normal_(std=initializer_range)
|
| 506 |
|
| 507 |
def get_input_embeddings(self):
|
| 508 |
return self.caduceus.backbone.embeddings.word_embeddings
|
|
|
|
| 2 |
|
| 3 |
"""
|
| 4 |
|
| 5 |
+
import inspect
|
| 6 |
import math
|
| 7 |
from functools import partial
|
| 8 |
from typing import Optional, Tuple, Union
|
| 9 |
|
| 10 |
import torch
|
| 11 |
+
from mamba_ssm.modules.mamba_simple import Mamba
|
| 12 |
+
try:
|
| 13 |
+
from mamba_ssm.modules.mamba_simple import Block # Legacy mambav1 file structure
|
| 14 |
+
except ImportError:
|
| 15 |
+
from mamba_ssm.modules.block import Block # mambav2 file structure
|
| 16 |
from torch import nn
|
| 17 |
from torch.nn import functional as F
|
| 18 |
from transformers import PreTrainedModel
|
| 19 |
from transformers.modeling_outputs import BaseModelOutputWithNoAttention, MaskedLMOutput, SequenceClassifierOutput
|
| 20 |
|
| 21 |
try:
|
| 22 |
+
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn # Legacy mambav1 file structure
|
| 23 |
except ImportError:
|
| 24 |
+
try:
|
| 25 |
+
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn # mambav2 file structure
|
| 26 |
+
except ImportError:
|
| 27 |
+
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
| 28 |
|
| 29 |
from .configuration_caduceus import CaduceusConfig
|
| 30 |
from .modeling_rcps import RCPSAddNormWrapper, RCPSEmbedding, RCPSLMHead, RCPSMambaBlock
|
|
|
|
| 62 |
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
|
| 63 |
)
|
| 64 |
block_cls = RCPSMambaBlock if rcps else Block
|
| 65 |
+
# mambav2 compatibility
|
| 66 |
+
if "mlp_cls" in inspect.signature(block_cls.__init__).parameters:
|
| 67 |
+
block = block_cls(
|
| 68 |
+
d_model,
|
| 69 |
+
mixer_cls,
|
| 70 |
+
mlp_cls=nn.Identity,
|
| 71 |
+
norm_cls=norm_cls,
|
| 72 |
+
fused_add_norm=fused_add_norm,
|
| 73 |
+
residual_in_fp32=residual_in_fp32,
|
| 74 |
+
)
|
| 75 |
+
else:
|
| 76 |
+
block = block_cls(
|
| 77 |
+
d_model,
|
| 78 |
+
mixer_cls,
|
| 79 |
+
norm_cls=norm_cls,
|
| 80 |
+
fused_add_norm=fused_add_norm,
|
| 81 |
+
residual_in_fp32=residual_in_fp32,
|
| 82 |
+
)
|
| 83 |
block.layer_idx = layer_idx
|
| 84 |
return block
|
| 85 |
|
|
|
|
| 516 |
|
| 517 |
# Initialize weights and apply final processing
|
| 518 |
self.post_init()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 519 |
|
| 520 |
def get_input_embeddings(self):
|
| 521 |
return self.caduceus.backbone.embeddings.word_embeddings
|