Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| # ๐พโ๏ธ๐ฎ | |
| # taken from https://github.com/Felflare/rpunct/blob/master/rpunct/punctuate.py | |
| # modified to support batching during gpu inference | |
| __author__ = "Daulet N." | |
| __email__ = "[email protected]" | |
| import time | |
| import logging | |
| import webvtt | |
| import torch | |
| from io import StringIO | |
| from nltk.tokenize import sent_tokenize | |
| #from langdetect import detect | |
| from simpletransformers.ner import NERModel | |
| class RestorePuncts: | |
| def __init__(self, wrds_per_pred=250): | |
| self.wrds_per_pred = wrds_per_pred | |
| self.overlap_wrds = 30 | |
| self.valid_labels = ['OU', 'OO', '.O', '!O', ',O', '.U', '!U', ',U', ':O', ';O', ':U', "'O", '-O', '?O', '?U'] | |
| self.model = NERModel("bert", "felflare/bert-restore-punctuation", labels=self.valid_labels, | |
| args={"silent": True, "max_seq_length": 512}) | |
| # use_cuda isnt working and this hack seems to load the model correctly to the gpu | |
| self.model.device = torch.device("cuda:1") | |
| # dummy punctuate to load the model onto gpu | |
| self.punctuate("hello how are you") | |
| def punctuate(self, text: str, batch_size:int=32, lang:str=''): | |
| """ | |
| Performs punctuation restoration on arbitrarily large text. | |
| Detects if input is not English, if non-English was detected terminates predictions. | |
| Overrride by supplying `lang='en'` | |
| Args: | |
| - text (str): Text to punctuate, can be few words to as large as you want. | |
| - lang (str): Explicit language of input text. | |
| """ | |
| #if not lang and len(text) > 10: | |
| # lang = detect(text) | |
| #if lang != 'en': | |
| # raise Exception(F"""Non English text detected. Restore Punctuation works only for English. | |
| # If you are certain the input is English, pass argument lang='en' to this function. | |
| # Punctuate received: {text}""") | |
| def chunks(L, n): | |
| return [L[x : x + n] for x in range(0, len(L), n)] | |
| # plit up large text into bert digestable chunks | |
| splits = self.split_on_toks(text, self.wrds_per_pred, self.overlap_wrds) | |
| texts = [i["text"] for i in splits] | |
| batches = chunks(texts, batch_size) | |
| preds_lst = [] | |
| for batch in batches: | |
| batch_preds, _ = self.model.predict(batch) | |
| preds_lst.extend(batch_preds) | |
| # predict slices | |
| # full_preds_lst contains tuple of labels and logits | |
| #full_preds_lst = [self.predict(i['text']) for i in splits] | |
| # extract predictions, and discard logits | |
| #preds_lst = [i[0][0] for i in full_preds_lst] | |
| # join text slices | |
| combined_preds = self.combine_results(text, preds_lst) | |
| # create punctuated prediction | |
| punct_text = self.punctuate_texts(combined_preds) | |
| return punct_text | |
| def predict(self, input_slice): | |
| """ | |
| Passes the unpunctuated text to the model for punctuation. | |
| """ | |
| predictions, raw_outputs = self.model.predict([input_slice]) | |
| return predictions, raw_outputs | |
| def split_on_toks(text, length, overlap): | |
| """ | |
| Splits text into predefined slices of overlapping text with indexes (offsets) | |
| that tie-back to original text. | |
| This is done to bypass 512 token limit on transformer models by sequentially | |
| feeding chunks of < 512 toks. | |
| Example output: | |
| [{...}, {"text": "...", 'start_idx': 31354, 'end_idx': 32648}, {...}] | |
| """ | |
| wrds = text.replace('\n', ' ').split(" ") | |
| resp = [] | |
| lst_chunk_idx = 0 | |
| i = 0 | |
| while True: | |
| # words in the chunk and the overlapping portion | |
| wrds_len = wrds[(length * i):(length * (i + 1))] | |
| wrds_ovlp = wrds[(length * (i + 1)):((length * (i + 1)) + overlap)] | |
| wrds_split = wrds_len + wrds_ovlp | |
| # Break loop if no more words | |
| if not wrds_split: | |
| break | |
| wrds_str = " ".join(wrds_split) | |
| nxt_chunk_start_idx = len(" ".join(wrds_len)) | |
| lst_char_idx = len(" ".join(wrds_split)) | |
| resp_obj = { | |
| "text": wrds_str, | |
| "start_idx": lst_chunk_idx, | |
| "end_idx": lst_char_idx + lst_chunk_idx, | |
| } | |
| resp.append(resp_obj) | |
| lst_chunk_idx += nxt_chunk_start_idx + 1 | |
| i += 1 | |
| logging.info(f"Sliced transcript into {len(resp)} slices.") | |
| return resp | |
| def combine_results(full_text: str, text_slices): | |
| """ | |
| Given a full text and predictions of each slice combines predictions into a single text again. | |
| Performs validataion wether text was combined correctly | |
| """ | |
| split_full_text = full_text.replace('\n', ' ').split(" ") | |
| split_full_text = [i for i in split_full_text if i] | |
| split_full_text_len = len(split_full_text) | |
| output_text = [] | |
| index = 0 | |
| if len(text_slices[-1]) <= 3 and len(text_slices) > 1: | |
| text_slices = text_slices[:-1] | |
| for _slice in text_slices: | |
| slice_wrds = len(_slice) | |
| for ix, wrd in enumerate(_slice): | |
| # print(index, "|", str(list(wrd.keys())[0]), "|", split_full_text[index]) | |
| if index == split_full_text_len: | |
| break | |
| if split_full_text[index] == str(list(wrd.keys())[0]) and \ | |
| ix <= slice_wrds - 3 and text_slices[-1] != _slice: | |
| index += 1 | |
| pred_item_tuple = list(wrd.items())[0] | |
| output_text.append(pred_item_tuple) | |
| elif split_full_text[index] == str(list(wrd.keys())[0]) and text_slices[-1] == _slice: | |
| index += 1 | |
| pred_item_tuple = list(wrd.items())[0] | |
| output_text.append(pred_item_tuple) | |
| assert [i[0] for i in output_text] == split_full_text | |
| return output_text | |
| def punctuate_texts(full_pred: list): | |
| """ | |
| Given a list of Predictions from the model, applies the predictions to text, | |
| thus punctuating it. | |
| """ | |
| punct_resp = "" | |
| for i in full_pred: | |
| word, label = i | |
| if label[-1] == "U": | |
| punct_wrd = word.capitalize() | |
| else: | |
| punct_wrd = word | |
| if label[0] != "O": | |
| punct_wrd += label[0] | |
| punct_resp += punct_wrd + " " | |
| punct_resp = punct_resp.strip() | |
| # Append trailing period if doesnt exist. | |
| if punct_resp[-1].isalnum(): | |
| punct_resp += "." | |
| return punct_resp | |
| if __name__ == "__main__": | |
| start = time.time() | |
| punct_model = RestorePuncts() | |
| load_model = time.time() | |
| print(f'Time to load model: {load_model - start}') | |
| # read test file | |
| # with open('en_lower.txt', 'r') as fp: | |
| # # test_sample = fp.read() | |
| # lines = fp.readlines() | |
| with open('sample.vtt', 'r') as fp: | |
| source_text = fp.read() | |
| # captions = webvtt.read_buffer(StringIO(source_text)) | |
| captions = webvtt.read('sample.vtt') | |
| source_sentences = [caption.text.replace('\r', '').replace('\n', ' ') for caption in captions] | |
| # print(source_sentences) | |
| sent = ' '.join(source_sentences) | |
| punctuated = punct_model.punctuate(sent) | |
| tokenised = sent_tokenize(punctuated) | |
| # print(tokenised) | |
| for i in range(len(tokenised)): | |
| captions[i].text = tokenised[i] | |
| # return captions.content | |
| captions.save('my_captions.vtt') | |
| end = time.time() | |
| print(f'Time for run: {end - load_model}') | |
| print(f'Total time: {end - start}') | |