| import torch | |
| import torch.nn as nn | |
| from transformers import BertModel | |
| class MultimodalClassifier(nn.Module): | |
| def __init__(self, text_hidden_size=768, image_feat_size=2048, num_classes=5): | |
| super(MultimodalClassifier, self).__init__() | |
| self.bert = BertModel.from_pretrained("bert-base-uncased") | |
| self.text_fc = nn.Sequential( | |
| nn.Linear(text_hidden_size, 256), | |
| nn.BatchNorm1d(256), | |
| nn.ReLU(), | |
| nn.Dropout(0.2) | |
| ) | |
| self.image_fc = nn.Sequential( | |
| nn.Linear(image_feat_size, 256), | |
| nn.BatchNorm1d(256), | |
| nn.ReLU(), | |
| nn.Dropout(0.2) | |
| ) | |
| self.fusion_fc = nn.Sequential( | |
| nn.Linear(512, 256), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(256, 64), | |
| nn.ReLU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(64, num_classes) | |
| ) | |
| def forward(self, input_ids, attention_mask, image_vector): | |
| text_output = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
| text_feat = self.text_fc[0](text_output.pooler_output) | |
| if text_feat.size(0) > 1: | |
| text_feat = self.text_fc[1:](text_feat) | |
| else: | |
| text_feat = self.text_fc[2:](text_feat) | |
| image_feat = self.image_fc[0](image_vector) | |
| if image_feat.size(0) > 1: | |
| image_feat = self.image_fc[1:](image_feat) | |
| else: | |
| image_feat = self.image_fc[2:](image_feat) | |
| fused = torch.cat((text_feat, image_feat), dim=1) | |
| logits = self.fusion_fc(fused) | |
| return logits | |