from transformers import AutoModel, PreTrainedModel, PretrainedConfig from transformers.modeling_outputs import BaseModelOutput import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Union, Tuple class ContrastiveClinicalConfig(PretrainedConfig): model_type = "contrastive_clinical" def __init__( self, model_name: str = "Simonlee711/Clinical_ModernBERT", dropout_rate: float = 0.15, entity_token_id: int = 50368, vocab_size: int = 50370, **kwargs ): super().__init__(**kwargs) self.model_name = model_name self.dropout_rate = dropout_rate self.entity_token_id = entity_token_id self.vocab_size = vocab_size # Set vocab_size in base config too self.vocab_size = vocab_size class ContrastiveClinicalModel(PreTrainedModel): config_class = ContrastiveClinicalConfig def __init__(self, config): super().__init__(config) # Import here to avoid circular imports from transformers import AutoConfig # Create base config with correct vocab size base_config = AutoConfig.from_pretrained(config.model_name) base_config.vocab_size = config.vocab_size # Load model with updated config (this should work since we saved weights with correct size) self.encoder = AutoModel.from_pretrained(config.model_name, config=base_config, ignore_mismatched_sizes=True) self.dropout = nn.Dropout(config.dropout_rate) self.model_name = config.model_name self.entity_token_id = getattr(config, 'entity_token_id', None) def forward( self, input_ids=None, attention_mask=None, pos_input_ids=None, pos_attention_mask=None, neg_input_ids=None, neg_attention_mask=None, labels=None, return_dict=None, **kwargs ): # Encode anchor anchor_output = self.encoder(input_ids=input_ids, attention_mask=attention_mask) anchor_emb = self._mean_pooling(anchor_output.last_hidden_state, attention_mask) if pos_input_ids is not None: # Training mode pos_output = self.encoder(input_ids=pos_input_ids, attention_mask=pos_attention_mask) pos_emb = self._mean_pooling(pos_output.last_hidden_state, pos_attention_mask) neg_output = self.encoder(input_ids=neg_input_ids, attention_mask=neg_attention_mask) neg_emb = self._mean_pooling(neg_output.last_hidden_state, neg_attention_mask) # Apply dropout and normalize anchor_emb = F.normalize(self.dropout(anchor_emb), p=2, dim=1) pos_emb = F.normalize(self.dropout(pos_emb), p=2, dim=1) neg_emb = F.normalize(self.dropout(neg_emb), p=2, dim=1) # Compute triplet loss loss = self._triplet_loss(anchor_emb, pos_emb, neg_emb, margin=1.0) return { 'loss': loss, 'anchor_embeddings': anchor_emb, 'positive_embeddings': pos_emb, 'negative_embeddings': neg_emb, } else: # Inference mode anchor_emb = F.normalize(anchor_emb, p=2, dim=1) return BaseModelOutput(last_hidden_state=anchor_emb.unsqueeze(1)) def _mean_pooling(self, last_hidden_state, attention_mask): """Mean pooling with attention mask - properly handles [ENTITY] tokens.""" input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float() sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1) sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) return sum_embeddings / sum_mask def _triplet_loss(self, anchor, positive, negative, margin=1.0): """Compute triplet loss with margin.""" pos_dist = F.pairwise_distance(anchor, positive, p=2) neg_dist = F.pairwise_distance(anchor, negative, p=2) loss = F.relu(pos_dist - neg_dist + margin) return loss.mean() def encode(self, input_ids, attention_mask): """Get embeddings for inference - handles [ENTITY] tokens properly.""" with torch.no_grad(): output = self.encoder(input_ids=input_ids, attention_mask=attention_mask) embeddings = self._mean_pooling(output.last_hidden_state, attention_mask) return F.normalize(embeddings, p=2, dim=1) # Register the model from transformers import AutoConfig, AutoModel AutoConfig.register("contrastive_clinical", ContrastiveClinicalConfig) AutoModel.register(ContrastiveClinicalConfig, ContrastiveClinicalModel)