|  | """ | 
					
						
						|  | This is a self-contained and flexible beam search implementation adapted from | 
					
						
						|  | AllenNLP's beam search: https://github.com/allenai/allennlp/blob/main/allennlp/nn/beam_search.py | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | import copy | 
					
						
						|  | import warnings | 
					
						
						|  | from abc import abstractmethod | 
					
						
						|  | from inspect import signature | 
					
						
						|  | from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, cast | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  |  | 
					
						
						|  | __all__ = [ | 
					
						
						|  | "Sampler", | 
					
						
						|  | "DeterministicSampler", | 
					
						
						|  | "MultinomialSampler", | 
					
						
						|  | "TopKSampler", | 
					
						
						|  | "TopPSampler", | 
					
						
						|  | "GumbelSampler", | 
					
						
						|  | "FinalSequenceScorer", | 
					
						
						|  | "SequenceLogProbabilityScorer", | 
					
						
						|  | "LengthNormalizedSequenceLogProbabilityScorer", | 
					
						
						|  | "Constraint", | 
					
						
						|  | "RepeatedNGramBlockingConstraint", | 
					
						
						|  | "BeamSearch", | 
					
						
						|  | ] | 
					
						
						|  |  | 
					
						
						|  | StateType = Dict[str, torch.Tensor] | 
					
						
						|  | StepFunctionTypeWithTimestep = Callable[[torch.Tensor, StateType, int], Tuple[torch.Tensor, StateType]] | 
					
						
						|  | StepFunctionTypeNoTimestep = Callable[[torch.Tensor, StateType], Tuple[torch.Tensor, StateType]] | 
					
						
						|  |  | 
					
						
						|  | StepFunctionType = TypeVar("StepFunctionType", StepFunctionTypeWithTimestep, StepFunctionTypeNoTimestep) | 
					
						
						|  | """ | 
					
						
						|  | The type of step function that can be passed to [`BeamSearch.search`](#search). | 
					
						
						|  |  | 
					
						
						|  | This can either be [`StepFunctionTypeWithTimestep`](#stepfunctiontypewithtimestep) | 
					
						
						|  | or [`StepFunctionTypeNoTimestep`](#stepfunctiontypenotimestep). | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | ConstraintStateType = List[List[Dict[str, Any]]] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class Sampler: | 
					
						
						|  | """ | 
					
						
						|  | An abstract class that can be used to sample candidates (either nodes or beams) | 
					
						
						|  | within `BeamSearch`. | 
					
						
						|  |  | 
					
						
						|  | A `Sampler` just has three methods, `init_state()`, `sample_nodes()` and `sample_beams()`. | 
					
						
						|  |  | 
					
						
						|  | `init_state()` takes three arguments: | 
					
						
						|  |  | 
					
						
						|  | - a tensor of starting log probs with shape `(batch_size,, num_classes)`, | 
					
						
						|  | - the batch size, an int, | 
					
						
						|  | - and the number of classes, also an int. | 
					
						
						|  |  | 
					
						
						|  | It returns a state dictionary with any state tensors needed for subsequent | 
					
						
						|  | calls to `sample_nodes()` and `sample_beams()`. | 
					
						
						|  |  | 
					
						
						|  | By default this method just returns an empty dictionary. | 
					
						
						|  |  | 
					
						
						|  | Both `sample_nodes()` and `sample_beams()` should take three arguments: | 
					
						
						|  |  | 
					
						
						|  | - tensor of normalized log probabilities with shape `(batch_size, num_examples)`, | 
					
						
						|  | - an integer representing the number of samples to take for each example in the batch, | 
					
						
						|  | - and a state dictionary which could contain any tensors needed for the `Sampler` to keep | 
					
						
						|  | track of state. | 
					
						
						|  |  | 
					
						
						|  | For `sample_nodes()`, `num_examples = num_classes`, but for `sample_beams`, | 
					
						
						|  | `num_examples = beam_size * per_node_beam_size`. | 
					
						
						|  |  | 
					
						
						|  | The return value should be a tuple containing: | 
					
						
						|  |  | 
					
						
						|  | - a tensor of log probabilities of the sampled examples with shape `(batch_size, num_samples)`, | 
					
						
						|  | - a tensor of indices of the sampled examples with shape `(batch_size, num_samples)`, | 
					
						
						|  | - and the updated state dictionary. | 
					
						
						|  |  | 
					
						
						|  | A default implementation of `sample_beams` is provided, which just deterministically | 
					
						
						|  | picks the `k` examples with highest log probability. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def init_state( | 
					
						
						|  | self, start_class_log_probabilities: torch.Tensor, batch_size: int, num_classes: int | 
					
						
						|  | ) -> StateType: | 
					
						
						|  | del start_class_log_probabilities, batch_size, num_classes | 
					
						
						|  | return {} | 
					
						
						|  |  | 
					
						
						|  | @abstractmethod | 
					
						
						|  | def sample_nodes( | 
					
						
						|  | self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType | 
					
						
						|  | ) -> Tuple[torch.Tensor, torch.Tensor, StateType]: | 
					
						
						|  | raise NotImplementedError | 
					
						
						|  |  | 
					
						
						|  | def sample_beams( | 
					
						
						|  | self, log_probs: torch.Tensor, beam_size: int, state: StateType | 
					
						
						|  | ) -> Tuple[torch.Tensor, torch.Tensor, StateType]: | 
					
						
						|  | del state | 
					
						
						|  | selected_log_probs, selected_indices = torch.topk(log_probs, beam_size, dim=-1) | 
					
						
						|  | return selected_log_probs, selected_indices, {} | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class DeterministicSampler(Sampler): | 
					
						
						|  | """ | 
					
						
						|  | A `Sampler` that just deterministically returns the `k` nodes or beams with highest | 
					
						
						|  | log probability. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def sample_nodes( | 
					
						
						|  | self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType | 
					
						
						|  | ) -> Tuple[torch.Tensor, torch.Tensor, StateType]: | 
					
						
						|  | del state | 
					
						
						|  | selected_log_probs, selected_indices = torch.topk(log_probs, per_node_beam_size, dim=-1) | 
					
						
						|  | return selected_log_probs, selected_indices, {} | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class MultinomialSampler(Sampler): | 
					
						
						|  | """ | 
					
						
						|  | A `Sampler` which samples nodes from the given multinomial distribution. Beams are sampled | 
					
						
						|  | in the default, non-deterministic way. | 
					
						
						|  |  | 
					
						
						|  | :param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature` | 
					
						
						|  | above 1.0 produces a flatter probability distribution. | 
					
						
						|  | :param with_replacement: Whether to sample with replacement. | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | temperature: float = 1.0, | 
					
						
						|  | with_replacement: bool = False, | 
					
						
						|  | ) -> None: | 
					
						
						|  | self.temperature = temperature | 
					
						
						|  | self.with_replacement = with_replacement | 
					
						
						|  |  | 
					
						
						|  | def sample_nodes( | 
					
						
						|  | self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType | 
					
						
						|  | ) -> Tuple[torch.Tensor, torch.Tensor, StateType]: | 
					
						
						|  | if self.temperature != 1.0: | 
					
						
						|  | _probabilities = torch.nn.functional.softmax(log_probs / self.temperature, dim=-1) | 
					
						
						|  | else: | 
					
						
						|  | _probabilities = log_probs.exp() | 
					
						
						|  |  | 
					
						
						|  | selected_indices = torch.multinomial(_probabilities, per_node_beam_size, replacement=self.with_replacement) | 
					
						
						|  |  | 
					
						
						|  | return torch.gather(log_probs, 1, selected_indices), selected_indices, state | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class TopKSampler(Sampler): | 
					
						
						|  | """ | 
					
						
						|  | A `Sampler` which redistributes the probability mass function for nodes among the | 
					
						
						|  | top `k` choices, then samples from that subset after re-normalizing the probabilities. | 
					
						
						|  |  | 
					
						
						|  | Beams are sampled in the default, deterministic way. | 
					
						
						|  |  | 
					
						
						|  | :param k: The number of top choices to be selected from. | 
					
						
						|  | :param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature` | 
					
						
						|  | above 1.0 produces a flatter probability distribution. | 
					
						
						|  | :param with_replacement: If set to `True`, samples will be selected with replacement from the top k choices. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | k: int = 1, | 
					
						
						|  | temperature: float = 1.0, | 
					
						
						|  | with_replacement: bool = False, | 
					
						
						|  | ): | 
					
						
						|  | self.k = k | 
					
						
						|  | self.temperature = temperature or 1.0 | 
					
						
						|  | self.with_replacement = with_replacement | 
					
						
						|  |  | 
					
						
						|  | def sample_nodes( | 
					
						
						|  | self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType | 
					
						
						|  | ) -> Tuple[torch.Tensor, torch.Tensor, StateType]: | 
					
						
						|  | if not per_node_beam_size <= self.k <= log_probs.size()[1]: | 
					
						
						|  | raise ValueError( | 
					
						
						|  | "k must be a postive integer no less than per_node_beam_size and no greater than vocabulary size" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | top_k_log_probs, top_k_indices = log_probs.topk(self.k, dim=-1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.temperature != 1.0: | 
					
						
						|  | top_k_log_probs = top_k_log_probs / self.temperature | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | normalized_top_k_probs = torch.nn.functional.softmax(top_k_log_probs, dim=-1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | sampled_indices = torch.multinomial( | 
					
						
						|  | normalized_top_k_probs, per_node_beam_size, replacement=self.with_replacement | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | indices = top_k_indices.gather(-1, sampled_indices) | 
					
						
						|  |  | 
					
						
						|  | return log_probs.gather(1, indices), indices, state | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class TopPSampler(Sampler): | 
					
						
						|  | """ | 
					
						
						|  | A `Sampler` which redistributes the probability mass function for nodes among | 
					
						
						|  | the top choices with a cumulative probability of at least `p`, then samples from that subset | 
					
						
						|  | after re-normalizing the probabilities. | 
					
						
						|  |  | 
					
						
						|  | Beams are sampled in the default, deterministic way. | 
					
						
						|  |  | 
					
						
						|  | :param p: | 
					
						
						|  | The cumulative probability cutoff threshold. A higher value of `p` will result in more possible | 
					
						
						|  | examples to sample from. If `with_replacement` is `False` and the number of possible samples is | 
					
						
						|  | insufficient to sample without replacement from when calling `sample_nodes`, then the top | 
					
						
						|  | `per_node_beam_size` examples will be chosen. | 
					
						
						|  | :param temperature: | 
					
						
						|  | A `temperature` below 1.0 produces a sharper probability distribution and a `temperature` | 
					
						
						|  | above 1.0 produces a flatter probability distribution. | 
					
						
						|  | :param with_replacement: | 
					
						
						|  | If set to `True`, samples will be selected with replacement from the top choices. | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | p: float = 0.9, | 
					
						
						|  | temperature: float = 1.0, | 
					
						
						|  | with_replacement: bool = False, | 
					
						
						|  | ): | 
					
						
						|  | if p < 0.0 or p > 1.0: | 
					
						
						|  | raise ValueError("p must be a positive float no greater than 1.0") | 
					
						
						|  | self.p = p | 
					
						
						|  | self.temperature = temperature or 1.0 | 
					
						
						|  | self.with_replacement = with_replacement | 
					
						
						|  |  | 
					
						
						|  | def sample_nodes( | 
					
						
						|  | self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType | 
					
						
						|  | ) -> Tuple[torch.Tensor, torch.Tensor, StateType]: | 
					
						
						|  | if not per_node_beam_size <= log_probs.size()[1]: | 
					
						
						|  | raise ValueError("per_node_beam_size cannot be greater than vocabulary size") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.temperature != 1.0: | 
					
						
						|  | _log_probs = torch.nn.functional.log_softmax(log_probs / self.temperature, dim=-1) | 
					
						
						|  | else: | 
					
						
						|  | _log_probs = log_probs | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | log_probs_descending, sorting_indices = torch.sort(_log_probs, descending=True) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | probabilities_descending = log_probs_descending.exp() | 
					
						
						|  | probabilities_summed = torch.cumsum(probabilities_descending, dim=-1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | exclusion_mask = probabilities_summed >= self.p | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | exclusion_mask[..., 1:] = exclusion_mask[..., :-1].clone() | 
					
						
						|  | exclusion_mask[..., 0] = False | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if not self.with_replacement: | 
					
						
						|  | exclusion_mask[..., :per_node_beam_size] = False | 
					
						
						|  |  | 
					
						
						|  | log_probs_descending[exclusion_mask] = torch.finfo(log_probs.dtype).min | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | filtered_probabilities = torch.nn.functional.softmax(log_probs_descending, dim=-1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | sampled_indices = torch.multinomial( | 
					
						
						|  | filtered_probabilities, per_node_beam_size, replacement=self.with_replacement | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | selected_indices = sorting_indices.gather(-1, sampled_indices) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | return torch.gather(log_probs, 1, selected_indices), selected_indices, state | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class GumbelSampler(Sampler): | 
					
						
						|  | """ | 
					
						
						|  | A `Sampler` which uses the Gumbel-Top-K trick to sample without replacement. See | 
					
						
						|  | [*Stochastic Beams and Where to Find Them: The Gumbel-Top-k Trick for Sampling | 
					
						
						|  | Sequences Without Replacement*, W Kool, H Van Hoof and M Welling, 2010] | 
					
						
						|  | (https://api.semanticscholar.org/CorpusID:76662039). | 
					
						
						|  |  | 
					
						
						|  | :param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature` | 
					
						
						|  | above 1.0 produces a flatter probability distribution. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, temperature: float = 1.0): | 
					
						
						|  | self.temperature = temperature | 
					
						
						|  |  | 
					
						
						|  | def init_state( | 
					
						
						|  | self, start_class_log_probabilities: torch.Tensor, batch_size: int, num_classes: int | 
					
						
						|  | ) -> StateType: | 
					
						
						|  |  | 
					
						
						|  | zeros = start_class_log_probabilities.new_zeros((batch_size, num_classes)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | G_phi_S = self.gumbel_with_max(start_class_log_probabilities, zeros) | 
					
						
						|  |  | 
					
						
						|  | return {"G_phi_S": G_phi_S} | 
					
						
						|  |  | 
					
						
						|  | def sample_nodes( | 
					
						
						|  | self, | 
					
						
						|  | log_probs: torch.Tensor, | 
					
						
						|  | per_node_beam_size: int, | 
					
						
						|  | state: StateType, | 
					
						
						|  | ) -> Tuple[torch.Tensor, torch.Tensor, StateType]: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.temperature != 1.0: | 
					
						
						|  | _log_probs = torch.nn.functional.log_softmax(log_probs / self.temperature, dim=-1) | 
					
						
						|  | else: | 
					
						
						|  | _log_probs = log_probs | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | phi_S = state["phi_S"] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | phi_S = phi_S.unsqueeze(-1).expand_as(_log_probs) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | phi_S_new = phi_S + _log_probs | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | G_phi_S = state["G_phi_S"].unsqueeze(-1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | G_phi_S_new = self.gumbel_with_max(phi_S_new, G_phi_S) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | top_G_phi_S_new, top_indices = torch.topk(G_phi_S_new, per_node_beam_size, dim=-1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | top_log_probs = log_probs.gather(1, top_indices) | 
					
						
						|  |  | 
					
						
						|  | return top_log_probs, top_indices, {"G_phi_S": top_G_phi_S_new} | 
					
						
						|  |  | 
					
						
						|  | def sample_beams( | 
					
						
						|  | self, | 
					
						
						|  | log_probs: torch.Tensor, | 
					
						
						|  | beam_size: int, | 
					
						
						|  | state: StateType, | 
					
						
						|  | ) -> Tuple[torch.Tensor, torch.Tensor, StateType]: | 
					
						
						|  | """ | 
					
						
						|  | Returns the beams with the highest perturbed log probabilities. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | batch_size = log_probs.size()[0] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | G_phi_S = state["G_phi_S"] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | G_phi_S = G_phi_S.reshape_as(log_probs) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | G_phi_S_new, selected_indices = torch.topk(G_phi_S, beam_size, dim=-1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | selected_log_probs = log_probs.gather(1, selected_indices) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | selected_log_probs, sort_indices = selected_log_probs.sort(dim=-1, descending=True) | 
					
						
						|  | selected_indices = selected_indices.gather(1, sort_indices) | 
					
						
						|  | G_phi_S_new = G_phi_S_new.gather(1, sort_indices) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | G_phi_S_new = G_phi_S_new.reshape(batch_size * beam_size) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | phi_S = selected_log_probs.reshape(batch_size * beam_size) | 
					
						
						|  |  | 
					
						
						|  | return selected_log_probs, selected_indices, {"G_phi_S": G_phi_S_new, "phi_S": phi_S} | 
					
						
						|  |  | 
					
						
						|  | def gumbel(self, phi) -> torch.Tensor: | 
					
						
						|  | """ | 
					
						
						|  | Sample `Gumbel(phi)`. | 
					
						
						|  |  | 
					
						
						|  | `phi` should have shape `(batch_size, num_classes)`. | 
					
						
						|  | """ | 
					
						
						|  | return -torch.log(-torch.log(torch.rand_like(phi))) + phi | 
					
						
						|  |  | 
					
						
						|  | def gumbel_with_max(self, phi, T) -> torch.Tensor: | 
					
						
						|  | """ | 
					
						
						|  | Sample `Gumbel(phi)` conditioned on the maximum value being equal to `T`. | 
					
						
						|  |  | 
					
						
						|  | `phi` should have shape `(batch_size, num_classes)` and `T` should have | 
					
						
						|  | shape `(batch_size, 1)`. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | G_phi = self.gumbel(phi) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | Z, _ = G_phi.max(dim=-1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | v = T - G_phi + torch.log1p(-torch.exp(G_phi - Z.unsqueeze(-1))) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | return T - torch.nn.functional.relu(v) - torch.log1p(torch.exp(-v.abs())) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class FinalSequenceScorer: | 
					
						
						|  | """ | 
					
						
						|  | An abstract class that can be used to score the final generated sequences found | 
					
						
						|  | by beam search. Given the predicted sequences and the corresponding log probabilities of | 
					
						
						|  | those sequences, the class calculates and returns the final score of the sequences. | 
					
						
						|  |  | 
					
						
						|  | The default implementation scores the sequences using the sum of the log probabilities of | 
					
						
						|  | the sequence, which is passed as input. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | @abstractmethod | 
					
						
						|  | def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor: | 
					
						
						|  | """ | 
					
						
						|  | Score the final predictions found by beam search. | 
					
						
						|  | Returns a tensor of the final sequence scores of shape `(batch_size, beam_size)`. | 
					
						
						|  |  | 
					
						
						|  | :param predictions: A tensor containing the initial predictions with shape `(batch_size, beam_size, max_steps)`. | 
					
						
						|  | :param log_probabilities: A tensor containing the log probabilities of the sequence, defined as the sum | 
					
						
						|  | of the log probabilities per token, with shape `(batch_size, beam_size)`. | 
					
						
						|  | :param end_index: The index of the end symbol. | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  | raise NotImplementedError | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class SequenceLogProbabilityScorer(FinalSequenceScorer): | 
					
						
						|  | """ | 
					
						
						|  | A :class:`FinalSequenceScorer` which scores the sequences by the sum of the log probabilities | 
					
						
						|  | across the sequence's tokens. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor: | 
					
						
						|  | del predictions, end_index | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | return log_probabilities | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class LengthNormalizedSequenceLogProbabilityScorer(FinalSequenceScorer): | 
					
						
						|  | """ | 
					
						
						|  | A :class:`FinalSequenceScorer` which scores the sequences by the average log probability of the | 
					
						
						|  | tokens in the sequence. It optionally includes a length penalty which promotes | 
					
						
						|  | or demotes sequences based on their lengths. The final score for a sequence will | 
					
						
						|  | be `(sequence_log_probability) / (sequence_length ** length_penalty)`. The sequence length | 
					
						
						|  | here includes the end token. | 
					
						
						|  |  | 
					
						
						|  | :param length_penalty: The length penalty to use. A value of 1.0 means no length penalty is used. | 
					
						
						|  | A value > 1.0 favors longer sequences, and < 1.0 favors shorter sequences. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, length_penalty: float = 1.0): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.length_penalty = length_penalty | 
					
						
						|  |  | 
					
						
						|  | def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor: | 
					
						
						|  |  | 
					
						
						|  | lengths = (predictions != end_index).long().sum(dim=2) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | is_end_token = predictions[:, :, -1] == end_index | 
					
						
						|  | lengths += is_end_token.long() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | average_log_probs = log_probabilities / (lengths**self.length_penalty) | 
					
						
						|  | return average_log_probs | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class Constraint: | 
					
						
						|  | """ | 
					
						
						|  | An abstract class that can be used to enforce constraints on the output predictions | 
					
						
						|  | by manipulating the class log probabilities during beam search. | 
					
						
						|  |  | 
					
						
						|  | A `Constraint` just has three methods that need to be implemented by subclasses: | 
					
						
						|  | `init_state()`, `apply()` and `_update_state()`. | 
					
						
						|  |  | 
					
						
						|  | `init_state()` takes one argument: | 
					
						
						|  |  | 
					
						
						|  | - the batch size, an int | 
					
						
						|  |  | 
					
						
						|  | It returns a constraint state, which is a nested list of dictionaries, with any state needed for subsequent | 
					
						
						|  | calls to `apply()` and `update_state()`. The length of the outer list should be equal to `batch_size`. | 
					
						
						|  | Each inner list should be of length 1. | 
					
						
						|  |  | 
					
						
						|  | `apply()` takes two arguments: | 
					
						
						|  |  | 
					
						
						|  | - the constraint state, which is a nested list of dictionaries. The length of the outer list is `batch_size` | 
					
						
						|  | and the length of each inner list is `beam_size` except on the first time `apply()` is called when it is 1. | 
					
						
						|  | - `class_log_probabilities`, a tensor of shape `(batch_size, beam_size, num_classes)` that contains the | 
					
						
						|  | log probabilities for the classes during search. The first time `apply()` is called, `beam_size = 1`. | 
					
						
						|  |  | 
					
						
						|  | The `apply()` method should return new `class_log_probabilities` that enforce the constraint | 
					
						
						|  | for this step of beam search. For instance, it may prevent a specific class from being selected by setting | 
					
						
						|  | the corresponding log probability to a negligible value such as `float("-inf")` or | 
					
						
						|  | `torch.finfo(class_log_probabilities.dtype).min`. | 
					
						
						|  |  | 
					
						
						|  | `_update_state()` takes two arguments: | 
					
						
						|  |  | 
					
						
						|  | - the copied parent constraint state, which is a nested list of dictionaries. `state[i][j]` contains the | 
					
						
						|  | copied state for the parent of `last_prediction[i, j]`. It is unique to that batch and beam, so it can be | 
					
						
						|  | directly edited in-place without affecting the others. | 
					
						
						|  | - last_prediction, a tensor of shape `(batch_size, beam_size)` containing the predictions from the last | 
					
						
						|  | step of beam search. | 
					
						
						|  |  | 
					
						
						|  | The `_update_state()` function should return a new constraint state, a nested list of dictionaries of | 
					
						
						|  | length `batch_size` and inner list of length `beam_size`, one for each of the predictions in `last_prediction`. | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | @abstractmethod | 
					
						
						|  | def init_state( | 
					
						
						|  | self, | 
					
						
						|  | batch_size: int, | 
					
						
						|  | ) -> ConstraintStateType: | 
					
						
						|  | raise NotImplementedError | 
					
						
						|  |  | 
					
						
						|  | @abstractmethod | 
					
						
						|  | def apply( | 
					
						
						|  | self, | 
					
						
						|  | state: ConstraintStateType, | 
					
						
						|  | class_log_probabilities: torch.Tensor, | 
					
						
						|  | ) -> torch.Tensor: | 
					
						
						|  | raise NotImplementedError | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def _copy_state( | 
					
						
						|  | state: ConstraintStateType, | 
					
						
						|  | batch_size: int, | 
					
						
						|  | beam_size: int, | 
					
						
						|  | last_backpointer: Optional[torch.Tensor] = None, | 
					
						
						|  | ) -> ConstraintStateType: | 
					
						
						|  | """ | 
					
						
						|  | Copies the `state` . This method copies the data in `state` using `copy.deepcopy()`. If this | 
					
						
						|  | is not appropriate for your constraint, you will need to implement the copying yourself. | 
					
						
						|  | """ | 
					
						
						|  | new_state = [] | 
					
						
						|  | for i in range(batch_size): | 
					
						
						|  | batch_state = [] | 
					
						
						|  | for j in range(beam_size): | 
					
						
						|  | if last_backpointer is None: | 
					
						
						|  |  | 
					
						
						|  | backpointer = 0 | 
					
						
						|  | else: | 
					
						
						|  | backpointer = last_backpointer[i, j].item() | 
					
						
						|  | batch_state.append(copy.deepcopy(state[i][backpointer])) | 
					
						
						|  | new_state.append(batch_state) | 
					
						
						|  | return new_state | 
					
						
						|  |  | 
					
						
						|  | def update_state( | 
					
						
						|  | self, | 
					
						
						|  | state: ConstraintStateType, | 
					
						
						|  | last_prediction: torch.Tensor, | 
					
						
						|  | last_backpointer: Optional[torch.Tensor] = None, | 
					
						
						|  | ) -> ConstraintStateType: | 
					
						
						|  | batch_size, beam_size = last_prediction.size() | 
					
						
						|  | new_state = self._copy_state(state, batch_size, beam_size, last_backpointer) | 
					
						
						|  | return self._update_state(new_state, last_prediction) | 
					
						
						|  |  | 
					
						
						|  | @abstractmethod | 
					
						
						|  | def _update_state( | 
					
						
						|  | self, | 
					
						
						|  | state: ConstraintStateType, | 
					
						
						|  | last_prediction: torch.Tensor, | 
					
						
						|  | ) -> ConstraintStateType: | 
					
						
						|  | raise NotImplementedError | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class RepeatedNGramBlockingConstraint(Constraint): | 
					
						
						|  | def __init__(self, ngram_size: int, **kwargs) -> None: | 
					
						
						|  | super().__init__(**kwargs) | 
					
						
						|  | self.ngram_size = ngram_size | 
					
						
						|  |  | 
					
						
						|  | def init_state( | 
					
						
						|  | self, | 
					
						
						|  | batch_size: int, | 
					
						
						|  | ) -> ConstraintStateType: | 
					
						
						|  | return [[{"seen_ngrams": {}, "current_prefix": []}] for _ in range(batch_size)] | 
					
						
						|  |  | 
					
						
						|  | def apply( | 
					
						
						|  | self, | 
					
						
						|  | state: ConstraintStateType, | 
					
						
						|  | class_log_probabilities: torch.Tensor, | 
					
						
						|  | ) -> torch.Tensor: | 
					
						
						|  | for i, batch in enumerate(state): | 
					
						
						|  | for j, beam in enumerate(batch): | 
					
						
						|  | current_prefix = tuple(beam["current_prefix"]) | 
					
						
						|  | seen_ngrams = beam["seen_ngrams"] | 
					
						
						|  | try: | 
					
						
						|  | disallowed_indices = seen_ngrams[current_prefix] | 
					
						
						|  | class_log_probabilities[i, j, disallowed_indices] = torch.finfo( | 
					
						
						|  | class_log_probabilities.dtype | 
					
						
						|  | ).min | 
					
						
						|  | except KeyError: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | pass | 
					
						
						|  | return class_log_probabilities | 
					
						
						|  |  | 
					
						
						|  | def _update_state( | 
					
						
						|  | self, | 
					
						
						|  | state: ConstraintStateType, | 
					
						
						|  | last_prediction: torch.Tensor, | 
					
						
						|  | ) -> ConstraintStateType: | 
					
						
						|  | for i, batch in enumerate(state): | 
					
						
						|  | for j, beam in enumerate(batch): | 
					
						
						|  | prediction = last_prediction[i, j].item() | 
					
						
						|  | prefix = beam["current_prefix"] | 
					
						
						|  | seen_ngrams = beam["seen_ngrams"] | 
					
						
						|  |  | 
					
						
						|  | if len(prefix) == self.ngram_size - 1: | 
					
						
						|  |  | 
					
						
						|  | if tuple(prefix) not in seen_ngrams: | 
					
						
						|  | seen_ngrams[tuple(prefix)] = [] | 
					
						
						|  | seen_ngrams[tuple(prefix)].append(prediction) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | prefix.append(prediction) | 
					
						
						|  | if len(prefix) == self.ngram_size: | 
					
						
						|  | prefix.pop(0) | 
					
						
						|  | return state | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class BeamSearch: | 
					
						
						|  | """ | 
					
						
						|  | Implements the beam search algorithm for decoding the most likely sequences. | 
					
						
						|  |  | 
					
						
						|  | :param end_index: The index of the "stop" or "end" token in the vocabulary. Usually the EOS token ID. | 
					
						
						|  |  | 
					
						
						|  | :param max_steps: The maximum number of decoding steps to take, i.e. the maximum length | 
					
						
						|  | of the predicted sequences. | 
					
						
						|  |  | 
					
						
						|  | :param beam_size: The width of the beam used. | 
					
						
						|  |  | 
					
						
						|  | :param per_node_beam_size: The maximum number of candidates to consider per node, at each step in the search. | 
					
						
						|  | If not given, this just defaults to `beam_size`. Setting this parameter | 
					
						
						|  | to a number smaller than `beam_size` may give better results, as it can introduce | 
					
						
						|  | more diversity into the search. See | 
					
						
						|  | [*Beam Search Strategies for Neural Machine Translation*, Freitag and Al-Onaizan, 2017] | 
					
						
						|  | (https://api.semanticscholar.org/CorpusID:2229477). | 
					
						
						|  |  | 
					
						
						|  | :param sampler: An optional `Sampler` which is used to pick next candidate nodes and beams. | 
					
						
						|  | If not specified, `DeterministicSampler` will be used, which just takes the | 
					
						
						|  | `per_node_beam_size` most likely nodes and the `beam_size` most likely beams. | 
					
						
						|  |  | 
					
						
						|  | Using the [`GumbelSampler`](#gumbelsampler), on the other hand, will give you | 
					
						
						|  | [Stochastic Beam Search](https://api.semanticscholar.org/CorpusID:76662039). | 
					
						
						|  |  | 
					
						
						|  | :param min_steps: The minimum number of decoding steps to take, i.e. the minimum length of | 
					
						
						|  | the predicted sequences. This does not include the start or end tokens. If `None`, | 
					
						
						|  | no minimum is enforced. | 
					
						
						|  |  | 
					
						
						|  | :param final_sequence_scorer: An optional `FinalSequenceScorer` which is used to score the final generated sequences. | 
					
						
						|  | The output from this module is what is returned by the `search` method. If not | 
					
						
						|  | specified, `SequenceLogProbabilityScorer` will be used, which scores the sequences | 
					
						
						|  | by the sum of the token log probabilities. | 
					
						
						|  |  | 
					
						
						|  | :param constraints: An optional list of `Constraint`s which should be applied during beam search. If not | 
					
						
						|  | provided, no constraints will be enforced. | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | end_index: int, | 
					
						
						|  | *, | 
					
						
						|  | max_steps: int = 50, | 
					
						
						|  | beam_size: int = 10, | 
					
						
						|  | per_node_beam_size: Optional[int] = None, | 
					
						
						|  | sampler: Optional[Sampler] = None, | 
					
						
						|  | min_steps: Optional[int] = None, | 
					
						
						|  | final_sequence_scorer: Optional[FinalSequenceScorer] = None, | 
					
						
						|  | constraints: Optional[List[Constraint]] = None, | 
					
						
						|  | ) -> None: | 
					
						
						|  | if not max_steps > 0: | 
					
						
						|  | raise ValueError("max_steps must be positive") | 
					
						
						|  | if not beam_size > 0: | 
					
						
						|  | raise ValueError("beam_size must be positive") | 
					
						
						|  | if per_node_beam_size is not None and not per_node_beam_size > 0: | 
					
						
						|  | raise ValueError("per_node_beam_size must be positive") | 
					
						
						|  | if min_steps is not None: | 
					
						
						|  | if not min_steps >= 0: | 
					
						
						|  | raise ValueError("min_steps must be non-negative") | 
					
						
						|  | if not min_steps <= max_steps: | 
					
						
						|  | raise ValueError("min_steps must be less than or equal to max_steps") | 
					
						
						|  |  | 
					
						
						|  | self._end_index = end_index | 
					
						
						|  | self.max_steps = max_steps | 
					
						
						|  | self.beam_size = beam_size | 
					
						
						|  | self.per_node_beam_size = per_node_beam_size or beam_size | 
					
						
						|  | self.sampler = sampler or DeterministicSampler() | 
					
						
						|  | self.min_steps = min_steps or 0 | 
					
						
						|  | self.final_sequence_scorer = final_sequence_scorer or SequenceLogProbabilityScorer() | 
					
						
						|  | self.constraints = constraints or [] | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def _reconstruct_sequences(predictions, backpointers): | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | reconstructed_predictions = [predictions[-1].unsqueeze(2)] | 
					
						
						|  |  | 
					
						
						|  | if not backpointers: | 
					
						
						|  | return reconstructed_predictions | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | cur_backpointers = backpointers[-1] | 
					
						
						|  |  | 
					
						
						|  | for timestep in range(len(predictions) - 2, 0, -1): | 
					
						
						|  |  | 
					
						
						|  | cur_preds = predictions[timestep].gather(1, cur_backpointers).unsqueeze(2) | 
					
						
						|  |  | 
					
						
						|  | reconstructed_predictions.append(cur_preds) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | cur_backpointers = backpointers[timestep - 1].gather(1, cur_backpointers) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | final_preds = predictions[0].gather(1, cur_backpointers).unsqueeze(2) | 
					
						
						|  |  | 
					
						
						|  | reconstructed_predictions.append(final_preds) | 
					
						
						|  |  | 
					
						
						|  | return reconstructed_predictions | 
					
						
						|  |  | 
					
						
						|  | def search( | 
					
						
						|  | self, | 
					
						
						|  | start_predictions: torch.Tensor, | 
					
						
						|  | start_state: StateType, | 
					
						
						|  | step: StepFunctionType, | 
					
						
						|  | ) -> Tuple[torch.Tensor, torch.Tensor]: | 
					
						
						|  | """ | 
					
						
						|  | Given a starting state and a step function, apply beam search to find the | 
					
						
						|  | most likely target sequences. | 
					
						
						|  |  | 
					
						
						|  | Returns a tuple of `(predictions, final_scores)`, where `predictions` | 
					
						
						|  | has shape `(batch_size, beam_size, max_steps)` and `final_scores` | 
					
						
						|  | has shape `(batch_size, beam_size)`. | 
					
						
						|  |  | 
					
						
						|  | .. note:: | 
					
						
						|  | If your step function returns `-inf` for some log probabilities | 
					
						
						|  | (like if you're using a masked log-softmax) then some of the "best" | 
					
						
						|  | sequences returned may also have `-inf` log probability. Specifically | 
					
						
						|  | this happens when the beam size is smaller than the number of actions | 
					
						
						|  | with finite log probability (non-zero probability) returned by the step function. | 
					
						
						|  | Therefore if you're using a mask you may want to check the results from `search` | 
					
						
						|  | and potentially discard sequences with non-finite log probability. | 
					
						
						|  |  | 
					
						
						|  | :param start_predictions: A tensor containing the initial predictions with shape `(batch_size,)`. | 
					
						
						|  | Usually the initial predictions are just the index of the "start" token | 
					
						
						|  | in the target vocabulary. | 
					
						
						|  |  | 
					
						
						|  | :param start_state: The initial state passed to the `step` function. Each value of the state dict | 
					
						
						|  | should be a tensor of shape `(batch_size, *)`, where `*` means any other | 
					
						
						|  | number of dimensions. | 
					
						
						|  |  | 
					
						
						|  | :param step: A function that is responsible for computing the next most likely tokens, | 
					
						
						|  | given the current state and the predictions from the last time step. | 
					
						
						|  | The function should accept two or three arguments: | 
					
						
						|  |  | 
					
						
						|  | - a tensor of shape `(group_size,)` or representing the index of the predicted | 
					
						
						|  | tokens from the last time step, | 
					
						
						|  | - the current state, a `StateType`, and | 
					
						
						|  | - optionally, the timestep, an `int`. | 
					
						
						|  |  | 
					
						
						|  | The `group_size` will be `batch_size * beam_size`, except in the initial | 
					
						
						|  | step, for which it will just be `batch_size`. | 
					
						
						|  |  | 
					
						
						|  | The function is expected to return a tuple, where the first element | 
					
						
						|  | is a tensor of shape `(group_size, vocab_size)` containing | 
					
						
						|  | the log probabilities of the tokens for the next step, and the second | 
					
						
						|  | element is the updated state. The tensor in the state should have shape | 
					
						
						|  | `(group_size, *)`, where `*` means any other number of dimensions. | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  | step_signature = signature(step) | 
					
						
						|  | if len(step_signature.parameters) < 3: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | old_step = cast(StepFunctionTypeNoTimestep, step) | 
					
						
						|  |  | 
					
						
						|  | def new_step(last_predictions: torch.Tensor, state: Dict[str, torch.Tensor], time_step: int): | 
					
						
						|  | del time_step | 
					
						
						|  | return old_step(last_predictions, state) | 
					
						
						|  |  | 
					
						
						|  | return self._search(start_predictions, start_state, new_step) | 
					
						
						|  | else: | 
					
						
						|  | return self._search(start_predictions, start_state, cast(StepFunctionTypeWithTimestep, step)) | 
					
						
						|  |  | 
					
						
						|  | def _search( | 
					
						
						|  | self, | 
					
						
						|  | start_predictions: torch.Tensor, | 
					
						
						|  | start_state: StateType, | 
					
						
						|  | step: StepFunctionTypeWithTimestep, | 
					
						
						|  | ) -> Tuple[torch.Tensor, torch.Tensor]: | 
					
						
						|  | batch_size = start_predictions.size()[0] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | predictions: List[torch.Tensor] = [] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | backpointers: List[torch.Tensor] = [] | 
					
						
						|  |  | 
					
						
						|  | constraint_states = [constraint.init_state(batch_size) for constraint in self.constraints] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | start_class_log_probabilities, state = step(start_predictions, start_state, 0) | 
					
						
						|  |  | 
					
						
						|  | num_classes = start_class_log_probabilities.size()[1] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.per_node_beam_size > num_classes: | 
					
						
						|  | raise ValueError( | 
					
						
						|  | f"Vocab size ({num_classes:d}) too small " | 
					
						
						|  | f"relative to per_node_beam_size ({self.per_node_beam_size:d}).\n" | 
					
						
						|  | f"Please decrease beam_size or per_node_beam_size." | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | sampler_state = self.sampler.init_state(start_class_log_probabilities, batch_size, num_classes) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.constraints: | 
					
						
						|  |  | 
					
						
						|  | expanded_start_class_log_probabilities = start_class_log_probabilities.unsqueeze(1) | 
					
						
						|  | for constraint, constraint_state in zip(self.constraints, constraint_states): | 
					
						
						|  | expanded_start_class_log_probabilities = constraint.apply( | 
					
						
						|  | constraint_state, expanded_start_class_log_probabilities | 
					
						
						|  | ) | 
					
						
						|  | start_class_log_probabilities = expanded_start_class_log_probabilities.squeeze(1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.min_steps >= 1: | 
					
						
						|  | start_class_log_probabilities[:, self._end_index] = torch.finfo( | 
					
						
						|  | start_class_log_probabilities.dtype | 
					
						
						|  | ).min | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | ( | 
					
						
						|  | start_top_log_probabilities, | 
					
						
						|  | start_predicted_classes, | 
					
						
						|  | sampler_state, | 
					
						
						|  | ) = self.sampler.sample_beams(start_class_log_probabilities, self.beam_size, sampler_state) | 
					
						
						|  |  | 
					
						
						|  | if self.beam_size == 1 and (start_predicted_classes == self._end_index).all(): | 
					
						
						|  | warnings.warn( | 
					
						
						|  | "Empty sequences predicted. You may want to increase the beam size or ensure " | 
					
						
						|  | "your step function is working properly.", | 
					
						
						|  | RuntimeWarning, | 
					
						
						|  | ) | 
					
						
						|  | return start_predicted_classes.unsqueeze(-1), start_top_log_probabilities | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | last_log_probabilities = start_top_log_probabilities | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | predictions.append(start_predicted_classes) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | log_probs_after_end = start_class_log_probabilities.new_full( | 
					
						
						|  | (batch_size * self.beam_size, num_classes), | 
					
						
						|  | torch.finfo(start_class_log_probabilities.dtype).min, | 
					
						
						|  | ) | 
					
						
						|  | log_probs_after_end[:, self._end_index] = 0.0 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self._update_initial_state(state, batch_size) | 
					
						
						|  |  | 
					
						
						|  | for i, constraint in enumerate(self.constraints): | 
					
						
						|  | constraint_states[i] = constraint.update_state(constraint_states[i], start_predicted_classes) | 
					
						
						|  |  | 
					
						
						|  | for timestep in range(self.max_steps - 1): | 
					
						
						|  |  | 
					
						
						|  | last_predictions = predictions[-1].reshape(batch_size * self.beam_size) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if (last_predictions == self._end_index).all(): | 
					
						
						|  | break | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class_log_probabilities, state = step(last_predictions, state, timestep + 1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.constraints: | 
					
						
						|  |  | 
					
						
						|  | reshaped_class_log_probabilities = class_log_probabilities.view(batch_size, self.beam_size, -1) | 
					
						
						|  | for constraint, constraint_state in zip(self.constraints, constraint_states): | 
					
						
						|  | reshaped_class_log_probabilities = constraint.apply( | 
					
						
						|  | constraint_state, reshaped_class_log_probabilities | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | class_log_probabilities = reshaped_class_log_probabilities.view(batch_size * self.beam_size, -1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if timestep + 2 <= self.min_steps: | 
					
						
						|  | class_log_probabilities[:, self._end_index] = torch.finfo(class_log_probabilities.dtype).min | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | last_predictions_expanded = last_predictions.unsqueeze(-1).expand( | 
					
						
						|  | batch_size * self.beam_size, num_classes | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | cleaned_log_probabilities = torch.where( | 
					
						
						|  | last_predictions_expanded == self._end_index, | 
					
						
						|  | log_probs_after_end, | 
					
						
						|  | class_log_probabilities, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | top_log_probabilities, predicted_classes, sampler_state = self.sampler.sample_nodes( | 
					
						
						|  | cleaned_log_probabilities, self.per_node_beam_size, sampler_state | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | expanded_last_log_probabilities = ( | 
					
						
						|  | last_log_probabilities.unsqueeze(2) | 
					
						
						|  | .expand(batch_size, self.beam_size, self.per_node_beam_size) | 
					
						
						|  | .reshape(batch_size * self.beam_size, self.per_node_beam_size) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | summed_top_log_probabilities = top_log_probabilities + expanded_last_log_probabilities | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | reshaped_summed = summed_top_log_probabilities.reshape( | 
					
						
						|  | batch_size, self.beam_size * self.per_node_beam_size | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | reshaped_predicted_classes = predicted_classes.reshape( | 
					
						
						|  | batch_size, self.beam_size * self.per_node_beam_size | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | ( | 
					
						
						|  | restricted_beam_log_probs, | 
					
						
						|  | restricted_beam_indices, | 
					
						
						|  | sampler_state, | 
					
						
						|  | ) = self.sampler.sample_beams(reshaped_summed, self.beam_size, sampler_state) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | restricted_predicted_classes = reshaped_predicted_classes.gather(1, restricted_beam_indices) | 
					
						
						|  |  | 
					
						
						|  | predictions.append(restricted_predicted_classes) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | last_log_probabilities = restricted_beam_log_probs | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | backpointer = torch.divide(restricted_beam_indices, self.per_node_beam_size, rounding_mode="trunc") | 
					
						
						|  | backpointers.append(backpointer) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self._update_state(state, backpointer) | 
					
						
						|  |  | 
					
						
						|  | for i, constraint in enumerate(self.constraints): | 
					
						
						|  | constraint_states[i] = constraint.update_state( | 
					
						
						|  | constraint_states[i], restricted_predicted_classes, last_backpointer=backpointer | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if not self.constraints and ( | 
					
						
						|  | not torch.isfinite(last_log_probabilities).all() | 
					
						
						|  | or (last_log_probabilities == torch.finfo(last_log_probabilities.dtype).min).any() | 
					
						
						|  | ): | 
					
						
						|  | warnings.warn( | 
					
						
						|  | "Negligible log probabilities encountered ('-inf' or equivalent). " | 
					
						
						|  | "Some final sequences may not make sense. " | 
					
						
						|  | "This can happen when the beam size is larger than the number of valid (non-zero " | 
					
						
						|  | "probability) transitions that the step function produces.", | 
					
						
						|  | RuntimeWarning, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | reconstructed_predictions = self._reconstruct_sequences(predictions, backpointers) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | all_predictions = torch.cat(list(reversed(reconstructed_predictions)), 2) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | final_scores = self.final_sequence_scorer.score(all_predictions, last_log_probabilities, self._end_index) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | sorted_final_scores, sorted_indices = torch.sort(final_scores, dim=1, descending=True) | 
					
						
						|  | sorted_all_predictions = torch.gather( | 
					
						
						|  | all_predictions, 1, sorted_indices.unsqueeze(-1).expand_as(all_predictions) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | return sorted_all_predictions, sorted_final_scores | 
					
						
						|  |  | 
					
						
						|  | def _update_initial_state(self, state: StateType, batch_size: int): | 
					
						
						|  | """ | 
					
						
						|  | Expand tensors in a state dictionary from `(batch_size, *)` to `(batch_size * beam_size, *)`. | 
					
						
						|  | """ | 
					
						
						|  | for key, state_tensor in state.items(): | 
					
						
						|  | if state_tensor is None: | 
					
						
						|  | continue | 
					
						
						|  |  | 
					
						
						|  | _, *last_dims = state_tensor.size() | 
					
						
						|  | state[key] = ( | 
					
						
						|  | state_tensor.unsqueeze(1) | 
					
						
						|  | .expand(batch_size, self.beam_size, *last_dims) | 
					
						
						|  | .reshape(batch_size * self.beam_size, *last_dims) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def _update_state(self, state: StateType, backpointer: torch.Tensor): | 
					
						
						|  | batch_size = backpointer.size()[0] | 
					
						
						|  |  | 
					
						
						|  | for key, state_tensor in state.items(): | 
					
						
						|  | if state_tensor is None: | 
					
						
						|  | continue | 
					
						
						|  | _, *last_dims = state_tensor.size() | 
					
						
						|  |  | 
					
						
						|  | expanded_backpointer = backpointer.view(batch_size, self.beam_size, *([1] * len(last_dims))).expand( | 
					
						
						|  | batch_size, self.beam_size, *last_dims | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | state[key] = ( | 
					
						
						|  | state_tensor.reshape(batch_size, self.beam_size, *last_dims) | 
					
						
						|  | .gather(1, expanded_backpointer) | 
					
						
						|  | .reshape(batch_size * self.beam_size, *last_dims) | 
					
						
						|  | ) | 
					
						
						|  |  |