azza1625's picture
Update README.md
cd6c55f verified
---
tags:
- target-identification
- argumentation
- contrastive-learning
license: mit
language:
- en
base_model:
- answerdotai/ModernBERT-base
pipeline_tag: text-classification
---
This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
---
## Model Description
This is a dual-encoder retrieval model built on top of `answerdotai/ModernBERT-base`. The model is designed to perform target identification by finding the most relevant `theses` along with their associated data for a given `claim`
You can modify the `top_k`, `num_args` & `top_level_only` variables to adjust the output of the model.
## How to use
You can use this model for inference by loading it with the `transformers` library. The following code demonstrates how to make a prediction:
```python
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
from huggingface_hub import hf_hub_download, PyTorchModelHubMixin
import pickle
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
class DualEncoderThesisModel(nn.Module, PyTorchModelHubMixin):
def __init__(self) -> None:
super(DualEncoderThesisModel, self).__init__()
self.encoder = AutoModel.from_pretrained("answerdotai/ModernBERT-base")
def forward(self, input_ids_a, attention_mask_a, input_ids_b, attention_mask_b):
# Encode arguments
output_a = self.encoder(input_ids=input_ids_a, attention_mask=attention_mask_a).last_hidden_state
emb_a = output_a[:, 0]
# Encode theses
output_b = self.encoder(input_ids=input_ids_b, attention_mask=attention_mask_b).last_hidden_state
emb_b = output_b[:, 0]
return emb_a, emb_b
model_name = "azza1625/target-identification"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = DualEncoderThesisModel.from_pretrained(model_name)
model.eval()
device = "cpu"
embeddings_path = hf_hub_download(
repo_id="azza1625/target-identification",
filename="retrieval_data_random_negatives_10_train_data.pkl"
)
with open(embeddings_path, "rb") as f:
embeddings_metadata = pickle.load(f)
@torch.no_grad()
def retrieve_theses(claim, top_k=3, num_args=5, top_level_only=True, device="cpu"):
stored_embeddings = embeddings_metadata["embeddings"]
metadata = embeddings_metadata["metadata"]
enc = tokenizer(claim, return_tensors='pt', truncation=True, padding='max_length', max_length=1024).to(device)
query_embedding = model.encoder(**enc).last_hidden_state[:, 0].cpu().numpy()
sims = cosine_similarity(query_embedding, stored_embeddings)[0]
top_indices = np.argsort(sims)[::-1][:top_k]
results = []
for idx in top_indices:
arguments = metadata[idx]['arguments']
if top_level_only:
arguments = [arg for arg in arguments if arg['target_type'] == 'thesis']
results.append({
"thesis": metadata[idx]["thesis"],
"debate_title": metadata[idx]["debate_title"],
"arguments": arguments[:num_args]
})
return results
claim = "A fetus or embryo is not a person; therefore, abortion should not be considered murder."
theses = retrieve_theses(claim)
for thesis in theses:
print(f"{thesis['thesis']} | {thesis['debate_title']}")