|
|
|
|
|
""" |
|
preprocess input data into feature and stores binary as python shelve DB |
|
each chunk is gzipped JSON string |
|
""" |
|
import argparse |
|
import gzip |
|
import json |
|
import subprocess as sp |
|
import shelve |
|
import os |
|
from os.path import dirname, exists, join |
|
|
|
import torch |
|
from lsp_model import GPT2Tokenizer |
|
from tqdm import tqdm |
|
|
|
from env import END_OF_TEXT_TOKEN |
|
from gpt2_training.train_utils import InputFeatures_train as InputFeatures |
|
|
|
|
|
def _get_file_len(corpus): |
|
n_line = int(sp.check_output(f"wc -l {corpus}".split(), |
|
universal_newlines=True).split()[0]) |
|
return n_line |
|
|
|
|
|
def _norm_text(text): |
|
w, *toks = text.strip().split() |
|
try: |
|
w = float(w) |
|
except Exception: |
|
toks = [w] + toks |
|
w = 1.0 |
|
return w, ' '.join(toks) |
|
|
|
|
|
def _get_inputs_from_text(text, tokenizer): |
|
srcs, tgt = text.strip().split('\t') |
|
weights = [] |
|
inputs = [] |
|
for src in srcs.split(' EOS '): |
|
src_weight, src = _norm_text(src) |
|
context_id = tokenizer.encode(src) |
|
weights.append(src_weight) |
|
inputs.append(context_id) |
|
tgt_weight, tgt = _norm_text(tgt) |
|
if tgt_weight != 0: |
|
response_id = tokenizer.encode(tgt) |
|
weights.append(tgt_weight) |
|
inputs.append(response_id) |
|
return weights, inputs |
|
|
|
|
|
def _make_features(id_, weights, inputs, tokenizer, max_len): |
|
end_of_text_id = tokenizer.encoder[END_OF_TEXT_TOKEN] |
|
features = [] |
|
sents = [] |
|
ws = [] |
|
len_ = 0 |
|
i = 0 |
|
for ids, w in zip(inputs, weights): |
|
if len(ids) > max_len: |
|
if len(sents) >= 2: |
|
feat = _make_feature(id_ + i, sents, ws, end_of_text_id) |
|
if feat is not None: |
|
features.append(feat) |
|
i += 1 |
|
len_ = 0 |
|
sents = [] |
|
ws = [] |
|
continue |
|
elif len_ > max_len: |
|
feat = _make_feature(id_ + i, sents, ws, end_of_text_id) |
|
if feat is not None: |
|
features.append(feat) |
|
i += 1 |
|
len_ = len(sents[-1]) + 1 |
|
sents = sents[-1:] |
|
ws = ws[-1:] |
|
len_ += (len(ids) + 1) |
|
sents.append(ids) |
|
ws.append(w) |
|
if len(sents) >= 2: |
|
feat = _make_feature(id_ + i, sents, ws, end_of_text_id) |
|
if feat is not None: |
|
features.append(feat) |
|
|
|
return features |
|
|
|
|
|
def _make_feature(id_, sents, ws, eos): |
|
if all(w == 0 for w in ws[1:]): |
|
return None |
|
input_ids = [i for s in sents for i in s+[eos]][:-1] |
|
lm_labels = [] |
|
weights = [] |
|
token_type_ids = [] |
|
for i, (s, w) in enumerate(zip(sents, ws)): |
|
if i == 0: |
|
lm_labels += [-1] * len(s) |
|
weights += [0.0] * len(s) |
|
token_type_ids += [0] * len(s) |
|
continue |
|
|
|
token_type_ids += [i] * (len(s) + 1) |
|
if w == 0.0: |
|
lm_labels += [-1] * (len(s) + 1) |
|
weights += [0.0] * (len(s) + 1) |
|
else: |
|
lm_labels += (s + [eos]) |
|
weights += [w] * (len(s) + 1) |
|
|
|
|
|
i = len(lm_labels) - 1 |
|
while i >= 0: |
|
if lm_labels[i] != -1: |
|
break |
|
i -= 1 |
|
input_ids = input_ids[:i+1] |
|
lm_labels = lm_labels[:i+1] |
|
weights = weights[:i+1] |
|
token_type_ids = token_type_ids[:i+1] |
|
|
|
|
|
while len(input_ids) % 8 != 0: |
|
input_ids.append(0) |
|
token_type_ids.append(0) |
|
lm_labels.append(-1) |
|
weights.append(0.0) |
|
|
|
position_ids = list(range(len(input_ids))) |
|
assert (len(input_ids) == len(position_ids) == len(token_type_ids) |
|
== len(lm_labels) == len(weights)) |
|
assert len(input_ids) % 8 == 0 |
|
if len(input_ids) == 0: |
|
import pdb |
|
pdb.set_trace() |
|
feature = InputFeatures(id_, input_ids, position_ids, token_type_ids, |
|
lm_labels, weights) |
|
return feature |
|
|
|
|
|
def main(args): |
|
toker = GPT2Tokenizer.from_pretrained('gpt2') |
|
attrs = [] |
|
if args.reverse: |
|
attrs.append('reverse') |
|
if args.two_turn: |
|
attrs.append('2turn') |
|
if attrs: |
|
db_path = (f'{args.corpus[:-4]}.{args.max_seq_len}len.' |
|
f'{".".join(attrs)}.db/db') |
|
else: |
|
db_path = f'{args.corpus[:-4]}.{args.max_seq_len}len.db/db' |
|
if exists(dirname(db_path)): |
|
raise ValueError('Found existing DB, please backup') |
|
else: |
|
os.makedirs(dirname(db_path)) |
|
with open(args.corpus, "r", encoding="utf-8") as reader, \ |
|
shelve.open(db_path, 'n') as db: |
|
chunk = [] |
|
n_chunk = 0 |
|
n_example = 0 |
|
for line in tqdm(reader, total=_get_file_len(args.corpus)): |
|
try: |
|
if len(chunk) >= args.chunk_size: |
|
|
|
db[f'chunk_{n_chunk}'] = gzip.compress( |
|
json.dumps(chunk[:args.chunk_size]).encode('utf-8')) |
|
chunk = chunk[args.chunk_size:] |
|
n_chunk += 1 |
|
|
|
weights, inputs = _get_inputs_from_text(line, toker) |
|
if args.reverse: |
|
weights = list(reversed(weights)) |
|
inputs = list(reversed(inputs)) |
|
if args.two_turn: |
|
weights = weights[:2] |
|
inputs = inputs[:2] |
|
if len(weights) < 2: |
|
continue |
|
features = _make_features(n_example, weights, inputs, |
|
toker, args.max_seq_len) |
|
for feature in features: |
|
chunk.append(vars(feature)) |
|
n_example += 1 |
|
except Exception as e: |
|
print('!!! prepro exception !!!', e) |
|
continue |
|
|
|
db[f'chunk_{n_chunk}'] = gzip.compress( |
|
json.dumps(chunk).encode('utf-8')) |
|
|
|
meta = {'n_example': n_example, |
|
'chunk_size': args.chunk_size, |
|
'max_seq_len': args.max_seq_len, |
|
'reverse': args.reverse, |
|
'two_turn': args.two_turn} |
|
with open(join(dirname(db_path), 'meta.json'), 'w') as writer: |
|
json.dump(meta, writer, indent=4) |
|
torch.save(toker, join(dirname(db_path), 'tokenizer.pt')) |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--corpus', required=True, |
|
help='file name of training corpus (should be .tsv)') |
|
parser.add_argument('--chunk_size', type=int, default=65536, |
|
help='num of data examples in a storing chunk') |
|
parser.add_argument('--max_seq_len', type=int, default=128, |
|
help='discard data longer than this') |
|
parser.add_argument('--reverse', action='store_true', |
|
help='reverse the src tgt') |
|
parser.add_argument('--two_turn', action='store_true', |
|
help='take only the first 2 turns') |
|
|
|
args = parser.parse_args() |
|
|
|
main(args) |
|
|