import torch import torch.nn as nn from transformers import BertModel class MultimodalClassifier(nn.Module): def __init__(self, text_hidden_size=768, image_feat_size=2048, num_classes=5): super(MultimodalClassifier, self).__init__() self.bert = BertModel.from_pretrained("bert-base-uncased") self.text_fc = nn.Sequential( nn.Linear(text_hidden_size, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.2) ) self.image_fc = nn.Sequential( nn.Linear(image_feat_size, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.2) ) self.fusion_fc = nn.Sequential( nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, 64), nn.ReLU(), nn.Dropout(0.2), nn.Linear(64, num_classes) ) def forward(self, input_ids, attention_mask, image_vector): text_output = self.bert(input_ids=input_ids, attention_mask=attention_mask) text_feat = self.text_fc[0](text_output.pooler_output) if text_feat.size(0) > 1: text_feat = self.text_fc[1:](text_feat) else: text_feat = self.text_fc[2:](text_feat) image_feat = self.image_fc[0](image_vector) if image_feat.size(0) > 1: image_feat = self.image_fc[1:](image_feat) else: image_feat = self.image_fc[2:](image_feat) fused = torch.cat((text_feat, image_feat), dim=1) logits = self.fusion_fc(fused) return logits