Spaces:
Build error
Build error
| import os | |
| import numpy as np | |
| import torch | |
| import random | |
| from PIL import Image, ImageDraw, ImageFont | |
| import pickle | |
| from config.GlobalVariables import * | |
| np.random.seed(0) | |
| class DataLoader(): | |
| def __init__(self, num_writer=2, num_samples=5, divider=10.0, datadir='./data/writers'): | |
| self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| self.num_writer = num_writer | |
| self.num_samples = num_samples | |
| self.divider = divider | |
| self.datadir = datadir | |
| print ('self.datadir : ', self.datadir) | |
| self.total_writers = len([name for name in os.listdir(datadir)]) | |
| def next_batch(self, TYPE='TRAIN', uid=-1, tids=[]): | |
| all_sentence_level_stroke_in = [] | |
| all_sentence_level_stroke_out = [] | |
| all_sentence_level_stroke_length = [] | |
| all_sentence_level_term = [] | |
| all_sentence_level_char = [] | |
| all_sentence_level_char_length = [] | |
| all_word_level_stroke_in = [] | |
| all_word_level_stroke_out = [] | |
| all_word_level_stroke_length = [] | |
| all_word_level_term = [] | |
| all_word_level_char = [] | |
| all_word_level_char_length = [] | |
| all_segment_level_stroke_in = [] | |
| all_segment_level_stroke_out = [] | |
| all_segment_level_stroke_length = [] | |
| all_segment_level_term = [] | |
| all_segment_level_char = [] | |
| all_segment_level_char_length = [] | |
| while len(all_sentence_level_stroke_in) < self.num_writer: | |
| if uid < 0: | |
| if TYPE == 'TRAIN': | |
| if self.datadir == './data/NEW_writers' or self.datadir == './data/writers': | |
| uid = np.random.choice([i for i in range(150)]) | |
| else: | |
| if self.device == 'cpu': | |
| uid = np.random.choice([i for i in range(20)]) | |
| else: | |
| uid = np.random.choice([i for i in range(294)]) | |
| else: | |
| uid = np.random.choice([i for i in range(150,170)]) | |
| total_texts = len([name for name in os.listdir(self.datadir+'/'+str(uid))]) | |
| if len(tids) == 0: | |
| tids = random.sample([i for i in range(total_texts)], self.num_samples) | |
| user_sentence_level_stroke_in = [] | |
| user_sentence_level_stroke_out = [] | |
| user_sentence_level_stroke_length = [] | |
| user_sentence_level_term = [] | |
| user_sentence_level_char = [] | |
| user_sentence_level_char_length = [] | |
| user_word_level_stroke_in = [] | |
| user_word_level_stroke_out = [] | |
| user_word_level_stroke_length = [] | |
| user_word_level_term = [] | |
| user_word_level_char = [] | |
| user_word_level_char_length = [] | |
| user_segment_level_stroke_in = [] | |
| user_segment_level_stroke_out = [] | |
| user_segment_level_stroke_length = [] | |
| user_segment_level_term = [] | |
| user_segment_level_char = [] | |
| user_segment_level_char_length = [] | |
| # print ("uid: ", uid, "\ttids:", tids) | |
| for tid in tids: | |
| if self.datadir == './data/NEW_writers': | |
| [sentence_level_raw_stroke, sentence_level_stroke_in, sentence_level_stroke_out, sentence_level_term, sentence_level_char, word_level_raw_stroke, word_level_stroke_in, word_level_stroke_out, word_level_term, word_level_char, segment_level_raw_stroke, segment_level_stroke_in, segment_level_stroke_out, segment_level_term, segment_level_char] = \ | |
| np.load(self.datadir+'/'+str(uid)+'/'+str(tid)+'.npy', allow_pickle=True, encoding='bytes') | |
| elif self.datadir == './data/DW_writers': | |
| [sentence_level_raw_stroke, sentence_level_char, sentence_level_term, sentence_level_stroke_in, sentence_level_stroke_out, | |
| word_level_raw_stroke, word_level_char, word_level_term, word_level_stroke_in, word_level_stroke_out, | |
| segment_level_raw_stroke, segment_level_char, segment_level_term, segment_level_stroke_in, segment_level_stroke_out, _] = \ | |
| np.load(self.datadir+'/'+str(uid)+'/'+str(tid)+'.npy', allow_pickle=True, encoding='bytes') | |
| elif self.datadir == './data/VALID_DW_writers': | |
| [sentence_level_raw_stroke, sentence_level_char, sentence_level_term, sentence_level_stroke_in, sentence_level_stroke_out, | |
| word_level_raw_stroke, word_level_char, word_level_term, word_level_stroke_in, word_level_stroke_out, | |
| segment_level_raw_stroke, segment_level_char, segment_level_term, segment_level_stroke_in, segment_level_stroke_out, _] = \ | |
| np.load(self.datadir+'/'+str(uid)+'/'+str(tid)+'.npy', allow_pickle=True, encoding='bytes') | |
| else: | |
| [sentence_level_raw_stroke, sentence_level_stroke_in, sentence_level_stroke_out, sentence_level_term, sentence_level_char, word_level_raw_stroke, word_level_stroke_in, word_level_stroke_out, word_level_term, word_level_char, segment_level_raw_stroke, segment_level_stroke_in, segment_level_stroke_out, segment_level_term, segment_level_char, _] = \ | |
| np.load(self.datadir+'/'+str(uid)+'/'+str(tid)+'.npy', allow_pickle=True, encoding='bytes') | |
| if self.datadir == './data/DW_writers': | |
| sentence_level_char = sentence_level_char[1:] | |
| sentence_level_term = sentence_level_term[1:] | |
| if self.datadir == './data/VALID_DW_writers': | |
| sentence_level_char = sentence_level_char[1:] | |
| sentence_level_term = sentence_level_term[1:] | |
| while True: | |
| if len(sentence_level_term) == 0: | |
| break | |
| if sentence_level_term[-1] != 1.0: | |
| sentence_level_raw_stroke = sentence_level_raw_stroke[:-1] | |
| sentence_level_char = sentence_level_char[:-1] | |
| sentence_level_term = sentence_level_term[:-1] | |
| sentence_level_stroke_in = sentence_level_stroke_in[:-1] | |
| sentence_level_stroke_out = sentence_level_stroke_out[:-1] | |
| else: | |
| break | |
| tmp = [] | |
| for i, t in enumerate(sentence_level_term): | |
| if t == 1: | |
| tmp.append(sentence_level_char[i]) | |
| a = np.ones_like(sentence_level_stroke_in) | |
| a[:,:2] /= self.divider | |
| if len(sentence_level_stroke_in) == len(sentence_level_term) and len(tmp) > 0 and len(sentence_level_stroke_in) > 0: | |
| user_sentence_level_stroke_in.append(np.asarray(sentence_level_stroke_in) * a) | |
| user_sentence_level_stroke_out.append(np.asarray(sentence_level_stroke_out) * a) | |
| user_sentence_level_stroke_length.append(len(sentence_level_stroke_in)) | |
| user_sentence_level_char.append(np.asarray(tmp)) | |
| user_sentence_level_term.append(np.asarray(sentence_level_term)) | |
| user_sentence_level_char_length.append(len(tmp)) | |
| for wid in range(len(word_level_stroke_in)): | |
| each_word_level_stroke_in = word_level_stroke_in[wid] | |
| each_word_level_stroke_out = word_level_stroke_out[wid] | |
| if self.datadir == './data/DW_writers': | |
| each_word_level_term = word_level_term[wid][1:] | |
| each_word_level_char = word_level_char[wid][1:] | |
| elif self.datadir == './data/VALID_DW_writers': | |
| each_word_level_term = word_level_term[wid][1:] | |
| each_word_level_char = word_level_char[wid][1:] | |
| else: | |
| each_word_level_term = word_level_term[wid] | |
| each_word_level_char = word_level_char[wid] | |
| # assert (len(each_word_level_stroke_in) == len(each_word_level_char) == len(each_word_level_term)) | |
| while True: | |
| if len(each_word_level_term) == 0: | |
| break | |
| if each_word_level_term[-1] != 1.0: | |
| # each_word_level_raw_stroke = each_word_level_raw_stroke[:-1] | |
| each_word_level_char = each_word_level_char[:-1] | |
| each_word_level_term = each_word_level_term[:-1] | |
| each_word_level_stroke_in = each_word_level_stroke_in[:-1] | |
| each_word_level_stroke_out = each_word_level_stroke_out[:-1] | |
| else: | |
| break | |
| tmp = [] | |
| for i, t in enumerate(each_word_level_term): | |
| if t == 1: | |
| tmp.append(each_word_level_char[i]) | |
| b = np.ones_like(each_word_level_stroke_in) | |
| b[:,:2] /= self.divider | |
| if len(each_word_level_stroke_in) == len(each_word_level_term) and len(tmp) > 0 and len(each_word_level_stroke_in) > 0: | |
| user_word_level_stroke_in.append(np.asarray(each_word_level_stroke_in) * b) | |
| user_word_level_stroke_out.append(np.asarray(each_word_level_stroke_out) * b) | |
| user_word_level_stroke_length.append(len(each_word_level_stroke_in)) | |
| user_word_level_char.append(np.asarray(tmp)) | |
| user_word_level_term.append(np.asarray(each_word_level_term)) | |
| user_word_level_char_length.append(len(tmp)) | |
| segment_level_stroke_in_list = [] | |
| segment_level_stroke_out_list = [] | |
| segment_level_stroke_length_list = [] | |
| segment_level_char_list = [] | |
| segment_level_term_list = [] | |
| segment_level_char_length_list = [] | |
| for sid in range(len(segment_level_stroke_in[wid])): | |
| each_segment_level_stroke_in = segment_level_stroke_in[wid][sid] | |
| each_segment_level_stroke_out = segment_level_stroke_out[wid][sid] | |
| if self.datadir == './data/DW_writers': | |
| each_segment_level_term = segment_level_term[wid][sid][1:] | |
| each_segment_level_char = segment_level_char[wid][sid][1:] | |
| elif self.datadir == './data/VALID_DW_writers': | |
| each_segment_level_term = segment_level_term[wid][sid][1:] | |
| each_segment_level_char = segment_level_char[wid][sid][1:] | |
| else: | |
| each_segment_level_term = segment_level_term[wid][sid] | |
| each_segment_level_char = segment_level_char[wid][sid] | |
| while True: | |
| if len(each_segment_level_term) == 0: | |
| break | |
| if each_segment_level_term[-1] != 1.0: | |
| # each_segment_level_raw_stroke = each_segment_level_raw_stroke[:-1] | |
| each_segment_level_char = each_segment_level_char[:-1] | |
| each_segment_level_term = each_segment_level_term[:-1] | |
| each_segment_level_stroke_in = each_segment_level_stroke_in[:-1] | |
| each_segment_level_stroke_out = each_segment_level_stroke_out[:-1] | |
| else: | |
| break | |
| tmp = [] | |
| for i, t in enumerate(each_segment_level_term): | |
| if t == 1: | |
| tmp.append(each_segment_level_char[i]) | |
| c = np.ones_like(each_segment_level_stroke_in) | |
| c[:,:2] /= self.divider | |
| if len(each_segment_level_stroke_in) == len(each_segment_level_term) and len(tmp) > 0 and len(each_segment_level_stroke_in) > 0: | |
| segment_level_stroke_in_list.append(np.asarray(each_segment_level_stroke_in) * c) | |
| segment_level_stroke_out_list.append(np.asarray(each_segment_level_stroke_out) * c) | |
| segment_level_stroke_length_list.append(len(each_segment_level_stroke_in)) | |
| segment_level_char_list.append(np.asarray(tmp)) | |
| segment_level_term_list.append(np.asarray(each_segment_level_term)) | |
| segment_level_char_length_list.append(len(tmp)) | |
| if len(segment_level_stroke_length_list) > 0: | |
| SEGMENT_MAX_STROKE_LENGTH = np.max(segment_level_stroke_length_list) | |
| SEGMENT_MAX_CHARACTER_LENGTH = np.max(segment_level_char_length_list) | |
| new_segment_level_stroke_in_list = np.asarray([np.pad(a, ((0, SEGMENT_MAX_STROKE_LENGTH-len(a)), (0, 0)), 'constant') for a in segment_level_stroke_in_list]) | |
| new_segment_level_stroke_out_list = np.asarray([np.pad(a, ((0, SEGMENT_MAX_STROKE_LENGTH-len(a)), (0, 0)), 'constant') for a in segment_level_stroke_out_list]) | |
| new_segment_level_term_list = np.asarray([np.pad(a, ((0, SEGMENT_MAX_STROKE_LENGTH-len(a))), 'constant') for a in segment_level_term_list]) | |
| new_segment_level_char_list = np.asarray([np.pad(a, ((0, SEGMENT_MAX_CHARACTER_LENGTH-len(a))), 'constant') for a in segment_level_char_list]) | |
| user_segment_level_stroke_in.append(new_segment_level_stroke_in_list) | |
| user_segment_level_stroke_out.append(new_segment_level_stroke_out_list) | |
| user_segment_level_stroke_length.append(segment_level_stroke_length_list) | |
| user_segment_level_char.append(new_segment_level_char_list) | |
| user_segment_level_term.append(new_segment_level_term_list) | |
| user_segment_level_char_length.append(segment_level_char_length_list) | |
| WORD_MAX_STROKE_LENGTH = np.max(user_word_level_stroke_length) | |
| WORD_MAX_CHARACTER_LENGTH = np.max(user_word_level_char_length) | |
| SENTENCE_MAX_STROKE_LENGTH = np.max(user_sentence_level_stroke_length) | |
| SENTENCE_MAX_CHARACTER_LENGTH = np.max(user_sentence_level_char_length) | |
| new_sentence_level_stroke_in = np.asarray([np.pad(a, ((0, SENTENCE_MAX_STROKE_LENGTH-len(a)), (0,0)), 'constant') for a in user_sentence_level_stroke_in]) | |
| new_sentence_level_stroke_out = np.asarray([np.pad(a, ((0, SENTENCE_MAX_STROKE_LENGTH-len(a)), (0,0)), 'constant') for a in user_sentence_level_stroke_out]) | |
| new_sentence_level_term = np.asarray([np.pad(a, ((0, SENTENCE_MAX_STROKE_LENGTH-len(a))), 'constant') for a in user_sentence_level_term]) | |
| new_sentence_level_char = np.asarray([np.pad(a, ((0, SENTENCE_MAX_CHARACTER_LENGTH-len(a))), 'constant') for a in user_sentence_level_char]) | |
| new_word_level_stroke_in = np.asarray([np.pad(a, ((0, WORD_MAX_STROKE_LENGTH-len(a)), (0,0)), 'constant') for a in user_word_level_stroke_in]) | |
| new_word_level_stroke_out = np.asarray([np.pad(a, ((0, WORD_MAX_STROKE_LENGTH-len(a)), (0,0)), 'constant') for a in user_word_level_stroke_out]) | |
| new_word_level_term = np.asarray([np.pad(a, ((0, WORD_MAX_STROKE_LENGTH-len(a))), 'constant') for a in user_word_level_term]) | |
| new_word_level_char = np.asarray([np.pad(a, ((0, WORD_MAX_CHARACTER_LENGTH-len(a))), 'constant') for a in user_word_level_char]) | |
| all_sentence_level_stroke_in.append(new_sentence_level_stroke_in) | |
| all_sentence_level_stroke_out.append(new_sentence_level_stroke_out) | |
| all_sentence_level_stroke_length.append(user_sentence_level_stroke_length) | |
| all_sentence_level_term.append(new_sentence_level_term) | |
| all_sentence_level_char.append(new_sentence_level_char) | |
| all_sentence_level_char_length.append(user_sentence_level_char_length) | |
| all_word_level_stroke_in.append(new_word_level_stroke_in) | |
| all_word_level_stroke_out.append(new_word_level_stroke_out) | |
| all_word_level_stroke_length.append(user_word_level_stroke_length) | |
| all_word_level_term.append(new_word_level_term) | |
| all_word_level_char.append(new_word_level_char) | |
| all_word_level_char_length.append(user_word_level_char_length) | |
| all_segment_level_stroke_in.append(user_segment_level_stroke_in) | |
| all_segment_level_stroke_out.append(user_segment_level_stroke_out) | |
| all_segment_level_stroke_length.append(user_segment_level_stroke_length) | |
| all_segment_level_term.append(user_segment_level_term) | |
| all_segment_level_char.append(user_segment_level_char) | |
| all_segment_level_char_length.append(user_segment_level_char_length) | |
| return [all_sentence_level_stroke_in, all_sentence_level_stroke_out, all_sentence_level_stroke_length, all_sentence_level_term, all_sentence_level_char, all_sentence_level_char_length, all_word_level_stroke_in, all_word_level_stroke_out, all_word_level_stroke_length, all_word_level_term, all_word_level_char, all_word_level_char_length, all_segment_level_stroke_in, all_segment_level_stroke_out, all_segment_level_stroke_length, all_segment_level_term, all_segment_level_char, all_segment_level_char_length] | |