{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "gpuType": "T4" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "id": "Q4qAMMPkQhfY" }, "outputs": [], "source": [ "import os, random, itertools, math, torch\n", "from torch.utils.data import DataLoader\n", "from transformers import (\n", " AutoTokenizer, AutoModelForMaskedLM,\n", " get_cosine_schedule_with_warmup\n", ")\n", "from torch.optim import AdamW\n", "from datasets import load_dataset\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" ] }, { "cell_type": "code", "source": [ "model_id = \"johnowhitaker/modernbert-diffusion\"\n", "tokenizer = AutoTokenizer.from_pretrained(model_id)\n", "SEP_ID, CLS_ID, MASK_ID = tokenizer.sep_token_id, tokenizer.cls_token_id, tokenizer.mask_token_id\n", "model = AutoModelForMaskedLM.from_pretrained(model_id, device_map=device)\n", "model.eval();" ], "metadata": { "id": "e4kbDTS3Qo_a" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Single forward pass:\n", "prompt = \"User: Which is the best programming language? \" + tokenizer.sep_token + \" Assistant:\"\n", "prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)\n", "ans_len = 12\n", "ids = [CLS_ID] + prompt_ids + [SEP_ID] + [MASK_ID]*ans_len + [SEP_ID]\n", "with torch.no_grad():\n", " outs = model(input_ids=torch.tensor([ids]).to(device)).logits\n", "print(outs.shape)\n", "out_ids = outs[0].argmax(dim=-1).tolist()\n", "print(tokenizer.decode(out_ids))" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Y7ZwaE3IQzJT", "outputId": "bd8a6d10-41c3-4531-d244-32094e71b1d3" }, "execution_count": 3, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "torch.Size([1, 28, 50368])\n", "[CLS]User: Which is the best programming language? \n", " Assistant: Python, Python,,,,,, is Python..[SEP]\n" ] } ] }, { "cell_type": "code", "source": [ "# In a loop, keeping the most confident\n", "prompt = \"User: Which is the best programming language? \" + tokenizer.sep_token + \" Assistant:\"\n", "prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)\n", "ans_len = 32\n", "ids = [CLS_ID] + prompt_ids + [SEP_ID] + [MASK_ID]*ans_len + [SEP_ID]\n", "for i in range(ans_len):\n", " with torch.no_grad():\n", " outs = model(input_ids=torch.tensor([ids]).to(device)).logits\n", " out_probs = torch.softmax(outs[0], dim=-1)\n", " mask_locs = (torch.tensor(ids) == MASK_ID).nonzero(as_tuple=True)[0]\n", " new_probs = torch.zeros_like(out_probs)\n", " new_probs[mask_locs] = out_probs[mask_locs]\n", " max_probs, max_locs = new_probs.max(dim=-1)\n", " max_loc = max_probs.argmax(dim=-1)\n", " ids[max_loc] = new_probs[max_loc].argmax().item()\n", "print(tokenizer.decode(ids))" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "wadlDG2DUUjX", "outputId": "06317b7c-7f71-4621-e0b6-c173df0839b7" }, "execution_count": 24, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "[CLS]User: Which is the best programming language? [SEP] Assistant:[SEP] is the best programming language?\n", "\n", "A: Python is the best programming language. It is simple, powerful, and has a wide range of useful features.[SEP]\n" ] } ] }, { "cell_type": "code", "source": [ "# Wrapping that in a function\n", "def sample(q, ans_len=32):\n", " prompt = f\"User: {q} \" + tokenizer.sep_token + \" Assistant:\"\n", " prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)\n", " ids = [CLS_ID] + prompt_ids + [SEP_ID] + [MASK_ID]*ans_len + [SEP_ID]\n", " for i in range(ans_len):\n", " with torch.no_grad():\n", " outs = model(input_ids=torch.tensor([ids]).to(device)).logits\n", " out_probs = torch.softmax(outs[0], dim=-1)\n", " mask_locs = (torch.tensor(ids) == MASK_ID).nonzero(as_tuple=True)[0]\n", " new_probs = torch.zeros_like(out_probs)\n", " new_probs[mask_locs] = out_probs[mask_locs]\n", " max_probs, max_locs = new_probs.max(dim=-1)\n", " max_loc = max_probs.argmax(dim=-1)\n", " ids[max_loc] = new_probs[max_loc].argmax().item()\n", " return tokenizer.decode(ids)" ], "metadata": { "id": "FAj0rtmhYcjF" }, "execution_count": 25, "outputs": [] }, { "cell_type": "code", "source": [ "sample(\"Tell me a fun fact about cows\")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 52 }, "id": "HAS20X0oZhw5", "outputId": "4f157101-1652-4c25-b67e-b957512bf632" }, "execution_count": 26, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "\"[CLS]User: Tell me a fun fact about cows [SEP] Assistant:[SEP], here's a fun fact about cows:\\n\\nThe fact is that cows are the most intelligent animals in the world. They can think and make decisions.[SEP]\"" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" } }, "metadata": {}, "execution_count": 26 } ] }, { "cell_type": "code", "source": [ "sample(\"Tell me a funny joke about lemons\")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 52 }, "id": "f0S3ZQLNUUnU", "outputId": "ddfc0e47-bbb1-496b-8177-5d796b8bd9af" }, "execution_count": 30, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "'[CLS]User: Tell me a funny joke about lemons [SEP] Assistant:[SEP]\\'s a funny joke about lemons: \"I have a lemonade stand, and I\\'m going to sell lemons.\"\\n Assistant: That\\'s funny.[SEP]'" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" } }, "metadata": {}, "execution_count": 30 } ] }, { "cell_type": "code", "source": [ "sample(\"Which OS is best?\")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 52 }, "id": "KugOpLPHaQSA", "outputId": "43767abf-5a3e-48e0-c14b-b180f7ba9a14" }, "execution_count": 31, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "\"[CLS]User: Which OS is best? [SEP] Assistant:[SEP], I don't know. I haven't used them personally. I'm sure there are some that are better than others, but I can't tell you.[SEP]\"" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" } }, "metadata": {}, "execution_count": 31 } ] } ] }