| import os | |
| import json | |
| import tempfile | |
| import sys | |
| import datetime | |
| import re | |
| import string | |
| sys.path.append('mtool') | |
| import torch | |
| from model.model import Model | |
| from data.dataset import Dataset | |
| from config.params import Params | |
| from utility.initialize import initialize | |
| from data.batch import Batch | |
| from mtool.main import main as mtool_main | |
| from tqdm import tqdm | |
| class PredictionModel(torch.nn.Module): | |
| def __init__(self, checkpoint_path=os.path.join('models', 'checkpoint.bin'), default_mrp_path=os.path.join('models', 'default.mrp'), verbose=False): | |
| super().__init__() | |
| self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
| self.checkpoint = torch.load('./models/checkpoint.bin', map_location=torch.device('cpu')) | |
| self.verbose = verbose | |
| self.args = Params().load_state_dict(self.checkpoint['params']) | |
| self.args.log_wandb = False | |
| self.args.training_data = default_mrp_path | |
| self.args.validation_data = default_mrp_path | |
| self.args.test_data = default_mrp_path | |
| self.args.only_train = False | |
| self.args.encoder = os.path.join('models', 'encoder') | |
| initialize(self.args, init_wandb=False) | |
| self.dataset = Dataset(self.args, verbose=False) | |
| self.model = Model(self.dataset, self.args).to(self.device) | |
| self.model.load_state_dict(self.checkpoint["model"], strict=False) | |
| self.model.eval() | |
| def _mrp_to_text(self, mrp_list, graph_mode='labeled-edge'): | |
| framework = 'norec' | |
| with tempfile.NamedTemporaryFile(delete=False, mode='w') as output_text_file: | |
| output_text_filename = output_text_file.name | |
| with tempfile.NamedTemporaryFile(delete=False, mode='w') as mrp_file: | |
| line = '\n'.join([json.dumps(entry) for entry in mrp_list]) | |
| mrp_file.write(line) | |
| mrp_filename = mrp_file.name | |
| if graph_mode == 'labeled-edge': | |
| mtool_main([ | |
| '--strings', | |
| '--ids', | |
| '--read', 'mrp', | |
| '--write', framework, | |
| mrp_filename, output_text_filename | |
| ]) | |
| elif graph_mode == 'node-centric': | |
| mtool_main([ | |
| '--node_centric', | |
| '--strings', | |
| '--ids', | |
| '--read', 'mrp', | |
| '--write', framework, | |
| mrp_filename, output_text_filename | |
| ]) | |
| else: | |
| raise Exception(f'Unknown graph mode: {graph_mode}') | |
| with open(output_text_filename) as f: | |
| texts = json.load(f) | |
| os.unlink(output_text_filename) | |
| os.unlink(mrp_filename) | |
| return texts | |
| def clean_texts(self, texts): | |
| punctuation = ''.join([f'\\{s}' for s in string.punctuation]) | |
| texts = [re.sub(f'([{punctuation}])', ' \\1 ', t) for t in texts] | |
| texts = [re.sub(r' +', ' ', t) for t in texts] | |
| return texts | |
| def _predict_to_mrp(self, texts, graph_mode='labeled-edge'): | |
| texts = self.clean_texts(texts) | |
| framework, language = self.args.framework, self.args.language | |
| data = self.dataset.load_sentences(texts, self.args) | |
| res_sentences = {f"{i}": {'input': sentence} for i, sentence in enumerate(texts)} | |
| date_str = datetime.datetime.now().date().isoformat() | |
| for key, value_dict in res_sentences.items(): | |
| value_dict['id'] = key | |
| value_dict['time'] = date_str | |
| value_dict['framework'], value_dict['language'] = framework, language | |
| value_dict['nodes'], value_dict['edges'], value_dict['tops'] = [], [], [] | |
| for i, batch in enumerate(tqdm(data) if self.verbose else data): | |
| with torch.no_grad(): | |
| predictions = self.model(Batch.to(batch, self.device), inference=True) | |
| for prediction in predictions: | |
| for key, value in prediction.items(): | |
| res_sentences[prediction['id']][key] = value | |
| return res_sentences | |
| def predict(self, text_list, graph_mode='labeled-edge', language='no'): | |
| mrp_predictions = self._predict_to_mrp(text_list, graph_mode) | |
| predictions = self._mrp_to_text(mrp_predictions.values(), graph_mode) | |
| return predictions | |
| def forward(self, x): | |
| return self.predict(x) | |