Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import transformers | |
| class BertAD(nn.Module): | |
| def __init__(self): | |
| super(BertAD, self).__init__() | |
| model_config = transformers.AutoConfig.from_pretrained('./model') | |
| model_config.update({"output_hidden_states":True}) | |
| self.bert = transformers.BertModel(model_config) | |
| self.layer = nn.Linear(768, 2) | |
| def forward(self, ids, mask, token_type): | |
| output = self.bert(input_ids = ids, | |
| attention_mask = mask, | |
| token_type_ids = token_type) | |
| logits = self.layer(output[0]) | |
| start_logits, end_logits = logits.split(1, dim=-1) | |
| start_logits = start_logits.squeeze(-1) | |
| end_logits = end_logits.squeeze(-1) | |
| return start_logits, end_logits |