πŸ₯ Clinical Contrastive ModernBERT with [ENTITY] Token Support

This is a custom contrastive learning model specifically designed for clinical text with built-in support for the [ENTITY] token for anonymizing sensitive patient information.

🎯 Key Features

  • βœ… [ENTITY] Token Support: Anonymize patient names, IDs, locations
  • βœ… Contrastive Learning: Trained with triplet loss on clinical text
  • βœ… Clinical Domain: Optimized for medical/clinical language
  • βœ… Custom Architecture: Specialized contrastive model class
  • βœ… Attention-Masked Pooling: Proper handling of special tokens

πŸ“Š Model Details

  • Base Model: Simonlee711/Clinical_ModernBERT
  • Architecture: ContrastiveClinicalModel with triplet loss
  • Training: Triplet loss with margin=1.0
  • Vocabulary Size: 50,370 tokens
  • [ENTITY] Token ID: 50368
  • Max Sequence Length: 8192 tokens
  • Hidden Size: 768
  • Layers: 22

πŸš€ Quick Start

from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F

# Load model (trust_remote_code=True required for custom model)
tokenizer = AutoTokenizer.from_pretrained("nikhil061307/contrastive-learning-bert-added-token-v5")
model = AutoModel.from_pretrained("nikhil061307/contrastive-learning-bert-added-token-v5", trust_remote_code=True)

def get_clinical_embeddings(texts, max_length=256):
    """Get embeddings for clinical texts with [ENTITY] support."""
    inputs = tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=max_length,
        return_tensors='pt'
    )
    
    # Use the model's custom encode method
    with torch.no_grad():
        embeddings = model.encode(inputs['input_ids'], inputs['attention_mask'])
    
    return embeddings

# Example with [ENTITY] token for anonymization
clinical_texts = [
    "Patient [ENTITY] presents with chest pain and shortness of breath.",
    "Patient [ENTITY] reports severe headache lasting 3 days.",
    "Patient [ENTITY] diagnosed with acute myocardial infarction."
]

embeddings = get_clinical_embeddings(clinical_texts)
print(f"Embeddings shape: {embeddings.shape}")

# Calculate similarities
similarity_matrix = torch.mm(embeddings, embeddings.t())
print(f"Similarity between first two texts: {similarity_matrix[0,1]:.4f}")

⚠️ Important Usage Notes

  1. Trust Remote Code: Always use trust_remote_code=True when loading
  2. Custom Architecture: This uses a specialized ContrastiveClinicalModel class
  3. [ENTITY] Token: Token ID 50368 is preserved from training
  4. L2 Normalization: Embeddings are automatically L2 normalized
  5. Attention Masking: Properly handles padding and special tokens

🎯 Training Details

  • Training Method: Triplet loss contrastive learning
  • Loss Function: Triplet loss with margin=1.0
  • Pooling Strategy: Attention-masked mean pooling
  • Dropout Rate: 0.15 (training only)
  • Normalization: L2 normalization on embeddings
  • Special Tokens: Handles [ENTITY], [PAD], [CLS], [SEP]

πŸ”’ Privacy & Compliance

This model is designed to help with healthcare data privacy by:

  • Supporting entity anonymization with [ENTITY] tokens
  • Maintaining semantic similarity despite anonymization
  • Enabling analysis of de-identified clinical text
  • Preserving medical meaning while protecting patient privacy

Note: Always ensure compliance with relevant healthcare privacy regulations (HIPAA, GDPR, etc.) when processing medical data.

Downloads last month
1
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Spaces using nikhil061307/contrastive-learning-bert-added-token-v5 2

Evaluation results