import torch import torch.nn as nn import torch.nn.functional as F from torch_geometric.nn import GATConv from transformers import PreTrainedModel from transformers.modeling_outputs import BaseModelOutputWithPooling, SequenceClassifierOutput from .configuration_bbb_model import BBBConfig class BBBModelForSequenceClassification(PreTrainedModel): config_class = BBBConfig def __init__(self, config: BBBConfig): super().__init__(config) self.config = config self.activation = nn.LeakyReLU() self.gats = nn.ModuleList() self.bns = nn.ModuleList() for i in range(config.gnn_layers): if i == 0: self.gats.append(GATConv(config.input_dim, config.gnn_hidden, heads=config.num_heads, concat=True, dropout=config.dropout)) else: self.gats.append(GATConv(config.gnn_hidden * config.num_heads, config.gnn_hidden, heads=config.num_heads, concat=True, dropout=config.dropout)) self.bns.append(nn.BatchNorm1d(config.gnn_hidden * config.num_heads)) self.proj_gnn = nn.Sequential( nn.Linear(config.gnn_hidden * config.num_heads, config.proj_dim), nn.LeakyReLU(), nn.Dropout(config.dropout), nn.BatchNorm1d(config.proj_dim) ) self.proj_feat = nn.Sequential( nn.Linear(config.num_features, config.proj_dim), nn.LeakyReLU(), nn.Dropout(config.dropout), nn.BatchNorm1d(config.proj_dim) ) layers = [] input_dim = config.proj_dim * 2 for output_dim in config.neurons_fc: layers.append(nn.Linear(input_dim, output_dim)) layers.append(nn.LeakyReLU()) layers.append(nn.Dropout(config.dropout)) layers.append(nn.BatchNorm1d(output_dim)) input_dim = output_dim layers.append(nn.Linear(input_dim, 1)) self.fc = nn.Sequential(*layers) def _init_weights(self, module): if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=1.0) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=1.0) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) def forward(self, x: torch.Tensor = None, edge_index: torch.Tensor = None, batch: torch.Tensor = None, features: torch.Tensor = None, labels: torch.Tensor = None, # standard input from HF output_attentions = None, output_hidden_states=None, return_dict=None): return_dict = return_dict if return_dict is not None else self.config.use_return_dict if x is None or edge_index is None or batch is None or features is None: raise ValueError('You have to specify x, edge_index, batch, and features') for gat, bn in zip(self.gats, self.bns): x = gat(x, edge_index) x = bn(x) x = self.activation(x) super_nodes = [] for i in range(batch.max().item() + 1): mask = (batch == i) node_idx = mask.nonzero(as_tuple=True)[0] super_node_idx = node_idx[-1] super_nodes.append(x[super_node_idx]) x_super_nodes = torch.stack(super_nodes, dim=0) gnn_proj = self.proj_gnn(x_super_nodes) feat_proj = self.proj_feat(features) combined = torch.cat([gnn_proj, feat_proj], dim=1) logits = self.fc(combined) loss = None if labels is not None: if self.config.task == "regression": loss_fct = nn.MSELoss() loss = loss_fct(logits.squeeze(-1), labels.squeeze(-1)) elif self.config.task == "classification": loss_fct = nn.BCEWithLogitsLoss() loss = loss_fct(logits, labels.float().unsqueeze(-1)) return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=None, attentions=None )