Spaces:
Build error
Build error
| import logging | |
| from typing import List | |
| import numpy as np | |
| import tensorflow as tf | |
| from transformers import BertTokenizer, TFAutoModelForMaskedLM | |
| from rhyme_with_ai.token_weighter import TokenWeighter | |
| from rhyme_with_ai.utils import pairwise | |
| class RhymeGenerator: | |
| def __init__( | |
| self, | |
| model: TFAutoModelForMaskedLM, | |
| tokenizer: BertTokenizer, | |
| token_weighter: TokenWeighter = None, | |
| ): | |
| """Generate rhymes. | |
| Parameters | |
| ---------- | |
| model : Model for masked language modelling | |
| tokenizer : Tokenizer for model | |
| token_weighter : Class that weighs tokens | |
| """ | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| if token_weighter is None: | |
| token_weighter = TokenWeighter(tokenizer) | |
| self.token_weighter = token_weighter | |
| self._logger = logging.getLogger(__name__) | |
| self.tokenized_rhymes_ = None | |
| self.position_probas_ = None | |
| # Easy access. | |
| self.comma_token_id = self.tokenizer.encode(",", add_special_tokens=False)[0] | |
| self.period_token_id = self.tokenizer.encode(".", add_special_tokens=False)[0] | |
| self.mask_token_id = self.tokenizer.mask_token_id | |
| def start(self, query: str, rhyme_words: List[str]) -> None: | |
| """Start the sentence generator. | |
| Parameters | |
| ---------- | |
| query : Seed sentence | |
| rhyme_words : Rhyme words for next sentence | |
| """ | |
| # TODO: What if no content? | |
| self._logger.info("Got sentence %s", query) | |
| tokenized_rhymes = [ | |
| self._initialize_rhymes(query, rhyme_word) for rhyme_word in rhyme_words | |
| ] | |
| # Make same length. | |
| self.tokenized_rhymes_ = tf.keras.preprocessing.sequence.pad_sequences( | |
| tokenized_rhymes, padding="post", value=self.tokenizer.pad_token_id | |
| ) | |
| p = self.tokenized_rhymes_ == self.tokenizer.mask_token_id | |
| self.position_probas_ = p / p.sum(1).reshape(-1, 1) | |
| def _initialize_rhymes(self, query: str, rhyme_word: str) -> List[int]: | |
| """Initialize the rhymes. | |
| * Tokenize input | |
| * Append a comma if the sentence does not end in it (might add better predictions as it | |
| shows the two sentence parts are related) | |
| * Make second line as long as the original | |
| * Add a period | |
| Parameters | |
| ---------- | |
| query : First line | |
| rhyme_word : Last word for second line | |
| Returns | |
| ------- | |
| Tokenized rhyme lines | |
| """ | |
| query_token_ids = self.tokenizer.encode(query, add_special_tokens=False) | |
| rhyme_word_token_ids = self.tokenizer.encode( | |
| rhyme_word, add_special_tokens=False | |
| ) | |
| if query_token_ids[-1] != self.comma_token_id: | |
| query_token_ids.append(self.comma_token_id) | |
| magic_correction = len(rhyme_word_token_ids) + 1 # 1 for comma | |
| return ( | |
| query_token_ids | |
| + [self.tokenizer.mask_token_id] * (len(query_token_ids) - magic_correction) | |
| + rhyme_word_token_ids | |
| + [self.period_token_id] | |
| ) | |
| def mutate(self): | |
| """Mutate the current rhymes. | |
| Returns | |
| ------- | |
| Mutated rhymes | |
| """ | |
| self.tokenized_rhymes_ = self._mutate( | |
| self.tokenized_rhymes_, self.position_probas_, self.token_weighter.proba | |
| ) | |
| rhymes = [] | |
| for i in range(len(self.tokenized_rhymes_)): | |
| rhymes.append( | |
| self.tokenizer.convert_tokens_to_string( | |
| self.tokenizer.convert_ids_to_tokens( | |
| self.tokenized_rhymes_[i], skip_special_tokens=True | |
| ) | |
| ) | |
| ) | |
| return rhymes | |
| def _mutate( | |
| self, | |
| tokenized_rhymes: np.ndarray, | |
| position_probas: np.ndarray, | |
| token_id_probas: np.ndarray, | |
| ) -> np.ndarray: | |
| replacements = [] | |
| for i in range(tokenized_rhymes.shape[0]): | |
| mask_idx, masked_token_ids = self._mask_token( | |
| tokenized_rhymes[i], position_probas[i] | |
| ) | |
| tokenized_rhymes[i] = masked_token_ids | |
| replacements.append(mask_idx) | |
| predictions = self._predict_masked_tokens(tokenized_rhymes) | |
| for i, token_ids in enumerate(tokenized_rhymes): | |
| replace_ix = replacements[i] | |
| token_ids[replace_ix] = self._draw_replacement( | |
| predictions[i], token_id_probas, replace_ix | |
| ) | |
| tokenized_rhymes[i] = token_ids | |
| return tokenized_rhymes | |
| def _mask_token(self, token_ids, position_probas): | |
| """Mask line and return index to update.""" | |
| token_ids = self._mask_repeats(token_ids, position_probas) | |
| ix = self._locate_mask(token_ids, position_probas) | |
| token_ids[ix] = self.mask_token_id | |
| return ix, token_ids | |
| def _locate_mask(self, token_ids, position_probas): | |
| """Update masks or a random token.""" | |
| if self.mask_token_id in token_ids: | |
| # Already masks present, just return the last. | |
| # We used to return thee first but this returns worse predictions. | |
| return np.where(token_ids == self.tokenizer.mask_token_id)[0][-1] | |
| return np.random.choice(range(len(position_probas)), p=position_probas) | |
| def _mask_repeats(self, token_ids, position_probas): | |
| """Repeated tokens are generally of less quality.""" | |
| repeats = [ | |
| ii for ii, ids in enumerate(pairwise(token_ids[:-2])) if ids[0] == ids[1] | |
| ] | |
| for ii in repeats: | |
| if position_probas[ii] > 0: | |
| token_ids[ii] = self.mask_token_id | |
| if position_probas[ii + 1] > 0: | |
| token_ids[ii + 1] = self.mask_token_id | |
| return token_ids | |
| def _predict_masked_tokens(self, tokenized_rhymes): | |
| return self.model(tf.constant(tokenized_rhymes))[0] | |
| def _draw_replacement(self, predictions, token_probas, replace_ix): | |
| """Get probability, weigh and draw.""" | |
| # TODO (HG): Can't we softmax when calling the model? | |
| probas = tf.nn.softmax(predictions[replace_ix]).numpy() * token_probas | |
| probas /= probas.sum() | |
| return np.random.choice(range(len(probas)), p=probas) | |