Spaces:
Build error
Build error
| from config.GlobalVariables import * | |
| import math | |
| import torch | |
| from scipy import stats | |
| import numpy as np | |
| from PIL import Image, ImageDraw | |
| import os | |
| import pickle | |
| def preprocess_dataset(data_dir, resample=20, pred_start=1): | |
| def reformat_raw_data(raw_data, pred_start): | |
| if pred_start == 1: | |
| tmp = np.concatenate([[[0, 500, 0]], raw_data], 0) | |
| tmp = tmp[1:] - tmp[:-1] | |
| tmp[1:, 2] = raw_data[:-1, 2] | |
| tmp = np.concatenate([[[0, 0, 0]], tmp], 0) | |
| else: | |
| tmp = np.concatenate([raw_data[0:1], raw_data]) | |
| tmp = tmp[1:] - tmp[:-1] | |
| tmp[0,2] = 0 | |
| tmp[1:,2] = raw_data[:-1, 2] | |
| return tmp[:-1], tmp[1:] | |
| prohibits = [f'./BRUSH/5/118_resample{resample}', f'./BRUSH/7/14_resample{resample}', | |
| f'./BRUSH/7/101_resample{resample}', f'./BRUSH/7/58_resample{resample}', | |
| f'./BRUSH/14/20_resample{resample}', f'./BRUSH/22/45_resample{resample}', | |
| f'./BRUSH/30/45_resample{resample}', f'./BRUSH/40/85_resample{resample}', | |
| f'./BRUSH/50/45_resample{resample}', f'./BRUSH/59/29_resample{resample}', | |
| f'./BRUSH/96/120_resample{resample}', f'./BRUSH/99/134_resample{resample}', | |
| f'./BRUSH/140/35_resample{resample}', f'./BRUSH/144/55_resample{resample}', | |
| f'./BRUSH/144/91_resample{resample}', f'./BRUSH/144/28_resample{resample}', | |
| f'./BRUSH/144/69_resample{resample}'] | |
| preprocess_dir = 'preprocess' if pred_start == 1 else 'preprocess2' | |
| for writer_id in range(170): | |
| print(f'Preprocessing the BRUSH dataset - Finished Writer ID: {writer_id + 1} / 170') | |
| for sentence_id in [i for i in os.listdir(f'{data_dir}/{writer_id}') if i[-3:] == f'e{resample}']: | |
| with open(f'{data_dir}/{writer_id}/{sentence_id}', 'rb') as f: | |
| [sentence_text, raw_points, character_labels] = pickle.load(f) | |
| if f'{data_dir}/{writer_id}/{sentence_id}' not in prohibits: | |
| sentence_raw_points = raw_points | |
| sentence_raw_points[:, 0] -= sentence_raw_points[0, 0] | |
| sentence_stroke_in, sentence_stroke_out = reformat_raw_data(sentence_raw_points, pred_start=pred_start) | |
| split_char_ids = [i for i, c in enumerate(sentence_text) if c == ' '] | |
| sentence_char = [CHARACTERS.find(c) for c in sentence_text] | |
| sentence_term = [] | |
| cid = 0 | |
| for i in range(len(character_labels) - 1): | |
| if character_labels[i + 1, cid] != 1: | |
| if np.argmax(character_labels[i + 1]) >= cid: | |
| cid += 1 | |
| sentence_term.append(1) | |
| else: | |
| sentence_term.append(0) | |
| else: | |
| sentence_term.append(0) | |
| sentence_term.append(1) | |
| sentence_term = np.asarray(sentence_term) | |
| assert (len(sentence_term) == len(character_labels)) | |
| word_level_raw_stroke = [] | |
| word_level_stroke_in = [] | |
| word_level_stroke_out = [] | |
| word_level_char = [] | |
| word_level_term = [] | |
| segment_level_raw_stroke = [] | |
| segment_level_stroke_in = [] | |
| segment_level_stroke_out = [] | |
| segment_level_char = [] | |
| segment_level_term = [] | |
| character_level_raw_stroke = [] | |
| character_level_stroke_in = [] | |
| character_level_stroke_out = [] | |
| character_level_char = [] | |
| character_level_term = [] | |
| word_start_id = 0 | |
| for i, c in enumerate(sentence_text): | |
| if c != ' ': | |
| character_raw_points = raw_points[character_labels[:, i] > 0] | |
| character_raw_points[:, 0] -= character_raw_points[0, 0] | |
| character_level_raw_stroke.append(character_raw_points) | |
| character_stroke_in, character_stroke_out = reformat_raw_data(character_raw_points, pred_start=pred_start) | |
| character_level_stroke_in.append(character_stroke_in) | |
| character_level_stroke_out.append(character_stroke_out) | |
| term = np.zeros([len(character_raw_points)]) | |
| term[-1] = 1 | |
| character_level_term.append(term) | |
| character_level_char.append([CHARACTERS.find(c)]) | |
| if i in split_char_ids: | |
| word = sentence_text[word_start_id:i] | |
| word_labels = np.zeros(len(character_labels)) | |
| for j in range(word_start_id, i): | |
| word_labels += character_labels[:, j] | |
| word_raw_points = raw_points[word_labels > 0] | |
| word_term = sentence_term[word_labels > 0] | |
| word_term[0] = 0 | |
| assert (np.sum(word_term) == len(word)) | |
| word_raw_points[:, 0] -= word_raw_points[0, 0] | |
| word_level_raw_stroke.append(word_raw_points) | |
| word_stroke_in, word_stroke_out = reformat_raw_data(word_raw_points, pred_start=pred_start) | |
| word_level_stroke_in.append(word_stroke_in) | |
| word_level_stroke_out.append(word_stroke_out) | |
| word_level_term.append(word_term) | |
| word_level_char.append([CHARACTERS.find(c) for c in word]) | |
| word_start_id = i + 1 | |
| assert (len(character_level_raw_stroke) == len(word)) | |
| segment_level_raw_stroke.append(character_level_raw_stroke) | |
| segment_level_stroke_in.append(character_level_stroke_in) | |
| segment_level_stroke_out.append(character_level_stroke_out) | |
| segment_level_char.append(character_level_char) | |
| segment_level_term.append(character_level_term) | |
| character_level_raw_stroke = [] | |
| character_level_stroke_in = [] | |
| character_level_stroke_out = [] | |
| character_level_char = [] | |
| character_level_term = [] | |
| word = sentence_text[word_start_id:] | |
| word_labels = np.zeros(len(character_labels)) | |
| for j in range(word_start_id, len(sentence_text)): | |
| word_labels += character_labels[:, j] | |
| word_raw_points = raw_points[word_labels > 0] | |
| word_raw_points[:, 0] -= word_raw_points[0, 0] | |
| word_term = sentence_term[word_labels > 0] | |
| word_term[0] = 0 | |
| assert (np.sum(word_term) == len(word)) | |
| word_level_raw_stroke.append(word_raw_points) | |
| word_stroke_in, word_stroke_out = reformat_raw_data(word_raw_points, pred_start=pred_start) | |
| word_level_stroke_in.append(word_stroke_in) | |
| word_level_stroke_out.append(word_stroke_out) | |
| word_level_term.append(word_term) | |
| word_level_char.append([CHARACTERS.find(c) for c in word]) | |
| assert (len(character_level_raw_stroke) == len(word)) | |
| segment_level_raw_stroke.append(character_level_raw_stroke) | |
| segment_level_stroke_in.append(character_level_stroke_in) | |
| segment_level_stroke_out.append(character_level_stroke_out) | |
| segment_level_char.append(character_level_char) | |
| segment_level_term.append(character_level_term) | |
| if not os.path.exists(f'{data_dir}/{preprocess_dir}/{writer_id}'): | |
| os.mkdir(f'{data_dir}/{preprocess_dir}/{writer_id}') | |
| with open(f'{data_dir}/{preprocess_dir}/{writer_id}/{sentence_id}', 'wb') as f: | |
| pickle.dump([ | |
| sentence_stroke_in, sentence_stroke_out, sentence_term, sentence_char, | |
| word_level_stroke_in, word_level_stroke_out, word_level_term, word_level_char, | |
| segment_level_stroke_in, segment_level_stroke_out, segment_level_term, segment_level_char], f) | |
| def gaussian_2d(x1, x2, mu1, mu2, s1, s2, rho): | |
| norm1 = x1 - mu1 | |
| norm2 = x2 - mu2 | |
| s1s2 = s1 * s2 | |
| z = (norm1 / s1) ** 2 + (norm2 / s2) ** 2 - 2 * rho * norm1 * norm2 / s1s2 | |
| numerator = torch.exp(-z / (2 * (1 - rho ** 2))) | |
| denominator = 2 * math.pi * s1s2 * torch.sqrt(1 - rho ** 2) | |
| gaussian = numerator / denominator | |
| return gaussian | |
| def get_minimax(stroke_results): | |
| minimas = [] | |
| maximas = [] | |
| for stroke in stroke_results: | |
| for i, [x, y] in enumerate(stroke): | |
| if i == 0: | |
| prev_x, prev_y = x, y | |
| if i == len(stroke) - 1: | |
| if prev_y <= y: | |
| maximas.append([x, y]) | |
| if prev_y >= y: | |
| minimas.append([x, y]) | |
| break | |
| else: | |
| next_x, next_y = stroke[i + 1] | |
| if prev_y <= y and y >= next_y: | |
| maximas.append([x, y]) | |
| if prev_y >= y and y <= next_y: | |
| minimas.append([x, y]) | |
| minimas = np.asarray(minimas) | |
| maximas = np.asarray(maximas) | |
| return minimas, maximas | |
| def get_slope(minimas, maximas): | |
| minima_slope, minima_intercept, _, _, _ = stats.linregress(minimas[:, 0], minimas[:, 1]) | |
| maxima_slope, maxima_intercept, _, _, _ = stats.linregress(maximas[:, 0], maximas[:, 1]) | |
| min_se = [] | |
| max_se = [] | |
| for [x, y] in minimas: | |
| min_ny = minima_slope * x + minima_intercept | |
| min_se.append(abs(min_ny - y)) | |
| for [x, y] in maximas: | |
| max_ny = maxima_slope * x + maxima_intercept | |
| max_se.append(abs(max_ny - y)) | |
| min_se, max_se = np.asarray(min_se), np.asarray(max_se) | |
| new_minimas = minimas[min_se < np.mean(min_se)] | |
| new_maximas = maximas[max_se < np.mean(max_se)] | |
| if len(new_minimas) > 5: | |
| minima_slope, minima_intercept, _, _, _ = stats.linregress(new_minimas[:, 0], new_minimas[:, 1]) | |
| if len(new_maximas) > 5: | |
| maxima_slope, maxima_intercept, _, _, _ = stats.linregress(new_maximas[:, 0], new_maximas[:, 1]) | |
| return minima_slope, minima_intercept, maxima_slope, maxima_intercept | |
| def draw_commands(commands): | |
| im = Image.fromarray(np.zeros([160, 750])) | |
| dr = ImageDraw.Draw(im) | |
| px, py = 50, 100 | |
| for i, [dx, dy, t] in enumerate(commands): | |
| x = px + dx * 5 | |
| y = py + dy * 5 | |
| if t == 0: | |
| dr.line((px, py, x, y), 255, 1) | |
| px, py = x, y | |
| return im | |
| def draw_points(raw_points, character_labels): | |
| [w, h, _] = np.max(raw_points, 0) | |
| im = Image.new("RGB", [int(w) + 100, int(h) + 100]) | |
| dr = ImageDraw.Draw(im) | |
| colors = np.random.randint(0, 255, (len(character_labels[0]), 3)) | |
| for i, [x, y, t] in enumerate(raw_points): | |
| if i > 0: | |
| if pt == 0: | |
| dr.line((px, py, x, y), tuple(colors[np.argmax(character_labels[i])]), 3) | |
| px, py, pt = x, y, t | |
| return im | |