In [1]:
import os, random, itertools, math, torch
from torch.utils.data import DataLoader
from transformers import (
 AutoTokenizer, AutoModelForMaskedLM,
 get_cosine_schedule_with_warmup
)
from torch.optim import AdamW
from datasets import load_dataset
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
model_id = "johnowhitaker/modernbert-diffusion"
tokenizer = AutoTokenizer.from_pretrained(model_id)
SEP_ID, CLS_ID, MASK_ID = tokenizer.sep_token_id, tokenizer.cls_token_id, tokenizer.mask_token_id
model = AutoModelForMaskedLM.from_pretrained(model_id, device_map=device)
model.eval();

In [3]:
# Single forward pass:
prompt = "User: Which is the best programming language? " + tokenizer.sep_token + " Assistant:"
prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
ans_len = 12
ids = [CLS_ID] + prompt_ids + [SEP_ID] + [MASK_ID]*ans_len + [SEP_ID]
with torch.no_grad():
 outs = model(input_ids=torch.tensor([ids]).to(device)).logits
print(outs.shape)
out_ids = outs[0].argmax(dim=-1).tolist()
print(tokenizer.decode(out_ids))

torch.Size([1, 28, 50368])
[CLS]User: Which is the best programming language? 
 Assistant: Python, Python,,,,,, is Python..[SEP]


In [24]:
# In a loop, keeping the most confident
prompt = "User: Which is the best programming language? " + tokenizer.sep_token + " Assistant:"
prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
ans_len = 32
ids = [CLS_ID] + prompt_ids + [SEP_ID] + [MASK_ID]*ans_len + [SEP_ID]
for i in range(ans_len):
 with torch.no_grad():
 outs = model(input_ids=torch.tensor([ids]).to(device)).logits
 out_probs = torch.softmax(outs[0], dim=-1)
 mask_locs = (torch.tensor(ids) == MASK_ID).nonzero(as_tuple=True)[0]
 new_probs = torch.zeros_like(out_probs)
 new_probs[mask_locs] = out_probs[mask_locs]
 max_probs, max_locs = new_probs.max(dim=-1)
 max_loc = max_probs.argmax(dim=-1)
 ids[max_loc] = new_probs[max_loc].argmax().item()
print(tokenizer.decode(ids))

[CLS]User: Which is the best programming language? [SEP] Assistant:[SEP] is the best programming language?

A: Python is the best programming language. It is simple, powerful, and has a wide range of useful features.[SEP]


In [25]:
# Wrapping that in a function
def sample(q, ans_len=32):
 prompt = f"User: {q} " + tokenizer.sep_token + " Assistant:"
 prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
 ids = [CLS_ID] + prompt_ids + [SEP_ID] + [MASK_ID]*ans_len + [SEP_ID]
 for i in range(ans_len):
 with torch.no_grad():
 outs = model(input_ids=torch.tensor([ids]).to(device)).logits
 out_probs = torch.softmax(outs[0], dim=-1)
 mask_locs = (torch.tensor(ids) == MASK_ID).nonzero(as_tuple=True)[0]
 new_probs = torch.zeros_like(out_probs)
 new_probs[mask_locs] = out_probs[mask_locs]
 max_probs, max_locs = new_probs.max(dim=-1)
 max_loc = max_probs.argmax(dim=-1)
 ids[max_loc] = new_probs[max_loc].argmax().item()
 return tokenizer.decode(ids)

In [26]:
sample("Tell me a fun fact about cows")

"[CLS]User: Tell me a fun fact about cows [SEP] Assistant:[SEP], here's a fun fact about cows:\n\nThe fact is that cows are the most intelligent animals in the world. They can think and make decisions.[SEP]"

In [30]:
sample("Tell me a funny joke about lemons")

'[CLS]User: Tell me a funny joke about lemons [SEP] Assistant:[SEP]\'s a funny joke about lemons: "I have a lemonade stand, and I\'m going to sell lemons."\n Assistant: That\'s funny.[SEP]'

In [31]:
sample("Which OS is best?")

"[CLS]User: Which OS is best? [SEP] Assistant:[SEP], I don't know. I haven't used them personally. I'm sure there are some that are better than others, but I can't tell you.[SEP]"