|
|
import json |
|
|
import os.path |
|
|
import pickle |
|
|
import yaml |
|
|
import re |
|
|
|
|
|
|
|
|
try: |
|
|
import nltk |
|
|
def tokenize( |
|
|
text, |
|
|
): |
|
|
tokens = ' '.join(nltk.word_tokenize(text)).replace('``', '"').replace("''", '"').split() |
|
|
return tokens |
|
|
|
|
|
from nltk.tokenize.treebank import TreebankWordDetokenizer |
|
|
def detokenize(text: str): |
|
|
tokens = text.split() |
|
|
|
|
|
text = TreebankWordDetokenizer().detokenize(tokens) |
|
|
text = text.replace(' . ', '. ') |
|
|
return text |
|
|
except Exception: |
|
|
print('-> Cannot import nltk') |
|
|
|
|
|
|
|
|
def is_disjoint( |
|
|
span1, |
|
|
span2, |
|
|
): |
|
|
""" |
|
|
check joint span for exclude spans |
|
|
:param span1: |
|
|
:param span2: |
|
|
:return: |
|
|
""" |
|
|
return (span1[0] - span2[1] + 1) * (span2[0] - span1[1] + 1) < 0 |
|
|
|
|
|
|
|
|
def remove_accent(text): |
|
|
text = re.sub(r'[àáạảãâầấậẩẫăằắặẳẵ]', 'a', text) |
|
|
text = re.sub(r'[ÀÁẠẢÃĂẰẮẶẲẴÂẦẤẬẨẪ]', 'A', text) |
|
|
text = re.sub(r'[èéẹẻẽêềếệểễ]', 'e', text) |
|
|
text = re.sub(r'[ÈÉẸẺẼÊỀẾỆỂỄ]', 'E', text) |
|
|
text = re.sub(r'[ìíịỉĩ]', 'i', text) |
|
|
text = re.sub(r'[ÌÍỊỈĨ]', 'I', text) |
|
|
text = re.sub(r'[òóọỏõôồốộổỗơờớợởỡ]', 'o', text) |
|
|
text = re.sub(r'[ÒÓỌỎÕÔỒỐỘỔỖƠỜỚỢỞỠ]', 'O', text) |
|
|
text = re.sub(r'[ùúụủũưừứựửữ]', 'u', text) |
|
|
text = re.sub(r'[ƯỪỨỰỬỮÙÚỤỦŨ]', 'U', text) |
|
|
text = re.sub(r'[ỳýỵỷỹ]', 'y', text) |
|
|
text = re.sub(r'[ỲÝỴỶỸ]', 'Y', text) |
|
|
text = re.sub(r'đ', 'd', text) |
|
|
text = re.sub(r'Đ', 'D', text) |
|
|
return text |
|
|
|
|
|
|
|
|
def remove_timbre(text): |
|
|
text = re.sub(r'[àáạảã]', 'a', text) |
|
|
text = re.sub(r'[ÀÁẠẢÃ]', 'A', text) |
|
|
|
|
|
text = re.sub(r'[âầấậẩẫ]', 'â', text) |
|
|
text = re.sub(r'[ÂẦẤẬẨẪ]', 'Â', text) |
|
|
|
|
|
text = re.sub(r'[ăằắặẳẵ]', 'ă', text) |
|
|
text = re.sub(r'[ĂẰẮẶẲẴ]', 'Ă', text) |
|
|
|
|
|
text = re.sub(r'[èéẹẻẽ]', 'e', text) |
|
|
text = re.sub(r'[ÈÉẸẺẼ]', 'E', text) |
|
|
|
|
|
text = re.sub(r'[êềếệểễ]', 'ê', text) |
|
|
text = re.sub(r'[ÊỀẾỆỂỄ]', 'Ê', text) |
|
|
|
|
|
text = re.sub(r'[ìíịỉĩ]', 'i', text) |
|
|
text = re.sub(r'[ÌÍỊỈĨ]', 'I', text) |
|
|
|
|
|
text = re.sub(r'[òóọỏõ]', 'o', text) |
|
|
text = re.sub(r'[ÒÓỌỎÕ]', 'O', text) |
|
|
|
|
|
text = re.sub(r'[ôồốộổỗ]', 'ô', text) |
|
|
text = re.sub(r'[ÔỒỐỘỔỖ]', 'Ô', text) |
|
|
|
|
|
text = re.sub(r'[ơờớợởỡ]', 'ơ', text) |
|
|
text = re.sub(r'[ƠỜỚỢỞỠ]', 'Ơ', text) |
|
|
|
|
|
text = re.sub(r'[ùúụủũ]', 'u', text) |
|
|
text = re.sub(r'[ÙÚỤỦŨ]', 'U', text) |
|
|
|
|
|
text = re.sub(r'[ưừứựửữ]', 'ư', text) |
|
|
text = re.sub(r'[ƯỪỨỰỬỮ]', 'Ư', text) |
|
|
|
|
|
text = re.sub(r'[ỳýỵỷỹ]', 'y', text) |
|
|
text = re.sub(r'[ỲÝỴỶỸ]', 'Y', text) |
|
|
|
|
|
text = re.sub(r'đ', 'd', text) |
|
|
text = re.sub(r'Đ', 'D', text) |
|
|
|
|
|
return text |
|
|
|
|
|
|
|
|
|
|
|
def get_mentions(ner_tags, tokens, ): |
|
|
spans = [] |
|
|
prev_tag = None |
|
|
s_pos = None |
|
|
for i, tag in enumerate(ner_tags): |
|
|
if tag == 'O': |
|
|
if prev_tag is not None and s_pos is not None: |
|
|
|
|
|
spans.append({ |
|
|
'span': [s_pos, i], |
|
|
'tag': prev_tag, |
|
|
'text': ' '.join(tokens[s_pos:i]), |
|
|
}) |
|
|
|
|
|
prev_tag = None |
|
|
s_pos = None |
|
|
elif tag.startswith('B'): |
|
|
if prev_tag is not None and s_pos is not None: |
|
|
|
|
|
spans.append({ |
|
|
'span': [s_pos, i], |
|
|
'tag': prev_tag, |
|
|
'text': ' '.join(tokens[s_pos:i]), |
|
|
}) |
|
|
|
|
|
s_pos = i |
|
|
prev_tag = tag[2:] |
|
|
else: |
|
|
cur_tag = tag[2:] |
|
|
if prev_tag is not None and prev_tag != cur_tag and s_pos is not None: |
|
|
|
|
|
|
|
|
spans.append({ |
|
|
'span': [s_pos, i], |
|
|
'tag': prev_tag, |
|
|
'text': ' '.join(tokens[s_pos:i]), |
|
|
}) |
|
|
prev_tag = None |
|
|
s_pos = None |
|
|
|
|
|
if i == len(ner_tags) - 1 and prev_tag is not None and s_pos is not None: |
|
|
|
|
|
spans.append({ |
|
|
'span': [s_pos, i + 1], |
|
|
'tag': prev_tag, |
|
|
'text': ' '.join(tokens[s_pos: i + 1]) |
|
|
}) |
|
|
|
|
|
return spans |
|
|
|
|
|
|
|
|
def mentions2tags( |
|
|
tokens, |
|
|
mentions, |
|
|
): |
|
|
tags = ['O'] * len(tokens) |
|
|
for mention in mentions: |
|
|
mention_tag = mention['tag'] |
|
|
span = mention['span'] |
|
|
if mention_tag != 'O': |
|
|
mention_tags = [f'B-{mention_tag}'] + [f'I-{mention_tag}'] * (span[1] - span[0] - 1) |
|
|
else: |
|
|
mention_tags = ['O'] * (span[1] - span[0]) |
|
|
tags[span[0]: span[1]] = mention_tags |
|
|
return tags |
|
|
|
|
|
|
|
|
def load_yaml(fn): |
|
|
with open(fn, mode='r', encoding='utf8') as f: |
|
|
config = yaml.safe_load(f) |
|
|
return config |
|
|
|
|
|
|
|
|
def load_json(fn): |
|
|
with open(fn, mode='r', encoding='utf8') as f: |
|
|
data = json.load(f) |
|
|
return data |
|
|
|
|
|
|
|
|
def load_jsonl(fn, num_max_lines=None,): |
|
|
data = [] |
|
|
i = 0 |
|
|
with open(fn, mode='r', encoding='utf8') as f: |
|
|
for line in f: |
|
|
data.append(json.loads(line)) |
|
|
i += 1 |
|
|
if num_max_lines is not None and i == num_max_lines: |
|
|
break |
|
|
return data |
|
|
|
|
|
|
|
|
def load_jsonl_generator(fn, num_max_lines=None,): |
|
|
i = 0 |
|
|
with open(fn, mode='r', encoding='utf8') as f: |
|
|
for line in f: |
|
|
try: |
|
|
item = json.loads(line) |
|
|
yield item |
|
|
i += 1 |
|
|
if num_max_lines is not None and i == num_max_lines: |
|
|
break |
|
|
except Exception as e: |
|
|
print(f'-> error {e} in line content:`\n{line}\n`') |
|
|
|
|
|
|
|
|
def load_jsonl_by_batch( |
|
|
fn, |
|
|
bs, |
|
|
): |
|
|
batch = [] |
|
|
with open(fn, mode='r', encoding='utf8', errors='surrogateescape') as f: |
|
|
for line in f: |
|
|
try: |
|
|
item = json.loads(line) |
|
|
batch.append(item) |
|
|
except Exception as e: |
|
|
print(e) |
|
|
|
|
|
print('#' * 10) |
|
|
|
|
|
if len(batch) == bs: |
|
|
yield batch |
|
|
batch = [] |
|
|
|
|
|
if len(batch) > 0: |
|
|
yield batch |
|
|
|
|
|
|
|
|
def load_text_by_batch( |
|
|
fn, |
|
|
bs, |
|
|
): |
|
|
batch = [] |
|
|
with open(fn, mode='r', encoding='utf8', errors='surrogateescape') as f: |
|
|
for line in f: |
|
|
batch.append(line.strip()) |
|
|
|
|
|
if len(batch) == bs: |
|
|
yield batch |
|
|
batch = [] |
|
|
|
|
|
if len(batch) > 0: |
|
|
yield batch |
|
|
|
|
|
|
|
|
def dump_json(data, fn, indent=None,): |
|
|
with open(fn, mode='w', encoding='utf8') as f: |
|
|
json.dump(data, f, ensure_ascii=False, indent=indent,) |
|
|
|
|
|
|
|
|
def dump_jsonl(data: list, fn,): |
|
|
with open(fn, mode='w', encoding='utf8') as f: |
|
|
for item in data: |
|
|
f.write(json.dumps(item, ensure_ascii=False)) |
|
|
f.write('\n') |
|
|
|
|
|
|
|
|
def convert_jsonl2jsonl_gz( |
|
|
input_path, |
|
|
output_path, |
|
|
): |
|
|
import gzip |
|
|
with gzip.open(output_path, mode='wb',) as f: |
|
|
for batch in load_jsonl_by_batch( |
|
|
fn=input_path, |
|
|
bs=1000, |
|
|
): |
|
|
for item in batch: |
|
|
f.write((json.dumps(item, ensure_ascii=False,) + '\n').encode('utf8')) |
|
|
|
|
|
|
|
|
|
|
|
def convert_jsonl2jsonl_gz_in_dir( |
|
|
input_dir, |
|
|
output_dir=None, |
|
|
): |
|
|
if output_dir is None: |
|
|
output_dir = input_dir |
|
|
fns = [ |
|
|
fn for fn in os.listdir(input_dir) |
|
|
if fn.endswith('.jsonl') |
|
|
] |
|
|
for fn in fns: |
|
|
print(f'-> compressing {fn}') |
|
|
convert_jsonl2jsonl_gz( |
|
|
input_path=os.path.join(input_dir, fn), |
|
|
output_path=os.path.join(output_dir, fn + '.gz'), |
|
|
) |
|
|
|
|
|
|
|
|
def dump_jsonl_gz( |
|
|
data, |
|
|
output_path, |
|
|
): |
|
|
import gzip |
|
|
with gzip.open(output_path, mode='wb') as f: |
|
|
for item in data: |
|
|
f.write((json.dumps(item, ensure_ascii=False,) + '\n').encode('utf8')) |
|
|
|
|
|
|
|
|
def load_jsonl_gz( |
|
|
fn, |
|
|
num_max_lines=None, |
|
|
num_workers=1, |
|
|
): |
|
|
if num_workers == 1: |
|
|
import gzip |
|
|
else: |
|
|
import mgzip as gzip |
|
|
data = [] |
|
|
i = 0 |
|
|
with (gzip.open(fn, mode='rb') if num_workers == 1 else gzip.open(fn, mode='rb', thread=num_workers)) as f: |
|
|
for line in f: |
|
|
try: |
|
|
item = json.loads(line) |
|
|
data.append(item) |
|
|
i += 1 |
|
|
if num_max_lines is not None and i == num_max_lines: |
|
|
break |
|
|
except Exception as e: |
|
|
print(e) |
|
|
|
|
|
return data |
|
|
|
|
|
|
|
|
def load_jsonl_gz_generator( |
|
|
fn, |
|
|
num_max_lines=None, |
|
|
num_workers=1, |
|
|
): |
|
|
if num_workers == 1: |
|
|
import gzip |
|
|
else: |
|
|
import mgzip as gzip |
|
|
i = 0 |
|
|
with (gzip.open(fn, mode='rb') if num_workers == 1 else gzip.open(fn, mode='rb', thread=num_workers)) as f: |
|
|
try: |
|
|
for line in f: |
|
|
try: |
|
|
item = json.loads(line) |
|
|
yield item |
|
|
i += 1 |
|
|
if num_max_lines is not None and i == num_max_lines: |
|
|
break |
|
|
except Exception as e: |
|
|
print(e) |
|
|
except Exception as e_file: |
|
|
print(f'-> error: {e_file}') |
|
|
|
|
|
|
|
|
def load_jsonl_or_jsonl_gz( |
|
|
fn, |
|
|
num_max_lines=None, |
|
|
num_workers=1, |
|
|
): |
|
|
from functools import partial |
|
|
if fn.endswith('.jsonl'): |
|
|
load_func = load_jsonl |
|
|
elif fn.endswith('.jsonl.gz'): |
|
|
load_func = partial(load_jsonl_gz, num_workers=num_workers,) |
|
|
else: |
|
|
raise NotImplementedError() |
|
|
return load_func(fn, num_max_lines=num_max_lines,) |
|
|
|
|
|
|
|
|
def load_jsonl_or_jsonl_gz_generator( |
|
|
fn, |
|
|
num_max_lines=None, |
|
|
num_workers=1, |
|
|
): |
|
|
from functools import partial |
|
|
if fn.endswith('.jsonl'): |
|
|
load_func = load_jsonl_generator |
|
|
elif fn.endswith('.gz'): |
|
|
load_func = partial(load_jsonl_gz_generator, num_workers=num_workers,) |
|
|
else: |
|
|
raise NotImplementedError() |
|
|
for item in load_func(fn, num_max_lines=num_max_lines): |
|
|
yield item |
|
|
|
|
|
|
|
|
def load_jsonl_gz_by_batch( |
|
|
fn, |
|
|
bs, |
|
|
num_workers=1, |
|
|
): |
|
|
if num_workers == 1: |
|
|
import gzip |
|
|
else: |
|
|
import mgzip as gzip |
|
|
batch = [] |
|
|
|
|
|
with (gzip.open(fn, mode='rb') if num_workers == 1 else gzip.open(fn, mode='rb', thread=num_workers)) as f: |
|
|
try: |
|
|
for line in f: |
|
|
try: |
|
|
item = json.loads(line) |
|
|
batch.append(item) |
|
|
if len(batch) == bs: |
|
|
yield batch |
|
|
batch = [] |
|
|
|
|
|
except Exception as e: |
|
|
print(e) |
|
|
print(line) |
|
|
print('#' * 10) |
|
|
except Exception as e: |
|
|
print(e) |
|
|
|
|
|
if len(batch) > 0: |
|
|
yield batch |
|
|
|
|
|
|
|
|
def load_jsonl_or_jsonl_gz_by_batch(fn, bs, num_workers=1,): |
|
|
from functools import partial |
|
|
if fn.endswith('.jsonl'): |
|
|
load_func = load_jsonl_by_batch |
|
|
elif fn.endswith('.jsonl.gz'): |
|
|
load_func = partial(load_jsonl_gz_by_batch, num_workers=num_workers,) |
|
|
else: |
|
|
raise NotImplementedError() |
|
|
|
|
|
for batch in load_func(fn, bs): |
|
|
yield batch |
|
|
|
|
|
|
|
|
def get_load_func( |
|
|
fn, |
|
|
): |
|
|
if fn.endswith('.jsonl'): |
|
|
load_func = load_jsonl_by_batch |
|
|
elif fn.endswith('.jsonl.gz'): |
|
|
load_func = load_jsonl_gz_by_batch |
|
|
else: |
|
|
raise NotImplementedError() |
|
|
return load_func |
|
|
|
|
|
|
|
|
def load_text_gz( |
|
|
fn, |
|
|
max_lines=None, |
|
|
num_workers=1, |
|
|
): |
|
|
if num_workers == 1: |
|
|
import gzip |
|
|
else: |
|
|
import mgzip as gzip |
|
|
data = [] |
|
|
n = 0 |
|
|
|
|
|
with (gzip.open(fn, mode='rb') if num_workers == 1 else gzip.open(fn, mode='rb', thread=num_workers)) as f: |
|
|
for line in f: |
|
|
n += 1 |
|
|
data.append(line.decode('utf8')) |
|
|
|
|
|
if max_lines is not None and n >= max_lines: |
|
|
break |
|
|
return data |
|
|
|
|
|
|
|
|
def load_text_gz_generator(fn, num_max_lines=None, num_workers=1,): |
|
|
if num_workers == 1: |
|
|
import gzip |
|
|
else: |
|
|
import mgzip as gzip |
|
|
n = 0 |
|
|
|
|
|
with (gzip.open(fn, mode='rb') if num_workers == 1 else gzip.open(fn, mode='rb', thread=num_workers)) as f: |
|
|
for line in f: |
|
|
n += 1 |
|
|
yield line.decode('utf8') |
|
|
|
|
|
if num_max_lines is not None and n >= num_max_lines: |
|
|
break |
|
|
|
|
|
|
|
|
def dump_pickle(data, fn): |
|
|
with open(fn, mode='wb') as f: |
|
|
pickle.dump(data, f) |
|
|
|
|
|
|
|
|
def load_pickle(fn): |
|
|
with open(fn, mode='rb',) as f: |
|
|
data = pickle.load(f) |
|
|
return data |
|
|
|
|
|
|
|
|
def load_text(fn): |
|
|
data = [] |
|
|
with open(fn, mode='r', encoding='utf8') as f: |
|
|
for line in f: |
|
|
line = line.strip() |
|
|
if line: |
|
|
data.append(line) |
|
|
return data |
|
|
|
|
|
|
|
|
def load_text_generator(fn, num_max_lines=None,): |
|
|
n = 0 |
|
|
with open(fn, mode='r', encoding='utf8') as f: |
|
|
for line in f: |
|
|
n += 1 |
|
|
line = line.strip() |
|
|
yield line |
|
|
if num_max_lines is not None and n == num_max_lines: |
|
|
break |
|
|
|
|
|
|
|
|
def load_text_or_text_gz_generator(fn, num_max_lines=None, num_workers=1,): |
|
|
from functools import partial |
|
|
if fn.endswith('.gz'): |
|
|
load_func = partial(load_text_gz_generator, num_workers=num_workers,) |
|
|
else: |
|
|
load_func = load_text_generator |
|
|
|
|
|
for item in load_func(fn, num_max_lines=num_max_lines): |
|
|
yield item |
|
|
|
|
|
|
|
|
def load_zst_generator(fn, num_max_lines=None): |
|
|
import zstandard as zstd |
|
|
import io |
|
|
n = 0 |
|
|
|
|
|
DCTX = zstd.ZstdDecompressor(max_window_size=2 ** 31) |
|
|
with zstd.open(fn, mode='rb', dctx=DCTX) as zfh, \ |
|
|
io.TextIOWrapper(zfh) as iofh: |
|
|
for line in iofh: |
|
|
line = line.strip() |
|
|
n += 1 |
|
|
yield line |
|
|
|
|
|
if num_max_lines is not None and n == num_max_lines: |
|
|
break |
|
|
|
|
|
|
|
|
def load_zst_jsonl_generator(fn, num_max_lines=None): |
|
|
import zstandard as zstd |
|
|
import io |
|
|
import json |
|
|
n = 0 |
|
|
|
|
|
DCTX = zstd.ZstdDecompressor(max_window_size=2 ** 31) |
|
|
with zstd.open(fn, mode='rb', dctx=DCTX) as zfh, \ |
|
|
io.TextIOWrapper(zfh) as iofh: |
|
|
for line in iofh: |
|
|
line = line.strip() |
|
|
line = json.loads(line) |
|
|
n += 1 |
|
|
yield line |
|
|
|
|
|
if num_max_lines is not None and n == num_max_lines: |
|
|
break |
|
|
|
|
|
|
|
|
def write_text(data, fn): |
|
|
with open(fn, mode='w', encoding='utf8') as f: |
|
|
for item in data: |
|
|
f.write(item) |
|
|
f.write('\n') |
|
|
|
|
|
|
|
|
def split_jsonl(fn, n_parts=10,): |
|
|
data = load_jsonl(fn) |
|
|
n = len(data) |
|
|
bs = (n - 1) // n_parts + 1 |
|
|
for i in range(n_parts): |
|
|
dump_jsonl(data[i * bs: (i + 1) * bs], f'{fn}.part{i}') |
|
|
|
|
|
|
|
|
def count_file_lines(file_path: str) -> int: |
|
|
import subprocess |
|
|
if file_path.endswith('.gz'): |
|
|
ps = subprocess.Popen(('zcat', file_path), stdout=subprocess.PIPE,) |
|
|
output = subprocess.check_output(["wc", "-l"], stdin=ps.stdout) |
|
|
else: |
|
|
output = subprocess.check_output(["wc", "-l", file_path]) |
|
|
num_lines = int(output.split()[0]) |
|
|
return num_lines |
|
|
|
|
|
|
|
|
def mask_number( |
|
|
text, |
|
|
tag, |
|
|
): |
|
|
if tag == 'num.phone' and text.isdigit(): |
|
|
return '<num.phone>' |
|
|
|
|
|
text = re.sub(r'\d+([\.,]?\d+)*', '<number>', text) |
|
|
return text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_conll(fn, sep='\t'): |
|
|
with open(fn, mode='r', encoding='utf8') as f: |
|
|
text = f.read() |
|
|
|
|
|
text = text.strip() |
|
|
if text == '': |
|
|
return [] |
|
|
|
|
|
data = [] |
|
|
|
|
|
for sent_text in re.split(r'\n{2,}', text.strip()): |
|
|
sent_lines = re.split(r'\n', sent_text) |
|
|
sent = [] |
|
|
for line in sent_lines: |
|
|
line = line.strip('\n') |
|
|
if line == '': |
|
|
continue |
|
|
parts = line.split(sep) |
|
|
if len(parts) > 0: |
|
|
sent.append(parts) |
|
|
|
|
|
if len(sent) > 0: |
|
|
data.append(sent) |
|
|
sent = [] |
|
|
return data |
|
|
|
|
|
|
|
|
def unique_data( |
|
|
data, |
|
|
): |
|
|
results = [] |
|
|
unique_info = set() |
|
|
for item in data: |
|
|
info = ( |
|
|
' '.join(item['tokens']), |
|
|
' '.join(item['tags']) |
|
|
) |
|
|
if info not in unique_info: |
|
|
unique_info.add(info) |
|
|
results.append(item) |
|
|
print(f'-> Deduplicate: from {len(data)} -> {len(results)}') |
|
|
return results |
|
|
|
|
|
|
|
|
def identify( |
|
|
data, |
|
|
prefix, |
|
|
): |
|
|
for i, item in enumerate(data): |
|
|
item['id'] = f'{prefix}_{i:06d}' |
|
|
|
|
|
|
|
|
def split_ner_data( |
|
|
data, |
|
|
test_size=0.1, |
|
|
seed=42, |
|
|
test=False, |
|
|
): |
|
|
|
|
|
import random |
|
|
data = unique_data(data) |
|
|
random.seed(seed) |
|
|
from collections import defaultdict |
|
|
negative_samples = [] |
|
|
samples = [] |
|
|
entity_tag2idxs = defaultdict(set) |
|
|
|
|
|
entity_values = set() |
|
|
entity_tag_values = set() |
|
|
|
|
|
for item in data: |
|
|
entities = get_mentions( |
|
|
ner_tags=item['tags'], |
|
|
tokens=item['tokens'], |
|
|
) |
|
|
for entity in entities: |
|
|
entity_values.add(entity['text']) |
|
|
entity_tag_values.add((entity['text'], entity['tag'])) |
|
|
if len(entities) == 0: |
|
|
negative_samples.append(item) |
|
|
item['negative'] = True |
|
|
else: |
|
|
item['negative'] = False |
|
|
idx = len(samples) |
|
|
samples.append(item) |
|
|
for entity in entities: |
|
|
entity_tag2idxs[entity['tag']].add(idx) |
|
|
|
|
|
entity_tag2idxs = {tag: list(idxs) for tag, idxs in entity_tag2idxs.items()} |
|
|
|
|
|
train_samples = [] |
|
|
dev_samples = [] |
|
|
test_samples = [] |
|
|
|
|
|
n_negative_samples = len(negative_samples) |
|
|
train_pos_idxs = set() |
|
|
dev_pos_idxs = set() |
|
|
test_pos_idxs = set() |
|
|
|
|
|
selected_idxs = set() |
|
|
for entity_tag, idxs in entity_tag2idxs.items(): |
|
|
idxs = [i for i in idxs if i not in selected_idxs] |
|
|
if len(idxs) == 0: |
|
|
continue |
|
|
|
|
|
random.shuffle(idxs) |
|
|
n_test = int(test_size * len(idxs)) |
|
|
if n_test == 0: |
|
|
train_pos_idxs.update(idxs) |
|
|
else: |
|
|
dev_pos_idxs.update(idxs[:n_test]) |
|
|
test_pos_idxs.update(idxs[n_test: 2 * n_test]) |
|
|
train_pos_idxs.update(idxs[2 * n_test:]) |
|
|
|
|
|
selected_idxs.update(idxs) |
|
|
|
|
|
assert len(train_pos_idxs.intersection(dev_pos_idxs)) == 0 |
|
|
assert len(train_pos_idxs.intersection(test_pos_idxs)) == 0 |
|
|
assert len(dev_pos_idxs.intersection(test_pos_idxs)) == 0 |
|
|
|
|
|
train_samples.extend([ |
|
|
samples[i] for i in train_pos_idxs |
|
|
]) |
|
|
if len(dev_pos_idxs) > 0: |
|
|
dev_samples.extend([ |
|
|
samples[i] for i in dev_pos_idxs |
|
|
]) |
|
|
else: |
|
|
dev_samples.extend([ |
|
|
samples[i] for i in train_pos_idxs |
|
|
]) |
|
|
test_samples.extend([ |
|
|
samples[i] for i in test_pos_idxs |
|
|
]) |
|
|
|
|
|
if test: |
|
|
print('#train samples:', len(train_samples)) |
|
|
print('#dev samples:', len(dev_samples)) |
|
|
print('#test samples:', len(test_samples)) |
|
|
|
|
|
identify(train_samples, prefix='train') |
|
|
identify(dev_samples, prefix='dev') |
|
|
identify(test_samples, prefix='test') |
|
|
|
|
|
return train_samples, dev_samples, test_samples |
|
|
else: |
|
|
train_samples = [ |
|
|
*train_samples, |
|
|
*test_samples, |
|
|
] |
|
|
print('#train samples:', len(train_samples)) |
|
|
print('#dev samples:', len(dev_samples)) |
|
|
|
|
|
identify(train_samples, prefix='train') |
|
|
identify(dev_samples, prefix='dev') |
|
|
identify(test_samples, prefix='test') |
|
|
|
|
|
return train_samples, dev_samples |
|
|
|
|
|
|
|
|
def load_ner_src_tgt( |
|
|
src_path, |
|
|
tgt_path, |
|
|
): |
|
|
|
|
|
texts = load_text(src_path) |
|
|
labels = load_text(tgt_path) |
|
|
assert len(texts) == len(labels) |
|
|
data = [] |
|
|
for text, label in zip(texts, labels): |
|
|
tokens = text.split() |
|
|
tags = label.split() |
|
|
assert len(tokens) == len(tags) |
|
|
data.append({ |
|
|
'tokens': tokens, |
|
|
'tags': tags, |
|
|
}) |
|
|
return data |
|
|
|
|
|
|
|
|
def load_ner_src_tgt_inline(fn): |
|
|
data = [] |
|
|
with open(fn, mode='r', encoding='utf8') as f: |
|
|
for line in f: |
|
|
text, tag = line.split('\t') |
|
|
text = text.strip() |
|
|
tag = tag.strip() |
|
|
tokens = text.split() |
|
|
tags = tag.split() |
|
|
assert len(tokens) == len(tags) |
|
|
|
|
|
data.append({ |
|
|
'tokens': tokens, |
|
|
'tags': tags, |
|
|
}) |
|
|
return data |
|
|
|
|
|
|
|
|
def load_syllable(path): |
|
|
syllables = [] |
|
|
with open(path, mode='r', encoding='utf8') as f: |
|
|
for line in f: |
|
|
line = line.strip() |
|
|
if line == '': |
|
|
continue |
|
|
|
|
|
parts = line.split('\t', maxsplit=1) |
|
|
syllables.append(parts[0]) |
|
|
return syllables |
|
|
|
|
|
|
|
|
def get_coordinates( |
|
|
text: str, |
|
|
): |
|
|
text = text.strip() |
|
|
lines = re.split(r'\n+', text) |
|
|
coordinates = [] |
|
|
for line in lines: |
|
|
line = line.strip() |
|
|
if line == '': |
|
|
continue |
|
|
|
|
|
long, lat, _ = line.split(',') |
|
|
lat = float(lat) |
|
|
long = float(long) |
|
|
|
|
|
coordinates.append([lat, long]) |
|
|
return coordinates |
|
|
|
|
|
|
|
|
def load_polygon( |
|
|
kml_path, |
|
|
): |
|
|
from lxml import etree |
|
|
with open(kml_path, mode='r',) as f: |
|
|
xml_data = etree.XML(f.read()) |
|
|
|
|
|
|
|
|
place_marks = xml_data.xpath('//Placemark') |
|
|
results = {} |
|
|
for place_mark in place_marks: |
|
|
name = place_mark.xpath('.//name/text()')[0] |
|
|
coordinates = get_coordinates(place_mark.xpath('.//coordinates/text()')[0]) |
|
|
results[name] = coordinates |
|
|
return results |
|
|
|
|
|
|
|
|
|
|
|
|