Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected/bnc_spoken_unaffected.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected/childes_unaffected.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected/gutenberg_unaffected.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected/open_subtitles_unaffected.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected/simple_wiki_unaffected.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected/switchboard_unaffected.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected_sents/bnc_spoken_unaffected_sents.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected_sents/childes_unaffected_sents.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected_sents/gutenberg_unaffected_sents.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected_sents/open_subtitles_unaffected_sents.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected_sents/simple_wiki_unaffected_sents.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected_sents/switchboard_unaffected_sents.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected/bnc_spoken_unaffected.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected/childes_unaffected.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected/gutenberg_unaffected.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected/open_subtitles_unaffected.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected/simple_wiki_unaffected.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected/switchboard_unaffected.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected_sents/bnc_spoken_unaffected_sents.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected_sents/childes_unaffected_sents.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected_sents/gutenberg_unaffected_sents.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected_sents/open_subtitles_unaffected_sents.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected_sents/simple_wiki_unaffected_sents.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected_sents/switchboard_unaffected_sents.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected/bnc_spoken_unaffected.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected/childes_unaffected.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected/gutenberg_unaffected.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected/open_subtitles_unaffected.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected/simple_wiki_unaffected.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected/switchboard_unaffected.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected_sents/bnc_spoken_unaffected_sents.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected_sents/childes_unaffected_sents.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected_sents/gutenberg_unaffected_sents.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected_sents/open_subtitles_unaffected_sents.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected_sents/simple_wiki_unaffected_sents.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected_sents/switchboard_unaffected_sents.test +0 -0
- data/perturb.py +359 -0
- data/perturb.sh +35 -0
- data/perturb_llama.py +361 -0
- data/perturb_model.sh +40 -0
- data/perturb_qwen.py +361 -0
- data/tag.py +153 -0
- data/tag_1.py +166 -0
- data/tag_distributed.py +106 -0
- data/tag_single.py +140 -0
- perplexities/perplexity_results/Qwen2.5-0.5B/reverse_full/Qwen2.5-0.5B_seed1_test_reverse_full_checkpoint-1000.csv +2 -0
- perplexities/perplexity_results/Qwen2.5-0.5B/reverse_full/Qwen2.5-0.5B_seed1_test_reverse_full_checkpoint-10000.csv +2 -0
- perplexities/perplexity_results/Qwen2.5-0.5B/reverse_full/Qwen2.5-0.5B_seed1_test_reverse_full_checkpoint-11500.csv +2 -0
- perplexities/perplexity_results/Qwen2.5-0.5B/reverse_full/Qwen2.5-0.5B_seed1_test_reverse_full_checkpoint-1500.csv +2 -0
- perplexities/perplexity_results/Qwen2.5-0.5B/reverse_full/Qwen2.5-0.5B_seed1_test_reverse_full_checkpoint-2000.csv +2 -0
data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected/bnc_spoken_unaffected.test
ADDED
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected/childes_unaffected.test
ADDED
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected/gutenberg_unaffected.test
ADDED
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected/open_subtitles_unaffected.test
ADDED
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected/simple_wiki_unaffected.test
ADDED
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected/switchboard_unaffected.test
ADDED
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected_sents/bnc_spoken_unaffected_sents.test
ADDED
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected_sents/childes_unaffected_sents.test
ADDED
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected_sents/gutenberg_unaffected_sents.test
ADDED
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected_sents/open_subtitles_unaffected_sents.test
ADDED
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected_sents/simple_wiki_unaffected_sents.test
ADDED
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected_sents/switchboard_unaffected_sents.test
ADDED
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected/bnc_spoken_unaffected.test
ADDED
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected/childes_unaffected.test
ADDED
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected/gutenberg_unaffected.test
ADDED
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected/open_subtitles_unaffected.test
ADDED
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected/simple_wiki_unaffected.test
ADDED
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected/switchboard_unaffected.test
ADDED
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected_sents/bnc_spoken_unaffected_sents.test
ADDED
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected_sents/childes_unaffected_sents.test
ADDED
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected_sents/gutenberg_unaffected_sents.test
ADDED
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected_sents/open_subtitles_unaffected_sents.test
ADDED
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected_sents/simple_wiki_unaffected_sents.test
ADDED
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected_sents/switchboard_unaffected_sents.test
ADDED
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected/bnc_spoken_unaffected.test
ADDED
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected/childes_unaffected.test
ADDED
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected/gutenberg_unaffected.test
ADDED
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected/open_subtitles_unaffected.test
ADDED
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected/simple_wiki_unaffected.test
ADDED
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected/switchboard_unaffected.test
ADDED
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected_sents/bnc_spoken_unaffected_sents.test
ADDED
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected_sents/childes_unaffected_sents.test
ADDED
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected_sents/gutenberg_unaffected_sents.test
ADDED
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected_sents/open_subtitles_unaffected_sents.test
ADDED
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected_sents/simple_wiki_unaffected_sents.test
ADDED
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected_sents/switchboard_unaffected_sents.test
ADDED
File without changes
|
data/perturb.py
ADDED
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# perturb.py
|
2 |
+
# Author: Julie Kallini
|
3 |
+
|
4 |
+
# For importing utils
|
5 |
+
import sys
|
6 |
+
sys.path.append("..")
|
7 |
+
|
8 |
+
from utils import PERTURBATIONS, BABYLM_SPLITS, BABYLM_DATA_PATH, \
|
9 |
+
GENRES, MARKER_TOKEN_IDS, marker_sg_token, marker_pl_token, marker_rev_token, write_file
|
10 |
+
from glob import glob
|
11 |
+
import numpy as np
|
12 |
+
import itertools
|
13 |
+
import json
|
14 |
+
import os
|
15 |
+
import tqdm
|
16 |
+
import argparse
|
17 |
+
import pytest
|
18 |
+
|
19 |
+
|
20 |
+
def lines_equivalent_3pres(file1_path, file2_path):
|
21 |
+
"""Compare lines of two files after splitting them."""
|
22 |
+
with open(file1_path, 'r') as file1, open(file2_path, 'r') as file2:
|
23 |
+
for line1, line2 in zip(file1, file2):
|
24 |
+
# Split each line and compare the resulting lists
|
25 |
+
res1 = [i for i in line1.split() if int(
|
26 |
+
i) not in (marker_sg_token, marker_pl_token)]
|
27 |
+
res2 = [i for i in line2.split() if int(
|
28 |
+
i) not in (marker_sg_token, marker_pl_token)]
|
29 |
+
if res1 != res2:
|
30 |
+
print(line1)
|
31 |
+
print(line2)
|
32 |
+
return False
|
33 |
+
|
34 |
+
# Check if one file has more lines than the other
|
35 |
+
if file1.readline() or file2.readline():
|
36 |
+
return False
|
37 |
+
|
38 |
+
return True
|
39 |
+
|
40 |
+
|
41 |
+
perturbation_pairs_3pres = [
|
42 |
+
("0tokens", "4tokens"),
|
43 |
+
("0tokens", "4words"),
|
44 |
+
("4tokens", "4words"),
|
45 |
+
]
|
46 |
+
|
47 |
+
# Yj: 针对与第三人称单数/复数相关的扰动对进行组合测试
|
48 |
+
|
49 |
+
test_data = itertools.product(
|
50 |
+
["100M", "dev", "test_affected", "test_unaffected"], GENRES.keys(), perturbation_pairs_3pres) # Yj: generate different pairs used in test
|
51 |
+
|
52 |
+
# Yj: 用于在测试函数中,例如 test_3pres_all_equivalent,生成各种测试组合,包括不同的扰动策略。
|
53 |
+
# Yj: 区分受影响和未受影响的测试子集,以比较扰动前后的效果。
|
54 |
+
|
55 |
+
|
56 |
+
@pytest.mark.parametrize("split, genre, perturbation_pair", test_data) # 测试函数会针对 test_data 中的每组参数运行一次
|
57 |
+
def test_3pres_all_equivalent(split, genre, perturbation_pair): # Yj: genre these are different kinds of Corpus, which can be seen in utils.py
|
58 |
+
|
59 |
+
perturbation1, perturbation2 = perturbation_pair
|
60 |
+
|
61 |
+
if split in ("100M", "10M"):
|
62 |
+
filename = f"{genre}.train"
|
63 |
+
elif split == "test_affected":
|
64 |
+
filename = f"{genre}_affected.test"
|
65 |
+
elif split == "test_unaffected":
|
66 |
+
filename = f"{genre}_unaffected.test"
|
67 |
+
elif split == "dev":
|
68 |
+
filename = f"{genre}.dev" # Yj: Development Set is similar to Validation Set
|
69 |
+
|
70 |
+
path1 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed/babylm_3pres_{perturbation1}/babylm_{split}/{filename}"
|
71 |
+
path2 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed/babylm_3pres_{perturbation2}/babylm_{split}/{filename}"
|
72 |
+
|
73 |
+
#Yj: compare two files in two paths
|
74 |
+
assert lines_equivalent_3pres(path1, path2), f"File {filename} of " + \
|
75 |
+
f"3pres_{perturbation1} and 3pres_{perturbation2} have non-equivalent lines!"
|
76 |
+
|
77 |
+
|
78 |
+
def lines_equivalent_reversal(rev_path, ident_path):
|
79 |
+
"""Compare lines of reversal file and identity file after splitting them."""
|
80 |
+
with open(rev_path, 'r') as file1, open(ident_path, 'r') as file2:
|
81 |
+
for line1, line2 in zip(file1, file2):
|
82 |
+
# Split each line and compare the resulting lists
|
83 |
+
line1_tokens = line1.split()
|
84 |
+
line2_tokens = line2.split()
|
85 |
+
|
86 |
+
# Get REV marker index
|
87 |
+
marker_index = line1_tokens.index(str(marker_rev_token))
|
88 |
+
|
89 |
+
# Make sure tokens up to and including the marker are all the same
|
90 |
+
if line1_tokens[:marker_index+1] != line2_tokens[:marker_index+1]:
|
91 |
+
return False
|
92 |
+
|
93 |
+
# Make sure reversal of rest of string is equal to identity
|
94 |
+
line1_tokens_rev = line1_tokens[marker_index+1:].copy()
|
95 |
+
line1_tokens_rev.reverse()
|
96 |
+
if line1_tokens_rev != line2_tokens[marker_index+1:]:
|
97 |
+
return False
|
98 |
+
|
99 |
+
# Check if one file has more lines than the other
|
100 |
+
if file1.readline() or file2.readline():
|
101 |
+
return False
|
102 |
+
|
103 |
+
return True
|
104 |
+
|
105 |
+
|
106 |
+
perturbation_pairs_reversal = [
|
107 |
+
("reversal", "reversal_identity"),
|
108 |
+
]
|
109 |
+
# Yj: 针对反转扰动对进行组合测试
|
110 |
+
|
111 |
+
test_data = itertools.product(
|
112 |
+
["100M", "dev", "test_affected"], GENRES.keys(), perturbation_pairs_reversal)
|
113 |
+
|
114 |
+
@pytest.mark.parametrize("split, genre, perturbation_pair", test_data)
|
115 |
+
def test_reversal_all_equivalent(split, genre, perturbation_pair):
|
116 |
+
|
117 |
+
perturbation1, perturbation2 = perturbation_pair
|
118 |
+
|
119 |
+
if split in ("100M", "10M"):
|
120 |
+
filename = f"{genre}.train"
|
121 |
+
elif split == "test_affected":
|
122 |
+
filename = f"{genre}_affected.test"
|
123 |
+
elif split == "test_unaffected":
|
124 |
+
filename = f"{genre}_unaffected.test"
|
125 |
+
elif split == "dev":
|
126 |
+
filename = f"{genre}.dev"
|
127 |
+
|
128 |
+
path1 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed/babylm_{perturbation1}/babylm_{split}/{filename}"
|
129 |
+
path2 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed/babylm_{perturbation2}/babylm_{split}/{filename}"
|
130 |
+
|
131 |
+
assert lines_equivalent_reversal(path1, path2), f"File {filename} of " + \
|
132 |
+
f"{perturbation1} and {perturbation2} have non-equivalent lines!"
|
133 |
+
|
134 |
+
|
135 |
+
def lines_equivalent_determiner_swap(det_path, ident_path):
|
136 |
+
"""Compare lines of reversal file and identity file after splitting them."""
|
137 |
+
with open(det_path, 'r') as file1, open(ident_path, 'r') as file2:
|
138 |
+
for line1, line2 in zip(file1, file2):
|
139 |
+
# Split each line and compare the resulting lists
|
140 |
+
line1_tokens = set(line1.split())
|
141 |
+
line2_tokens = set(line2.split())
|
142 |
+
if line1_tokens != line2_tokens:
|
143 |
+
print(line1.split())
|
144 |
+
print(line2.split())
|
145 |
+
return False
|
146 |
+
|
147 |
+
# Check if one file has more lines than the other
|
148 |
+
if file1.readline() or file2.readline():
|
149 |
+
return False
|
150 |
+
|
151 |
+
return True
|
152 |
+
|
153 |
+
|
154 |
+
perturbation_pairs_reversal = [
|
155 |
+
("determiner_swap", "determiner_swap_identity"),
|
156 |
+
]
|
157 |
+
test_data = itertools.product(
|
158 |
+
["100M", "dev", "test_affected", "test_unaffected"], GENRES.keys(), perturbation_pairs_reversal)
|
159 |
+
|
160 |
+
@pytest.mark.parametrize("split, genre, perturbation_pair", test_data)
|
161 |
+
def test_determiner_swap_all_equivalent(split, genre, perturbation_pair):
|
162 |
+
|
163 |
+
perturbation1, perturbation2 = perturbation_pair
|
164 |
+
|
165 |
+
if split in ("100M", "10M"):
|
166 |
+
filename = f"{genre}.train"
|
167 |
+
elif split == "test_affected":
|
168 |
+
filename = f"{genre}_affected.test"
|
169 |
+
elif split == "test_unaffected":
|
170 |
+
filename = f"{genre}_unaffected.test"
|
171 |
+
elif split == "dev":
|
172 |
+
filename = f"{genre}.dev"
|
173 |
+
|
174 |
+
path1 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed/babylm_{perturbation1}/babylm_{split}/{filename}"
|
175 |
+
path2 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed/babylm_{perturbation2}/babylm_{split}/{filename}"
|
176 |
+
|
177 |
+
assert lines_equivalent_determiner_swap(path1, path2), f"File {filename} of " + \
|
178 |
+
f"{perturbation1} and {perturbation2} have non-equivalent lines!"
|
179 |
+
|
180 |
+
|
181 |
+
def flatten_list(l):
|
182 |
+
"""Function to flatten a nested list."""
|
183 |
+
return list(itertools.chain.from_iterable(l))
|
184 |
+
|
185 |
+
|
186 |
+
def process_line(line):
|
187 |
+
"""
|
188 |
+
Process a given line from the dataset, apply transformations to its sentences,
|
189 |
+
and categorize them into affected or unaffected based on the transformation.
|
190 |
+
|
191 |
+
Parameters:
|
192 |
+
- line (dict): A dictionary representing a line from the dataset, which contains
|
193 |
+
sentence annotations.
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
- tuple: A tuple containing three lists:
|
197 |
+
1. new_lines_affected (list of str): Sentences that were affected by the transformation.
|
198 |
+
2. new_lines_unaffected (list of str): Sentences that were not affected by the transformation.
|
199 |
+
|
200 |
+
Note:
|
201 |
+
- The transformation functions (`perturbation_function`, `affect_function`, `filter_function`)
|
202 |
+
are expected to be available in the global scope.
|
203 |
+
"""
|
204 |
+
|
205 |
+
new_lines_affected = []
|
206 |
+
new_lines_unaffected = []
|
207 |
+
sents_unaffected = []
|
208 |
+
|
209 |
+
# Apply transformation to each sentence on line
|
210 |
+
for sent in line["sent_annotations"]: # Yj: 这处不明白为什么用annotations不用text?
|
211 |
+
|
212 |
+
tokens = perturbation_function(sent)
|
213 |
+
if len([tok for tok in tokens if tok not in MARKER_TOKEN_IDS]) <= 1:
|
214 |
+
continue
|
215 |
+
|
216 |
+
token_line = " ".join([str(tok) for tok in tokens])
|
217 |
+
|
218 |
+
# Check if sent is affected
|
219 |
+
if affect_function(sent):
|
220 |
+
|
221 |
+
# Check if this affected sentence should be filtered or not
|
222 |
+
if filter_function(sent):
|
223 |
+
new_lines_affected.append(token_line + "\n")
|
224 |
+
|
225 |
+
else: # Unaffected sentences
|
226 |
+
new_lines_unaffected.append(token_line + "\n")
|
227 |
+
sents_unaffected.append(sent["sent_text"] + "\n")
|
228 |
+
|
229 |
+
return new_lines_affected, new_lines_unaffected, sents_unaffected
|
230 |
+
|
231 |
+
|
232 |
+
if __name__ == "__main__":
|
233 |
+
|
234 |
+
parser = argparse.ArgumentParser(
|
235 |
+
prog='Perturb BabyLM dataset',
|
236 |
+
description='Perturb BabyLM dataset by altering POS-tagged data')
|
237 |
+
parser.add_argument('perturbation_type',
|
238 |
+
default='all',
|
239 |
+
const='all',
|
240 |
+
nargs='?',
|
241 |
+
choices=PERTURBATIONS.keys(),
|
242 |
+
help='Perturbation function used to transform BabyLM dataset')
|
243 |
+
parser.add_argument('babylm_dataset',
|
244 |
+
default='all',
|
245 |
+
const='all',
|
246 |
+
nargs='?',
|
247 |
+
choices=BABYLM_SPLITS,
|
248 |
+
help='BabyLM dataset choice')
|
249 |
+
|
250 |
+
# Get args
|
251 |
+
args = parser.parse_args()
|
252 |
+
|
253 |
+
# Load dataset (only json files containing tagged data)
|
254 |
+
babylm_dataset = args.babylm_dataset
|
255 |
+
json_ext = "_parsed.json"
|
256 |
+
# babylm_data = glob(f"{BABYLM_DATA_PATH}/babylm_data/babylm_{babylm_dataset}/*{json_ext}")
|
257 |
+
babylm_data = glob(f"babylm_data/babylm_{babylm_dataset}/*{json_ext}")
|
258 |
+
print("babylm_data:", babylm_data)
|
259 |
+
|
260 |
+
# Get perturbation, affect, and filter functions
|
261 |
+
perturbation_function = PERTURBATIONS[args.perturbation_type]['perturbation_function']
|
262 |
+
affect_function = PERTURBATIONS[args.perturbation_type]['affect_function']
|
263 |
+
filter_function = PERTURBATIONS[args.perturbation_type]['filter_function']
|
264 |
+
gpt2_tokenizer = PERTURBATIONS[args.perturbation_type]['gpt2_tokenizer']
|
265 |
+
|
266 |
+
if babylm_dataset == "test": # Yj: 为什么abylm_dataset是test? BABYLM_SPLITS = ['100M', '10M', 'dev', 'test', 'unittest']
|
267 |
+
|
268 |
+
# Iterate over files and do transform
|
269 |
+
for file in babylm_data:
|
270 |
+
print(file)
|
271 |
+
f = open(file)
|
272 |
+
data = json.load(f)
|
273 |
+
f.close()
|
274 |
+
|
275 |
+
# Perturb data iteratively
|
276 |
+
results = []
|
277 |
+
for line in tqdm.tqdm(data):
|
278 |
+
results.append(process_line(line))
|
279 |
+
|
280 |
+
new_lines_affected, new_lines_unaffected, unaffected_sents = zip(
|
281 |
+
*results)
|
282 |
+
new_lines_affected = flatten_list(new_lines_affected)
|
283 |
+
new_lines_unaffected = flatten_list(new_lines_unaffected)
|
284 |
+
unaffected_sents = flatten_list(unaffected_sents)
|
285 |
+
|
286 |
+
# Name new file
|
287 |
+
new_file_affected = os.path.basename(
|
288 |
+
file).replace(json_ext, "_affected.test")
|
289 |
+
new_file_unaffected = os.path.basename(
|
290 |
+
file).replace(json_ext, "_unaffected.test")
|
291 |
+
file_unaffected_sents = os.path.basename(
|
292 |
+
file).replace(json_ext, "_unaffected_sents.test")
|
293 |
+
|
294 |
+
# Create directory
|
295 |
+
data_write_directory = f"{BABYLM_DATA_PATH}/babylm_data_perturbed"
|
296 |
+
directory_affected = f"{data_write_directory}/babylm_{args.perturbation_type}/babylm_test_affected/"
|
297 |
+
if not os.path.exists(directory_affected):
|
298 |
+
os.makedirs(directory_affected)
|
299 |
+
directory_unaffected = f"{data_write_directory}/babylm_{args.perturbation_type}/babylm_test_unaffected/"
|
300 |
+
if not os.path.exists(directory_unaffected):
|
301 |
+
os.makedirs(directory_unaffected)
|
302 |
+
directory_unaffected_sents = f"{data_write_directory}/babylm_{args.perturbation_type}/babylm_test_unaffected_sents/"
|
303 |
+
if not os.path.exists(directory_unaffected_sents):
|
304 |
+
os.makedirs(directory_unaffected_sents)
|
305 |
+
|
306 |
+
# Write files
|
307 |
+
write_file(directory_affected,
|
308 |
+
new_file_affected, new_lines_affected)
|
309 |
+
write_file(directory_unaffected,
|
310 |
+
new_file_unaffected, new_lines_unaffected)
|
311 |
+
write_file(directory_unaffected_sents,
|
312 |
+
file_unaffected_sents, unaffected_sents)
|
313 |
+
|
314 |
+
else:
|
315 |
+
# Yj: BABYLM_SPLITS = ['100M', '10M', 'dev', 'test', 'unittest']
|
316 |
+
# Iterate over files and do transform
|
317 |
+
for file in babylm_data:
|
318 |
+
print(file)
|
319 |
+
f = open(file)
|
320 |
+
data = json.load(f)
|
321 |
+
f.close()
|
322 |
+
|
323 |
+
# Perturb data iteratively
|
324 |
+
results = []
|
325 |
+
for line in tqdm.tqdm(data):
|
326 |
+
results.append(process_line(line))
|
327 |
+
|
328 |
+
new_lines_affected, new_lines_unaffected, _ = zip(
|
329 |
+
*results)
|
330 |
+
|
331 |
+
new_lines_affected = flatten_list(new_lines_affected)
|
332 |
+
new_lines_unaffected = flatten_list(new_lines_unaffected)
|
333 |
+
|
334 |
+
# Combine affected and unaffected sentences
|
335 |
+
new_lines = new_lines_unaffected + new_lines_affected
|
336 |
+
|
337 |
+
# Name new file
|
338 |
+
if babylm_dataset == "dev":
|
339 |
+
new_file = os.path.basename(file).replace(json_ext, ".dev")
|
340 |
+
elif babylm_dataset == 'unittest':
|
341 |
+
new_file = os.path.basename(file).replace(json_ext, ".test")
|
342 |
+
|
343 |
+
# Print strings for unittest
|
344 |
+
new_lines_decoded = [gpt2_tokenizer.decode(
|
345 |
+
[int(tok) for tok in line.split()]) + "\n" for line in new_lines]
|
346 |
+
new_lines_with_strings = []
|
347 |
+
for tokens, line in list(zip(new_lines, new_lines_decoded)):
|
348 |
+
new_lines_with_strings.append(tokens)
|
349 |
+
new_lines_with_strings.append(line)
|
350 |
+
new_lines = new_lines_with_strings
|
351 |
+
|
352 |
+
else:
|
353 |
+
new_file = os.path.basename(file).replace(json_ext, ".train") # '10M 100M' is training set
|
354 |
+
|
355 |
+
# Create directory and write file
|
356 |
+
directory = f"{BABYLM_DATA_PATH}/babylm_data_perturbed/babylm_{args.perturbation_type}/babylm_{babylm_dataset}/"
|
357 |
+
if not os.path.exists(directory):
|
358 |
+
os.makedirs(directory)
|
359 |
+
write_file(directory, new_file, new_lines)
|
data/perturb.sh
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/sh
|
2 |
+
# perturb.sh
|
3 |
+
# author: Julie Kallini
|
4 |
+
|
5 |
+
echo "
|
6 |
+
-------------------------------------------------------------------------------
|
7 |
+
Arguments
|
8 |
+
-------------------------------------------------------------------------------
|
9 |
+
"
|
10 |
+
echo "Perturbation type: $1"
|
11 |
+
echo "Train set: $2"
|
12 |
+
|
13 |
+
|
14 |
+
# Create perturbed dataset for all splits
|
15 |
+
echo "
|
16 |
+
-------------------------------------------------------------------------------
|
17 |
+
Creating perturbed dataset for all splits
|
18 |
+
-------------------------------------------------------------------------------
|
19 |
+
"
|
20 |
+
|
21 |
+
cd ../data
|
22 |
+
|
23 |
+
echo "python3 perturb.py $1 $2"
|
24 |
+
python3 perturb.py $1 $2
|
25 |
+
echo "
|
26 |
+
python3 perturb.py $1 dev"
|
27 |
+
python3 perturb.py $1 dev
|
28 |
+
echo "
|
29 |
+
python3 perturb.py $1 test"
|
30 |
+
python3 perturb.py $1 test
|
31 |
+
echo "
|
32 |
+
python3 perturb.py $1 unittest"
|
33 |
+
python3 perturb.py $1 unittest
|
34 |
+
|
35 |
+
cd ..
|
data/perturb_llama.py
ADDED
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# perturb.py
|
2 |
+
# Author: Julie Kallini
|
3 |
+
|
4 |
+
# For importing utils
|
5 |
+
import sys
|
6 |
+
sys.path.append("..")
|
7 |
+
|
8 |
+
from utils_llama import PERTURBATIONS, BABYLM_SPLITS, BABYLM_DATA_PATH, \
|
9 |
+
GENRES, MARKER_TOKEN_IDS, marker_sg_token, marker_pl_token, marker_rev_token, write_file
|
10 |
+
from glob import glob
|
11 |
+
import numpy as np
|
12 |
+
import itertools
|
13 |
+
import json
|
14 |
+
import os
|
15 |
+
import tqdm
|
16 |
+
import argparse
|
17 |
+
import pytest
|
18 |
+
|
19 |
+
MODEL_NAME = "Llama-3.2-3B"
|
20 |
+
|
21 |
+
def lines_equivalent_3pres(file1_path, file2_path):
|
22 |
+
"""Compare lines of two files after splitting them."""
|
23 |
+
with open(file1_path, 'r') as file1, open(file2_path, 'r') as file2:
|
24 |
+
for line1, line2 in zip(file1, file2):
|
25 |
+
# Split each line and compare the resulting lists
|
26 |
+
res1 = [i for i in line1.split() if int(
|
27 |
+
i) not in (marker_sg_token, marker_pl_token)]
|
28 |
+
res2 = [i for i in line2.split() if int(
|
29 |
+
i) not in (marker_sg_token, marker_pl_token)]
|
30 |
+
if res1 != res2:
|
31 |
+
print(line1)
|
32 |
+
print(line2)
|
33 |
+
return False
|
34 |
+
|
35 |
+
# Check if one file has more lines than the other
|
36 |
+
if file1.readline() or file2.readline():
|
37 |
+
return False
|
38 |
+
|
39 |
+
return True
|
40 |
+
|
41 |
+
|
42 |
+
perturbation_pairs_3pres = [
|
43 |
+
("0tokens", "4tokens"),
|
44 |
+
("0tokens", "4words"),
|
45 |
+
("4tokens", "4words"),
|
46 |
+
]
|
47 |
+
|
48 |
+
# Yj: 针对与第三人称单数/复数相关的扰动对进行组合测试
|
49 |
+
|
50 |
+
test_data = itertools.product(
|
51 |
+
["100M", "dev", "test_affected", "test_unaffected"], GENRES.keys(), perturbation_pairs_3pres) # Yj: generate different pairs used in test
|
52 |
+
|
53 |
+
# Yj: 用于在测试函数中,例如 test_3pres_all_equivalent,生成各种测试组合,包括不同的扰动策略。
|
54 |
+
# Yj: 区分受影响和未受影响的测试子集,以比较扰动前后的效果。
|
55 |
+
|
56 |
+
|
57 |
+
@pytest.mark.parametrize("split, genre, perturbation_pair", test_data) # 测试函数会针对 test_data 中的每组参数运行一次
|
58 |
+
def test_3pres_all_equivalent(split, genre, perturbation_pair): # Yj: genre these are different kinds of Corpus, which can be seen in utils.py
|
59 |
+
|
60 |
+
perturbation1, perturbation2 = perturbation_pair
|
61 |
+
|
62 |
+
if split in ("100M", "10M"):
|
63 |
+
filename = f"{genre}.train"
|
64 |
+
elif split == "test_affected":
|
65 |
+
filename = f"{genre}_affected.test"
|
66 |
+
elif split == "test_unaffected":
|
67 |
+
filename = f"{genre}_unaffected.test"
|
68 |
+
elif split == "dev":
|
69 |
+
filename = f"{genre}.dev" # Yj: Development Set is similar to Validation Set
|
70 |
+
|
71 |
+
path1 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_llama/babylm_3pres_{perturbation1}/babylm_{split}/{filename}"
|
72 |
+
path2 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_llama/babylm_3pres_{perturbation2}/babylm_{split}/{filename}"
|
73 |
+
|
74 |
+
#Yj: compare two files in two paths
|
75 |
+
assert lines_equivalent_3pres(path1, path2), f"File {filename} of " + \
|
76 |
+
f"3pres_{perturbation1} and 3pres_{perturbation2} have non-equivalent lines!"
|
77 |
+
|
78 |
+
|
79 |
+
def lines_equivalent_reversal(rev_path, ident_path):
|
80 |
+
"""Compare lines of reversal file and identity file after splitting them."""
|
81 |
+
with open(rev_path, 'r') as file1, open(ident_path, 'r') as file2:
|
82 |
+
for line1, line2 in zip(file1, file2):
|
83 |
+
# Split each line and compare the resulting lists
|
84 |
+
line1_tokens = line1.split()
|
85 |
+
line2_tokens = line2.split()
|
86 |
+
|
87 |
+
# Get REV marker index
|
88 |
+
marker_index = line1_tokens.index(str(marker_rev_token))
|
89 |
+
|
90 |
+
# Make sure tokens up to and including the marker are all the same
|
91 |
+
if line1_tokens[:marker_index+1] != line2_tokens[:marker_index+1]:
|
92 |
+
return False
|
93 |
+
|
94 |
+
# Make sure reversal of rest of string is equal to identity
|
95 |
+
line1_tokens_rev = line1_tokens[marker_index+1:].copy()
|
96 |
+
line1_tokens_rev.reverse()
|
97 |
+
if line1_tokens_rev != line2_tokens[marker_index+1:]:
|
98 |
+
return False
|
99 |
+
|
100 |
+
# Check if one file has more lines than the other
|
101 |
+
if file1.readline() or file2.readline():
|
102 |
+
return False
|
103 |
+
|
104 |
+
return True
|
105 |
+
|
106 |
+
|
107 |
+
perturbation_pairs_reversal = [
|
108 |
+
("reversal", "reversal_identity"),
|
109 |
+
]
|
110 |
+
# Yj: 针对反转扰动对进行组合测试
|
111 |
+
|
112 |
+
test_data = itertools.product(
|
113 |
+
["100M", "dev", "test_affected"], GENRES.keys(), perturbation_pairs_reversal)
|
114 |
+
|
115 |
+
@pytest.mark.parametrize("split, genre, perturbation_pair", test_data)
|
116 |
+
def test_reversal_all_equivalent(split, genre, perturbation_pair):
|
117 |
+
|
118 |
+
perturbation1, perturbation2 = perturbation_pair
|
119 |
+
|
120 |
+
if split in ("100M", "10M"):
|
121 |
+
filename = f"{genre}.train"
|
122 |
+
elif split == "test_affected":
|
123 |
+
filename = f"{genre}_affected.test"
|
124 |
+
elif split == "test_unaffected":
|
125 |
+
filename = f"{genre}_unaffected.test"
|
126 |
+
elif split == "dev":
|
127 |
+
filename = f"{genre}.dev"
|
128 |
+
|
129 |
+
path1 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_llama/babylm_{perturbation1}/babylm_{split}/{filename}"
|
130 |
+
path2 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_llama/babylm_{perturbation2}/babylm_{split}/{filename}"
|
131 |
+
|
132 |
+
assert lines_equivalent_reversal(path1, path2), f"File {filename} of " + \
|
133 |
+
f"{perturbation1} and {perturbation2} have non-equivalent lines!"
|
134 |
+
|
135 |
+
|
136 |
+
def lines_equivalent_determiner_swap(det_path, ident_path):
|
137 |
+
"""Compare lines of reversal file and identity file after splitting them."""
|
138 |
+
with open(det_path, 'r') as file1, open(ident_path, 'r') as file2:
|
139 |
+
for line1, line2 in zip(file1, file2):
|
140 |
+
# Split each line and compare the resulting lists
|
141 |
+
line1_tokens = set(line1.split())
|
142 |
+
line2_tokens = set(line2.split())
|
143 |
+
if line1_tokens != line2_tokens:
|
144 |
+
print(line1.split())
|
145 |
+
print(line2.split())
|
146 |
+
return False
|
147 |
+
|
148 |
+
# Check if one file has more lines than the other
|
149 |
+
if file1.readline() or file2.readline():
|
150 |
+
return False
|
151 |
+
|
152 |
+
return True
|
153 |
+
|
154 |
+
|
155 |
+
perturbation_pairs_reversal = [
|
156 |
+
("determiner_swap", "determiner_swap_identity"),
|
157 |
+
]
|
158 |
+
test_data = itertools.product(
|
159 |
+
["100M", "dev", "test_affected", "test_unaffected"], GENRES.keys(), perturbation_pairs_reversal)
|
160 |
+
|
161 |
+
@pytest.mark.parametrize("split, genre, perturbation_pair", test_data)
|
162 |
+
def test_determiner_swap_all_equivalent(split, genre, perturbation_pair):
|
163 |
+
|
164 |
+
perturbation1, perturbation2 = perturbation_pair
|
165 |
+
|
166 |
+
if split in ("100M", "10M"):
|
167 |
+
filename = f"{genre}.train"
|
168 |
+
elif split == "test_affected":
|
169 |
+
filename = f"{genre}_affected.test"
|
170 |
+
elif split == "test_unaffected":
|
171 |
+
filename = f"{genre}_unaffected.test"
|
172 |
+
elif split == "dev":
|
173 |
+
filename = f"{genre}.dev"
|
174 |
+
|
175 |
+
path1 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_llama/babylm_{perturbation1}/babylm_{split}/{filename}"
|
176 |
+
path2 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_llama/babylm_{perturbation2}/babylm_{split}/{filename}"
|
177 |
+
|
178 |
+
assert lines_equivalent_determiner_swap(path1, path2), f"File {filename} of " + \
|
179 |
+
f"{perturbation1} and {perturbation2} have non-equivalent lines!"
|
180 |
+
|
181 |
+
|
182 |
+
def flatten_list(l):
|
183 |
+
"""Function to flatten a nested list."""
|
184 |
+
return list(itertools.chain.from_iterable(l))
|
185 |
+
|
186 |
+
|
187 |
+
def process_line(line):
|
188 |
+
"""
|
189 |
+
Process a given line from the dataset, apply transformations to its sentences,
|
190 |
+
and categorize them into affected or unaffected based on the transformation.
|
191 |
+
|
192 |
+
Parameters:
|
193 |
+
- line (dict): A dictionary representing a line from the dataset, which contains
|
194 |
+
sentence annotations.
|
195 |
+
|
196 |
+
Returns:
|
197 |
+
- tuple: A tuple containing three lists:
|
198 |
+
1. new_lines_affected (list of str): Sentences that were affected by the transformation.
|
199 |
+
2. new_lines_unaffected (list of str): Sentences that were not affected by the transformation.
|
200 |
+
|
201 |
+
Note:
|
202 |
+
- The transformation functions (`perturbation_function`, `affect_function`, `filter_function`)
|
203 |
+
are expected to be available in the global scope.
|
204 |
+
"""
|
205 |
+
|
206 |
+
new_lines_affected = []
|
207 |
+
new_lines_unaffected = []
|
208 |
+
sents_unaffected = []
|
209 |
+
|
210 |
+
# Apply transformation to each sentence on line
|
211 |
+
for sent in line["sent_annotations"]: # Yj: 这处不明白为什么用annotations不用text?
|
212 |
+
|
213 |
+
tokens = perturbation_function(sent)
|
214 |
+
if len([tok for tok in tokens if tok not in MARKER_TOKEN_IDS]) <= 1:
|
215 |
+
continue
|
216 |
+
|
217 |
+
token_line = " ".join([str(tok) for tok in tokens])
|
218 |
+
|
219 |
+
# Check if sent is affected
|
220 |
+
if affect_function(sent):
|
221 |
+
|
222 |
+
# Check if this affected sentence should be filtered or not
|
223 |
+
if filter_function(sent):
|
224 |
+
new_lines_affected.append(token_line + "\n")
|
225 |
+
|
226 |
+
else: # Unaffected sentences
|
227 |
+
new_lines_unaffected.append(token_line + "\n")
|
228 |
+
sents_unaffected.append(sent["sent_text"] + "\n")
|
229 |
+
|
230 |
+
return new_lines_affected, new_lines_unaffected, sents_unaffected
|
231 |
+
|
232 |
+
|
233 |
+
if __name__ == "__main__":
|
234 |
+
|
235 |
+
parser = argparse.ArgumentParser(
|
236 |
+
prog='Perturb BabyLM dataset',
|
237 |
+
description='Perturb BabyLM dataset by altering POS-tagged data')
|
238 |
+
parser.add_argument('perturbation_type',
|
239 |
+
default='all',
|
240 |
+
const='all',
|
241 |
+
nargs='?',
|
242 |
+
choices=PERTURBATIONS.keys(),
|
243 |
+
help='Perturbation function used to transform BabyLM dataset')
|
244 |
+
parser.add_argument('babylm_dataset',
|
245 |
+
default='all',
|
246 |
+
const='all',
|
247 |
+
nargs='?',
|
248 |
+
choices=BABYLM_SPLITS,
|
249 |
+
help='BabyLM dataset choice')
|
250 |
+
|
251 |
+
# Get args
|
252 |
+
args = parser.parse_args()
|
253 |
+
|
254 |
+
# Load dataset (only json files containing tagged data)
|
255 |
+
babylm_dataset = args.babylm_dataset
|
256 |
+
json_ext = "_parsed.json"
|
257 |
+
# babylm_data = glob(f"{BABYLM_DATA_PATH}/babylm_data/babylm_{babylm_dataset}/*{json_ext}")
|
258 |
+
babylm_data = glob(f"babylm_data/babylm_{babylm_dataset}/*{json_ext}")
|
259 |
+
print("babylm_data:", babylm_data)
|
260 |
+
|
261 |
+
# Get perturbation, affect, and filter functions
|
262 |
+
perturbation_function = PERTURBATIONS[args.perturbation_type]['perturbation_function']
|
263 |
+
affect_function = PERTURBATIONS[args.perturbation_type]['affect_function']
|
264 |
+
filter_function = PERTURBATIONS[args.perturbation_type]['filter_function']
|
265 |
+
llama_tokenizer = PERTURBATIONS[args.perturbation_type]['llama_tokenizer']
|
266 |
+
|
267 |
+
if babylm_dataset == "test": # Yj: 为什么abylm_dataset是test? BABYLM_SPLITS = ['100M', '10M', 'dev', 'test', 'unittest']
|
268 |
+
|
269 |
+
# Iterate over files and do transform
|
270 |
+
for file in babylm_data:
|
271 |
+
print(file)
|
272 |
+
f = open(file)
|
273 |
+
data = json.load(f)
|
274 |
+
f.close()
|
275 |
+
|
276 |
+
# Perturb data iteratively
|
277 |
+
results = []
|
278 |
+
for line in tqdm.tqdm(data):
|
279 |
+
results.append(process_line(line))
|
280 |
+
|
281 |
+
new_lines_affected, new_lines_unaffected, unaffected_sents = zip(
|
282 |
+
*results)
|
283 |
+
new_lines_affected = flatten_list(new_lines_affected)
|
284 |
+
new_lines_unaffected = flatten_list(new_lines_unaffected)
|
285 |
+
unaffected_sents = flatten_list(unaffected_sents)
|
286 |
+
|
287 |
+
# Name new file
|
288 |
+
new_file_affected = os.path.basename(
|
289 |
+
file).replace(json_ext, "_affected.test")
|
290 |
+
new_file_unaffected = os.path.basename(
|
291 |
+
file).replace(json_ext, "_unaffected.test")
|
292 |
+
file_unaffected_sents = os.path.basename(
|
293 |
+
file).replace(json_ext, "_unaffected_sents.test")
|
294 |
+
|
295 |
+
# Create directory
|
296 |
+
data_write_directory = f"{BABYLM_DATA_PATH}/Perturbed_data/{MODEL_NAME}"
|
297 |
+
directory_affected = f"{data_write_directory}/babylm_{args.perturbation_type}/babylm_test_affected/"
|
298 |
+
if not os.path.exists(directory_affected):
|
299 |
+
os.makedirs(directory_affected)
|
300 |
+
directory_unaffected = f"{data_write_directory}/babylm_{args.perturbation_type}/babylm_test_unaffected/"
|
301 |
+
if not os.path.exists(directory_unaffected):
|
302 |
+
os.makedirs(directory_unaffected)
|
303 |
+
directory_unaffected_sents = f"{data_write_directory}/babylm_{args.perturbation_type}/babylm_test_unaffected_sents/"
|
304 |
+
if not os.path.exists(directory_unaffected_sents):
|
305 |
+
os.makedirs(directory_unaffected_sents)
|
306 |
+
|
307 |
+
# Write files
|
308 |
+
write_file(directory_affected,
|
309 |
+
new_file_affected, new_lines_affected)
|
310 |
+
write_file(directory_unaffected,
|
311 |
+
new_file_unaffected, new_lines_unaffected)
|
312 |
+
write_file(directory_unaffected_sents,
|
313 |
+
file_unaffected_sents, unaffected_sents)
|
314 |
+
|
315 |
+
else:
|
316 |
+
# Yj: BABYLM_SPLITS = ['100M', '10M', 'dev', 'test', 'unittest']
|
317 |
+
# Iterate over files and do transform
|
318 |
+
for file in babylm_data:
|
319 |
+
print(file)
|
320 |
+
f = open(file)
|
321 |
+
data = json.load(f)
|
322 |
+
f.close()
|
323 |
+
|
324 |
+
# Perturb data iteratively
|
325 |
+
results = []
|
326 |
+
for line in tqdm.tqdm(data):
|
327 |
+
results.append(process_line(line))
|
328 |
+
|
329 |
+
new_lines_affected, new_lines_unaffected, _ = zip(
|
330 |
+
*results)
|
331 |
+
|
332 |
+
new_lines_affected = flatten_list(new_lines_affected)
|
333 |
+
new_lines_unaffected = flatten_list(new_lines_unaffected)
|
334 |
+
|
335 |
+
# Combine affected and unaffected sentences
|
336 |
+
new_lines = new_lines_unaffected + new_lines_affected
|
337 |
+
|
338 |
+
# Name new file
|
339 |
+
if babylm_dataset == "dev":
|
340 |
+
new_file = os.path.basename(file).replace(json_ext, ".dev")
|
341 |
+
elif babylm_dataset == 'unittest':
|
342 |
+
new_file = os.path.basename(file).replace(json_ext, ".test")
|
343 |
+
|
344 |
+
# Print strings for unittest
|
345 |
+
new_lines_decoded = [llama_tokenizer.decode(
|
346 |
+
[int(tok) for tok in line.split()]) + "\n" for line in new_lines]
|
347 |
+
new_lines_with_strings = []
|
348 |
+
for tokens, line in list(zip(new_lines, new_lines_decoded)):
|
349 |
+
new_lines_with_strings.append(tokens)
|
350 |
+
new_lines_with_strings.append(line)
|
351 |
+
new_lines = new_lines_with_strings
|
352 |
+
|
353 |
+
else:
|
354 |
+
new_file = os.path.basename(file).replace(json_ext, ".train") # '10M 100M' is training set
|
355 |
+
|
356 |
+
# Create directory and write file
|
357 |
+
directory = f"{BABYLM_DATA_PATH}/Perturbed_data/{MODEL_NAME}/babylm_{args.perturbation_type}/babylm_{babylm_dataset}/"
|
358 |
+
print("directory:", directory)
|
359 |
+
if not os.path.exists(directory):
|
360 |
+
os.makedirs(directory)
|
361 |
+
write_file(directory, new_file, new_lines)
|
data/perturb_model.sh
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Define your perturbations and BabyLM splits
|
4 |
+
PERTURBATIONS=("hop_control" "hop_tokens4" "hop_words4" "reverse_control" "reverse_partial" "reverse_full" "shuffle_control"
|
5 |
+
"shuffle_nondeterministic" "shuffle_deterministic21" "shuffle_deterministic57" "shuffle_deterministic84" "shuffle_local3"
|
6 |
+
"shuffle_local5" "shuffle_local10" "shuffle_even_odd")
|
7 |
+
|
8 |
+
# BABYLM_SPLITS=("100M" "10M" "dev" "test" "unittest") # Add more splits as needed
|
9 |
+
BABYLM_SPLITS=("dev")
|
10 |
+
|
11 |
+
# Specify the GPUs to use
|
12 |
+
SPECIFIED_GPUS=(1 2 3 4 5 6 7) # Set these to the GPUs you want to use
|
13 |
+
|
14 |
+
# Store PIDs and Gpu mapping to track running processes
|
15 |
+
declare -A GPU_PROCESS_MAP
|
16 |
+
|
17 |
+
# Iterate over all combinations of perturbations and splits
|
18 |
+
for perturbation in "${PERTURBATIONS[@]}"; do
|
19 |
+
for split in "${BABYLM_SPLITS[@]}"; do
|
20 |
+
|
21 |
+
# Check for a free GPU
|
22 |
+
while true; do
|
23 |
+
for gpu in "${SPECIFIED_GPUS[@]}"; do
|
24 |
+
# Check if there's no process associated with this GPU
|
25 |
+
if ! ps -p ${GPU_PROCESS_MAP[$gpu]} > /dev/null 2>&1; then
|
26 |
+
# Run the Python perturbation script on the available GPU
|
27 |
+
CUDA_VISIBLE_DEVICES=$gpu python perturb_llama.py "$perturbation" "$split" &
|
28 |
+
GPU_PROCESS_MAP[$gpu]=$!
|
29 |
+
echo "Running on GPU $gpu: Perturbation=$perturbation, Split=$split, PID=$!"
|
30 |
+
break 2 # Break out of the loops once a GPU is assigned
|
31 |
+
fi
|
32 |
+
done
|
33 |
+
sleep 1 # Wait a second before checking again
|
34 |
+
done
|
35 |
+
done
|
36 |
+
done
|
37 |
+
|
38 |
+
# Wait for all processes to finish
|
39 |
+
wait
|
40 |
+
echo "All tasks completed."
|
data/perturb_qwen.py
ADDED
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# perturb.py
|
2 |
+
# Author: Julie Kallini
|
3 |
+
|
4 |
+
# For importing utils
|
5 |
+
import sys
|
6 |
+
sys.path.append("..")
|
7 |
+
|
8 |
+
from utils_qwen import PERTURBATIONS, BABYLM_SPLITS, BABYLM_DATA_PATH, \
|
9 |
+
GENRES, MARKER_TOKEN_IDS, marker_sg_token, marker_pl_token, marker_rev_token, write_file
|
10 |
+
from glob import glob
|
11 |
+
import numpy as np
|
12 |
+
import itertools
|
13 |
+
import json
|
14 |
+
import os
|
15 |
+
import tqdm
|
16 |
+
import argparse
|
17 |
+
import pytest
|
18 |
+
|
19 |
+
MODEL_NAME = "Qwen2.5-7B"
|
20 |
+
|
21 |
+
def lines_equivalent_3pres(file1_path, file2_path):
|
22 |
+
"""Compare lines of two files after splitting them."""
|
23 |
+
with open(file1_path, 'r') as file1, open(file2_path, 'r') as file2:
|
24 |
+
for line1, line2 in zip(file1, file2):
|
25 |
+
# Split each line and compare the resulting lists
|
26 |
+
res1 = [i for i in line1.split() if int(
|
27 |
+
i) not in (marker_sg_token, marker_pl_token)]
|
28 |
+
res2 = [i for i in line2.split() if int(
|
29 |
+
i) not in (marker_sg_token, marker_pl_token)]
|
30 |
+
if res1 != res2:
|
31 |
+
print(line1)
|
32 |
+
print(line2)
|
33 |
+
return False
|
34 |
+
|
35 |
+
# Check if one file has more lines than the other
|
36 |
+
if file1.readline() or file2.readline():
|
37 |
+
return False
|
38 |
+
|
39 |
+
return True
|
40 |
+
|
41 |
+
|
42 |
+
perturbation_pairs_3pres = [
|
43 |
+
("0tokens", "4tokens"),
|
44 |
+
("0tokens", "4words"),
|
45 |
+
("4tokens", "4words"),
|
46 |
+
]
|
47 |
+
|
48 |
+
# Yj: 针对与第三人称单数/复数相关的扰动对进行组合测试
|
49 |
+
|
50 |
+
test_data = itertools.product(
|
51 |
+
["100M", "dev", "test_affected", "test_unaffected"], GENRES.keys(), perturbation_pairs_3pres) # Yj: generate different pairs used in test
|
52 |
+
|
53 |
+
# Yj: 用于在测试函数中,例如 test_3pres_all_equivalent,生成各种测试组合,包括不同的扰动策略。
|
54 |
+
# Yj: 区分受影响和未受影响的测试子集,以比较扰动前后的效果。
|
55 |
+
|
56 |
+
|
57 |
+
@pytest.mark.parametrize("split, genre, perturbation_pair", test_data) # 测试函数会针对 test_data 中的每组参数运行一次
|
58 |
+
def test_3pres_all_equivalent(split, genre, perturbation_pair): # Yj: genre these are different kinds of Corpus, which can be seen in utils.py
|
59 |
+
|
60 |
+
perturbation1, perturbation2 = perturbation_pair
|
61 |
+
|
62 |
+
if split in ("100M", "10M"):
|
63 |
+
filename = f"{genre}.train"
|
64 |
+
elif split == "test_affected":
|
65 |
+
filename = f"{genre}_affected.test"
|
66 |
+
elif split == "test_unaffected":
|
67 |
+
filename = f"{genre}_unaffected.test"
|
68 |
+
elif split == "dev":
|
69 |
+
filename = f"{genre}.dev" # Yj: Development Set is similar to Validation Set
|
70 |
+
|
71 |
+
path1 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_qwen/babylm_3pres_{perturbation1}/babylm_{split}/{filename}"
|
72 |
+
path2 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_qwen/babylm_3pres_{perturbation2}/babylm_{split}/{filename}"
|
73 |
+
|
74 |
+
#Yj: compare two files in two paths
|
75 |
+
assert lines_equivalent_3pres(path1, path2), f"File {filename} of " + \
|
76 |
+
f"3pres_{perturbation1} and 3pres_{perturbation2} have non-equivalent lines!"
|
77 |
+
|
78 |
+
|
79 |
+
def lines_equivalent_reversal(rev_path, ident_path):
|
80 |
+
"""Compare lines of reversal file and identity file after splitting them."""
|
81 |
+
with open(rev_path, 'r') as file1, open(ident_path, 'r') as file2:
|
82 |
+
for line1, line2 in zip(file1, file2):
|
83 |
+
# Split each line and compare the resulting lists
|
84 |
+
line1_tokens = line1.split()
|
85 |
+
line2_tokens = line2.split()
|
86 |
+
|
87 |
+
# Get REV marker index
|
88 |
+
marker_index = line1_tokens.index(str(marker_rev_token))
|
89 |
+
|
90 |
+
# Make sure tokens up to and including the marker are all the same
|
91 |
+
if line1_tokens[:marker_index+1] != line2_tokens[:marker_index+1]:
|
92 |
+
return False
|
93 |
+
|
94 |
+
# Make sure reversal of rest of string is equal to identity
|
95 |
+
line1_tokens_rev = line1_tokens[marker_index+1:].copy()
|
96 |
+
line1_tokens_rev.reverse()
|
97 |
+
if line1_tokens_rev != line2_tokens[marker_index+1:]:
|
98 |
+
return False
|
99 |
+
|
100 |
+
# Check if one file has more lines than the other
|
101 |
+
if file1.readline() or file2.readline():
|
102 |
+
return False
|
103 |
+
|
104 |
+
return True
|
105 |
+
|
106 |
+
|
107 |
+
perturbation_pairs_reversal = [
|
108 |
+
("reversal", "reversal_identity"),
|
109 |
+
]
|
110 |
+
# Yj: 针对反转扰动对进行组合测试
|
111 |
+
|
112 |
+
test_data = itertools.product(
|
113 |
+
["100M", "dev", "test_affected"], GENRES.keys(), perturbation_pairs_reversal)
|
114 |
+
|
115 |
+
@pytest.mark.parametrize("split, genre, perturbation_pair", test_data)
|
116 |
+
def test_reversal_all_equivalent(split, genre, perturbation_pair):
|
117 |
+
|
118 |
+
perturbation1, perturbation2 = perturbation_pair
|
119 |
+
|
120 |
+
if split in ("100M", "10M"):
|
121 |
+
filename = f"{genre}.train"
|
122 |
+
elif split == "test_affected":
|
123 |
+
filename = f"{genre}_affected.test"
|
124 |
+
elif split == "test_unaffected":
|
125 |
+
filename = f"{genre}_unaffected.test"
|
126 |
+
elif split == "dev":
|
127 |
+
filename = f"{genre}.dev"
|
128 |
+
|
129 |
+
path1 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_qwen/babylm_{perturbation1}/babylm_{split}/{filename}"
|
130 |
+
path2 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_qwen/babylm_{perturbation2}/babylm_{split}/{filename}"
|
131 |
+
|
132 |
+
assert lines_equivalent_reversal(path1, path2), f"File {filename} of " + \
|
133 |
+
f"{perturbation1} and {perturbation2} have non-equivalent lines!"
|
134 |
+
|
135 |
+
|
136 |
+
def lines_equivalent_determiner_swap(det_path, ident_path):
|
137 |
+
"""Compare lines of reversal file and identity file after splitting them."""
|
138 |
+
with open(det_path, 'r') as file1, open(ident_path, 'r') as file2:
|
139 |
+
for line1, line2 in zip(file1, file2):
|
140 |
+
# Split each line and compare the resulting lists
|
141 |
+
line1_tokens = set(line1.split())
|
142 |
+
line2_tokens = set(line2.split())
|
143 |
+
if line1_tokens != line2_tokens:
|
144 |
+
print(line1.split())
|
145 |
+
print(line2.split())
|
146 |
+
return False
|
147 |
+
|
148 |
+
# Check if one file has more lines than the other
|
149 |
+
if file1.readline() or file2.readline():
|
150 |
+
return False
|
151 |
+
|
152 |
+
return True
|
153 |
+
|
154 |
+
|
155 |
+
perturbation_pairs_reversal = [
|
156 |
+
("determiner_swap", "determiner_swap_identity"),
|
157 |
+
]
|
158 |
+
test_data = itertools.product(
|
159 |
+
["100M", "dev", "test_affected", "test_unaffected"], GENRES.keys(), perturbation_pairs_reversal)
|
160 |
+
|
161 |
+
@pytest.mark.parametrize("split, genre, perturbation_pair", test_data)
|
162 |
+
def test_determiner_swap_all_equivalent(split, genre, perturbation_pair):
|
163 |
+
|
164 |
+
perturbation1, perturbation2 = perturbation_pair
|
165 |
+
|
166 |
+
if split in ("100M", "10M"):
|
167 |
+
filename = f"{genre}.train"
|
168 |
+
elif split == "test_affected":
|
169 |
+
filename = f"{genre}_affected.test"
|
170 |
+
elif split == "test_unaffected":
|
171 |
+
filename = f"{genre}_unaffected.test"
|
172 |
+
elif split == "dev":
|
173 |
+
filename = f"{genre}.dev"
|
174 |
+
|
175 |
+
path1 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_qwen/babylm_{perturbation1}/babylm_{split}/{filename}"
|
176 |
+
path2 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_qwen/babylm_{perturbation2}/babylm_{split}/{filename}"
|
177 |
+
|
178 |
+
assert lines_equivalent_determiner_swap(path1, path2), f"File {filename} of " + \
|
179 |
+
f"{perturbation1} and {perturbation2} have non-equivalent lines!"
|
180 |
+
|
181 |
+
|
182 |
+
def flatten_list(l):
|
183 |
+
"""Function to flatten a nested list."""
|
184 |
+
return list(itertools.chain.from_iterable(l))
|
185 |
+
|
186 |
+
|
187 |
+
def process_line(line):
|
188 |
+
"""
|
189 |
+
Process a given line from the dataset, apply transformations to its sentences,
|
190 |
+
and categorize them into affected or unaffected based on the transformation.
|
191 |
+
|
192 |
+
Parameters:
|
193 |
+
- line (dict): A dictionary representing a line from the dataset, which contains
|
194 |
+
sentence annotations.
|
195 |
+
|
196 |
+
Returns:
|
197 |
+
- tuple: A tuple containing three lists:
|
198 |
+
1. new_lines_affected (list of str): Sentences that were affected by the transformation.
|
199 |
+
2. new_lines_unaffected (list of str): Sentences that were not affected by the transformation.
|
200 |
+
|
201 |
+
Note:
|
202 |
+
- The transformation functions (`perturbation_function`, `affect_function`, `filter_function`)
|
203 |
+
are expected to be available in the global scope.
|
204 |
+
"""
|
205 |
+
|
206 |
+
new_lines_affected = []
|
207 |
+
new_lines_unaffected = []
|
208 |
+
sents_unaffected = []
|
209 |
+
|
210 |
+
# Apply transformation to each sentence on line
|
211 |
+
for sent in line["sent_annotations"]: # Yj: 这处不明白为什么用annotations不用text?
|
212 |
+
|
213 |
+
tokens = perturbation_function(sent)
|
214 |
+
if len([tok for tok in tokens if tok not in MARKER_TOKEN_IDS]) <= 1:
|
215 |
+
continue
|
216 |
+
|
217 |
+
token_line = " ".join([str(tok) for tok in tokens])
|
218 |
+
|
219 |
+
# Check if sent is affected
|
220 |
+
if affect_function(sent):
|
221 |
+
|
222 |
+
# Check if this affected sentence should be filtered or not
|
223 |
+
if filter_function(sent):
|
224 |
+
new_lines_affected.append(token_line + "\n")
|
225 |
+
|
226 |
+
else: # Unaffected sentences
|
227 |
+
new_lines_unaffected.append(token_line + "\n")
|
228 |
+
sents_unaffected.append(sent["sent_text"] + "\n")
|
229 |
+
|
230 |
+
return new_lines_affected, new_lines_unaffected, sents_unaffected
|
231 |
+
|
232 |
+
|
233 |
+
if __name__ == "__main__":
|
234 |
+
|
235 |
+
parser = argparse.ArgumentParser(
|
236 |
+
prog='Perturb BabyLM dataset',
|
237 |
+
description='Perturb BabyLM dataset by altering POS-tagged data')
|
238 |
+
parser.add_argument('perturbation_type',
|
239 |
+
default='all',
|
240 |
+
const='all',
|
241 |
+
nargs='?',
|
242 |
+
choices=PERTURBATIONS.keys(),
|
243 |
+
help='Perturbation function used to transform BabyLM dataset')
|
244 |
+
parser.add_argument('babylm_dataset',
|
245 |
+
default='all',
|
246 |
+
const='all',
|
247 |
+
nargs='?',
|
248 |
+
choices=BABYLM_SPLITS,
|
249 |
+
help='BabyLM dataset choice')
|
250 |
+
|
251 |
+
# Get args
|
252 |
+
args = parser.parse_args()
|
253 |
+
|
254 |
+
# Load dataset (only json files containing tagged data)
|
255 |
+
babylm_dataset = args.babylm_dataset
|
256 |
+
json_ext = "_parsed.json"
|
257 |
+
# babylm_data = glob(f"{BABYLM_DATA_PATH}/babylm_data/babylm_{babylm_dataset}/*{json_ext}")
|
258 |
+
babylm_data = glob(f"babylm_data/babylm_{babylm_dataset}/*{json_ext}")
|
259 |
+
print("babylm_data:", babylm_data)
|
260 |
+
|
261 |
+
# Get perturbation, affect, and filter functions
|
262 |
+
perturbation_function = PERTURBATIONS[args.perturbation_type]['perturbation_function']
|
263 |
+
affect_function = PERTURBATIONS[args.perturbation_type]['affect_function']
|
264 |
+
filter_function = PERTURBATIONS[args.perturbation_type]['filter_function']
|
265 |
+
qwen_tokenizer = PERTURBATIONS[args.perturbation_type]['qwen_tokenizer']
|
266 |
+
|
267 |
+
if babylm_dataset == "test": # Yj: 为什么abylm_dataset是test? BABYLM_SPLITS = ['100M', '10M', 'dev', 'test', 'unittest']
|
268 |
+
|
269 |
+
# Iterate over files and do transform
|
270 |
+
for file in babylm_data:
|
271 |
+
print(file)
|
272 |
+
f = open(file)
|
273 |
+
data = json.load(f)
|
274 |
+
f.close()
|
275 |
+
|
276 |
+
# Perturb data iteratively
|
277 |
+
results = []
|
278 |
+
for line in tqdm.tqdm(data):
|
279 |
+
results.append(process_line(line))
|
280 |
+
|
281 |
+
new_lines_affected, new_lines_unaffected, unaffected_sents = zip(
|
282 |
+
*results)
|
283 |
+
new_lines_affected = flatten_list(new_lines_affected)
|
284 |
+
new_lines_unaffected = flatten_list(new_lines_unaffected)
|
285 |
+
unaffected_sents = flatten_list(unaffected_sents)
|
286 |
+
|
287 |
+
# Name new file
|
288 |
+
new_file_affected = os.path.basename(
|
289 |
+
file).replace(json_ext, "_affected.test")
|
290 |
+
new_file_unaffected = os.path.basename(
|
291 |
+
file).replace(json_ext, "_unaffected.test")
|
292 |
+
file_unaffected_sents = os.path.basename(
|
293 |
+
file).replace(json_ext, "_unaffected_sents.test")
|
294 |
+
|
295 |
+
# Create directory
|
296 |
+
data_write_directory = f"{BABYLM_DATA_PATH}/Qwen_perturbed_data/{MODEL_NAME}"
|
297 |
+
directory_affected = f"{data_write_directory}/babylm_{args.perturbation_type}/babylm_test_affected/"
|
298 |
+
if not os.path.exists(directory_affected):
|
299 |
+
os.makedirs(directory_affected)
|
300 |
+
directory_unaffected = f"{data_write_directory}/babylm_{args.perturbation_type}/babylm_test_unaffected/"
|
301 |
+
if not os.path.exists(directory_unaffected):
|
302 |
+
os.makedirs(directory_unaffected)
|
303 |
+
directory_unaffected_sents = f"{data_write_directory}/babylm_{args.perturbation_type}/babylm_test_unaffected_sents/"
|
304 |
+
if not os.path.exists(directory_unaffected_sents):
|
305 |
+
os.makedirs(directory_unaffected_sents)
|
306 |
+
|
307 |
+
# Write files
|
308 |
+
write_file(directory_affected,
|
309 |
+
new_file_affected, new_lines_affected)
|
310 |
+
write_file(directory_unaffected,
|
311 |
+
new_file_unaffected, new_lines_unaffected)
|
312 |
+
write_file(directory_unaffected_sents,
|
313 |
+
file_unaffected_sents, unaffected_sents)
|
314 |
+
|
315 |
+
else:
|
316 |
+
# Yj: BABYLM_SPLITS = ['100M', '10M', 'dev', 'test', 'unittest']
|
317 |
+
# Iterate over files and do transform
|
318 |
+
for file in babylm_data:
|
319 |
+
print(file)
|
320 |
+
f = open(file)
|
321 |
+
data = json.load(f)
|
322 |
+
f.close()
|
323 |
+
|
324 |
+
# Perturb data iteratively
|
325 |
+
results = []
|
326 |
+
for line in tqdm.tqdm(data):
|
327 |
+
results.append(process_line(line))
|
328 |
+
|
329 |
+
new_lines_affected, new_lines_unaffected, _ = zip(
|
330 |
+
*results)
|
331 |
+
|
332 |
+
new_lines_affected = flatten_list(new_lines_affected)
|
333 |
+
new_lines_unaffected = flatten_list(new_lines_unaffected)
|
334 |
+
|
335 |
+
# Combine affected and unaffected sentences
|
336 |
+
new_lines = new_lines_unaffected + new_lines_affected
|
337 |
+
|
338 |
+
# Name new file
|
339 |
+
if babylm_dataset == "dev":
|
340 |
+
new_file = os.path.basename(file).replace(json_ext, ".dev")
|
341 |
+
elif babylm_dataset == 'unittest':
|
342 |
+
new_file = os.path.basename(file).replace(json_ext, ".test")
|
343 |
+
|
344 |
+
# Print strings for unittest
|
345 |
+
new_lines_decoded = [qwen_tokenizer.decode(
|
346 |
+
[int(tok) for tok in line.split()]) + "\n" for line in new_lines]
|
347 |
+
new_lines_with_strings = []
|
348 |
+
for tokens, line in list(zip(new_lines, new_lines_decoded)):
|
349 |
+
new_lines_with_strings.append(tokens)
|
350 |
+
new_lines_with_strings.append(line)
|
351 |
+
new_lines = new_lines_with_strings
|
352 |
+
|
353 |
+
else:
|
354 |
+
new_file = os.path.basename(file).replace(json_ext, ".train") # '10M 100M' is training set
|
355 |
+
|
356 |
+
# Create directory and write file
|
357 |
+
directory = f"{BABYLM_DATA_PATH}/Perturbed_data/{MODEL_NAME}/babylm_{args.perturbation_type}/babylm_{babylm_dataset}/"
|
358 |
+
print("directory:", directory)
|
359 |
+
if not os.path.exists(directory):
|
360 |
+
os.makedirs(directory)
|
361 |
+
write_file(directory, new_file, new_lines)
|
data/tag.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# tag.py
|
2 |
+
# Author: Julie Kallini
|
3 |
+
|
4 |
+
# For importing utils
|
5 |
+
import sys
|
6 |
+
sys.path.append("..")
|
7 |
+
|
8 |
+
import pytest
|
9 |
+
import glob
|
10 |
+
import tqdm
|
11 |
+
import os
|
12 |
+
import argparse
|
13 |
+
import stanza
|
14 |
+
import json
|
15 |
+
|
16 |
+
|
17 |
+
test_all_files = sorted(glob.glob("babylm_data/babylm_*/*"))
|
18 |
+
test_original_files = [f for f in test_all_files if ".json" not in f]
|
19 |
+
test_json_files = [f for f in test_all_files if "_parsed.json" in f]
|
20 |
+
test_cases = list(zip(test_original_files, test_json_files))
|
21 |
+
|
22 |
+
|
23 |
+
@pytest.mark.parametrize("original_file, json_file", test_cases)
|
24 |
+
def test_equivalent_lines(original_file, json_file):
|
25 |
+
|
26 |
+
# Read lines of file and remove all whitespace
|
27 |
+
original_file = open(original_file)
|
28 |
+
original_data = "".join(original_file.readlines())
|
29 |
+
original_data = "".join(original_data.split())
|
30 |
+
|
31 |
+
json_file = open(json_file)
|
32 |
+
json_lines = json.load(json_file)
|
33 |
+
json_data = ""
|
34 |
+
for line in json_lines:
|
35 |
+
for sent in line["sent_annotations"]:
|
36 |
+
json_data += sent["sent_text"]
|
37 |
+
json_data = "".join(json_data.split())
|
38 |
+
|
39 |
+
# Test equivalence
|
40 |
+
assert (original_data == json_data)
|
41 |
+
|
42 |
+
|
43 |
+
def __get_constituency_parse(sent, nlp):
|
44 |
+
|
45 |
+
# Try parsing the doc
|
46 |
+
try:
|
47 |
+
parse_doc = nlp(sent.text)
|
48 |
+
except:
|
49 |
+
return None
|
50 |
+
|
51 |
+
# Get set of constituency parse trees
|
52 |
+
parse_trees = [str(sent.constituency) for sent in parse_doc.sentences]
|
53 |
+
|
54 |
+
# Join parse trees and add ROOT
|
55 |
+
constituency_parse = "(ROOT " + " ".join(parse_trees) + ")"
|
56 |
+
return constituency_parse
|
57 |
+
|
58 |
+
|
59 |
+
if __name__ == "__main__":
|
60 |
+
|
61 |
+
parser = argparse.ArgumentParser(
|
62 |
+
prog='Tag BabyLM dataset',
|
63 |
+
description='Tag BabyLM dataset using Stanza')
|
64 |
+
parser.add_argument('path', type=argparse.FileType('r'),
|
65 |
+
nargs='+', help="Path to file(s)")
|
66 |
+
parser.add_argument('-p', '--parse', action='store_true',
|
67 |
+
help="Include constituency parse")
|
68 |
+
|
69 |
+
# Get args
|
70 |
+
args = parser.parse_args()
|
71 |
+
|
72 |
+
# Init Stanza NLP tools
|
73 |
+
nlp1 = stanza.Pipeline(
|
74 |
+
lang='en',
|
75 |
+
processors='tokenize, pos, lemma',
|
76 |
+
package="default_accurate",
|
77 |
+
use_gpu=True)
|
78 |
+
|
79 |
+
# If constituency parse is needed, init second Stanza parser
|
80 |
+
if args.parse:
|
81 |
+
nlp2 = stanza.Pipeline(lang='en',
|
82 |
+
processors='tokenize,pos,constituency',
|
83 |
+
package="default_accurate",
|
84 |
+
use_gpu=True)
|
85 |
+
|
86 |
+
# BATCH_SIZE = 5000
|
87 |
+
BATCH_SIZE=100
|
88 |
+
|
89 |
+
# Iterate over BabyLM files
|
90 |
+
for file in args.path:
|
91 |
+
|
92 |
+
print(file.name)
|
93 |
+
lines = file.readlines()
|
94 |
+
|
95 |
+
# Strip lines and join text
|
96 |
+
print("Concatenating lines...")
|
97 |
+
lines = [l.strip() for l in lines]
|
98 |
+
line_batches = [lines[i:i + BATCH_SIZE]
|
99 |
+
for i in range(0, len(lines), BATCH_SIZE)]
|
100 |
+
text_batches = [" ".join(l) for l in line_batches]
|
101 |
+
|
102 |
+
# Iterate over lines in file and track annotations
|
103 |
+
line_annotations = []
|
104 |
+
print("Segmenting and parsing text batches...")
|
105 |
+
for text in tqdm.tqdm(text_batches):
|
106 |
+
# Tokenize text with stanza
|
107 |
+
doc = nlp1(text)
|
108 |
+
|
109 |
+
# Iterate over sents in the line and track annotations
|
110 |
+
sent_annotations = []
|
111 |
+
for sent in doc.sentences:
|
112 |
+
|
113 |
+
# Iterate over words in sent and track annotations
|
114 |
+
word_annotations = []
|
115 |
+
for token, word in zip(sent.tokens, sent.words):
|
116 |
+
wa = {
|
117 |
+
'id': word.id,
|
118 |
+
'text': word.text,
|
119 |
+
'lemma': word.lemma,
|
120 |
+
'upos': word.upos,
|
121 |
+
'xpos': word.xpos,
|
122 |
+
'feats': word.feats,
|
123 |
+
'start_char': token.start_char,
|
124 |
+
'end_char': token.end_char
|
125 |
+
}
|
126 |
+
word_annotations.append(wa) # Track word annotation
|
127 |
+
|
128 |
+
# Get constituency parse if needed
|
129 |
+
if args.parse:
|
130 |
+
constituency_parse = __get_constituency_parse(sent, nlp2)
|
131 |
+
sa = {
|
132 |
+
'sent_text': sent.text,
|
133 |
+
'constituency_parse': constituency_parse,
|
134 |
+
'word_annotations': word_annotations,
|
135 |
+
}
|
136 |
+
else:
|
137 |
+
sa = {
|
138 |
+
'sent_text': sent.text,
|
139 |
+
'word_annotations': word_annotations,
|
140 |
+
}
|
141 |
+
sent_annotations.append(sa) # Track sent annotation
|
142 |
+
|
143 |
+
la = {
|
144 |
+
'sent_annotations': sent_annotations
|
145 |
+
}
|
146 |
+
line_annotations.append(la) # Track line annotation
|
147 |
+
|
148 |
+
# Write annotations to file as a JSON
|
149 |
+
print("Writing JSON outfile...")
|
150 |
+
ext = '_parsed.json' if args.parse else '.json'
|
151 |
+
json_filename = os.path.splitext(file.name)[0] + ext
|
152 |
+
with open(json_filename, "w") as outfile:
|
153 |
+
json.dump(line_annotations, outfile, indent=4)
|
data/tag_1.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# tag.py
|
2 |
+
# Author: Julie Kallini
|
3 |
+
|
4 |
+
# For importing utils
|
5 |
+
import sys
|
6 |
+
sys.path.append("..")
|
7 |
+
|
8 |
+
import pytest
|
9 |
+
import glob
|
10 |
+
import tqdm
|
11 |
+
import os
|
12 |
+
import argparse
|
13 |
+
import stanza
|
14 |
+
import json
|
15 |
+
from transformers import AutoTokenizer
|
16 |
+
|
17 |
+
# Define the function to chunk text
|
18 |
+
def chunk_text(text, tokenizer, max_length=512):
|
19 |
+
tokens = tokenizer(text)['input_ids']
|
20 |
+
chunks = [tokens[i:i + max_length] for i in range(0, len(tokens), max_length)]
|
21 |
+
return [tokenizer.decode(chunk, skip_special_tokens=True) for chunk in chunks]
|
22 |
+
|
23 |
+
# Test case for checking equivalence of original and parsed files
|
24 |
+
test_all_files = sorted(glob.glob("babylm_data/babylm_*/*"))
|
25 |
+
test_original_files = [f for f in test_all_files if ".json" not in f]
|
26 |
+
test_json_files = [f for f in test_all_files if "_parsed.json" in f]
|
27 |
+
test_cases = list(zip(test_original_files, test_json_files))
|
28 |
+
|
29 |
+
@pytest.mark.parametrize("original_file, json_file", test_cases)
|
30 |
+
def test_equivalent_lines(original_file, json_file):
|
31 |
+
|
32 |
+
# Read lines of file and remove all whitespace
|
33 |
+
original_file = open(original_file)
|
34 |
+
original_data = "".join(original_file.readlines())
|
35 |
+
original_data = "".join(original_data.split())
|
36 |
+
|
37 |
+
json_file = open(json_file)
|
38 |
+
json_lines = json.load(json_file)
|
39 |
+
json_data = ""
|
40 |
+
for line in json_lines:
|
41 |
+
for sent in line["sent_annotations"]:
|
42 |
+
json_data += sent["sent_text"]
|
43 |
+
json_data = "".join(json_data.split())
|
44 |
+
|
45 |
+
# Test equivalence
|
46 |
+
assert (original_data == json_data)
|
47 |
+
|
48 |
+
# Constituency parsing function
|
49 |
+
def __get_constituency_parse(sent, nlp):
|
50 |
+
|
51 |
+
# Try parsing the doc
|
52 |
+
try:
|
53 |
+
parse_doc = nlp(sent.text)
|
54 |
+
except:
|
55 |
+
return None
|
56 |
+
|
57 |
+
# Get set of constituency parse trees
|
58 |
+
parse_trees = [str(sent.constituency) for sent in parse_doc.sentences]
|
59 |
+
|
60 |
+
# Join parse trees and add ROOT
|
61 |
+
constituency_parse = "(ROOT " + " ".join(parse_trees) + ")"
|
62 |
+
return constituency_parse
|
63 |
+
|
64 |
+
# Main function
|
65 |
+
if __name__ == "__main__":
|
66 |
+
|
67 |
+
parser = argparse.ArgumentParser(
|
68 |
+
prog='Tag BabyLM dataset',
|
69 |
+
description='Tag BabyLM dataset using Stanza')
|
70 |
+
parser.add_argument('path', type=argparse.FileType('r'),
|
71 |
+
nargs='+', help="Path to file(s)")
|
72 |
+
parser.add_argument('-p', '--parse', action='store_true',
|
73 |
+
help="Include constituency parse")
|
74 |
+
|
75 |
+
# Get args
|
76 |
+
args = parser.parse_args()
|
77 |
+
|
78 |
+
# Init Stanza NLP tools
|
79 |
+
nlp1 = stanza.Pipeline(
|
80 |
+
lang='en',
|
81 |
+
processors='tokenize, pos, lemma',
|
82 |
+
package="default_accurate",
|
83 |
+
use_gpu=True)
|
84 |
+
|
85 |
+
# If constituency parse is needed, init second Stanza parser
|
86 |
+
if args.parse:
|
87 |
+
nlp2 = stanza.Pipeline(lang='en',
|
88 |
+
processors='tokenize,pos,constituency',
|
89 |
+
package="default_accurate",
|
90 |
+
use_gpu=True)
|
91 |
+
|
92 |
+
BATCH_SIZE = 100
|
93 |
+
|
94 |
+
# Tokenizer for splitting long text
|
95 |
+
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
96 |
+
|
97 |
+
# Iterate over BabyLM files
|
98 |
+
for file in args.path:
|
99 |
+
|
100 |
+
print(file.name)
|
101 |
+
lines = file.readlines()
|
102 |
+
|
103 |
+
# Strip lines and join text
|
104 |
+
print("Concatenating lines...")
|
105 |
+
lines = [l.strip() for l in lines]
|
106 |
+
line_batches = [lines[i:i + BATCH_SIZE]
|
107 |
+
for i in range(0, len(lines), BATCH_SIZE)]
|
108 |
+
text_batches = [" ".join(l) for l in line_batches]
|
109 |
+
|
110 |
+
# Iterate over lines in file and track annotations
|
111 |
+
line_annotations = []
|
112 |
+
print("Segmenting and parsing text batches...")
|
113 |
+
for text in tqdm.tqdm(text_batches):
|
114 |
+
# Split the text into chunks if it exceeds the max length
|
115 |
+
text_chunks = chunk_text(text, tokenizer)
|
116 |
+
|
117 |
+
# Iterate over each chunk
|
118 |
+
for chunk in text_chunks:
|
119 |
+
# Tokenize text with stanza
|
120 |
+
doc = nlp1(chunk)
|
121 |
+
|
122 |
+
# Iterate over sentences in the line and track annotations
|
123 |
+
sent_annotations = []
|
124 |
+
for sent in doc.sentences:
|
125 |
+
|
126 |
+
# Iterate over words in the sentence and track annotations
|
127 |
+
word_annotations = []
|
128 |
+
for token, word in zip(sent.tokens, sent.words):
|
129 |
+
wa = {
|
130 |
+
'id': word.id,
|
131 |
+
'text': word.text,
|
132 |
+
'lemma': word.lemma,
|
133 |
+
'upos': word.upos,
|
134 |
+
'xpos': word.xpos,
|
135 |
+
'feats': word.feats,
|
136 |
+
'start_char': token.start_char,
|
137 |
+
'end_char': token.end_char
|
138 |
+
}
|
139 |
+
word_annotations.append(wa) # Track word annotation
|
140 |
+
|
141 |
+
# Get constituency parse if needed
|
142 |
+
if args.parse:
|
143 |
+
constituency_parse = __get_constituency_parse(sent, nlp2)
|
144 |
+
sa = {
|
145 |
+
'sent_text': sent.text,
|
146 |
+
'constituency_parse': constituency_parse,
|
147 |
+
'word_annotations': word_annotations,
|
148 |
+
}
|
149 |
+
else:
|
150 |
+
sa = {
|
151 |
+
'sent_text': sent.text,
|
152 |
+
'word_annotations': word_annotations,
|
153 |
+
}
|
154 |
+
sent_annotations.append(sa) # Track sent annotation
|
155 |
+
|
156 |
+
la = {
|
157 |
+
'sent_annotations': sent_annotations
|
158 |
+
}
|
159 |
+
line_annotations.append(la) # Track line annotation
|
160 |
+
|
161 |
+
# Write annotations to file as a JSON
|
162 |
+
print("Writing JSON outfile...")
|
163 |
+
ext = '_parsed.json' if args.parse else '.json'
|
164 |
+
json_filename = os.path.splitext(file.name)[0] + ext
|
165 |
+
with open(json_filename, "w") as outfile:
|
166 |
+
json.dump(line_annotations, outfile, indent=4)
|
data/tag_distributed.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# the files can be processed on different gpus, each file is processed on a gpu
|
2 |
+
import torch
|
3 |
+
import torch.distributed as dist
|
4 |
+
import sys
|
5 |
+
sys.path.append("..")
|
6 |
+
|
7 |
+
import pytest
|
8 |
+
import glob
|
9 |
+
import tqdm
|
10 |
+
import os
|
11 |
+
import argparse
|
12 |
+
import stanza
|
13 |
+
import json
|
14 |
+
from transformers import AutoTokenizer
|
15 |
+
|
16 |
+
def chunk_text(text, tokenizer, max_length=512):
|
17 |
+
tokens = tokenizer(text)['input_ids']
|
18 |
+
chunks = [tokens[i:i + max_length] for i in range(0, len(tokens), max_length)]
|
19 |
+
return [tokenizer.decode(chunk, skip_special_tokens=True) for chunk in chunks]
|
20 |
+
|
21 |
+
def init_distributed_mode():
|
22 |
+
dist.init_process_group(backend='nccl')
|
23 |
+
rank = dist.get_rank()
|
24 |
+
torch.cuda.set_device(rank) # 使用rank指定GPU
|
25 |
+
return rank
|
26 |
+
|
27 |
+
def run_on_gpu(rank, args, tokenizer, nlp1, nlp2):
|
28 |
+
print(f"Running on Rank {rank}, using GPU {torch.cuda.current_device()}")
|
29 |
+
print(f"Rank {rank}, GPU {torch.cuda.current_device()} started")
|
30 |
+
files_per_gpu = len(args.path) // dist.get_world_size()
|
31 |
+
start_idx = rank * files_per_gpu
|
32 |
+
end_idx = start_idx + files_per_gpu if rank != dist.get_world_size() - 1 else len(args.path)
|
33 |
+
gpu_files = args.path[start_idx:end_idx]
|
34 |
+
|
35 |
+
for file in gpu_files:
|
36 |
+
print(f"GPU {rank}: Processing {file.name}")
|
37 |
+
lines = file.readlines()
|
38 |
+
|
39 |
+
lines = [l.strip() for l in lines]
|
40 |
+
line_batches = [lines[i:i + BATCH_SIZE] for i in range(0, len(lines), BATCH_SIZE)]
|
41 |
+
text_batches = [" ".join(l) for l in line_batches]
|
42 |
+
|
43 |
+
line_annotations = []
|
44 |
+
for text in tqdm.tqdm(text_batches, desc=f"GPU {rank}"):
|
45 |
+
text_chunks = chunk_text(text, tokenizer)
|
46 |
+
for chunk in text_chunks:
|
47 |
+
doc = nlp1(chunk)
|
48 |
+
sent_annotations = []
|
49 |
+
for sent in doc.sentences:
|
50 |
+
word_annotations = []
|
51 |
+
for token, word in zip(sent.tokens, sent.words):
|
52 |
+
wa = {
|
53 |
+
'id': word.id,
|
54 |
+
'text': word.text,
|
55 |
+
'lemma': word.lemma,
|
56 |
+
'upos': word.upos,
|
57 |
+
'xpos': word.xpos,
|
58 |
+
'feats': word.feats,
|
59 |
+
'start_char': token.start_char,
|
60 |
+
'end_char': token.end_char
|
61 |
+
}
|
62 |
+
word_annotations.append(wa)
|
63 |
+
|
64 |
+
sa = {
|
65 |
+
'sent_text': sent.text,
|
66 |
+
'word_annotations': word_annotations
|
67 |
+
}
|
68 |
+
if args.parse:
|
69 |
+
sa['constituency_parse'] = __get_constituency_parse(sent, nlp2)
|
70 |
+
|
71 |
+
sent_annotations.append(sa)
|
72 |
+
line_annotations.append({'sent_annotations': sent_annotations})
|
73 |
+
|
74 |
+
json_filename = os.path.splitext(file.name)[0] + '_parsed.json' if args.parse else '.json'
|
75 |
+
with open(json_filename, "w") as outfile:
|
76 |
+
json.dump(line_annotations, outfile, indent=4)
|
77 |
+
|
78 |
+
def __get_constituency_parse(sent, nlp):
|
79 |
+
try:
|
80 |
+
parse_doc = nlp(sent.text)
|
81 |
+
except:
|
82 |
+
return None
|
83 |
+
parse_trees = [str(sent.constituency) for sent in parse_doc.sentences]
|
84 |
+
return "(ROOT " + " ".join(parse_trees) + ")"
|
85 |
+
|
86 |
+
if __name__ == "__main__":
|
87 |
+
parser = argparse.ArgumentParser(
|
88 |
+
prog='Tag BabyLM dataset',
|
89 |
+
description='Tag BabyLM dataset using Stanza')
|
90 |
+
parser.add_argument('path', type=argparse.FileType('r'),
|
91 |
+
nargs='+', help="Path to file(s)")
|
92 |
+
parser.add_argument('-p', '--parse', action='store_true',
|
93 |
+
help="Include constituency parse")
|
94 |
+
args = parser.parse_args()
|
95 |
+
|
96 |
+
rank = init_distributed_mode()
|
97 |
+
|
98 |
+
BATCH_SIZE = 1000
|
99 |
+
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
100 |
+
nlp1 = stanza.Pipeline(lang='en', processors='tokenize,pos,lemma', package="default_accurate", use_gpu=True)
|
101 |
+
|
102 |
+
nlp2 = None
|
103 |
+
if args.parse:
|
104 |
+
nlp2 = stanza.Pipeline(lang='en', processors='tokenize,pos,constituency', package="default_accurate", use_gpu=True)
|
105 |
+
|
106 |
+
run_on_gpu(rank, args, tokenizer, nlp1, nlp2)
|
data/tag_single.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# single file can be split to some small files and run on different gpus
|
2 |
+
import torch
|
3 |
+
import torch.distributed as dist
|
4 |
+
import sys
|
5 |
+
sys.path.append("..")
|
6 |
+
|
7 |
+
import pytest
|
8 |
+
import glob
|
9 |
+
import tqdm
|
10 |
+
import os
|
11 |
+
import argparse
|
12 |
+
import stanza
|
13 |
+
import json
|
14 |
+
from transformers import AutoTokenizer
|
15 |
+
|
16 |
+
def chunk_text(text, tokenizer, max_length=512):
|
17 |
+
tokens = tokenizer(text)['input_ids']
|
18 |
+
chunks = [tokens[i:i + max_length] for i in range(0, len(tokens), max_length)]
|
19 |
+
return [tokenizer.decode(chunk, skip_special_tokens=True) for chunk in chunks]
|
20 |
+
|
21 |
+
def init_distributed_mode():
|
22 |
+
dist.init_process_group(backend='nccl')
|
23 |
+
rank = dist.get_rank()
|
24 |
+
torch.cuda.set_device(rank) # 使用rank指定GPU
|
25 |
+
return rank
|
26 |
+
|
27 |
+
def process_single_file(file, rank, tokenizer, nlp1, nlp2):
|
28 |
+
print(f"GPU {rank}: Processing {file.name}")
|
29 |
+
lines = file.readlines()
|
30 |
+
|
31 |
+
# 根据行数划分任务
|
32 |
+
num_lines = len(lines)
|
33 |
+
num_gpus = dist.get_world_size()
|
34 |
+
|
35 |
+
lines_per_gpu = (num_lines + num_gpus - 1) // num_gpus
|
36 |
+
start_idx = rank * lines_per_gpu
|
37 |
+
end_idx = min(start_idx + lines_per_gpu, num_lines)
|
38 |
+
gpu_lines = lines[start_idx:end_idx]
|
39 |
+
|
40 |
+
line_batches = [gpu_lines[i:i + BATCH_SIZE] for i in range(0, len(gpu_lines), BATCH_SIZE)]
|
41 |
+
text_batches = [" ".join(l) for l in line_batches]
|
42 |
+
|
43 |
+
line_annotations = []
|
44 |
+
for text in tqdm.tqdm(text_batches, desc=f"GPU {rank}"):
|
45 |
+
text_chunks = chunk_text(text, tokenizer)
|
46 |
+
for chunk in text_chunks:
|
47 |
+
doc = nlp1(chunk)
|
48 |
+
sent_annotations = []
|
49 |
+
for sent in doc.sentences:
|
50 |
+
word_annotations = []
|
51 |
+
for token, word in zip(sent.tokens, sent.words):
|
52 |
+
wa = {
|
53 |
+
'id': word.id,
|
54 |
+
'text': word.text,
|
55 |
+
'lemma': word.lemma,
|
56 |
+
'upos': word.upos,
|
57 |
+
'xpos': word.xpos,
|
58 |
+
'feats': word.feats,
|
59 |
+
'start_char': token.start_char,
|
60 |
+
'end_char': token.end_char
|
61 |
+
}
|
62 |
+
word_annotations.append(wa)
|
63 |
+
|
64 |
+
sa = {
|
65 |
+
'sent_text': sent.text,
|
66 |
+
'word_annotations': word_annotations
|
67 |
+
}
|
68 |
+
if args.parse:
|
69 |
+
sa['constituency_parse'] = __get_constituency_parse(sent, nlp2)
|
70 |
+
|
71 |
+
sent_annotations.append(sa)
|
72 |
+
line_annotations.append({'sent_annotations': sent_annotations})
|
73 |
+
|
74 |
+
# 暂存不同GPU的输出
|
75 |
+
temp_filename = os.path.splitext(file.name)[0] + f'_rank{rank}.json'
|
76 |
+
with open(temp_filename, "w") as outfile:
|
77 |
+
json.dump(line_annotations, outfile, indent=4)
|
78 |
+
|
79 |
+
return temp_filename
|
80 |
+
|
81 |
+
def merge_files(temp_files, output_file):
|
82 |
+
merged_data = []
|
83 |
+
for file in temp_files:
|
84 |
+
with open(file, "r") as infile:
|
85 |
+
data = json.load(infile)
|
86 |
+
merged_data.extend(data)
|
87 |
+
os.remove(file) # 删除临时文件
|
88 |
+
|
89 |
+
with open(output_file, "w") as outfile:
|
90 |
+
json.dump(merged_data, outfile, indent=4)
|
91 |
+
|
92 |
+
def run_on_gpu(rank, args, tokenizer, nlp1, nlp2):
|
93 |
+
print(f"Running on Rank {rank}, using GPU {torch.cuda.current_device()}")
|
94 |
+
|
95 |
+
temp_files = []
|
96 |
+
if len(args.path) == 1:
|
97 |
+
temp_files.append(process_single_file(args.path[0], rank, tokenizer, nlp1, nlp2))
|
98 |
+
dist.barrier() # 等待所有进程完成处理
|
99 |
+
if rank == 0:
|
100 |
+
# 合并文件
|
101 |
+
final_output = os.path.splitext(args.path[0].name)[0] + '_merged.json'
|
102 |
+
merge_files(temp_files, final_output)
|
103 |
+
else:
|
104 |
+
files_per_gpu = len(args.path) // dist.get_world_size()
|
105 |
+
start_idx = rank * files_per_gpu
|
106 |
+
end_idx = start_idx + files_per_gpu if rank != dist.get_world_size() - 1 else len(args.path)
|
107 |
+
gpu_files = args.path[start_idx:end_idx]
|
108 |
+
|
109 |
+
for file in gpu_files:
|
110 |
+
process_single_file(file, rank, tokenizer, nlp1, nlp2)
|
111 |
+
|
112 |
+
def __get_constituency_parse(sent, nlp):
|
113 |
+
try:
|
114 |
+
parse_doc = nlp(sent.text)
|
115 |
+
except:
|
116 |
+
return None
|
117 |
+
parse_trees = [str(sent.constituency) for sent in parse_doc.sentences]
|
118 |
+
return "(ROOT " + " ".join(parse_trees) + ")"
|
119 |
+
|
120 |
+
if __name__ == "__main__":
|
121 |
+
parser = argparse.ArgumentParser(
|
122 |
+
prog='Tag BabyLM dataset',
|
123 |
+
description='Tag BabyLM dataset using Stanza')
|
124 |
+
parser.add_argument('path', type=argparse.FileType('r'),
|
125 |
+
nargs='+', help="Path to file(s)")
|
126 |
+
parser.add_argument('-p', '--parse', action='store_true',
|
127 |
+
help="Include constituency parse")
|
128 |
+
args = parser.parse_args()
|
129 |
+
|
130 |
+
rank = init_distributed_mode()
|
131 |
+
|
132 |
+
BATCH_SIZE = 1000
|
133 |
+
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
134 |
+
nlp1 = stanza.Pipeline(lang='en', processors='tokenize,pos,lemma', package="default_accurate", use_gpu=True)
|
135 |
+
|
136 |
+
nlp2 = None
|
137 |
+
if args.parse:
|
138 |
+
nlp2 = stanza.Pipeline(lang='en', processors='tokenize,pos,constituency', package="default_accurate", use_gpu=True)
|
139 |
+
|
140 |
+
run_on_gpu(rank, args, tokenizer, nlp1, nlp2)
|
perplexities/perplexity_results/Qwen2.5-0.5B/reverse_full/Qwen2.5-0.5B_seed1_test_reverse_full_checkpoint-1000.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
Perplexity
|
2 |
+
12239897.0
|
perplexities/perplexity_results/Qwen2.5-0.5B/reverse_full/Qwen2.5-0.5B_seed1_test_reverse_full_checkpoint-10000.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
Perplexity
|
2 |
+
100086656.0
|
perplexities/perplexity_results/Qwen2.5-0.5B/reverse_full/Qwen2.5-0.5B_seed1_test_reverse_full_checkpoint-11500.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
Perplexity
|
2 |
+
86934072.0
|
perplexities/perplexity_results/Qwen2.5-0.5B/reverse_full/Qwen2.5-0.5B_seed1_test_reverse_full_checkpoint-1500.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
Perplexity
|
2 |
+
221982080.0
|
perplexities/perplexity_results/Qwen2.5-0.5B/reverse_full/Qwen2.5-0.5B_seed1_test_reverse_full_checkpoint-2000.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
Perplexity
|
2 |
+
389647168.0
|