| # utils_qwen.py | |
| # Author: Yaning | |
| from collections import deque | |
| from string import punctuation | |
| from transformers import AutoTokenizer, AddedToken | |
| from functools import partial | |
| from numpy.random import default_rng | |
| from nltk.tree import ParentedTree | |
| import torch | |
| ############################################################################## | |
| # CONSTANTS | |
| ############################################################################## | |
| BABYLM_SPLITS = ['100M', '10M', 'dev', 'test', 'unittest'] | |
| # Yj: 用于在参数解析和数据加载时指定数据集 | |
| # 影响数据集的预处理过程,如生成训练、开发、测试和单元测试集。 | |
| SEEDS = [21, 57, 84] | |
| CHECKPOINTS = list(range(50, 501, 50)) | |
| GENRES = { | |
| "aochildes": "CHILDES", | |
| "bnc_spoken": "British National Corpus (BNC)", | |
| "cbt": "Children’s Book Test", | |
| "children_stories": "Children’s Stories Text Corpus", | |
| "gutenberg": "Standardized Project Gutenberg Corpus", | |
| "open_subtitles": "OpenSubtitles", | |
| "qed": "QCRI Educational Domain Corpus", | |
| "simple_wikipedia": "Simple Wikipedia", | |
| "switchboard": "Switchboard Dialog Act Corpus", | |
| "wikipedia": "Wikipedia" | |
| } | |
| CHECKPOINT_WRITE_PATH = "/nlp/scr3/nlp/llms-in-llms/babylm_models" | |
| CHECKPOINT_READ_PATH = "/nlp/scr3/nlp/llms-in-llms/babylm_models" | |
| # BABYLM_DATA_PATH = "/nlp/scr3/nlp/llms-in-llms/babylm_data" | |
| BABYLM_DATA_PATH = "." | |
| MARKER_HOP_SING = "🅂" | |
| MARKER_HOP_PLUR = "🄿" | |
| MARKER_REV = "🅁" | |
| BOS_TOKEN = "<BOS_TOKEN>" | |
| PART_TOKENS = set(["n't", "'ll", "'s", "'re", "'ve", "'m"]) | |
| PUNCT_TOKENS = set(punctuation) | |
| MODEL_NAME = "gpt2" | |
| ############################################################################## | |
| # PARENS MODELS (Structurally-pretrained) | |
| ############################################################################## | |
| PAREN_MODEL_PATH = "/u/scr/isabelvp//tilt-stuff/tilt-finetuning/pretrained_checkpoints/" | |
| PAREN_MODELS = { | |
| "CROSS": "flat-parens_vocab500-uniform_deplength-nesting-nolimit", | |
| "NEST": "nested-parens0.49_vocab500-uniform", | |
| "RAND": "random_vocab500-uniform", | |
| } | |
| ############################################################################## | |
| # HELPER FUNCTIONS | |
| ############################################################################## | |
| def write_file(directory, filename, lines): | |
| f = open(directory + filename, "w") | |
| f.writelines(lines) | |
| f.close() | |
| def get_qwen_tokenizer_with_markers(marker_list): | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| # If no new markers to add, return normal tokenizer | |
| if len(marker_list) == 0: | |
| return tokenizer | |
| # Create tokens and return modified tokenizer | |
| new_tokens = [] | |
| for marker in marker_list: | |
| new_tokens.append(AddedToken(marker, lstrip=True, rstrip=False)) | |
| tokenizer.add_tokens(new_tokens) | |
| return tokenizer | |
| qwen_original_tokenizer = get_qwen_tokenizer_with_markers([]) | |
| # GPT-2 hop tokenization | |
| qwen_hop_tokenizer = get_qwen_tokenizer_with_markers( | |
| [MARKER_HOP_SING, MARKER_HOP_PLUR]) | |
| # Get ids of marker tokens | |
| marker_sg_token = qwen_hop_tokenizer.get_added_vocab()[ | |
| MARKER_HOP_SING] | |
| # Yj:获取分词器中所有自定义添加的标记及其对应的 token ID | |
| marker_pl_token = qwen_hop_tokenizer.get_added_vocab()[ | |
| MARKER_HOP_PLUR] | |
| # Qwen reverse tokenization | |
| qwen_rev_tokenizer = get_qwen_tokenizer_with_markers( | |
| [MARKER_REV]) | |
| # Get ids of marker tokens | |
| marker_rev_token = qwen_rev_tokenizer.get_added_vocab()[ | |
| MARKER_REV] | |
| # Qwen determiner tokenization | |
| qwen_det_tokenizer = get_qwen_tokenizer_with_markers( | |
| [BOS_TOKEN]) | |
| # Get id of BOS token | |
| bos_token_id = qwen_det_tokenizer.get_added_vocab()[BOS_TOKEN] | |
| MARKER_TOKEN_IDS = [marker_sg_token, marker_pl_token, marker_rev_token] | |
| def compute_surprisals(model, input_ids): | |
| # Get the log probabilities from the model | |
| with torch.no_grad(): | |
| outputs = model(input_ids) | |
| logits = outputs.logits[:, :-1] | |
| shifted_input_ids = input_ids[:, 1:] | |
| # Get the log probabilities for the actual next tokens | |
| log_probs = torch.log2(torch.nn.functional.softmax(logits, dim=-1)) | |
| true_log_probs = log_probs.gather( | |
| 2, shifted_input_ids.unsqueeze(-1)).squeeze(-1) | |
| # Get the negative log probabilities | |
| neg_log_probs = (-true_log_probs).tolist() | |
| surprisals = [[None] + probs for probs in neg_log_probs] | |
| return surprisals | |
| def compute_token_probabilities(model, input_ids, token_id, pad_token_id): | |
| # Get the log probabilities from the model | |
| with torch.no_grad(): | |
| outputs = model(input_ids) | |
| logits = outputs.logits[:, :-1] | |
| probs = torch.nn.functional.softmax(logits, dim=-1) | |
| # Get the probabilities for the specified token at each position | |
| token_probs = probs[:, :, token_id] | |
| # Convert to list and add None at the beginning to align with input tokens | |
| # Put null probability for instances of pad token | |
| token_probs_list = [] | |
| for batch_i, probs in enumerate(token_probs): | |
| input_ids_seq = input_ids[batch_i].tolist() + [pad_token_id] | |
| filtered = [p if input_ids_seq[pos_i+1] != | |
| pad_token_id else None for pos_i, p in enumerate(probs.tolist())] | |
| token_probs_list.append([None] + filtered) | |
| return token_probs_list | |
| def merge_part_tokens(words): | |
| result = [] | |
| for s in words: | |
| if result and s in PART_TOKENS and len(result) > 0: | |
| result[-1] += s | |
| else: | |
| result.append(s) | |
| return result | |
| def __affect_hop_word(word): | |
| return word["feats"] and "Person=3" in word["feats"] \ | |
| and "Tense=Pres" in word["feats"] \ | |
| and "VerbForm=Fin" in word["feats"] \ | |
| and "Number" in word["feats"] | |
| def __perturb_hop_words(sent, num_hops, marker_sg, marker_pl): | |
| perturbed_tokens, _ = __perturb_hop_words_complete_hops( | |
| sent, num_hops, marker_sg, marker_pl) | |
| return perturbed_tokens | |
| def check_word_hops_completed(sent, num_hops=4, marker=MARKER_HOP_SING): | |
| _, hops_completed = __perturb_hop_words_complete_hops( | |
| sent, num_hops, marker, marker) | |
| return hops_completed | |
| def __perturb_hop_words_complete_hops(sent, num_hops, marker_sg, marker_pl): | |
| word_annotations = sent["word_annotations"].copy() | |
| word_annotations.reverse() | |
| hop_completed = [] | |
| new_sent = [] | |
| for word in word_annotations: | |
| # Identify 3.pres verbs | |
| if __affect_hop_word(word): | |
| # Lemmatize verb if possible | |
| new_sent.append( | |
| word["lemma"] if word["lemma"] is not None else word["text"]) | |
| # Marker hopping logic | |
| insert_index = len(new_sent)-1 | |
| skipped_words = 0 | |
| while skipped_words < num_hops and insert_index > 0: | |
| # Handle edge case when punctuation (or sequence of | |
| # punctuation) begin the sentence | |
| if (not any([c.isalnum() for c in | |
| "".join(new_sent[:insert_index])])): | |
| break | |
| # Yj: 如果字符串中不存在任何字母或数字字符(即都是标点、空格等) | |
| # Count word as skipped if it is not a special token | |
| if (new_sent[insert_index] not in PART_TOKENS) and \ | |
| (not set(new_sent[insert_index]).issubset(PUNCT_TOKENS)): | |
| skipped_words += 1 | |
| insert_index -= 1 | |
| # Handle edge case when insert index is punctuation (and this is not | |
| # sentence-initial punctuation) | |
| if any([c.isalnum() for c in | |
| "".join(new_sent[:insert_index])]): | |
| while insert_index != 0 and (new_sent[insert_index] in PART_TOKENS | |
| or set(new_sent[insert_index]).issubset(PUNCT_TOKENS)): | |
| insert_index -= 1 | |
| # Handle edge case when token before insert index is part/aux token | |
| if insert_index != 0 and new_sent[insert_index-1] in PART_TOKENS: | |
| insert_index -= 1 | |
| # Log if this sentence had all full hops | |
| hop_completed.append(skipped_words == num_hops) | |
| # Use correct marker for singular vs. plural | |
| if "Number=Sing" in word["feats"]: | |
| new_sent.insert(insert_index, marker_sg) | |
| elif "Number=Plur" in word["feats"]: | |
| new_sent.insert(insert_index, marker_pl) | |
| else: | |
| raise Exception( | |
| "Number not in verb features\n" + sent["sent_text"]) | |
| else: | |
| new_sent.append(word["text"]) | |
| new_sent.reverse() | |
| sent_string = " ".join(merge_part_tokens(new_sent)) | |
| tokens = qwen_hop_tokenizer.encode(sent_string) | |
| return tokens, all(hop_completed) and len(hop_completed) > 0 | |
| def __perturb_hop_tokens(sent, num_hops): | |
| word_annotations = sent["word_annotations"].copy() | |
| word_annotations.reverse() | |
| new_sent = deque() | |
| tokens = [] | |
| for word in word_annotations: | |
| # Identify 3.pres verbs | |
| if __affect_hop_word(word): | |
| # Lemmatize verb if possible | |
| lemma = word["lemma"] if word["lemma"] is not None else word["text"] | |
| if len(new_sent) > 0 and new_sent[0] in PART_TOKENS: | |
| lemma = lemma + new_sent[0] | |
| new_sent.popleft() | |
| if len(new_sent) > 0: | |
| sent_string = " ".join(merge_part_tokens(new_sent)) | |
| tokens = qwen_hop_tokenizer.encode( | |
| " " + sent_string) + tokens | |
| # Use correct marker for singular vs. plural | |
| if "Number=Sing" in word["feats"]: | |
| tokens.insert(num_hops, marker_sg_token) | |
| elif "Number=Plur" in word["feats"]: | |
| tokens.insert(num_hops, marker_pl_token) | |
| else: | |
| raise Exception( | |
| "Number not in verb features\n" + sent["sent_text"]) | |
| new_sent = deque() | |
| new_sent.append(lemma) | |
| else: | |
| new_sent.appendleft(word["text"]) | |
| if len(new_sent) > 0: | |
| sent_string = " ".join(merge_part_tokens(new_sent)) | |
| tokens = qwen_hop_tokenizer.encode(sent_string) + tokens | |
| return tokens | |
| def __perturb_reverse(sent, rng, reverse, full): | |
| # Get sentence text and GPT-2 tokens | |
| tokens = qwen_rev_tokenizer.encode(sent["sent_text"]) | |
| # Pick random index to insert REV token | |
| i = rng.choice(len(tokens)+1) | |
| tokens.insert(i, marker_rev_token) | |
| # Extract tokens before/after the marker, and reverse tokens after | |
| tokens_before = tokens[:i+1] | |
| tokens_after = tokens[i+1:] | |
| if reverse: | |
| tokens_after.reverse() | |
| new_tokens = tokens_before + tokens_after | |
| if full: | |
| assert not reverse | |
| new_tokens.reverse() | |
| return new_tokens | |
| def __perturb_shuffle_deterministic(sent, seed, shuffle): | |
| # Get sentence text and GPT-2 tokens | |
| tokens = qwen_original_tokenizer.encode(sent["sent_text"]) | |
| if shuffle: | |
| default_rng(seed).shuffle(tokens) | |
| return tokens | |
| def __perturb_shuffle_nondeterministic(sent, rng): | |
| # Get sentence text and GPT-2 tokens | |
| tokens = qwen_original_tokenizer.encode(sent["sent_text"]) | |
| rng.shuffle(tokens) | |
| return tokens | |
| def __perturb_shuffle_local(sent, seed, window=5): | |
| # Get sentence text and GPT-2 tokens | |
| tokens = qwen_original_tokenizer.encode(sent["sent_text"]) | |
| # Shuffle tokens in batches of size window | |
| shuffled_tokens = [] | |
| for i in range(0, len(tokens), window): | |
| batch = tokens[i:i+window].copy() | |
| default_rng(seed).shuffle(batch) | |
| shuffled_tokens += batch | |
| return shuffled_tokens | |
| def __perturb_shuffle_even_odd(sent): | |
| # Get sentence text and GPT-2 tokens | |
| tokens = qwen_original_tokenizer.encode(sent["sent_text"]) | |
| even = [tok for i, tok in enumerate(tokens) if i % 2 == 0] | |
| odd = [tok for i, tok in enumerate(tokens) if i % 2 != 0] | |
| return even + odd | |
| ############################################################################## | |
| # AFFECT FUNCTIONS | |
| # These functions define when a perturbation has been applied to a sentence | |
| # not. This is used for identifying which test sentences have been | |
| # altered to separate affected vs. unaffected senences. Affect functions are | |
| # functions of the input sentence object and return a boolean. | |
| ############################################################################## | |
| def affect_hop(sent): | |
| return any([__affect_hop_word(word) for word in sent['word_annotations']]) \ | |
| and sent["constituency_parse"] is not None | |
| def affect_reverse(sent): | |
| return True | |
| def affect_shuffle(sent): | |
| return True | |
| def affect_none(sent): | |
| return False | |
| ############################################################################## | |
| # FILTER FUNCTIONS | |
| # These functions define when an affected sentence should be included in the | |
| # final dataset. For instance, hop perturbations where the marker is placed | |
| # at the end of the sentence should be excluded. A filter function returns | |
| # True if an affected sentence should be included in the dataset. | |
| ############################################################################## | |
| def filter_hop(sent): | |
| # Assertion needed since filter function is only defined for affected | |
| # sentences | |
| assert (affect_hop(sent)) | |
| return check_word_hops_completed(sent, 4) | |
| def filter_reverse(sent): | |
| return True | |
| def filter_shuffle(sent): | |
| tokens = qwen_original_tokenizer.encode(sent["sent_text"]) | |
| return len(tokens) > 1 and len(tokens) <= 350 | |
| def filter_none(sent): | |
| return False | |
| ############################################################################## | |
| # PERTURBATION FUNCTIONS | |
| # These functions define how a perturbation will affect a sentence. They | |
| # take in a sentence object and an optional marker | |
| # for verb transformations. They return a string representing the transformed | |
| # sentence. | |
| ############################################################################## | |
| def perturb_hop_words4(sent): | |
| return __perturb_hop_words(sent, 4, MARKER_HOP_SING, MARKER_HOP_PLUR) | |
| def perturb_hop_tokens4(sent): | |
| return __perturb_hop_tokens(sent, 4) | |
| def perturb_hop_control(sent): | |
| return __perturb_hop_tokens(sent, 0) | |
| def perturb_reverse(sent, rng, reverse=True, full=False): | |
| return __perturb_reverse(sent, rng, reverse, full) | |
| def perturb_shuffle_deterministic(sent, seed=None, shuffle=True): | |
| return __perturb_shuffle_deterministic(sent, seed, shuffle) | |
| def perturb_shuffle_nondeterministic(sent, rng): | |
| return __perturb_shuffle_nondeterministic(sent, rng) | |
| def perturb_shuffle_local(sent, seed, window): | |
| return __perturb_shuffle_local(sent, seed, window) | |
| def perturb_shuffle_even_odd(sent): | |
| return __perturb_shuffle_even_odd(sent) | |
| ############################################################################## | |
| # PERTURBATIONS | |
| # This dict maps the name of a perturbation to its perturbation and filter | |
| # functions. The names and functions in this dict are used throughout the | |
| # repo. | |
| ############################################################################## | |
| PERTURBATIONS = { | |
| "shuffle_control": { | |
| "perturbation_function": partial(perturb_shuffle_deterministic, seed=None, shuffle=False), | |
| "affect_function": affect_shuffle, | |
| "filter_function": filter_shuffle, | |
| "qwen_tokenizer": qwen_original_tokenizer, | |
| "color": "#606060", | |
| }, | |
| "shuffle_nondeterministic": { | |
| "perturbation_function": partial(perturb_shuffle_nondeterministic, rng=default_rng(0)), | |
| "affect_function": affect_shuffle, | |
| "filter_function": filter_shuffle, | |
| "qwen_tokenizer": qwen_original_tokenizer, | |
| "color": "#E8384F", | |
| }, | |
| "shuffle_deterministic21": { | |
| "perturbation_function": partial(perturb_shuffle_deterministic, seed=21, shuffle=True), | |
| "affect_function": affect_shuffle, | |
| "filter_function": filter_shuffle, | |
| "qwen_tokenizer": qwen_original_tokenizer, | |
| "color": "#FFB000", | |
| }, | |
| "shuffle_deterministic57": { | |
| "perturbation_function": partial(perturb_shuffle_deterministic, seed=57, shuffle=True), | |
| "affect_function": affect_shuffle, | |
| "filter_function": filter_shuffle, | |
| "qwen_tokenizer": qwen_original_tokenizer, | |
| "color": "#8db000", | |
| }, | |
| "shuffle_deterministic84": { | |
| "perturbation_function": partial(perturb_shuffle_deterministic, seed=84, shuffle=True), | |
| "affect_function": affect_shuffle, | |
| "filter_function": filter_shuffle, | |
| "qwen_tokenizer": qwen_original_tokenizer, | |
| "color": "#62BB35", | |
| }, | |
| "shuffle_local3": { | |
| "perturbation_function": partial(perturb_shuffle_local, seed=0, window=3), | |
| "affect_function": affect_shuffle, | |
| "filter_function": filter_shuffle, | |
| "qwen_tokenizer": qwen_original_tokenizer, | |
| "color": "#208EA3", | |
| }, | |
| "shuffle_local5": { | |
| "perturbation_function": partial(perturb_shuffle_local, seed=0, window=5), | |
| "affect_function": affect_shuffle, | |
| "filter_function": filter_shuffle, | |
| "qwen_tokenizer": qwen_original_tokenizer, | |
| "color": "#4178BC", | |
| }, | |
| "shuffle_local10": { | |
| "perturbation_function": partial(perturb_shuffle_local, seed=0, window=10), | |
| "affect_function": affect_shuffle, | |
| "filter_function": filter_shuffle, | |
| "qwen_tokenizer": qwen_original_tokenizer, | |
| "color": "#AA71FF", | |
| }, | |
| "shuffle_even_odd": { | |
| "perturbation_function": perturb_shuffle_even_odd, | |
| "affect_function": affect_shuffle, | |
| "filter_function": filter_shuffle, | |
| "qwen_tokenizer": qwen_original_tokenizer, | |
| "color": "#E37CFF", | |
| }, | |
| "reverse_control": { | |
| "perturbation_function": partial(perturb_reverse, rng=default_rng(21), reverse=False, full=False), | |
| "affect_function": affect_reverse, | |
| "filter_function": filter_reverse, | |
| "qwen_tokenizer": qwen_rev_tokenizer, | |
| "color": "#606060", | |
| }, | |
| "reverse_partial": { | |
| "perturbation_function": partial(perturb_reverse, rng=default_rng(21), reverse=True, full=False), | |
| "affect_function": affect_reverse, | |
| "filter_function": filter_reverse, | |
| "qwen_tokenizer": qwen_rev_tokenizer, | |
| "color": "#E5A836", | |
| }, | |
| "reverse_full": { | |
| "perturbation_function": partial(perturb_reverse, rng=default_rng(21), reverse=False, full=True), | |
| "affect_function": affect_reverse, | |
| "filter_function": filter_reverse, | |
| "qwen_tokenizer": qwen_rev_tokenizer, | |
| "color": "#A348A6", | |
| }, | |
| "hop_control": { | |
| "perturbation_function": perturb_hop_control, | |
| "affect_function": affect_hop, | |
| "filter_function": filter_hop, | |
| "qwen_tokenizer": qwen_hop_tokenizer, | |
| "color": "#606060", | |
| }, | |
| "hop_tokens4": { | |
| "perturbation_function": perturb_hop_tokens4, | |
| "affect_function": affect_hop, | |
| "filter_function": filter_hop, | |
| "qwen_tokenizer": qwen_hop_tokenizer, | |
| "color": "#fa8128", | |
| }, | |
| "hop_words4": { | |
| "perturbation_function": perturb_hop_words4, | |
| "affect_function": affect_hop, | |
| "filter_function": filter_hop, | |
| "qwen_tokenizer": qwen_hop_tokenizer, | |
| "color": "#03a0ff", | |
| }, | |
| } | |
| # # utils.py | |
| # # Author: Julie Kallini | |
| # from collections import deque | |
| # from string import punctuation | |
| # from transformers import AutoTokenizer, AddedToken | |
| # from functools import partial | |
| # from numpy.random import default_rng | |
| # from nltk.tree import ParentedTree | |
| # import torch | |
| # ############################################################################## | |
| # # CONSTANTS | |
| # ############################################################################## | |
| # BABYLM_SPLITS = ['100M', '10M', 'dev', 'test', 'unittest'] | |
| # # Yj: 用于在参数解析和数据加载时指定数据集 | |
| # # 影响数据集的预处理过程,如生成训练、开发、测试和单元测试集。 | |
| # SEEDS = [21, 57, 84] | |
| # CHECKPOINTS = list(range(50, 501, 50)) | |
| # GENRES = { | |
| # "aochildes": "CHILDES", | |
| # "bnc_spoken": "British National Corpus (BNC)", | |
| # "cbt": "Children’s Book Test", | |
| # "children_stories": "Children’s Stories Text Corpus", | |
| # "gutenberg": "Standardized Project Gutenberg Corpus", | |
| # "open_subtitles": "OpenSubtitles", | |
| # "qed": "QCRI Educational Domain Corpus", | |
| # "simple_wikipedia": "Simple Wikipedia", | |
| # "switchboard": "Switchboard Dialog Act Corpus", | |
| # "wikipedia": "Wikipedia" | |
| # } | |
| # CHECKPOINT_WRITE_PATH = "/nlp/scr3/nlp/llms-in-llms/babylm_models" | |
| # CHECKPOINT_READ_PATH = "/nlp/scr3/nlp/llms-in-llms/babylm_models" | |
| # # BABYLM_DATA_PATH = "/nlp/scr3/nlp/llms-in-llms/babylm_data" | |
| # BABYLM_DATA_PATH = "." | |
| # MARKER_HOP_SING = "🅂" | |
| # MARKER_HOP_PLUR = "🄿" | |
| # MARKER_REV = "🅁" | |
| # BOS_TOKEN = "<BOS_TOKEN>" | |
| # PART_TOKENS = set(["n't", "'ll", "'s", "'re", "'ve", "'m"]) | |
| # PUNCT_TOKENS = set(punctuation) | |
| # ############################################################################## | |
| # # PARENS MODELS (Structurally-pretrained) | |
| # ############################################################################## | |
| # PAREN_MODEL_PATH = "/u/scr/isabelvp//tilt-stuff/tilt-finetuning/pretrained_checkpoints/" | |
| # PAREN_MODELS = { | |
| # "CROSS": "flat-parens_vocab500-uniform_deplength-nesting-nolimit", | |
| # "NEST": "nested-parens0.49_vocab500-uniform", | |
| # "RAND": "random_vocab500-uniform", | |
| # } | |
| # ############################################################################## | |
| # # HELPER FUNCTIONS | |
| # ############################################################################## | |
| # def write_file(directory, filename, lines): | |
| # f = open(directory + filename, "w") | |
| # f.writelines(lines) | |
| # f.close() | |
| # def get_gpt2_tokenizer_with_markers(marker_list): | |
| # tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
| # # If no new markers to add, return normal tokenizer | |
| # if len(marker_list) == 0: | |
| # return tokenizer | |
| # # Create tokens and return modified tokenizer | |
| # new_tokens = [] | |
| # for marker in marker_list: | |
| # new_tokens.append(AddedToken(marker, lstrip=True, rstrip=False)) | |
| # tokenizer.add_tokens(new_tokens) | |
| # return tokenizer | |
| # gpt2_original_tokenizer = get_gpt2_tokenizer_with_markers([]) | |
| # # GPT-2 hop tokenization | |
| # gpt2_hop_tokenizer = get_gpt2_tokenizer_with_markers( | |
| # [MARKER_HOP_SING, MARKER_HOP_PLUR]) | |
| # # Get ids of marker tokens | |
| # marker_sg_token = gpt2_hop_tokenizer.get_added_vocab()[ | |
| # MARKER_HOP_SING] | |
| # # Yj:获取分词器中所有自定义添加的标记及其对应的 token ID | |
| # marker_pl_token = gpt2_hop_tokenizer.get_added_vocab()[ | |
| # MARKER_HOP_PLUR] | |
| # # GPT-2 reverse tokenization | |
| # gpt2_rev_tokenizer = get_gpt2_tokenizer_with_markers( | |
| # [MARKER_REV]) | |
| # # Get ids of marker tokens | |
| # marker_rev_token = gpt2_rev_tokenizer.get_added_vocab()[ | |
| # MARKER_REV] | |
| # # GPT-2 determiner tokenization | |
| # gpt2_det_tokenizer = get_gpt2_tokenizer_with_markers( | |
| # [BOS_TOKEN]) | |
| # # Get id of BOS token | |
| # bos_token_id = gpt2_det_tokenizer.get_added_vocab()[BOS_TOKEN] | |
| # MARKER_TOKEN_IDS = [marker_sg_token, marker_pl_token, marker_rev_token] | |
| # def compute_surprisals(model, input_ids): | |
| # # Get the log probabilities from the model | |
| # with torch.no_grad(): | |
| # outputs = model(input_ids) | |
| # logits = outputs.logits[:, :-1] | |
| # shifted_input_ids = input_ids[:, 1:] | |
| # # Get the log probabilities for the actual next tokens | |
| # log_probs = torch.log2(torch.nn.functional.softmax(logits, dim=-1)) | |
| # true_log_probs = log_probs.gather( | |
| # 2, shifted_input_ids.unsqueeze(-1)).squeeze(-1) | |
| # # Get the negative log probabilities | |
| # neg_log_probs = (-true_log_probs).tolist() | |
| # surprisals = [[None] + probs for probs in neg_log_probs] | |
| # return surprisals | |
| # def compute_token_probabilities(model, input_ids, token_id, pad_token_id): | |
| # # Get the log probabilities from the model | |
| # with torch.no_grad(): | |
| # outputs = model(input_ids) | |
| # logits = outputs.logits[:, :-1] | |
| # probs = torch.nn.functional.softmax(logits, dim=-1) | |
| # # Get the probabilities for the specified token at each position | |
| # token_probs = probs[:, :, token_id] | |
| # # Convert to list and add None at the beginning to align with input tokens | |
| # # Put null probability for instances of pad token | |
| # token_probs_list = [] | |
| # for batch_i, probs in enumerate(token_probs): | |
| # input_ids_seq = input_ids[batch_i].tolist() + [pad_token_id] | |
| # filtered = [p if input_ids_seq[pos_i+1] != | |
| # pad_token_id else None for pos_i, p in enumerate(probs.tolist())] | |
| # token_probs_list.append([None] + filtered) | |
| # return token_probs_list | |
| # def merge_part_tokens(words): | |
| # result = [] | |
| # for s in words: | |
| # if result and s in PART_TOKENS and len(result) > 0: | |
| # result[-1] += s | |
| # else: | |
| # result.append(s) | |
| # return result | |
| # def __affect_hop_word(word): | |
| # return word["feats"] and "Person=3" in word["feats"] \ | |
| # and "Tense=Pres" in word["feats"] \ | |
| # and "VerbForm=Fin" in word["feats"] \ | |
| # and "Number" in word["feats"] | |
| # def __perturb_hop_words(sent, num_hops, marker_sg, marker_pl): | |
| # perturbed_tokens, _ = __perturb_hop_words_complete_hops( | |
| # sent, num_hops, marker_sg, marker_pl) | |
| # return perturbed_tokens | |
| # def check_word_hops_completed(sent, num_hops=4, marker=MARKER_HOP_SING): | |
| # _, hops_completed = __perturb_hop_words_complete_hops( | |
| # sent, num_hops, marker, marker) | |
| # return hops_completed | |
| # def __perturb_hop_words_complete_hops(sent, num_hops, marker_sg, marker_pl): | |
| # word_annotations = sent["word_annotations"].copy() | |
| # word_annotations.reverse() | |
| # hop_completed = [] | |
| # new_sent = [] | |
| # for word in word_annotations: | |
| # # Identify 3.pres verbs | |
| # if __affect_hop_word(word): | |
| # # Lemmatize verb if possible | |
| # new_sent.append( | |
| # word["lemma"] if word["lemma"] is not None else word["text"]) | |
| # # Marker hopping logic | |
| # insert_index = len(new_sent)-1 | |
| # skipped_words = 0 | |
| # while skipped_words < num_hops and insert_index > 0: | |
| # # Handle edge case when punctuation (or sequence of | |
| # # punctuation) begin the sentence | |
| # if (not any([c.isalnum() for c in | |
| # "".join(new_sent[:insert_index])])): | |
| # break | |
| # # Yj: 如果字符串中不存在任何字母或数字字符(即都是标点、空格等) | |
| # # Count word as skipped if it is not a special token | |
| # if (new_sent[insert_index] not in PART_TOKENS) and \ | |
| # (not set(new_sent[insert_index]).issubset(PUNCT_TOKENS)): | |
| # skipped_words += 1 | |
| # insert_index -= 1 | |
| # # Handle edge case when insert index is punctuation (and this is not | |
| # # sentence-initial punctuation) | |
| # if any([c.isalnum() for c in | |
| # "".join(new_sent[:insert_index])]): | |
| # while insert_index != 0 and (new_sent[insert_index] in PART_TOKENS | |
| # or set(new_sent[insert_index]).issubset(PUNCT_TOKENS)): | |
| # insert_index -= 1 | |
| # # Handle edge case when token before insert index is part/aux token | |
| # if insert_index != 0 and new_sent[insert_index-1] in PART_TOKENS: | |
| # insert_index -= 1 | |
| # # Log if this sentence had all full hops | |
| # hop_completed.append(skipped_words == num_hops) | |
| # # Use correct marker for singular vs. plural | |
| # if "Number=Sing" in word["feats"]: | |
| # new_sent.insert(insert_index, marker_sg) | |
| # elif "Number=Plur" in word["feats"]: | |
| # new_sent.insert(insert_index, marker_pl) | |
| # else: | |
| # raise Exception( | |
| # "Number not in verb features\n" + sent["sent_text"]) | |
| # else: | |
| # new_sent.append(word["text"]) | |
| # new_sent.reverse() | |
| # sent_string = " ".join(merge_part_tokens(new_sent)) | |
| # tokens = gpt2_hop_tokenizer.encode(sent_string) | |
| # return tokens, all(hop_completed) and len(hop_completed) > 0 | |
| # def __perturb_hop_tokens(sent, num_hops): | |
| # word_annotations = sent["word_annotations"].copy() | |
| # word_annotations.reverse() | |
| # new_sent = deque() | |
| # tokens = [] | |
| # for word in word_annotations: | |
| # # Identify 3.pres verbs | |
| # if __affect_hop_word(word): | |
| # # Lemmatize verb if possible | |
| # lemma = word["lemma"] if word["lemma"] is not None else word["text"] | |
| # if len(new_sent) > 0 and new_sent[0] in PART_TOKENS: | |
| # lemma = lemma + new_sent[0] | |
| # new_sent.popleft() | |
| # if len(new_sent) > 0: | |
| # sent_string = " ".join(merge_part_tokens(new_sent)) | |
| # tokens = gpt2_hop_tokenizer.encode( | |
| # " " + sent_string) + tokens | |
| # # Use correct marker for singular vs. plural | |
| # if "Number=Sing" in word["feats"]: | |
| # tokens.insert(num_hops, marker_sg_token) | |
| # elif "Number=Plur" in word["feats"]: | |
| # tokens.insert(num_hops, marker_pl_token) | |
| # else: | |
| # raise Exception( | |
| # "Number not in verb features\n" + sent["sent_text"]) | |
| # new_sent = deque() | |
| # new_sent.append(lemma) | |
| # else: | |
| # new_sent.appendleft(word["text"]) | |
| # if len(new_sent) > 0: | |
| # sent_string = " ".join(merge_part_tokens(new_sent)) | |
| # tokens = gpt2_hop_tokenizer.encode(sent_string) + tokens | |
| # return tokens | |
| # def __perturb_reverse(sent, rng, reverse, full): | |
| # # Get sentence text and GPT-2 tokens | |
| # tokens = gpt2_rev_tokenizer.encode(sent["sent_text"]) | |
| # # Pick random index to insert REV token | |
| # i = rng.choice(len(tokens)+1) | |
| # tokens.insert(i, marker_rev_token) | |
| # # Extract tokens before/after the marker, and reverse tokens after | |
| # tokens_before = tokens[:i+1] | |
| # tokens_after = tokens[i+1:] | |
| # if reverse: | |
| # tokens_after.reverse() | |
| # new_tokens = tokens_before + tokens_after | |
| # if full: | |
| # assert not reverse | |
| # new_tokens.reverse() | |
| # return new_tokens | |
| # def __perturb_shuffle_deterministic(sent, seed, shuffle): | |
| # # Get sentence text and GPT-2 tokens | |
| # tokens = gpt2_original_tokenizer.encode(sent["sent_text"]) | |
| # if shuffle: | |
| # default_rng(seed).shuffle(tokens) | |
| # return tokens | |
| # def __perturb_shuffle_nondeterministic(sent, rng): | |
| # # Get sentence text and GPT-2 tokens | |
| # tokens = gpt2_original_tokenizer.encode(sent["sent_text"]) | |
| # rng.shuffle(tokens) | |
| # return tokens | |
| # def __perturb_shuffle_local(sent, seed, window=5): | |
| # # Get sentence text and GPT-2 tokens | |
| # tokens = gpt2_original_tokenizer.encode(sent["sent_text"]) | |
| # # Shuffle tokens in batches of size window | |
| # shuffled_tokens = [] | |
| # for i in range(0, len(tokens), window): | |
| # batch = tokens[i:i+window].copy() | |
| # default_rng(seed).shuffle(batch) | |
| # shuffled_tokens += batch | |
| # return shuffled_tokens | |
| # def __perturb_shuffle_even_odd(sent): | |
| # # Get sentence text and GPT-2 tokens | |
| # tokens = gpt2_original_tokenizer.encode(sent["sent_text"]) | |
| # even = [tok for i, tok in enumerate(tokens) if i % 2 == 0] | |
| # odd = [tok for i, tok in enumerate(tokens) if i % 2 != 0] | |
| # return even + odd | |
| # ############################################################################## | |
| # # AFFECT FUNCTIONS | |
| # # These functions define when a perturbation has been applied to a sentence | |
| # # not. This is used for identifying which test sentences have been | |
| # # altered to separate affected vs. unaffected senences. Affect functions are | |
| # # functions of the input sentence object and return a boolean. | |
| # ############################################################################## | |
| # def affect_hop(sent): | |
| # return any([__affect_hop_word(word) for word in sent['word_annotations']]) \ | |
| # and sent["constituency_parse"] is not None | |
| # def affect_reverse(sent): | |
| # return True | |
| # def affect_shuffle(sent): | |
| # return True | |
| # def affect_none(sent): | |
| # return False | |
| # ############################################################################## | |
| # # FILTER FUNCTIONS | |
| # # These functions define when an affected sentence should be included in the | |
| # # final dataset. For instance, hop perturbations where the marker is placed | |
| # # at the end of the sentence should be excluded. A filter function returns | |
| # # True if an affected sentence should be included in the dataset. | |
| # ############################################################################## | |
| # def filter_hop(sent): | |
| # # Assertion needed since filter function is only defined for affected | |
| # # sentences | |
| # assert (affect_hop(sent)) | |
| # return check_word_hops_completed(sent, 4) | |
| # def filter_reverse(sent): | |
| # return True | |
| # def filter_shuffle(sent): | |
| # tokens = gpt2_original_tokenizer.encode(sent["sent_text"]) | |
| # return len(tokens) > 1 and len(tokens) <= 350 | |
| # def filter_none(sent): | |
| # return False | |
| # ############################################################################## | |
| # # PERTURBATION FUNCTIONS | |
| # # These functions define how a perturbation will affect a sentence. They | |
| # # take in a sentence object and an optional marker | |
| # # for verb transformations. They return a string representing the transformed | |
| # # sentence. | |
| # ############################################################################## | |
| # def perturb_hop_words4(sent): | |
| # return __perturb_hop_words(sent, 4, MARKER_HOP_SING, MARKER_HOP_PLUR) | |
| # def perturb_hop_tokens4(sent): | |
| # return __perturb_hop_tokens(sent, 4) | |
| # def perturb_hop_control(sent): | |
| # return __perturb_hop_tokens(sent, 0) | |
| # def perturb_reverse(sent, rng, reverse=True, full=False): | |
| # return __perturb_reverse(sent, rng, reverse, full) | |
| # def perturb_shuffle_deterministic(sent, seed=None, shuffle=True): | |
| # return __perturb_shuffle_deterministic(sent, seed, shuffle) | |
| # def perturb_shuffle_nondeterministic(sent, rng): | |
| # return __perturb_shuffle_nondeterministic(sent, rng) | |
| # def perturb_shuffle_local(sent, seed, window): | |
| # return __perturb_shuffle_local(sent, seed, window) | |
| # def perturb_shuffle_even_odd(sent): | |
| # return __perturb_shuffle_even_odd(sent) | |
| # ############################################################################## | |
| # # PERTURBATIONS | |
| # # This dict maps the name of a perturbation to its perturbation and filter | |
| # # functions. The names and functions in this dict are used throughout the | |
| # # repo. | |
| # ############################################################################## | |
| # PERTURBATIONS = { | |
| # "shuffle_control": { | |
| # "perturbation_function": partial(perturb_shuffle_deterministic, seed=None, shuffle=False), | |
| # "affect_function": affect_shuffle, | |
| # "filter_function": filter_shuffle, | |
| # "gpt2_tokenizer": gpt2_original_tokenizer, | |
| # "color": "#606060", | |
| # }, | |
| # "shuffle_nondeterministic": { | |
| # "perturbation_function": partial(perturb_shuffle_nondeterministic, rng=default_rng(0)), | |
| # "affect_function": affect_shuffle, | |
| # "filter_function": filter_shuffle, | |
| # "gpt2_tokenizer": gpt2_original_tokenizer, | |
| # "color": "#E8384F", | |
| # }, | |
| # "shuffle_deterministic21": { | |
| # "perturbation_function": partial(perturb_shuffle_deterministic, seed=21, shuffle=True), | |
| # "affect_function": affect_shuffle, | |
| # "filter_function": filter_shuffle, | |
| # "gpt2_tokenizer": gpt2_original_tokenizer, | |
| # "color": "#FFB000", | |
| # }, | |
| # "shuffle_deterministic57": { | |
| # "perturbation_function": partial(perturb_shuffle_deterministic, seed=57, shuffle=True), | |
| # "affect_function": affect_shuffle, | |
| # "filter_function": filter_shuffle, | |
| # "gpt2_tokenizer": gpt2_original_tokenizer, | |
| # "color": "#8db000", | |
| # }, | |
| # "shuffle_deterministic84": { | |
| # "perturbation_function": partial(perturb_shuffle_deterministic, seed=84, shuffle=True), | |
| # "affect_function": affect_shuffle, | |
| # "filter_function": filter_shuffle, | |
| # "gpt2_tokenizer": gpt2_original_tokenizer, | |
| # "color": "#62BB35", | |
| # }, | |
| # "shuffle_local3": { | |
| # "perturbation_function": partial(perturb_shuffle_local, seed=0, window=3), | |
| # "affect_function": affect_shuffle, | |
| # "filter_function": filter_shuffle, | |
| # "gpt2_tokenizer": gpt2_original_tokenizer, | |
| # "color": "#208EA3", | |
| # }, | |
| # "shuffle_local5": { | |
| # "perturbation_function": partial(perturb_shuffle_local, seed=0, window=5), | |
| # "affect_function": affect_shuffle, | |
| # "filter_function": filter_shuffle, | |
| # "gpt2_tokenizer": gpt2_original_tokenizer, | |
| # "color": "#4178BC", | |
| # }, | |
| # "shuffle_local10": { | |
| # "perturbation_function": partial(perturb_shuffle_local, seed=0, window=10), | |
| # "affect_function": affect_shuffle, | |
| # "filter_function": filter_shuffle, | |
| # "gpt2_tokenizer": gpt2_original_tokenizer, | |
| # "color": "#AA71FF", | |
| # }, | |
| # "shuffle_even_odd": { | |
| # "perturbation_function": perturb_shuffle_even_odd, | |
| # "affect_function": affect_shuffle, | |
| # "filter_function": filter_shuffle, | |
| # "gpt2_tokenizer": gpt2_original_tokenizer, | |
| # "color": "#E37CFF", | |
| # }, | |
| # "reverse_control": { | |
| # "perturbation_function": partial(perturb_reverse, rng=default_rng(21), reverse=False, full=False), | |
| # "affect_function": affect_reverse, | |
| # "filter_function": filter_reverse, | |
| # "gpt2_tokenizer": gpt2_rev_tokenizer, | |
| # "color": "#606060", | |
| # }, | |
| # "reverse_partial": { | |
| # "perturbation_function": partial(perturb_reverse, rng=default_rng(21), reverse=True, full=False), | |
| # "affect_function": affect_reverse, | |
| # "filter_function": filter_reverse, | |
| # "gpt2_tokenizer": gpt2_rev_tokenizer, | |
| # "color": "#E5A836", | |
| # }, | |
| # "reverse_full": { | |
| # "perturbation_function": partial(perturb_reverse, rng=default_rng(21), reverse=False, full=True), | |
| # "affect_function": affect_reverse, | |
| # "filter_function": filter_reverse, | |
| # "gpt2_tokenizer": gpt2_rev_tokenizer, | |
| # "color": "#A348A6", | |
| # }, | |
| # "hop_control": { | |
| # "perturbation_function": perturb_hop_control, | |
| # "affect_function": affect_hop, | |
| # "filter_function": filter_hop, | |
| # "gpt2_tokenizer": gpt2_hop_tokenizer, | |
| # "color": "#606060", | |
| # }, | |
| # "hop_tokens4": { | |
| # "perturbation_function": perturb_hop_tokens4, | |
| # "affect_function": affect_hop, | |
| # "filter_function": filter_hop, | |
| # "gpt2_tokenizer": gpt2_hop_tokenizer, | |
| # "color": "#fa8128", | |
| # }, | |
| # "hop_words4": { | |
| # "perturbation_function": perturb_hop_words4, | |
| # "affect_function": affect_hop, | |
| # "filter_function": filter_hop, | |
| # "gpt2_tokenizer": gpt2_hop_tokenizer, | |
| # "color": "#03a0ff", | |
| # }, | |
| # } | |