# Copyright (c) OpenMMLab. All rights reserved. from transformers import StoppingCriteria class StopWordStoppingCriteria(StoppingCriteria): """StopWord stopping criteria.""" def __init__(self, tokenizer, stop_word): self.tokenizer = tokenizer self.stop_word = stop_word self.length = len(self.stop_word) def __call__(self, input_ids, *args, **kwargs) -> bool: cur_text = self.tokenizer.decode(input_ids[0]) cur_text = cur_text.replace('\r', '').replace('\n', '') return cur_text[-self.length:] == self.stop_word