Ensure weights are tied for BiMamba (if applicable) when loaded from_pretrained
Browse files- modeling_caduceus.py +31 -1
modeling_caduceus.py
CHANGED
|
@@ -360,6 +360,24 @@ class Caduceus(CaduceusPreTrainedModel):
|
|
| 360 |
factory_kwargs = {"device": device, "dtype": dtype}
|
| 361 |
self.backbone = CaduceusMixerModel(config, **factory_kwargs, **kwargs)
|
| 362 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
def forward(
|
| 364 |
self,
|
| 365 |
input_ids: torch.LongTensor = None,
|
|
@@ -431,8 +449,12 @@ class CaduceusForMaskedLM(CaduceusPreTrainedModel):
|
|
| 431 |
raise NotImplementedError("Setting output embeddings for RCPS LM is not supported.")
|
| 432 |
self.lm_head = new_embeddings
|
| 433 |
|
|
|
|
|
|
|
|
|
|
| 434 |
def tie_weights(self):
|
| 435 |
"""Tie weights, accounting for RCPS."""
|
|
|
|
| 436 |
if self.config.rcps:
|
| 437 |
self.lm_head.set_weight(self.get_input_embeddings().weight)
|
| 438 |
else:
|
|
@@ -445,7 +467,7 @@ class CaduceusForMaskedLM(CaduceusPreTrainedModel):
|
|
| 445 |
def set_decoder(self, decoder):
|
| 446 |
"""Set decoder (backbone) for the model."""
|
| 447 |
self.caduceus = decoder
|
| 448 |
-
|
| 449 |
def forward(
|
| 450 |
self,
|
| 451 |
input_ids: torch.LongTensor = None,
|
|
@@ -536,6 +558,13 @@ class CaduceusForSequenceClassification(CaduceusPreTrainedModel):
|
|
| 536 |
if self.pooling_strategy == "first": # Use embedding of first token in the sequence
|
| 537 |
return hidden_states.moveaxis(hidden_states, sequence_length_dim, 0)[0, ...]
|
| 538 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 539 |
def forward(
|
| 540 |
self,
|
| 541 |
input_ids: torch.LongTensor = None,
|
|
@@ -543,6 +572,7 @@ class CaduceusForSequenceClassification(CaduceusPreTrainedModel):
|
|
| 543 |
labels: Optional[torch.LongTensor] = None,
|
| 544 |
output_hidden_states: Optional[bool] = None,
|
| 545 |
return_dict: Optional[bool] = None,
|
|
|
|
| 546 |
) -> Union[Tuple, SequenceClassifierOutput]:
|
| 547 |
r"""
|
| 548 |
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
|
|
| 360 |
factory_kwargs = {"device": device, "dtype": dtype}
|
| 361 |
self.backbone = CaduceusMixerModel(config, **factory_kwargs, **kwargs)
|
| 362 |
|
| 363 |
+
def maybe_weight_tie_mamba(self):
|
| 364 |
+
if getattr(self.config, 'bidirectional', False) and getattr(self.config, 'bidirectional_weight_tie', False):
|
| 365 |
+
if getattr(self.config, 'rcps', False):
|
| 366 |
+
for layer in self.backbone.layers:
|
| 367 |
+
layer.mixer.submodule.mamba_rev.in_proj.weight = layer.mixer.submodule.mamba_fwd.in_proj.weight
|
| 368 |
+
layer.mixer.submodule.mamba_rev.in_proj.bias = layer.mixer.submodule.mamba_fwd.in_proj.bias
|
| 369 |
+
layer.mixer.submodule.mamba_rev.out_proj.weight = layer.mixer.submodule.mamba_fwd.out_proj.weight
|
| 370 |
+
layer.mixer.submodule.mamba_rev.out_proj.bias = layer.mixer.submodule.mamba_fwd.out_proj.bias
|
| 371 |
+
else:
|
| 372 |
+
for layer in self.backbone.layers:
|
| 373 |
+
layer.mixer.mamba_rev.in_proj.weight = layer.mixer.mamba_fwd.in_proj.weight
|
| 374 |
+
layer.mixer.mamba_rev.in_proj.bias = layer.mixer.mamba_fwd.in_proj.bias
|
| 375 |
+
layer.mixer.mamba_rev.out_proj.weight = layer.mixer.mamba_fwd.out_proj.weight
|
| 376 |
+
layer.mixer.mamba_rev.out_proj.bias = layer.mixer.mamba_fwd.out_proj.bias
|
| 377 |
+
|
| 378 |
+
def tie_weights(self):
|
| 379 |
+
self.maybe_weight_tie_mamba()
|
| 380 |
+
|
| 381 |
def forward(
|
| 382 |
self,
|
| 383 |
input_ids: torch.LongTensor = None,
|
|
|
|
| 449 |
raise NotImplementedError("Setting output embeddings for RCPS LM is not supported.")
|
| 450 |
self.lm_head = new_embeddings
|
| 451 |
|
| 452 |
+
def maybe_weight_tie_mamba(self):
|
| 453 |
+
self.caduceus.maybe_weight_tie_mamba()
|
| 454 |
+
|
| 455 |
def tie_weights(self):
|
| 456 |
"""Tie weights, accounting for RCPS."""
|
| 457 |
+
self.maybe_weight_tie_mamba()
|
| 458 |
if self.config.rcps:
|
| 459 |
self.lm_head.set_weight(self.get_input_embeddings().weight)
|
| 460 |
else:
|
|
|
|
| 467 |
def set_decoder(self, decoder):
|
| 468 |
"""Set decoder (backbone) for the model."""
|
| 469 |
self.caduceus = decoder
|
| 470 |
+
|
| 471 |
def forward(
|
| 472 |
self,
|
| 473 |
input_ids: torch.LongTensor = None,
|
|
|
|
| 558 |
if self.pooling_strategy == "first": # Use embedding of first token in the sequence
|
| 559 |
return hidden_states.moveaxis(hidden_states, sequence_length_dim, 0)[0, ...]
|
| 560 |
|
| 561 |
+
def maybe_weight_tie_mamba(self):
|
| 562 |
+
self.caduceus.maybe_weight_tie_mamba()
|
| 563 |
+
|
| 564 |
+
def tie_weights(self):
|
| 565 |
+
self.maybe_weight_tie_mamba()
|
| 566 |
+
super().tie_weights()
|
| 567 |
+
|
| 568 |
def forward(
|
| 569 |
self,
|
| 570 |
input_ids: torch.LongTensor = None,
|
|
|
|
| 572 |
labels: Optional[torch.LongTensor] = None,
|
| 573 |
output_hidden_states: Optional[bool] = None,
|
| 574 |
return_dict: Optional[bool] = None,
|
| 575 |
+
**kwargs,
|
| 576 |
) -> Union[Tuple, SequenceClassifierOutput]:
|
| 577 |
r"""
|
| 578 |
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|