Spaces:
Runtime error
Runtime error
| import time | |
| import re | |
| from math import floor, ceil | |
| from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils | |
| # from nltk.tokenize import sent_tokenize | |
| from flask import Flask, request, jsonify | |
| from flask_cors import CORS, cross_origin | |
| import webvtt | |
| from io import StringIO | |
| from mosestokenizer import MosesSentenceSplitter | |
| from indicTrans.inference.engine import Model | |
| from punctuate import RestorePuncts | |
| from indicnlp.tokenize.sentence_tokenize import sentence_split | |
| app = Flask(__name__) | |
| cors = CORS(app) | |
| app.config['CORS_HEADERS'] = 'Content-Type' | |
| indic2en_model = Model(expdir='models/v3/indic-en') | |
| en2indic_model = Model(expdir='models/v3/en-indic') | |
| m2m_model = Model(expdir='models/m2m') | |
| rpunct = RestorePuncts() | |
| indic_language_dict = { | |
| 'Assamese': 'as', | |
| 'Hindi' : 'hi', | |
| 'Marathi' : 'mr', | |
| 'Tamil' : 'ta', | |
| 'Bengali' : 'bn', | |
| 'Kannada' : 'kn', | |
| 'Oriya' : 'or', | |
| 'Telugu' : 'te', | |
| 'Gujarati' : 'gu', | |
| 'Malayalam' : 'ml', | |
| 'Punjabi' : 'pa', | |
| } | |
| splitter = MosesSentenceSplitter('en') | |
| def get_inference_params(): | |
| source_language = request.form['source_language'] | |
| target_language = request.form['target_language'] | |
| if source_language in indic_language_dict and target_language == 'English': | |
| model = indic2en_model | |
| source_lang = indic_language_dict[source_language] | |
| target_lang = 'en' | |
| elif source_language == 'English' and target_language in indic_language_dict: | |
| model = en2indic_model | |
| source_lang = 'en' | |
| target_lang = indic_language_dict[target_language] | |
| elif source_language in indic_language_dict and target_language in indic_language_dict: | |
| model = m2m_model | |
| source_lang = indic_language_dict[source_language] | |
| target_lang = indic_language_dict[target_language] | |
| return model, source_lang, target_lang | |
| def main(): | |
| return "IndicTrans API" | |
| def supported_languages(): | |
| return jsonify(indic_language_dict) | |
| def infer_indic_en(): | |
| model, source_lang, target_lang = get_inference_params() | |
| source_text = request.form['text'] | |
| start_time = time.time() | |
| target_text = model.translate_paragraph(source_text, source_lang, target_lang) | |
| end_time = time.time() | |
| return {'text':target_text, 'duration':round(end_time-start_time, 2)} | |
| def infer_vtt_indic_en(): | |
| start_time = time.time() | |
| model, source_lang, target_lang = get_inference_params() | |
| source_text = request.form['text'] | |
| # vad_segments = request.form['vad_nochunk'] # Assuming it is an array of start & end timestamps | |
| vad = webvtt.read_buffer(StringIO(source_text)) | |
| source_sentences = [v.text.replace('\r', '').replace('\n', ' ') for v in vad] | |
| ## SUMANTH LOGIC HERE ## | |
| # for each vad timestamp, do: | |
| large_sentence = ' '.join(source_sentences) # only sentences in that time range | |
| large_sentence = large_sentence.lower() | |
| # split_sents = sentence_split(large_sentence, 'en') | |
| # print(split_sents) | |
| large_sentence = re.sub(r'[^\w\s]', '', large_sentence) | |
| punctuated = rpunct.punctuate(large_sentence, batch_size=32) | |
| end_time = time.time() | |
| print("Time Taken for punctuation: {} s".format(end_time - start_time)) | |
| start_time = time.time() | |
| split_sents = splitter([punctuated]) ### Please uncomment | |
| # print(split_sents) | |
| # output_sentence_punctuated = model.translate_paragraph(punctuated, source_lang, target_lang) | |
| output_sents = model.batch_translate(split_sents, source_lang, target_lang) | |
| # print(output_sents) | |
| # output_sents = split_sents | |
| # print(output_sents) | |
| # align this to those range of source_sentences in `captions` | |
| map_ = {split_sents[i] : output_sents[i] for i in range(len(split_sents))} | |
| # print(map_) | |
| punct_para = ' '.join(list(map_.keys())) | |
| nmt_para = ' '.join(list(map_.values())) | |
| nmt_words = nmt_para.split(' ') | |
| len_punct = len(punct_para.split(' ')) | |
| len_nmt = len(nmt_para.split(' ')) | |
| start = 0 | |
| for i in range(len(vad)): | |
| if vad[i].text == '': | |
| continue | |
| len_caption = len(vad[i].text.split(' ')) | |
| frac = (len_caption / len_punct) | |
| # frac = round(frac, 2) | |
| req_nmt_size = floor(frac * len_nmt) | |
| # print(frac, req_nmt_size) | |
| vad[i].text = ' '.join(nmt_words[start:start+req_nmt_size]) | |
| # print(vad[i].text) | |
| # print(start, req_nmt_size) | |
| start += req_nmt_size | |
| end_time = time.time() | |
| print("Time Taken for translation: {} s".format(end_time - start_time)) | |
| # vad.save('aligned.vtt') | |
| return { | |
| 'text': vad.content, | |
| # 'duration':round(end_time-start_time, 2) | |
| } | |