{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "429b26f3-8c61-46cc-b5fc-284add4d018f", "metadata": {}, "outputs": [], "source": [ "import json\n", "from tqdm.auto import tqdm\n", "from datasets import load_dataset\n", "import pandas as pd\n", "import numpy as np\n", "import torch\n", "import os\n", "\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"" ] }, { "cell_type": "code", "execution_count": 2, "id": "2a927511-78a0-42d5-861d-9e7af50ff000", "metadata": {}, "outputs": [], "source": [ "import requests\n", "from bs4 import BeautifulSoup\n", "\n", "page = requests.get('https://arxiv.org/category_taxonomy')\n", "soup = BeautifulSoup(page.content)\n", "tag_to_name = {}\n", "for tag_html in soup.find_all('h4')[1:]:\n", " tag, name = tag_html.text.split(maxsplit=1)\n", " tag_to_name[tag] = name[1:-1]\n", "with open('tag_to_name.json', 'w') as fout:\n", " json.dump(tag_to_name, fout)" ] }, { "cell_type": "code", "execution_count": 3, "id": "19b75e52-15c0-472e-b737-72c5eea896ec", "metadata": {}, "outputs": [], "source": [ "tag_to_label = dict(zip(tag_to_name, range(len(tag_to_name))))" ] }, { "cell_type": "code", "execution_count": 4, "id": "fec2865f-2992-4b3e-9202-8e9b8c5a7da1", "metadata": {}, "outputs": [], "source": [ "def add_labels(row):\n", " tag_list = eval(row['tag'])\n", " label_ids, label_tags = [], []\n", " for tag_dict in tag_list:\n", " if tag_dict['term'] in tag_to_label:\n", " label_tags.append(tag_dict['term'])\n", " label_ids.append(tag_to_label[tag_dict['term']])\n", " return {'label_ids': label_ids, 'label_tags': label_tags}" ] }, { "cell_type": "code", "execution_count": 5, "id": "81dff335-093f-4a59-93b5-27d7c57aac9a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using custom data configuration default-60d1f0f90275ae1e\n", "Found cached dataset json (/root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-66945521f8e38136.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-5298549794823409.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-6c93a706327f5678.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-ff58b61d0d461ac4.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-259b966b550351dc.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-8f0ed2baf297a3db.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-845944d2885d6a34.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-8ec43ba6cf3d3eba.arrow\n" ] } ], "source": [ "dataset = load_dataset(\"json\", data_files=\"arxivData.json\", split=\"train\")\n", "dataset = dataset.map(add_labels, num_proc=8)\n", "dataset = dataset.remove_columns(['author', 'day', 'id', 'link', 'month', 'tag', 'year'])" ] }, { "cell_type": "code", "execution_count": 6, "id": "c9a6ab6a-6a47-4377-a9d9-044c3a395ef3", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
| \n", " | summary | \n", "title | \n", "label_ids | \n", "label_tags | \n", "
|---|---|---|---|---|
| 0 | \n", "We propose an architecture for VQA which utili... | \n", "Dual Recurrent Attention Units for Visual Ques... | \n", "[0, 5, 7, 28, 152] | \n", "[cs.AI, cs.CL, cs.CV, cs.NE, stat.ML] | \n", "
| 1 | \n", "In a physical neural system, where storage and... | \n", "A Theory of Local Learning, the Learning Chann... | \n", "[22, 28, 152] | \n", "[cs.LG, cs.NE, stat.ML] | \n", "
| 2 | \n", "One way to approach end-to-end autonomous driv... | \n", "Query-Efficient Imitation Learning for End-to-... | \n", "[22, 0, 34] | \n", "[cs.LG, cs.AI, cs.RO] | \n", "
| Step | \n", "Training Loss | \n", "Validation Loss | \n", "
|---|---|---|
| 100 | \n", "4.286100 | \n", "2.809958 | \n", "
| 200 | \n", "2.365700 | \n", "2.110714 | \n", "
| 300 | \n", "2.023600 | \n", "2.046348 | \n", "
| 400 | \n", "2.020400 | \n", "1.982979 | \n", "
| 500 | \n", "1.927300 | \n", "1.915667 | \n", "
| 600 | \n", "1.919500 | \n", "1.927610 | \n", "
| 700 | \n", "1.834600 | \n", "1.929402 | \n", "
| 800 | \n", "1.840800 | \n", "1.861055 | \n", "
| 900 | \n", "1.823900 | \n", "1.819358 | \n", "
| 1000 | \n", "1.757100 | \n", "1.798097 | \n", "
| 1100 | \n", "1.746500 | \n", "1.779167 | \n", "
| 1200 | \n", "1.775000 | \n", "1.774340 | \n", "
| 1300 | \n", "1.698500 | \n", "1.764457 | \n", "
| 1400 | \n", "1.684200 | \n", "1.741629 | \n", "
| 1500 | \n", "1.763000 | \n", "1.680664 | \n", "
| 1600 | \n", "1.678400 | \n", "1.712918 | \n", "
| 1700 | \n", "1.669800 | \n", "1.710484 | \n", "
| 1800 | \n", "1.665000 | \n", "1.698851 | \n", "
| 1900 | \n", "1.645200 | \n", "1.663767 | \n", "
| 2000 | \n", "1.667600 | \n", "1.674545 | \n", "
| 2100 | \n", "1.602300 | \n", "1.680639 | \n", "
| 2200 | \n", "1.651800 | \n", "1.667343 | \n", "
| 2300 | \n", "1.622600 | \n", "1.659117 | \n", "
| 2400 | \n", "1.616900 | \n", "1.645381 | \n", "
| 2500 | \n", "1.600900 | \n", "1.642603 | \n", "
| 2600 | \n", "1.590200 | \n", "1.657698 | \n", "
| 2700 | \n", "1.646300 | \n", "1.644075 | \n", "
| 2800 | \n", "1.602600 | \n", "1.626339 | \n", "
| 2900 | \n", "1.596800 | \n", "1.646950 | \n", "
| 3000 | \n", "1.547200 | \n", "1.622913 | \n", "
| 3100 | \n", "1.563500 | \n", "1.611651 | \n", "
| 3200 | \n", "1.583500 | \n", "1.608005 | \n", "
| 3300 | \n", "1.565800 | \n", "1.626086 | \n", "
| 3400 | \n", "1.531000 | \n", "1.626902 | \n", "
| 3500 | \n", "1.566100 | \n", "1.607745 | \n", "
| 3600 | \n", "1.555100 | \n", "1.594658 | \n", "
| 3700 | \n", "1.597600 | \n", "1.597994 | \n", "
| 3800 | \n", "1.497600 | \n", "1.590335 | \n", "
| 3900 | \n", "1.522300 | \n", "1.588875 | \n", "
| 4000 | \n", "1.506600 | \n", "1.572686 | \n", "
| 4100 | \n", "1.497900 | \n", "1.602122 | \n", "
| 4200 | \n", "1.534100 | \n", "1.576102 | \n", "
| 4300 | \n", "1.517400 | \n", "1.578320 | \n", "
| 4400 | \n", "1.518500 | \n", "1.588920 | \n", "
| 4500 | \n", "1.510200 | \n", "1.596100 | \n", "
| 4600 | \n", "1.441100 | \n", "1.576099 | \n", "
| 4700 | \n", "1.511000 | \n", "1.575001 | \n", "
| 4800 | \n", "1.487700 | \n", "1.579319 | \n", "
| 4900 | \n", "1.491300 | \n", "1.591276 | \n", "
| 5000 | \n", "1.474700 | \n", "1.572709 | \n", "
"
],
"text/plain": [
"