File size: 4,309 Bytes
c308f77
29bcc5f
4f4785c
e932abd
 
 
 
29bcc5f
4f4785c
 
 
 
29bcc5f
4f4785c
29bcc5f
 
c308f77
4f4785c
 
 
 
 
e932abd
 
29bcc5f
 
4f4785c
29bcc5f
 
 
e932abd
 
29bcc5f
 
 
 
e932abd
29bcc5f
 
 
 
 
e932abd
29bcc5f
 
 
 
 
 
 
e932abd
c308f77
29bcc5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from constants import TOKENIZER_PATH
from data_preprocessing import RandomizeImageTransform

import einops
import random
from pytorch_lightning.callbacks import Callback
import torch
import torch.nn.functional as F
from torchvision import transforms


class LogImageTexCallback(Callback):
    def __init__(self, logger, top_k, max_length):
        self.logger = logger
        self.top_k = top_k
        self.max_length = max_length
        self.tex_tokenizer = torch.load(TOKENIZER_PATH)
        self.tensor_to_PIL = transforms.ToPILImage()

    def on_validation_batch_start(self, trainer, transformer, batch, batch_idx, dataloader_idx):
        if batch_idx != 0 or dataloader_idx != 0:
            return
        sample_id = random.randint(0, len(batch['images']) - 1)
        image = batch['images'][sample_id]
        texs_predicted, texs_ids = beam_search_decode(transformer, image, transform_image=False, top_k=self.top_k,
                                                      max_length=self.max_length)
        image = self.tensor_to_PIL(image)
        tex_true = self.tex_tokenizer.decode(list(batch['tex_ids'][sample_id].to('cpu', torch.int)))
        self.logger.log_image(key="samples", images=[image],
                              caption=[f"True: {tex_true}\nPredicted: " + "\n".join(texs_predicted)])


@torch.inference_mode()
def beam_search_decode(transformer, image, transform_image=True, top_k=10, max_length=100):
    """Performs decoding maintaining k best candidates"""
    assert torch.is_tensor(image) and len(image.shape) == 3, "Image must be a 3 dimensional tensor (c h w)"

    def get_tgt_padding_mask(tgt):
        mask = tgt == tex_tokenizer.token_to_id("[SEP]")
        mask = torch.cumsum(mask, dim=1)
        mask = mask.to(transformer.device, torch.bool)
        return mask

    src = einops.rearrange(image, "c h w -> () c h w").to(transformer.device)
    if transform_image:
        image_transform = RandomizeImageTransform(width=transformer.hparams["image_width"],
                                                  height=transformer.hparams["image_width"],
                                                  random_magnitude=0)
        src = image_transform(src)
    memory = transformer.encode(src)

    tex_tokenizer = torch.load(TOKENIZER_PATH)
    candidates_tex_ids = [[tex_tokenizer.token_to_id("[CLS]")]]
    candidates_log_prob = torch.tensor([0], dtype=torch.float, device=transformer.device)

    while candidates_tex_ids[0][-1] != tex_tokenizer.token_to_id("[SEP]") and len(candidates_tex_ids[0]) < max_length:
        candidates_tex_ids = torch.tensor(candidates_tex_ids, dtype=torch.float, device=transformer.device)
        tgt_mask = transformer.transformer.generate_square_subsequent_mask(candidates_tex_ids.shape[1]).to(
            transformer.device, torch.bool)
        shared_memories = einops.repeat(memory, f"one n d_model -> ({candidates_tex_ids.shape[0]} one) n d_model")
        outs = transformer.decode(tgt=candidates_tex_ids,
                                  memory=shared_memories,
                                  tgt_mask=tgt_mask,
                                  memory_mask=None,
                                  tgt_padding_mask=get_tgt_padding_mask(candidates_tex_ids))
        outs = einops.rearrange(outs, 'b n prob -> b prob n')[:, :, -1]
        vocab_size = outs.shape[1]
        outs = F.log_softmax(outs, dim=1)
        outs += einops.rearrange(candidates_log_prob, "prob -> prob ()")
        outs = einops.rearrange(outs, 'b prob -> (b prob)')
        candidates_log_prob, indices = torch.topk(outs, k=top_k)

        new_candidates = []
        for index in indices:
            candidate_id, token_id = divmod(index.item(), vocab_size)
            new_candidates.append(candidates_tex_ids[candidate_id].to(int).tolist() + [token_id])
        candidates_tex_ids = new_candidates

    candidates_tex_ids = torch.tensor(candidates_tex_ids)
    padding_mask = get_tgt_padding_mask(candidates_tex_ids).cpu()
    candidates_tex_ids = candidates_tex_ids.masked_fill(
        padding_mask & (candidates_tex_ids != tex_tokenizer.token_to_id("[SEP]")),
        tex_tokenizer.token_to_id("[PAD]")).tolist()
    texs = tex_tokenizer.decode_batch(candidates_tex_ids, skip_special_tokens=True)
    return texs, candidates_tex_ids