QiDeBERTa-CSC / QiDeBERTaCSC.py
Morton-Li's picture
更新模型版本并修复问题。
b2f83ee
from typing import Optional, Tuple
import torch
from torch.nn import Module, Embedding, LayerNorm, Dropout, ModuleList, Linear, functional, Parameter
from transformers import DebertaV2PreTrainedModel
from transformers.modeling_outputs import BaseModelOutput, MaskedLMOutput
from transformers.models.deberta_v2.modeling_deberta_v2 import build_relative_position, scaled_size_sqrt, build_rpos
from .configuration import QiDeBERTaConfig
class QiDeBERTaEmbeddings(Module):
"""Construct the embeddings from word, position and token_type embeddings."""
def __init__(
self,
pad_token_id: int,
d_model: int,
vocab_size: int,
layer_norm_eps: float,
hidden_dropout_prob: float,
):
super().__init__()
self.word_embeddings = Embedding(num_embeddings=vocab_size, embedding_dim=d_model, padding_idx=pad_token_id)
self.LayerNorm = LayerNorm(normalized_shape=d_model, eps=layer_norm_eps)
self.dropout = Dropout(p=hidden_dropout_prob)
def forward(self, input_ids: torch.Tensor, mask: torch.Tensor):
inputs_embeds = self.word_embeddings(input_ids)
embeddings = self.LayerNorm(inputs_embeds)
if mask.dim() != embeddings.dim():
if mask.dim() == 4:
mask = mask.squeeze(1).squeeze(1)
mask = mask.unsqueeze(2)
mask = mask.to(embeddings.dtype)
return self.dropout(embeddings * mask), inputs_embeds
class QiDeBERTaDisentangledSelfAttention(Module):
"""
Disentangled self-attention module
"""
def __init__(
self,
num_heads: int,
d_model: int,
share_att_key: bool,
relative_attention: bool,
max_position_embeddings: int,
hidden_dropout_prob: float,
attention_probs_dropout_prob: float,
pos_att_type: Optional[list] = None,
position_buckets: int = -1,
max_relative_positions: int = -1,
):
super().__init__()
self.num_attention_heads = num_heads
self.attention_head_size = d_model // num_heads
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query_proj = Linear(in_features=d_model, out_features=self.all_head_size, bias=True)
self.key_proj = Linear(in_features=d_model, out_features=self.all_head_size, bias=True)
self.value_proj = Linear(in_features=d_model, out_features=self.all_head_size, bias=True)
self.share_att_key = share_att_key
self.pos_att_type = pos_att_type if pos_att_type is not None else []
self.relative_attention = relative_attention
if self.relative_attention:
self.position_buckets = position_buckets
self.max_relative_positions = max_relative_positions
if self.max_relative_positions < 1:
self.max_relative_positions = max_position_embeddings
self.pos_ebd_size = self.max_relative_positions
if self.position_buckets > 0:
self.pos_ebd_size = self.position_buckets
self.pos_dropout = Dropout(p=hidden_dropout_prob)
if not self.share_att_key:
if "c2p" in self.pos_att_type:
self.pos_key_proj = Linear(in_features=d_model, out_features=self.all_head_size, bias=True)
if "p2c" in self.pos_att_type:
self.pos_query_proj = Linear(in_features=d_model, out_features=self.all_head_size)
self.dropout = Dropout(p=attention_probs_dropout_prob)
@staticmethod
def transpose_for_scores(x, attention_heads) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (attention_heads, -1)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1))
def forward(
self,
hidden_states,
attention_mask,
output_attentions=False,
relative_pos=None,
rel_embeddings=None,
):
"""
Call the module
Args:
hidden_states (`torch.FloatTensor`):
Input states to the module usually the output from previous layer, it will be the Q,K and V in
*Attention(Q,K,V)*
attention_mask (`torch.BoolTensor`):
An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum
sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j*
th token.
output_attentions (`bool`, *optional*):
Whether return the attention matrix.
relative_pos (`torch.LongTensor`):
The relative position encoding between the tokens in the sequence. It's of shape [*B*, *N*, *N*] with
values ranging in [*-max_relative_positions*, *max_relative_positions*].
rel_embeddings (`torch.FloatTensor`):
The embedding of relative distances. It's a tensor of shape [\\(2 \\times
\\text{max_relative_positions}\\), *hidden_size*].
"""
query_layer = self.transpose_for_scores(self.query_proj(hidden_states), self.num_attention_heads)
key_layer = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads)
value_layer = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads)
rel_att = None
# Take the dot product between "query" and "key" to get the raw attention scores.
scale_factor = 1
if "c2p" in self.pos_att_type:
scale_factor += 1
if "p2c" in self.pos_att_type:
scale_factor += 1
scale = scaled_size_sqrt(query_layer, scale_factor)
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2) / scale.to(dtype=query_layer.dtype))
if self.relative_attention:
rel_embeddings = self.pos_dropout(rel_embeddings)
rel_att = self.disentangled_attention_bias(
query_layer, key_layer, relative_pos, rel_embeddings, scale_factor
)
if rel_att is not None:
attention_scores = attention_scores + rel_att
attention_scores = attention_scores
attention_scores = attention_scores.view(
-1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1)
)
attention_mask = attention_mask.bool()
attention_scores = attention_scores.masked_fill(~(attention_mask), torch.finfo(query_layer.dtype).min)
# bsz x height x length x dimension
attention_probs = functional.softmax(attention_scores, dim=-1)
attention_probs = self.dropout(attention_probs)
context_layer = torch.bmm(
attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer
)
context_layer = (
context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), context_layer.size(-1))
.permute(0, 2, 1, 3)
.contiguous()
)
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
context_layer = context_layer.view(new_context_layer_shape)
return (context_layer, attention_probs) if output_attentions else (context_layer, None)
def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor):
if relative_pos is None:
relative_pos = build_relative_position(
query_layer,
key_layer,
bucket_size=self.position_buckets,
max_position=self.max_relative_positions,
)
if relative_pos.dim() == 2:
relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
elif relative_pos.dim() == 3:
relative_pos = relative_pos.unsqueeze(1)
# bsz x height x query x key
elif relative_pos.dim() != 4:
raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}")
att_span = self.pos_ebd_size
relative_pos = relative_pos.to(device=query_layer.device, dtype=torch.long)
rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0)
if self.share_att_key:
pos_query_layer = self.transpose_for_scores(
self.query_proj(rel_embeddings), self.num_attention_heads
).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1)
pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads).repeat(
query_layer.size(0) // self.num_attention_heads, 1, 1
)
else:
if "c2p" in self.pos_att_type:
pos_key_layer = self.transpose_for_scores(
self.pos_key_proj(rel_embeddings), self.num_attention_heads
).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1) # .split(self.all_head_size, dim=-1)
if "p2c" in self.pos_att_type:
pos_query_layer = self.transpose_for_scores(
self.pos_query_proj(rel_embeddings), self.num_attention_heads
).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1) # .split(self.all_head_size, dim=-1)
score = 0
# content->position
if "c2p" in self.pos_att_type:
scale = scaled_size_sqrt(pos_key_layer, scale_factor)
c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2))
c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
c2p_att = torch.gather(
c2p_att,
dim=-1,
index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]),
)
score += c2p_att / scale.to(dtype=c2p_att.dtype)
# position->content
if "p2c" in self.pos_att_type:
scale = scaled_size_sqrt(pos_query_layer, scale_factor)
r_pos = build_rpos(
query_layer,
key_layer,
relative_pos,
self.max_relative_positions,
self.position_buckets,
)
p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2))
p2c_att = torch.gather(
p2c_att,
dim=-1,
index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]),
).transpose(-1, -2)
score += p2c_att / scale.to(dtype=p2c_att.dtype)
return score
class QiDeBERTaSelfOutput(Module):
def __init__(
self,
d_model: int,
layer_norm_eps: float,
hidden_dropout_prob: float,
):
super().__init__()
self.dense = Linear(in_features=d_model, out_features=d_model)
self.LayerNorm = LayerNorm(normalized_shape=d_model, eps=layer_norm_eps)
self.dropout = Dropout(p=hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class QiDeBERTaAttention(Module):
def __init__(
self,
num_heads: int,
d_model: int,
share_att_key: bool,
relative_attention: bool,
max_position_embeddings: int,
hidden_dropout_prob: float,
attention_probs_dropout_prob: float,
layer_norm_eps: float,
pos_att_type: Optional[list] = None,
position_buckets: int = -1,
max_relative_positions: int = -1,
):
super().__init__()
self.self = QiDeBERTaDisentangledSelfAttention(
num_heads=num_heads,
d_model=d_model,
share_att_key=share_att_key,
relative_attention=relative_attention,
max_position_embeddings=max_position_embeddings,
hidden_dropout_prob=hidden_dropout_prob,
attention_probs_dropout_prob=attention_probs_dropout_prob,
pos_att_type=pos_att_type,
position_buckets=position_buckets,
max_relative_positions=max_relative_positions,
)
self.output = QiDeBERTaSelfOutput(
d_model=d_model,
layer_norm_eps=layer_norm_eps,
hidden_dropout_prob=hidden_dropout_prob,
)
def forward(
self,
hidden_states,
attention_mask,
output_attentions: bool = False,
relative_pos=None,
rel_embeddings=None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
self_output, att_matrix = self.self(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
relative_pos=relative_pos,
rel_embeddings=rel_embeddings,
)
attention_output = self.output(hidden_states=self_output, input_tensor=hidden_states)
return (attention_output, att_matrix) if output_attentions else (attention_output, None)
class QiDeBERTaIntermediate(Module):
def __init__(
self,
d_model: int,
d_ff: int,
):
super().__init__()
self.dense = Linear(in_features=d_model, out_features=d_ff)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = functional.gelu(hidden_states)
return hidden_states
class QiDeBERTaOutput(Module):
def __init__(
self,
d_ff: int,
d_model: int,
layer_norm_eps: float,
hidden_dropout_prob: float,
):
super().__init__()
self.dense = Linear(in_features=d_ff, out_features=d_model)
self.LayerNorm = LayerNorm(normalized_shape=d_model, eps=layer_norm_eps)
self.dropout = Dropout(p=hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class QiDeBERTaLayer(Module):
def __init__(
self,
num_heads: int,
d_model: int,
d_ff: int,
share_att_key: bool,
relative_attention: bool,
max_position_embeddings: int,
hidden_dropout_prob: float,
attention_probs_dropout_prob: float,
layer_norm_eps: float,
pos_att_type: Optional[list] = None,
position_buckets: int = -1,
max_relative_positions: int = -1,
):
super().__init__()
self.attention = QiDeBERTaAttention(
num_heads=num_heads,
d_model=d_model,
share_att_key=share_att_key,
relative_attention=relative_attention,
max_position_embeddings=max_position_embeddings,
hidden_dropout_prob=hidden_dropout_prob,
attention_probs_dropout_prob=attention_probs_dropout_prob,
layer_norm_eps=layer_norm_eps,
pos_att_type=pos_att_type,
position_buckets=position_buckets,
max_relative_positions=max_relative_positions,
)
self.intermediate = QiDeBERTaIntermediate(
d_model=d_model,
d_ff=d_ff
)
self.output = QiDeBERTaOutput(
d_ff=d_ff,
d_model=d_model,
layer_norm_eps=layer_norm_eps,
hidden_dropout_prob=hidden_dropout_prob,
)
def forward(
self,
hidden_states,
attention_mask,
relative_pos=None,
rel_embeddings=None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
attention_output, att_matrix = self.attention(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
relative_pos=relative_pos,
rel_embeddings=rel_embeddings,
)
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return (layer_output, att_matrix) if output_attentions else (layer_output, None)
class QiDeBERTaEncoder(Module):
"""Modified BertEncoder with relative position bias support"""
def __init__(
self,
num_layers: int,
num_heads: int,
d_model: int,
d_ff: int,
share_att_key: bool,
relative_attention: bool,
max_position_embeddings: int,
hidden_dropout_prob: float,
attention_probs_dropout_prob: float,
layer_norm_eps: float,
pos_att_type: Optional[list] = None,
position_buckets: int = -1,
max_relative_positions: int = -1,
):
super().__init__()
self.layer = ModuleList([
QiDeBERTaLayer(
num_heads=num_heads,
d_model=d_model,
d_ff=d_ff,
share_att_key=share_att_key,
relative_attention=relative_attention,
max_position_embeddings=max_position_embeddings,
hidden_dropout_prob=hidden_dropout_prob,
attention_probs_dropout_prob=attention_probs_dropout_prob,
layer_norm_eps=layer_norm_eps,
pos_att_type=pos_att_type,
position_buckets=position_buckets,
max_relative_positions=max_relative_positions,
)
for _ in range(num_layers)
])
self.max_relative_positions = max_position_embeddings
self.position_buckets = position_buckets
pos_ebd_size = position_buckets * 2
self.rel_embeddings = Embedding(num_embeddings=pos_ebd_size, embedding_dim=d_model)
self.LayerNorm = LayerNorm(normalized_shape=d_model, eps=layer_norm_eps, elementwise_affine=True)
self.gradient_checkpointing = False
def get_rel_embedding(self):
rel_embeddings = self.rel_embeddings.weight
if rel_embeddings is not None:
rel_embeddings = self.LayerNorm(rel_embeddings)
return rel_embeddings
@staticmethod
def get_attention_mask(attention_mask):
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
return attention_mask
def get_rel_pos(self, hidden_states):
relative_pos = build_relative_position(
hidden_states,
hidden_states,
bucket_size=self.position_buckets,
max_position=self.max_relative_positions,
)
return relative_pos
def forward(
self,
hidden_states,
attention_mask,
output_attentions: bool = True,
):
attention_mask = self.get_attention_mask(attention_mask)
relative_pos = self.get_rel_pos(hidden_states)
all_hidden_states: Optional[Tuple[torch.Tensor]] = (hidden_states,)
all_attentions = ()
next_kv = hidden_states
rel_embeddings = self.get_rel_embedding()
for i, layer_module in enumerate(self.layer):
if self.gradient_checkpointing and self.training:
output_states, attn_weights = self._gradient_checkpointing_func(
layer_module.__call__,
next_kv,
attention_mask,
relative_pos,
rel_embeddings,
output_attentions,
)
else:
output_states, attn_weights = layer_module(
hidden_states=next_kv,
attention_mask=attention_mask,
relative_pos=relative_pos,
rel_embeddings=rel_embeddings,
output_attentions=output_attentions,
)
if output_attentions:
all_attentions = all_attentions + (attn_weights,)
all_hidden_states = all_hidden_states + (output_states,)
next_kv = output_states
return BaseModelOutput(
last_hidden_state=output_states,
hidden_states=all_hidden_states,
attentions=all_attentions if output_attentions else None
)
class QiDeBERTaBase(DebertaV2PreTrainedModel):
VERSION = '1.1.0'
config_class = QiDeBERTaConfig
base_model_prefix = 'qideberta'
_encoder_layer_path = ''
_embedding_layer_path = ''
def freeze_encoder_layers(self, freeze_layers: Optional[int] = None):
"""
Freeze the first `freeze_layers` layers of the encoder
:param freeze_layers:
:return:
"""
# 以点分割
encoder_layer = self
for attr in self._encoder_layer_path.split("."):
# 获取属性
encoder_layer = getattr(encoder_layer, attr)
if not isinstance(encoder_layer, QiDeBERTaEncoder):
raise ValueError(f"Encoder layer is not instance of QiDeBERTaEncoder")
if freeze_layers is not None and freeze_layers > 0:
if freeze_layers >= len(encoder_layer.layer):
# 冻结所有层
encoder_layer.requires_grad_(requires_grad=False)
else:
# 冻结前freeze_layers层
for param in encoder_layer.layer[:freeze_layers].parameters():
param.requires_grad = False
# 解冻后面的层
for param in encoder_layer.layer[freeze_layers:].parameters():
param.requires_grad = True
else:
encoder_layer.requires_grad_(requires_grad=True)
def freeze_encoder_embed_layer(self, freeze: bool = True):
"""
Freeze the embedding layer
:param freeze:
:return:
"""
embedding_layer = self
for attr in self._embedding_layer_path.split("."):
embedding_layer = getattr(embedding_layer, attr)
if not isinstance(embedding_layer, QiDeBERTaEmbeddings):
raise ValueError(f"Embedding layer is not instance of QiDeBERTaEmbeddings")
embedding_layer.requires_grad_(
requires_grad=False if freeze else True
)
class QiDeBERTa(QiDeBERTaBase):
_encoder_layer_path = 'encoder'
_embedding_layer_path = 'embeddings'
def __init__(self, config: QiDeBERTaConfig):
super(QiDeBERTa, self).__init__(config=config)
self.embeddings = QiDeBERTaEmbeddings(
pad_token_id=config.pad_token_id,
d_model=config.d_model,
vocab_size=config.vocab_size,
layer_norm_eps=config.layer_norm_eps,
hidden_dropout_prob=config.hidden_dropout_prob,
)
self.encoder = QiDeBERTaEncoder(
num_layers=config.num_layers,
num_heads=config.num_heads,
max_position_embeddings=config.max_position_embeddings,
position_buckets=config.position_buckets,
d_model=config.d_model,
d_ff=config.d_ff,
layer_norm_eps=config.layer_norm_eps,
share_att_key=config.share_att_key,
relative_attention=config.relative_attention,
hidden_dropout_prob=config.hidden_dropout_prob,
attention_probs_dropout_prob=config.attention_probs_dropout_prob,
pos_att_type=config.pos_att_type,
max_relative_positions=config.max_relative_positions,
)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> Embedding:
return self.embeddings.word_embeddings
def set_input_embeddings(self, new_embeddings):
self.embeddings.word_embeddings = new_embeddings
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = True,
) -> BaseModelOutput:
"""
Forward pass of the model
:param input_ids: Token indices of input sequence tokens in the vocabulary. (batch_size, sequence_length)
:param attention_mask: Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
(batch_size, sequence_length)
:param output_attentions:
:return:
"""
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
input_shape = input_ids.size()
device = input_ids.device
if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device)
embedding_output, token_embeddings = self.embeddings(
input_ids=input_ids,
mask=attention_mask,
)
encoder_outputs = self.encoder(
hidden_states=embedding_output,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
return BaseModelOutput(
last_hidden_state=encoder_outputs.last_hidden_state,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions if output_attentions else None,
)
class QiDeBERTaMLMHead(Module):
def __init__(
self,
d_model: int,
vocab_size: int,
layer_norm_eps: float,
):
super().__init__()
self.dense = Linear(in_features=d_model, out_features=d_model)
self.LayerNorm = LayerNorm(normalized_shape=d_model, eps=layer_norm_eps, elementwise_affine=True)
self.bias = Parameter(torch.zeros(vocab_size))
@staticmethod
def _init_weights(module):
"""Initialize the weights."""
if isinstance(module, Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
def _initialize_weights(self, module):
if getattr(module, "_is_hf_initialized", False):
return
self._init_weights(module)
module._is_hf_initialized = True
def forward(self, hidden_states: torch.Tensor, word_embeddings: Embedding):
hidden_states = self.dense(hidden_states)
hidden_states = functional.gelu(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
hidden_states = torch.matmul(hidden_states, word_embeddings.weight.t()) + self.bias
return hidden_states
class QiDeBERTaClassificationHead(Module):
def __init__(
self,
d_model: int,
num_labels: int,
hidden_dropout_prob: float,
):
super().__init__()
self.dropout = Dropout(p=hidden_dropout_prob)
self.classifier = Linear(in_features=d_model, out_features=num_labels)
@staticmethod
def _init_weights(module):
"""Initialize the weights."""
if isinstance(module, Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
def _initialize_weights(self, module):
if getattr(module, "_is_hf_initialized", False):
return
self._init_weights(module)
module._is_hf_initialized = True
def forward(self, hidden_states: torch.Tensor):
dropped = self.dropout(hidden_states)
logits = self.classifier(dropped)
return logits
class QiDeBERTaCSC(QiDeBERTaBase):
_tied_weights_keys = ["mlm_head.weight", "qideberta.embeddings.word_embeddings.weight"]
_encoder_layer_path = 'qideberta.encoder'
_embedding_layer_path = 'qideberta.embeddings'
task_head = ['mlm_head']
def __init__(self, config: QiDeBERTaConfig):
super().__init__(config)
self.qideberta = QiDeBERTa(config=config)
# 掩码语言模型任务头
self.mlm_head = QiDeBERTaMLMHead(
d_model=config.d_model,
vocab_size=config.vocab_size,
layer_norm_eps=config.layer_norm_eps,
)
self.post_init()
def get_output_embeddings(self):
return self.qideberta.embeddings.word_embeddings
def set_output_embeddings(self, new_embeddings):
self.qideberta.embeddings.word_embeddings = new_embeddings
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> MaskedLMOutput:
outputs = self.qideberta(
input_ids=input_ids,
attention_mask=attention_mask,
)
# 计算 mlm_logits
mlm_logits = self.mlm_head(hidden_states=outputs.last_hidden_state, word_embeddings=self.get_output_embeddings())
return MaskedLMOutput(
logits=mlm_logits, # [B, L, V] where V is the vocabulary size
hidden_states=outputs.hidden_states, # [B, L, H]
attentions=outputs.attentions, # [B, H, L, L]
)
class QiDeBERTaCSCForMaskedLM(QiDeBERTaCSC):
pass