import torch.nn as nn from transformers import AutoModel, PreTrainedModel class LlamaClassificationModel(PreTrainedModel): def __init__(self, config): super().__init__(config) self.base_model = AutoModel.from_pretrained(config.model_path, config=config) self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.config = config def forward(self, input_ids, attention_mask, labels=None): outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask) summed_representation = outputs.last_hidden_state.sum(dim=1) logits = self.classifier(summed_representation) loss = None if labels is not None: loss_fn = nn.BCEWithLogitsLoss() loss = loss_fn(logits, labels.float()) return {"loss": loss, "logits": logits}