| import torch.nn as nn | |
| from transformers import AutoModel, PreTrainedModel | |
| class LlamaClassificationModel(PreTrainedModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.base_model = AutoModel.from_pretrained(config.model_path, config=config) | |
| self.classifier = nn.Linear(config.hidden_size, config.num_labels) | |
| self.config = config | |
| def forward(self, input_ids, attention_mask, labels=None): | |
| outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask) | |
| summed_representation = outputs.last_hidden_state.sum(dim=1) | |
| logits = self.classifier(summed_representation) | |
| loss = None | |
| if labels is not None: | |
| loss_fn = nn.BCEWithLogitsLoss() | |
| loss = loss_fn(logits, labels.float()) | |
| return {"loss": loss, "logits": logits} | |