diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..b48a19824f60e01f250bb906aab558f25a1d4954 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +data/annotation.json filter=lfs diff=lfs merge=lfs -text +pycocoevalcap/meteor/meteor-1.5.jar filter=lfs diff=lfs merge=lfs -text +pycocoevalcap/tokenizer/stanford-corenlp-3.4.1.jar filter=lfs diff=lfs merge=lfs -text diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/PromptNet.py b/PromptNet.py new file mode 100644 index 0000000000000000000000000000000000000000..6a714bbe26cd8b539c2acf014a18d19d3916f235 --- /dev/null +++ b/PromptNet.py @@ -0,0 +1,114 @@ +import torch +import argparse +from modules.dataloader import R2DataLoader +from modules.tokenizers import Tokenizer +from modules.loss import compute_loss +from modules.metrics import compute_scores +from modules.optimizers import build_optimizer, build_lr_scheduler +from models.models import MedCapModel +from modules.trainer import Trainer +import numpy as np + +def main(): + parser = argparse.ArgumentParser() + + # Data input Settings + parser.add_argument('--json_path', default='data/mimic_cxr/annotation.json', + help='Path to the json file') + parser.add_argument('--image_dir', default='data/mimic_cxr/images/', + help='Directory of images') + + # Dataloader Settings + parser.add_argument('--dataset', default='mimic_cxr', help='dataset for training MedCap') + parser.add_argument('--bs', type=int, default=16) + parser.add_argument('--threshold', type=int, default=10, help='the cut off frequency for the words.') + parser.add_argument('--num_workers', type=int, default=2, help='the number of workers for dataloader.') + parser.add_argument('--max_seq_length', type=int, default=1024, help='the maximum sequence length of the reports.') + + #Trainer Settings + parser.add_argument('--epochs', type=int, default=30) + parser.add_argument('--n_gpu', type=int, default=1, help='the number of gpus to be used.') + parser.add_argument('--save_dir', type=str, default='results/mimic_cxr/', help='the patch to save the models.') + parser.add_argument('--record_dir', type=str, default='./record_dir/', + help='the patch to save the results of experiments.') + parser.add_argument('--log_period', type=int, default=1000, help='the logging interval (in batches).') + parser.add_argument('--save_period', type=int, default=1) + parser.add_argument('--monitor_mode', type=str, default='max', choices=['min', 'max'], help='whether to max or min the metric.') + parser.add_argument('--monitor_metric', type=str, default='BLEU_4', help='the metric to be monitored.') + parser.add_argument('--early_stop', type=int, default=50, help='the patience of training.') + + # Training related + parser.add_argument('--noise_inject', default='no', choices=['yes', 'no']) + + # Sample related + parser.add_argument('--sample_method', type=str, default='greedy', help='the sample methods to sample a report.') + parser.add_argument('--prompt', default='/prompt/prompt.pt') + parser.add_argument('--prompt_load', default='no',choices=['yes','no']) + + # Optimization + parser.add_argument('--optim', type=str, default='Adam', help='the type of the optimizer.') + parser.add_argument('--lr_ve', type=float, default=1e-5, help='the learning rate for the visual extractor.') + parser.add_argument('--lr_ed', type=float, default=5e-4, help='the learning rate for the remaining parameters.') + parser.add_argument('--weight_decay', type=float, default=5e-5, help='the weight decay.') + parser.add_argument('--adam_betas', type=tuple, default=(0.9, 0.98), help='the weight decay.') + parser.add_argument('--adam_eps', type=float, default=1e-9, help='the weight decay.') + parser.add_argument('--amsgrad', type=bool, default=True, help='.') + parser.add_argument('--noamopt_warmup', type=int, default=5000, help='.') + parser.add_argument('--noamopt_factor', type=int, default=1, help='.') + + # Learning Rate Scheduler + parser.add_argument('--lr_scheduler', type=str, default='StepLR', help='the type of the learning rate scheduler.') + parser.add_argument('--step_size', type=int, default=50, help='the step size of the learning rate scheduler.') + parser.add_argument('--gamma', type=float, default=0.1, help='the gamma of the learning rate scheduler.') + + # Others + parser.add_argument('--seed', type=int, default=9153, help='.') + parser.add_argument('--resume', type=str, help='whether to resume the training from existing checkpoints.') + parser.add_argument('--train_mode', default='base', choices=['base', 'fine-tuning'], + help='Training mode: base (autoencoding) or fine-tuning (full supervised training or fine-tuned on downstream datasets)') + parser.add_argument('--F_version', default='v1', choices=['v1', 'v2'],) + parser.add_argument('--clip_update', default='no' , choices=['yes','no']) + + # Fine-tuning + parser.add_argument('--random_init', default='yes', choices=['yes', 'no'], + help='Whether to load the pre-trained weights for fine-tuning.') + parser.add_argument('--weight_path', default='path_to_default_weights', type=str, + help='Path to the pre-trained model weights.') + args = parser.parse_args() + + # fix random seeds + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + np.random.seed(args.seed) + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # create tokenizer + tokenizer = Tokenizer(args) + + # create data loader + train_dataloader = R2DataLoader(args, tokenizer, split='train', shuffle=True) + val_dataloader = R2DataLoader(args, tokenizer, split='val', shuffle=False) + test_dataloader = R2DataLoader(args, tokenizer, split='test', shuffle=False) + + # get function handles of loss and metrics + criterion = compute_loss + metrics = compute_scores + model = MedCapModel(args, tokenizer) + + if args.train_mode == 'fine-tuning' and args.random_init == 'no': + # Load weights from the specified path + checkpoint = torch.load(args.weight_path) + model.load_state_dict(checkpoint) + + # build optimizer, learning rate scheduler + optimizer = build_optimizer(args, model) + lr_scheduler = build_lr_scheduler(args, optimizer) + + # build trainer and start to train + trainer = Trainer(model, criterion, metrics, optimizer, args, lr_scheduler, train_dataloader, val_dataloader, test_dataloader) + trainer.train() + +if __name__ == '__main__': + main() diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..3af63acaf3d6ea43ee89ab55f01041234cf22045 --- /dev/null +++ b/app.py @@ -0,0 +1,121 @@ +import gradio as gr +import torch +from PIL import Image +from models.r2gen import R2GenModel +from modules.tokenizers import Tokenizer +import argparse + +# Assuming you have a predefined configuration function for model args +def get_model_args(): + parser = argparse.ArgumentParser() + + # Model loader settings + parser.add_argument('--load', type=str, default='ckpts/few-shot.pth', help='the path to the model weights.') + parser.add_argument('--prompt', type=str, default='prompt/prompt.pth', help='the path to the prompt weights.') + + # Data input settings + parser.add_argument('--image_path', type=str, default='example_figs/example_fig1.jpg', help='the path to the test image.') + parser.add_argument('--image_dir', type=str, default='data/images/', help='the path to the directory containing the data.') + parser.add_argument('--ann_path', type=str, default='data/annotation.json', help='the path to the directory containing the data.') + + # Data loader settings + parser.add_argument('--dataset_name', type=str, default='mimic_cxr', help='the dataset to be used.') + parser.add_argument('--max_seq_length', type=int, default=60, help='the maximum sequence length of the reports.') + parser.add_argument('--threshold', type=int, default=3, help='the cut off frequency for the words.') + parser.add_argument('--num_workers', type=int, default=2, help='the number of workers for dataloader.') + parser.add_argument('--batch_size', type=int, default=16, help='the number of samples for a batch') + + # Model settings (for visual extractor) + parser.add_argument('--visual_extractor', type=str, default='resnet101', help='the visual extractor to be used.') + parser.add_argument('--visual_extractor_pretrained', type=bool, default=True, help='whether to load the pretrained visual extractor') + + # Model settings (for Transformer) + parser.add_argument('--d_model', type=int, default=512, help='the dimension of Transformer.') + parser.add_argument('--d_ff', type=int, default=512, help='the dimension of FFN.') + parser.add_argument('--d_vf', type=int, default=2048, help='the dimension of the patch features.') + parser.add_argument('--num_heads', type=int, default=8, help='the number of heads in Transformer.') + parser.add_argument('--num_layers', type=int, default=3, help='the number of layers of Transformer.') + parser.add_argument('--dropout', type=float, default=0.1, help='the dropout rate of Transformer.') + parser.add_argument('--logit_layers', type=int, default=1, help='the number of the logit layer.') + parser.add_argument('--bos_idx', type=int, default=0, help='the index of .') + parser.add_argument('--eos_idx', type=int, default=0, help='the index of .') + parser.add_argument('--pad_idx', type=int, default=0, help='the index of .') + parser.add_argument('--use_bn', type=int, default=0, help='whether to use batch normalization.') + parser.add_argument('--drop_prob_lm', type=float, default=0.5, help='the dropout rate of the output layer.') + # for Relational Memory + parser.add_argument('--rm_num_slots', type=int, default=3, help='the number of memory slots.') + parser.add_argument('--rm_num_heads', type=int, default=8, help='the numebr of heads in rm.') + parser.add_argument('--rm_d_model', type=int, default=512, help='the dimension of rm.') + + # Sample related + parser.add_argument('--sample_method', type=str, default='beam_search', help='the sample methods to sample a report.') + parser.add_argument('--beam_size', type=int, default=3, help='the beam size when beam searching.') + parser.add_argument('--temperature', type=float, default=1.0, help='the temperature when sampling.') + parser.add_argument('--sample_n', type=int, default=1, help='the sample number per image.') + parser.add_argument('--group_size', type=int, default=1, help='the group size.') + parser.add_argument('--output_logsoftmax', type=int, default=1, help='whether to output the probabilities.') + parser.add_argument('--decoding_constraint', type=int, default=0, help='whether decoding constraint.') + parser.add_argument('--block_trigrams', type=int, default=1, help='whether to use block trigrams.') + + # Trainer settings + parser.add_argument('--n_gpu', type=int, default=1, help='the number of gpus to be used.') + parser.add_argument('--epochs', type=int, default=100, help='the number of training epochs.') + parser.add_argument('--save_dir', type=str, default='results/iu_xray', help='the patch to save the models.') + parser.add_argument('--record_dir', type=str, default='records/', help='the patch to save the results of experiments') + parser.add_argument('--save_period', type=int, default=1, help='the saving period.') + parser.add_argument('--monitor_mode', type=str, default='max', choices=['min', 'max'], help='whether to max or min the metric.') + parser.add_argument('--monitor_metric', type=str, default='BLEU_4', help='the metric to be monitored.') + parser.add_argument('--early_stop', type=int, default=50, help='the patience of training.') + + # Optimization + parser.add_argument('--optim', type=str, default='Adam', help='the type of the optimizer.') + parser.add_argument('--lr_ve', type=float, default=5e-5, help='the learning rate for the visual extractor.') + parser.add_argument('--lr_ed', type=float, default=1e-4, help='the learning rate for the remaining parameters.') + parser.add_argument('--weight_decay', type=float, default=5e-5, help='the weight decay.') + parser.add_argument('--amsgrad', type=bool, default=True, help='.') + + # Learning Rate Scheduler + parser.add_argument('--lr_scheduler', type=str, default='StepLR', help='the type of the learning rate scheduler.') + parser.add_argument('--step_size', type=int, default=50, help='the step size of the learning rate scheduler.') + parser.add_argument('--gamma', type=float, default=0.1, help='the gamma of the learning rate scheduler.') + + # Others + parser.add_argument('--seed', type=int, default=9233, help='.') + parser.add_argument('--resume', type=str, help='whether to resume the training from existing checkpoints.') + + args = parser.parse_args() + return args + +def load_model(): + args = get_model_args() + tokenizer = Tokenizer(args) + device = 'cuda' if torch.cuda.is_available() else 'cpu' # Determine the device dynamically + model = R2GenModel(args, tokenizer).to(device) + checkpoint_path = args.load + # Ensure the state dict is loaded onto the same device as the model + state_dict = torch.load(checkpoint_path, map_location=device) + model_state_dict = state_dict['state_dict'] if 'state_dict' in state_dict else state_dict + model.load_state_dict(model_state_dict) + model.eval() + return model, tokenizer + +model, tokenizer = load_model() + +def generate_report(image): + image = Image.fromarray(image).convert('RGB') + with torch.no_grad(): + output = model([image], mode='sample') + reports = tokenizer.decode_batch(output.cpu().numpy()) + return reports[0] + +# Define Gradio interface +iface = gr.Interface( + fn=generate_report, + inputs=gr.inputs.Image(), # Define input shape as needed + outputs="text", + title="PromptNet", + description="Upload a medical image for thorax disease reporting." +) + +if __name__ == "__main__": + iface.launch() diff --git a/ckpts/few-shot.pth b/ckpts/few-shot.pth new file mode 100644 index 0000000000000000000000000000000000000000..5653892a560a8c56d4ddf456e444a7449332abfd --- /dev/null +++ b/ckpts/few-shot.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fa4c3ef1a822fdca8895f6ad0c73b4f355b036d0d28a8523aaf51f58c7393f38 +size 1660341639 diff --git a/data/annotation.json b/data/annotation.json new file mode 100644 index 0000000000000000000000000000000000000000..845c41e7029c0dbcfc9c766fed0d085bb0337bdf --- /dev/null +++ b/data/annotation.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5d9590de8db89b0c74343a7e2aecba61e8029e15801de10ec4e030be80b62adc +size 155745921 diff --git a/decoder_config/decoder_config.pkl b/decoder_config/decoder_config.pkl new file mode 100644 index 0000000000000000000000000000000000000000..31293e585081fe4de2004432722e04100ab9fa9c --- /dev/null +++ b/decoder_config/decoder_config.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c454e6bddb15af52c82734f1796391bf3a10a6c5533ea095de06f661ebb858bb +size 1744 diff --git a/example_figs/example_fig1.jpg.png b/example_figs/example_fig1.jpg.png new file mode 100644 index 0000000000000000000000000000000000000000..3ac518805ff76d19c425c65a9b20768977f1d562 Binary files /dev/null and b/example_figs/example_fig1.jpg.png differ diff --git a/example_figs/example_fig2.jpg.jpg b/example_figs/example_fig2.jpg.jpg new file mode 100644 index 0000000000000000000000000000000000000000..862a1b95e29366b96e24b5c65db2d9586b794eb8 Binary files /dev/null and b/example_figs/example_fig2.jpg.jpg differ diff --git a/example_figs/example_fig3.jpg.png b/example_figs/example_fig3.jpg.png new file mode 100644 index 0000000000000000000000000000000000000000..492575cc553ba11c0de1175875788748e8b76f47 Binary files /dev/null and b/example_figs/example_fig3.jpg.png differ diff --git a/inference.py b/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..215ac4eefa08520de298af6c3a48472c65a18d41 --- /dev/null +++ b/inference.py @@ -0,0 +1,110 @@ +import torch +from models.r2gen import R2GenModel +from PIL import Image +from modules.tokenizers import Tokenizer +import main +import argparse +import json +import re +from collections import Counter + +def parse_agrs(): + parser = argparse.ArgumentParser() + + # Model loader settings + parser.add_argument('--load', type=str, default='ckpt/checkpoint.pth', help='the path to the model weights.') + parser.add_argument('--prompt', type=str, default='ckpt/prompt.pth', help='the path to the prompt weights.') + + # Data input settings + parser.add_argument('--image_path', type=str, default='example_figs/fig1.jpg', help='the path to the test image.') + parser.add_argument('--image_dir', type=str, default='data/images/', help='the path to the directory containing the data.') + parser.add_argument('--ann_path', type=str, default='data/annotation.json', help='the path to the directory containing the data.') + + # Data loader settings + parser.add_argument('--dataset_name', type=str, default='mimic_cxr', help='the dataset to be used.') + parser.add_argument('--max_seq_length', type=int, default=60, help='the maximum sequence length of the reports.') + parser.add_argument('--threshold', type=int, default=3, help='the cut off frequency for the words.') + parser.add_argument('--num_workers', type=int, default=2, help='the number of workers for dataloader.') + parser.add_argument('--batch_size', type=int, default=16, help='the number of samples for a batch') + + # Model settings (for visual extractor) + parser.add_argument('--visual_extractor', type=str, default='resnet101', help='the visual extractor to be used.') + parser.add_argument('--visual_extractor_pretrained', type=bool, default=True, help='whether to load the pretrained visual extractor') + + # Model settings (for Transformer) + parser.add_argument('--d_model', type=int, default=512, help='the dimension of Transformer.') + parser.add_argument('--d_ff', type=int, default=512, help='the dimension of FFN.') + parser.add_argument('--d_vf', type=int, default=2048, help='the dimension of the patch features.') + parser.add_argument('--num_heads', type=int, default=8, help='the number of heads in Transformer.') + parser.add_argument('--num_layers', type=int, default=3, help='the number of layers of Transformer.') + parser.add_argument('--dropout', type=float, default=0.1, help='the dropout rate of Transformer.') + parser.add_argument('--logit_layers', type=int, default=1, help='the number of the logit layer.') + parser.add_argument('--bos_idx', type=int, default=0, help='the index of .') + parser.add_argument('--eos_idx', type=int, default=0, help='the index of .') + parser.add_argument('--pad_idx', type=int, default=0, help='the index of .') + parser.add_argument('--use_bn', type=int, default=0, help='whether to use batch normalization.') + parser.add_argument('--drop_prob_lm', type=float, default=0.5, help='the dropout rate of the output layer.') + # for Relational Memory + parser.add_argument('--rm_num_slots', type=int, default=3, help='the number of memory slots.') + parser.add_argument('--rm_num_heads', type=int, default=8, help='the numebr of heads in rm.') + parser.add_argument('--rm_d_model', type=int, default=512, help='the dimension of rm.') + + # Sample related + parser.add_argument('--sample_method', type=str, default='beam_search', help='the sample methods to sample a report.') + parser.add_argument('--beam_size', type=int, default=3, help='the beam size when beam searching.') + parser.add_argument('--temperature', type=float, default=1.0, help='the temperature when sampling.') + parser.add_argument('--sample_n', type=int, default=1, help='the sample number per image.') + parser.add_argument('--group_size', type=int, default=1, help='the group size.') + parser.add_argument('--output_logsoftmax', type=int, default=1, help='whether to output the probabilities.') + parser.add_argument('--decoding_constraint', type=int, default=0, help='whether decoding constraint.') + parser.add_argument('--block_trigrams', type=int, default=1, help='whether to use block trigrams.') + + # Trainer settings + parser.add_argument('--n_gpu', type=int, default=1, help='the number of gpus to be used.') + parser.add_argument('--epochs', type=int, default=100, help='the number of training epochs.') + parser.add_argument('--save_dir', type=str, default='results/iu_xray', help='the patch to save the models.') + parser.add_argument('--record_dir', type=str, default='records/', help='the patch to save the results of experiments') + parser.add_argument('--save_period', type=int, default=1, help='the saving period.') + parser.add_argument('--monitor_mode', type=str, default='max', choices=['min', 'max'], help='whether to max or min the metric.') + parser.add_argument('--monitor_metric', type=str, default='BLEU_4', help='the metric to be monitored.') + parser.add_argument('--early_stop', type=int, default=50, help='the patience of training.') + + # Optimization + parser.add_argument('--optim', type=str, default='Adam', help='the type of the optimizer.') + parser.add_argument('--lr_ve', type=float, default=5e-5, help='the learning rate for the visual extractor.') + parser.add_argument('--lr_ed', type=float, default=1e-4, help='the learning rate for the remaining parameters.') + parser.add_argument('--weight_decay', type=float, default=5e-5, help='the weight decay.') + parser.add_argument('--amsgrad', type=bool, default=True, help='.') + + # Learning Rate Scheduler + parser.add_argument('--lr_scheduler', type=str, default='StepLR', help='the type of the learning rate scheduler.') + parser.add_argument('--step_size', type=int, default=50, help='the step size of the learning rate scheduler.') + parser.add_argument('--gamma', type=float, default=0.1, help='the gamma of the learning rate scheduler.') + + # Others + parser.add_argument('--seed', type=int, default=9233, help='.') + parser.add_argument('--resume', type=str, help='whether to resume the training from existing checkpoints.') + + args = parser.parse_args() + return args + + +args = parse_agrs() +tokenizer = Tokenizer(args) +image_path=args.image_path +checkpoint_path = args.load + +image =[Image.open(image_path).convert('RGB') +] +model=R2GenModel(args ,tokenizer).to('cuda' if torch.cuda.is_available() else 'cpu') + +state_dict = torch.load(checkpoint_path) +model_state_dict = state_dict['state_dict'] +model.load_state_dict(model_state_dict).to('cuda' if torch.cuda.is_available() else 'cpu') + +model.eval() +with torch.no_grad(): + + output = model(image, mode='sample') + reports = model.tokenizer.decode_batch(output.cpu().numpy()) + print(reports) diff --git a/models/models.py b/models/models.py new file mode 100644 index 0000000000000000000000000000000000000000..138f3d3a2061d0fad2a404a4fce050a64a0ddd59 --- /dev/null +++ b/models/models.py @@ -0,0 +1,125 @@ +import numpy as np +import torch +import torch.nn as nn +import pickle +from typing import Tuple +from transformers import GPT2LMHeadModel +from modules.decoder import DeCap +from medclip import MedCLIPModel, MedCLIPVisionModelViT +import math +import pdb + + +class MedCapModel(nn.Module): + def __init__(self, args, tokenizer): + super(MedCapModel, self).__init__() + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.args = args + self.tokenizer = tokenizer + self.model = DeCap(args, tokenizer) + + self.align_model = MedCLIPModel(vision_cls=MedCLIPVisionModelViT) + self.align_model.from_pretrained() + self.prompt = torch.load(args.prompt) + if args.dataset == 'iu_xray': + self.forward = self.forward_iu_xray + else: + self.forward = self.forward_mimic_cxr + + def noise_injection(self, x, variance=0.001, modality_offset=None, dont_norm=False): + if variance == 0.0: + return x + std = math.sqrt(variance) + if not dont_norm: + x = torch.nn.functional.normalize(x, dim=1) + else: + x = x + (torch.randn(x.shape) * std) # todo by some conventions multivraiance noise should be devided by sqrt of dim + if modality_offset is not None: + x = x + modality_offset + return torch.nn.functional.normalize(x, dim=1) + + def align_encode_images_iu_xray(self, images): + # Split the images + image1, image2 = images.unbind(dim=1) + # Encode each image + feature1 = self.align_model.encode_image(image1) + feature2 = self.align_model.encode_image(image2) + if self.args.prompt_load == 'yes': + sim_1 = feature1 @ self.prompt.T.float() + sim_1 = (sim_1 * 100).softmax(dim=-1) + prefix_embedding_1 = sim_1 @ self.prompt.float() + prefix_embedding_1 /= prefix_embedding_1.norm(dim=-1, keepdim=True) + + sim_2 = feature2 @ self.prompt.T.float() + sim_2 = (sim_2 * 100).softmax(dim=-1) + prefix_embedding_2 = sim_2 @ self.prompt.float() + prefix_embedding_2 /= prefix_embedding_2.norm(dim=-1, keepdim=True) + averaged_prompt_features = torch.mean(torch.stack([prefix_embedding_1, prefix_embedding_2]), dim=0) + return averaged_prompt_features + else: + # Concatenate the features + averaged_features = torch.mean(torch.stack([feature1, feature2]), dim=0) + return averaged_features + + def align_encode_images_mimic_cxr(self, images): + feature = self.align_model.encode_image(images) + if self.args.prompt_load == 'yes': + sim = feature @ self.prompt.T.float() + sim = (sim * 100).softmax(dim=-1) + prefix_embedding = sim @ self.prompt.float() + prefix_embedding /= prefix_embedding.norm(dim=-1, keepdim=True) + return prefix_embedding + else: + return feature + + def forward_iu_xray(self, reports_ids, align_ids, align_masks, images, mode='train', update_opts={}): + self.align_model.to(self.device) + self.align_model.eval() + align_ids = align_ids.long() + + align_image_feature = None + if self.args.train_mode == 'fine-tuning': + align_image_feature = self.align_encode_images_iu_xray(images) + if mode == 'train': + align_text_feature = self.align_model.encode_text(align_ids, align_masks) + if self.args.noise_inject == 'yes': + align_text_feature = self.noise_injection(align_text_feature) + + if self.args.train_mode == 'fine-tuning': + if self.args.F_version == 'v1': + combined_feature = torch.cat([align_text_feature, align_image_feature], dim=-1) + align_text_feature = self.fc_reduce_dim(combined_feature) + if self.args.F_version == 'v2': + align_text_feature = align_image_feature + + outputs = self.model(align_text_feature, reports_ids, mode='forward') + logits = outputs.logits + logits = logits[:, :-1] + return logits + elif mode == 'sample': + align_image_feature = self.align_encode_images_iu_xray(images) + outputs = self.model(align_image_feature, reports_ids, mode='sample', update_opts=update_opts) + return outputs + else: + raise ValueError + + def forward_mimic_cxr(self, reports_ids, align_ids, align_masks, images, mode='train', update_opts={}): + self.align_model.to(self.device) + self.align_model.eval() + align_ids = align_ids.long() + if mode == 'train': + if self.args.noise_inject == 'yes': + align_text_feature = self.align_model.encode_text(align_ids, align_masks) + align_text_feature = self.noise_injection(align_text_feature) + else: + align_text_feature = self.align_model.encode_text(align_ids, align_masks) + outputs = self.model(align_text_feature, reports_ids, mode='forward') + logits = outputs.logits + logits = logits[:, :-1] + return logits + elif mode == 'sample': + align_image_feature = self.align_encode_images_mimic_cxr(images) + outputs = self.model(align_image_feature, reports_ids, mode='sample', update_opts=update_opts) + return outputs + else: + raise ValueError diff --git a/models/r2gen.py b/models/r2gen.py new file mode 100644 index 0000000000000000000000000000000000000000..a192ad1e132853e467148dfd669870282eff501d --- /dev/null +++ b/models/r2gen.py @@ -0,0 +1,63 @@ +import torch +import torch.nn as nn +import numpy as np + +from modules.visual_extractor import VisualExtractor +from modules.encoder_decoder import EncoderDecoder +import torch.nn.functional as F + +class R2GenModel(nn.Module): + def __init__(self, args, tokenizer): + super(R2GenModel, self).__init__() + self.args = args + self.tokenizer = tokenizer + self.visual_extractor = VisualExtractor(args) + self.encoder_decoder = EncoderDecoder(args, tokenizer) + if args.dataset_name == 'iu_xray': + self.forward = self.forward_iu_xray + else: + self.forward = self.forward_mimic_cxr + self.affine_a = nn.Linear(1024, 2048) + self.affine_b = nn.Linear(1024, 2048) + self.affine_c = nn.Linear(1024, 2048) + self.affine_d = nn.Linear(1024, 2048) + self.affine_aa = nn.Linear(1024, 2048) + self.affine_bb = nn.Linear(1024, 2048) + + def __str__(self): + model_parameters = filter(lambda p: p.requires_grad, self.parameters()) + params = sum([np.prod(p.size()) for p in model_parameters]) + return super().__str__() + '\nTrainable parameters: {}'.format(params) + + def forward_iu_xray(self, images, targets=None, mode='train'): + att_feats_0, fc_feats_0 = self.visual_extractor(images[:, 0]) + att_feats_1, fc_feats_1 = self.visual_extractor(images[:, 1]) + #new add + att_feats_0=F.relu(self.affine_a(att_feats_0)) + fc_feats_0=F.relu(self.affine_b(fc_feats_0)) + att_feats_1=F.relu(self.affine_c(att_feats_1)) + fc_feats_1=F.relu(self.affine_d(fc_feats_1)) + + fc_feats = torch.cat((fc_feats_0, fc_feats_1), dim=1) + att_feats = torch.cat((att_feats_0, att_feats_1), dim=1) + if mode == 'train': + output = self.encoder_decoder(fc_feats, att_feats, targets, mode='forward') + elif mode == 'sample': + output, _ = self.encoder_decoder(fc_feats, att_feats, mode='sample') + else: + raise ValueError + return output + + def forward_mimic_cxr(self, images, targets=None, mode='train'): + att_feats1, fc_feats1 = self.visual_extractor(images) + att_feats=F.relu(self.affine_aa(att_feats1)) + fc_feats=F.relu(self.affine_bb(fc_feats1)) + + if mode == 'train': + output = self.encoder_decoder(fc_feats, att_feats, targets, mode='forward') + elif mode == 'sample': + output, _ = self.encoder_decoder(fc_feats, att_feats, mode='sample') + else: + raise ValueError + return output + diff --git a/modules/att_model.py b/modules/att_model.py new file mode 100644 index 0000000000000000000000000000000000000000..1e03575ea9dedf5778f7b13568fc0f260b7522e9 --- /dev/null +++ b/modules/att_model.py @@ -0,0 +1,319 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence + +import modules.utils as utils +from modules.caption_model import CaptionModel + + +def sort_pack_padded_sequence(input, lengths): + sorted_lengths, indices = torch.sort(lengths, descending=True) + tmp = pack_padded_sequence(input[indices], sorted_lengths, batch_first=True) + inv_ix = indices.clone() + inv_ix[indices] = torch.arange(0, len(indices)).type_as(inv_ix) + return tmp, inv_ix + + +def pad_unsort_packed_sequence(input, inv_ix): + tmp, _ = pad_packed_sequence(input, batch_first=True) + tmp = tmp[inv_ix] + return tmp + + +def pack_wrapper(module, att_feats, att_masks): + if att_masks is not None: + packed, inv_ix = sort_pack_padded_sequence(att_feats, att_masks.data.long().sum(1)) + return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix) + else: + return module(att_feats) + + +class AttModel(CaptionModel): + def __init__(self, args, tokenizer): + super(AttModel, self).__init__() + self.args = args + self.tokenizer = tokenizer + self.vocab_size = len(tokenizer.idx2token) + self.input_encoding_size = args.d_model + self.rnn_size = args.d_ff + self.num_layers = args.num_layers + self.drop_prob_lm = args.drop_prob_lm + self.max_seq_length = args.max_seq_length + self.att_feat_size = args.d_vf + self.att_hid_size = args.d_model + + self.bos_idx = args.bos_idx + self.eos_idx = args.eos_idx + self.pad_idx = args.pad_idx + + self.use_bn = args.use_bn + + self.embed = lambda x: x + self.fc_embed = lambda x: x + self.att_embed = nn.Sequential(*( + ((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ()) + + (nn.Linear(self.att_feat_size, self.input_encoding_size), + nn.ReLU(), + nn.Dropout(self.drop_prob_lm)) + + ((nn.BatchNorm1d(self.input_encoding_size),) if self.use_bn == 2 else ()))) + + def clip_att(self, att_feats, att_masks): + # Clip the length of att_masks and att_feats to the maximum length + if att_masks is not None: + max_len = att_masks.data.long().sum(1).max() + att_feats = att_feats[:, :max_len].contiguous() + att_masks = att_masks[:, :max_len].contiguous() + return att_feats, att_masks + + def _prepare_feature(self, fc_feats, att_feats, att_masks): + att_feats, att_masks = self.clip_att(att_feats, att_masks) + + # embed fc and att feats + fc_feats = self.fc_embed(fc_feats) + att_feats = pack_wrapper(self.att_embed, att_feats, att_masks) + + # Project the attention feats first to reduce memory and computation comsumptions. + p_att_feats = self.ctx2att(att_feats) + + return fc_feats, att_feats, p_att_feats, att_masks + + def get_logprobs_state(self, it, fc_feats, att_feats, p_att_feats, att_masks, state, output_logsoftmax=1): + # 'it' contains a word index + xt = self.embed(it) + + output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state, att_masks) + if output_logsoftmax: + logprobs = F.log_softmax(self.logit(output), dim=1) + else: + logprobs = self.logit(output) + + return logprobs, state + + def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): + beam_size = opt.get('beam_size', 10) + group_size = opt.get('group_size', 1) + sample_n = opt.get('sample_n', 10) + # when sample_n == beam_size then each beam is a sample. + assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search' + batch_size = fc_feats.size(0) + + p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) + + assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' + seq = fc_feats.new_full((batch_size * sample_n, self.max_seq_length), self.pad_idx, dtype=torch.long) + seqLogprobs = fc_feats.new_zeros(batch_size * sample_n, self.max_seq_length, self.vocab_size + 1) + # lets process every image independently for now, for simplicity + + self.done_beams = [[] for _ in range(batch_size)] + + state = self.init_hidden(batch_size) + + # first step, feed bos + it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long) + logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state) + + p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(beam_size, + [p_fc_feats, p_att_feats, + pp_att_feats, p_att_masks] + ) + self.done_beams = self.beam_search(state, logprobs, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, opt=opt) + for k in range(batch_size): + if sample_n == beam_size: + for _n in range(sample_n): + seq_len = self.done_beams[k][_n]['seq'].shape[0] + seq[k * sample_n + _n, :seq_len] = self.done_beams[k][_n]['seq'] + seqLogprobs[k * sample_n + _n, :seq_len] = self.done_beams[k][_n]['logps'] + else: + seq_len = self.done_beams[k][0]['seq'].shape[0] + seq[k, :seq_len] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score + seqLogprobs[k, :seq_len] = self.done_beams[k][0]['logps'] + # return the samples and their log likelihoods + return seq, seqLogprobs + + def _sample(self, fc_feats, att_feats, att_masks=None): + opt = self.args.__dict__ + sample_method = opt.get('sample_method', 'greedy') + beam_size = opt.get('beam_size', 1) + temperature = opt.get('temperature', 1.0) + sample_n = int(opt.get('sample_n', 1)) + group_size = opt.get('group_size', 1) + output_logsoftmax = opt.get('output_logsoftmax', 1) + decoding_constraint = opt.get('decoding_constraint', 0) + block_trigrams = opt.get('block_trigrams', 0) + if beam_size > 1 and sample_method in ['greedy', 'beam_search']: + return self._sample_beam(fc_feats, att_feats, att_masks, opt) + if group_size > 1: + return self._diverse_sample(fc_feats, att_feats, att_masks, opt) + + batch_size = fc_feats.size(0) + state = self.init_hidden(batch_size * sample_n) + + p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) + + if sample_n > 1: + p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(sample_n, + [p_fc_feats, p_att_feats, + pp_att_feats, p_att_masks] + ) + + trigrams = [] # will be a list of batch_size dictionaries + + seq = fc_feats.new_full((batch_size * sample_n, self.max_seq_length), self.pad_idx, dtype=torch.long) + seqLogprobs = fc_feats.new_zeros(batch_size * sample_n, self.max_seq_length, self.vocab_size + 1) + for t in range(self.max_seq_length + 1): + if t == 0: # input + it = fc_feats.new_full([batch_size * sample_n], self.bos_idx, dtype=torch.long) + + logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state, + output_logsoftmax=output_logsoftmax) + + if decoding_constraint and t > 0: + tmp = logprobs.new_zeros(logprobs.size()) + tmp.scatter_(1, seq[:, t - 1].data.unsqueeze(1), float('-inf')) + logprobs = logprobs + tmp + + # Mess with trigrams + # Copy from https://github.com/lukemelas/image-paragraph-captioning + if block_trigrams and t >= 3: + # Store trigram generated at last step + prev_two_batch = seq[:, t - 3:t - 1] + for i in range(batch_size): # = seq.size(0) + prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item()) + current = seq[i][t - 1] + if t == 3: # initialize + trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int} + elif t > 3: + if prev_two in trigrams[i]: # add to list + trigrams[i][prev_two].append(current) + else: # create list + trigrams[i][prev_two] = [current] + # Block used trigrams at next step + prev_two_batch = seq[:, t - 2:t] + mask = torch.zeros(logprobs.size(), requires_grad=False).cuda() # batch_size x vocab_size + for i in range(batch_size): + prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item()) + if prev_two in trigrams[i]: + for j in trigrams[i][prev_two]: + mask[i, j] += 1 + # Apply mask to log probs + # logprobs = logprobs - (mask * 1e9) + alpha = 2.0 # = 4 + logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best) + + # sample the next word + if t == self.max_seq_length: # skip if we achieve maximum length + break + it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, temperature) + + # stop when all finished + if t == 0: + unfinished = it != self.eos_idx + else: + it[~unfinished] = self.pad_idx # This allows eos_idx not being overwritten to 0 + logprobs = logprobs * unfinished.unsqueeze(1).float() + unfinished = unfinished * (it != self.eos_idx) + seq[:, t] = it + seqLogprobs[:, t] = logprobs + # quit loop if all sequences have finished + if unfinished.sum() == 0: + break + + return seq, seqLogprobs + + def _diverse_sample(self, fc_feats, att_feats, att_masks=None, opt={}): + + sample_method = opt.get('sample_method', 'greedy') + beam_size = opt.get('beam_size', 1) + temperature = opt.get('temperature', 1.0) + group_size = opt.get('group_size', 1) + diversity_lambda = opt.get('diversity_lambda', 0.5) + decoding_constraint = opt.get('decoding_constraint', 0) + block_trigrams = opt.get('block_trigrams', 0) + + batch_size = fc_feats.size(0) + state = self.init_hidden(batch_size) + + p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) + + trigrams_table = [[] for _ in range(group_size)] # will be a list of batch_size dictionaries + + seq_table = [fc_feats.new_full((batch_size, self.max_seq_length), self.pad_idx, dtype=torch.long) for _ in + range(group_size)] + seqLogprobs_table = [fc_feats.new_zeros(batch_size, self.max_seq_length) for _ in range(group_size)] + state_table = [self.init_hidden(batch_size) for _ in range(group_size)] + + for tt in range(self.max_seq_length + group_size): + for divm in range(group_size): + t = tt - divm + seq = seq_table[divm] + seqLogprobs = seqLogprobs_table[divm] + trigrams = trigrams_table[divm] + if t >= 0 and t <= self.max_seq_length - 1: + if t == 0: # input + it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long) + else: + it = seq[:, t - 1] # changed + + logprobs, state_table[divm] = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, + p_att_masks, state_table[divm]) # changed + logprobs = F.log_softmax(logprobs / temperature, dim=-1) + + # Add diversity + if divm > 0: + unaug_logprobs = logprobs.clone() + for prev_choice in range(divm): + prev_decisions = seq_table[prev_choice][:, t] + logprobs[:, prev_decisions] = logprobs[:, prev_decisions] - diversity_lambda + + if decoding_constraint and t > 0: + tmp = logprobs.new_zeros(logprobs.size()) + tmp.scatter_(1, seq[:, t - 1].data.unsqueeze(1), float('-inf')) + logprobs = logprobs + tmp + + # Mess with trigrams + if block_trigrams and t >= 3: + # Store trigram generated at last step + prev_two_batch = seq[:, t - 3:t - 1] + for i in range(batch_size): # = seq.size(0) + prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item()) + current = seq[i][t - 1] + if t == 3: # initialize + trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int} + elif t > 3: + if prev_two in trigrams[i]: # add to list + trigrams[i][prev_two].append(current) + else: # create list + trigrams[i][prev_two] = [current] + # Block used trigrams at next step + prev_two_batch = seq[:, t - 2:t] + mask = torch.zeros(logprobs.size(), requires_grad=False).cuda() # batch_size x vocab_size + for i in range(batch_size): + prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item()) + if prev_two in trigrams[i]: + for j in trigrams[i][prev_two]: + mask[i, j] += 1 + # Apply mask to log probs + # logprobs = logprobs - (mask * 1e9) + alpha = 2.0 # = 4 + logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best) + + it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, 1) + + # stop when all finished + if t == 0: + unfinished = it != self.eos_idx + else: + unfinished = seq[:, t - 1] != self.pad_idx & seq[:, t - 1] != self.eos_idx + it[~unfinished] = self.pad_idx + unfinished = unfinished & (it != self.eos_idx) # changed + seq[:, t] = it + seqLogprobs[:, t] = sampleLogprobs.view(-1) + + return torch.stack(seq_table, 1).reshape(batch_size * group_size, -1), torch.stack(seqLogprobs_table, + 1).reshape( + batch_size * group_size, -1) diff --git a/modules/att_models.py b/modules/att_models.py new file mode 100644 index 0000000000000000000000000000000000000000..e96981685933d82f052581202e407a71ea086a76 --- /dev/null +++ b/modules/att_models.py @@ -0,0 +1,120 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import pdb + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence + +import modules.utils as utils +from modules.caption_model import CaptionModel + + +class AttModel(CaptionModel): + def __init__(self, args, tokenizer): + super(AttModel, self).__init__() + self.args = args + self.tokenizer = tokenizer + self.vocab_size = len(tokenizer.idx2token) + self.max_seq_length = 60 + + def _sample(self, clip_features, gpt_tokens,update_opts={}): + + opt = self.args.__dict__ + opt.update(**update_opts) + sample_method = opt.get('sample_method', 'greedy') + + + if sample_method == 'greedy': + return self._greedy_sample(clip_features, gpt_tokens) + elif sample_method == 'beam_search': + return self._beam_search_sample(clip_features, gpt_tokens) + else: + raise ValueError("Unknown sample_method: " + sample_method) + + def _greedy_sample(self, clip_features, gpt_tokens, temperature=1.0): + #input_ids = torch.full((clip_features.size(0), 1), self.tokenizer.bos_token_id).type_as(clip_features).long() + clip_features = self.clip_project(clip_features).reshape(clip_features.size(0), 1, -1) + tokens = [None for _ in range(clip_features.size(0))] + finished = [False for _ in range(clip_features.size(0))] + max_length = 200 + for _ in range(max_length): + outputs = self.decoder(inputs_embeds= clip_features) + logits = outputs.logits[:, -1, :] / (temperature if temperature > 0 else 1.0) + next_tokens = torch.argmax(logits, -1).unsqueeze(1) + next_token_embeds = self.decoder.transformer.wte(next_tokens) + for j in range(clip_features.size(0)): + if finished[j]: + continue + if tokens[j] is None: + tokens[j] = next_tokens[j] + else: + tokens[j] = torch.cat((tokens[j], next_tokens[j]), dim=0) + if next_tokens[j].item() == self.tokenizer.eos_token_id: + finished[j] = True + clip_features = torch.cat((clip_features, next_token_embeds), dim=1) + outputs = [] + for token in tokens: + try: + output_list = token.squeeze().cpu().numpy().tolist() + # Pad or truncate output_list to max_length + output_list = (output_list + [self.tokenizer.pad_token_id] * max_length)[:max_length] + except Exception as e: + print(f"Error during decoding: {type(e).__name__}: {e}") + output_list = [self.tokenizer.pad_token_id] * max_length + outputs.append(output_list) + + # Convert list of lists to tensor + outputs = torch.tensor(outputs, device=clip_features.device) + return outputs + + + def _beam_search_sample(self, clip_features, gpt_tokens, beam_size=5): + batch_size = clip_features.size(0) + # Prepare the first input for every beam + input_ids = torch.full((batch_size*beam_size, 1), self.tokenizer.bos_token_id).type_as(clip_features).long() + beam_scores = torch.zeros((batch_size, beam_size)).type_as(clip_features) + done = [False]*batch_size + + for _ in range(self.max_seq_length): + outputs = self._forward(clip_features.repeat_interleave(beam_size, 0), input_ids) + next_token_logits = outputs.logits[:, -1, :] + next_token_probs = F.softmax(next_token_logits, dim=-1) + + # Apply a mask for already finished beams + next_token_probs[done] = 0 + next_token_probs[:, self.tokenizer.eos_token_id] = -float('Inf') + + # Multiply old scores with new probabilities + scores = beam_scores.unsqueeze(2) * next_token_probs + scores = scores.view(batch_size, -1) + + # Get the top beam_size scores and their respective indices + top_scores, top_indices = scores.topk(beam_size, dim=1) + + # Update beam scores + beam_scores = top_scores.log() + + # Reshape input_ids + input_ids = input_ids.view(batch_size, beam_size, -1) + + # Compute next inputs + next_token_ids = top_indices % self.vocab_size + beam_indices = top_indices // self.vocab_size + next_input_ids = torch.cat([input_ids.gather(1, beam_indices.unsqueeze(2).expand(-1, -1, input_ids.size(2))), next_token_ids.unsqueeze(2)], dim=2) + + # Flatten input_ids + input_ids = next_input_ids.view(batch_size*beam_size, -1) + + # Check which beams are done + done = (next_token_ids == self.tokenizer.eos_token_id).all(dim=1).tolist() + + if all(done): + break + + return input_ids.view(batch_size, beam_size, -1) + + diff --git a/modules/caption_model.py b/modules/caption_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b02bcb0666e8882ae277e9c33ca291ca843bc2a3 --- /dev/null +++ b/modules/caption_model.py @@ -0,0 +1,401 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import modules.utils as utils + + +class CaptionModel(nn.Module): + def __init__(self): + super(CaptionModel, self).__init__() + + # implements beam search + # calls beam_step and returns the final set of beams + # augments log-probabilities with diversity terms when number of groups > 1 + + def forward(self, *args, **kwargs): + mode = kwargs.get('mode', 'forward') + if 'mode' in kwargs: + del kwargs['mode'] + return getattr(self, '_' + mode)(*args, **kwargs) + + def beam_search(self, init_state, init_logprobs, *args, **kwargs): + + # function computes the similarity score to be augmented + def add_diversity(beam_seq_table, logprobs, t, divm, diversity_lambda, bdash): + local_time = t - divm + unaug_logprobs = logprobs.clone() + batch_size = beam_seq_table[0].shape[0] + + if divm > 0: + change = logprobs.new_zeros(batch_size, logprobs.shape[-1]) + for prev_choice in range(divm): + prev_decisions = beam_seq_table[prev_choice][:, :, local_time] # Nxb + for prev_labels in range(bdash): + change.scatter_add_(1, prev_decisions[:, prev_labels].unsqueeze(-1), + change.new_ones(batch_size, 1)) + + if local_time == 0: + logprobs = logprobs - change * diversity_lambda + else: + logprobs = logprobs - self.repeat_tensor(bdash, change) * diversity_lambda + + return logprobs, unaug_logprobs + + # does one step of classical beam search + + def beam_step(logprobs, unaug_logprobs, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state): + # INPUTS: + # logprobs: probabilities augmented after diversity N*bxV + # beam_size: obvious + # t : time instant + # beam_seq : tensor contanining the beams + # beam_seq_logprobs: tensor contanining the beam logprobs + # beam_logprobs_sum: tensor contanining joint logprobs + # OUPUTS: + # beam_seq : tensor containing the word indices of the decoded captions Nxbxl + # beam_seq_logprobs : log-probability of each decision made, NxbxlxV + # beam_logprobs_sum : joint log-probability of each beam Nxb + + batch_size = beam_logprobs_sum.shape[0] + vocab_size = logprobs.shape[-1] + logprobs = logprobs.reshape(batch_size, -1, vocab_size) # NxbxV + if t == 0: + assert logprobs.shape[1] == 1 + beam_logprobs_sum = beam_logprobs_sum[:, :1] + candidate_logprobs = beam_logprobs_sum.unsqueeze(-1) + logprobs # beam_logprobs_sum Nxb logprobs is NxbxV + ys, ix = torch.sort(candidate_logprobs.reshape(candidate_logprobs.shape[0], -1), -1, True) + ys, ix = ys[:, :beam_size], ix[:, :beam_size] + beam_ix = ix // vocab_size # Nxb which beam + selected_ix = ix % vocab_size # Nxb # which world + state_ix = (beam_ix + torch.arange(batch_size).type_as(beam_ix).unsqueeze(-1) * logprobs.shape[1]).reshape( + -1) # N*b which in Nxb beams + + if t > 0: + # gather according to beam_ix + assert (beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq)) == + beam_seq.reshape(-1, beam_seq.shape[-1])[state_ix].view_as(beam_seq)).all() + beam_seq = beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq)) + + beam_seq_logprobs = beam_seq_logprobs.gather(1, beam_ix.unsqueeze(-1).unsqueeze(-1).expand_as( + beam_seq_logprobs)) + + beam_seq = torch.cat([beam_seq, selected_ix.unsqueeze(-1)], -1) # beam_seq Nxbxl + beam_logprobs_sum = beam_logprobs_sum.gather(1, beam_ix) + \ + logprobs.reshape(batch_size, -1).gather(1, ix) + assert (beam_logprobs_sum == ys).all() + _tmp_beam_logprobs = unaug_logprobs[state_ix].reshape(batch_size, -1, vocab_size) + beam_logprobs = unaug_logprobs.reshape(batch_size, -1, vocab_size).gather(1, + beam_ix.unsqueeze(-1).expand(-1, + -1, + vocab_size)) # NxbxV + assert (_tmp_beam_logprobs == beam_logprobs).all() + beam_seq_logprobs = torch.cat([ + beam_seq_logprobs, + beam_logprobs.reshape(batch_size, -1, 1, vocab_size)], 2) + + new_state = [None for _ in state] + for _ix in range(len(new_state)): + # copy over state in previous beam q to new beam at vix + new_state[_ix] = state[_ix][:, state_ix] + state = new_state + return beam_seq, beam_seq_logprobs, beam_logprobs_sum, state + + # Start diverse_beam_search + opt = kwargs['opt'] + temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs + beam_size = opt.get('beam_size', 10) + group_size = opt.get('group_size', 1) + diversity_lambda = opt.get('diversity_lambda', 0.5) + decoding_constraint = opt.get('decoding_constraint', 0) + suppress_UNK = opt.get('suppress_UNK', 0) + length_penalty = utils.penalty_builder(opt.get('length_penalty', '')) + bdash = beam_size // group_size # beam per group + + batch_size = init_logprobs.shape[0] + device = init_logprobs.device + # INITIALIZATIONS + beam_seq_table = [torch.LongTensor(batch_size, bdash, 0).to(device) for _ in range(group_size)] + beam_seq_logprobs_table = [torch.FloatTensor(batch_size, bdash, 0, self.vocab_size + 1).to(device) for _ in + range(group_size)] + beam_logprobs_sum_table = [torch.zeros(batch_size, bdash).to(device) for _ in range(group_size)] + + # logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1) + done_beams_table = [[[] for __ in range(group_size)] for _ in range(batch_size)] + state_table = [[_.clone() for _ in init_state] for _ in range(group_size)] + logprobs_table = [init_logprobs.clone() for _ in range(group_size)] + # END INIT + + # Chunk elements in the args + args = list(args) + args = utils.split_tensors(group_size, args) # For each arg, turn (Bbg)x... to (Bb)x(g)x... + if self.__class__.__name__ == 'AttEnsemble': + args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in + range(group_size)] # group_name, arg_name, model_name + else: + args = [[args[i][j] for i in range(len(args))] for j in range(group_size)] + + for t in range(self.max_seq_length + group_size - 1): + for divm in range(group_size): + if t >= divm and t <= self.max_seq_length + divm - 1: + # add diversity + logprobs = logprobs_table[divm] + # suppress previous word + if decoding_constraint and t - divm > 0: + logprobs.scatter_(1, beam_seq_table[divm][:, :, t - divm - 1].reshape(-1, 1).to(device), + float('-inf')) + # suppress UNK tokens in the decoding + if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobs.size(1) - 1)] == 'UNK': + logprobs[:, logprobs.size(1) - 1] = logprobs[:, logprobs.size(1) - 1] - 1000 + # diversity is added here + # the function directly modifies the logprobs values and hence, we need to return + # the unaugmented ones for sorting the candidates in the end. # for historical + # reasons :-) + logprobs, unaug_logprobs = add_diversity(beam_seq_table, logprobs, t, divm, diversity_lambda, bdash) + + # infer new beams + beam_seq_table[divm], \ + beam_seq_logprobs_table[divm], \ + beam_logprobs_sum_table[divm], \ + state_table[divm] = beam_step(logprobs, + unaug_logprobs, + bdash, + t - divm, + beam_seq_table[divm], + beam_seq_logprobs_table[divm], + beam_logprobs_sum_table[divm], + state_table[divm]) + + # if time's up... or if end token is reached then copy beams + for b in range(batch_size): + is_end = beam_seq_table[divm][b, :, t - divm] == self.eos_idx + assert beam_seq_table[divm].shape[-1] == t - divm + 1 + if t == self.max_seq_length + divm - 1: + is_end.fill_(1) + for vix in range(bdash): + if is_end[vix]: + final_beam = { + 'seq': beam_seq_table[divm][b, vix].clone(), + 'logps': beam_seq_logprobs_table[divm][b, vix].clone(), + 'unaug_p': beam_seq_logprobs_table[divm][b, vix].sum().item(), + 'p': beam_logprobs_sum_table[divm][b, vix].item() + } + final_beam['p'] = length_penalty(t - divm + 1, final_beam['p']) + done_beams_table[b][divm].append(final_beam) + beam_logprobs_sum_table[divm][b, is_end] -= 1000 + + # move the current group one step forward in time + + it = beam_seq_table[divm][:, :, t - divm].reshape(-1) + logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it.cuda(), *( + args[divm] + [state_table[divm]])) + logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1) + + # all beams are sorted by their log-probabilities + done_beams_table = [[sorted(done_beams_table[b][i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)] + for b in range(batch_size)] + done_beams = [sum(_, []) for _ in done_beams_table] + return done_beams + + def old_beam_search(self, init_state, init_logprobs, *args, **kwargs): + + # function computes the similarity score to be augmented + def add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash): + local_time = t - divm + unaug_logprobsf = logprobsf.clone() + for prev_choice in range(divm): + prev_decisions = beam_seq_table[prev_choice][local_time] + for sub_beam in range(bdash): + for prev_labels in range(bdash): + logprobsf[sub_beam][prev_decisions[prev_labels]] = logprobsf[sub_beam][prev_decisions[ + prev_labels]] - diversity_lambda + return unaug_logprobsf + + # does one step of classical beam search + + def beam_step(logprobsf, unaug_logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state): + # INPUTS: + # logprobsf: probabilities augmented after diversity + # beam_size: obvious + # t : time instant + # beam_seq : tensor contanining the beams + # beam_seq_logprobs: tensor contanining the beam logprobs + # beam_logprobs_sum: tensor contanining joint logprobs + # OUPUTS: + # beam_seq : tensor containing the word indices of the decoded captions + # beam_seq_logprobs : log-probability of each decision made, same size as beam_seq + # beam_logprobs_sum : joint log-probability of each beam + + ys, ix = torch.sort(logprobsf, 1, True) + candidates = [] + cols = min(beam_size, ys.size(1)) + rows = beam_size + if t == 0: + rows = 1 + for c in range(cols): # for each column (word, essentially) + for q in range(rows): # for each beam expansion + # compute logprob of expanding beam q with word in (sorted) position c + local_logprob = ys[q, c].item() + candidate_logprob = beam_logprobs_sum[q] + local_logprob + # local_unaug_logprob = unaug_logprobsf[q,ix[q,c]] + candidates.append({'c': ix[q, c], 'q': q, 'p': candidate_logprob, 'r': unaug_logprobsf[q]}) + candidates = sorted(candidates, key=lambda x: -x['p']) + + new_state = [_.clone() for _ in state] + # beam_seq_prev, beam_seq_logprobs_prev + if t >= 1: + # we''ll need these as reference when we fork beams around + beam_seq_prev = beam_seq[:t].clone() + beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone() + for vix in range(beam_size): + v = candidates[vix] + # fork beam index q into index vix + if t >= 1: + beam_seq[:t, vix] = beam_seq_prev[:, v['q']] + beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:, v['q']] + # rearrange recurrent states + for state_ix in range(len(new_state)): + # copy over state in previous beam q to new beam at vix + new_state[state_ix][:, vix] = state[state_ix][:, v['q']] # dimension one is time step + # append new end terminal at the end of this beam + beam_seq[t, vix] = v['c'] # c'th word is the continuation + beam_seq_logprobs[t, vix] = v['r'] # the raw logprob here + beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam + state = new_state + return beam_seq, beam_seq_logprobs, beam_logprobs_sum, state, candidates + + # Start diverse_beam_search + opt = kwargs['opt'] + temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs + beam_size = opt.get('beam_size', 10) + group_size = opt.get('group_size', 1) + diversity_lambda = opt.get('diversity_lambda', 0.5) + decoding_constraint = opt.get('decoding_constraint', 0) + suppress_UNK = opt.get('suppress_UNK', 0) + length_penalty = utils.penalty_builder(opt.get('length_penalty', '')) + bdash = beam_size // group_size # beam per group + + # INITIALIZATIONS + beam_seq_table = [torch.LongTensor(self.max_seq_length, bdash).zero_() for _ in range(group_size)] + beam_seq_logprobs_table = [torch.FloatTensor(self.max_seq_length, bdash, self.vocab_size + 1).zero_() for _ in + range(group_size)] + beam_logprobs_sum_table = [torch.zeros(bdash) for _ in range(group_size)] + + # logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1) + done_beams_table = [[] for _ in range(group_size)] + # state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)] + state_table = list(zip(*[_.chunk(group_size, 1) for _ in init_state])) + logprobs_table = list(init_logprobs.chunk(group_size, 0)) + # END INIT + + # Chunk elements in the args + args = list(args) + if self.__class__.__name__ == 'AttEnsemble': + args = [[_.chunk(group_size) if _ is not None else [None] * group_size for _ in args_] for args_ in + args] # arg_name, model_name, group_name + args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in + range(group_size)] # group_name, arg_name, model_name + else: + args = [_.chunk(group_size) if _ is not None else [None] * group_size for _ in args] + args = [[args[i][j] for i in range(len(args))] for j in range(group_size)] + + for t in range(self.max_seq_length + group_size - 1): + for divm in range(group_size): + if t >= divm and t <= self.max_seq_length + divm - 1: + # add diversity + logprobsf = logprobs_table[divm].float() + # suppress previous word + if decoding_constraint and t - divm > 0: + logprobsf.scatter_(1, beam_seq_table[divm][t - divm - 1].unsqueeze(1).cuda(), float('-inf')) + # suppress UNK tokens in the decoding + if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobsf.size(1) - 1)] == 'UNK': + logprobsf[:, logprobsf.size(1) - 1] = logprobsf[:, logprobsf.size(1) - 1] - 1000 + # diversity is added here + # the function directly modifies the logprobsf values and hence, we need to return + # the unaugmented ones for sorting the candidates in the end. # for historical + # reasons :-) + unaug_logprobsf = add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash) + + # infer new beams + beam_seq_table[divm], \ + beam_seq_logprobs_table[divm], \ + beam_logprobs_sum_table[divm], \ + state_table[divm], \ + candidates_divm = beam_step(logprobsf, + unaug_logprobsf, + bdash, + t - divm, + beam_seq_table[divm], + beam_seq_logprobs_table[divm], + beam_logprobs_sum_table[divm], + state_table[divm]) + + # if time's up... or if end token is reached then copy beams + for vix in range(bdash): + if beam_seq_table[divm][t - divm, vix] == self.eos_idx or t == self.max_seq_length + divm - 1: + final_beam = { + 'seq': beam_seq_table[divm][:, vix].clone(), + 'logps': beam_seq_logprobs_table[divm][:, vix].clone(), + 'unaug_p': beam_seq_logprobs_table[divm][:, vix].sum().item(), + 'p': beam_logprobs_sum_table[divm][vix].item() + } + final_beam['p'] = length_penalty(t - divm + 1, final_beam['p']) + done_beams_table[divm].append(final_beam) + # don't continue beams from finished sequences + beam_logprobs_sum_table[divm][vix] = -1000 + + # move the current group one step forward in time + + it = beam_seq_table[divm][t - divm] + logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it.cuda(), *( + args[divm] + [state_table[divm]])) + logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1) + + # all beams are sorted by their log-probabilities + done_beams_table = [sorted(done_beams_table[i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)] + done_beams = sum(done_beams_table, []) + return done_beams + + def sample_next_word(self, logprobs, sample_method, temperature): + if sample_method == 'greedy': + sampleLogprobs, it = torch.max(logprobs.data, 1) + it = it.view(-1).long() + elif sample_method == 'gumbel': # gumbel softmax + def sample_gumbel(shape, eps=1e-20): + U = torch.rand(shape).cuda() + return -torch.log(-torch.log(U + eps) + eps) + + def gumbel_softmax_sample(logits, temperature): + y = logits + sample_gumbel(logits.size()) + return F.log_softmax(y / temperature, dim=-1) + + _logprobs = gumbel_softmax_sample(logprobs, temperature) + _, it = torch.max(_logprobs.data, 1) + sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) # gather the logprobs at sampled positions + else: + logprobs = logprobs / temperature + if sample_method.startswith('top'): # topk sampling + top_num = float(sample_method[3:]) + if 0 < top_num < 1: + # nucleus sampling from # The Curious Case of Neural Text Degeneration + probs = F.softmax(logprobs, dim=1) + sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1) + _cumsum = sorted_probs.cumsum(1) + mask = _cumsum < top_num + mask = torch.cat([torch.ones_like(mask[:, :1]), mask[:, :-1]], 1) + sorted_probs = sorted_probs * mask.float() + sorted_probs = sorted_probs / sorted_probs.sum(1, keepdim=True) + logprobs.scatter_(1, sorted_indices, sorted_probs.log()) + else: + the_k = int(top_num) + tmp = torch.empty_like(logprobs).fill_(float('-inf')) + topk, indices = torch.topk(logprobs, the_k, dim=1) + tmp = tmp.scatter(1, indices, topk) + logprobs = tmp + it = torch.distributions.Categorical(logits=logprobs.detach()).sample() + sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) # gather the logprobs at sampled positions + return it, sampleLogprobs \ No newline at end of file diff --git a/modules/config.pkl b/modules/config.pkl new file mode 100644 index 0000000000000000000000000000000000000000..31293e585081fe4de2004432722e04100ab9fa9c --- /dev/null +++ b/modules/config.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c454e6bddb15af52c82734f1796391bf3a10a6c5533ea095de06f661ebb858bb +size 1744 diff --git a/modules/dataloader.py b/modules/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..5808efd33b8e31dbc4924ea5ba43cdd35472aea2 --- /dev/null +++ b/modules/dataloader.py @@ -0,0 +1,59 @@ +import pdb + +import torch +from torch.utils.data import DataLoader +from .dataset import IuxrayMultiImageDataset, MimiccxrSingleImageDataset +from medclip import MedCLIPProcessor +import numpy as np + +class R2DataLoader(DataLoader): + def __init__(self, args, tokenizer, split, shuffle): + self.args = args + self.dataset_name = args.dataset + self.batch_size = args.bs + self.shuffle = shuffle + self.num_workers = args.num_workers + self.tokenizer = tokenizer + self.split = split + self.processor = MedCLIPProcessor() + + if self.dataset_name == 'iu_xray': + self.dataset = IuxrayMultiImageDataset(self.args, self.tokenizer, self.split, self.processor) + else: + self.dataset = MimiccxrSingleImageDataset(self.args, self.tokenizer, self.split, self.processor) + + self.init_kwargs = { + 'dataset': self.dataset, + 'batch_size': self.batch_size, + 'shuffle': self.shuffle, + 'collate_fn': self.collate_fn, + 'num_workers': self.num_workers + } + super().__init__(**self.init_kwargs) + + @staticmethod + def collate_fn(data): + image_id_batch, image_batch, report_ids_batch, report_masks_batch, processor_ids_batch, processor_mask_batch, seq_lengths_batch, processor_lenghts_batch = zip(*data) + image_batch = torch.stack(image_batch, 0) + + max_seq_length = max(seq_lengths_batch) + target_batch = np.zeros((len(report_ids_batch), max_seq_length), dtype=int) + target_masks_batch = np.zeros((len(report_ids_batch), max_seq_length), dtype=int) + + max_processor_length = max(processor_lenghts_batch) + target_processor_batch = np.zeros((len(processor_ids_batch), max_processor_length), dtype=int) + target_processor_mask_batch = np.zeros((len(processor_mask_batch), max_processor_length), dtype=int) + + for i, report_ids in enumerate(report_ids_batch): + target_batch[i, :len(report_ids)] = report_ids + + for i, report_masks in enumerate(report_masks_batch): + target_masks_batch[i, :len(report_masks)] = report_masks + + for i, report_ids in enumerate(processor_ids_batch): + target_processor_batch[i, :len(report_ids)] = report_ids + + for i, report_masks in enumerate(processor_mask_batch): + target_processor_mask_batch[i, :len(report_masks)] = report_masks + + return image_id_batch, image_batch, torch.LongTensor(target_batch), torch.FloatTensor(target_masks_batch), torch.FloatTensor(target_processor_batch), torch.FloatTensor(target_processor_mask_batch) diff --git a/modules/dataloaders.py b/modules/dataloaders.py new file mode 100644 index 0000000000000000000000000000000000000000..92982b553528a51d416674c4ae93738c986a8a4b --- /dev/null +++ b/modules/dataloaders.py @@ -0,0 +1,62 @@ +import torch +import numpy as np +from torchvision import transforms +from torch.utils.data import DataLoader +from .datasets import IuxrayMultiImageDataset, MimiccxrSingleImageDataset + + +class R2DataLoader(DataLoader): + def __init__(self, args, tokenizer, split, shuffle): + self.args = args + self.dataset_name = args.dataset_name + self.batch_size = args.batch_size + self.shuffle = shuffle + self.num_workers = args.num_workers + self.tokenizer = tokenizer + self.split = split + + if split == 'train': + self.transform = transforms.Compose([ + transforms.Resize(256), + transforms.RandomCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225))]) + else: + self.transform = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225))]) + + if self.dataset_name == 'iu_xray': + self.dataset = IuxrayMultiImageDataset(self.args, self.tokenizer, self.split, transform=self.transform) + else: + self.dataset = MimiccxrSingleImageDataset(self.args, self.tokenizer, self.split, transform=self.transform) + + self.init_kwargs = { + 'dataset': self.dataset, + 'batch_size': self.batch_size, + 'shuffle': self.shuffle, + 'collate_fn': self.collate_fn, + 'num_workers': self.num_workers + } + super().__init__(**self.init_kwargs) + + @staticmethod + def collate_fn(data): + images_id, images, reports_ids, reports_masks, seq_lengths = zip(*data) + images = torch.stack(images, 0) + max_seq_length = max(seq_lengths) + + targets = np.zeros((len(reports_ids), max_seq_length), dtype=int) + targets_masks = np.zeros((len(reports_ids), max_seq_length), dtype=int) + + for i, report_ids in enumerate(reports_ids): + targets[i, :len(report_ids)] = report_ids + + for i, report_masks in enumerate(reports_masks): + targets_masks[i, :len(report_masks)] = report_masks + + return images_id, images, torch.LongTensor(targets), torch.FloatTensor(targets_masks) diff --git a/modules/dataset.py b/modules/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..4c0f15630ee417964863c6be31ed43f9fffdc53a --- /dev/null +++ b/modules/dataset.py @@ -0,0 +1,68 @@ +import os +from PIL import Image +import json +from torch.utils.data import Dataset +import numpy as np +import torch + + +class BaseDataset(Dataset): + def __init__(self, args, tokenizer, split, processor): + self.image_dir = args.image_dir + self.ann_path = args.json_path + self.max_seq_length = args.max_seq_length + self.split = split + self.tokenizer = tokenizer + self.ann = json.loads(open(self.ann_path, 'r').read()) + self.examples = self.ann[self.split] + self.processor = processor + + def preprocess_text(self, text): + ids = self.tokenizer(text)[:self.max_seq_length] + mask = [1] * len(ids) + text_inputs = self.processor(text=text, return_tensors="pt",truncation=True, padding=False, max_length=self.max_seq_length) + processor_ids = text_inputs['input_ids'].squeeze(0).tolist() + processor_mask = text_inputs['attention_mask'].squeeze(0).tolist() + return ids, mask, processor_ids, processor_mask + + def __len__(self): + return len(self.examples) + + +class IuxrayMultiImageDataset(BaseDataset): + def __getitem__(self, idx): + example = self.examples[idx] + report = example['report'] + report_ids, report_masks, processor_ids, processor_mask = self.preprocess_text(report) + + image_id = example['id'] + image_path = example['image_path'] + image_1 = Image.open(os.path.join(self.image_dir, image_path[0])).convert('RGB') + image_2 = Image.open(os.path.join(self.image_dir, image_path[1])).convert('RGB') + # MedCLIP processing + image_inputs_1 = self.processor(images=image_1, return_tensors="pt") + image_inputs_2 = self.processor(images=image_2, return_tensors="pt") + image = torch.stack((image_inputs_1.pixel_values[0], image_inputs_2.pixel_values[0]), 0) + + seq_length = len(report_ids) + processor_length = len(processor_ids) + sample = (image_id, image, report_ids, report_masks, processor_ids, processor_mask, seq_length, processor_length) + return sample + + +class MimiccxrSingleImageDataset(BaseDataset): + def __getitem__(self, idx): + example = self.examples[idx] + report = example['report'] + report_ids, report_masks, processor_ids, processor_mask = self.preprocess_text(report) + + image_id = example['id'] + image_path = example['image_path'] + image = Image.open(os.path.join(self.image_dir, image_path[0])).convert('RGB') + image_inputs = self.processor(images=image, return_tensors="pt") + image = image_inputs.pixel_values[0] + + seq_length = len(report_ids) + processor_length = len(processor_ids) + sample = (image_id, image, report_ids, report_masks, processor_ids, processor_mask, seq_length, processor_length) + return sample diff --git a/modules/datasets.py b/modules/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..98c53cde7c006776ad04925ad1510cec597c324d --- /dev/null +++ b/modules/datasets.py @@ -0,0 +1,57 @@ +import os +import json +import torch +from PIL import Image +from torch.utils.data import Dataset + + +class BaseDataset(Dataset): + def __init__(self, args, tokenizer, split, transform=None): + self.image_dir = args.image_dir + self.ann_path = args.ann_path + self.max_seq_length = args.max_seq_length + self.split = split + self.tokenizer = tokenizer + self.transform = transform + self.ann = json.loads(open(self.ann_path, 'r').read()) + + self.examples = self.ann[self.split] + for i in range(len(self.examples)): + self.examples[i]['ids'] = tokenizer(self.examples[i]['report'])[:self.max_seq_length] + self.examples[i]['mask'] = [1] * len(self.examples[i]['ids']) + + def __len__(self): + return len(self.examples) + + +class IuxrayMultiImageDataset(BaseDataset): + def __getitem__(self, idx): + example = self.examples[idx] + image_id = example['id'] + image_path = example['image_path'] + image_1 = Image.open(os.path.join(self.image_dir, image_path[0])).convert('RGB') + image_2 = Image.open(os.path.join(self.image_dir, image_path[1])).convert('RGB') + if self.transform is not None: + image_1 = self.transform(image_1) + image_2 = self.transform(image_2) + image = torch.stack((image_1, image_2), 0) + report_ids = example['ids'] + report_masks = example['mask'] + seq_length = len(report_ids) + sample = (image_id, image, report_ids, report_masks, seq_length) + return sample + + +class MimiccxrSingleImageDataset(BaseDataset): + def __getitem__(self, idx): + example = self.examples[idx] + image_id = example['id'] + image_path = example['image_path'] + image = Image.open(os.path.join(self.image_dir, image_path[0])).convert('RGB') + if self.transform is not None: + image = self.transform(image) + report_ids = example['ids'] + report_masks = example['mask'] + seq_length = len(report_ids) + sample = (image_id, image, report_ids, report_masks, seq_length) + return sample diff --git a/modules/decoder.py b/modules/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..09c3ee7093043af735265810f0bc32a5b8ecc5a2 --- /dev/null +++ b/modules/decoder.py @@ -0,0 +1,50 @@ +import numpy as np +import torch +import torch.nn as nn +import pickle +from typing import Tuple +from transformers import GPT2LMHeadModel +from .att_models import AttModel +import pdb + +class MLP(nn.Module): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.model(x) + + def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh): + super(MLP, self).__init__() + layers = [] + for i in range(len(sizes) - 1): + layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias)) + if i < len(sizes) - 2: + layers.append(act()) + self.model = nn.Sequential(*layers) + +class DeCap(AttModel): + + def __init__(self, args, tokenizer): + super(DeCap, self).__init__(args, tokenizer) + + # decoder: 4 layers transformer with 4 attention heads + # the decoder is not pretrained + with open('./decoder_config/decoder_config.pkl', 'rb') as f: + config = pickle.load(f) + # Change the parameters you need + config.vocab_size = tokenizer.get_vocab_size() + config.bos_token_id = tokenizer.bos_token_id + config.eos_token_id = tokenizer.eos_token_id + self.decoder = GPT2LMHeadModel(config) + self.embedding_size = self.decoder.transformer.wte.weight.shape[1] + self.prefix_size = 512 + self.clip_project = MLP((self.prefix_size, self.embedding_size)) + + def _forward(self, clip_features, gpt_tokens): + + embedding_text = self.decoder.transformer.wte(gpt_tokens) + embedding_clip = self.clip_project(clip_features) + embedding_clip = embedding_clip.reshape(-1, 1, self.embedding_size) + embedding_cat = torch.cat([embedding_clip, embedding_text], dim=1) + out = self.decoder(inputs_embeds=embedding_cat) + return out + diff --git a/modules/encoder_decoder.py b/modules/encoder_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..64824c9c02803a54556d3a714b49794f16a6ae85 --- /dev/null +++ b/modules/encoder_decoder.py @@ -0,0 +1,391 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .att_model import pack_wrapper, AttModel + + +def clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) + + +def attention(query, key, value, mask=None, dropout=None): + d_k = query.size(-1) + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e9) + p_attn = F.softmax(scores, dim=-1) + if dropout is not None: + p_attn = dropout(p_attn) + return torch.matmul(p_attn, value), p_attn + + +def subsequent_mask(size): + attn_shape = (1, size, size) + subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') + return torch.from_numpy(subsequent_mask) == 0 + + +class Transformer(nn.Module): + def __init__(self, encoder, decoder, src_embed, tgt_embed, rm): + super(Transformer, self).__init__() + self.encoder = encoder + self.decoder = decoder + self.src_embed = src_embed + self.tgt_embed = tgt_embed + self.rm = rm + + def forward(self, src, tgt, src_mask, tgt_mask): + return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask) + + def encode(self, src, src_mask): + return self.encoder(self.src_embed(src), src_mask) + + def decode(self, hidden_states, src_mask, tgt, tgt_mask): + memory = self.rm.init_memory(hidden_states.size(0)).to(hidden_states) + memory = self.rm(self.tgt_embed(tgt), memory) + return self.decoder(self.tgt_embed(tgt), hidden_states, src_mask, tgt_mask, memory) + + +class Encoder(nn.Module): + def __init__(self, layer, N): + super(Encoder, self).__init__() + self.layers = clones(layer, N) + self.norm = LayerNorm(layer.d_model) + + def forward(self, x, mask): + for layer in self.layers: + x = layer(x, mask) + return self.norm(x) + + +class EncoderLayer(nn.Module): + def __init__(self, d_model, self_attn, feed_forward, dropout): + super(EncoderLayer, self).__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.sublayer = clones(SublayerConnection(d_model, dropout), 2) + self.d_model = d_model + + def forward(self, x, mask): + x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) + return self.sublayer[1](x, self.feed_forward) + + +class SublayerConnection(nn.Module): + def __init__(self, d_model, dropout): + super(SublayerConnection, self).__init__() + self.norm = LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, sublayer): + return x + self.dropout(sublayer(self.norm(x))) + + +class LayerNorm(nn.Module): + def __init__(self, features, eps=1e-6): + super(LayerNorm, self).__init__() + self.gamma = nn.Parameter(torch.ones(features)) + self.beta = nn.Parameter(torch.zeros(features)) + self.eps = eps + + def forward(self, x): + mean = x.mean(-1, keepdim=True) + std = x.std(-1, keepdim=True) + return self.gamma * (x - mean) / (std + self.eps) + self.beta + + +class Decoder(nn.Module): + def __init__(self, layer, N): + super(Decoder, self).__init__() + self.layers = clones(layer, N) + self.norm = LayerNorm(layer.d_model) + + def forward(self, x, hidden_states, src_mask, tgt_mask, memory): + for layer in self.layers: + x = layer(x, hidden_states, src_mask, tgt_mask, memory) + return self.norm(x) + + +class DecoderLayer(nn.Module): + def __init__(self, d_model, self_attn, src_attn, feed_forward, dropout, rm_num_slots, rm_d_model): + super(DecoderLayer, self).__init__() + self.d_model = d_model + self.self_attn = self_attn + self.src_attn = src_attn + self.feed_forward = feed_forward + self.sublayer = clones(ConditionalSublayerConnection(d_model, dropout, rm_num_slots, rm_d_model), 3) + + def forward(self, x, hidden_states, src_mask, tgt_mask, memory): + m = hidden_states + x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask), memory) + x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask), memory) + return self.sublayer[2](x, self.feed_forward, memory) + + +class ConditionalSublayerConnection(nn.Module): + def __init__(self, d_model, dropout, rm_num_slots, rm_d_model): + super(ConditionalSublayerConnection, self).__init__() + self.norm = ConditionalLayerNorm(d_model, rm_num_slots, rm_d_model) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, sublayer, memory): + return x + self.dropout(sublayer(self.norm(x, memory))) + + +class ConditionalLayerNorm(nn.Module): + def __init__(self, d_model, rm_num_slots, rm_d_model, eps=1e-6): + super(ConditionalLayerNorm, self).__init__() + self.gamma = nn.Parameter(torch.ones(d_model)) + self.beta = nn.Parameter(torch.zeros(d_model)) + self.rm_d_model = rm_d_model + self.rm_num_slots = rm_num_slots + self.eps = eps + + self.mlp_gamma = nn.Sequential(nn.Linear(rm_num_slots * rm_d_model, d_model), + nn.ReLU(inplace=True), + nn.Linear(rm_d_model, rm_d_model)) + + self.mlp_beta = nn.Sequential(nn.Linear(rm_num_slots * rm_d_model, d_model), + nn.ReLU(inplace=True), + nn.Linear(d_model, d_model)) + + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.constant_(m.bias, 0.1) + + def forward(self, x, memory): + mean = x.mean(-1, keepdim=True) + std = x.std(-1, keepdim=True) + delta_gamma = self.mlp_gamma(memory) + delta_beta = self.mlp_beta(memory) + gamma_hat = self.gamma.clone() + beta_hat = self.beta.clone() + gamma_hat = torch.stack([gamma_hat] * x.size(0), dim=0) + gamma_hat = torch.stack([gamma_hat] * x.size(1), dim=1) + beta_hat = torch.stack([beta_hat] * x.size(0), dim=0) + beta_hat = torch.stack([beta_hat] * x.size(1), dim=1) + gamma_hat += delta_gamma + beta_hat += delta_beta + return gamma_hat * (x - mean) / (std + self.eps) + beta_hat + + +class MultiHeadedAttention(nn.Module): + def __init__(self, h, d_model, dropout=0.1): + super(MultiHeadedAttention, self).__init__() + assert d_model % h == 0 + self.d_k = d_model // h + self.h = h + self.linears = clones(nn.Linear(d_model, d_model), 4) + self.attn = None + self.dropout = nn.Dropout(p=dropout) + + def forward(self, query, key, value, mask=None): + if mask is not None: + mask = mask.unsqueeze(1) + nbatches = query.size(0) + query, key, value = \ + [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) + for l, x in zip(self.linears, (query, key, value))] + + x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout) + + x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k) + return self.linears[-1](x) + + +class PositionwiseFeedForward(nn.Module): + def __init__(self, d_model, d_ff, dropout=0.1): + super(PositionwiseFeedForward, self).__init__() + self.w_1 = nn.Linear(d_model, d_ff) + self.w_2 = nn.Linear(d_ff, d_model) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + return self.w_2(self.dropout(F.relu(self.w_1(x)))) + + +class Embeddings(nn.Module): + def __init__(self, d_model, vocab): + super(Embeddings, self).__init__() + self.lut = nn.Embedding(vocab, d_model) + self.d_model = d_model + + def forward(self, x): + return self.lut(x) * math.sqrt(self.d_model) + + +class PositionalEncoding(nn.Module): + def __init__(self, d_model, dropout, max_len=5000): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len).unsqueeze(1).float() + div_term = torch.exp(torch.arange(0, d_model, 2).float() * + -(math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.register_buffer('pe', pe) + + def forward(self, x): + x = x + self.pe[:, :x.size(1)] + return self.dropout(x) + + +class RelationalMemory(nn.Module): + + def __init__(self, num_slots, d_model, num_heads=1): + super(RelationalMemory, self).__init__() + self.num_slots = num_slots + self.num_heads = num_heads + self.d_model = d_model + + self.attn = MultiHeadedAttention(num_heads, d_model) + self.mlp = nn.Sequential(nn.Linear(self.d_model, self.d_model), + nn.ReLU(), + nn.Linear(self.d_model, self.d_model), + nn.ReLU()) + + self.W = nn.Linear(self.d_model, self.d_model * 2) + self.U = nn.Linear(self.d_model, self.d_model * 2) + + def init_memory(self, batch_size): + memory = torch.stack([torch.eye(self.num_slots)] * batch_size) + if self.d_model > self.num_slots: + diff = self.d_model - self.num_slots + pad = torch.zeros((batch_size, self.num_slots, diff)) + memory = torch.cat([memory, pad], -1) + elif self.d_model < self.num_slots: + memory = memory[:, :, :self.d_model] + + return memory + + def forward_step(self, input, memory): +# print('inputinputinputinputinput',input.size()) +# print('memorymemorymemorymemorymemorymemory',memory.size()) + + memory = memory.reshape(-1, self.num_slots, self.d_model) +# if input.shape[0]!=memory.shape[0]: +# input=input.repeat(round(memory.shape[0]/input.shape[0]),1) + q = memory + k = torch.cat([memory, input.unsqueeze(1)], 1) + v = torch.cat([memory, input.unsqueeze(1)], 1) + next_memory = memory + self.attn(q, k, v) + next_memory = next_memory + self.mlp(next_memory) + + gates = self.W(input.unsqueeze(1)) + self.U(torch.tanh(memory)) + gates = torch.split(gates, split_size_or_sections=self.d_model, dim=2) + input_gate, forget_gate = gates + input_gate = torch.sigmoid(input_gate) + forget_gate = torch.sigmoid(forget_gate) + + next_memory = input_gate * torch.tanh(next_memory) + forget_gate * memory + next_memory = next_memory.reshape(-1, self.num_slots * self.d_model) + + return next_memory + + def forward(self, inputs, memory): + outputs = [] + for i in range(inputs.shape[1]): + memory = self.forward_step(inputs[:, i], memory) + outputs.append(memory) + outputs = torch.stack(outputs, dim=1) + + return outputs + + +class EncoderDecoder(AttModel): + + def make_model(self, tgt_vocab): + c = copy.deepcopy + attn = MultiHeadedAttention(self.num_heads, self.d_model) + ff = PositionwiseFeedForward(self.d_model, self.d_ff, self.dropout) + position = PositionalEncoding(self.d_model, self.dropout) + rm = RelationalMemory(num_slots=self.rm_num_slots, d_model=self.rm_d_model, num_heads=self.rm_num_heads) + model = Transformer( + Encoder(EncoderLayer(self.d_model, c(attn), c(ff), self.dropout), self.num_layers), + Decoder( + DecoderLayer(self.d_model, c(attn), c(attn), c(ff), self.dropout, self.rm_num_slots, self.rm_d_model), + self.num_layers), + lambda x: x, + nn.Sequential(Embeddings(self.d_model, tgt_vocab), c(position)), + rm) + for p in model.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + return model + + def __init__(self, args, tokenizer): + super(EncoderDecoder, self).__init__(args, tokenizer) + self.args = args + self.num_layers = args.num_layers + self.d_model = args.d_model + self.d_ff = args.d_ff + self.num_heads = args.num_heads + self.dropout = args.dropout + self.rm_num_slots = args.rm_num_slots + self.rm_num_heads = args.rm_num_heads + self.rm_d_model = args.rm_d_model + + tgt_vocab = self.vocab_size + 1 + + self.model = self.make_model(tgt_vocab) + self.logit = nn.Linear(args.d_model, tgt_vocab) + + def init_hidden(self, bsz): + return [] + + def _prepare_feature(self, fc_feats, att_feats, att_masks): + + att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks) + memory = self.model.encode(att_feats, att_masks) + + return fc_feats[..., :1], att_feats[..., :1], memory, att_masks + + def _prepare_feature_forward(self, att_feats, att_masks=None, seq=None): + att_feats, att_masks = self.clip_att(att_feats, att_masks) + att_feats = pack_wrapper(self.att_embed, att_feats, att_masks) + + if att_masks is None: + att_masks = att_feats.new_ones(att_feats.shape[:2], dtype=torch.long) + att_masks = att_masks.unsqueeze(-2) + + if seq is not None: + # crop the last one + seq = seq[:, :-1] + seq_mask = (seq.data > 0) + seq_mask[:, 0] += True + + seq_mask = seq_mask.unsqueeze(-2) + seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask) + else: + seq_mask = None + + return att_feats, seq, att_masks, seq_mask + + def _forward(self, fc_feats, att_feats, seq, att_masks=None): + + att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq) + out = self.model(att_feats, seq, att_masks, seq_mask) + outputs = F.log_softmax(self.logit(out), dim=-1) + return outputs + + def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask): + + if len(state) == 0: + ys = it.unsqueeze(1) + else: + ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1) + out = self.model.decode(memory, mask, ys, subsequent_mask(ys.size(1)).to(memory.device)) + return out[:, -1], [ys.unsqueeze(0)] diff --git a/modules/loss.py b/modules/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..e807833fe1727e816712e06703fa3518c1115542 --- /dev/null +++ b/modules/loss.py @@ -0,0 +1,22 @@ +import torch +import torch.nn as nn + + +class LanguageModelCriterion(nn.Module): + def __init__(self): + super(LanguageModelCriterion, self).__init__() + + def forward(self, input, target, mask): + # truncate to the same size + target = target[:, :input.size(1)] + mask = mask[:, :input.size(1)] + output = -input.gather(2, target.long().unsqueeze(2)).squeeze(2) * mask + output = torch.sum(output) / torch.sum(mask) + + return output + + +def compute_loss(output, reports_ids, reports_masks): + criterion = LanguageModelCriterion() + loss = criterion(output, reports_ids[:, 1:], reports_masks[:, 1:]).mean() + return loss diff --git a/modules/metrics.py b/modules/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..4b8e88c97a0787eb7ae87f0fd255e43965f444da --- /dev/null +++ b/modules/metrics.py @@ -0,0 +1,33 @@ +from pycocoevalcap.bleu.bleu import Bleu +from pycocoevalcap.meteor import Meteor +from pycocoevalcap.rouge import Rouge + + +def compute_scores(gts, res): + """ + Performs the MS COCO evaluation using the Python 3 implementation (https://github.com/salaniz/pycocoevalcap) + + :param gts: Dictionary with the image ids and their gold captions, + :param res: Dictionary with the image ids ant their generated captions + :print: Evaluation score (the mean of the scores of all the instances) for each measure + """ + + # Set up scorers + scorers = [ + (Bleu(4), ["BLEU_1", "BLEU_2", "BLEU_3", "BLEU_4"]), + (Meteor(), "METEOR"), + (Rouge(), "ROUGE_L") + ] + eval_res = {} + # Compute score for each metric + for scorer, method in scorers: + try: + score, scores = scorer.compute_score(gts, res, verbose=0) + except TypeError: + score, scores = scorer.compute_score(gts, res) + if type(method) == list: + for sc, m in zip(score, method): + eval_res[m] = sc + else: + eval_res[method] = score + return eval_res \ No newline at end of file diff --git a/modules/optimizers.py b/modules/optimizers.py new file mode 100644 index 0000000000000000000000000000000000000000..510566ed87547324e6a98d145c6957f7a8b629d9 --- /dev/null +++ b/modules/optimizers.py @@ -0,0 +1,18 @@ +import torch + + +def build_optimizer(args, model): + ve_params = list(map(id, model.visual_extractor.parameters())) + ed_params = filter(lambda x: id(x) not in ve_params, model.parameters()) + optimizer = getattr(torch.optim, args.optim)( + [{'params': model.visual_extractor.parameters(), 'lr': args.lr_ve}, + {'params': ed_params, 'lr': args.lr_ed}], + weight_decay=args.weight_decay, + amsgrad=args.amsgrad + ) + return optimizer + + +def build_lr_scheduler(args, optimizer): + lr_scheduler = getattr(torch.optim.lr_scheduler, args.lr_scheduler)(optimizer, args.step_size, args.gamma) + return lr_scheduler diff --git a/modules/tester.py b/modules/tester.py new file mode 100644 index 0000000000000000000000000000000000000000..097416800779ec26bb1c7c0689d2d4b6289b1f05 --- /dev/null +++ b/modules/tester.py @@ -0,0 +1,144 @@ +import logging +import os +from abc import abstractmethod + +import cv2 +import numpy as np +import pandas as pd +import spacy +import torch +from tqdm import tqdm + +from modules.utils import generate_heatmap + + +class BaseTester(object): + def __init__(self, model, criterion, metric_ftns, args): + self.args = args + + logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) + self.logger = logging.getLogger(__name__) + + # setup GPU device if available, move model into configured device + self.device, device_ids = self._prepare_device(args.n_gpu) + self.model = model.to(self.device) + if len(device_ids) > 1: + self.model = torch.nn.DataParallel(model, device_ids=device_ids) + + self.criterion = criterion + self.metric_ftns = metric_ftns + + self.epochs = self.args.epochs + self.save_dir = self.args.save_dir + if not os.path.exists(self.save_dir): + os.makedirs(self.save_dir) + + self._load_checkpoint(args.load) + + @abstractmethod + def test(self): + raise NotImplementedError + + @abstractmethod + def plot(self): + raise NotImplementedError + + def _prepare_device(self, n_gpu_use): + n_gpu = torch.cuda.device_count() + if n_gpu_use > 0 and n_gpu == 0: + self.logger.warning( + "Warning: There\'s no GPU available on this machine," "training will be performed on CPU.") + n_gpu_use = 0 + if n_gpu_use > n_gpu: + self.logger.warning( + "Warning: The number of GPU\'s configured to use is {}, but only {} are available " "on this machine.".format( + n_gpu_use, n_gpu)) + n_gpu_use = n_gpu + device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu') + list_ids = list(range(n_gpu_use)) + return device, list_ids + + def _load_checkpoint(self, load_path): + load_path = str(load_path) + self.logger.info("Loading checkpoint: {} ...".format(load_path)) + checkpoint = torch.load(load_path) + self.model.load_state_dict(checkpoint) + + +class Tester(BaseTester): + def __init__(self, model, criterion, metric_ftns, args, test_dataloader): + super(Tester, self).__init__(model, criterion, metric_ftns, args) + self.test_dataloader = test_dataloader + + def test(self): + self.logger.info('Start to evaluate in the test set.') + self.model.eval() + log = dict() + with torch.no_grad(): + test_gts, test_res = [], [] + for batch_idx, (images_id, images, reports_ids, reports_masks, align_ids, align_masks) in enumerate(self.test_dataloader): + images, reports_ids, reports_masks, align_ids, align_masks = images.to(self.device), reports_ids.to(self.device), \ + reports_masks.to(self.device), align_ids.to(self.device), align_masks.to(self.device) + output = self.model(reports_ids, align_ids, align_masks, images, mode='sample') + reports = self.model.tokenizer.decode_batch(output.cpu().numpy()) + ground_truths = self.model.tokenizer.decode_batch(reports_ids[:, 1:].cpu().numpy()) + test_res.extend(reports) + test_gts.extend(ground_truths) + + test_met = self.metric_ftns({i: [gt] for i, gt in enumerate(test_gts)}, + {i: [re] for i, re in enumerate(test_res)}) + log.update(**{'test_' + k: v for k, v in test_met.items()}) + print(log) + + test_res, test_gts = pd.DataFrame(test_res), pd.DataFrame(test_gts) + test_res.to_csv(os.path.join(self.save_dir, "res.csv"), index=False, header=False) + test_gts.to_csv(os.path.join(self.save_dir, "gts.csv"), index=False, header=False) + + return log + + def plot(self): + assert self.args.batch_size == 1 and self.args.beam_size == 1 + self.logger.info('Start to plot attention weights in the test set.') + os.makedirs(os.path.join(self.save_dir, "attentions"), exist_ok=True) + os.makedirs(os.path.join(self.save_dir, "attentions_entities"), exist_ok=True) + ner = spacy.load("en_core_sci_sm") + mean = torch.tensor((0.485, 0.456, 0.406)) + std = torch.tensor((0.229, 0.224, 0.225)) + mean = mean[:, None, None] + std = std[:, None, None] + + self.model.eval() + with torch.no_grad(): + for batch_idx, (images_id, images, reports_ids, reports_masks) in tqdm(enumerate(self.test_dataloader)): + images, reports_ids, reports_masks = images.to(self.device), reports_ids.to( + self.device), reports_masks.to(self.device) + output, _ = self.model(images, mode='sample') + image = torch.clamp((images[0].cpu() * std + mean) * 255, 0, 255).int().cpu().numpy() + report = self.model.tokenizer.decode_batch(output.cpu().numpy())[0].split() + + char2word = [idx for word_idx, word in enumerate(report) for idx in [word_idx] * (len(word) + 1)][:-1] + + attention_weights = self.model.encoder_decoder.attention_weights[:-1] + assert len(attention_weights) == len(report) + for word_idx, (attns, word) in enumerate(zip(attention_weights, report)): + for layer_idx, attn in enumerate(attns): + os.makedirs(os.path.join(self.save_dir, "attentions", "{:04d}".format(batch_idx), + "layer_{}".format(layer_idx)), exist_ok=True) + + heatmap = generate_heatmap(image, attn.mean(1).squeeze()) + cv2.imwrite(os.path.join(self.save_dir, "attentions", "{:04d}".format(batch_idx), + "layer_{}".format(layer_idx), "{:04d}_{}.png".format(word_idx, word)), + heatmap) + + for ne_idx, ne in enumerate(ner(" ".join(report)).ents): + for layer_idx in range(len(attention_weights[0])): + os.makedirs(os.path.join(self.save_dir, "attentions_entities", "{:04d}".format(batch_idx), + "layer_{}".format(layer_idx)), exist_ok=True) + attn = [attns[layer_idx] for attns in + attention_weights[char2word[ne.start_char]:char2word[ne.end_char] + 1]] + attn = np.concatenate(attn, axis=2) + heatmap = generate_heatmap(image, attn.mean(1).mean(1).squeeze()) + cv2.imwrite(os.path.join(self.save_dir, "attentions_entities", "{:04d}".format(batch_idx), + "layer_{}".format(layer_idx), "{:04d}_{}.png".format(ne_idx, ne)), + heatmap) diff --git a/modules/tokenizers.py b/modules/tokenizers.py new file mode 100644 index 0000000000000000000000000000000000000000..dddb6591f8a6ef47cee9fe533cd8239501d68dc0 --- /dev/null +++ b/modules/tokenizers.py @@ -0,0 +1,95 @@ +import json +import re +from collections import Counter + + +class Tokenizer(object): + def __init__(self, args): + self.ann_path = args.ann_path + self.threshold = args.threshold + self.dataset_name = args.dataset_name + if self.dataset_name == 'iu_xray': + self.clean_report = self.clean_report_iu_xray + else: + self.clean_report = self.clean_report_mimic_cxr + self.ann = json.loads(open(self.ann_path, 'r').read()) + self.token2idx, self.idx2token = self.create_vocabulary() + + def create_vocabulary(self): + total_tokens = [] + + for example in self.ann['train']: + tokens = self.clean_report(example['report']).split() + for token in tokens: + total_tokens.append(token) + + counter = Counter(total_tokens) + vocab = [k for k, v in counter.items() if v >= self.threshold] + [''] + vocab.sort() + token2idx, idx2token = {}, {} + for idx, token in enumerate(vocab): + token2idx[token] = idx + 1 + idx2token[idx + 1] = token + return token2idx, idx2token + + def clean_report_iu_xray(self, report): + report_cleaner = lambda t: t.replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '') \ + .replace('. 2. ', '. ').replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ') \ + .replace(' 2. ', '. ').replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \ + .strip().lower().split('. ') + sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', ''). + replace('\\', '').replace("'", '').strip().lower()) + tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []] + report = ' . '.join(tokens) + ' .' + return report + + def clean_report_mimic_cxr(self, report): + report_cleaner = lambda t: t.replace('\n', ' ').replace('__', '_').replace('__', '_').replace('__', '_') \ + .replace('__', '_').replace('__', '_').replace('__', '_').replace('__', '_').replace(' ', ' ') \ + .replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ') \ + .replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.') \ + .replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '').replace('. 2. ', '. ') \ + .replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ').replace(' 2. ', '. ') \ + .replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \ + .strip().lower().split('. ') + sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '') + .replace('\\', '').replace("'", '').strip().lower()) + tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []] + report = ' . '.join(tokens) + ' .' + return report + + def get_token_by_id(self, id): + return self.idx2token[id] + + def get_id_by_token(self, token): + if token not in self.token2idx: + return self.token2idx[''] + return self.token2idx[token] + + def get_vocab_size(self): + return len(self.token2idx) + + def __call__(self, report): + tokens = self.clean_report(report).split() + ids = [] + for token in tokens: + ids.append(self.get_id_by_token(token)) + ids = [0] + ids + [0] + return ids + + def decode(self, ids): + txt = '' + for i, idx in enumerate(ids): + if idx > 0: + if i >= 1: + txt += ' ' + txt += self.idx2token[idx] + else: + break + return txt + + def decode_batch(self, ids_batch): + out = [] + for ids in ids_batch: + out.append(self.decode(ids)) + return out diff --git a/modules/trainer.py b/modules/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..7edbdfb0c5fac1503c6fe4afa5e8b154daa14e10 --- /dev/null +++ b/modules/trainer.py @@ -0,0 +1,255 @@ +import os +from abc import abstractmethod +import json +import time +import torch +import pandas as pd +from numpy import inf + + +class BaseTrainer(object): + def __init__(self, model, criterion, metric_ftns, optimizer, args): + self.args = args + + # setup GPU device if available, move model into configured device + self.device, device_ids = self._prepare_device(args.n_gpu) + self.model = model.to(self.device) + if len(device_ids) > 1: + self.model = torch.nn.DataParallel(model, device_ids=device_ids) + + self.criterion = criterion + self.metric_ftns = metric_ftns + self.optimizer = optimizer + + self.epochs = self.args.epochs + self.save_period = self.args.save_period + + self.mnt_mode = args.monitor_mode + self.mnt_metric = 'val_' + args.monitor_metric + self.mnt_metric_test = 'test_' + args.monitor_metric + assert self.mnt_mode in ['min', 'max'] + + self.mnt_best = inf if self.mnt_mode == 'min' else -inf + self.early_stop = getattr(self.args, 'early_stop', inf) + + self.start_epoch = 1 + self.checkpoint_dir = args.save_dir + + if not os.path.exists(self.checkpoint_dir): + os.makedirs(self.checkpoint_dir) + + if args.resume is not None: + self._resume_checkpoint(args.resume) + + self.best_recorder = {'val': {self.mnt_metric: self.mnt_best}, + 'test': {self.mnt_metric_test: self.mnt_best}} + + @abstractmethod + def _train_epoch(self, epoch): + raise NotImplementedError + + def train(self): + not_improved_count = 0 + for epoch in range(self.start_epoch, self.epochs + 1): + result = self._train_epoch(epoch) + + # save logged informations into log dict + log = {'epoch': epoch} + log.update(result) + self._record_best(log) + + # print logged informations to the screen + for key, value in log.items(): + print('\t{:15s}: {}'.format(str(key), value)) + + # evaluate model performance according to configured metric, save best checkpoint as model_best + best = False + if self.mnt_mode != 'off': + try: + # check whether model performance improved or not, according to specified metric(mnt_metric) + improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \ + (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best) + except KeyError: + print("Warning: Metric '{}' is not found. " "Model performance monitoring is disabled.".format( + self.mnt_metric)) + self.mnt_mode = 'off' + improved = False + + if improved: + self.mnt_best = log[self.mnt_metric] + not_improved_count = 0 + best = True + else: + not_improved_count += 1 + + if not_improved_count > self.early_stop: + print("Validation performance didn\'t improve for {} epochs. " "Training stops.".format( + self.early_stop)) + break + + if epoch % self.save_period == 0: + self._save_checkpoint(epoch, save_best=best) + self._print_best() + self._print_best_to_file() + + def _print_best_to_file(self): + crt_time = time.asctime(time.localtime(time.time())) + self.best_recorder['val']['time'] = crt_time + self.best_recorder['test']['time'] = crt_time + self.best_recorder['val']['seed'] = self.args.seed + self.best_recorder['test']['seed'] = self.args.seed + self.best_recorder['val']['best_model_from'] = 'val' + self.best_recorder['test']['best_model_from'] = 'test' + + if not os.path.exists(self.args.record_dir): + os.makedirs(self.args.record_dir) + record_path = os.path.join(self.args.record_dir, self.args.dataset_name+'.csv') + if not os.path.exists(record_path): + record_table = pd.DataFrame() + else: + record_table = pd.read_csv(record_path) + record_table = record_table.append(self.best_recorder['val'], ignore_index=True) + record_table = record_table.append(self.best_recorder['test'], ignore_index=True) + record_table.to_csv(record_path, index=False) + + def _prepare_device(self, n_gpu_use): + n_gpu = torch.cuda.device_count() + if n_gpu_use > 0 and n_gpu == 0: + print("Warning: There\'s no GPU available on this machine," "training will be performed on CPU.") + n_gpu_use = 0 + if n_gpu_use > n_gpu: + print( + "Warning: The number of GPU\'s configured to use is {}, but only {} are available " "on this machine.".format( + n_gpu_use, n_gpu)) + n_gpu_use = n_gpu + device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu') + list_ids = list(range(n_gpu_use)) + return device, list_ids + + def _save_checkpoint(self, epoch, save_best=False): + state = { + 'epoch': epoch, + 'state_dict': self.model.state_dict(), + 'optimizer': self.optimizer.state_dict(), + 'monitor_best': self.mnt_best + } + filename = os.path.join(self.checkpoint_dir, 'current_checkpoint.pth') + torch.save(state, filename) + print("Saving checkpoint: {} ...".format(filename)) + if save_best: + best_path = os.path.join(self.checkpoint_dir, 'model_best.pth') + torch.save(state, best_path) + print("Saving current best: model_best.pth ...") + + def _resume_checkpoint(self, resume_path): + resume_path = str(resume_path) + print("Loading checkpoint: {} ...".format(resume_path)) + checkpoint = torch.load(resume_path) + self.start_epoch = checkpoint['epoch'] + 1 + self.mnt_best = checkpoint['monitor_best'] + self.model.load_state_dict(checkpoint['state_dict']) + self.optimizer.load_state_dict(checkpoint['optimizer']) + + print("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch)) + + def _record_best(self, log): + improved_val = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.best_recorder['val'][ + self.mnt_metric]) or \ + (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.best_recorder['val'][self.mnt_metric]) + if improved_val: + self.best_recorder['val'].update(log) + + improved_test = (self.mnt_mode == 'min' and log[self.mnt_metric_test] <= self.best_recorder['test'][ + self.mnt_metric_test]) or \ + (self.mnt_mode == 'max' and log[self.mnt_metric_test] >= self.best_recorder['test'][ + self.mnt_metric_test]) + if improved_test: + self.best_recorder['test'].update(log) + + def _print_best(self): + print('Best results (w.r.t {}) in validation set:'.format(self.args.monitor_metric)) + for key, value in self.best_recorder['val'].items(): + print('\t{:15s}: {}'.format(str(key), value)) + + print('Best results (w.r.t {}) in test set:'.format(self.args.monitor_metric)) + for key, value in self.best_recorder['test'].items(): + print('\t{:15s}: {}'.format(str(key), value)) + + +if not os.path.exists('valreports/'): + os.makedirs('valreports/') +if not os.path.exists('testreports/'): + os.makedirs('testreports/') + +class Trainer(BaseTrainer): + def __init__(self, model, criterion, metric_ftns, optimizer, args, lr_scheduler, train_dataloader, val_dataloader, + test_dataloader): + super(Trainer, self).__init__(model, criterion, metric_ftns, optimizer, args) + self.lr_scheduler = lr_scheduler + self.train_dataloader = train_dataloader + self.val_dataloader = val_dataloader + self.test_dataloader = test_dataloader + + def _train_epoch(self, epoch): + + train_loss = 0 + self.model.train() + for batch_idx, (images_id, images, reports_ids, reports_masks) in enumerate(self.train_dataloader): + images, reports_ids, reports_masks = images.to(self.device), reports_ids.to(self.device), reports_masks.to( + self.device) + output = self.model(images, reports_ids, mode='train') + loss = self.criterion(output, reports_ids, reports_masks) + train_loss += loss.item() + self.optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_value_(self.model.parameters(), 0.1) + self.optimizer.step() + log = {'train_loss': train_loss / len(self.train_dataloader)} + + + self.model.eval() + with torch.no_grad(): + result_report_val = [] + val_gts, val_res = [], [] + for batch_idx, (images_id, images, reports_ids, reports_masks) in enumerate(self.val_dataloader): + images, reports_ids, reports_masks = images.to(self.device), reports_ids.to( + self.device), reports_masks.to(self.device) + output = self.model(images, mode='sample') + reports = self.model.tokenizer.decode_batch(output.cpu().numpy()) + for i in range(reports_ids.shape[0]): + temp1 = {'reports_ids': images_id[i], 'reports': reports[i]} + result_report_val.append(temp1) + ground_truths = self.model.tokenizer.decode_batch(reports_ids[:, 1:].cpu().numpy()) + val_res.extend(reports) + val_gts.extend(ground_truths) + val_met = self.metric_ftns({i: [gt] for i, gt in enumerate(val_gts)}, + {i: [re] for i, re in enumerate(val_res)}) + log.update(**{'val_' + k: v for k, v in val_met.items()}) + resFileval = 'valreports/mixed-' + str(epoch) + '.json' + json.dump(result_report_val, open(resFileval, 'w')) + + + self.model.eval() + with torch.no_grad(): + result_report_test = [] + test_gts, test_res = [], [] + for batch_idx, (images_id, images, reports_ids, reports_masks) in enumerate(self.test_dataloader): + images, reports_ids, reports_masks = images.to(self.device), reports_ids.to( + self.device), reports_masks.to(self.device) + output = self.model(images, mode='sample') + reports = self.model.tokenizer.decode_batch(output.cpu().numpy()) + # print('reportsreportsreportsreports',images_id,reports) + for i in range(reports_ids.shape[0]): + temp = {'reports_ids': images_id[i], 'reports': reports[i]} + result_report_test.append(temp) + ground_truths = self.model.tokenizer.decode_batch(reports_ids[:, 1:].cpu().numpy()) + test_res.extend(reports) + test_gts.extend(ground_truths) + test_met = self.metric_ftns({i: [gt] for i, gt in enumerate(test_gts)}, + {i: [re] for i, re in enumerate(test_res)}) + log.update(**{'test_' + k: v for k, v in test_met.items()}) + resFiletest = 'testreports/mixed-' + str(epoch) + '.json' + json.dump(result_report_test, open(resFiletest, 'w')) + self.lr_scheduler.step() + + return log \ No newline at end of file diff --git a/modules/utils.py b/modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d6c47335be626e64d0451d8212e8c6a966a62aa4 --- /dev/null +++ b/modules/utils.py @@ -0,0 +1,55 @@ +import torch + + +def penalty_builder(penalty_config): + if penalty_config == '': + return lambda x, y: y + pen_type, alpha = penalty_config.split('_') + alpha = float(alpha) + if pen_type == 'wu': + return lambda x, y: length_wu(x, y, alpha) + if pen_type == 'avg': + return lambda x, y: length_average(x, y, alpha) + + +def length_wu(length, logprobs, alpha=0.): + """ + NMT length re-ranking score from + "Google's Neural Machine Translation System" :cite:`wu2016google`. + """ + + modifier = (((5 + length) ** alpha) / + ((5 + 1) ** alpha)) + return logprobs / modifier + + +def length_average(length, logprobs, alpha=0.): + """ + Returns the average probability of tokens in a sequence. + """ + return logprobs / length + + +def split_tensors(n, x): + if torch.is_tensor(x): + assert x.shape[0] % n == 0 + x = x.reshape(x.shape[0] // n, n, *x.shape[1:]).unbind(1) + elif type(x) is list or type(x) is tuple: + x = [split_tensors(n, _) for _ in x] + elif x is None: + x = [None] * n + return x + + +def repeat_tensors(n, x): + """ + For a tensor of size Bx..., we repeat it n times, and make it Bnx... + For collections, do nested repeat + """ + if torch.is_tensor(x): + x = x.unsqueeze(1) # Bx1x... + x = x.expand(-1, n, *([-1] * len(x.shape[2:]))) # Bxnx... + x = x.reshape(x.shape[0] * n, *x.shape[2:]) # Bnx... + elif type(x) is list or type(x) is tuple: + x = [repeat_tensors(n, _) for _ in x] + return x diff --git a/modules/visual_extractor.py b/modules/visual_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..48ed3c6d2ddcfdb5166008502a60756dc8771403 --- /dev/null +++ b/modules/visual_extractor.py @@ -0,0 +1,53 @@ +import os + +os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' + +from medclip import MedCLIPModel, MedCLIPVisionModelViT +from medclip import MedCLIPProcessor +from PIL import Image +import torch +import torch.nn as nn +import torchvision.models as models +import torch.nn.functional as F + +class VisualExtractor(nn.Module): + # prepare for the demo image and text + def __init__(self, args): + super(VisualExtractor, self).__init__() + self.model = MedCLIPModel(vision_cls=MedCLIPVisionModelViT) + self.model.from_pretrained() + self.model.cuda() + self.processor = MedCLIPProcessor() + with torch.no_grad(): + self.prompt = torch.load('prompt/prompt.pth') + + + def forward(self, images): + a=[] + for i in images: + inputs = self.processor( text="lungs",images=i,return_tensors="pt",padding=True) + outputs = self.model(**inputs) + feats = outputs['img_embeds'] + a.append(feats) + batch_feats = torch.stack(a, dim=0) + + ha = [] + for i in range(batch_feats.shape[0]): + b = batch_feats[i].unsqueeze(1) + b = b.repeat(self.prompt.shape[0], 1, 1).transpose(-2, -1) + c_t = torch.bmm(self.prompt, b) + c_t = c_t.float() + alpha = F.softmax(c_t) + aa = alpha * self.prompt + sum_a = aa.sum(axis=0) + ha.append(sum_a) + featsem = torch.stack(ha, dim=0) + + feats = torch.cat((featsem, batch_feats), dim=2) + + patch_feats = feats.repeat(1, 49, 1) + batch_feats1 = feats.squeeze(1) + avg_feats = batch_feats1 + + + return patch_feats, avg_feats \ No newline at end of file diff --git a/prompt/prompt.pth b/prompt/prompt.pth new file mode 100644 index 0000000000000000000000000000000000000000..c2ee4862ec44b6c3a091433db1ea6cad43383b8e --- /dev/null +++ b/prompt/prompt.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4b03692f5ba61e9d50d10556cdbb724ed6249668873bad099bc6548af618a7d0 +size 20480747 diff --git a/pycocoevalcap/README.md b/pycocoevalcap/README.md new file mode 100644 index 0000000000000000000000000000000000000000..942de18171e9c5e9dad3adcaf28419de218a31e6 --- /dev/null +++ b/pycocoevalcap/README.md @@ -0,0 +1,23 @@ +Microsoft COCO Caption Evaluation Tools
+--- + +Modified the code to work with Python 3.
+ +### Requirements +* Python 3.x +* Java 1.8 +* pycocotools + +--- + +### Tested on +* Windows 10, Python 3.5. + +--- +### To fix Windows JVM memory error:
+Add the following in System Variables
+    Variable name : _JAVA_OPTIONS
+    Variable value : -Xmx1024M
+ +--- +Original code : https://github.com/tylin/coco-caption
diff --git a/pycocoevalcap/__init__.py b/pycocoevalcap/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..680063e9fa3f9a30d61b4e7ff7a337b7637ee7fd --- /dev/null +++ b/pycocoevalcap/__init__.py @@ -0,0 +1 @@ +__author__ = 'tylin' \ No newline at end of file diff --git a/pycocoevalcap/bleu/LICENSE b/pycocoevalcap/bleu/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..9ccf677900b5238062979c7bc1e7102e501e0be4 --- /dev/null +++ b/pycocoevalcap/bleu/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2015 Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/pycocoevalcap/bleu/__init__.py b/pycocoevalcap/bleu/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..680063e9fa3f9a30d61b4e7ff7a337b7637ee7fd --- /dev/null +++ b/pycocoevalcap/bleu/__init__.py @@ -0,0 +1 @@ +__author__ = 'tylin' \ No newline at end of file diff --git a/pycocoevalcap/bleu/bleu.py b/pycocoevalcap/bleu/bleu.py new file mode 100644 index 0000000000000000000000000000000000000000..60e723e510016c73c6ddb0513eb6558e42378621 --- /dev/null +++ b/pycocoevalcap/bleu/bleu.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python +# +# File Name : bleu.py +# +# Description : Wrapper for BLEU scorer. +# +# Creation Date : 06-01-2015 +# Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT +# Authors : Hao Fang and Tsung-Yi Lin + +# Last modified : Wed 22 May 2019 08:10:00 PM EDT +# By Sabarish Sivanath +# To support Python 3 + +from .bleu_scorer import BleuScorer + + +class Bleu: + def __init__(self, n=4): + # default compute Blue score up to 4 + self._n = n + self._hypo_for_image = {} + self.ref_for_image = {} + + def compute_score(self, gts, res, score_option = 'closest', verbose = 1): + ''' + Inputs: + gts - ground truths + res - predictions + score_option - {shortest, closest, average} + verbose - 1 or 0 + Outputs: + Blue scores + ''' + assert(gts.keys() == res.keys()) + imgIds = gts.keys() + + bleu_scorer = BleuScorer(n=self._n) + for id in imgIds: + hypo = res[id] + ref = gts[id] + + # Sanity check. + assert(type(hypo) is list) + assert(len(hypo) == 1) + assert(type(ref) is list) + #assert(len(ref) >= 1) + + bleu_scorer += (hypo[0], ref) + + score, scores = bleu_scorer.compute_score(option = score_option, verbose =verbose) + + # return (bleu, bleu_info) + return score, scores + + def method(self): + return "Bleu" diff --git a/pycocoevalcap/bleu/bleu_scorer.py b/pycocoevalcap/bleu/bleu_scorer.py new file mode 100644 index 0000000000000000000000000000000000000000..d5646aae12c24c6b0e9e104c8c09e6aa3645f396 --- /dev/null +++ b/pycocoevalcap/bleu/bleu_scorer.py @@ -0,0 +1,268 @@ +# bleu_scorer.py +# David Chiang + +# Copyright (c) 2004-2006 University of Maryland. All rights +# reserved. Do not redistribute without permission from the +# author. Not for commercial use. + +# Modified by: +# Hao Fang +# Tsung-Yi Lin + +# Last modified : Wed 22 May 2019 08:10:00 PM EDT +# By Sabarish Sivanath +# To support Python 3 + +'''Provides: +cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test(). +cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked(). +''' + +import copy +import sys, math, re +from collections import defaultdict + +def precook(s, n=4, out=False): + """Takes a string as input and returns an object that can be given to + either cook_refs or cook_test. This is optional: cook_refs and cook_test + can take string arguments as well.""" + words = s.split() + counts = defaultdict(int) + for k in range(1,n+1): + for i in range(len(words)-k+1): + ngram = tuple(words[i:i+k]) + counts[ngram] += 1 + return (len(words), counts) + +def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average" + '''Takes a list of reference sentences for a single segment + and returns an object that encapsulates everything that BLEU + needs to know about them.''' + + reflen = [] + maxcounts = {} + for ref in refs: + rl, counts = precook(ref, n) + reflen.append(rl) + for (ngram,count) in counts.items(): + maxcounts[ngram] = max(maxcounts.get(ngram,0), count) + + # Calculate effective reference sentence length. + if eff == "shortest": + reflen = min(reflen) + elif eff == "average": + reflen = float(sum(reflen))/len(reflen) + + ## lhuang: N.B.: leave reflen computaiton to the very end!! + + ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design) + + return (reflen, maxcounts) + +def cook_test(test, refs , eff=None, n=4): + '''Takes a test sentence and returns an object that + encapsulates everything that BLEU needs to know about it.''' + + reflen = refs[0] + refmaxcounts = refs[1] + + testlen, counts = precook(test, n, True) + + result = {} + + # Calculate effective reference sentence length. + + if eff == "closest": + result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1] + else: ## i.e., "average" or "shortest" or None + result["reflen"] = reflen + + result["testlen"] = testlen + + result["guess"] = [max(0,testlen-k+1) for k in range(1,n+1)] + + result['correct'] = [0]*n + for (ngram, count) in counts.items(): + result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count) + + return result + +class BleuScorer(object): + """Bleu scorer. + """ + + __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen" + # special_reflen is used in oracle (proportional effective ref len for a node). + + def copy(self): + ''' copy the refs.''' + new = BleuScorer(n=self.n) + new.ctest = copy.copy(self.ctest) + new.crefs = copy.copy(self.crefs) + new._score = None + return new + + def __init__(self, test=None, refs=None, n=4, special_reflen=None): + ''' singular instance ''' + + self.n = n + self.crefs = [] + self.ctest = [] + self.cook_append(test, refs) + self.special_reflen = special_reflen + + def cook_append(self, test, refs): + '''called by constructor and __iadd__ to avoid creating new instances.''' + + if refs is not None: + self.crefs.append(cook_refs(refs)) + if test is not None: + cooked_test = cook_test(test, self.crefs[-1]) + self.ctest.append(cooked_test) ## N.B.: -1 + else: + self.ctest.append(None) # lens of crefs and ctest have to match + + self._score = None ## need to recompute + + def ratio(self, option=None): + self.compute_score(option=option) + return self._ratio + + def score_ratio(self, option=None): + '''return (bleu, len_ratio) pair''' + return (self.fscore(option=option), self.ratio(option=option)) + + def score_ratio_str(self, option=None): + return "%.4f (%.2f)" % self.score_ratio(option) + + def reflen(self, option=None): + self.compute_score(option=option) + return self._reflen + + def testlen(self, option=None): + self.compute_score(option=option) + return self._testlen + + def retest(self, new_test): + if type(new_test) is str: + new_test = [new_test] + assert len(new_test) == len(self.crefs), new_test + self.ctest = [] + for t, rs in zip(new_test, self.crefs): + self.ctest.append(cook_test(t, rs)) + self._score = None + + return self + + def rescore(self, new_test): + ''' replace test(s) with new test(s), and returns the new score.''' + + return self.retest(new_test).compute_score() + + def size(self): + assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) + return len(self.crefs) + + def __iadd__(self, other): + '''add an instance (e.g., from another sentence).''' + + if type(other) is tuple: + ## avoid creating new BleuScorer instances + self.cook_append(other[0], other[1]) + else: + assert self.compatible(other), "incompatible BLEUs." + self.ctest.extend(other.ctest) + self.crefs.extend(other.crefs) + self._score = None ## need to recompute + + return self + + def compatible(self, other): + return isinstance(other, BleuScorer) and self.n == other.n + + def single_reflen(self, option="average"): + return self._single_reflen(self.crefs[0][0], option) + + def _single_reflen(self, reflens, option=None, testlen=None): + + if option == "shortest": + reflen = min(reflens) + elif option == "average": + reflen = float(sum(reflens))/len(reflens) + elif option == "closest": + reflen = min((abs(l-testlen), l) for l in reflens)[1] + else: + assert False, "unsupported reflen option %s" % option + + return reflen + + def recompute_score(self, option=None, verbose=0): + self._score = None + return self.compute_score(option, verbose) + + def compute_score(self, option=None, verbose=0): + n = self.n + small = 1e-9 + tiny = 1e-15 ## so that if guess is 0 still return 0 + bleu_list = [[] for _ in range(n)] + + if self._score is not None: + return self._score + + if option is None: + option = "average" if len(self.crefs) == 1 else "closest" + + self._testlen = 0 + self._reflen = 0 + totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n} + + # for each sentence + for comps in self.ctest: + testlen = comps['testlen'] + self._testlen += testlen + + if self.special_reflen is None: ## need computation + reflen = self._single_reflen(comps['reflen'], option, testlen) + else: + reflen = self.special_reflen + + self._reflen += reflen + + for key in ['guess','correct']: + for k in range(n): + totalcomps[key][k] += comps[key][k] + + # append per image bleu score + bleu = 1. + for k in range(n): + bleu *= (float(comps['correct'][k]) + tiny) \ + /(float(comps['guess'][k]) + small) + bleu_list[k].append(bleu ** (1./(k+1))) + ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division + if ratio < 1: + for k in range(n): + bleu_list[k][-1] *= math.exp(1 - 1/ratio) + + if verbose > 1: + print(comps, reflen) + + totalcomps['reflen'] = self._reflen + totalcomps['testlen'] = self._testlen + + bleus = [] + bleu = 1. + for k in range(n): + bleu *= float(totalcomps['correct'][k] + tiny) \ + / (totalcomps['guess'][k] + small) + bleus.append(bleu ** (1./(k+1))) + ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division + if ratio < 1: + for k in range(n): + bleus[k] *= math.exp(1 - 1/ratio) + + if verbose > 0: + print(totalcomps) + print("ratio:", ratio) + + self._score = bleus + return self._score, bleu_list diff --git a/pycocoevalcap/cider/__init__.py b/pycocoevalcap/cider/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3f7d85bba884ea8f83fc6ab2a1e6ade80d98d4d9 --- /dev/null +++ b/pycocoevalcap/cider/__init__.py @@ -0,0 +1 @@ +__author__ = 'tylin' diff --git a/pycocoevalcap/cider/cider.py b/pycocoevalcap/cider/cider.py new file mode 100644 index 0000000000000000000000000000000000000000..7aadb9a7881fbb5e9d3f5441203fb6e47378a706 --- /dev/null +++ b/pycocoevalcap/cider/cider.py @@ -0,0 +1,55 @@ +# Filename: cider.py +# +# Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric +# by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) +# +# Creation Date: Sun Feb 8 14:16:54 2015 +# +# Authors: Ramakrishna Vedantam and Tsung-Yi Lin + + +from .cider_scorer import CiderScorer +import pdb + +class Cider: + """ + Main Class to compute the CIDEr metric + + """ + def __init__(self, test=None, refs=None, n=4, sigma=6.0): + # set cider to sum over 1 to 4-grams + self._n = n + # set the standard deviation parameter for gaussian penalty + self._sigma = sigma + + def compute_score(self, gts, res): + """ + Main function to compute CIDEr score + :param hypo_for_image (dict) : dictionary with key and value + ref_for_image (dict) : dictionary with key and value + :return: cider (float) : computed CIDEr score for the corpus + """ + + assert(gts.keys() == res.keys()) + imgIds = gts.keys() + + cider_scorer = CiderScorer(n=self._n, sigma=self._sigma) + + for id in imgIds: + hypo = res[id] + ref = gts[id] + + # Sanity check. + assert(type(hypo) is list) + assert(len(hypo) == 1) + assert(type(ref) is list) + assert(len(ref) > 0) + + cider_scorer += (hypo[0], ref) + + (score, scores) = cider_scorer.compute_score() + + return score, scores + + def method(self): + return "CIDEr" \ No newline at end of file diff --git a/pycocoevalcap/cider/cider_scorer.py b/pycocoevalcap/cider/cider_scorer.py new file mode 100644 index 0000000000000000000000000000000000000000..94752e8198476106b929f2607f82b7e8e269ef13 --- /dev/null +++ b/pycocoevalcap/cider/cider_scorer.py @@ -0,0 +1,197 @@ +#!/usr/bin/env python +# Tsung-Yi Lin +# Ramakrishna Vedantam + + +# Last modified : Wed 22 May 2019 08:10:00 PM EDT +# By Sabarish Sivanath +# To support Python 3 + +import copy +from collections import defaultdict +import numpy as np +import pdb +import math + +def precook(s, n=4, out=False): + """ + Takes a string as input and returns an object that can be given to + either cook_refs or cook_test. This is optional: cook_refs and cook_test + can take string arguments as well. + :param s: string : sentence to be converted into ngrams + :param n: int : number of ngrams for which representation is calculated + :return: term frequency vector for occuring ngrams + """ + words = s.split() + counts = defaultdict(int) + for k in range(1,n+1): + for i in range(len(words)-k+1): + ngram = tuple(words[i:i+k]) + counts[ngram] += 1 + return counts + +def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" + '''Takes a list of reference sentences for a single segment + and returns an object that encapsulates everything that BLEU + needs to know about them. + :param refs: list of string : reference sentences for some image + :param n: int : number of ngrams for which (ngram) representation is calculated + :return: result (list of dict) + ''' + return [precook(ref, n) for ref in refs] + +def cook_test(test, n=4): + '''Takes a test sentence and returns an object that + encapsulates everything that BLEU needs to know about it. + :param test: list of string : hypothesis sentence for some image + :param n: int : number of ngrams for which (ngram) representation is calculated + :return: result (dict) + ''' + return precook(test, n, True) + +class CiderScorer(object): + """CIDEr scorer. + """ + + def copy(self): + ''' copy the refs.''' + new = CiderScorer(n=self.n) + new.ctest = copy.copy(self.ctest) + new.crefs = copy.copy(self.crefs) + return new + + def __init__(self, test=None, refs=None, n=4, sigma=6.0): + ''' singular instance ''' + self.n = n + self.sigma = sigma + self.crefs = [] + self.ctest = [] + self.document_frequency = defaultdict(float) + self.cook_append(test, refs) + self.ref_len = None + + def cook_append(self, test, refs): + '''called by constructor and __iadd__ to avoid creating new instances.''' + + if refs is not None: + self.crefs.append(cook_refs(refs)) + if test is not None: + self.ctest.append(cook_test(test)) ## N.B.: -1 + else: + self.ctest.append(None) # lens of crefs and ctest have to match + + def size(self): + assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) + return len(self.crefs) + + def __iadd__(self, other): + '''add an instance (e.g., from another sentence).''' + + if type(other) is tuple: + ## avoid creating new CiderScorer instances + self.cook_append(other[0], other[1]) + else: + self.ctest.extend(other.ctest) + self.crefs.extend(other.crefs) + + return self + def compute_doc_freq(self): + ''' + Compute term frequency for reference data. + This will be used to compute idf (inverse document frequency later) + The term frequency is stored in the object + :return: None + ''' + for refs in self.crefs: + # refs, k ref captions of one image + for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]): + self.document_frequency[ngram] += 1 + # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) + + def compute_cider(self): + def counts2vec(cnts): + """ + Function maps counts of ngram to vector of tfidf weights. + The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. + The n-th entry of array denotes length of n-grams. + :param cnts: + :return: vec (array of dict), norm (array of float), length (int) + """ + vec = [defaultdict(float) for _ in range(self.n)] + length = 0 + norm = [0.0 for _ in range(self.n)] + for (ngram,term_freq) in cnts.items(): + # give word count 1 if it doesn't appear in reference corpus + df = np.log(max(1.0, self.document_frequency[ngram])) + # ngram index + n = len(ngram)-1 + # tf (term_freq) * idf (precomputed idf) for n-grams + vec[n][ngram] = float(term_freq)*(self.ref_len - df) + # compute norm for the vector. the norm will be used for computing similarity + norm[n] += pow(vec[n][ngram], 2) + + if n == 1: + length += term_freq + norm = [np.sqrt(n) for n in norm] + return vec, norm, length + + def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): + ''' + Compute the cosine similarity of two vectors. + :param vec_hyp: array of dictionary for vector corresponding to hypothesis + :param vec_ref: array of dictionary for vector corresponding to reference + :param norm_hyp: array of float for vector corresponding to hypothesis + :param norm_ref: array of float for vector corresponding to reference + :param length_hyp: int containing length of hypothesis + :param length_ref: int containing length of reference + :return: array of score for each n-grams cosine similarity + ''' + delta = float(length_hyp - length_ref) + # measure consine similarity + val = np.array([0.0 for _ in range(self.n)]) + for n in range(self.n): + # ngram + for (ngram,count) in vec_hyp[n].items(): + # vrama91 : added clipping + val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram] + + if (norm_hyp[n] != 0) and (norm_ref[n] != 0): + val[n] /= (norm_hyp[n]*norm_ref[n]) + + assert(not math.isnan(val[n])) + # vrama91: added a length based gaussian penalty + val[n] *= np.e**(-(delta**2)/(2*self.sigma**2)) + return val + + # compute log reference length + self.ref_len = np.log(float(len(self.crefs))) + + scores = [] + for test, refs in zip(self.ctest, self.crefs): + # compute vector for test captions + vec, norm, length = counts2vec(test) + # compute vector for ref captions + score = np.array([0.0 for _ in range(self.n)]) + for ref in refs: + vec_ref, norm_ref, length_ref = counts2vec(ref) + score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) + # change by vrama91 - mean of ngram scores, instead of sum + score_avg = np.mean(score) + # divide by number of references + score_avg /= len(refs) + # multiply score by 10 + score_avg *= 10.0 + # append score of an image to the score list + scores.append(score_avg) + return scores + + def compute_score(self, option=None, verbose=0): + # compute idf + self.compute_doc_freq() + # assert to check document frequency + assert(len(self.ctest) >= max(self.document_frequency.values())) + # compute cider score + score = self.compute_cider() + # debug + # print score + return np.mean(np.array(score)), np.array(score) \ No newline at end of file diff --git a/pycocoevalcap/eval.py b/pycocoevalcap/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..21f53dcd52f3e7d93f30b239ca994dabd9aa7037 --- /dev/null +++ b/pycocoevalcap/eval.py @@ -0,0 +1,74 @@ +__author__ = 'tylin' +from .tokenizer.ptbtokenizer import PTBTokenizer +from .bleu.bleu import Bleu +from .meteor.meteor import Meteor +from .rouge.rouge import Rouge +from .cider.cider import Cider + +class COCOEvalCap: + def __init__(self, coco, cocoRes): + self.evalImgs = [] + self.eval = {} + self.imgToEval = {} + self.coco = coco + self.cocoRes = cocoRes + self.params = {'image_id': cocoRes.getImgIds()} + + def evaluate(self): + imgIds = self.params['image_id'] + # imgIds = self.coco.getImgIds() + gts = {} + res = {} + for imgId in imgIds: + gts[imgId] = self.coco.imgToAnns[imgId] + res[imgId] = self.cocoRes.imgToAnns[imgId] + + # ================================================= + # Set up scorers + # ================================================= + print('tokenization...') + tokenizer = PTBTokenizer() + gts = tokenizer.tokenize(gts) + res = tokenizer.tokenize(res) + + # ================================================= + # Set up scorers + # ================================================= + print('setting up scorers...') + scorers = [ + (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), + (Meteor(),"METEOR"), + (Rouge(), "ROUGE_L"), + (Cider(), "CIDEr") + ] + + # ================================================= + # Compute scores + # ================================================= + eval = {} + for scorer, method in scorers: + print('computing %s score...'%(scorer.method())) + score, scores = scorer.compute_score(gts, res) + if type(method) == list: + for sc, scs, m in zip(score, scores, method): + self.setEval(sc, m) + self.setImgToEvalImgs(scs, imgIds, m) + print("%s: %0.3f"%(m, sc)) + else: + self.setEval(score, method) + self.setImgToEvalImgs(scores, imgIds, method) + print("%s: %0.3f"%(method, score)) + self.setEvalImgs() + + def setEval(self, score, method): + self.eval[method] = score + + def setImgToEvalImgs(self, scores, imgIds, method): + for imgId, score in zip(imgIds, scores): + if not imgId in self.imgToEval: + self.imgToEval[imgId] = {} + self.imgToEval[imgId]["image_id"] = imgId + self.imgToEval[imgId][method] = score + + def setEvalImgs(self): + self.evalImgs = [eval for imgId, eval in self.imgToEval.items()] diff --git a/pycocoevalcap/license.txt b/pycocoevalcap/license.txt new file mode 100644 index 0000000000000000000000000000000000000000..3ada56f2474312c22e71968ab82233a5430a6428 --- /dev/null +++ b/pycocoevalcap/license.txt @@ -0,0 +1,26 @@ +Copyright (c) 2015, Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +The views and conclusions contained in the software and documentation are those +of the authors and should not be interpreted as representing official policies, +either expressed or implied, of the FreeBSD Project. \ No newline at end of file diff --git a/pycocoevalcap/meteor/__init__.py b/pycocoevalcap/meteor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..349338d26d52b73b4aaf5f575520a6d580025598 --- /dev/null +++ b/pycocoevalcap/meteor/__init__.py @@ -0,0 +1 @@ +from .meteor import * \ No newline at end of file diff --git a/pycocoevalcap/meteor/meteor-1.5.jar b/pycocoevalcap/meteor/meteor-1.5.jar new file mode 100644 index 0000000000000000000000000000000000000000..80e236efc242679f16fde17ea323f0a1cc35a126 --- /dev/null +++ b/pycocoevalcap/meteor/meteor-1.5.jar @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1e57b4c72c0830ebe68558f1c799a624e96cbc1b6045c9f6330e26dcff6eafc2 +size 6318693 diff --git a/pycocoevalcap/meteor/meteor.py b/pycocoevalcap/meteor/meteor.py new file mode 100644 index 0000000000000000000000000000000000000000..d27020d2658f60a55dd3c10dc36e33a38b755b9b --- /dev/null +++ b/pycocoevalcap/meteor/meteor.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python + +# Python wrapper for METEOR implementation, by Xinlei Chen +# Acknowledge Michael Denkowski for the generous discussion and help + +# Last modified : Wed 22 May 2019 08:10:00 PM EDT +# By Sabarish Sivanath +# To support Python 3 + +import os +import sys +import subprocess +import threading + +# Assumes meteor-1.5.jar is in the same directory as meteor.py. Change as needed. +METEOR_JAR = 'meteor-1.5.jar' + + +# print METEOR_JAR + +class Meteor: + + def __init__(self): + self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, \ + '-', '-', '-stdio', '-l', 'en', '-norm'] + self.meteor_p = subprocess.Popen(self.meteor_cmd, \ + cwd=os.path.dirname(os.path.abspath(__file__)), \ + stdin=subprocess.PIPE, \ + stdout=subprocess.PIPE, \ + stderr=subprocess.PIPE, + universal_newlines=True, + bufsize=1) + # Used to guarantee thread safety + self.lock = threading.Lock() + + def compute_score(self, gts, res): + assert (gts.keys() == res.keys()) + imgIds = gts.keys() + scores = [] + + eval_line = 'EVAL' + self.lock.acquire() + for i in imgIds: + assert (len(res[i]) == 1) + stat = self._stat(res[i][0], gts[i]) + eval_line += ' ||| {}'.format(stat) + + self.meteor_p.stdin.write('{}\n'.format(eval_line)) + for i in range(0, len(imgIds)): + scores.append(float(self.meteor_p.stdout.readline().strip())) + score = float(self.meteor_p.stdout.readline().strip()) + self.lock.release() + + return score, scores + + def method(self): + return "METEOR" + + def _stat(self, hypothesis_str, reference_list): + # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words + hypothesis_str = hypothesis_str.replace('|||', '').replace(' ', ' ') + score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) + self.meteor_p.stdin.write('{}\n'.format(score_line)) + return self.meteor_p.stdout.readline().strip() + + def _score(self, hypothesis_str, reference_list): + self.lock.acquire() + # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words + hypothesis_str = hypothesis_str.replace('|||', '').replace(' ', ' ') + score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) + self.meteor_p.stdin.write('{}\n'.format(score_line)) + stats = self.meteor_p.stdout.readline().strip() + eval_line = 'EVAL ||| {}'.format(stats) + # EVAL ||| stats + self.meteor_p.stdin.write('{}\n'.format(eval_line)) + score = float(self.meteor_p.stdout.readline().strip()) + # bug fix: there are two values returned by the jar file, one average, and one all, so do it twice + # thanks for Andrej for pointing this out + score = float(self.meteor_p.stdout.readline().strip()) + self.lock.release() + return score + + def __del__(self): + self.lock.acquire() + self.meteor_p.stdin.close() + self.meteor_p.kill() + self.meteor_p.wait() + self.lock.release() diff --git a/pycocoevalcap/rouge/__init__.py b/pycocoevalcap/rouge/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e3c0469a60189757cd73225f759004eef4cbeb00 --- /dev/null +++ b/pycocoevalcap/rouge/__init__.py @@ -0,0 +1 @@ +from .rouge import * \ No newline at end of file diff --git a/pycocoevalcap/rouge/rouge.py b/pycocoevalcap/rouge/rouge.py new file mode 100644 index 0000000000000000000000000000000000000000..3a10f5a50371328d397dcb53c7c9d81eac9472fa --- /dev/null +++ b/pycocoevalcap/rouge/rouge.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python +# +# File Name : rouge.py +# +# Description : Computes ROUGE-L metric as described by Lin and Hovey (2004) +# +# Creation Date : 2015-01-07 06:03 +# Author : Ramakrishna Vedantam + +import numpy as np +import pdb + +def my_lcs(string, sub): + """ + Calculates longest common subsequence for a pair of tokenized strings + :param string : list of str : tokens from a string split using whitespace + :param sub : list of str : shorter string, also split using whitespace + :returns: length (list of int): length of the longest common subsequence between the two strings + + Note: my_lcs only gives length of the longest common subsequence, not the actual LCS + """ + if(len(string)< len(sub)): + sub, string = string, sub + + lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)] + + for j in range(1,len(sub)+1): + for i in range(1,len(string)+1): + if(string[i-1] == sub[j-1]): + lengths[i][j] = lengths[i-1][j-1] + 1 + else: + lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1]) + + return lengths[len(string)][len(sub)] + +class Rouge(): + ''' + Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set + + ''' + def __init__(self): + # vrama91: updated the value below based on discussion with Hovey + self.beta = 1.2 + + def calc_score(self, candidate, refs): + """ + Compute ROUGE-L score given one candidate and references for an image + :param candidate: str : candidate sentence to be evaluated + :param refs: list of str : COCO reference sentences for the particular image to be evaluated + :returns score: int (ROUGE-L score for the candidate evaluated against references) + """ + assert(len(candidate)==1) + assert(len(refs)>0) + prec = [] + rec = [] + + # split into tokens + token_c = candidate[0].split(" ") + + for reference in refs: + # split into tokens + token_r = reference.split(" ") + # compute the longest common subsequence + lcs = my_lcs(token_r, token_c) + prec.append(lcs/float(len(token_c))) + rec.append(lcs/float(len(token_r))) + + prec_max = max(prec) + rec_max = max(rec) + + if(prec_max!=0 and rec_max !=0): + score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max) + else: + score = 0.0 + return score + + def compute_score(self, gts, res): + """ + Computes Rouge-L score given a set of reference and candidate sentences for the dataset + Invoked by evaluate_captions.py + :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values + :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values + :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images) + """ + assert(gts.keys() == res.keys()) + imgIds = gts.keys() + + score = [] + for id in imgIds: + hypo = res[id] + ref = gts[id] + + score.append(self.calc_score(hypo, ref)) + + # Sanity check. + assert(type(hypo) is list) + assert(len(hypo) == 1) + assert(type(ref) is list) + assert(len(ref) > 0) + + average_score = np.mean(np.array(score)) + return average_score, np.array(score) + + def method(self): + return "Rouge" diff --git a/pycocoevalcap/tokenizer/__init__.py b/pycocoevalcap/tokenizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..71357a4bff7219ddcf7f7321cfeb4484bd8bee08 --- /dev/null +++ b/pycocoevalcap/tokenizer/__init__.py @@ -0,0 +1 @@ +__author__ = 'hfang' diff --git a/pycocoevalcap/tokenizer/ptbtokenizer.py b/pycocoevalcap/tokenizer/ptbtokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..b7d06e154daa3230ed88cd46c823a0dffbc0f291 --- /dev/null +++ b/pycocoevalcap/tokenizer/ptbtokenizer.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python +# +# File Name : ptbtokenizer.py +# +# Description : Do the PTB Tokenization and remove punctuations. +# +# Creation Date : 29-12-2014 +# Last Modified : Thu Mar 19 09:53:35 2015 +# Authors : Hao Fang and Tsung-Yi Lin + +import os +import sys +import subprocess +import tempfile +import itertools + + +# Last modified : Wed 22 May 2019 08:10:00 PM EDT +# By Sabarish Sivanath +# To support Python 3 + +# path to the stanford corenlp jar +STANFORD_CORENLP_3_4_1_JAR = 'stanford-corenlp-3.4.1.jar' + +# punctuations to be removed from the sentences +PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \ + ".", "?", "!", ",", ":", "-", "--", "...", ";"] + +class PTBTokenizer: + """Python wrapper of Stanford PTBTokenizer""" + + def tokenize(self, captions_for_image): + cmd = ['java', '-cp', STANFORD_CORENLP_3_4_1_JAR, \ + 'edu.stanford.nlp.process.PTBTokenizer', \ + '-preserveLines', '-lowerCase'] + + # ====================================================== + # prepare data for PTB Tokenizer + # ====================================================== + final_tokenized_captions_for_image = {} + image_id = [k for k, v in captions_for_image.items() for _ in range(len(v))] + sentences = '\n'.join([c['caption'].replace('\n', ' ') for k, v in captions_for_image.items() for c in v]) + + # ====================================================== + # save sentences to temporary file + # ====================================================== + path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__)) + tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dirname) + tmp_file.write(sentences.encode('utf-8')) + tmp_file.close() + + # ====================================================== + # tokenize sentence + # ====================================================== + cmd.append(os.path.basename(tmp_file.name)) + p_tokenizer = subprocess.Popen(cmd, + cwd=path_to_jar_dirname, + stdout=subprocess.PIPE, + universal_newlines = True, + bufsize = 1) + token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0] + lines = token_lines.split('\n') + # remove temp file + os.remove(tmp_file.name) + + # ====================================================== + # create dictionary for tokenized captions + # ====================================================== + for k, line in zip(image_id, lines): + if not k in final_tokenized_captions_for_image: + final_tokenized_captions_for_image[k] = [] + tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \ + if w not in PUNCTUATIONS]) + final_tokenized_captions_for_image[k].append(tokenized_caption) + + return final_tokenized_captions_for_image diff --git a/pycocoevalcap/tokenizer/stanford-corenlp-3.4.1.jar b/pycocoevalcap/tokenizer/stanford-corenlp-3.4.1.jar new file mode 100644 index 0000000000000000000000000000000000000000..07e4e5e4f90d7060180c968bf31ca35084627c2d --- /dev/null +++ b/pycocoevalcap/tokenizer/stanford-corenlp-3.4.1.jar @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2fcb91bb7a111f93d71e264f4ee0e3afd19ba0dde6d21b38605088df9e940399 +size 5921410 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..16ebe12f9341ff328a7ef7c62ae5b47ae26574c9 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,15 @@ +numpy +pandas +Pillow +requests +tqdm +wget +medclip +opencv-python>=4.4.0.42 +nltk>=3.7 +scikit_learn>=1.1.2 +textaugment>=1.3.4 +timm>=0.6.11 +torch>=1.12.1 +torchvision>=0.13.1 +transformers>=4.23.1,<=4.24.0 diff --git a/test.py b/test.py new file mode 100644 index 0000000000000000000000000000000000000000..43b9c4074328961d5131da8766756247eaff222c --- /dev/null +++ b/test.py @@ -0,0 +1,99 @@ +import torch +import argparse +from modules.dataloader import R2DataLoader +from modules.tokenizers import Tokenizer +from modules.loss import compute_loss +from modules.metrics import compute_scores +from models.models import MedCapModel +from modules.tester import Tester +import numpy as np +import os +os.environ['CURL_CA_BUNDLE'] = '' +def main(): + parser = argparse.ArgumentParser() + + # Data input Settings + parser.add_argument('--json_path', default='data/mimic_cxr/annotation.json', + help='Path to the json file') + parser.add_argument('--image_dir', default='data/mimic_cxr/images/', + help='Directory of images') + + # Dataloader Settings + parser.add_argument('--dataset', default='iu_xray', help='dataset for training MedCap') + parser.add_argument('--bs', type=int, default=16) + parser.add_argument('--threshold', type=int, default=3, help='the cut off frequency for the words.') + parser.add_argument('--num_workers', type=int, default=2, help='the number of workers for dataloader.') + parser.add_argument('--max_seq_length', type=int, default=1024, help='the maximum sequence length of the reports.') + + #Trainer Settings + parser.add_argument('--epochs', type=int, default=30) + parser.add_argument('--n_gpu', type=int, default=1, help='the number of gpus to be used.') + parser.add_argument('--save_dir', type=str, default='results/mimic_cxr/', help='the patch to save the models.') + parser.add_argument('--record_dir', type=str, default='./record_dir/', + help='the patch to save the results of experiments.') + parser.add_argument('--log_period', type=int, default=1000, help='the logging interval (in batches).') + parser.add_argument('--save_period', type=int, default=1) + parser.add_argument('--monitor_mode', type=str, default='max', choices=['min', 'max'], help='whether to max or min the metric.') + parser.add_argument('--monitor_metric', type=str, default='BLEU_4', help='the metric to be monitored.') + parser.add_argument('--early_stop', type=int, default=50, help='the patience of training.') + + # Training related + parser.add_argument('--noise_inject', default='no', choices=['yes', 'no']) + + # Sample related + parser.add_argument('--sample_method', type=str, default='greedy', help='the sample methods to sample a report.') + parser.add_argument('--prompt',default='/prompt/prompt.pt') + parser.add_argument('--prompt_load', default='yes',choices=['yes','no']) + + # Optimization + parser.add_argument('--optim', type=str, default='Adam', help='the type of the optimizer.') + parser.add_argument('--lr_ve', type=float, default=5e-5, help='the learning rate for the visual extractor.') + parser.add_argument('--lr_ed', type=float, default=7e-4, help='the learning rate for the remaining parameters.') + parser.add_argument('--weight_decay', type=float, default=5e-5, help='the weight decay.') + parser.add_argument('--adam_betas', type=tuple, default=(0.9, 0.98), help='the weight decay.') + parser.add_argument('--adam_eps', type=float, default=1e-9, help='the weight decay.') + parser.add_argument('--amsgrad', type=bool, default=True, help='.') + parser.add_argument('--noamopt_warmup', type=int, default=5000, help='.') + parser.add_argument('--noamopt_factor', type=int, default=1, help='.') + + # Learning Rate Scheduler + parser.add_argument('--lr_scheduler', type=str, default='StepLR', help='the type of the learning rate scheduler.') + parser.add_argument('--step_size', type=int, default=50, help='the step size of the learning rate scheduler.') + parser.add_argument('--gamma', type=float, default=0.1, help='the gamma of the learning rate scheduler.') + + # Others + parser.add_argument('--seed', type=int, default=9153, help='.') + parser.add_argument('--resume', type=str, help='whether to resume the training from existing checkpoints.') + parser.add_argument('--train_mode', default='base', choices=['base', 'full'], + help='Training mode: base (text only training) or full (full supervised training)') + parser.add_argument('--full_supervised_version', default='v1', choices=['v1', 'v2' , 'v3'], + help='Full supervised version: v1 (only get image features) or v2 (feature fusion) or v3(feature fusion+image features') + parser.add_argument('--clip_update', default='no' , choices=['yes','no']) + parser.add_argument('--load', type=str, help='whether to load the pre-trained model.') + + + args = parser.parse_args() + + # fix random seeds + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + np.random.seed(args.seed) + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # create tokenizer + tokenizer = Tokenizer(args) + test_dataloader = R2DataLoader(args, tokenizer, split='test', shuffle=False) + + # get function handles of loss and metrics + criterion = compute_loss + metrics = compute_scores + model = MedCapModel(args, tokenizer) + + # build trainer and start to train + tester = Tester(model, criterion, metrics, args, test_dataloader) + tester.test() + +if __name__ == '__main__': + main() diff --git a/train.py b/train.py new file mode 100644 index 0000000000000000000000000000000000000000..619e0c9d3a34962af1a8fa5764a21a6b4ef4b1e1 --- /dev/null +++ b/train.py @@ -0,0 +1,124 @@ +import torch +import argparse +import numpy as np +from modules.tokenizers import Tokenizer +from modules.dataloaders import R2DataLoader +from modules.metrics import compute_scores +from modules.optimizers import build_optimizer, build_lr_scheduler +from modules.trainer import Trainer +from modules.loss import compute_loss +from models.r2gen import R2GenModel + + +def parse_agrs(): + parser = argparse.ArgumentParser() + + # Data input settings + parser.add_argument('--image_dir', type=str, default='data/mimic_cxr/images/', help='the path to the directory containing the data.') + parser.add_argument('--ann_path', type=str, default='data/mimic_cxr/annotation.json', help='the path to the directory containing the data.') + + # Data loader settings + parser.add_argument('--dataset_name', type=str, default='mimic_cxr', choices=['iu_xray', 'mimic_cxr'], help='the dataset to be used.') + parser.add_argument('--max_seq_length', type=int, default=60, help='the maximum sequence length of the reports.') + parser.add_argument('--threshold', type=int, default=3, help='the cut off frequency for the words.') + parser.add_argument('--num_workers', type=int, default=2, help='the number of workers for dataloader.') + parser.add_argument('--batch_size', type=int, default=16, help='the number of samples for a batch') + + # Model settings (for visual extractor) + parser.add_argument('--visual_extractor', type=str, default='resnet101', help='the visual extractor to be used.') + parser.add_argument('--visual_extractor_pretrained', type=bool, default=True, help='whether to load the pretrained visual extractor') + + # Model settings (for Transformer) + parser.add_argument('--d_model', type=int, default=512, help='the dimension of Transformer.') + parser.add_argument('--d_ff', type=int, default=512, help='the dimension of FFN.') + parser.add_argument('--d_vf', type=int, default=2048, help='the dimension of the patch features.') + parser.add_argument('--num_heads', type=int, default=8, help='the number of heads in Transformer.') + parser.add_argument('--num_layers', type=int, default=3, help='the number of layers of Transformer.') + parser.add_argument('--dropout', type=float, default=0.1, help='the dropout rate of Transformer.') + parser.add_argument('--logit_layers', type=int, default=1, help='the number of the logit layer.') + parser.add_argument('--bos_idx', type=int, default=0, help='the index of .') + parser.add_argument('--eos_idx', type=int, default=0, help='the index of .') + parser.add_argument('--pad_idx', type=int, default=0, help='the index of .') + parser.add_argument('--use_bn', type=int, default=0, help='whether to use batch normalization.') + parser.add_argument('--drop_prob_lm', type=float, default=0.5, help='the dropout rate of the output layer.') + # for Relational Memory + parser.add_argument('--rm_num_slots', type=int, default=3, help='the number of memory slots.') + parser.add_argument('--rm_num_heads', type=int, default=8, help='the numebr of heads in rm.') + parser.add_argument('--rm_d_model', type=int, default=512, help='the dimension of rm.') + + # Sample related + parser.add_argument('--sample_method', type=str, default='beam_search', help='the sample methods to sample a report.') + parser.add_argument('--beam_size', type=int, default=3, help='the beam size when beam searching.') + parser.add_argument('--temperature', type=float, default=1.0, help='the temperature when sampling.') + parser.add_argument('--sample_n', type=int, default=1, help='the sample number per image.') + parser.add_argument('--group_size', type=int, default=1, help='the group size.') + parser.add_argument('--output_logsoftmax', type=int, default=1, help='whether to output the probabilities.') + parser.add_argument('--decoding_constraint', type=int, default=0, help='whether decoding constraint.') + parser.add_argument('--block_trigrams', type=int, default=1, help='whether to use block trigrams.') + + # Trainer settings + parser.add_argument('--n_gpu', type=int, default=1, help='the number of gpus to be used.') + parser.add_argument('--epochs', type=int, default=100, help='the number of training epochs.') + parser.add_argument('--save_dir', type=str, default='results/iu_xray', help='the patch to save the models.') + parser.add_argument('--record_dir', type=str, default='records/', help='the patch to save the results of experiments') + parser.add_argument('--save_period', type=int, default=1, help='the saving period.') + parser.add_argument('--monitor_mode', type=str, default='max', choices=['min', 'max'], help='whether to max or min the metric.') + parser.add_argument('--monitor_metric', type=str, default='BLEU_4', help='the metric to be monitored.') + parser.add_argument('--early_stop', type=int, default=50, help='the patience of training.') + + # Optimization + parser.add_argument('--optim', type=str, default='Adam', help='the type of the optimizer.') + parser.add_argument('--lr_ve', type=float, default=5e-5, help='the learning rate for the visual extractor.') + parser.add_argument('--lr_ed', type=float, default=1e-4, help='the learning rate for the remaining parameters.') + parser.add_argument('--weight_decay', type=float, default=5e-5, help='the weight decay.') + parser.add_argument('--amsgrad', type=bool, default=True, help='.') + + # Learning Rate Scheduler + parser.add_argument('--lr_scheduler', type=str, default='StepLR', help='the type of the learning rate scheduler.') + parser.add_argument('--step_size', type=int, default=50, help='the step size of the learning rate scheduler.') + parser.add_argument('--gamma', type=float, default=0.1, help='the gamma of the learning rate scheduler.') + + # Others + parser.add_argument('--seed', type=int, default=9233, help='.') + parser.add_argument('--resume', type=str, help='whether to resume the training from existing checkpoints.') + + args = parser.parse_args() + return args + + +def main(): + # parse arguments + args = parse_agrs() + + # fix random seeds + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + np.random.seed(args.seed) + + # create tokenizer + tokenizer = Tokenizer(args) + + # create data loader + train_dataloader = R2DataLoader(args, tokenizer, split='train', shuffle=True) + val_dataloader = R2DataLoader(args, tokenizer, split='val', shuffle=False) + test_dataloader = R2DataLoader(args, tokenizer, split='test', shuffle=False) + + # build model architecture + model = R2GenModel(args, tokenizer) + + # get function handles of loss and metrics + criterion = compute_loss + metrics = compute_scores + + # build optimizer, learning rate scheduler + optimizer = build_optimizer(args, model) + lr_scheduler = build_lr_scheduler(args, optimizer) + + # build trainer and start to train + trainer = Trainer(model, criterion, metrics, optimizer, args, lr_scheduler, train_dataloader, val_dataloader, test_dataloader) + trainer.train() + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/training_loss_epoch.jpg b/training_loss_epoch.jpg new file mode 100644 index 0000000000000000000000000000000000000000..78cf6b9f4112a25807f2c97207926429d4883ae5 Binary files /dev/null and b/training_loss_epoch.jpg differ diff --git a/training_loss_iteration.jpg b/training_loss_iteration.jpg new file mode 100644 index 0000000000000000000000000000000000000000..597a823665ddef5eb86aadbe05eca40ebffb95a2 Binary files /dev/null and b/training_loss_iteration.jpg differ