emro-misty / misty_emro.py
caseykennington's picture
Upload 7 files
b923a93 verified
raw
history blame contribute delete
983 Bytes
import torch
from transformers import RobertaModel
class RobertaClass(torch.nn.Module):
def __init__(self, num_classes):
super(RobertaClass, self).__init__()
self.roberta = RobertaModel.from_pretrained('roberta-base')
self.dropout = torch.nn.Dropout(0.3)
self.classifier = torch.nn.Linear(self.roberta.config.hidden_size, num_classes)
self.loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
def forward(self, input_ids, attention_mask, labels=None):
outputs = self.roberta(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
)
pooled_output = torch.mean(outputs.last_hidden_state, dim=1)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
if labels is not None:
loss = self.loss_fn(logits, labels)
return loss, logits
else:
return logits