diff --git a/omg_llava/__init__.py b/omg_llava/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/omg_llava/__pycache__/__init__.cpython-310.pyc b/omg_llava/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba1dfd8482764fa057058c8ec6b172ab2ccdced4 Binary files /dev/null and b/omg_llava/__pycache__/__init__.cpython-310.pyc differ diff --git a/omg_llava/configs/__init__.py b/omg_llava/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_baseline.py b/omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_baseline.py new file mode 100644 index 0000000000000000000000000000000000000000..41cbb83cec11555506f2bed10386487af9009346 --- /dev/null +++ b/omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_baseline.py @@ -0,0 +1,951 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\ + semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\ + referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = '/mnt/workspace/taozhang/chekpoints/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './work_dirs/omg_llava_7b_pretrain_1024image_8gpus.pth' +omg_ov_class_embed_path='/mnt/workspace/taozhang/chekpoints/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = '/mnt/workspace/taozhang/chekpoints/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 0 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset, + glamm_grandf_dataset, glamm_psg_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + region_cap_osprey_dataset, region_conversation_osprey_dataset, + mdpv_detailed_description_ade20k_dataset, + mdpv_detailed_description_cocostuff_10k_dataset, + mdpv_detailed_description_cocostuff_164k_dataset, + mdpv_detailed_description_vg_dataset, + mdpv_brief_description_lvis_dataset, + mdpv_brief_description_vg_dataset, + mdpv_brief_description_ade20k_dataset, + mdpv_brief_description_cocostuff10k_dataset, + mdpv_brief_description_cocostuff164k_dataset, + mdpv_qa_vg_dataset, + mdpv_qa_lvis_dataset, + mdpv_qa_ade20k_dataset, + mdpv_qa_cocostuff10k_dataset, + mdpv_qa_cocostuff164k_dataset, + mdpv_multi_points_flicker30k_dataset, + mdpv_multi_points_openpsg_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_cat.py b/omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_cat.py new file mode 100644 index 0000000000000000000000000000000000000000..b6da12d81e9c0127601e731fcc4528b1a4a061a3 --- /dev/null +++ b/omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_cat.py @@ -0,0 +1,954 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\ + semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\ + referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = '/mnt/workspace/taozhang/chekpoints/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './work_dirs/omg_llava_7b_pretrain_1024image_8gpus.pth' +omg_ov_class_embed_path='/mnt/workspace/taozhang/chekpoints/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = '/mnt/workspace/taozhang/chekpoints/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 0 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + using_multilayer_states=True, + seg_token_merge_type='cat', + selected_layers=32, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset, + glamm_grandf_dataset, glamm_psg_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + region_cap_osprey_dataset, region_conversation_osprey_dataset, + mdpv_detailed_description_ade20k_dataset, + mdpv_detailed_description_cocostuff_10k_dataset, + mdpv_detailed_description_cocostuff_164k_dataset, + mdpv_detailed_description_vg_dataset, + mdpv_brief_description_lvis_dataset, + mdpv_brief_description_vg_dataset, + mdpv_brief_description_ade20k_dataset, + mdpv_brief_description_cocostuff10k_dataset, + mdpv_brief_description_cocostuff164k_dataset, + mdpv_qa_vg_dataset, + mdpv_qa_lvis_dataset, + mdpv_qa_ade20k_dataset, + mdpv_qa_cocostuff10k_dataset, + mdpv_qa_cocostuff164k_dataset, + mdpv_multi_points_flicker30k_dataset, + mdpv_multi_points_openpsg_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_cat_debug.py b/omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_cat_debug.py new file mode 100644 index 0000000000000000000000000000000000000000..ef61339363f4e7bac0a347b5ee86f24989fa5a05 --- /dev/null +++ b/omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_cat_debug.py @@ -0,0 +1,927 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\ + semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\ + referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = '/mnt/workspace/taozhang/chekpoints/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './work_dirs/omg_llava_7b_pretrain_1024image_8gpus.pth' +omg_ov_class_embed_path='/mnt/workspace/taozhang/chekpoints/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = '/mnt/workspace/taozhang/chekpoints/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 0 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + using_multilayer_states=True, + seg_token_merge_type='cat', + selected_layers=32, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[glamm_refcocog_dataset, ], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_linear_cat.py b/omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_linear_cat.py new file mode 100644 index 0000000000000000000000000000000000000000..415bc59f269dbddfb786ff3240e11e44c730d895 --- /dev/null +++ b/omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_linear_cat.py @@ -0,0 +1,954 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\ + semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\ + referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = '/mnt/workspace/taozhang/chekpoints/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './work_dirs/omg_llava_7b_pretrain_1024image_8gpus.pth' +omg_ov_class_embed_path='/mnt/workspace/taozhang/chekpoints/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = '/mnt/workspace/taozhang/chekpoints/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 0 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + using_multilayer_states=True, + seg_token_merge_type='linear_cat', + selected_layers=32, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset, + glamm_grandf_dataset, glamm_psg_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + region_cap_osprey_dataset, region_conversation_osprey_dataset, + mdpv_detailed_description_ade20k_dataset, + mdpv_detailed_description_cocostuff_10k_dataset, + mdpv_detailed_description_cocostuff_164k_dataset, + mdpv_detailed_description_vg_dataset, + mdpv_brief_description_lvis_dataset, + mdpv_brief_description_vg_dataset, + mdpv_brief_description_ade20k_dataset, + mdpv_brief_description_cocostuff10k_dataset, + mdpv_brief_description_cocostuff164k_dataset, + mdpv_qa_vg_dataset, + mdpv_qa_lvis_dataset, + mdpv_qa_ade20k_dataset, + mdpv_qa_cocostuff10k_dataset, + mdpv_qa_cocostuff164k_dataset, + mdpv_multi_points_flicker30k_dataset, + mdpv_multi_points_openpsg_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_linearcat_debug.py b/omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_linearcat_debug.py new file mode 100644 index 0000000000000000000000000000000000000000..782946579760f1aa0ac7d91993e0d9de8b54830d --- /dev/null +++ b/omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_linearcat_debug.py @@ -0,0 +1,927 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\ + semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\ + referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = '/mnt/workspace/taozhang/chekpoints/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './work_dirs/omg_llava_7b_pretrain_1024image_8gpus.pth' +omg_ov_class_embed_path='/mnt/workspace/taozhang/chekpoints/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = '/mnt/workspace/taozhang/chekpoints/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 0 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + using_multilayer_states=True, + seg_token_merge_type='linear_cat', + selected_layers=32, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[glamm_refcocog_dataset, ], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_mean.py b/omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_mean.py new file mode 100644 index 0000000000000000000000000000000000000000..7dcf07802c3a081dfc224af7696c33a8fce37190 --- /dev/null +++ b/omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_mean.py @@ -0,0 +1,954 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\ + semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\ + referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = '/mnt/workspace/taozhang/chekpoints/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './work_dirs/omg_llava_7b_pretrain_1024image_8gpus.pth' +omg_ov_class_embed_path='/mnt/workspace/taozhang/chekpoints/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = '/mnt/workspace/taozhang/chekpoints/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 0 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + using_multilayer_states=True, + seg_token_merge_type='mean', + selected_layers=32, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset, + glamm_grandf_dataset, glamm_psg_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + region_cap_osprey_dataset, region_conversation_osprey_dataset, + mdpv_detailed_description_ade20k_dataset, + mdpv_detailed_description_cocostuff_10k_dataset, + mdpv_detailed_description_cocostuff_164k_dataset, + mdpv_detailed_description_vg_dataset, + mdpv_brief_description_lvis_dataset, + mdpv_brief_description_vg_dataset, + mdpv_brief_description_ade20k_dataset, + mdpv_brief_description_cocostuff10k_dataset, + mdpv_brief_description_cocostuff164k_dataset, + mdpv_qa_vg_dataset, + mdpv_qa_lvis_dataset, + mdpv_qa_ade20k_dataset, + mdpv_qa_cocostuff10k_dataset, + mdpv_qa_cocostuff164k_dataset, + mdpv_multi_points_flicker30k_dataset, + mdpv_multi_points_openpsg_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/ablation_multi_seg_states/debug.py b/omg_llava/configs/finetune/ablation_multi_seg_states/debug.py new file mode 100644 index 0000000000000000000000000000000000000000..3c5782d906d9e7800a423a7e1b0ac5e254639cf2 --- /dev/null +++ b/omg_llava/configs/finetune/ablation_multi_seg_states/debug.py @@ -0,0 +1,924 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\ + semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\ + referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = '/mnt/workspace/taozhang/chekpoints/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './work_dirs/omg_llava_7b_pretrain_1024image_8gpus/iter_4361.pth' +omg_ov_class_embed_path='/mnt/workspace/taozhang/chekpoints/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = '/mnt/workspace/taozhang/chekpoints/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[glamm_refcocog_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/ablation_projector/finetune_ablation_projector_cross.py b/omg_llava/configs/finetune/ablation_projector/finetune_ablation_projector_cross.py new file mode 100644 index 0000000000000000000000000000000000000000..b98374afc1d1350cbe2b4e6a2d4d82e23e6c9092 --- /dev/null +++ b/omg_llava/configs/finetune/ablation_projector/finetune_ablation_projector_cross.py @@ -0,0 +1,953 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\ + semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\ + referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './pretrained/omg_llava/omg_llava_7b_pretrain_1024image_seperate_cross_8gpus.pth' +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + visual_prompt_proj=False, + add_cross_attn_layer=True, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset, + glamm_grandf_dataset, glamm_psg_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + region_cap_osprey_dataset, region_conversation_osprey_dataset, + mdpv_detailed_description_ade20k_dataset, + mdpv_detailed_description_cocostuff_10k_dataset, + mdpv_detailed_description_cocostuff_164k_dataset, + mdpv_detailed_description_vg_dataset, + mdpv_brief_description_lvis_dataset, + mdpv_brief_description_vg_dataset, + mdpv_brief_description_ade20k_dataset, + mdpv_brief_description_cocostuff10k_dataset, + mdpv_brief_description_cocostuff164k_dataset, + mdpv_qa_vg_dataset, + mdpv_qa_lvis_dataset, + mdpv_qa_ade20k_dataset, + mdpv_qa_cocostuff10k_dataset, + mdpv_qa_cocostuff164k_dataset, + mdpv_multi_points_flicker30k_dataset, + mdpv_multi_points_openpsg_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/ablation_projector/finetune_ablation_projector_seperate.py b/omg_llava/configs/finetune/ablation_projector/finetune_ablation_projector_seperate.py new file mode 100644 index 0000000000000000000000000000000000000000..2690aed1cbafbb4667910d37a20ba9610e010e7b --- /dev/null +++ b/omg_llava/configs/finetune/ablation_projector/finetune_ablation_projector_seperate.py @@ -0,0 +1,953 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\ + semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\ + referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './pretrained/omg_llava/omg_llava_7b_pretrain_1024image_8gpus.pth' +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + visual_prompt_proj=True, + add_cross_attn_layer=False, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset, + glamm_grandf_dataset, glamm_psg_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + region_cap_osprey_dataset, region_conversation_osprey_dataset, + mdpv_detailed_description_ade20k_dataset, + mdpv_detailed_description_cocostuff_10k_dataset, + mdpv_detailed_description_cocostuff_164k_dataset, + mdpv_detailed_description_vg_dataset, + mdpv_brief_description_lvis_dataset, + mdpv_brief_description_vg_dataset, + mdpv_brief_description_ade20k_dataset, + mdpv_brief_description_cocostuff10k_dataset, + mdpv_brief_description_cocostuff164k_dataset, + mdpv_qa_vg_dataset, + mdpv_qa_lvis_dataset, + mdpv_qa_ade20k_dataset, + mdpv_qa_cocostuff10k_dataset, + mdpv_qa_cocostuff164k_dataset, + mdpv_multi_points_flicker30k_dataset, + mdpv_multi_points_openpsg_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/ablation_projector/finetune_ablation_projector_seperate_cross.py b/omg_llava/configs/finetune/ablation_projector/finetune_ablation_projector_seperate_cross.py new file mode 100644 index 0000000000000000000000000000000000000000..eb575feac689040f4b98979e56fca6c1f72f3f35 --- /dev/null +++ b/omg_llava/configs/finetune/ablation_projector/finetune_ablation_projector_seperate_cross.py @@ -0,0 +1,953 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\ + semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\ + referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './pretrained/omg_llava/omg_llava_7b_pretrain_1024image_seperate_cross_8gpus.pth' +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + visual_prompt_proj=True, + add_cross_attn_layer=True, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset, + glamm_grandf_dataset, glamm_psg_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + region_cap_osprey_dataset, region_conversation_osprey_dataset, + mdpv_detailed_description_ade20k_dataset, + mdpv_detailed_description_cocostuff_10k_dataset, + mdpv_detailed_description_cocostuff_164k_dataset, + mdpv_detailed_description_vg_dataset, + mdpv_brief_description_lvis_dataset, + mdpv_brief_description_vg_dataset, + mdpv_brief_description_ade20k_dataset, + mdpv_brief_description_cocostuff10k_dataset, + mdpv_brief_description_cocostuff164k_dataset, + mdpv_qa_vg_dataset, + mdpv_qa_lvis_dataset, + mdpv_qa_ade20k_dataset, + mdpv_qa_cocostuff10k_dataset, + mdpv_qa_cocostuff164k_dataset, + mdpv_multi_points_flicker30k_dataset, + mdpv_multi_points_openpsg_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/ablation_projector/finetune_ablation_projector_seperate_cross_debug.py b/omg_llava/configs/finetune/ablation_projector/finetune_ablation_projector_seperate_cross_debug.py new file mode 100644 index 0000000000000000000000000000000000000000..383e5cf1e0bbe5e13c1e9cd4eb9994cdb7064f49 --- /dev/null +++ b/omg_llava/configs/finetune/ablation_projector/finetune_ablation_projector_seperate_cross_debug.py @@ -0,0 +1,926 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\ + semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\ + referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './pretrained/omg_llava/omg_llava_7b_pretrain_1024image_seperate_cross_8gpus.pth' +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 0 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + visual_prompt_proj=True, + add_cross_attn_layer=True, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[mdpv_brief_description_lvis_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/debug.py b/omg_llava/configs/finetune/debug.py new file mode 100644 index 0000000000000000000000000000000000000000..78d2b667c30cd0fada27a2b27475f01f89b7cefb --- /dev/null +++ b/omg_llava/configs/finetune/debug.py @@ -0,0 +1,967 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = './pretrained/omg_llava/internlm2-7b' # Please change to your own path +pretrained_pth = './pretrained/omg_llava/omg_llava_7b_pretrain_1024image_8gpus.pth' +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +semantic_seg_mapillary_dataset = dict( + type=MapillarySemanticSegDataset, + data_path=mapillary_class_file, + image_folder=mapillary_image_path, + label_path=mapillary_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +semantic_seg_pascal_part_dataset = dict( + type=PascalPartSemanticSegDataset, + data_path=pascal_file, + image_folder=pascal_part_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=pascal_part_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +semantic_seg_paco_dataset = dict( + type=PacoSemanticSegDataset, + data_path=paco_file, + image_folder=paco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=pascal_part_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[glamm_refcocog_dataset, + ], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/fix_unfrozen_bug_omg_llava_7b_finetune_stage1_1024image_uniSegFormat_8gpus.py b/omg_llava/configs/finetune/fix_unfrozen_bug_omg_llava_7b_finetune_stage1_1024image_uniSegFormat_8gpus.py new file mode 100644 index 0000000000000000000000000000000000000000..1e8f9f91932249edf521af375f7da95c10924c94 --- /dev/null +++ b/omg_llava/configs/finetune/fix_unfrozen_bug_omg_llava_7b_finetune_stage1_1024image_uniSegFormat_8gpus.py @@ -0,0 +1,951 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\ + semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\ + referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './pretrained/omg_llava/omg_llava_7b_pretrain_1024image_8gpus.pth' +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 0 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset, + glamm_grandf_dataset, glamm_psg_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + region_cap_osprey_dataset, region_conversation_osprey_dataset, + mdpv_detailed_description_ade20k_dataset, + mdpv_detailed_description_cocostuff_10k_dataset, + mdpv_detailed_description_cocostuff_164k_dataset, + mdpv_detailed_description_vg_dataset, + mdpv_brief_description_lvis_dataset, + mdpv_brief_description_vg_dataset, + mdpv_brief_description_ade20k_dataset, + mdpv_brief_description_cocostuff10k_dataset, + mdpv_brief_description_cocostuff164k_dataset, + mdpv_qa_vg_dataset, + mdpv_qa_lvis_dataset, + mdpv_qa_ade20k_dataset, + mdpv_qa_cocostuff10k_dataset, + mdpv_qa_cocostuff164k_dataset, + mdpv_multi_points_flicker30k_dataset, + mdpv_multi_points_openpsg_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/hf_app.py b/omg_llava/configs/finetune/hf_app.py new file mode 100644 index 0000000000000000000000000000000000000000..9e5ef9d1fcb16b9c3c96e5c0eb9b7f4992df7c19 --- /dev/null +++ b/omg_llava/configs/finetune/hf_app.py @@ -0,0 +1,951 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\ + semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\ + referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './pretrained/omg_llava/omg_llava_7b_pretrain_1024image_8gpus.pth' +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset, + glamm_grandf_dataset, glamm_psg_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + region_cap_osprey_dataset, region_conversation_osprey_dataset, + mdpv_detailed_description_ade20k_dataset, + mdpv_detailed_description_cocostuff_10k_dataset, + mdpv_detailed_description_cocostuff_164k_dataset, + mdpv_detailed_description_vg_dataset, + mdpv_brief_description_lvis_dataset, + mdpv_brief_description_vg_dataset, + mdpv_brief_description_ade20k_dataset, + mdpv_brief_description_cocostuff10k_dataset, + mdpv_brief_description_cocostuff164k_dataset, + mdpv_qa_vg_dataset, + mdpv_qa_lvis_dataset, + mdpv_qa_ade20k_dataset, + mdpv_qa_cocostuff10k_dataset, + mdpv_qa_cocostuff164k_dataset, + mdpv_multi_points_flicker30k_dataset, + mdpv_multi_points_openpsg_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/omg_llava_20b_finetune_stage1_1024image_8gpus.py b/omg_llava/configs/finetune/omg_llava_20b_finetune_stage1_1024image_8gpus.py new file mode 100644 index 0000000000000000000000000000000000000000..3c38e99ae60a4ae2da5901ef3c6a1066507a555d --- /dev/null +++ b/omg_llava/configs/finetune/omg_llava_20b_finetune_stage1_1024image_8gpus.py @@ -0,0 +1,993 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = './pretrained/omg_llava/internlm2-chat-20b' # Please change to your own path +pretrained_pth = './pretrained/omg_llava/omg_llava_20b_pretrain_1024image_8gpus.pth' +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 4 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +semantic_seg_mapillary_dataset = dict( + type=MapillarySemanticSegDataset, + data_path=mapillary_class_file, + image_folder=mapillary_image_path, + label_path=mapillary_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +semantic_seg_pascal_part_dataset = dict( + type=PascalPartSemanticSegDataset, + data_path=pascal_file, + image_folder=pascal_part_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=pascal_part_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +semantic_seg_paco_dataset = dict( + type=PacoSemanticSegDataset, + data_path=paco_file, + image_folder=paco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=pascal_part_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset, + glamm_grandf_dataset, glamm_psg_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + region_cap_osprey_dataset, region_conversation_osprey_dataset, + mdpv_detailed_description_ade20k_dataset, + mdpv_detailed_description_cocostuff_10k_dataset, + mdpv_detailed_description_cocostuff_164k_dataset, + mdpv_detailed_description_vg_dataset, + mdpv_brief_description_lvis_dataset, + mdpv_brief_description_vg_dataset, + mdpv_brief_description_ade20k_dataset, + mdpv_brief_description_cocostuff10k_dataset, + mdpv_brief_description_cocostuff164k_dataset, + mdpv_qa_vg_dataset, + mdpv_qa_lvis_dataset, + mdpv_qa_ade20k_dataset, + mdpv_qa_cocostuff10k_dataset, + mdpv_qa_cocostuff164k_dataset, + mdpv_multi_points_flicker30k_dataset, + mdpv_multi_points_openpsg_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/omg_llava_7b_convnextXXL_finetune_stage1_1024image_uniSegFormat_8gpus.py b/omg_llava/configs/finetune/omg_llava_7b_convnextXXL_finetune_stage1_1024image_uniSegFormat_8gpus.py new file mode 100644 index 0000000000000000000000000000000000000000..ce00b4aee77b8abf7b90bf120026394778880a79 --- /dev/null +++ b/omg_llava/configs/finetune/omg_llava_7b_convnextXXL_finetune_stage1_1024image_uniSegFormat_8gpus.py @@ -0,0 +1,952 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\ + semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\ + referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './pretrained/omg_llava/omg_llava_7b_pretrain_1024image_convnextXXL.pth' +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_xxlarge_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convxxl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_xxlarge', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s34b_b82k_augreg_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[384, 768, 1536, 3072], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + clip_feat_channel=3072, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset, + glamm_grandf_dataset, glamm_psg_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + region_cap_osprey_dataset, region_conversation_osprey_dataset, + mdpv_detailed_description_ade20k_dataset, + mdpv_detailed_description_cocostuff_10k_dataset, + mdpv_detailed_description_cocostuff_164k_dataset, + mdpv_detailed_description_vg_dataset, + mdpv_brief_description_lvis_dataset, + mdpv_brief_description_vg_dataset, + mdpv_brief_description_ade20k_dataset, + mdpv_brief_description_cocostuff10k_dataset, + mdpv_brief_description_cocostuff164k_dataset, + mdpv_qa_vg_dataset, + mdpv_qa_lvis_dataset, + mdpv_qa_ade20k_dataset, + mdpv_qa_cocostuff10k_dataset, + mdpv_qa_cocostuff164k_dataset, + mdpv_multi_points_flicker30k_dataset, + mdpv_multi_points_openpsg_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_8gpus.py b/omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_8gpus.py new file mode 100644 index 0000000000000000000000000000000000000000..e3c9a42e5bf34fa3ffc9464a6c708e63c2bba10e --- /dev/null +++ b/omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_8gpus.py @@ -0,0 +1,993 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './pretrained/omg_llava/omg_llava_7b_pretrain_1024image_8gpus.pth' +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +semantic_seg_mapillary_dataset = dict( + type=MapillarySemanticSegDataset, + data_path=mapillary_class_file, + image_folder=mapillary_image_path, + label_path=mapillary_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +semantic_seg_pascal_part_dataset = dict( + type=PascalPartSemanticSegDataset, + data_path=pascal_file, + image_folder=pascal_part_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=pascal_part_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +semantic_seg_paco_dataset = dict( + type=PacoSemanticSegDataset, + data_path=paco_file, + image_folder=paco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=pascal_part_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset, + glamm_grandf_dataset, glamm_psg_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + region_cap_osprey_dataset, region_conversation_osprey_dataset, + mdpv_detailed_description_ade20k_dataset, + mdpv_detailed_description_cocostuff_10k_dataset, + mdpv_detailed_description_cocostuff_164k_dataset, + mdpv_detailed_description_vg_dataset, + mdpv_brief_description_lvis_dataset, + mdpv_brief_description_vg_dataset, + mdpv_brief_description_ade20k_dataset, + mdpv_brief_description_cocostuff10k_dataset, + mdpv_brief_description_cocostuff164k_dataset, + mdpv_qa_vg_dataset, + mdpv_qa_lvis_dataset, + mdpv_qa_ade20k_dataset, + mdpv_qa_cocostuff10k_dataset, + mdpv_qa_cocostuff164k_dataset, + mdpv_multi_points_flicker30k_dataset, + mdpv_multi_points_openpsg_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_8gpus_01.py b/omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_8gpus_01.py new file mode 100644 index 0000000000000000000000000000000000000000..a46b0008af47d266d60e2066255a381e918fbd35 --- /dev/null +++ b/omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_8gpus_01.py @@ -0,0 +1,1007 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = './pretrained/omg_llava/internlm2-7b' # Please change to your own path +pretrained_pth = './pretrained/omg_llava/omg_llava_7b_pretrain_1024image_8gpus.pth' +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, + num_proc=32 +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, + num_proc=32 +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, + num_proc=32 +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + num_proc=32 +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + num_proc=32 +) + +semantic_seg_mapillary_dataset = dict( + type=MapillarySemanticSegDataset, + data_path=mapillary_class_file, + image_folder=mapillary_image_path, + label_path=mapillary_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + num_proc=32 +) + +semantic_seg_pascal_part_dataset = dict( + type=PascalPartSemanticSegDataset, + data_path=pascal_file, + image_folder=pascal_part_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=pascal_part_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + num_proc=32 +) + +semantic_seg_paco_dataset = dict( + type=PacoSemanticSegDataset, + data_path=paco_file, + image_folder=paco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=pascal_part_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + num_proc=32 +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + num_proc=32 +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + num_proc=32 +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + num_proc=32 +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + num_proc=32 +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + num_proc=32 +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + num_proc=32 +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset, + glamm_grandf_dataset, glamm_psg_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + region_cap_osprey_dataset, region_conversation_osprey_dataset, + mdpv_detailed_description_ade20k_dataset, + mdpv_detailed_description_cocostuff_10k_dataset, + mdpv_detailed_description_cocostuff_164k_dataset, + mdpv_detailed_description_vg_dataset, + mdpv_brief_description_lvis_dataset, + mdpv_brief_description_vg_dataset, + mdpv_brief_description_ade20k_dataset, + mdpv_brief_description_cocostuff10k_dataset, + mdpv_brief_description_cocostuff164k_dataset, + mdpv_qa_vg_dataset, + mdpv_qa_lvis_dataset, + mdpv_qa_ade20k_dataset, + mdpv_qa_cocostuff10k_dataset, + mdpv_qa_cocostuff164k_dataset, + mdpv_multi_points_flicker30k_dataset, + mdpv_multi_points_openpsg_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_decoupleGCG_8gpus.py b/omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_decoupleGCG_8gpus.py new file mode 100644 index 0000000000000000000000000000000000000000..510b1d2e784b189086d6e2a1b61753350aea852a --- /dev/null +++ b/omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_decoupleGCG_8gpus.py @@ -0,0 +1,1028 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\ + semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\ + referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn,\ + DecoupledGranDfGCGDataset, DecoupledOpenPsgGCGDataset, DecoupledRefCOCOgGCGDataset, DecoupledFlickrGCGDataset,\ + glamm_openpsg_decoupled_given_description_map_fn, glamm_openpsg_decoupled_given_objects_map_fn,\ + glamm_flickr_decoupled_given_objects_map_fn, glamm_flickr_decoupled_given_description_map_fn,\ + glamm_granf_decoupled_given_objects_map_fn, glamm_granf_decoupled_given_description_map_fn,\ + glamm_refcocog_decoupled_given_objects_map_fn, glamm_refcocog_decoupled_given_description_map_fn + +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './pretrained/omg_llava/omg_llava_7b_pretrain_1024image_8gpus.pth' +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 0 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset_given_description = dict( + type=DecoupledRefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_decoupled_given_description_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + mode='given_description' +) + +glamm_refcocog_dataset_given_objects = dict( + type=DecoupledRefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_decoupled_given_objects_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + mode='given_objects' +) + +glamm_grandf_dataset_given_description = dict( + type=DecoupledGranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_decoupled_given_description_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, + mode='given_description' +) + +glamm_grandf_dataset_given_objects = dict( + type=DecoupledGranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_decoupled_given_objects_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, + mode='given_objects' +) + +glamm_psg_dataset_given_description = dict( + type=DecoupledOpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_decoupled_given_description_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, + mode='given_description' +) + +glamm_psg_dataset_given_objects = dict( + type=DecoupledOpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_decoupled_given_objects_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, + mode='given_objects' +) + +glamm_flickr_dataset_given_description = dict( + type=DecoupledFlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_decoupled_given_description_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, + mode='given_description' +) + +glamm_flickr_dataset_given_objects = dict( + type=DecoupledFlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_decoupled_given_objects_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, + mode='given_objects' +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[llava_dataset, + glamm_flickr_dataset_given_description, glamm_flickr_dataset_given_objects, + glamm_refcocog_dataset_given_objects, glamm_refcocog_dataset_given_description, + glamm_psg_dataset_given_description, glamm_psg_dataset_given_objects, + glamm_grandf_dataset_given_description, glamm_grandf_dataset_given_objects, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + region_cap_osprey_dataset, region_conversation_osprey_dataset, + mdpv_detailed_description_ade20k_dataset, + mdpv_detailed_description_cocostuff_10k_dataset, + mdpv_detailed_description_cocostuff_164k_dataset, + mdpv_detailed_description_vg_dataset, + mdpv_brief_description_lvis_dataset, + mdpv_brief_description_vg_dataset, + mdpv_brief_description_ade20k_dataset, + mdpv_brief_description_cocostuff10k_dataset, + mdpv_brief_description_cocostuff164k_dataset, + mdpv_qa_vg_dataset, + mdpv_qa_lvis_dataset, + mdpv_qa_ade20k_dataset, + mdpv_qa_cocostuff10k_dataset, + mdpv_qa_cocostuff164k_dataset, + mdpv_multi_points_flicker30k_dataset, + mdpv_multi_points_openpsg_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_decoupleGCG_8gpus_debug.py b/omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_decoupleGCG_8gpus_debug.py new file mode 100644 index 0000000000000000000000000000000000000000..38ac6d9d03f228c0189922fdfae7f4acbc7c6d04 --- /dev/null +++ b/omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_decoupleGCG_8gpus_debug.py @@ -0,0 +1,1000 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\ + semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\ + referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn,\ + DecoupledGranDfGCGDataset, DecoupledOpenPsgGCGDataset, DecoupledRefCOCOgGCGDataset, DecoupledFlickrGCGDataset,\ + glamm_openpsg_decoupled_given_description_map_fn, glamm_openpsg_decoupled_given_objects_map_fn,\ + glamm_flickr_decoupled_given_objects_map_fn, glamm_flickr_decoupled_given_description_map_fn,\ + glamm_granf_decoupled_given_objects_map_fn, glamm_granf_decoupled_given_description_map_fn,\ + glamm_refcocog_decoupled_given_objects_map_fn, glamm_refcocog_decoupled_given_description_map_fn + +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './pretrained/omg_llava/omg_llava_7b_pretrain_1024image_8gpus.pth' +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 0 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset_given_description = dict( + type=DecoupledRefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_decoupled_given_description_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + mode='given_description' +) + +glamm_refcocog_dataset_given_objects = dict( + type=DecoupledRefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_decoupled_given_objects_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + mode='given_objects' +) + +glamm_grandf_dataset_given_description = dict( + type=DecoupledGranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_decoupled_given_description_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, + mode='given_description' +) + +glamm_grandf_dataset_given_objects = dict( + type=DecoupledGranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_decoupled_given_objects_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, + mode='given_objects' +) + +glamm_psg_dataset_given_description = dict( + type=DecoupledOpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_decoupled_given_description_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, + mode='given_description' +) + +glamm_psg_dataset_given_objects = dict( + type=DecoupledOpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_decoupled_given_objects_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, + mode='given_objects' +) + +glamm_flickr_dataset_given_description = dict( + type=DecoupledFlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_decoupled_given_description_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, + mode='given_description' +) + +glamm_flickr_dataset_given_objects = dict( + type=DecoupledFlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_decoupled_given_objects_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, + mode='given_objects' +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[ + glamm_refcocog_dataset_given_objects, glamm_refcocog_dataset_given_description, + ], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_uniSegFormat_8gpus.py b/omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_uniSegFormat_8gpus.py new file mode 100644 index 0000000000000000000000000000000000000000..9e5ef9d1fcb16b9c3c96e5c0eb9b7f4992df7c19 --- /dev/null +++ b/omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_uniSegFormat_8gpus.py @@ -0,0 +1,951 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\ + semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\ + referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './pretrained/omg_llava/omg_llava_7b_pretrain_1024image_8gpus.pth' +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset, + glamm_grandf_dataset, glamm_psg_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + region_cap_osprey_dataset, region_conversation_osprey_dataset, + mdpv_detailed_description_ade20k_dataset, + mdpv_detailed_description_cocostuff_10k_dataset, + mdpv_detailed_description_cocostuff_164k_dataset, + mdpv_detailed_description_vg_dataset, + mdpv_brief_description_lvis_dataset, + mdpv_brief_description_vg_dataset, + mdpv_brief_description_ade20k_dataset, + mdpv_brief_description_cocostuff10k_dataset, + mdpv_brief_description_cocostuff164k_dataset, + mdpv_qa_vg_dataset, + mdpv_qa_lvis_dataset, + mdpv_qa_ade20k_dataset, + mdpv_qa_cocostuff10k_dataset, + mdpv_qa_cocostuff164k_dataset, + mdpv_multi_points_flicker30k_dataset, + mdpv_multi_points_openpsg_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/omg_llava_7b_finetune_stage2_1024image_8gpus.py b/omg_llava/configs/finetune/omg_llava_7b_finetune_stage2_1024image_8gpus.py new file mode 100644 index 0000000000000000000000000000000000000000..16ed4dc066a7efca8aef6f6dc665fd6cd459767b --- /dev/null +++ b/omg_llava/configs/finetune/omg_llava_7b_finetune_stage2_1024image_8gpus.py @@ -0,0 +1,994 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './work_dirs/omg_llava_7b_finetune_stage1_1024image_8gpus/iter_27600.pth' # noqa: E501 +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=True, + freeze_llm_with_lora=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +semantic_seg_mapillary_dataset = dict( + type=MapillarySemanticSegDataset, + data_path=mapillary_class_file, + image_folder=mapillary_image_path, + label_path=mapillary_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +semantic_seg_pascal_part_dataset = dict( + type=PascalPartSemanticSegDataset, + data_path=pascal_file, + image_folder=pascal_part_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=pascal_part_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +semantic_seg_paco_dataset = dict( + type=PacoSemanticSegDataset, + data_path=paco_file, + image_folder=paco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=pascal_part_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset, + glamm_grandf_dataset, glamm_psg_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + region_cap_osprey_dataset, region_conversation_osprey_dataset, + mdpv_detailed_description_ade20k_dataset, + mdpv_detailed_description_cocostuff_10k_dataset, + mdpv_detailed_description_cocostuff_164k_dataset, + mdpv_detailed_description_vg_dataset, + mdpv_brief_description_lvis_dataset, + mdpv_brief_description_vg_dataset, + mdpv_brief_description_ade20k_dataset, + mdpv_brief_description_cocostuff10k_dataset, + mdpv_brief_description_cocostuff164k_dataset, + mdpv_qa_vg_dataset, + mdpv_qa_lvis_dataset, + mdpv_qa_ade20k_dataset, + mdpv_qa_cocostuff10k_dataset, + mdpv_qa_cocostuff164k_dataset, + mdpv_multi_points_flicker30k_dataset, + mdpv_multi_points_openpsg_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/specific_tasks_finetune/finetune_gcg.py b/omg_llava/configs/finetune/specific_tasks_finetune/finetune_gcg.py new file mode 100644 index 0000000000000000000000000000000000000000..9e3dbef5b5583359795e05591509e9d79660455a --- /dev/null +++ b/omg_llava/configs/finetune/specific_tasks_finetune/finetune_gcg.py @@ -0,0 +1,925 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\ + semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\ + referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './work_dirs/omg_llava_1024x_2stage_finetune_1_clear_reratio_rmqcache_uniformSegFormat_8gpus.pth' +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=True, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[glamm_flickr_dataset, glamm_refcocog_dataset, + glamm_grandf_dataset, glamm_psg_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/specific_tasks_finetune/finetune_refseg.py b/omg_llava/configs/finetune/specific_tasks_finetune/finetune_refseg.py new file mode 100644 index 0000000000000000000000000000000000000000..a0126ce51774c06d72694f61fb67b2690d2ffa9c --- /dev/null +++ b/omg_llava/configs/finetune/specific_tasks_finetune/finetune_refseg.py @@ -0,0 +1,929 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\ + semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\ + referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './work_dirs/omg_llava_1024x_2stage_finetune_1_clear_reratio_rmqcache_uniformSegFormat_8gpus.pth' +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=True, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/pretrain/ablation_projector/ablation_projector_baseline.py b/omg_llava/configs/pretrain/ablation_projector/ablation_projector_baseline.py new file mode 100644 index 0000000000000000000000000000000000000000..0cb68e3552b3094466bdab4f7bbc49981251d0c1 --- /dev/null +++ b/omg_llava/configs/pretrain/ablation_projector/ablation_projector_baseline.py @@ -0,0 +1,377 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.dataset import DefaultSampler +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model.convnext_clip import OpenCLIPBackbone_omgseg +from omg_llava.model.omg_seg import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model or model paths +llm_name_or_path = '/mnt/workspace/taozhang/chekpoints/internlm2-chat-7b' # Please change to your own path +omg_ov_class_embed_path='/mnt/workspace/taozhang/chekpoints/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = '/mnt/workspace/taozhang/chekpoints/omg_seg_convl.pth' # Please change to your own path + +# Data paths +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Pretrain/blip_laion_cc_sbu_558k.json' +image_folder = data_root + 'LLaVA-Pretrain/images' +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2) + +# Scheduler & Optimizer +batch_size = 16 # per_device +accumulative_counts = 2 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 1e-3 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + +# Save +save_steps = 500 +save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 200 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +# using coco class as the class classifier +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + text2vision_projector=True, + keep_omg_decoder_frozen=True, + add_seg_pretrain=True, + pixel_shuffle_ratio=2, + visual_prompt_proj=False, + add_cross_attn_layer=False, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=llava_dataset, + sampler=dict(type=DefaultSampler, shuffle=True), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/pretrain/ablation_projector/ablation_projector_seperate_cross.py b/omg_llava/configs/pretrain/ablation_projector/ablation_projector_seperate_cross.py new file mode 100644 index 0000000000000000000000000000000000000000..b49b3c19287bb6146b9940bd483c229eaf542e57 --- /dev/null +++ b/omg_llava/configs/pretrain/ablation_projector/ablation_projector_seperate_cross.py @@ -0,0 +1,377 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.dataset import DefaultSampler +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model.convnext_clip import OpenCLIPBackbone_omgseg +from omg_llava.model.omg_seg import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model or model paths +llm_name_or_path = '/mnt/workspace/taozhang/chekpoints/internlm2-chat-7b' # Please change to your own path +omg_ov_class_embed_path='/mnt/workspace/taozhang/chekpoints/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = '/mnt/workspace/taozhang/chekpoints/omg_seg_convl.pth' # Please change to your own path + +# Data paths +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Pretrain/blip_laion_cc_sbu_558k.json' +image_folder = data_root + 'LLaVA-Pretrain/images' +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2) + +# Scheduler & Optimizer +batch_size = 16 # per_device +accumulative_counts = 2 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 1e-3 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + +# Save +save_steps = 500 +save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 200 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +# using coco class as the class classifier +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + text2vision_projector=True, + keep_omg_decoder_frozen=True, + add_seg_pretrain=True, + pixel_shuffle_ratio=2, + visual_prompt_proj=False, + add_cross_attn_layer=True, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=llava_dataset, + sampler=dict(type=DefaultSampler, shuffle=True), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/pretrain/ablation_projector/ablation_projector_seperate_cross_rmProjloss.py b/omg_llava/configs/pretrain/ablation_projector/ablation_projector_seperate_cross_rmProjloss.py new file mode 100644 index 0000000000000000000000000000000000000000..3e709027e948ba534ba334ce087e3175a0ab1aa4 --- /dev/null +++ b/omg_llava/configs/pretrain/ablation_projector/ablation_projector_seperate_cross_rmProjloss.py @@ -0,0 +1,377 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.dataset import DefaultSampler +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model.convnext_clip import OpenCLIPBackbone_omgseg +from omg_llava.model.omg_seg import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model or model paths +llm_name_or_path = '/mnt/workspace/taozhang/chekpoints/internlm2-chat-7b' # Please change to your own path +omg_ov_class_embed_path='/mnt/workspace/taozhang/chekpoints/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = '/mnt/workspace/taozhang/chekpoints/omg_seg_convl.pth' # Please change to your own path + +# Data paths +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Pretrain/blip_laion_cc_sbu_558k.json' +image_folder = data_root + 'LLaVA-Pretrain/images' +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2) + +# Scheduler & Optimizer +batch_size = 16 # per_device +accumulative_counts = 4 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 1e-3 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + +# Save +save_steps = 500 +save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 200 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +# using coco class as the class classifier +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + text2vision_projector=True, + keep_omg_decoder_frozen=True, + add_seg_pretrain=True, + pixel_shuffle_ratio=2, + visual_prompt_proj=False, + add_cross_attn_layer=True, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=llava_dataset, + sampler=dict(type=DefaultSampler, shuffle=True), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/pretrain/omg_llava_20b_pretrain_1024image_8gpus.py b/omg_llava/configs/pretrain/omg_llava_20b_pretrain_1024image_8gpus.py new file mode 100644 index 0000000000000000000000000000000000000000..7e78c928accbb1cf0733fd978b108a437db7a1fd --- /dev/null +++ b/omg_llava/configs/pretrain/omg_llava_20b_pretrain_1024image_8gpus.py @@ -0,0 +1,379 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.dataset import DefaultSampler +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model.convnext_clip import OpenCLIPBackbone_omgseg +from omg_llava.model.omg_seg import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model or model paths +llm_name_or_path = './pretrained/omg_llava/internlm2-chat-20b' # Please change to your own path +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path + +# Data paths +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Pretrain/blip_laion_cc_sbu_558k.json' +image_folder = data_root + 'LLaVA-Pretrain/images' +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2) + +# Scheduler & Optimizer +batch_size = 16 # per_device +accumulative_counts = 2 +dataloader_num_workers = 0 +max_epochs = 1 +optim_type = AdamW +lr = 1e-3 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + +# Save +save_steps = 500 +save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 200 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +# using coco class as the class classifier +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + + + + + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + text2vision_projector=True, + keep_omg_decoder_frozen=True, + add_seg_pretrain=True, + pixel_shuffle_ratio=2, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=llava_dataset, + sampler=dict(type=DefaultSampler, shuffle=True), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/pretrain/omg_llava_7b_pretrain_1024image_8gpus.py b/omg_llava/configs/pretrain/omg_llava_7b_pretrain_1024image_8gpus.py new file mode 100644 index 0000000000000000000000000000000000000000..65321b73fe90a8c51ad3ff102a1be01d534a5e17 --- /dev/null +++ b/omg_llava/configs/pretrain/omg_llava_7b_pretrain_1024image_8gpus.py @@ -0,0 +1,375 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.dataset import DefaultSampler +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model.convnext_clip import OpenCLIPBackbone_omgseg +from omg_llava.model.omg_seg import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model or model paths +llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path + +# Data paths +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Pretrain/blip_laion_cc_sbu_558k.json' +image_folder = data_root + 'LLaVA-Pretrain/images' +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2) + +# Scheduler & Optimizer +batch_size = 16 # per_device +accumulative_counts = 2 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 1e-3 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + +# Save +save_steps = 500 +save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 200 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +# using coco class as the class classifier +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + text2vision_projector=True, + keep_omg_decoder_frozen=True, + add_seg_pretrain=True, + pixel_shuffle_ratio=2, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=llava_dataset, + sampler=dict(type=DefaultSampler, shuffle=True), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/pretrain/omg_llava_7b_pretrain_1024image_convnextXXL_8gpus.py b/omg_llava/configs/pretrain/omg_llava_7b_pretrain_1024image_convnextXXL_8gpus.py new file mode 100644 index 0000000000000000000000000000000000000000..33f6780bb24f9396ec28597c91dc2dd445dd265e --- /dev/null +++ b/omg_llava/configs/pretrain/omg_llava_7b_pretrain_1024image_convnextXXL_8gpus.py @@ -0,0 +1,376 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.dataset import DefaultSampler +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model.convnext_clip import OpenCLIPBackbone_omgseg +from omg_llava.model.omg_seg import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model or model paths +llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_xxlarge_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convxxl.pth' # Please change to your own path + +# Data paths +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Pretrain/blip_laion_cc_sbu_558k.json' +image_folder = data_root + 'LLaVA-Pretrain/images' +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2) + +# Scheduler & Optimizer +batch_size = 16 # per_device +accumulative_counts = 4 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 1e-3 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + +# Save +save_steps = 500 +save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 200 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +# using coco class as the class classifier +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_xxlarge', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s34b_b82k_augreg_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[384, 768, 1536, 3072], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + text2vision_projector=True, + keep_omg_decoder_frozen=True, + add_seg_pretrain=True, + pixel_shuffle_ratio=2, + clip_feat_channel=3072, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=llava_dataset, + sampler=dict(type=DefaultSampler, shuffle=True), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/dataset/CombineDataset.py b/omg_llava/dataset/CombineDataset.py new file mode 100644 index 0000000000000000000000000000000000000000..db46456e27ffd52a66d36cb1ec0536af60b19657 --- /dev/null +++ b/omg_llava/dataset/CombineDataset.py @@ -0,0 +1,81 @@ +from torch.utils.data import Dataset +import numpy as np + +class CombineDataset(Dataset): + def __init__(self, + datasets_cfgs, + ): + super().__init__() + + self.datasets = [] + self.datasets_length = [] + + self.tokenizer = datasets_cfgs[0].tokenizer + tokenizer_type = self.tokenizer['type'] + del self.tokenizer['type'] + self.tokenizer = tokenizer_type(**self.tokenizer) + + self._add_special_tokens() + + for i in range(len(datasets_cfgs)): + datasets_cfgs[i].tokenizer = self.tokenizer + + for dataset_cfg in datasets_cfgs: + dataset = dataset_cfg['type'] + del dataset_cfg['type'] + dataset = dataset(**dataset_cfg) + self.datasets.append(dataset) + self.datasets_length.append(len(dataset)) + + self.dataset_threthold = [] + for i, length in enumerate(self.datasets_length): + if i == 0: + self.dataset_threthold.append(length) + else: + self.dataset_threthold.append(length + self.dataset_threthold[i - 1]) + + np.random.seed(42) + self.shuffled_index = np.arange(self.dataset_threthold[-1]) + np.random.shuffle(self.shuffled_index) + + @property + def modality_length(self): + length_list = [] + for dataset in self.datasets: + for data_dict in dataset.text_data: + cur_len = len(data_dict['input_ids']) + if data_dict.get('image', None) is None: + cur_len = -cur_len + length_list.append(cur_len) + return length_list + + def __len__(self): + return self.dataset_threthold[-1] + + def __getitem__(self, index): + index = int(self.shuffled_index[index]) + for i, thred in enumerate(self.dataset_threthold): + if index < thred: + break + + + if i == 0: + _index = index + else: + _index = index - self.dataset_threthold[i - 1] + + return self.datasets[i][_index] + + def _add_special_tokens(self): + assert hasattr(self, "tokenizer") + # Adding special tokens for pixel grounding + segmentation_tokens = ['[SEG]'] + # Adding tokens for GCG + phrase_tokens = ['

', '

'] + # add for visual prompt + region_tokens = [''] + point_tokens = [''] + special_tokens = segmentation_tokens + phrase_tokens + region_tokens + point_tokens + + self.tokenizer.add_tokens(special_tokens, special_tokens=True) + return \ No newline at end of file diff --git a/omg_llava/dataset/DecoupledGCGDataset.py b/omg_llava/dataset/DecoupledGCGDataset.py new file mode 100644 index 0000000000000000000000000000000000000000..f046aef9bf14f092412e49797dc41a3c2cfdb1f2 --- /dev/null +++ b/omg_llava/dataset/DecoupledGCGDataset.py @@ -0,0 +1,381 @@ +import json +import logging +import os + +import torch +from datasets import Dataset as HFDataset +from datasets import DatasetDict, load_from_disk +from mmengine import print_log +from mmengine.config import Config, ConfigDict +from PIL import Image +from torch.utils.data import Dataset +from pycocotools import mask +import numpy as np +import torch.nn.functional as F +import copy + +from xtuner.registry import BUILDER +from omg_llava.dataset.utils import expand2square, expand2square_mask +from xtuner.dataset.huggingface import process_hf_dataset + +class DecoupledGCGDataset(Dataset): + + def __init__(self, + image_folder, + image_processor, + data_path=None, + tokenizer=None, + offline_processed_text_folder=None, + max_dataset_length=None, + dataset_map_fn=None, + template_map_fn=None, + max_length=2048, + pad_image_to_square=False, + num_proc=32, + debug=False, + repeats=1, + mode='given_description'): + super().__init__() + + assert offline_processed_text_folder or (data_path and tokenizer) + self.debug = debug + if offline_processed_text_folder and data_path: + print_log( + 'Both `offline_processed_text_folder` and ' + '`data_path` are set, and we load dataset from' + '`offline_processed_text_folder` ' + f'({offline_processed_text_folder})', + logger='current', + level=logging.WARNING) + + if offline_processed_text_folder is not None: + raise NotImplementedError + else: + json_data = self.json_file_preprocess(data_path) + json_data = DatasetDict({'train': HFDataset.from_list(json_data)}) + self.text_data = process_hf_dataset( + dataset=json_data, + tokenizer=tokenizer, + max_length=max_length, + dataset_map_fn=dataset_map_fn, + template_map_fn=template_map_fn, + split='train', + max_dataset_length=max_dataset_length, + remove_unused_columns=False, + pack_to_max_length=False, + with_image_token=True, + map_num_proc=num_proc, # because limited mem + ) + + self.image_folder = image_folder + size = image_processor.crop_size + if isinstance(size, int): + self.image_h, self.image_w = size, size + else: + self.image_w, self.image_h = size + + if isinstance(image_processor, dict) or isinstance( + image_processor, Config) or isinstance(image_processor, + ConfigDict): + self.image_processor = BUILDER.build(image_processor) + else: + self.image_processor = image_processor + self.pad_image_to_square = pad_image_to_square + self.down_ratio = 1 + self.repeats = repeats + self.mode = mode + + def json_file_preprocess(self, data_path): + with open(data_path, 'r') as f: + json_data = json.load(f) + + # for quickly debug with mini split + if self.debug: + json_data = json_data[:100] + return json_data + + @property + def modality_length(self): + length_list = [] + for data_dict in self.text_data: + cur_len = len(data_dict['input_ids']) + if data_dict.get('image', None) is None: + cur_len = -cur_len + length_list.append(cur_len) + length_list = length_list * self.repeats + return length_list + + def __len__(self): + return len(self.text_data) * self.repeats + + def real_len(self): + return len(self.text_data) + + def decode_mask(self, object_masks, ori_height, ori_width): + binary_masks = [] + for object_mask in object_masks: + binary_mask = np.zeros((ori_height, ori_width), dtype=np.uint8) + for seg in object_mask: + rles = mask.frPyObjects([seg], ori_height, ori_width) + m = mask.decode(rles) + m = m.astype(np.uint8) + binary_mask += m.squeeze() + + binary_masks.append(binary_mask) + if len(binary_masks) == 0: + return None + masks = np.stack(binary_masks, axis=0) + if self.pad_image_to_square: + masks = expand2square_mask(masks) + masks = torch.from_numpy(masks) + masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio, self.image_w // self.down_ratio), mode='nearest').squeeze(0) + return masks + + def __getitem__(self, index): + index = index % self.real_len() + data_dict = copy.deepcopy(self.text_data[index]) + + if data_dict.get('image', None) is not None: + image_file = data_dict['image'] + image = Image.open(os.path.join(self.image_folder, + image_file)).convert('RGB') + ori_width, ori_height = image.size + if self.pad_image_to_square: + image = expand2square( + image, + tuple( + int(x * 255) for x in self.image_processor.image_mean)) + image = self.image_processor.preprocess( + image, return_tensors='pt')['pixel_values'][0] + data_dict['pixel_values'] = image + + # process and get masks + data_dict['masks'] = self.decode_mask(data_dict['masks'], ori_height=ori_height, ori_width=ori_width) + + assert self.mode in ['given_objects', 'given_description'] + if self.mode == 'given_objects': + data_dict['regions'] = copy.deepcopy(data_dict['masks']) + + # if data_dict['masks'] is None: + # return self.__getitem__(0) + else: + if hasattr(self.image_processor, 'crop_size'): + crop_size = self.image_processor.crop_size + else: + crop_size = self.image_processor.size + data_dict['pixel_values'] = torch.zeros(3, crop_size['height'], + crop_size['width']) + data_dict['masks'] = None + return data_dict + +class DecoupledRefCOCOgGCGDataset(DecoupledGCGDataset): + def __init__(self, + image_folder, + image_processor, + data_path=None, + tokenizer=None, + offline_processed_text_folder=None, + max_dataset_length=None, + dataset_map_fn=None, + template_map_fn=None, + max_length=2048, + pad_image_to_square=False, + debug=False, + repeats=1, + mode='given_description', + ): + super().__init__( + image_folder=image_folder, + image_processor=image_processor, + data_path=data_path, + tokenizer=tokenizer, + offline_processed_text_folder=offline_processed_text_folder, + max_dataset_length=max_dataset_length, + dataset_map_fn=dataset_map_fn, + template_map_fn=template_map_fn, + max_length=max_length, + pad_image_to_square=pad_image_to_square, + debug=debug, + repeats=repeats, + mode=mode, + ) + + def json_file_preprocess(self, data_path): + json_data = json.load(open(data_path)) + if self.debug: + json_data = json_data[:100] + + # convert {id: dict} to dict(..., id=xx) + for idx in range(len(json_data)): + id = list(json_data[idx].keys())[0] + json_data[idx] = json_data[idx][id] + json_data[idx].update({'id': id}) + return json_data + +class DecoupledGranDfGCGDataset(DecoupledGCGDataset): + def __init__(self, + image_folder, + image_processor, + data_path=None, + tokenizer=None, + offline_processed_text_folder=None, + max_dataset_length=None, + dataset_map_fn=None, + template_map_fn=None, + max_length=2048, + pad_image_to_square=False, + num_proc=4, + debug=False, + repeats=1, + mode='given_description'): + super().__init__( + image_folder=image_folder, + image_processor=image_processor, + data_path=data_path, + tokenizer=tokenizer, + offline_processed_text_folder=offline_processed_text_folder, + max_dataset_length=max_dataset_length, + dataset_map_fn=dataset_map_fn, + template_map_fn=template_map_fn, + max_length=max_length, + pad_image_to_square=pad_image_to_square, + num_proc=num_proc, + debug=debug, + repeats=repeats, + mode=mode + ) + + def decode_mask(self, object_masks, ori_height, ori_width): + binary_masks = [] + for object_mask in object_masks: + binary_mask = np.zeros((ori_height, ori_width), dtype=np.uint8) + + for rle in object_mask: + m = mask.decode(rle).astype(np.uint8) + binary_mask += m.squeeze() + + binary_masks.append(binary_mask) + if len(binary_masks) == 0: + return None + masks = np.stack(binary_masks, axis=0) + if self.pad_image_to_square: + masks = expand2square_mask(masks) + masks = torch.from_numpy(masks) + masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio, self.image_w // self.down_ratio), mode='nearest').squeeze(0) + return masks + +class DecoupledOpenPsgGCGDataset(DecoupledGranDfGCGDataset): + def __init__(self, + image_folder, + image_processor, + data_path=None, + tokenizer=None, + offline_processed_text_folder=None, + max_dataset_length=None, + dataset_map_fn=None, + template_map_fn=None, + max_length=2048, + pad_image_to_square=False, + num_proc=4, + debug=False, + repeats=1, + mode='given_description'): + super().__init__( + image_folder=image_folder, + image_processor=image_processor, + data_path=data_path, + tokenizer=tokenizer, + offline_processed_text_folder=offline_processed_text_folder, + max_dataset_length=max_dataset_length, + dataset_map_fn=dataset_map_fn, + template_map_fn=template_map_fn, + max_length=max_length, + pad_image_to_square=pad_image_to_square, + num_proc=num_proc, + debug=debug, + repeats=repeats, + mode=mode + ) + +class DecoupledFlickrGCGDataset(DecoupledGCGDataset): + def __init__(self, + image_folder, + image_processor, + data_path=None, + tokenizer=None, + offline_processed_text_folder=None, + max_dataset_length=None, + dataset_map_fn=None, + template_map_fn=None, + max_length=2048, + pad_image_to_square=False, + num_proc=4, + debug=False, + repeats=1, + mode='given_description' + ): + super().__init__( + image_folder=image_folder, + image_processor=image_processor, + data_path=data_path, + tokenizer=tokenizer, + offline_processed_text_folder=offline_processed_text_folder, + max_dataset_length=max_dataset_length, + dataset_map_fn=dataset_map_fn, + template_map_fn=template_map_fn, + max_length=max_length, + pad_image_to_square=pad_image_to_square, + num_proc=num_proc, + debug=debug, + repeats=repeats, + mode=mode + ) + + def json_file_preprocess(self, data_path): + def filter_images(data_infos, min_size): + return [i for i, info in enumerate(data_infos) if min(info['width'], info['height']) >= min_size] + + # convert {id: dict} to dict(..., id=xx) + from pycocotools.coco import COCO + self.coco = COCO(data_path) + self.image_ids = self.coco.getImgIds() + data_infos = [] + total_ann_ids = [] + removed_img_count = 0 + for img_id in self.image_ids: + info = self.coco.loadImgs([img_id])[0] + if len(info['caption'].split(' ')) < 3: + removed_img_count += 1 + continue + info['filename'] = info['file_name'].split('_')[-1] + info['height'] = int(info['height']) + info['width'] = int(info['width']) + data_infos.append(info) + ann_ids = self.coco.getAnnIds(imgIds=[img_id]) + total_ann_ids.extend(ann_ids) + assert len(set(total_ann_ids)) == len(total_ann_ids), f"Non-unique annotation IDs in '{data_path}'!" + print(f'Removed {removed_img_count} images.') + data_infos = [data_infos[i] for i in filter_images(data_infos, min_size=32)] + + # obtain_annotations + for data_info in data_infos: + ann_ids = self.coco.getAnnIds(imgIds=data_info['id']) + ann_info = self.coco.loadAnns(ann_ids) + data_info.update({'ann_info': ann_info}) + if self.debug: + data_infos = data_infos[:32] + return data_infos + + def decode_mask(self, object_masks, ori_height, ori_width): + binary_masks = [] + for object_mask in object_masks: + binary_mask = mask.decode(object_mask).astype(np.uint8) + binary_masks.append(binary_mask) + if len(binary_masks) == 0: + return None + masks = np.stack(binary_masks, axis=0) + if self.pad_image_to_square: + masks = expand2square_mask(masks) + masks = torch.from_numpy(masks) + masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio, self.image_w // self.down_ratio), mode='nearest').squeeze(0) + return masks \ No newline at end of file diff --git a/omg_llava/dataset/GCGDataset.py b/omg_llava/dataset/GCGDataset.py new file mode 100644 index 0000000000000000000000000000000000000000..f04ead5dcbb672a831e4b5f985ab03611f81ec33 --- /dev/null +++ b/omg_llava/dataset/GCGDataset.py @@ -0,0 +1,364 @@ +import json +import logging +import os + +import torch +from datasets import Dataset as HFDataset +from datasets import DatasetDict, load_from_disk +from mmengine import print_log +from mmengine.config import Config, ConfigDict +from PIL import Image +from torch.utils.data import Dataset +from pycocotools import mask +import numpy as np +import torch.nn.functional as F +import copy + +from xtuner.registry import BUILDER +from omg_llava.dataset.utils import expand2square, expand2square_mask +from xtuner.dataset.huggingface import process_hf_dataset + +class GCGDataset(Dataset): + + def __init__(self, + image_folder, + image_processor, + data_path=None, + tokenizer=None, + offline_processed_text_folder=None, + max_dataset_length=None, + dataset_map_fn=None, + template_map_fn=None, + max_length=2048, + pad_image_to_square=False, + num_proc=32, + debug=False, + repeats=1): + super().__init__() + + assert offline_processed_text_folder or (data_path and tokenizer) + self.debug = debug + if offline_processed_text_folder and data_path: + print_log( + 'Both `offline_processed_text_folder` and ' + '`data_path` are set, and we load dataset from' + '`offline_processed_text_folder` ' + f'({offline_processed_text_folder})', + logger='current', + level=logging.WARNING) + + if offline_processed_text_folder is not None: + raise NotImplementedError + else: + json_data = self.json_file_preprocess(data_path) + json_data = DatasetDict({'train': HFDataset.from_list(json_data)}) + self.text_data = process_hf_dataset( + dataset=json_data, + tokenizer=tokenizer, + max_length=max_length, + dataset_map_fn=dataset_map_fn, + template_map_fn=template_map_fn, + split='train', + max_dataset_length=max_dataset_length, + remove_unused_columns=False, + pack_to_max_length=False, + with_image_token=True, + map_num_proc=num_proc, # because limited mem + ) + + self.image_folder = image_folder + size = image_processor.crop_size + if isinstance(size, int): + self.image_h, self.image_w = size, size + else: + self.image_w, self.image_h = size + + if isinstance(image_processor, dict) or isinstance( + image_processor, Config) or isinstance(image_processor, + ConfigDict): + self.image_processor = BUILDER.build(image_processor) + else: + self.image_processor = image_processor + self.pad_image_to_square = pad_image_to_square + self.down_ratio = 1 + self.repeats = repeats + + def json_file_preprocess(self, data_path): + with open(data_path, 'r') as f: + json_data = json.load(f) + + # for quickly debug with mini split + if self.debug: + json_data = json_data[:100] + return json_data + + @property + def modality_length(self): + length_list = [] + for data_dict in self.text_data: + cur_len = len(data_dict['input_ids']) + if data_dict.get('image', None) is None: + cur_len = -cur_len + length_list.append(cur_len) + length_list = length_list * self.repeats + return length_list + + def __len__(self): + return len(self.text_data) * self.repeats + + def real_len(self): + return len(self.text_data) + + def decode_mask(self, object_masks, ori_height, ori_width): + binary_masks = [] + for object_mask in object_masks: + binary_mask = np.zeros((ori_height, ori_width), dtype=np.uint8) + for seg in object_mask: + rles = mask.frPyObjects([seg], ori_height, ori_width) + m = mask.decode(rles) + m = m.astype(np.uint8) + binary_mask += m.squeeze() + + binary_masks.append(binary_mask) + if len(binary_masks) == 0: + return None + masks = np.stack(binary_masks, axis=0) + if self.pad_image_to_square: + masks = expand2square_mask(masks) + masks = torch.from_numpy(masks) + masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio, self.image_w // self.down_ratio), mode='nearest').squeeze(0) + return masks + + def __getitem__(self, index): + index = index % self.real_len() + data_dict = copy.deepcopy(self.text_data[index]) + + if data_dict.get('image', None) is not None: + image_file = data_dict['image'] + image = Image.open(os.path.join(self.image_folder, + image_file)).convert('RGB') + ori_width, ori_height = image.size + if self.pad_image_to_square: + image = expand2square( + image, + tuple( + int(x * 255) for x in self.image_processor.image_mean)) + image = self.image_processor.preprocess( + image, return_tensors='pt')['pixel_values'][0] + data_dict['pixel_values'] = image + + # process and get masks + data_dict['masks'] = self.decode_mask(data_dict['masks'], ori_height=ori_height, ori_width=ori_width) + if data_dict['masks'] is None: + return self.__getitem__(0) + else: + if hasattr(self.image_processor, 'crop_size'): + crop_size = self.image_processor.crop_size + else: + crop_size = self.image_processor.size + data_dict['pixel_values'] = torch.zeros(3, crop_size['height'], + crop_size['width']) + data_dict['masks'] = None + return data_dict + +class RefCOCOgGCGDataset(GCGDataset): + def __init__(self, + image_folder, + image_processor, + data_path=None, + tokenizer=None, + offline_processed_text_folder=None, + max_dataset_length=None, + dataset_map_fn=None, + template_map_fn=None, + max_length=2048, + pad_image_to_square=False, + debug=False, + repeats=1,): + super().__init__( + image_folder=image_folder, + image_processor=image_processor, + data_path=data_path, + tokenizer=tokenizer, + offline_processed_text_folder=offline_processed_text_folder, + max_dataset_length=max_dataset_length, + dataset_map_fn=dataset_map_fn, + template_map_fn=template_map_fn, + max_length=max_length, + pad_image_to_square=pad_image_to_square, + debug=debug, + repeats=repeats, + ) + + def json_file_preprocess(self, data_path): + json_data = json.load(open(data_path)) + if self.debug: + json_data = json_data[:100] + + # convert {id: dict} to dict(..., id=xx) + for idx in range(len(json_data)): + id = list(json_data[idx].keys())[0] + json_data[idx] = json_data[idx][id] + json_data[idx].update({'id': id}) + return json_data + +class GranDfGCGDataset(GCGDataset): + def __init__(self, + image_folder, + image_processor, + data_path=None, + tokenizer=None, + offline_processed_text_folder=None, + max_dataset_length=None, + dataset_map_fn=None, + template_map_fn=None, + max_length=2048, + pad_image_to_square=False, + num_proc=4, + debug=False, + repeats=1): + super().__init__( + image_folder=image_folder, + image_processor=image_processor, + data_path=data_path, + tokenizer=tokenizer, + offline_processed_text_folder=offline_processed_text_folder, + max_dataset_length=max_dataset_length, + dataset_map_fn=dataset_map_fn, + template_map_fn=template_map_fn, + max_length=max_length, + pad_image_to_square=pad_image_to_square, + num_proc=num_proc, + debug=debug, + repeats=repeats, + ) + + def decode_mask(self, object_masks, ori_height, ori_width): + binary_masks = [] + for object_mask in object_masks: + binary_mask = np.zeros((ori_height, ori_width), dtype=np.uint8) + + for rle in object_mask: + m = mask.decode(rle).astype(np.uint8) + binary_mask += m.squeeze() + + binary_masks.append(binary_mask) + if len(binary_masks) == 0: + return None + masks = np.stack(binary_masks, axis=0) + if self.pad_image_to_square: + masks = expand2square_mask(masks) + masks = torch.from_numpy(masks) + masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio, self.image_w // self.down_ratio), mode='nearest').squeeze(0) + return masks + +class OpenPsgGCGDataset(GranDfGCGDataset): + def __init__(self, + image_folder, + image_processor, + data_path=None, + tokenizer=None, + offline_processed_text_folder=None, + max_dataset_length=None, + dataset_map_fn=None, + template_map_fn=None, + max_length=2048, + pad_image_to_square=False, + num_proc=4, + debug=False, + repeats=1): + super().__init__( + image_folder=image_folder, + image_processor=image_processor, + data_path=data_path, + tokenizer=tokenizer, + offline_processed_text_folder=offline_processed_text_folder, + max_dataset_length=max_dataset_length, + dataset_map_fn=dataset_map_fn, + template_map_fn=template_map_fn, + max_length=max_length, + pad_image_to_square=pad_image_to_square, + num_proc=num_proc, + debug=debug, + repeats=repeats, + ) + +class FlickrGCGDataset(GCGDataset): + def __init__(self, + image_folder, + image_processor, + data_path=None, + tokenizer=None, + offline_processed_text_folder=None, + max_dataset_length=None, + dataset_map_fn=None, + template_map_fn=None, + max_length=2048, + pad_image_to_square=False, + num_proc=4, + debug=False, + repeats=1,): + super().__init__( + image_folder=image_folder, + image_processor=image_processor, + data_path=data_path, + tokenizer=tokenizer, + offline_processed_text_folder=offline_processed_text_folder, + max_dataset_length=max_dataset_length, + dataset_map_fn=dataset_map_fn, + template_map_fn=template_map_fn, + max_length=max_length, + pad_image_to_square=pad_image_to_square, + num_proc=num_proc, + debug=debug, + repeats=repeats, + ) + + def json_file_preprocess(self, data_path): + def filter_images(data_infos, min_size): + return [i for i, info in enumerate(data_infos) if min(info['width'], info['height']) >= min_size] + + # convert {id: dict} to dict(..., id=xx) + from pycocotools.coco import COCO + self.coco = COCO(data_path) + self.image_ids = self.coco.getImgIds() + data_infos = [] + total_ann_ids = [] + removed_img_count = 0 + for img_id in self.image_ids: + info = self.coco.loadImgs([img_id])[0] + if len(info['caption'].split(' ')) < 3: + removed_img_count += 1 + continue + info['filename'] = info['file_name'].split('_')[-1] + info['height'] = int(info['height']) + info['width'] = int(info['width']) + data_infos.append(info) + ann_ids = self.coco.getAnnIds(imgIds=[img_id]) + total_ann_ids.extend(ann_ids) + assert len(set(total_ann_ids)) == len(total_ann_ids), f"Non-unique annotation IDs in '{data_path}'!" + print(f'Removed {removed_img_count} images.') + data_infos = [data_infos[i] for i in filter_images(data_infos, min_size=32)] + + # obtain_annotations + for data_info in data_infos: + ann_ids = self.coco.getAnnIds(imgIds=data_info['id']) + ann_info = self.coco.loadAnns(ann_ids) + data_info.update({'ann_info': ann_info}) + if self.debug: + data_infos = data_infos[:32] + return data_infos + + def decode_mask(self, object_masks, ori_height, ori_width): + binary_masks = [] + for object_mask in object_masks: + binary_mask = mask.decode(object_mask).astype(np.uint8) + binary_masks.append(binary_mask) + if len(binary_masks) == 0: + return None + masks = np.stack(binary_masks, axis=0) + if self.pad_image_to_square: + masks = expand2square_mask(masks) + masks = torch.from_numpy(masks) + masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio, self.image_w // self.down_ratio), mode='nearest').squeeze(0) + return masks \ No newline at end of file diff --git a/omg_llava/dataset/LlavaDataset.py b/omg_llava/dataset/LlavaDataset.py new file mode 100644 index 0000000000000000000000000000000000000000..fd8158187e94a31fa58df07046a11e08988fa9c0 --- /dev/null +++ b/omg_llava/dataset/LlavaDataset.py @@ -0,0 +1,134 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import logging +import os + +import torch +from datasets import Dataset as HFDataset +from datasets import DatasetDict, load_from_disk +from mmengine import print_log +from mmengine.config import Config, ConfigDict +from PIL import Image +from torch.utils.data import Dataset + +from xtuner.registry import BUILDER +from xtuner.dataset.huggingface import process_hf_dataset +from .utils import expand2square +import copy + +class LLaVADataset(Dataset): + + def __init__(self, + image_folder, + image_processor, + data_path=None, + tokenizer=None, + offline_processed_text_folder=None, + max_dataset_length=None, + dataset_map_fn=None, + template_map_fn=None, + max_length=2048, + pad_image_to_square=False, + debug=False): + super().__init__() + + assert offline_processed_text_folder or (data_path and tokenizer) + + self.tokenizer = tokenizer + if isinstance(tokenizer, dict) or isinstance( + tokenizer, Config) or isinstance(tokenizer, ConfigDict): + tokenizer_type = self.tokenizer['type'] + del self.tokenizer['type'] + self.tokenizer = tokenizer_type(**self.tokenizer) + self._add_special_tokens() + + if offline_processed_text_folder and data_path: + print_log( + 'Both `offline_processed_text_folder` and ' + '`data_path` are set, and we load dataset from' + '`offline_processed_text_folder` ' + f'({offline_processed_text_folder})', + logger='current', + level=logging.WARNING) + + if offline_processed_text_folder is not None: + self.text_data = load_from_disk(offline_processed_text_folder) + else: + json_data = json.load(open(data_path)) + if debug: + json_data = json_data[:10000] + for idx in range(len(json_data)): + if isinstance(json_data[idx]['id'], int): + json_data[idx]['id'] = str(json_data[idx]['id']) + json_data = DatasetDict({'train': HFDataset.from_list(json_data)}) + self.text_data = process_hf_dataset( + dataset=json_data, + tokenizer=self.tokenizer, + max_length=max_length, + dataset_map_fn=dataset_map_fn, + template_map_fn=template_map_fn, + split='train', + max_dataset_length=max_dataset_length, + remove_unused_columns=False, + pack_to_max_length=False, + with_image_token=True, + map_num_proc=32, # because limited mem + ) + + self.image_folder = image_folder + if isinstance(image_processor, dict) or isinstance( + image_processor, Config) or isinstance(image_processor, + ConfigDict): + self.image_processor = BUILDER.build(image_processor) + else: + self.image_processor = image_processor + self.pad_image_to_square = pad_image_to_square + + @property + def modality_length(self): + length_list = [] + for data_dict in self.text_data: + cur_len = len(data_dict['input_ids']) + if data_dict.get('image', None) is None: + cur_len = -cur_len + length_list.append(cur_len) + return length_list + + def __len__(self): + return len(self.text_data) + + def __getitem__(self, index): + data_dict = copy.deepcopy(self.text_data[index]) + if data_dict.get('image', None) is not None: + image_file = data_dict['image'] + image = Image.open(os.path.join(self.image_folder, + image_file)).convert('RGB') + if self.pad_image_to_square: + image = expand2square( + image, + tuple( + int(x * 255) for x in self.image_processor.image_mean)) + image = self.image_processor.preprocess( + image, return_tensors='pt')['pixel_values'][0] + data_dict['pixel_values'] = image + else: + if hasattr(self.image_processor, 'crop_size'): + crop_size = self.image_processor.crop_size + else: + crop_size = self.image_processor.size + data_dict['pixel_values'] = torch.zeros(3, crop_size['height'], + crop_size['width']) + return data_dict + + def _add_special_tokens(self): + assert hasattr(self, "tokenizer") + # Adding special tokens for pixel grounding + segmentation_tokens = ['[SEG]'] + # Adding tokens for GCG + phrase_tokens = ['

', '

'] + # add for visual prompt + region_tokens = [''] + point_tokens = [''] + special_tokens = segmentation_tokens + phrase_tokens + region_tokens + point_tokens + self.tokenizer.add_tokens(special_tokens, special_tokens=True) + return \ No newline at end of file diff --git a/omg_llava/dataset/MDPVPointsDataset.py b/omg_llava/dataset/MDPVPointsDataset.py new file mode 100644 index 0000000000000000000000000000000000000000..fe146d75fa08473ae1d552303ec995ca08475acb --- /dev/null +++ b/omg_llava/dataset/MDPVPointsDataset.py @@ -0,0 +1,220 @@ +import json +import logging +import os +import torch +from datasets import Dataset as HFDataset +from datasets import DatasetDict, load_from_disk +from mmengine import print_log +from mmengine.config import Config, ConfigDict +from PIL import Image +from torch.utils.data import Dataset +from pycocotools import mask +import numpy as np +import torch.nn.functional as F + +from xtuner.registry import BUILDER +from omg_llava.dataset.utils import expand2square, expand2square_mask, expand2square_points +from xtuner.dataset.huggingface import process_hf_dataset +import copy + +class MDPVPointDetailedCaptionDataset(Dataset): + def __init__(self, + image_folder, + image_processor, + data_path=None, + tokenizer=None, + offline_processed_text_folder=None, + max_dataset_length=None, + dataset_map_fn=None, + template_map_fn=None, + max_length=2048, + pad_image_to_square=False, + num_proc=32, + debug=False, + repeats=1): + super().__init__() + + assert offline_processed_text_folder or (data_path and tokenizer) + self.debug = debug + if offline_processed_text_folder and data_path: + print_log( + 'Both `offline_processed_text_folder` and ' + '`data_path` are set, and we load dataset from' + '`offline_processed_text_folder` ' + f'({offline_processed_text_folder})', + logger='current', + level=logging.WARNING) + + if offline_processed_text_folder is not None: + raise NotImplementedError + else: + json_data = self.json_file_preprocess(data_path) + self.json_data = json_data + hf_json_data = self.filter_hf_require_infos(json_data) + hf_json_data = DatasetDict({'train': HFDataset.from_list(hf_json_data)}) + self.text_data = process_hf_dataset( + dataset=hf_json_data, + tokenizer=tokenizer, + max_length=max_length, + dataset_map_fn=dataset_map_fn, + template_map_fn=template_map_fn, + split='train', + max_dataset_length=max_dataset_length, + remove_unused_columns=False, + pack_to_max_length=False, + with_image_token=True, + map_num_proc=num_proc, # because limited mem + ) + + self.image_folder = image_folder + size = image_processor.crop_size + if isinstance(size, int): + self.image_h, self.image_w = size, size + else: + self.image_w, self.image_h = size + + if isinstance(image_processor, dict) or isinstance( + image_processor, Config) or isinstance(image_processor, + ConfigDict): + self.image_processor = BUILDER.build(image_processor) + else: + self.image_processor = image_processor + self.pad_image_to_square = pad_image_to_square + self.down_ratio = 1 + self.repeats = repeats + + def filter_hf_require_infos(self, dataset_infos): + ret = [] + for dataset_info in dataset_infos: + conversations = dataset_info["conversations"] + image = dataset_info['image'].split('/')[-1] + num_marks = len(dataset_info['points']) + required_info = {'image': image, + 'conversations': conversations, + 'num_marks': num_marks} + ret.append(required_info) + return ret + + def json_file_preprocess(self, data_path): + with open(data_path, 'r') as f: + json_file = json.load(f) + if self.debug: + json_file = json_file[:10000] + return json_file + + @property + def modality_length(self): + length_list = [] + for data_dict in self.text_data: + cur_len = len(data_dict['input_ids']) + if data_dict.get('image', None) is None: + cur_len = -cur_len + length_list.append(cur_len) + length_list = length_list * self.repeats + return length_list + + def __len__(self): + return len(self.text_data) * self.repeats + + def real_len(self): + return len(self.text_data) + + def decode_mask(self, object_masks, ori_height, ori_width): + binary_masks = [] + for object_mask in object_masks: + binary_mask = np.zeros((ori_height, ori_width), dtype=np.uint8) + for seg in object_mask: + rles = mask.frPyObjects([seg], ori_height, ori_width) + m = mask.decode(rles) + m = m.astype(np.uint8) + binary_mask += m.squeeze() + + binary_masks.append(binary_mask) + if len(binary_masks) == 0: + return None + masks = np.stack(binary_masks, axis=0) + if self.pad_image_to_square: + masks = expand2square_mask(masks) + masks = torch.from_numpy(masks) + masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio, self.image_w // self.down_ratio), mode='nearest').squeeze(0) + return masks + + def __getitem__(self, index): + index = index % self.real_len() + data_dict = copy.deepcopy(self.json_data[index]) + data_dict.update(self.text_data[index]) + + if data_dict.get('image', None) is not None: + image_file = data_dict['image'] + image_path = os.path.join(self.image_folder, image_file) + if not os.path.exists(image_path) and "VG" in self.image_folder: + image_path = os.path.join(self.image_folder + "_2", image_file) + image = Image.open(image_path).convert('RGB') + ori_width, ori_height = image.size + if self.pad_image_to_square: + image = expand2square( + image, + tuple( + int(x * 255) for x in self.image_processor.image_mean)) + + image = self.image_processor.preprocess( + image, return_tensors='pt')['pixel_values'][0] + data_dict['pixel_values'] = image + + # process and get masks + points = data_dict["points"] + points = np.array(points) + if self.pad_image_to_square: + points = expand2square_points(points, height=ori_height, width=ori_width) + points[:, 0] = points[:, 0] / max(ori_height, ori_width) * self.image_w + points[:, 1] = points[:, 1] / max(ori_height, ori_width) * self.image_h + else: + points[:, 0] = points[:, 0] / ori_width * self.image_w + points[:, 1] = points[:, 1] / ori_height * self.image_h + data_dict['points'] = torch.from_numpy(points) + if data_dict['points'] is None: + return self.__getitem__(0) + data_dict['masks'] = None + data_dict['regions'] = None + else: + if hasattr(self.image_processor, 'crop_size'): + crop_size = self.image_processor.crop_size + else: + crop_size = self.image_processor.size + data_dict['pixel_values'] = torch.zeros(3, crop_size['height'], + crop_size['width']) + data_dict['masks'] = None + data_dict['regions'] = None + data_dict['points'] = None + return data_dict + +class MDPVPointBriefCaptionDataset(MDPVPointDetailedCaptionDataset): + def __init__(self, + image_folder, + image_processor, + data_path=None, + tokenizer=None, + offline_processed_text_folder=None, + max_dataset_length=None, + dataset_map_fn=None, + template_map_fn=None, + max_length=2048, + pad_image_to_square=False, + num_proc=32, + debug=False, + repeats=1): + super().__init__( + image_folder=image_folder, + image_processor=image_processor, + data_path=data_path, + tokenizer=tokenizer, + offline_processed_text_folder=offline_processed_text_folder, + max_dataset_length=max_dataset_length, + dataset_map_fn=dataset_map_fn, + template_map_fn=template_map_fn, + max_length=max_length, + pad_image_to_square=pad_image_to_square, + num_proc=num_proc, + debug=debug, + repeats=repeats + ) diff --git a/omg_llava/dataset/ReferringSegDataset.py b/omg_llava/dataset/ReferringSegDataset.py new file mode 100644 index 0000000000000000000000000000000000000000..df83acb2cab8cc52e271eaa5c8990feb1631f957 --- /dev/null +++ b/omg_llava/dataset/ReferringSegDataset.py @@ -0,0 +1,380 @@ +import logging +import os +import torch +from datasets import Dataset as HFDataset +from datasets import DatasetDict, load_from_disk +from mmengine import print_log +from mmengine.config import Config, ConfigDict +from PIL import Image +from torch.utils.data import Dataset +from pycocotools import mask +import numpy as np +import torch.nn.functional as F + +from xtuner.registry import BUILDER +from omg_llava.dataset.utils import expand2square, expand2square_mask +from xtuner.dataset.huggingface import process_hf_dataset +from omg_llava.dataset.utils.refcoco_refer import REFER +import copy + +class RefcocoReferringSegDataset(Dataset): + def __init__(self, + image_folder, + image_processor, + data_path=None, + tokenizer=None, + offline_processed_text_folder=None, + max_dataset_length=None, + dataset_map_fn=None, + template_map_fn=None, + max_length=2048, + pad_image_to_square=False, + num_proc=8, + debug=False, + repeats=1,): + self._set_attribute() + self.tokenizer = tokenizer + assert offline_processed_text_folder or (data_path and tokenizer) + self.debug = debug + if offline_processed_text_folder and data_path: + print_log( + 'Both `offline_processed_text_folder` and ' + '`data_path` are set, and we load dataset from' + '`offline_processed_text_folder` ' + f'({offline_processed_text_folder})', + logger='current', + level=logging.WARNING) + + if offline_processed_text_folder is not None: + raise NotImplementedError + else: + json_datas = self.json_file_preprocess(data_path) + self.json_datas = json_datas + json_datas = self.only_get_hf_map_infos() + json_data = DatasetDict({'train': HFDataset.from_list(json_datas)}) + self.text_data = process_hf_dataset( + dataset=json_data, + tokenizer=tokenizer, + max_length=max_length, + dataset_map_fn=dataset_map_fn, + template_map_fn=template_map_fn, + split='train', + max_dataset_length=max_dataset_length, + remove_unused_columns=False, + pack_to_max_length=False, + with_image_token=True, + map_num_proc=num_proc, # because limited mem + ) + + self.image_folder = image_folder + size = image_processor.crop_size + if isinstance(size, int): + self.image_h, self.image_w = size, size + else: + self.image_w, self.image_h = size + + if isinstance(image_processor, dict) or isinstance( + image_processor, Config) or isinstance(image_processor, + ConfigDict): + self.image_processor = BUILDER.build(image_processor) + else: + self.image_processor = image_processor + self.pad_image_to_square = pad_image_to_square + self.down_ratio = 1 + self.repeats = repeats + + def _set_attribute(self): + self.splitBy = "unc" + self.dataset_name = 'refcoco' + + def only_get_hf_map_infos(self): + ret = [] + for json_data in self.json_datas: + ret.append({'sampled_sents': json_data['selected_labels']}) + return ret + + def __len__(self): + return len(self.text_data) * self.repeats + + @property + def modality_length(self): + length_list = [] + for data_dict in self.text_data: + cur_len = len(data_dict['input_ids']) + if data_dict.get('image', None) is None: + cur_len = -cur_len + length_list.append(cur_len) + length_list = length_list * self.repeats + return length_list + + def real_len(self): + return len(self.text_data) + + def json_file_preprocess(self, data_path): + splitBy = self.splitBy + dataset_name = self.dataset_name + refer_api = REFER(data_path, dataset_name, splitBy) + ref_ids_train = refer_api.getRefIds(split='train') + images_ids_train = refer_api.getImgIds(ref_ids=ref_ids_train) + refs_train = refer_api.loadRefs(ref_ids=ref_ids_train) + self.img2refs = self.create_img_to_refs_mapping(refs_train) + + image_infos = [] + loaded_images = refer_api.loadImgs(image_ids=images_ids_train) + for item in loaded_images: + item = item.copy() + image_infos.append(item) + + self.annotations = refer_api.Anns + + refs = [self.img2refs[image_info['id']] for image_info in image_infos] + + ret = [] + for image_info, ref in zip(image_infos, refs): + if len(ref) == 0: + continue + + sents = [] + ann_ids = [] + for _ref in ref: + for sent in _ref["sentences"]: + text = sent["sent"] + sents.append(text) + ann_ids.append(_ref["ann_id"]) + if len(sents) >= 3: + sampled_inds = np.random.choice( + list(range(len(sents))), size=3, replace=False + ) + else: + sampled_inds = list(range(len(sents))) + sampled_sents = np.vectorize(sents.__getitem__)(sampled_inds).tolist() + sampled_ann_ids = [ann_ids[ind] for ind in sampled_inds] + selected_labels = sampled_sents + ret.append( + {'image_info': image_info, + 'sampled_ann_id': sampled_ann_ids, + 'selected_labels': selected_labels, + 'image': image_info['file_name'] + } + ) + if self.debug: + return ret[:1000] + return ret + + def create_img_to_refs_mapping(self, refs_train): + img2refs = {} + for ref in refs_train: + img2refs[ref["image_id"]] = img2refs.get(ref["image_id"], []) + [ref, ] + return img2refs + + def decode_mask(self, annotations_ids, image_info): + flag = False + masks = [] + + for ann_id in annotations_ids: + if isinstance(ann_id, list): + flag = True + if -1 in ann_id: + assert len(ann_id) == 1 + m = np.zeros((image_info["height"], image_info["width"])).astype( + np.uint8 + ) + else: + m_final = np.zeros( + (image_info["height"], image_info["width"]) + ).astype(np.uint8) + for ann_id_i in ann_id: + ann = self.annotations[ann_id_i] + + if len(ann["segmentation"]) == 0: + m = np.zeros( + (image_info["height"], image_info["width"]) + ).astype(np.uint8) + else: + if type(ann["segmentation"][0]) == list: # polygon + rle = mask.frPyObjects( + ann["segmentation"], image_info["height"], image_info["width"], ) + else: + rle = ann["segmentation"] + for i in range(len(rle)): + if not isinstance(rle[i]["counts"], bytes): + rle[i]["counts"] = rle[i]["counts"].encode() + m = mask.decode(rle) + m = np.sum( + m, axis=2 + ) # sometimes there are multiple binary map (corresponding to multiple segs) + m = m.astype(np.uint8) # convert to np.uint8 + m_final = m_final | m + m = m_final + masks.append(m) + continue + + ann = self.annotations[ann_id] + + if len(ann["segmentation"]) == 0: + m = np.zeros((image_info["height"], image_info["width"])).astype( + np.uint8 + ) + masks.append(m) + continue + + if type(ann["segmentation"][0]) == list: # polygon + rle = mask.frPyObjects( + ann["segmentation"], image_info["height"], image_info["width"] + ) + else: + rle = ann["segmentation"] + for i in range(len(rle)): + if not isinstance(rle[i]["counts"], bytes): + rle[i]["counts"] = rle[i]["counts"].encode() + m = mask.decode(rle) + m = np.sum(m, axis=2) # sometimes there are multiple binary map (corresponding to multiple segs) + m = m.astype(np.uint8) # convert to np.uint8 + masks.append(m) + masks = np.stack(masks, axis=0) + + if self.pad_image_to_square: + masks = expand2square_mask(masks) + masks = torch.from_numpy(masks) + masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio, + self.image_w // self.down_ratio), mode='nearest').squeeze(0) + return masks + + def __getitem__(self, index): + index = index % self.real_len() + data_dict = copy.deepcopy(self.text_data[index]) + data_dict.update(self.json_datas[index]) + + if data_dict.get('image', None) is not None: + image_file = data_dict['image'] + image_file = os.path.join(self.image_folder, image_file) + image = Image.open(image_file).convert('RGB') + ori_width, ori_height = image.size + if self.pad_image_to_square: + image = expand2square( + image, + tuple( + int(x * 255) for x in self.image_processor.image_mean)) + image = self.image_processor.preprocess( + image, return_tensors='pt')['pixel_values'][0] + data_dict['pixel_values'] = image + + # process and get masks + masks = self.decode_mask(data_dict['sampled_ann_id'], data_dict['image_info']) + data_dict['masks'] = masks + else: + if hasattr(self.image_processor, 'crop_size'): + crop_size = self.image_processor.crop_size + else: + crop_size = self.image_processor.size + data_dict['pixel_values'] = torch.zeros(3, crop_size['height'], + crop_size['width']) + data_dict['masks'] = None + return data_dict + +class Refcoco_plus_ReferringSegDataset(RefcocoReferringSegDataset): + def __init__(self, + image_folder, + image_processor, + data_path=None, + tokenizer=None, + offline_processed_text_folder=None, + max_dataset_length=None, + dataset_map_fn=None, + template_map_fn=None, + max_length=2048, + pad_image_to_square=False, + num_proc=8, + debug=False, + repeats=1,): + + super().__init__( + image_folder=image_folder, + image_processor=image_processor, + data_path=data_path, + tokenizer=tokenizer, + offline_processed_text_folder=offline_processed_text_folder, + max_dataset_length=max_dataset_length, + dataset_map_fn=dataset_map_fn, + template_map_fn=template_map_fn, + max_length=max_length, + pad_image_to_square=pad_image_to_square, + num_proc=num_proc, + debug=debug, + repeats=repeats,) + + def _set_attribute(self): + self.splitBy = "unc" + self.dataset_name = 'refcoco+' + +class Refcocog_ReferringSegDataset(RefcocoReferringSegDataset): + def __init__(self, + image_folder, + image_processor, + data_path=None, + tokenizer=None, + offline_processed_text_folder=None, + max_dataset_length=None, + dataset_map_fn=None, + template_map_fn=None, + max_length=2048, + pad_image_to_square=False, + num_proc=8, + debug=False, + repeats=1,): + + super().__init__( + image_folder=image_folder, + image_processor=image_processor, + data_path=data_path, + tokenizer=tokenizer, + offline_processed_text_folder=offline_processed_text_folder, + max_dataset_length=max_dataset_length, + dataset_map_fn=dataset_map_fn, + template_map_fn=template_map_fn, + max_length=max_length, + pad_image_to_square=pad_image_to_square, + num_proc=num_proc, + debug=debug, + repeats=repeats, + ) + + def _set_attribute(self): + self.splitBy = "umd" + self.dataset_name = 'refcocog' + +class Refclef_ReferringSegDataset(RefcocoReferringSegDataset): + def __init__(self, + image_folder, + image_processor, + data_path=None, + tokenizer=None, + offline_processed_text_folder=None, + max_dataset_length=None, + dataset_map_fn=None, + template_map_fn=None, + max_length=2048, + pad_image_to_square=False, + num_proc=8, + debug=False, + repeats=1,): + + super().__init__( + image_folder=image_folder, + image_processor=image_processor, + data_path=data_path, + tokenizer=tokenizer, + offline_processed_text_folder=offline_processed_text_folder, + max_dataset_length=max_dataset_length, + dataset_map_fn=dataset_map_fn, + template_map_fn=template_map_fn, + max_length=max_length, + pad_image_to_square=pad_image_to_square, + num_proc=num_proc, + debug=debug, + repeats=repeats, + ) + + def _set_attribute(self): + self.splitBy = "unc" + self.dataset_name = 'refclef' diff --git a/omg_llava/dataset/RegionCaptionDataset.py b/omg_llava/dataset/RegionCaptionDataset.py new file mode 100644 index 0000000000000000000000000000000000000000..f8b4e870596dcfeb4526a35bc4fa5c9f1465e512 --- /dev/null +++ b/omg_llava/dataset/RegionCaptionDataset.py @@ -0,0 +1,356 @@ +import json +import logging +import os +import copy + +import torch +from datasets import Dataset as HFDataset +from datasets import DatasetDict, load_from_disk +from mmengine import print_log +from mmengine.config import Config, ConfigDict +from PIL import Image, ImageDraw +from torch.utils.data import Dataset +from pycocotools import mask +import numpy as np +import torch.nn.functional as F + +from xtuner.registry import BUILDER +from omg_llava.dataset.utils import expand2square, expand2square_mask +from xtuner.dataset.huggingface import process_hf_dataset + +class OspreyRegionCaptionDataset(Dataset): + def __init__(self, + image_folder, + image_processor, + data_path=None, + tokenizer=None, + offline_processed_text_folder=None, + max_dataset_length=None, + dataset_map_fn=None, + template_map_fn=None, + max_length=2048, + pad_image_to_square=False, + num_proc=32, + debug=False, + repeats=1): + super().__init__() + + assert offline_processed_text_folder or (data_path and tokenizer) + self.debug = debug + if offline_processed_text_folder and data_path: + print_log( + 'Both `offline_processed_text_folder` and ' + '`data_path` are set, and we load dataset from' + '`offline_processed_text_folder` ' + f'({offline_processed_text_folder})', + logger='current', + level=logging.WARNING) + + if offline_processed_text_folder is not None: + raise NotImplementedError + else: + json_data = self.json_file_preprocess(data_path) + self.json_data = json_data + hf_json_data = self.filter_hf_require_infos(json_data) + hf_json_data = DatasetDict({'train': HFDataset.from_list(hf_json_data)}) + self.text_data = process_hf_dataset( + dataset=hf_json_data, + tokenizer=tokenizer, + max_length=max_length, + dataset_map_fn=dataset_map_fn, + template_map_fn=template_map_fn, + split='train', + max_dataset_length=max_dataset_length, + remove_unused_columns=False, + pack_to_max_length=False, + with_image_token=True, + map_num_proc=num_proc, # because limited mem + ) + + self.image_folder = image_folder + size = image_processor.crop_size + if isinstance(size, int): + self.image_h, self.image_w = size, size + else: + self.image_w, self.image_h = size + + if isinstance(image_processor, dict) or isinstance( + image_processor, Config) or isinstance(image_processor, + ConfigDict): + self.image_processor = BUILDER.build(image_processor) + else: + self.image_processor = image_processor + self.pad_image_to_square = pad_image_to_square + self.down_ratio = 1 + self.repeats = repeats + + def filter_hf_require_infos(self, dataset_infos): + ret = [] + for dataset_info in dataset_infos: + description = dataset_info["description"] + image = dataset_info['file_name'] + required_info = {'image': image, 'description': description} + ret.append(required_info) + return ret + + def json_file_preprocess(self, data_path): + with open(data_path, 'r') as f: + json_file = json.load(f) + + ret = [] + for item in json_file: + if len(item["description"]) != len(item["annotation"]): + print("The number of description is not equal to seg !!!") + else: + ret.append(item) + + if self.debug: + ret = ret[:10000] + return ret + + @property + def modality_length(self): + length_list = [] + for data_dict in self.text_data: + cur_len = len(data_dict['input_ids']) + if data_dict.get('image', None) is None: + cur_len = -cur_len + length_list.append(cur_len) + length_list = length_list * self.repeats + return length_list + + def __len__(self): + return len(self.text_data) * self.repeats + + def real_len(self): + return len(self.text_data) + + def decode_mask(self, object_masks, ori_height, ori_width): + binary_masks = [] + for object_mask in object_masks: + binary_mask = np.zeros((ori_height, ori_width), dtype=np.uint8) + for seg in object_mask: + rles = mask.frPyObjects([seg], ori_height, ori_width) + m = mask.decode(rles) + m = m.astype(np.uint8) + binary_mask += m.squeeze() + + binary_masks.append(binary_mask) + if len(binary_masks) == 0: + return None + masks = np.stack(binary_masks, axis=0) + if self.pad_image_to_square: + masks = expand2square_mask(masks) + masks = torch.from_numpy(masks) + masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio, self.image_w // self.down_ratio), mode='nearest').squeeze(0) + return masks + + def __getitem__(self, index): + index = index % self.real_len() + data_dict = copy.deepcopy(self.json_data[index]) + data_dict.update(self.text_data[index]) + + if data_dict.get('image', None) is not None: + image_file = data_dict['image'] + image = Image.open(os.path.join(self.image_folder, + image_file)).convert('RGB') + ori_width, ori_height = image.size + if self.pad_image_to_square: + image = expand2square( + image, + tuple( + int(x * 255) for x in self.image_processor.image_mean)) + image = self.image_processor.preprocess( + image, return_tensors='pt')['pixel_values'][0] + data_dict['pixel_values'] = image + + # process and get masks + annotations = data_dict['annotation'] + sampled_inds = data_dict['sampled_inds'] + annotations = [annotations[idx]['segmentation'] for idx in sampled_inds] + data_dict['regions'] = self.decode_mask(annotations, ori_height=ori_height, ori_width=ori_width) + + if data_dict['regions'] is None or len(data_dict['regions']) != len(sampled_inds): + print("Bad data item !!!") + return self.__getitem__(0) + seg_region_idx = data_dict['seg_region_idx'] + if len(seg_region_idx) == 0: + data_dict['masks'] = None + else: + data_dict['masks'] = data_dict['regions'][seg_region_idx] + else: + if hasattr(self.image_processor, 'crop_size'): + crop_size = self.image_processor.crop_size + else: + crop_size = self.image_processor.size + data_dict['pixel_values'] = torch.zeros(3, crop_size['height'], + crop_size['width']) + data_dict['masks'] = None + data_dict['regions'] = None + return data_dict + +class OspreyRegionConversationDataset(Dataset): + def __init__(self, + image_folder, + image_processor, + data_path=None, + tokenizer=None, + offline_processed_text_folder=None, + max_dataset_length=None, + dataset_map_fn=None, + template_map_fn=None, + max_length=2048, + pad_image_to_square=False, + num_proc=32, + debug=False, + repeats=1): + super().__init__() + + assert offline_processed_text_folder or (data_path and tokenizer) + self.debug = debug + if offline_processed_text_folder and data_path: + print_log( + 'Both `offline_processed_text_folder` and ' + '`data_path` are set, and we load dataset from' + '`offline_processed_text_folder` ' + f'({offline_processed_text_folder})', + logger='current', + level=logging.WARNING) + + if offline_processed_text_folder is not None: + raise NotImplementedError + else: + json_data = self.json_file_preprocess(data_path) + self.json_data = json_data + hf_json_data = self.filter_hf_require_infos(json_data) + hf_json_data = DatasetDict({'train': HFDataset.from_list(hf_json_data)}) + self.text_data = process_hf_dataset( + dataset=hf_json_data, + tokenizer=tokenizer, + max_length=max_length, + dataset_map_fn=dataset_map_fn, + template_map_fn=template_map_fn, + split='train', + max_dataset_length=max_dataset_length, + remove_unused_columns=False, + pack_to_max_length=False, + with_image_token=True, + map_num_proc=num_proc, # because limited mem + ) + + self.image_folder = image_folder + size = image_processor.crop_size + if isinstance(size, int): + self.image_h, self.image_w = size, size + else: + self.image_w, self.image_h = size + + if isinstance(image_processor, dict) or isinstance( + image_processor, Config) or isinstance(image_processor, + ConfigDict): + self.image_processor = BUILDER.build(image_processor) + else: + self.image_processor = image_processor + self.pad_image_to_square = pad_image_to_square + self.down_ratio = 1 + self.repeats = repeats + + def filter_hf_require_infos(self, dataset_infos): + ret = [] + for dataset_info in dataset_infos: + conversations = dataset_info["conversations"] + image = dataset_info['file_name'] + num_regions = len(dataset_info['annotation']) + required_info = {'image': image, 'conversations': conversations, + 'num_regions': num_regions} + ret.append(required_info) + return ret + + def json_file_preprocess(self, data_path): + with open(data_path, 'r') as f: + json_file = json.load(f) + + # filter + ret = [] + for dataset_info in json_file: + if 'annotation' not in dataset_info or len(dataset_info['annotation']) == 0: + print("The annotation is not valid, filter out!!!") + continue + ret.append(dataset_info) + + if self.debug: + ret = ret[:10000] + return ret + + @property + def modality_length(self): + length_list = [] + for data_dict in self.text_data: + cur_len = len(data_dict['input_ids']) + if data_dict.get('image', None) is None: + cur_len = -cur_len + length_list.append(cur_len) + length_list = length_list * self.repeats + return length_list + + def __len__(self): + return len(self.text_data) * self.repeats + + def real_len(self): + return len(self.text_data) + + def decode_mask(self, object_masks, ori_height, ori_width): + binary_masks = [] + for object_mask in object_masks: + binary_mask = np.zeros((ori_height, ori_width), dtype=np.uint8) + for seg in object_mask: + rles = mask.frPyObjects([seg], ori_height, ori_width) + m = mask.decode(rles) + m = m.astype(np.uint8) + binary_mask += m.squeeze() + binary_masks.append(binary_mask) + if len(binary_masks) == 0: + return None + masks = np.stack(binary_masks, axis=0) + if self.pad_image_to_square: + masks = expand2square_mask(masks) + masks = torch.from_numpy(masks) + masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio, self.image_w // self.down_ratio), mode='nearest').squeeze(0) + return masks + + def __getitem__(self, index): + index = index % self.real_len() + data_dict = copy.deepcopy(self.json_data[index]) + data_dict.update(self.text_data[index]) + + if data_dict.get('image', None) is not None: + image_file = data_dict['image'] + image = Image.open(os.path.join(self.image_folder, + image_file)).convert('RGB') + ori_width, ori_height = image.size + if self.pad_image_to_square: + image = expand2square( + image, + tuple( + int(x * 255) for x in self.image_processor.image_mean)) + image = self.image_processor.preprocess( + image, return_tensors='pt')['pixel_values'][0] + data_dict['pixel_values'] = image + + # process and get masks + annotations = data_dict['annotation'] + annotations = [annotations[idx]['segmentation'] for idx in range(len(annotations))] + data_dict['regions'] = self.decode_mask(annotations, ori_height=ori_height, ori_width=ori_width) + if data_dict['regions'] is None: + return self.__getitem__(0) + data_dict['masks'] = None + else: + if hasattr(self.image_processor, 'crop_size'): + crop_size = self.image_processor.crop_size + else: + crop_size = self.image_processor.size + data_dict['pixel_values'] = torch.zeros(3, crop_size['height'], + crop_size['width']) + data_dict['masks'] = None + data_dict['regions'] = None + return data_dict \ No newline at end of file diff --git a/omg_llava/dataset/SemanticSegDataset.py b/omg_llava/dataset/SemanticSegDataset.py new file mode 100644 index 0000000000000000000000000000000000000000..490f14d57818e34ac4bac06b287119515a5e0ead --- /dev/null +++ b/omg_llava/dataset/SemanticSegDataset.py @@ -0,0 +1,725 @@ +import random +import glob +import json +import logging +import os +import torch +from datasets import Dataset as HFDataset +from datasets import DatasetDict, load_from_disk +from mmengine import print_log +from mmengine.config import Config, ConfigDict +from PIL import Image +from torch.utils.data import Dataset +import numpy as np +import torch.nn.functional as F +from pycocotools.coco import COCO + +from xtuner.registry import BUILDER +from omg_llava.dataset.utils import expand2square, expand2square_mask +from xtuner.dataset.huggingface import process_hf_dataset +from omg_llava.dataset.process_functions.semantic_seg_process import semantic_seg_conversations, semantic_seg_gcg_format_conversations +import copy + +class SemanticSegDataset(Dataset): + def __init__(self, + image_folder, + image_processor, + data_path=None, + tokenizer=None, + offline_processed_text_folder=None, + max_dataset_length=None, + dataset_map_fn=None, + template_map_fn=None, + max_length=2048, + pad_image_to_square=False, + num_proc=8, + debug=False, + repeats=1, + gcg_format=False): + super().__init__() + self.tokenizer = tokenizer + assert offline_processed_text_folder or (data_path and tokenizer) + self.debug = debug + if offline_processed_text_folder and data_path: + print_log( + 'Both `offline_processed_text_folder` and ' + '`data_path` are set, and we load dataset from' + '`offline_processed_text_folder` ' + f'({offline_processed_text_folder})', + logger='current', + level=logging.WARNING) + + if offline_processed_text_folder is not None: + raise NotImplementedError + else: + self.image_label_datas = self.json_file_preprocess(data_path, image_folder) + if gcg_format: + conversations_datas = semantic_seg_gcg_format_conversations(self.classes) + else: + conversations_datas = semantic_seg_conversations(self.classes) + json_data = DatasetDict({'train': HFDataset.from_list(conversations_datas)}) + self.text_data = process_hf_dataset( + dataset=json_data, + tokenizer=tokenizer, + max_length=max_length, + dataset_map_fn=dataset_map_fn, + template_map_fn=template_map_fn, + split='train', + max_dataset_length=max_dataset_length, + remove_unused_columns=False, + pack_to_max_length=False, + with_image_token=True, + map_num_proc=num_proc, # because limited mem + ) + + self.clsid2convs = self.construct_cls2convs_dict() + self.image_folder = image_folder + size = image_processor.crop_size + if isinstance(size, int): + self.image_h, self.image_w = size, size + else: + self.image_w, self.image_h = size + + if isinstance(image_processor, dict) or isinstance( + image_processor, Config) or isinstance(image_processor, + ConfigDict): + self.image_processor = BUILDER.build(image_processor) + else: + self.image_processor = image_processor + self.pad_image_to_square = pad_image_to_square + self.down_ratio = 1 + self.repeats = repeats + + def construct_cls2convs_dict(self): + ret = {} + for conv_item in self.text_data: + cls_id = conv_item['class_id'] + if cls_id in ret.keys(): + ret[cls_id].append(conv_item) + else: + ret[cls_id] = [conv_item] + return ret + + def json_file_preprocess(self, data_path, image_folder): + # ade20k + with open(data_path, 'r') as file: + ade20k_classes = json.load(file) + ade20k_image_dir = image_folder + ade20k_images = [os.path.join(ade20k_image_dir, img) for img in os.listdir(ade20k_image_dir) if + img.endswith('.jpg')] + ade20k_labels = [img.replace(".jpg", ".png").replace("images", "annotations") for img in ade20k_images] + self.classes = np.array(ade20k_classes) + + ret = [] + for image, label in zip(ade20k_images, ade20k_labels): + ret.append({"image": image, "label": label}) + if self.debug: + return ret[:1000] + return ret + + def __len__(self): + return len(self.image_label_datas) * self.repeats + + @property + def modality_length(self): + length_list = [] + for data_dict in self.image_label_datas: + length_list.append(-100) + length_list = length_list * self.repeats + return length_list + + def real_len(self): + return len(self.image_label_datas) + + def decode_mask(self, label_path): + label = np.array(Image.open(label_path)) + + # ade 20k + label = np.where(label == 0, 255, label - 1) + unique_labels = [lbl for lbl in np.unique(label) if lbl != 255] + if not unique_labels: + return None, None + + # only choose 1 + selected_labels = np.random.choice( + unique_labels, 1, replace=False + ) + label = torch.from_numpy(label).long() + masks = torch.stack([label == class_id for class_id in selected_labels], dim=0) + + masks = masks.numpy() + if self.pad_image_to_square: + masks = expand2square_mask(masks) + + masks = torch.from_numpy(masks).to(torch.float32) + masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio, + self.image_w // self.down_ratio), mode='nearest').squeeze(0) + return masks, selected_labels[0] + + def __getitem__(self, index): + index = index % self.real_len() + data_dict = copy.deepcopy(self.image_label_datas[index]) + + if data_dict.get('image', None) is not None: + image_file = data_dict['image'] + image = Image.open(image_file).convert('RGB') + ori_width, ori_height = image.size + if self.pad_image_to_square: + image = expand2square( + image, + tuple( + int(x * 255) for x in self.image_processor.image_mean)) + image = self.image_processor.preprocess( + image, return_tensors='pt')['pixel_values'][0] + data_dict['pixel_values'] = image + + # process and get masks + data_dict['masks'], class_id = self.decode_mask(data_dict['label']) + if class_id is None: + return self.__getitem__(0) + conv_datas = self.clsid2convs[class_id] + selected_idx = np.random.randint(0, len(conv_datas)) + data_dict.update(conv_datas[selected_idx]) + else: + if hasattr(self.image_processor, 'crop_size'): + crop_size = self.image_processor.crop_size + else: + crop_size = self.image_processor.size + data_dict['pixel_values'] = torch.zeros(3, crop_size['height'], + crop_size['width']) + data_dict['masks'] = None + return data_dict + +class ADE20kSemanticSegDataset(SemanticSegDataset): + def __init__(self, + image_folder, + image_processor, + data_path=None, + tokenizer=None, + offline_processed_text_folder=None, + max_dataset_length=None, + dataset_map_fn=None, + template_map_fn=None, + max_length=2048, + pad_image_to_square=False, + num_proc=8, + debug=False, + repeats=1, + gcg_format=False): + super().__init__( + image_folder=image_folder, + image_processor=image_processor, + data_path=data_path, + tokenizer=tokenizer, + offline_processed_text_folder=offline_processed_text_folder, + max_dataset_length=max_dataset_length, + dataset_map_fn=dataset_map_fn, + template_map_fn=template_map_fn, + max_length=max_length, + pad_image_to_square=pad_image_to_square, + num_proc=num_proc, + debug=debug, + repeats=repeats, + gcg_format=gcg_format, + ) + +class COCOStuffSemanticSegDataset(SemanticSegDataset): + def __init__(self, + image_folder, + image_processor, + data_path=None, + tokenizer=None, + offline_processed_text_folder=None, + max_dataset_length=None, + dataset_map_fn=None, + template_map_fn=None, + max_length=2048, + pad_image_to_square=False, + num_proc=8, + debug=False, + repeats=1, + label_path=None, + gcg_format=False,): + self.label_path = label_path + super().__init__( + image_folder=image_folder, + image_processor=image_processor, + data_path=data_path, + tokenizer=tokenizer, + offline_processed_text_folder=offline_processed_text_folder, + max_dataset_length=max_dataset_length, + dataset_map_fn=dataset_map_fn, + template_map_fn=template_map_fn, + max_length=max_length, + pad_image_to_square=pad_image_to_square, + num_proc=num_proc, + debug=debug, + repeats=repeats, + gcg_format=gcg_format, + ) + self.cocostuff_class2index = {c: i for i, c in enumerate(self.classes)} + + def json_file_preprocess(self, data_path, image_folder): + # coco stuff + assert self.label_path is not None + with open(data_path, 'r') as file: + cocostuff_classes = [line.strip().split(": ")[-1] for line in file.readlines()[1:]] + coco_stuff_image_dir = image_folder + coco_stuff_label_dir = self.label_path + coco_stuff_labels = glob.glob(os.path.join(coco_stuff_label_dir, "*.png")) + + coco_stuff_images = [label.replace(".png", ".jpg").replace(coco_stuff_label_dir, coco_stuff_image_dir) + for label in coco_stuff_labels] + + self.classes = np.array(cocostuff_classes) + + ret = [] + for image, label in zip(coco_stuff_images, coco_stuff_labels): + ret.append({"image": image, "label": label}) + if self.debug: + return ret[:1000] + return ret + + def decode_mask(self, label_path): + label = np.array(Image.open(label_path)) + + # coco stuff + ignored_classes = [index for class_name, index in self.cocostuff_class2index.items() if + "-" in class_name] + label = np.where(np.isin(label, ignored_classes), 255, label) + + unique_labels = [lbl for lbl in np.unique(label) if lbl != 255] + if not unique_labels: + print("No valid label !!!") + return None, None + + # only choose 1 + selected_labels = np.random.choice( + unique_labels, 1, replace=False + ) + label = torch.from_numpy(label).long() + masks = torch.stack([label == class_id for class_id in selected_labels], dim=0) + + masks = masks.numpy() + if self.pad_image_to_square: + masks = expand2square_mask(masks) + + masks = torch.from_numpy(masks).to(torch.float32) + masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio, + self.image_w // self.down_ratio), mode='nearest').squeeze(0) + return masks, selected_labels[0] + +class MapillarySemanticSegDataset(SemanticSegDataset): + def __init__(self, + image_folder, + image_processor, + data_path=None, + tokenizer=None, + offline_processed_text_folder=None, + max_dataset_length=None, + dataset_map_fn=None, + template_map_fn=None, + max_length=2048, + pad_image_to_square=False, + num_proc=8, + debug=False, + repeats=1, + label_path=None, + gcg_format=False,): + self.label_path = label_path + super().__init__( + image_folder=image_folder, + image_processor=image_processor, + data_path=data_path, + tokenizer=tokenizer, + offline_processed_text_folder=offline_processed_text_folder, + max_dataset_length=max_dataset_length, + dataset_map_fn=dataset_map_fn, + template_map_fn=template_map_fn, + max_length=max_length, + pad_image_to_square=pad_image_to_square, + num_proc=num_proc, + debug=debug, + repeats=repeats, + gcg_format=gcg_format, + ) + + def json_file_preprocess(self, data_path, image_folder): + assert self.label_path is not None + # mapillary + with open(data_path, 'r') as file: + mapillary_classes = json.load(file)["labels"] + mapillary_classes = [cls["readable"].lower() for cls in mapillary_classes] + + mapillary_labels = sorted( + glob.glob(os.path.join(self.label_path, "*.png"))) + mapillary_images = [ + label.replace(".png", ".jpg").replace(self.label_path, image_folder) + for label in mapillary_labels] + + self.classes = np.array(mapillary_classes) + + ret = [] + for image, label in zip(mapillary_images, mapillary_labels): + ret.append({"image": image, "label": label}) + if self.debug: + return ret[:1000] + return ret + + def decode_mask(self, label_path): + label = np.array(Image.open(label_path)) + + ignored_classes = [index for index, class_name in enumerate(self.classes) if + "-" in class_name or '(' in class_name or + 'unlabeled' in class_name] + label = np.where(np.isin(label, ignored_classes), 255, label) + unique_labels = [lbl for lbl in np.unique(label) if lbl != 255] + if not unique_labels: + print("No valid label !!!") + return None, None + # only choose 1 + selected_labels = np.random.choice( + unique_labels, 1, replace=False + ) + label = torch.from_numpy(label).long() + masks = torch.stack([label == class_id for class_id in selected_labels], dim=0) + + masks = masks.numpy() + if self.pad_image_to_square: + masks = expand2square_mask(masks) + + masks = torch.from_numpy(masks).to(torch.float32) + masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio, + self.image_w // self.down_ratio), mode='nearest').squeeze(0) + return masks, selected_labels[0] + +class PascalPartSemanticSegDataset(Dataset): + def __init__(self, + image_folder, + image_processor, + data_path=None, + tokenizer=None, + offline_processed_text_folder=None, + max_dataset_length=None, + dataset_map_fn=None, + template_map_fn=None, + max_length=2048, + pad_image_to_square=False, + num_proc=8, + debug=False, + repeats=1): + super().__init__() + self.tokenizer = tokenizer + assert offline_processed_text_folder or (data_path and tokenizer) + self.debug = debug + if offline_processed_text_folder and data_path: + print_log( + 'Both `offline_processed_text_folder` and ' + '`data_path` are set, and we load dataset from' + '`offline_processed_text_folder` ' + f'({offline_processed_text_folder})', + logger='current', + level=logging.WARNING) + + if offline_processed_text_folder is not None: + raise NotImplementedError + else: + json_datas = self.json_file_preprocess(data_path) + json_data = DatasetDict({'train': HFDataset.from_list(json_datas)}) + self.text_data = process_hf_dataset( + dataset=json_data, + tokenizer=tokenizer, + max_length=max_length, + dataset_map_fn=dataset_map_fn, + template_map_fn=template_map_fn, + split='train', + max_dataset_length=max_dataset_length, + remove_unused_columns=False, + pack_to_max_length=False, + with_image_token=True, + map_num_proc=num_proc, # because limited mem + ) + + self.image_folder = image_folder + size = image_processor.crop_size + if isinstance(size, int): + self.image_h, self.image_w = size, size + else: + self.image_w, self.image_h = size + + if isinstance(image_processor, dict) or isinstance( + image_processor, Config) or isinstance(image_processor, + ConfigDict): + self.image_processor = BUILDER.build(image_processor) + else: + self.image_processor = image_processor + self.pad_image_to_square = pad_image_to_square + self.down_ratio = 1 + self.repeats = repeats + + def json_file_preprocess(self, data_path): + pascal_part_api = COCO(data_path) + all_classes = pascal_part_api.loadCats(pascal_part_api.getCatIds()) + class_map_pascal_part = {} + for cat in all_classes: + cat_main, cat_part = cat["name"].strip().split(":") + name = (cat_main, cat_part) + class_map_pascal_part[cat["id"]] = name + img_ids = pascal_part_api.getImgIds() + self.classes = class_map_pascal_part + self.coco_api = pascal_part_api + + img_infos = [self.coco_api.loadImgs([img_id])[0] for img_id in img_ids] + valid_img_infos = [] + for img_info in img_infos: + annotation_ids = self.coco_api.getAnnIds(imgIds=img_info["id"]) + annotations = self.coco_api.loadAnns(annotation_ids) + if not annotations: + continue + + # sampled to max number as 5 + sampled_anns = np.random.choice(annotations, 5, replace=False) if len( + annotations + ) >= 5 else annotations + + selected_labels = [] + for ann in sampled_anns: + category_id = ann["category_id"] + sampled_cls = self.classes[category_id] + if isinstance(sampled_cls, tuple): + obj, part = sampled_cls + name = f"{obj} {part}" if random.random() < 0.5 else f"the {part} of the {obj}" + else: + name = sampled_cls + selected_labels.append(name) + + img_info.update({"annotations": sampled_anns, + "selected_labels": selected_labels}) + valid_img_infos.append(img_info) + + if self.debug: + return valid_img_infos[:1000] + return valid_img_infos + + def __len__(self): + return len(self.text_data) * self.repeats + + @property + def modality_length(self): + length_list = [] + for data_dict in self.text_data: + cur_len = len(data_dict['input_ids']) + if data_dict.get('image', None) is None: + cur_len = -cur_len + length_list.append(cur_len) + length_list = length_list * self.repeats + return length_list + + def real_len(self): + return len(self.text_data) + + def decode_mask(self, annotations): + + try: + masks = [self.coco_api.annToMask(ann) for ann in annotations] + except Exception as e: + print(f"Error generating mask: {e}") + return None + + masks = np.stack(masks, axis=0) + if self.pad_image_to_square: + masks = expand2square_mask(masks) + masks = torch.from_numpy(masks) + masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio, + self.image_w // self.down_ratio), mode='nearest').squeeze(0) + return masks + + def __getitem__(self, index): + index = index % self.real_len() + data_dict = copy.deepcopy(self.text_data[index]) + + if data_dict.get('image', None) is not None: + image_file = data_dict['image'] + image_file = os.path.join(self.image_folder, image_file) + image = Image.open(image_file).convert('RGB') + ori_width, ori_height = image.size + if self.pad_image_to_square: + image = expand2square( + image, + tuple( + int(x * 255) for x in self.image_processor.image_mean)) + image = self.image_processor.preprocess( + image, return_tensors='pt')['pixel_values'][0] + data_dict['pixel_values'] = image + + # process and get masks + data_dict['masks'] = self.decode_mask(data_dict['annotations']) + if data_dict['masks'] is None: + return self.__getitem__(0) + else: + if hasattr(self.image_processor, 'crop_size'): + crop_size = self.image_processor.crop_size + else: + crop_size = self.image_processor.size + data_dict['pixel_values'] = torch.zeros(3, crop_size['height'], + crop_size['width']) + data_dict['masks'] = None + return data_dict + +class PacoSemanticSegDataset(PascalPartSemanticSegDataset): + def __init__(self, + image_folder, + image_processor, + data_path=None, + tokenizer=None, + offline_processed_text_folder=None, + max_dataset_length=None, + dataset_map_fn=None, + template_map_fn=None, + max_length=2048, + pad_image_to_square=False, + num_proc=8, + debug=False, + repeats=1,): + self.tokenizer = tokenizer + assert offline_processed_text_folder or (data_path and tokenizer) + self.debug = debug + if offline_processed_text_folder and data_path: + print_log( + 'Both `offline_processed_text_folder` and ' + '`data_path` are set, and we load dataset from' + '`offline_processed_text_folder` ' + f'({offline_processed_text_folder})', + logger='current', + level=logging.WARNING) + + if offline_processed_text_folder is not None: + raise NotImplementedError + else: + json_datas = self.json_file_preprocess(data_path) + self.json_datas = json_datas + json_datas = self.only_get_hf_map_infos() + json_data = DatasetDict({'train': HFDataset.from_list(json_datas)}) + self.text_data = process_hf_dataset( + dataset=json_data, + tokenizer=tokenizer, + max_length=max_length, + dataset_map_fn=dataset_map_fn, + template_map_fn=template_map_fn, + split='train', + max_dataset_length=max_dataset_length, + remove_unused_columns=False, + pack_to_max_length=False, + with_image_token=True, + map_num_proc=num_proc, # because limited mem + ) + + self.image_folder = image_folder + size = image_processor.crop_size + if isinstance(size, int): + self.image_h, self.image_w = size, size + else: + self.image_w, self.image_h = size + + if isinstance(image_processor, dict) or isinstance( + image_processor, Config) or isinstance(image_processor, + ConfigDict): + self.image_processor = BUILDER.build(image_processor) + else: + self.image_processor = image_processor + self.pad_image_to_square = pad_image_to_square + self.down_ratio = 1 + self.repeats = repeats + + def only_get_hf_map_infos(self): + ret = [] + for json_data in self.json_datas: + ret.append({'file_name': json_data['file_name'], + 'selected_labels': json_data['selected_labels']}) + return ret + + def json_file_preprocess(self, data_path): + paco_api = COCO(data_path) + all_classes = paco_api.loadCats(paco_api.getCatIds()) + class_map_paco = {} + for cat in all_classes: + cat_split = cat["name"].strip().split(":") + if len(cat_split) == 1: + name = cat_split[0].split("_(")[0] + else: + assert len(cat_split) == 2 + obj, part = cat_split + obj = obj.split("_(")[0] + part = part.split("_(")[0] + name = (obj, part) + class_map_paco[cat["id"]] = name + + img_ids = paco_api.getImgIds() + self.classes = class_map_paco + self.coco_api = paco_api + + img_infos = [self.coco_api.loadImgs([img_id])[0] for img_id in img_ids] + valid_img_infos = [] + for img_info in img_infos: + annotation_ids = self.coco_api.getAnnIds(imgIds=img_info["id"]) + annotations = self.coco_api.loadAnns(annotation_ids) + if not annotations: + continue + + # sampled to max number as 5 + sampled_anns = np.random.choice(annotations, 5, replace=False) if len( + annotations + ) >= 5 else annotations + + selected_labels = [] + for ann in sampled_anns: + category_id = ann["category_id"] + sampled_cls = self.classes[category_id] + if isinstance(sampled_cls, tuple): + obj, part = sampled_cls + name = f"{obj} {part}" if random.random() < 0.5 else f"the {part} of the {obj}" + else: + name = sampled_cls + selected_labels.append(name) + + img_info.update({"annotations": sampled_anns, + "selected_labels": selected_labels}) + valid_img_infos.append(img_info) + + if self.debug: + return valid_img_infos[:1000] + return valid_img_infos + + def __getitem__(self, index): + index = index % self.real_len() + data_dict = copy.deepcopy(self.text_data[index]) + data_dict.update(self.json_datas[index]) + + if data_dict.get('image', None) is not None: + image_file = data_dict['image'] + image_file = os.path.join(self.image_folder, image_file) + image = Image.open(image_file).convert('RGB') + ori_width, ori_height = image.size + if self.pad_image_to_square: + image = expand2square( + image, + tuple( + int(x * 255) for x in self.image_processor.image_mean)) + image = self.image_processor.preprocess( + image, return_tensors='pt')['pixel_values'][0] + data_dict['pixel_values'] = image + + # process and get masks + data_dict['masks'] = self.decode_mask(data_dict['annotations']) + if data_dict['masks'] is None: + return self.__getitem__(0) + else: + if hasattr(self.image_processor, 'crop_size'): + crop_size = self.image_processor.crop_size + else: + crop_size = self.image_processor.size + data_dict['pixel_values'] = torch.zeros(3, crop_size['height'], + crop_size['width']) + data_dict['masks'] = None + return data_dict + diff --git a/omg_llava/dataset/__init__.py b/omg_llava/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c30944a6f4c14880be8a354d3ea01d8987cd3397 --- /dev/null +++ b/omg_llava/dataset/__init__.py @@ -0,0 +1,29 @@ +from .CombineDataset import CombineDataset +from .GCGDataset import RefCOCOgGCGDataset, OpenPsgGCGDataset, GranDfGCGDataset, FlickrGCGDataset +from .SemanticSegDataset import SemanticSegDataset, ADE20kSemanticSegDataset,\ + COCOStuffSemanticSegDataset,MapillarySemanticSegDataset, PascalPartSemanticSegDataset,\ + PacoSemanticSegDataset +from .MDPVPointsDataset import MDPVPointDetailedCaptionDataset, MDPVPointBriefCaptionDataset +from .ReferringSegDataset import RefcocoReferringSegDataset, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset +from .RegionCaptionDataset import OspreyRegionCaptionDataset, OspreyRegionConversationDataset +from .LlavaDataset import LLaVADataset +from .DecoupledGCGDataset import DecoupledRefCOCOgGCGDataset, DecoupledOpenPsgGCGDataset,\ + DecoupledGranDfGCGDataset, DecoupledFlickrGCGDataset + + +from .process_functions import glamm_openpsg_map_fn, glamm_refcocog_map_fn,\ + glamm_granf_map_fn, glamm_flickr_map_fn,\ + semantic_seg_map_fn, pascal_part_map_fn,\ + semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\ + referring_seg_map_fn, referring_seg_gcg_format_map_fn,\ + osprey_region_caption_map_fn, osprey_region_caption_gcg_format_map_fn,\ + osprey_region_conversation_map_fn,\ + mdpv_points_map_fn + +from .process_functions import glamm_refcocog_decoupled_given_objects_map_fn, glamm_refcocog_decoupled_given_description_map_fn,\ + glamm_granf_decoupled_given_description_map_fn, glamm_granf_decoupled_given_objects_map_fn,\ + glamm_flickr_decoupled_given_description_map_fn, glamm_flickr_decoupled_given_objects_map_fn,\ + glamm_openpsg_decoupled_given_objects_map_fn, glamm_openpsg_decoupled_given_description_map_fn + +from .collect_fns import omg_llava_collate_fn \ No newline at end of file diff --git a/omg_llava/dataset/__pycache__/CombineDataset.cpython-310.pyc b/omg_llava/dataset/__pycache__/CombineDataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..602c2ab18ac02d169a8109617feac6518d732578 Binary files /dev/null and b/omg_llava/dataset/__pycache__/CombineDataset.cpython-310.pyc differ diff --git a/omg_llava/dataset/__pycache__/DecoupledGCGDataset.cpython-310.pyc b/omg_llava/dataset/__pycache__/DecoupledGCGDataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71a023792ea22e61af6d4de64ba2d7db66295233 Binary files /dev/null and b/omg_llava/dataset/__pycache__/DecoupledGCGDataset.cpython-310.pyc differ diff --git a/omg_llava/dataset/__pycache__/GCGDataset.cpython-310.pyc b/omg_llava/dataset/__pycache__/GCGDataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e669e7fb6a22c2ab84fff6acf9cc3176512e7123 Binary files /dev/null and b/omg_llava/dataset/__pycache__/GCGDataset.cpython-310.pyc differ diff --git a/omg_llava/dataset/__pycache__/LlavaDataset.cpython-310.pyc b/omg_llava/dataset/__pycache__/LlavaDataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91f90c3318f8c420c061f6bd8d372092d7b0fcf2 Binary files /dev/null and b/omg_llava/dataset/__pycache__/LlavaDataset.cpython-310.pyc differ diff --git a/omg_llava/dataset/__pycache__/MDPVPointsDataset.cpython-310.pyc b/omg_llava/dataset/__pycache__/MDPVPointsDataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1c8a6fdf042cd9aa337e9773dc0e14ea15b7e0b Binary files /dev/null and b/omg_llava/dataset/__pycache__/MDPVPointsDataset.cpython-310.pyc differ diff --git a/omg_llava/dataset/__pycache__/ReferringSegDataset.cpython-310.pyc b/omg_llava/dataset/__pycache__/ReferringSegDataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2faab3204d0d9799daba53558f27ebb06d40c29 Binary files /dev/null and b/omg_llava/dataset/__pycache__/ReferringSegDataset.cpython-310.pyc differ diff --git a/omg_llava/dataset/__pycache__/RegionCaptionDataset.cpython-310.pyc b/omg_llava/dataset/__pycache__/RegionCaptionDataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8f98ef78fd801007303d93f466b584c93338fa6 Binary files /dev/null and b/omg_llava/dataset/__pycache__/RegionCaptionDataset.cpython-310.pyc differ diff --git a/omg_llava/dataset/__pycache__/SemanticSegDataset.cpython-310.pyc b/omg_llava/dataset/__pycache__/SemanticSegDataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4cbe248eafff4f12183811317e94560536821425 Binary files /dev/null and b/omg_llava/dataset/__pycache__/SemanticSegDataset.cpython-310.pyc differ diff --git a/omg_llava/dataset/__pycache__/__init__.cpython-310.pyc b/omg_llava/dataset/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aaa223ba2be9f9bf3a5db8dfcea401ed2026ff7d Binary files /dev/null and b/omg_llava/dataset/__pycache__/__init__.cpython-310.pyc differ diff --git a/omg_llava/dataset/collect_fns/__init__.py b/omg_llava/dataset/collect_fns/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6381598f237a7b0bf7fcb51f541f994449d9612a --- /dev/null +++ b/omg_llava/dataset/collect_fns/__init__.py @@ -0,0 +1 @@ +from .omg_llava_collate_fn import omg_llava_collate_fn \ No newline at end of file diff --git a/omg_llava/dataset/collect_fns/__pycache__/__init__.cpython-310.pyc b/omg_llava/dataset/collect_fns/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f3b79e4cb069915394c6819bf6f27e8259494c2 Binary files /dev/null and b/omg_llava/dataset/collect_fns/__pycache__/__init__.cpython-310.pyc differ diff --git a/omg_llava/dataset/collect_fns/__pycache__/omg_llava_collate_fn.cpython-310.pyc b/omg_llava/dataset/collect_fns/__pycache__/omg_llava_collate_fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9058bb77075487db5d03fbdd7c223db743fdfb6 Binary files /dev/null and b/omg_llava/dataset/collect_fns/__pycache__/omg_llava_collate_fn.cpython-310.pyc differ diff --git a/omg_llava/dataset/collect_fns/omg_llava_collate_fn.py b/omg_llava/dataset/collect_fns/omg_llava_collate_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..233699f4bec871fb4d9ad6332aa510d0699e7abc --- /dev/null +++ b/omg_llava/dataset/collect_fns/omg_llava_collate_fn.py @@ -0,0 +1,136 @@ +from typing import Dict, Sequence + +import torch +from torch.nn.utils.rnn import pad_sequence + +from xtuner.parallel.sequence import (get_sequence_parallel_world_size, + pad_for_sequence_parallel) +from xtuner.utils import DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX + +def omg_llava_collate_fn(instances: Sequence[Dict], + pad_index: int = DEFAULT_PAD_TOKEN_INDEX, + return_hf_format: bool = False, + use_varlen_attn: bool = False): + seq_parallel_world_size = get_sequence_parallel_world_size() + + input_ids, labels = [], [] + has_image = any(inst.get('pixel_values') is not None for inst in instances) + has_mask = any(inst.get('masks') is not None for inst in instances) + has_region = any(inst.get('regions') is not None for inst in instances) + has_points = any(inst.get('points') is not None for inst in instances) + if use_varlen_attn: + position_ids, cumulative_len = [], [] + assert len(instances) == 1, ( + f'If utilizing varlen attention, the batch size should be' + f' set to 1, but got {len(instances)}') + assert not has_image, 'Currently, it is not configured to ' + 'accommodate the use of varlen Attention in multimodal training' + + if has_image: + pixel_values = [] + if has_mask: + object_masks = [] + if has_region: + object_regions = [] + if has_points: + prompt_points = [] + + for example in instances: + input_ids.append(torch.LongTensor(example['input_ids'])) + labels.append(torch.LongTensor(example['labels'])) + if use_varlen_attn: + cumulative_len.append(torch.IntTensor(example['cumulative_len'])) + position_ids.append(torch.LongTensor(example['position_ids'])) + + if has_image: + pixel_values.append(example['pixel_values']) + + if has_mask: + if 'masks' in example.keys() and example['masks'] is not None: + object_masks.append(example['masks']) + # object_masks.append(example['masks'] if 'masks' in example.keys() else None) + if has_region: + if 'regions' in example.keys() and example['regions'] is not None: + object_regions.append(example['regions']) + if has_points: + if 'points' in example.keys() and example['points'] is not None: + prompt_points.append(example['points']) + + ori_length = [len(ids) for ids in input_ids] + if len(instances) > 1: + input_ids = pad_sequence( + input_ids, batch_first=True, padding_value=pad_index) + labels = pad_sequence( + labels, batch_first=True, padding_value=IGNORE_INDEX) + else: + input_ids = torch.stack(input_ids) + labels = torch.stack(labels) + + if use_varlen_attn: + assert input_ids.size(1) % seq_parallel_world_size == 0 + attention_mask = None + position_ids = torch.stack(position_ids, dim=0) + else: + # Some tokenizers have the same eos token and pad token, so input_ids + # cannot be masked directly based on the pad token id. + attention_mask = torch.zeros_like(input_ids).bool() + for i in ori_length: + attention_mask[:i] = True + + bs, seq_len = input_ids.shape + position_ids = torch.arange(seq_len).unsqueeze(0).long().repeat(bs, 1) + + if seq_parallel_world_size > 1: + input_ids = pad_for_sequence_parallel(input_ids, pad_index) + labels = pad_for_sequence_parallel(labels, IGNORE_INDEX) + position_ids = pad_for_sequence_parallel(position_ids, 0) + if attention_mask is not None: + attention_mask = pad_for_sequence_parallel(attention_mask, 0) + + if use_varlen_attn: + max_seqlen = ( + cumulative_len[0][1:] - # noqa: W504 + cumulative_len[0][:-1]).max().item() + data_dict = { + 'input_ids': input_ids, + 'cumulative_len': cumulative_len, + 'position_ids': position_ids, + 'labels': labels, + 'max_seqlen': max_seqlen + } + else: + data_dict = { + 'input_ids': input_ids, + 'attention_mask': attention_mask, + 'position_ids': position_ids, + 'labels': labels + } + + if has_image: + pixel_values = torch.stack(pixel_values) + data_dict['pixel_values'] = pixel_values + + if has_mask: + if len(object_masks) == 0: + object_masks = torch.zeros((0, pixel_values.shape[-2], pixel_values.shape[-1])) + else: + object_masks = torch.cat(object_masks, dim=0) + data_dict['masks'] = object_masks + + if has_region: + if len(object_regions) == 0: + object_regions = torch.zeros((0, pixel_values.shape[-2], pixel_values.shape[-1])) + else: + object_regions = torch.cat(object_regions, dim=0) + data_dict['regions'] = object_regions + if has_points: + if len(prompt_points) == 0: + prompt_points = torch.zeros((0, 2)) + else: + prompt_points = torch.cat(prompt_points, dim=0) + data_dict['points'] = prompt_points + + if return_hf_format: + return data_dict + else: + return {'data': data_dict, 'data_samples': None} \ No newline at end of file diff --git a/omg_llava/dataset/process_functions/__init__.py b/omg_llava/dataset/process_functions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5e956b36669f1e5914648d888564d43216fe0553 --- /dev/null +++ b/omg_llava/dataset/process_functions/__init__.py @@ -0,0 +1,9 @@ +from .gcg_process import glamm_refcocog_map_fn, glamm_granf_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn +from .mdpv_points_process import mdpv_points_map_fn +from .referring_seg_process import referring_seg_map_fn, referring_seg_gcg_format_map_fn +from .region_caption_process import osprey_region_caption_map_fn, osprey_region_caption_gcg_format_map_fn, osprey_region_conversation_map_fn +from .semantic_seg_process import semantic_seg_map_fn, pascal_part_map_fn, semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn +from .decoupled_gcg_process import glamm_openpsg_decoupled_given_objects_map_fn, glamm_openpsg_decoupled_given_description_map_fn,\ + glamm_flickr_decoupled_given_objects_map_fn, glamm_flickr_decoupled_given_description_map_fn,\ + glamm_granf_decoupled_given_objects_map_fn, glamm_granf_decoupled_given_description_map_fn,\ + glamm_refcocog_decoupled_given_description_map_fn, glamm_refcocog_decoupled_given_objects_map_fn \ No newline at end of file diff --git a/omg_llava/dataset/process_functions/__pycache__/__init__.cpython-310.pyc b/omg_llava/dataset/process_functions/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa9f2ca7167c62284ac854a4df428f4dfaa06455 Binary files /dev/null and b/omg_llava/dataset/process_functions/__pycache__/__init__.cpython-310.pyc differ diff --git a/omg_llava/dataset/process_functions/__pycache__/decoupled_gcg_process.cpython-310.pyc b/omg_llava/dataset/process_functions/__pycache__/decoupled_gcg_process.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7720304087456f3985c99a491312353bdbe7be2a Binary files /dev/null and b/omg_llava/dataset/process_functions/__pycache__/decoupled_gcg_process.cpython-310.pyc differ diff --git a/omg_llava/dataset/process_functions/__pycache__/gcg_process.cpython-310.pyc b/omg_llava/dataset/process_functions/__pycache__/gcg_process.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26f0e6e01ad9554994380e11d952b21cd8c27cbf Binary files /dev/null and b/omg_llava/dataset/process_functions/__pycache__/gcg_process.cpython-310.pyc differ diff --git a/omg_llava/dataset/process_functions/__pycache__/mdpv_points_process.cpython-310.pyc b/omg_llava/dataset/process_functions/__pycache__/mdpv_points_process.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e024d58af37f2ed8892cb3ad1e381147b22e6add Binary files /dev/null and b/omg_llava/dataset/process_functions/__pycache__/mdpv_points_process.cpython-310.pyc differ diff --git a/omg_llava/dataset/process_functions/__pycache__/referring_seg_process.cpython-310.pyc b/omg_llava/dataset/process_functions/__pycache__/referring_seg_process.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1f2b69e238f1564aa2f7050f43f55a1a1ba90df Binary files /dev/null and b/omg_llava/dataset/process_functions/__pycache__/referring_seg_process.cpython-310.pyc differ diff --git a/omg_llava/dataset/process_functions/__pycache__/region_caption_process.cpython-310.pyc b/omg_llava/dataset/process_functions/__pycache__/region_caption_process.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6bbd06bca1c229c7655e993bb228e50e7142b03 Binary files /dev/null and b/omg_llava/dataset/process_functions/__pycache__/region_caption_process.cpython-310.pyc differ diff --git a/omg_llava/dataset/process_functions/__pycache__/semantic_seg_process.cpython-310.pyc b/omg_llava/dataset/process_functions/__pycache__/semantic_seg_process.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af55f374b571faa542706ddbd1fd99d52ea748ef Binary files /dev/null and b/omg_llava/dataset/process_functions/__pycache__/semantic_seg_process.cpython-310.pyc differ diff --git a/omg_llava/dataset/process_functions/decoupled_gcg_process.py b/omg_llava/dataset/process_functions/decoupled_gcg_process.py new file mode 100644 index 0000000000000000000000000000000000000000..e951c5a24fd72b8b925df5ed3f06db16b68e68b8 --- /dev/null +++ b/omg_llava/dataset/process_functions/decoupled_gcg_process.py @@ -0,0 +1,512 @@ +import numpy as np +import random +from xtuner.utils import DEFAULT_IMAGE_TOKEN + +GCG_QUESTIONS = [ + DEFAULT_IMAGE_TOKEN + 'Here is the description of the image: {} Please insert interleaved segmentation masks for the objects present in the image described in the description.', + DEFAULT_IMAGE_TOKEN + 'Below is the image description: {} Kindly add interleaved segmentation masks for the objects mentioned in the description.', + DEFAULT_IMAGE_TOKEN + 'The image is described as follows: {} Please insert interleaved segmentation masks for the objects outlined in the description.', + DEFAULT_IMAGE_TOKEN + 'Here is a description of the image: {} Please include interleaved segmentation masks for the objects detailed in the description.', + DEFAULT_IMAGE_TOKEN + 'Here’s what the image depicts: {} Please add interleaved segmentation masks for the objects identified in the description.', + DEFAULT_IMAGE_TOKEN + 'The following is a description of the image: {} Please incorporate interleaved segmentation masks for the objects referenced in the description.', +] + +GCG_QUESTIONS_GIVEN_OBJECTS = [ + DEFAULT_IMAGE_TOKEN + 'Please generate the image description for these objects: {}. Please include the interleaved segmentation masks for the corresponding objects in the provided image description.', + DEFAULT_IMAGE_TOKEN + 'Please create an image description for the following objects: {}. Ensure the interleaved segmentation masks for the corresponding objects are included in the provided image description.', + DEFAULT_IMAGE_TOKEN + 'Kindly generate a description of the image for these objects: {}. Please incorporate the interleaved segmentation masks for the corresponding objects in the supplied image description.', + DEFAULT_IMAGE_TOKEN + 'Please provide an image description for the objects: {}. Include the interleaved segmentation masks for the corresponding objects in the given image description.', + DEFAULT_IMAGE_TOKEN + 'Could you generate a description of the image focusing on these objects: {}? Please add the interleaved segmentation masks for the relevant objects in the provided image description.', + DEFAULT_IMAGE_TOKEN + 'Please compose an image description that includes these objects: {}. Make sure to include the interleaved segmentation masks for the corresponding objects in the described image.', +] + +def refcocog_parse_annotations(example): + # example {'id': str, 'refs': [{"setence", 'bbox', 'segmentation'},], 'img_file_name': str, 'caption': str} + annotations = {'labels': [], 'caption': [], 'masks': [], 'tokens_positive': [], + 'file_name': example['img_file_name'], 'image': example['img_file_name']} + + orig_caption = example['caption'].strip('"').strip() + annotations['caption'] = orig_caption.lower() + + for detail in example['refs']: + phrase = detail['sentence'] + if phrase.lower() in annotations['caption']: + annotations['labels'].append(phrase) + index = annotations['caption'].find(phrase) + end_index = index + len(phrase) if index != -1 else -1 + annotations['tokens_positive'].append([index, end_index]) + # still polygon or rle + annotations['masks'].append(detail["segmentation"]) + + # Sort tokens_positive and corresponding lists + tokens_positive = annotations['tokens_positive'] + sorted_indices = sorted(range(len(tokens_positive)), key=lambda i: tokens_positive[i][0]) + annotations['tokens_positive'] = [tokens_positive[i] for i in sorted_indices] + annotations['masks'] = [annotations['masks'][i] for i in sorted_indices] + annotations['labels'] = [annotations['labels'][i] for i in sorted_indices] + + # Trimming overlapping intervals + for i in range(len(tokens_positive)): + for j in range(i + 1, len(tokens_positive)): + # If there is overlap + if tokens_positive[i][1] >= tokens_positive[j][0]: + # Modify the end index of phrase i to be one less than the start index of phrase j + tokens_positive[i][1] = tokens_positive[j][0] - 1 + # Modify the phrases to reflect the change in indices + annotations['labels'][i] = orig_caption[tokens_positive[i][0]:tokens_positive[i][1] + 1] + break # Exit inner loop since i was modified + + return annotations + +def refcocog_conversation_decoupled_given_description(caption, tokens_positive): + # insert

and [seg] to caption and select a question + question = random.choice(GCG_QUESTIONS).strip().format(caption) + + # Prepare caption with tags + def tag_caption(caption, tokens): + for start, end in sorted(tokens, key=lambda x: x[0], reverse=True): + caption = f"{caption[:start]}

{caption[start:end]}

[SEG]{caption[end:]}" + return caption + + detailed_answer = tag_caption(caption, tokens_positive) + + conversations = [{'from': 'human', 'value': question}, {'from': 'gpt', 'value': detailed_answer}] + return conversations + +def refcocog_preprocess_decoupled_given_description(example): + data_labels = example['labels'] + masks = example['masks'] + caption = example['caption'] + tokens_positive = example['tokens_positive'] + + # Function to sort elements based on the start index of each phrase + def sort_by_start_index(items, order): + return [items[i] for i in order] + + # Sort phrases based on their appearance in the sentence + phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0]) + masks = sort_by_start_index(masks, phrase_order) + data_labels = sort_by_start_index(data_labels, phrase_order) + tokens_positive = sort_by_start_index(tokens_positive, phrase_order) + + conversations = refcocog_conversation_decoupled_given_description(caption, tokens_positive) + example['conversations'] = conversations + example['labels'] = data_labels + example['masks'] = masks + example['tokens_positive'] = tokens_positive + + return example + +def refcocog_conversation_decoupled_given_objects(caption, tokens_positive): + # insert

and [seg] to caption and select a question + object_tokens = '' + for i in range(len(tokens_positive)): + object_tokens = object_tokens + ' ' + object_tokens = object_tokens.strip() + + question = random.choice(GCG_QUESTIONS_GIVEN_OBJECTS).strip().format(object_tokens) + + # Prepare caption with tags + def tag_caption(caption, tokens): + for start, end in sorted(tokens, key=lambda x: x[0], reverse=True): + caption = f"{caption[:start]}

{caption[start:end]}

[SEG]{caption[end:]}" + return caption + + detailed_answer = tag_caption(caption, tokens_positive) + + conversations = [{'from': 'human', 'value': question}, {'from': 'gpt', 'value': detailed_answer}] + return conversations + +def refcocog_preprocess_decoupled_given_objects(example): + data_labels = example['labels'] + masks = example['masks'] + caption = example['caption'] + tokens_positive = example['tokens_positive'] + + # Function to sort elements based on the start index of each phrase + def sort_by_start_index(items, order): + return [items[i] for i in order] + + # Sort phrases based on their appearance in the sentence + phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0]) + masks = sort_by_start_index(masks, phrase_order) + data_labels = sort_by_start_index(data_labels, phrase_order) + tokens_positive = sort_by_start_index(tokens_positive, phrase_order) + + conversations = refcocog_conversation_decoupled_given_objects(caption, tokens_positive) + example['conversations'] = conversations + example['labels'] = data_labels + example['masks'] = masks + example['tokens_positive'] = tokens_positive + return example + +def glamm_refcocog_decoupled_given_description_map_fn(example): + # example {'id': str, 'refs': [{"setence", 'bbox', 'segmentation'},], 'img_file_name': str, 'caption': str} + + example = refcocog_parse_annotations(example) + # example 'labels': [], 'caption': str, 'masks': [], 'tokens_positive': [], 'file_name': image_file + + example = refcocog_preprocess_decoupled_given_description(example) + + # do llava preprocess + messages = example['conversations'] + input = '' + conversation = [] + while messages and messages[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + messages = messages[1:] + for msg in messages: + if msg['from'] == 'human': + if DEFAULT_IMAGE_TOKEN in msg['value']: + msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, + '').strip() + msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] + msg['value'] = msg['value'].strip() + input += msg['value'] + + elif msg['from'] == 'gpt': + conversation.append({'input': input, 'output': msg['value']}) + input = '' + else: + raise NotImplementedError + example.update({'conversation': conversation}) + return example + +def glamm_refcocog_decoupled_given_objects_map_fn(example): + # example {'id': str, 'refs': [{"setence", 'bbox', 'segmentation'},], 'img_file_name': str, 'caption': str} + + example = refcocog_parse_annotations(example) + # example 'labels': [], 'caption': str, 'masks': [], 'tokens_positive': [], 'file_name': image_file + + example = refcocog_preprocess_decoupled_given_objects(example) + + # do llava preprocess + messages = example['conversations'] + input = '' + conversation = [] + while messages and messages[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + messages = messages[1:] + for msg in messages: + if msg['from'] == 'human': + if DEFAULT_IMAGE_TOKEN in msg['value']: + msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, + '').strip() + msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] + msg['value'] = msg['value'].strip() + input += msg['value'] + + elif msg['from'] == 'gpt': + conversation.append({'input': input, 'output': msg['value']}) + input = '' + else: + raise NotImplementedError + example.update({'conversation': conversation}) + return example + +def grandf_parse_annotations(example): + image_path = example['file_name'] + annotations = { + 'labels': [], 'caption': [], 'masks': [], + 'tokens_positive': [], 'file_name': image_path, + 'image': image_path} + annotations['caption'] = example['caption'].strip('"').strip() + + for word, grounding in example["groundings"].items(): + if grounding is None: + continue + annotations['labels'].append(word) + annotations['tokens_positive'].append(grounding["token_positives"]) + annotations['masks'].append(grounding["rle_masks"]) + + return annotations + +def grandf_conversation_given_description(caption, tokens_positive): + question = random.choice(GCG_QUESTIONS).strip().format(caption) + + # Prepare caption with tags + def tag_caption(caption, tokens): + for start, end in sorted(tokens, key=lambda x: x[0], reverse=True): + caption = f"{caption[:start]}

{caption[start:end]}

[SEG]{caption[end:]}" + return caption + + detailed_answer = tag_caption(caption, tokens_positive) + + conversations = [{'from': 'human', 'value': question}, {'from': 'gpt', 'value': detailed_answer}] + return conversations + +def grandf_conversation_given_objects(caption, tokens_positive): + object_tokens = '' + for i in range(len(tokens_positive)): + object_tokens = object_tokens + ' ' + object_tokens = object_tokens.strip() + + question = random.choice(GCG_QUESTIONS_GIVEN_OBJECTS).strip().format(object_tokens) + + # Prepare caption with tags + def tag_caption(caption, tokens): + for start, end in sorted(tokens, key=lambda x: x[0], reverse=True): + caption = f"{caption[:start]}

{caption[start:end]}

[SEG]{caption[end:]}" + return caption + + detailed_answer = tag_caption(caption, tokens_positive) + + conversations = [{'from': 'human', 'value': question}, {'from': 'gpt', 'value': detailed_answer}] + return conversations + +def grandf_preprocess_given_description(example): + data_labels = example['labels'] + masks = example['masks'] + caption = example['caption'] + tokens_positive = example['tokens_positive'] + + # Function to sort elements based on the start index of each phrase + def sort_by_start_index(items, order): + return [items[i] for i in order] + + # Sort phrases based on their appearance in the sentence + phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0]) + masks = sort_by_start_index(masks, phrase_order) + data_labels = sort_by_start_index(data_labels, phrase_order) + tokens_positive = sort_by_start_index(tokens_positive, phrase_order) + + conversations = grandf_conversation_given_description(caption, tokens_positive) + example['conversations'] = conversations + example['labels'] = data_labels + example['masks'] = masks + example['tokens_positive'] = tokens_positive + return example + +def grandf_preprocess_given_objects(example): + data_labels = example['labels'] + masks = example['masks'] + caption = example['caption'] + tokens_positive = example['tokens_positive'] + + # Function to sort elements based on the start index of each phrase + def sort_by_start_index(items, order): + return [items[i] for i in order] + + # Sort phrases based on their appearance in the sentence + phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0]) + masks = sort_by_start_index(masks, phrase_order) + data_labels = sort_by_start_index(data_labels, phrase_order) + tokens_positive = sort_by_start_index(tokens_positive, phrase_order) + + conversations = grandf_conversation_given_objects(caption, tokens_positive) + example['conversations'] = conversations + example['labels'] = data_labels + example['masks'] = masks + example['tokens_positive'] = tokens_positive + return example + +def glamm_granf_decoupled_given_description_map_fn(example): + # example {'file_name': str, "height": int, "width": int, "image_id": str, caption: "str", + # "groundings": {ground_words: {'token_positives', 'rle_masks', }}} + example = grandf_parse_annotations(example) + # example 'labels': [], 'caption': str, 'masks': [], 'tokens_positive': [], 'file_name': image_file + + example = grandf_preprocess_given_description(example) + + # do llava preprocess + messages = example['conversations'] + input = '' + conversation = [] + while messages and messages[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + messages = messages[1:] + for msg in messages: + if msg['from'] == 'human': + if DEFAULT_IMAGE_TOKEN in msg['value']: + msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, + '').strip() + msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] + msg['value'] = msg['value'].strip() + input += msg['value'] + + elif msg['from'] == 'gpt': + conversation.append({'input': input, 'output': msg['value']}) + input = '' + else: + raise NotImplementedError + example.update({'conversation': conversation}) + return example + +def glamm_granf_decoupled_given_objects_map_fn(example): + # example {'file_name': str, "height": int, "width": int, "image_id": str, caption: "str", + # "groundings": {ground_words: {'token_positives', 'rle_masks', }}} + example = grandf_parse_annotations(example) + # example 'labels': [], 'caption': str, 'masks': [], 'tokens_positive': [], 'file_name': image_file + + example = grandf_preprocess_given_objects(example) + + # do llava preprocess + messages = example['conversations'] + input = '' + conversation = [] + while messages and messages[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + messages = messages[1:] + for msg in messages: + if msg['from'] == 'human': + if DEFAULT_IMAGE_TOKEN in msg['value']: + msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, + '').strip() + msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] + msg['value'] = msg['value'].strip() + input += msg['value'] + + elif msg['from'] == 'gpt': + conversation.append({'input': input, 'output': msg['value']}) + input = '' + else: + raise NotImplementedError + example.update({'conversation': conversation}) + return example + +glamm_openpsg_decoupled_given_objects_map_fn = glamm_granf_decoupled_given_objects_map_fn +glamm_openpsg_decoupled_given_description_map_fn = glamm_granf_decoupled_given_description_map_fn + +def flickr_parse_annotations(example): + annotations = {'bboxes': [], 'labels': [], 'bboxes_ignore': [], 'caption': example['caption'], 'masks': [], + 'tokens_positive': [], 'image': example['file_name']} + ann_info = example["ann_info"] + for ann in ann_info: + if ann.get('ignore', False): + continue + x1, y1, w, h = ann['bbox'] + inter_w = max(0, min(x1 + w, example['width']) - max(x1, 0)) + inter_h = max(0, min(y1 + h, example['height']) - max(y1, 0)) + if inter_w * inter_h == 0 or ann['area'] <= 0 or w < 1 or h < 1: + continue + bbox = [x1, y1, x1 + w, y1 + h] + annotations['bboxes'].append(bbox) + tokens_positive = ann['tokens_positive'] + gt_label = [example['caption'][span[0]:span[1]] for span in tokens_positive] + annotations['labels'].append(gt_label[0]) + annotations['tokens_positive'].append(tokens_positive[0]) + + rle = ann['sam_mask'] + annotations['masks'].append(rle) + + # Convert bounding boxes to numpy arrays + annotations['bboxes'] = np.array(annotations['bboxes'], dtype=np.float32) if annotations[ + 'bboxes'] else np.zeros((0, 4), dtype=np.float32) + annotations['bboxes_ignore'] = np.array(annotations['bboxes_ignore'], dtype=np.float32) if annotations[ + 'bboxes_ignore'] else np.zeros((0, 4), dtype=np.float32) + return annotations + +def flickr_preprocess_given_description(example): + data_labels = example['labels'] + masks = example['masks'] + caption = example['caption'] + tokens_positive = example['tokens_positive'] + + # Function to sort elements based on the start index of each phrase + def sort_by_start_index(items, order): + return [items[i] for i in order] + + # Sort phrases based on their appearance in the sentence + phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0]) + masks = sort_by_start_index(masks, phrase_order) + data_labels = sort_by_start_index(data_labels, phrase_order) + tokens_positive = sort_by_start_index(tokens_positive, phrase_order) + + conversations = grandf_conversation_given_description(caption, tokens_positive) + example['conversations'] = conversations + example['labels'] = data_labels + example['masks'] = masks + example['tokens_positive'] = tokens_positive + return example + +def flickr_preprocess_given_objects(example): + data_labels = example['labels'] + masks = example['masks'] + caption = example['caption'] + tokens_positive = example['tokens_positive'] + + # Function to sort elements based on the start index of each phrase + def sort_by_start_index(items, order): + return [items[i] for i in order] + + # Sort phrases based on their appearance in the sentence + phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0]) + masks = sort_by_start_index(masks, phrase_order) + data_labels = sort_by_start_index(data_labels, phrase_order) + tokens_positive = sort_by_start_index(tokens_positive, phrase_order) + + conversations = grandf_conversation_given_objects(caption, tokens_positive) + example['conversations'] = conversations + example['labels'] = data_labels + example['masks'] = masks + example['tokens_positive'] = tokens_positive + return example + +def glamm_flickr_decoupled_given_description_map_fn(example): + # example {'file_name': str, "height": int, "width": int, "image_id": str, caption: "str", + # "groundings": {ground_words: {'token_positives', 'rle_masks', }}} + + example = flickr_parse_annotations(example) + + example = flickr_preprocess_given_description(example) + + # do llava preprocess + messages = example['conversations'] + input = '' + conversation = [] + while messages and messages[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + messages = messages[1:] + for msg in messages: + if msg['from'] == 'human': + if DEFAULT_IMAGE_TOKEN in msg['value']: + msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, + '').strip() + msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] + msg['value'] = msg['value'].strip() + input += msg['value'] + + elif msg['from'] == 'gpt': + conversation.append({'input': input, 'output': msg['value']}) + input = '' + else: + raise NotImplementedError + example.update({'conversation': conversation}) + return example + +def glamm_flickr_decoupled_given_objects_map_fn(example): + # example {'file_name': str, "height": int, "width": int, "image_id": str, caption: "str", + # "groundings": {ground_words: {'token_positives', 'rle_masks', }}} + + example = flickr_parse_annotations(example) + + example = flickr_preprocess_given_objects(example) + + # do llava preprocess + messages = example['conversations'] + input = '' + conversation = [] + while messages and messages[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + messages = messages[1:] + for msg in messages: + if msg['from'] == 'human': + if DEFAULT_IMAGE_TOKEN in msg['value']: + msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, + '').strip() + msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] + msg['value'] = msg['value'].strip() + input += msg['value'] + + elif msg['from'] == 'gpt': + conversation.append({'input': input, 'output': msg['value']}) + input = '' + else: + raise NotImplementedError + example.update({'conversation': conversation}) + return example + + + + + diff --git a/omg_llava/dataset/process_functions/gcg_process.py b/omg_llava/dataset/process_functions/gcg_process.py new file mode 100644 index 0000000000000000000000000000000000000000..8befcbcc48375184f4d557b38791cce5b78360d8 --- /dev/null +++ b/omg_llava/dataset/process_functions/gcg_process.py @@ -0,0 +1,297 @@ +import numpy as np +import random +from xtuner.utils import DEFAULT_IMAGE_TOKEN + +GCG_QUESTIONS = [ + DEFAULT_IMAGE_TOKEN + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.', + DEFAULT_IMAGE_TOKEN + 'Can you provide a thorough description of the this image? Please output with interleaved segmentation masks for the corresponding phrases.', + DEFAULT_IMAGE_TOKEN + 'Please describe in detail the contents of the image. Please respond with interleaved segmentation masks for the corresponding parts of the answer.', + DEFAULT_IMAGE_TOKEN + 'Could you give a comprehensive explanation of what can be found within this picture? Please output with interleaved segmentation masks for the corresponding phrases.', + DEFAULT_IMAGE_TOKEN + 'Could you give me an elaborate explanation of this picture? Please respond with interleaved segmentation masks for the corresponding phrases.', + DEFAULT_IMAGE_TOKEN + 'Could you provide me with a detailed analysis of this photo? Please output with interleaved segmentation masks for the corresponding parts of the answer.', +] + +def refcocog_parse_annotations(example): + # example {'id': str, 'refs': [{"setence", 'bbox', 'segmentation'},], 'img_file_name': str, 'caption': str} + annotations = {'labels': [], 'caption': [], 'masks': [], 'tokens_positive': [], + 'file_name': example['img_file_name'], 'image': example['img_file_name']} + + orig_caption = example['caption'].strip('"').strip() + annotations['caption'] = orig_caption.lower() + + for detail in example['refs']: + phrase = detail['sentence'] + if phrase.lower() in annotations['caption']: + annotations['labels'].append(phrase) + index = annotations['caption'].find(phrase) + end_index = index + len(phrase) if index != -1 else -1 + annotations['tokens_positive'].append([index, end_index]) + # still polygon or rle + annotations['masks'].append(detail["segmentation"]) + + # Sort tokens_positive and corresponding lists + tokens_positive = annotations['tokens_positive'] + sorted_indices = sorted(range(len(tokens_positive)), key=lambda i: tokens_positive[i][0]) + annotations['tokens_positive'] = [tokens_positive[i] for i in sorted_indices] + annotations['masks'] = [annotations['masks'][i] for i in sorted_indices] + annotations['labels'] = [annotations['labels'][i] for i in sorted_indices] + + # Trimming overlapping intervals + for i in range(len(tokens_positive)): + for j in range(i + 1, len(tokens_positive)): + # If there is overlap + if tokens_positive[i][1] >= tokens_positive[j][0]: + # Modify the end index of phrase i to be one less than the start index of phrase j + tokens_positive[i][1] = tokens_positive[j][0] - 1 + # Modify the phrases to reflect the change in indices + annotations['labels'][i] = orig_caption[tokens_positive[i][0]:tokens_positive[i][1] + 1] + break # Exit inner loop since i was modified + + return annotations + +def refcocog_conversation(caption, tokens_positive): + # insert

and [seg] to caption and select a question + question = random.choice(GCG_QUESTIONS).strip() + + # Prepare caption with tags + def tag_caption(caption, tokens): + for start, end in sorted(tokens, key=lambda x: x[0], reverse=True): + caption = f"{caption[:start]}

{caption[start:end]}

[SEG]{caption[end:]}" + return caption + + detailed_answer = tag_caption(caption, tokens_positive) + + conversations = [{'from': 'human', 'value': question}, {'from': 'gpt', 'value': detailed_answer}] + return conversations + +def refcocog_preprocess(example): + data_labels = example['labels'] + masks = example['masks'] + caption = example['caption'] + tokens_positive = example['tokens_positive'] + + # Function to sort elements based on the start index of each phrase + def sort_by_start_index(items, order): + return [items[i] for i in order] + + # Sort phrases based on their appearance in the sentence + phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0]) + masks = sort_by_start_index(masks, phrase_order) + data_labels = sort_by_start_index(data_labels, phrase_order) + tokens_positive = sort_by_start_index(tokens_positive, phrase_order) + + conversations = refcocog_conversation(caption, tokens_positive) + example['conversations'] = conversations + example['labels'] = data_labels + example['masks'] = masks + example['tokens_positive'] = tokens_positive + + return example + +def glamm_refcocog_map_fn(example): + # example {'id': str, 'refs': [{"setence", 'bbox', 'segmentation'},], 'img_file_name': str, 'caption': str} + + example = refcocog_parse_annotations(example) + # example 'labels': [], 'caption': str, 'masks': [], 'tokens_positive': [], 'file_name': image_file + + example = refcocog_preprocess(example) + + # do llava preprocess + messages = example['conversations'] + input = '' + conversation = [] + while messages and messages[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + messages = messages[1:] + for msg in messages: + if msg['from'] == 'human': + if DEFAULT_IMAGE_TOKEN in msg['value']: + msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, + '').strip() + msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] + msg['value'] = msg['value'].strip() + input += msg['value'] + + elif msg['from'] == 'gpt': + conversation.append({'input': input, 'output': msg['value']}) + input = '' + else: + raise NotImplementedError + example.update({'conversation': conversation}) + return example + +def grandf_parse_annotations(example): + image_path = example['file_name'] + annotations = { + 'labels': [], 'caption': [], 'masks': [], + 'tokens_positive': [], 'file_name': image_path, + 'image': image_path} + annotations['caption'] = example['caption'].strip('"').strip() + + for word, grounding in example["groundings"].items(): + if grounding is None: + continue + annotations['labels'].append(word) + annotations['tokens_positive'].append(grounding["token_positives"]) + annotations['masks'].append(grounding["rle_masks"]) + + return annotations + +def grandf_conversation(caption, tokens_positive): + question = random.choice(GCG_QUESTIONS).strip() + + # Prepare caption with tags + def tag_caption(caption, tokens): + for start, end in sorted(tokens, key=lambda x: x[0], reverse=True): + caption = f"{caption[:start]}

{caption[start:end]}

[SEG]{caption[end:]}" + return caption + + detailed_answer = tag_caption(caption, tokens_positive) + + conversations = [{'from': 'human', 'value': question}, {'from': 'gpt', 'value': detailed_answer}] + return conversations +def grandf_preprocess(example): + data_labels = example['labels'] + masks = example['masks'] + caption = example['caption'] + tokens_positive = example['tokens_positive'] + + # Function to sort elements based on the start index of each phrase + def sort_by_start_index(items, order): + return [items[i] for i in order] + + # Sort phrases based on their appearance in the sentence + phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0]) + masks = sort_by_start_index(masks, phrase_order) + data_labels = sort_by_start_index(data_labels, phrase_order) + tokens_positive = sort_by_start_index(tokens_positive, phrase_order) + + conversations = grandf_conversation(caption, tokens_positive) + example['conversations'] = conversations + example['labels'] = data_labels + example['masks'] = masks + example['tokens_positive'] = tokens_positive + return example + +def glamm_granf_map_fn(example): + # example {'file_name': str, "height": int, "width": int, "image_id": str, caption: "str", + # "groundings": {ground_words: {'token_positives', 'rle_masks', }}} + example = grandf_parse_annotations(example) + # example 'labels': [], 'caption': str, 'masks': [], 'tokens_positive': [], 'file_name': image_file + + example = grandf_preprocess(example) + + # do llava preprocess + messages = example['conversations'] + input = '' + conversation = [] + while messages and messages[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + messages = messages[1:] + for msg in messages: + if msg['from'] == 'human': + if DEFAULT_IMAGE_TOKEN in msg['value']: + msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, + '').strip() + msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] + msg['value'] = msg['value'].strip() + input += msg['value'] + + elif msg['from'] == 'gpt': + conversation.append({'input': input, 'output': msg['value']}) + input = '' + else: + raise NotImplementedError + example.update({'conversation': conversation}) + return example + +glamm_openpsg_map_fn = glamm_granf_map_fn + +def flickr_parse_annotations(example): + annotations = {'bboxes': [], 'labels': [], 'bboxes_ignore': [], 'caption': example['caption'], 'masks': [], + 'tokens_positive': [], 'image': example['file_name']} + ann_info = example["ann_info"] + for ann in ann_info: + if ann.get('ignore', False): + continue + x1, y1, w, h = ann['bbox'] + inter_w = max(0, min(x1 + w, example['width']) - max(x1, 0)) + inter_h = max(0, min(y1 + h, example['height']) - max(y1, 0)) + if inter_w * inter_h == 0 or ann['area'] <= 0 or w < 1 or h < 1: + continue + bbox = [x1, y1, x1 + w, y1 + h] + annotations['bboxes'].append(bbox) + tokens_positive = ann['tokens_positive'] + gt_label = [example['caption'][span[0]:span[1]] for span in tokens_positive] + annotations['labels'].append(gt_label[0]) + annotations['tokens_positive'].append(tokens_positive[0]) + + rle = ann['sam_mask'] + annotations['masks'].append(rle) + + # Convert bounding boxes to numpy arrays + annotations['bboxes'] = np.array(annotations['bboxes'], dtype=np.float32) if annotations[ + 'bboxes'] else np.zeros((0, 4), dtype=np.float32) + annotations['bboxes_ignore'] = np.array(annotations['bboxes_ignore'], dtype=np.float32) if annotations[ + 'bboxes_ignore'] else np.zeros((0, 4), dtype=np.float32) + return annotations + +def flickr_preprocess(example): + data_labels = example['labels'] + masks = example['masks'] + caption = example['caption'] + tokens_positive = example['tokens_positive'] + + # Function to sort elements based on the start index of each phrase + def sort_by_start_index(items, order): + return [items[i] for i in order] + + # Sort phrases based on their appearance in the sentence + phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0]) + masks = sort_by_start_index(masks, phrase_order) + data_labels = sort_by_start_index(data_labels, phrase_order) + tokens_positive = sort_by_start_index(tokens_positive, phrase_order) + + conversations = grandf_conversation(caption, tokens_positive) + example['conversations'] = conversations + example['labels'] = data_labels + example['masks'] = masks + example['tokens_positive'] = tokens_positive + return example + +def glamm_flickr_map_fn(example): + # example {'file_name': str, "height": int, "width": int, "image_id": str, caption: "str", + # "groundings": {ground_words: {'token_positives', 'rle_masks', }}} + + example = flickr_parse_annotations(example) + + example = flickr_preprocess(example) + + # do llava preprocess + messages = example['conversations'] + input = '' + conversation = [] + while messages and messages[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + messages = messages[1:] + for msg in messages: + if msg['from'] == 'human': + if DEFAULT_IMAGE_TOKEN in msg['value']: + msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, + '').strip() + msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] + msg['value'] = msg['value'].strip() + input += msg['value'] + + elif msg['from'] == 'gpt': + conversation.append({'input': input, 'output': msg['value']}) + input = '' + else: + raise NotImplementedError + example.update({'conversation': conversation}) + return example + + + + + diff --git a/omg_llava/dataset/process_functions/mdpv_points_process.py b/omg_llava/dataset/process_functions/mdpv_points_process.py new file mode 100644 index 0000000000000000000000000000000000000000..4f966703b53356e2da6e604441d931b175477619 --- /dev/null +++ b/omg_llava/dataset/process_functions/mdpv_points_process.py @@ -0,0 +1,52 @@ +from xtuner.utils import DEFAULT_IMAGE_TOKEN + +def mdpv_points_preprocess(example): + conversations = example['conversations'] + num_marks = example['num_marks'] + + for i, conversation in enumerate(conversations): + if i == 0: + role = conversation['from'] + assert role == 'human' + question = DEFAULT_IMAGE_TOKEN + 'There are some marks:' + for i in range(num_marks): + question = question + ' Mark {} '.format(i + 1) + if i + 1 == num_marks: + question = question + '.\n' + else: + question = question + ',' + question = question + conversation['value'].replace('<', '').replace('>', '') + conversation['value'] = question + else: + conversation['value'] = conversation['value'].replace('<', '').replace('>', '') + + example['conversations'] = conversations + return example + +def mdpv_points_map_fn(example): + # examples {'image', 'conversations'} + example = mdpv_points_preprocess(example) + + # do llava preprocess + messages = example['conversations'] + input = '' + conversation = [] + while messages and messages[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + messages = messages[1:] + for msg in messages: + if msg['from'] == 'human': + if DEFAULT_IMAGE_TOKEN in msg['value']: + msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, + '').strip() + msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] + msg['value'] = msg['value'].strip() + input += msg['value'] + + elif msg['from'] == 'gpt': + conversation.append({'input': input, 'output': msg['value']}) + input = '' + else: + raise NotImplementedError + example.update({'conversation': conversation}) + return example \ No newline at end of file diff --git a/omg_llava/dataset/process_functions/referring_seg_process.py b/omg_llava/dataset/process_functions/referring_seg_process.py new file mode 100644 index 0000000000000000000000000000000000000000..679f7a32ffba957a8f27ee8bf2e4b7715859dd0c --- /dev/null +++ b/omg_llava/dataset/process_functions/referring_seg_process.py @@ -0,0 +1,135 @@ +import random +from xtuner.utils import DEFAULT_IMAGE_TOKEN + +SEG_QUESTIONS = [ + "Can you segment the {class_name} in this image?", + "Please segment {class_name} in this image.", + "What is {class_name} in this image? Please respond with segmentation mask.", + "What is {class_name} in this image? Please output segmentation mask.", + + "Can you segment the {class_name} in this image", + "Please segment {class_name} in this image", + "What is {class_name} in this image? Please respond with segmentation mask", + "What is {class_name} in this image? Please output segmentation mask", + + "Could you provide a segmentation mask for the {class_name} in this image?", + "Please identify and segment the {class_name} in this image.", + "Where is the {class_name} in this picture? Please respond with a segmentation mask.", + "Can you highlight the {class_name} in this image with a segmentation mask?", + + "Could you provide a segmentation mask for the {class_name} in this image", + "Please identify and segment the {class_name} in this image", + "Where is the {class_name} in this picture? Please respond with a segmentation mask", + "Can you highlight the {class_name} in this image with a segmentation mask", +] + +ANSWER_LIST = [ + "It is [SEG].", + "Sure, [SEG].", + "Sure, it is [SEG].", + "Sure, the segmentation result is [SEG].", + "[SEG].", +] + +ANSWER_LIST_GCG_FORMAT = [ + "

{}

[SEG].", +] + +def referring_seg_conversations(labels): + questions = [] + answers = [] + for i, label in enumerate(labels): + label = label.strip() + assert len(label.split("||")) == 1 + question_template = random.choice(SEG_QUESTIONS) + questions.append(question_template.format(class_name=label.lower())) + answers.append(random.choice(ANSWER_LIST)) + ret = [] + for i, (question, answer) in enumerate(zip(questions, answers)): + if i == 0: + ret.append( + {'from': 'human', 'value': DEFAULT_IMAGE_TOKEN+question} + ) + else: + ret.append( + {'from': 'human', 'value': question} + ) + ret.append( + {'from': 'gpt', 'value': answer} + ) + return ret + +def referring_seg_map_fn(example): + # example {'sampled_sents'} + messages = referring_seg_conversations(example['sampled_sents']) + input = '' + conversation = [] + while messages and messages[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + messages = messages[1:] + for msg in messages: + if msg['from'] == 'human': + if DEFAULT_IMAGE_TOKEN in msg['value']: + msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, + '').strip() + msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] + msg['value'] = msg['value'].strip() + input += msg['value'] + + elif msg['from'] == 'gpt': + conversation.append({'input': input, 'output': msg['value']}) + input = '' + else: + raise NotImplementedError + example.update({'conversation': conversation}) + return example + +def referring_seg_gcg_format_conversations(labels): + questions = [] + answers = [] + for i, label in enumerate(labels): + label = label.strip() + assert len(label.split("||")) == 1 + question_template = random.choice(SEG_QUESTIONS) + questions.append(question_template.format(class_name=label.lower())) + answers.append(random.choice(ANSWER_LIST_GCG_FORMAT).format(label.lower().capitalize())) + ret = [] + for i, (question, answer) in enumerate(zip(questions, answers)): + if i == 0: + ret.append( + {'from': 'human', 'value': DEFAULT_IMAGE_TOKEN+question} + ) + else: + ret.append( + {'from': 'human', 'value': question} + ) + ret.append( + {'from': 'gpt', 'value': answer} + ) + return ret + +def referring_seg_gcg_format_map_fn(example): + # example {'sampled_sents'} + + messages = referring_seg_gcg_format_conversations(example['sampled_sents']) + input = '' + conversation = [] + while messages and messages[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + messages = messages[1:] + for msg in messages: + if msg['from'] == 'human': + if DEFAULT_IMAGE_TOKEN in msg['value']: + msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, + '').strip() + msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] + msg['value'] = msg['value'].strip() + input += msg['value'] + + elif msg['from'] == 'gpt': + conversation.append({'input': input, 'output': msg['value']}) + input = '' + else: + raise NotImplementedError + example.update({'conversation': conversation}) + return example \ No newline at end of file diff --git a/omg_llava/dataset/process_functions/region_caption_process.py b/omg_llava/dataset/process_functions/region_caption_process.py new file mode 100644 index 0000000000000000000000000000000000000000..847abefeaac32351cd3e0849ed2f588f2901f7d3 --- /dev/null +++ b/omg_llava/dataset/process_functions/region_caption_process.py @@ -0,0 +1,223 @@ +import numpy as np +import random +from xtuner.utils import DEFAULT_IMAGE_TOKEN +import re + +REGION_QUESTIONS = [ + 'Can you provide me with a detailed description of the region in the picture marked by ?', + "I'm curious about the region represented by in the picture. Could you describe it in detail?", + 'What can you tell me about the region indicated by in the image?', + "I'd like to know more about the area in the photo labeled . Can you give me a detailed description?", + 'Could you describe the region shown as in the picture in great detail?', + 'What details can you give me about the region outlined by in the photo?', + 'Please provide me with a comprehensive description of the region marked with in the image.', + 'Can you give me a detailed account of the region labeled as in the picture?', + "I'm interested in learning more about the region represented by in the photo. Can you describe it in detail?", + 'What is the region outlined by in the picture like? Could you give me a detailed description?', + 'Can you provide me with a detailed description of the region in the picture marked by , please?', + "I'm curious about the region represented by in the picture. Could you describe it in detail, please?", + 'What can you tell me about the region indicated by in the image, exactly?', + "I'd like to know more about the area in the photo labeled , please. Can you give me a detailed description?", + 'Could you describe the region shown as in the picture in great detail, please?', + 'What details can you give me about the region outlined by in the photo, please?', + 'Please provide me with a comprehensive description of the region marked with in the image, please.', + 'Can you give me a detailed account of the region labeled as in the picture, please?', + "I'm interested in learning more about the region represented by in the photo. Can you describe it in detail, please?", + 'What is the region outlined by in the picture like, please? Could you give me a detailed description?', +] + +def region_caption_conversation(descriptions): + questions = [] + answers = [] + for i, description in enumerate(descriptions): + question = random.choice(REGION_QUESTIONS).strip().replace('', f'region{i + 1} ') + if i == 0: + question = DEFAULT_IMAGE_TOKEN + question + questions.append(question) + answers.append(description.replace('', f'region{i + 1}')) + + # seg qa + selected_seg_idx = 1 + np.random.randint(0, len(descriptions)) + question = "Please segment the region{}.".format(selected_seg_idx) + answer = "Sure, it is [SEG]." + questions.append(question) + answers.append(answer) + + conversations = [] + for question, answer in zip(questions, answers): + conversations.append({'from': 'human', 'value': question}) + conversations.append({'from': 'gpt', 'value': answer}) + return conversations, [selected_seg_idx - 1] + +def region_caption_gcg_format_conversation(descriptions): + questions = [] + answers = [] + for i, description in enumerate(descriptions): + question = random.choice(REGION_QUESTIONS).strip().replace('', f'region{i + 1} ') + if i == 0: + question = DEFAULT_IMAGE_TOKEN + question + questions.append(question) + answers.append(description.replace('', f'region{i + 1}')) + + # seg qa + selected_seg_idx = 1 + np.random.randint(0, len(descriptions)) + question = "Please segment the region{}.".format(selected_seg_idx) + answer = "

Region{}

[SEG].".format(selected_seg_idx) + questions.append(question) + answers.append(answer) + + conversations = [] + for question, answer in zip(questions, answers): + conversations.append({'from': 'human', 'value': question}) + conversations.append({'from': 'gpt', 'value': answer}) + return conversations, [selected_seg_idx - 1] + +def region_caption_preprocess(example): + descriptions = example['description'] + + # random select some labels + if len(descriptions) >= 3: + sampled_inds = np.random.choice( + list(range(len(descriptions))), size=3, replace=False + ) + else: + sampled_inds = list(range(len(descriptions))) + + selected_descriptions = [descriptions[idx] for idx in sampled_inds] + selected_descriptions = [re.sub(r'<[^>]*>', '', item) for item in selected_descriptions] + + conversations, selected_seg_idx = region_caption_conversation(selected_descriptions) + example['conversations'] = conversations + example['sampled_inds'] = sampled_inds + example['seg_region_idx'] = selected_seg_idx + return example + +def region_caption_gcg_format_preprocess(example): + descriptions = example['description'] + + # random select some labels + if len(descriptions) >= 3: + sampled_inds = np.random.choice( + list(range(len(descriptions))), size=3, replace=False + ) + else: + sampled_inds = list(range(len(descriptions))) + + selected_descriptions = [descriptions[idx] for idx in sampled_inds] + selected_descriptions = [re.sub(r'<[^>]*>', '', item) for item in selected_descriptions] + + conversations, selected_seg_idx = region_caption_gcg_format_conversation(selected_descriptions) + example['conversations'] = conversations + example['sampled_inds'] = sampled_inds + example['seg_region_idx'] = selected_seg_idx + return example + +def osprey_region_caption_map_fn(example): + # examples {'image', 'description'} + example = region_caption_preprocess(example) + + # do llava preprocess + messages = example['conversations'] + input = '' + conversation = [] + while messages and messages[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + messages = messages[1:] + for msg in messages: + if msg['from'] == 'human': + if DEFAULT_IMAGE_TOKEN in msg['value']: + msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, + '').strip() + msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] + msg['value'] = msg['value'].strip() + input += msg['value'] + + elif msg['from'] == 'gpt': + conversation.append({'input': input, 'output': msg['value']}) + input = '' + else: + raise NotImplementedError + example.update({'conversation': conversation}) + return example + +def osprey_region_caption_gcg_format_map_fn(example): + # examples {'image', 'description'} + example = region_caption_gcg_format_preprocess(example) + + # do llava preprocess + messages = example['conversations'] + input = '' + conversation = [] + while messages and messages[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + messages = messages[1:] + for msg in messages: + if msg['from'] == 'human': + if DEFAULT_IMAGE_TOKEN in msg['value']: + msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, + '').strip() + msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] + msg['value'] = msg['value'].strip() + input += msg['value'] + + elif msg['from'] == 'gpt': + conversation.append({'input': input, 'output': msg['value']}) + input = '' + else: + raise NotImplementedError + example.update({'conversation': conversation}) + return example + +def region_conversations_preprocess(example): + conversations = example['conversations'] + num_regions = example['num_regions'] + + for i, conversation in enumerate(conversations): + if i == 0: + role = conversation['from'] + assert role == 'human' + question = DEFAULT_IMAGE_TOKEN + 'There are some regions:' + for i in range(num_regions): + question = question + ' region{} '.format(i + 1) + if i + 1 == num_regions: + question = question + '.\n' + else: + question = question + ',' + question = question + conversation['value'].replace('<', '').replace('>', '').\ + replace("regin", "region") + conversation['value'] = question + else: + conversation['value'] = conversation['value'].replace('<', '').replace('>', ''). \ + replace("regin", "region") + + example['conversations'] = conversations + return example + + +def osprey_region_conversation_map_fn(example): + # examples {'image', 'conversations'} + example = region_conversations_preprocess(example) + + # do llava preprocess + messages = example['conversations'] + input = '' + conversation = [] + while messages and messages[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + messages = messages[1:] + for msg in messages: + if msg['from'] == 'human': + if DEFAULT_IMAGE_TOKEN in msg['value']: + msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, + '').strip() + msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] + msg['value'] = msg['value'].strip() + input += msg['value'] + + elif msg['from'] == 'gpt': + conversation.append({'input': input, 'output': msg['value']}) + input = '' + else: + raise NotImplementedError + example.update({'conversation': conversation}) + return example \ No newline at end of file diff --git a/omg_llava/dataset/process_functions/semantic_seg_process.py b/omg_llava/dataset/process_functions/semantic_seg_process.py new file mode 100644 index 0000000000000000000000000000000000000000..7e764975a1b6703b0e3f9145e1ed3c28db069f15 --- /dev/null +++ b/omg_llava/dataset/process_functions/semantic_seg_process.py @@ -0,0 +1,206 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random +from xtuner.utils import DEFAULT_IMAGE_TOKEN + +SEG_QUESTIONS = [ + "Can you segment the {class_name} in this image?", + "Please segment {class_name} in this image.", + "What is {class_name} in this image? Please respond with segmentation mask.", + "What is {class_name} in this image? Please output segmentation mask.", + + "Can you segment the {class_name} in this image", + "Please segment {class_name} in this image", + "What is {class_name} in this image? Please respond with segmentation mask", + "What is {class_name} in this image? Please output segmentation mask", + + "Could you provide a segmentation mask for the {class_name} in this image?", + "Please identify and segment the {class_name} in this image.", + "Where is the {class_name} in this picture? Please respond with a segmentation mask.", + "Can you highlight the {class_name} in this image with a segmentation mask?", + + "Could you provide a segmentation mask for the {class_name} in this image", + "Please identify and segment the {class_name} in this image", + "Where is the {class_name} in this picture? Please respond with a segmentation mask", + "Can you highlight the {class_name} in this image with a segmentation mask", +] + +ANSWER_LIST = [ + "It is [SEG].", + "Sure, [SEG].", + "Sure, it is [SEG].", + "Sure, the segmentation result is [SEG].", + "[SEG].", +] + +ANSWER_LIST_GCG_FORMAT = [ + "

{}

[SEG].", +] + +def semantic_seg_conversations(labels): + ret = [] + for i, label in enumerate(labels): + label = label.strip() + assert len(label.split("||")) == 1 + for question_template in SEG_QUESTIONS: + for answer_template in ANSWER_LIST: + item = {} + item['conversations'] = [{'from': 'human', 'value': DEFAULT_IMAGE_TOKEN+question_template.format(class_name=label.lower())}, + {'from': 'gpt', 'value': answer_template}] + item['class_id'] = i + ret.append(item) + return ret + +def semantic_seg_map_fn(example): + # example {'conversations', 'class_id'} + messages = example['conversations'] + input = '' + conversation = [] + while messages and messages[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + messages = messages[1:] + for msg in messages: + if msg['from'] == 'human': + if DEFAULT_IMAGE_TOKEN in msg['value']: + msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, + '').strip() + msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] + msg['value'] = msg['value'].strip() + input += msg['value'] + + elif msg['from'] == 'gpt': + conversation.append({'input': input, 'output': msg['value']}) + input = '' + else: + raise NotImplementedError + example.update({'conversation': conversation}) + return example + +def pascal_part_conversation(selected_labels): + conversations = [] + for i, selected_label in enumerate(selected_labels): + question = random.choice(SEG_QUESTIONS).format(class_name=selected_label.lower()).strip() + answer = random.choice(ANSWER_LIST) + if i == 0: + question = DEFAULT_IMAGE_TOKEN + question + conversations.append({'from': 'human', 'value': question}) + conversations.append({'from': 'gpt', 'value': answer}) + return conversations + +def pascal_part_preprocess(example): + selected_labels = example["selected_labels"] + conversations = pascal_part_conversation(selected_labels) + example['conversations'] = conversations + return example + +def pascal_part_map_fn(example): + example = pascal_part_preprocess(example) + example['image'] = example["file_name"] + # do llava preprocess + messages = example['conversations'] + input = '' + conversation = [] + while messages and messages[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + messages = messages[1:] + for msg in messages: + if msg['from'] == 'human': + if DEFAULT_IMAGE_TOKEN in msg['value']: + msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, + '').strip() + msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] + msg['value'] = msg['value'].strip() + input += msg['value'] + + elif msg['from'] == 'gpt': + conversation.append({'input': input, 'output': msg['value']}) + input = '' + else: + raise NotImplementedError + example.update({'conversation': conversation}) + return example + + +def semantic_seg_gcg_format_conversations(labels): + ret = [] + for i, label in enumerate(labels): + label = label.strip() + assert len(label.split("||")) == 1 + for question_template in SEG_QUESTIONS: + for answer_template in ANSWER_LIST_GCG_FORMAT: + item = {} + item['conversations'] = [{'from': 'human', 'value': DEFAULT_IMAGE_TOKEN+question_template.format(class_name=label.lower())}, + {'from': 'gpt', 'value': answer_template.format(label.lower().capitalize())}] + item['class_id'] = i + ret.append(item) + return ret + +def semantic_seg_gcg_format_map_fn(example): + # example {'conversations', 'class_id'} + messages = example['conversations'] + input = '' + conversation = [] + while messages and messages[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + messages = messages[1:] + for msg in messages: + if msg['from'] == 'human': + if DEFAULT_IMAGE_TOKEN in msg['value']: + msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, + '').strip() + msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] + msg['value'] = msg['value'].strip() + input += msg['value'] + + elif msg['from'] == 'gpt': + conversation.append({'input': input, 'output': msg['value']}) + input = '' + else: + raise NotImplementedError + example.update({'conversation': conversation}) + return example + +def pascal_part_gcg_format_conversation(selected_labels): + conversations = [] + for i, selected_label in enumerate(selected_labels): + question = random.choice(SEG_QUESTIONS).format(class_name=selected_label.lower()).strip() + answer = random.choice(ANSWER_LIST).format(selected_label.lower().capitalize()) + if i == 0: + question = DEFAULT_IMAGE_TOKEN + question + conversations.append({'from': 'human', 'value': question}) + conversations.append({'from': 'gpt', 'value': answer}) + return conversations + +def pascal_part_gcg_format_preprocess(example): + selected_labels = example["selected_labels"] + conversations = pascal_part_gcg_format_conversation(selected_labels) + example['conversations'] = conversations + return example + +def pascal_part_gcg_format_map_fn(example): + example = pascal_part_gcg_format_preprocess(example) + example['image'] = example["file_name"] + # do llava preprocess + messages = example['conversations'] + input = '' + conversation = [] + while messages and messages[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + messages = messages[1:] + for msg in messages: + if msg['from'] == 'human': + if DEFAULT_IMAGE_TOKEN in msg['value']: + msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, + '').strip() + msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] + msg['value'] = msg['value'].strip() + input += msg['value'] + + elif msg['from'] == 'gpt': + conversation.append({'input': input, 'output': msg['value']}) + input = '' + else: + raise NotImplementedError + example.update({'conversation': conversation}) + return example + + diff --git a/omg_llava/dataset/utils/__init__.py b/omg_llava/dataset/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..085858dbf4fcac7d3432bd9d1e308e33b0cd3e23 --- /dev/null +++ b/omg_llava/dataset/utils/__init__.py @@ -0,0 +1 @@ +from .utils import expand2square, expand2square_mask, expand2square_points, expand2square_bbox \ No newline at end of file diff --git a/omg_llava/dataset/utils/__pycache__/__init__.cpython-310.pyc b/omg_llava/dataset/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..361a6b6b14324d36eae3cfd6d4a9fec5b3acdc1f Binary files /dev/null and b/omg_llava/dataset/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/omg_llava/dataset/utils/__pycache__/refcoco_refer.cpython-310.pyc b/omg_llava/dataset/utils/__pycache__/refcoco_refer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff42d24c2743cb09f65512ccea331bad4f88708e Binary files /dev/null and b/omg_llava/dataset/utils/__pycache__/refcoco_refer.cpython-310.pyc differ diff --git a/omg_llava/dataset/utils/__pycache__/utils.cpython-310.pyc b/omg_llava/dataset/utils/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94360a333f042ea28f815f77b477a83f2c688743 Binary files /dev/null and b/omg_llava/dataset/utils/__pycache__/utils.cpython-310.pyc differ diff --git a/omg_llava/dataset/utils/ade20k_classes.json b/omg_llava/dataset/utils/ade20k_classes.json new file mode 100644 index 0000000000000000000000000000000000000000..1f96e616bc3fd2f8c0ec4caea975d77c680f44bb --- /dev/null +++ b/omg_llava/dataset/utils/ade20k_classes.json @@ -0,0 +1,30 @@ +[ + "wall", "building", "sky", "floor", "tree", "ceiling", "road", + "bed", "windowpane", "grass", "cabinet", "sidewalk", + "person", "earth", "door", "table", "mountain", "plant", + "curtain", "chair", "car", "water", "painting", "sofa", + "shelf", "house", "sea", "mirror", "rug", "field", "armchair", + "seat", "fence", "desk", "rock", "wardrobe", "lamp", + "bathtub", "railing", "cushion", "base", "box", "column", + "signboard", "chest of drawers", "counter", "sand", "sink", + "skyscraper", "fireplace", "refrigerator", "grandstand", + "path", "stairs", "runway", "case", "pool table", "pillow", + "screen door", "stairway", "river", "bridge", "bookcase", + "blind", "coffee table", "toilet", "flower", "book", "hill", + "bench", "countertop", "stove", "palm", "kitchen island", + "computer", "swivel chair", "boat", "bar", "arcade machine", + "hovel", "bus", "towel", "light", "truck", "tower", + "chandelier", "awning", "streetlight", "booth", + "television receiver", "airplane", "dirt track", "apparel", + "pole", "land", "bannister", "escalator", "ottoman", "bottle", + "buffet", "poster", "stage", "van", "ship", "fountain", + "conveyer belt", "canopy", "washer", "plaything", + "swimming pool", "stool", "barrel", "basket", "waterfall", + "tent", "bag", "minibike", "cradle", "oven", "ball", "food", + "step", "tank", "trade name", "microwave", "pot", "animal", + "bicycle", "lake", "dishwasher", "screen", "blanket", + "sculpture", "hood", "sconce", "vase", "traffic light", + "tray", "ashcan", "fan", "pier", "crt screen", "plate", + "monitor", "bulletin board", "shower", "radiator", "glass", + "clock", "flag" +] \ No newline at end of file diff --git a/omg_llava/dataset/utils/cocostuff_classes.txt b/omg_llava/dataset/utils/cocostuff_classes.txt new file mode 100644 index 0000000000000000000000000000000000000000..1d5a692b83ac8eead2bfffa805e1115cef737bae --- /dev/null +++ b/omg_llava/dataset/utils/cocostuff_classes.txt @@ -0,0 +1,183 @@ +0: unlabeled +1: person +2: bicycle +3: car +4: motorcycle +5: airplane +6: bus +7: train +8: truck +9: boat +10: traffic light +11: fire hydrant +12: street sign +13: stop sign +14: parking meter +15: bench +16: bird +17: cat +18: dog +19: horse +20: sheep +21: cow +22: elephant +23: bear +24: zebra +25: giraffe +26: hat +27: backpack +28: umbrella +29: shoe +30: eye glasses +31: handbag +32: tie +33: suitcase +34: frisbee +35: skis +36: snowboard +37: sports ball +38: kite +39: baseball bat +40: baseball glove +41: skateboard +42: surfboard +43: tennis racket +44: bottle +45: plate +46: wine glass +47: cup +48: fork +49: knife +50: spoon +51: bowl +52: banana +53: apple +54: sandwich +55: orange +56: broccoli +57: carrot +58: hot dog +59: pizza +60: donut +61: cake +62: chair +63: couch +64: potted plant +65: bed +66: mirror +67: dining table +68: window +69: desk +70: toilet +71: door +72: tv +73: laptop +74: mouse +75: remote +76: keyboard +77: cell phone +78: microwave +79: oven +80: toaster +81: sink +82: refrigerator +83: blender +84: book +85: clock +86: vase +87: scissors +88: teddy bear +89: hair drier +90: toothbrush +91: hair brush +92: banner +93: blanket +94: branch +95: bridge +96: building-other +97: bush +98: cabinet +99: cage +100: cardboard +101: carpet +102: ceiling-other +103: ceiling-tile +104: cloth +105: clothes +106: clouds +107: counter +108: cupboard +109: curtain +110: desk-stuff +111: dirt +112: door-stuff +113: fence +114: floor-marble +115: floor-other +116: floor-stone +117: floor-tile +118: floor-wood +119: flower +120: fog +121: food-other +122: fruit +123: furniture-other +124: grass +125: gravel +126: ground-other +127: hill +128: house +129: leaves +130: light +131: mat +132: metal +133: mirror-stuff +134: moss +135: mountain +136: mud +137: napkin +138: net +139: paper +140: pavement +141: pillow +142: plant-other +143: plastic +144: platform +145: playingfield +146: railing +147: railroad +148: river +149: road +150: rock +151: roof +152: rug +153: salad +154: sand +155: sea +156: shelf +157: sky +158: skyscraper +159: snow +160: solid-other +161: stairs +162: stone +163: straw +164: structural-other +165: table +166: tent +167: textile-other +168: towel +169: tree +170: vegetable +171: wall-brick +172: wall-concrete +173: wall-other +174: wall-panel +175: wall-stone +176: wall-tile +177: wall-wood +178: water-other +179: waterdrops +180: window-blind +181: window-other +182: wood diff --git a/omg_llava/dataset/utils/grefer.py b/omg_llava/dataset/utils/grefer.py new file mode 100644 index 0000000000000000000000000000000000000000..3c881c5860a2bbfc89eb91b8fcf91cc32c27fbbf --- /dev/null +++ b/omg_llava/dataset/utils/grefer.py @@ -0,0 +1,352 @@ +""" +grefer v0.1 +This interface provides access to gRefCOCO. + +The following API functions are defined: +G_REFER - REFER api class +getRefIds - get ref ids that satisfy given filter conditions. +getAnnIds - get ann ids that satisfy given filter conditions. +getImgIds - get image ids that satisfy given filter conditions. +getCatIds - get category ids that satisfy given filter conditions. +loadRefs - load refs with the specified ref ids. +loadAnns - load anns with the specified ann ids. +loadImgs - load images with the specified image ids. +loadCats - load category names with the specified category ids. +getRefBox - get ref's bounding box [x, y, w, h] given the ref_id +showRef - show image, segmentation or box of the referred object with the ref +getMaskByRef - get mask and area of the referred object given ref or ref ids +getMask - get mask and area of the referred object given ref +showMask - show mask of the referred object given ref +""" + +import itertools +import json +import os.path as osp +import pickle +import time + +import matplotlib.pyplot as plt +import numpy as np +import skimage.io as io +from matplotlib.collections import PatchCollection +from matplotlib.patches import Polygon, Rectangle +from pycocotools import mask + + +class G_REFER: + def __init__(self, data_root, dataset="grefcoco", splitBy="unc"): + # provide data_root folder which contains grefcoco + print("loading dataset %s into memory..." % dataset) + self.ROOT_DIR = osp.abspath(osp.dirname(__file__)) + self.DATA_DIR = osp.join(data_root, dataset) + if dataset in ["grefcoco"]: + self.IMAGE_DIR = osp.join(data_root, "images/train2014") + else: + raise KeyError("No refer dataset is called [%s]" % dataset) + + tic = time.time() + + # load refs from data/dataset/refs(dataset).json + self.data = {} + self.data["dataset"] = dataset + + ref_file = osp.join(self.DATA_DIR, f"grefs({splitBy}).p") + if osp.exists(ref_file): + self.data["refs"] = pickle.load(open(ref_file, "rb"), fix_imports=True) + else: + ref_file = osp.join(self.DATA_DIR, f"grefs({splitBy}).json") + if osp.exists(ref_file): + self.data["refs"] = json.load(open(ref_file, "rb")) + else: + raise FileNotFoundError("JSON file not found") + + # load annotations from data/dataset/instances.json + instances_file = osp.join(self.DATA_DIR, "instances.json") + instances = json.load(open(instances_file, "r")) + self.data["images"] = instances["images"] + self.data["annotations"] = instances["annotations"] + self.data["categories"] = instances["categories"] + + # create index + self.createIndex() + print("DONE (t=%.2fs)" % (time.time() - tic)) + + @staticmethod + def _toList(x): + return x if isinstance(x, list) else [x] + + @staticmethod + def match_any(a, b): + a = a if isinstance(a, list) else [a] + b = b if isinstance(b, list) else [b] + return set(a) & set(b) + + def createIndex(self): + # create sets of mapping + # 1) Refs: {ref_id: ref} + # 2) Anns: {ann_id: ann} + # 3) Imgs: {image_id: image} + # 4) Cats: {category_id: category_name} + # 5) Sents: {sent_id: sent} + # 6) imgToRefs: {image_id: refs} + # 7) imgToAnns: {image_id: anns} + # 8) refToAnn: {ref_id: ann} + # 9) annToRef: {ann_id: ref} + # 10) catToRefs: {category_id: refs} + # 11) sentToRef: {sent_id: ref} + # 12) sentToTokens: {sent_id: tokens} + print("creating index...") + # fetch info from instances + Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {} + Anns[-1] = None + for ann in self.data["annotations"]: + Anns[ann["id"]] = ann + imgToAnns[ann["image_id"]] = imgToAnns.get(ann["image_id"], []) + [ann] + for img in self.data["images"]: + Imgs[img["id"]] = img + for cat in self.data["categories"]: + Cats[cat["id"]] = cat["name"] + + # fetch info from refs + Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {} + Sents, sentToRef, sentToTokens = {}, {}, {} + availableSplits = [] + for ref in self.data["refs"]: + # ids + ref_id = ref["ref_id"] + ann_id = ref["ann_id"] + category_id = ref["category_id"] + image_id = ref["image_id"] + + if ref["split"] not in availableSplits: + availableSplits.append(ref["split"]) + + # add mapping related to ref + if ref_id in Refs: + print("Duplicate ref id") + Refs[ref_id] = ref + imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref] + + category_id = self._toList(category_id) + added_cats = [] + for cat in category_id: + if cat not in added_cats: + added_cats.append(cat) + catToRefs[cat] = catToRefs.get(cat, []) + [ref] + + ann_id = self._toList(ann_id) + refToAnn[ref_id] = [Anns[ann] for ann in ann_id] + for ann_id_n in ann_id: + annToRef[ann_id_n] = annToRef.get(ann_id_n, []) + [ref] + + # add mapping of sent + for sent in ref["sentences"]: + Sents[sent["sent_id"]] = sent + sentToRef[sent["sent_id"]] = ref + sentToTokens[sent["sent_id"]] = sent["tokens"] + + # create class members + self.Refs = Refs + self.Anns = Anns + self.Imgs = Imgs + self.Cats = Cats + self.Sents = Sents + self.imgToRefs = imgToRefs + self.imgToAnns = imgToAnns + self.refToAnn = refToAnn + self.annToRef = annToRef + self.catToRefs = catToRefs + self.sentToRef = sentToRef + self.sentToTokens = sentToTokens + self.availableSplits = availableSplits + print("index created.") + + def getRefIds(self, image_ids=[], cat_ids=[], split=[]): + image_ids = self._toList(image_ids) + cat_ids = self._toList(cat_ids) + split = self._toList(split) + + for s in split: + if s not in self.availableSplits: + raise ValueError(f"Invalid split name: {s}") + + refs = self.data["refs"] + + if len(image_ids) > 0: + lists = [self.imgToRefs[image_id] for image_id in image_ids] + refs = list(itertools.chain.from_iterable(lists)) + if len(cat_ids) > 0: + refs = [ref for ref in refs if self.match_any(ref["category_id"], cat_ids)] + if len(split) > 0: + refs = [ref for ref in refs if ref["split"] in split] + + ref_ids = [ref["ref_id"] for ref in refs] + return ref_ids + + def getAnnIds(self, image_ids=[], ref_ids=[]): + image_ids = self._toList(image_ids) + ref_ids = self._toList(ref_ids) + + if any([len(image_ids), len(ref_ids)]): + if len(image_ids) > 0: + lists = [ + self.imgToAnns[image_id] + for image_id in image_ids + if image_id in self.imgToAnns + ] + anns = list(itertools.chain.from_iterable(lists)) + else: + anns = self.data["annotations"] + ann_ids = [ann["id"] for ann in anns] + if len(ref_ids) > 0: + lists = [self.Refs[ref_id]["ann_id"] for ref_id in ref_ids] + anns_by_ref_id = list(itertools.chain.from_iterable(lists)) + ann_ids = list(set(ann_ids).intersection(set(anns_by_ref_id))) + else: + ann_ids = [ann["id"] for ann in self.data["annotations"]] + + return ann_ids + + def getImgIds(self, ref_ids=[]): + ref_ids = self._toList(ref_ids) + + if len(ref_ids) > 0: + image_ids = list(set([self.Refs[ref_id]["image_id"] for ref_id in ref_ids])) + else: + image_ids = self.Imgs.keys() + return image_ids + + def getCatIds(self): + return self.Cats.keys() + + def loadRefs(self, ref_ids=[]): + return [self.Refs[ref_id] for ref_id in self._toList(ref_ids)] + + def loadAnns(self, ann_ids=[]): + if isinstance(ann_ids, str): + ann_ids = int(ann_ids) + return [self.Anns[ann_id] for ann_id in self._toList(ann_ids)] + + def loadImgs(self, image_ids=[]): + return [self.Imgs[image_id] for image_id in self._toList(image_ids)] + + def loadCats(self, cat_ids=[]): + return [self.Cats[cat_id] for cat_id in self._toList(cat_ids)] + + def getRefBox(self, ref_id): + anns = self.refToAnn[ref_id] + return [ann["bbox"] for ann in anns] # [x, y, w, h] + + def showRef(self, ref, seg_box="seg"): + ax = plt.gca() + # show image + image = self.Imgs[ref["image_id"]] + I = io.imread(osp.join(self.IMAGE_DIR, image["file_name"])) + ax.imshow(I) + # show refer expression + for sid, sent in enumerate(ref["sentences"]): + print("%s. %s" % (sid + 1, sent["sent"])) + # show segmentations + if seg_box == "seg": + ann_id = ref["ann_id"] + ann = self.Anns[ann_id] + polygons = [] + color = [] + c = "none" + if type(ann["segmentation"][0]) == list: + # polygon used for refcoco* + for seg in ann["segmentation"]: + poly = np.array(seg).reshape((len(seg) / 2, 2)) + polygons.append(Polygon(poly, True, alpha=0.4)) + color.append(c) + p = PatchCollection( + polygons, + facecolors=color, + edgecolors=(1, 1, 0, 0), + linewidths=3, + alpha=1, + ) + ax.add_collection(p) # thick yellow polygon + p = PatchCollection( + polygons, + facecolors=color, + edgecolors=(1, 0, 0, 0), + linewidths=1, + alpha=1, + ) + ax.add_collection(p) # thin red polygon + else: + # mask used for refclef + rle = ann["segmentation"] + m = mask.decode(rle) + img = np.ones((m.shape[0], m.shape[1], 3)) + color_mask = np.array([2.0, 166.0, 101.0]) / 255 + for i in range(3): + img[:, :, i] = color_mask[i] + ax.imshow(np.dstack((img, m * 0.5))) + # show bounding-box + elif seg_box == "box": + ann_id = ref["ann_id"] + ann = self.Anns[ann_id] + bbox = self.getRefBox(ref["ref_id"]) + box_plot = Rectangle( + (bbox[0], bbox[1]), + bbox[2], + bbox[3], + fill=False, + edgecolor="green", + linewidth=3, + ) + ax.add_patch(box_plot) + + def getMask(self, ann): + if not ann: + return None + if ann["iscrowd"]: + raise ValueError("Crowd object") + image = self.Imgs[ann["image_id"]] + if type(ann["segmentation"][0]) == list: # polygon + rle = mask.frPyObjects(ann["segmentation"], image["height"], image["width"]) + else: + rle = ann["segmentation"] + + m = mask.decode(rle) + m = np.sum( + m, axis=2 + ) # sometimes there are multiple binary map (corresponding to multiple segs) + m = m.astype(np.uint8) # convert to np.uint8 + # compute area + area = sum(mask.area(rle)) # should be close to ann['area'] + return {"mask": m, "area": area} + + def getMaskByRef(self, ref=None, ref_id=None, merge=False): + if not ref and not ref_id: + raise ValueError + if ref: + ann_ids = ref["ann_id"] + ref_id = ref["ref_id"] + else: + ann_ids = self.getAnnIds(ref_ids=ref_id) + + if ann_ids == [-1]: + img = self.Imgs[self.Refs[ref_id]["image_id"]] + return { + "mask": np.zeros([img["height"], img["width"]], dtype=np.uint8), + "empty": True, + } + + anns = self.loadAnns(ann_ids) + mask_list = [self.getMask(ann) for ann in anns if not ann["iscrowd"]] + + if merge: + merged_masks = sum([mask["mask"] for mask in mask_list]) + merged_masks[np.where(merged_masks > 1)] = 1 + return {"mask": merged_masks, "empty": False} + else: + return mask_list + + def showMask(self, ref): + M = self.getMask(ref) + msk = M["mask"] + ax = plt.gca() + ax.imshow(msk) diff --git a/omg_llava/dataset/utils/refcoco_refer.py b/omg_llava/dataset/utils/refcoco_refer.py new file mode 100644 index 0000000000000000000000000000000000000000..3b4cea716e40e73d0b5aa118143eb076392f5eb1 --- /dev/null +++ b/omg_llava/dataset/utils/refcoco_refer.py @@ -0,0 +1,391 @@ +__author__ = "licheng" + +""" +This interface provides access to four datasets: +1) refclef +2) refcoco +3) refcoco+ +4) refcocog +split by unc and google + +The following API functions are defined: +REFER - REFER api class +getRefIds - get ref ids that satisfy given filter conditions. +getAnnIds - get ann ids that satisfy given filter conditions. +getImgIds - get image ids that satisfy given filter conditions. +getCatIds - get category ids that satisfy given filter conditions. +loadRefs - load refs with the specified ref ids. +loadAnns - load anns with the specified ann ids. +loadImgs - load images with the specified image ids. +loadCats - load category names with the specified category ids. +getRefBox - get ref's bounding box [x, y, w, h] given the ref_id +showRef - show image, segmentation or box of the referred object with the ref +getMask - get mask and area of the referred object given ref +showMask - show mask of the referred object given ref +""" + +import itertools +import json +import os.path as osp +import pickle +import sys +import time +from pprint import pprint + +import matplotlib.pyplot as plt +import numpy as np +import skimage.io as io +from matplotlib.collections import PatchCollection +from matplotlib.patches import Polygon, Rectangle +from pycocotools import mask + + +class REFER: + def __init__(self, data_root, dataset="refcoco", splitBy="unc"): + # provide data_root folder which contains refclef, refcoco, refcoco+ and refcocog + # also provide dataset name and splitBy information + # e.g., dataset = 'refcoco', splitBy = 'unc' + print("loading dataset %s into memory..." % dataset) + self.ROOT_DIR = osp.abspath(osp.dirname(__file__)) + self.DATA_DIR = osp.join(data_root, dataset) + if dataset in ["refcoco", "refcoco+", "refcocog"]: + self.IMAGE_DIR = osp.join(data_root, "images/mscoco/images/train2014") + elif dataset == "refclef": + self.IMAGE_DIR = osp.join(data_root, "images/saiapr_tc-12") + else: + print("No refer dataset is called [%s]" % dataset) + sys.exit() + + self.dataset = dataset + + # load refs from data/dataset/refs(dataset).json + tic = time.time() + + ref_file = osp.join(self.DATA_DIR, "refs(" + splitBy + ").p") + print("ref_file: ", ref_file) + self.data = {} + self.data["dataset"] = dataset + self.data["refs"] = pickle.load(open(ref_file, "rb")) + + # load annotations from data/dataset/instances.json + instances_file = osp.join(self.DATA_DIR, "instances.json") + instances = json.load(open(instances_file, "rb")) + self.data["images"] = instances["images"] + self.data["annotations"] = instances["annotations"] + self.data["categories"] = instances["categories"] + + # create index + self.createIndex() + print("DONE (t=%.2fs)" % (time.time() - tic)) + + def createIndex(self): + # create sets of mapping + # 1) Refs: {ref_id: ref} + # 2) Anns: {ann_id: ann} + # 3) Imgs: {image_id: image} + # 4) Cats: {category_id: category_name} + # 5) Sents: {sent_id: sent} + # 6) imgToRefs: {image_id: refs} + # 7) imgToAnns: {image_id: anns} + # 8) refToAnn: {ref_id: ann} + # 9) annToRef: {ann_id: ref} + # 10) catToRefs: {category_id: refs} + # 11) sentToRef: {sent_id: ref} + # 12) sentToTokens: {sent_id: tokens} + print("creating index...") + # fetch info from instances + Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {} + for ann in self.data["annotations"]: + Anns[ann["id"]] = ann + imgToAnns[ann["image_id"]] = imgToAnns.get(ann["image_id"], []) + [ann] + for img in self.data["images"]: + Imgs[img["id"]] = img + for cat in self.data["categories"]: + Cats[cat["id"]] = cat["name"] + + # fetch info from refs + Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {} + Sents, sentToRef, sentToTokens = {}, {}, {} + for ref in self.data["refs"]: + # ids + ref_id = ref["ref_id"] + ann_id = ref["ann_id"] + category_id = ref["category_id"] + image_id = ref["image_id"] + + # add mapping related to ref + Refs[ref_id] = ref + imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref] + catToRefs[category_id] = catToRefs.get(category_id, []) + [ref] + refToAnn[ref_id] = Anns[ann_id] + annToRef[ann_id] = ref + + # add mapping of sent + for sent in ref["sentences"]: + Sents[sent["sent_id"]] = sent + sentToRef[sent["sent_id"]] = ref + sentToTokens[sent["sent_id"]] = sent["tokens"] + + # create class members + self.Refs = Refs + self.Anns = Anns + self.Imgs = Imgs + self.Cats = Cats + self.Sents = Sents + self.imgToRefs = imgToRefs + self.imgToAnns = imgToAnns + self.refToAnn = refToAnn + self.annToRef = annToRef + self.catToRefs = catToRefs + self.sentToRef = sentToRef + self.sentToTokens = sentToTokens + print("index created.") + + def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=""): + image_ids = image_ids if type(image_ids) == list else [image_ids] + cat_ids = cat_ids if type(cat_ids) == list else [cat_ids] + ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] + + if len(image_ids) == len(cat_ids) == len(ref_ids) == len(split) == 0: + refs = self.data["refs"] + else: + if not len(image_ids) == 0: + refs = [self.imgToRefs[image_id] for image_id in image_ids] + else: + refs = self.data["refs"] + if not len(cat_ids) == 0: + refs = [ref for ref in refs if ref["category_id"] in cat_ids] + if not len(ref_ids) == 0: + refs = [ref for ref in refs if ref["ref_id"] in ref_ids] + if not len(split) == 0: + if split in ["testA", "testB", "testC"]: + refs = [ + ref for ref in refs if split[-1] in ref["split"] + ] # we also consider testAB, testBC, ... + elif split in ["testAB", "testBC", "testAC"]: + refs = [ + ref for ref in refs if ref["split"] == split + ] # rarely used I guess... + elif split == "test": + refs = [ref for ref in refs if "test" in ref["split"]] + elif split == "train" or split == "val": + refs = [ref for ref in refs if ref["split"] == split] + else: + print("No such split [%s]" % split) + sys.exit() + ref_ids = [ref["ref_id"] for ref in refs] + return ref_ids + + def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]): + image_ids = image_ids if type(image_ids) == list else [image_ids] + cat_ids = cat_ids if type(cat_ids) == list else [cat_ids] + ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] + + if len(image_ids) == len(cat_ids) == len(ref_ids) == 0: + ann_ids = [ann["id"] for ann in self.data["annotations"]] + else: + if not len(image_ids) == 0: + lists = [ + self.imgToAnns[image_id] + for image_id in image_ids + if image_id in self.imgToAnns + ] # list of [anns] + anns = list(itertools.chain.from_iterable(lists)) + else: + anns = self.data["annotations"] + if not len(cat_ids) == 0: + anns = [ann for ann in anns if ann["category_id"] in cat_ids] + ann_ids = [ann["id"] for ann in anns] + if not len(ref_ids) == 0: + ids = set(ann_ids).intersection( + set([self.Refs[ref_id]["ann_id"] for ref_id in ref_ids]) + ) + return ann_ids + + def getImgIds(self, ref_ids=[]): + ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] + + if not len(ref_ids) == 0: + image_ids = list(set([self.Refs[ref_id]["image_id"] for ref_id in ref_ids])) + else: + image_ids = self.Imgs.keys() + return image_ids + + def getCatIds(self): + return self.Cats.keys() + + def loadRefs(self, ref_ids=[]): + if type(ref_ids) == list: + return [self.Refs[ref_id] for ref_id in ref_ids] + elif type(ref_ids) == int: + return [self.Refs[ref_ids]] + + def loadAnns(self, ann_ids=[]): + if type(ann_ids) == list: + return [self.Anns[ann_id] for ann_id in ann_ids] + elif type(ann_ids) == int or type(ann_ids) == unicode: + return [self.Anns[ann_ids]] + + def loadImgs(self, image_ids=[]): + if type(image_ids) == list: + return [self.Imgs[image_id] for image_id in image_ids] + elif type(image_ids) == int: + return [self.Imgs[image_ids]] + + def loadCats(self, cat_ids=[]): + if type(cat_ids) == list: + return [self.Cats[cat_id] for cat_id in cat_ids] + elif type(cat_ids) == int: + return [self.Cats[cat_ids]] + + def getRefBox(self, ref_id): + ref = self.Refs[ref_id] + ann = self.refToAnn[ref_id] + return ann["bbox"] # [x, y, w, h] + + def showRef(self, ref, seg_box="seg"): + ax = plt.gca() + # show image + image = self.Imgs[ref["image_id"]] + I = io.imread(osp.join(self.IMAGE_DIR, image["file_name"])) + ax.imshow(I) + # show refer expression + for sid, sent in enumerate(ref["sentences"]): + print("%s. %s" % (sid + 1, sent["sent"])) + # show segmentations + if seg_box == "seg": + ann_id = ref["ann_id"] + ann = self.Anns[ann_id] + polygons = [] + color = [] + c = "none" + if type(ann["segmentation"][0]) == list: + # polygon used for refcoco* + for seg in ann["segmentation"]: + poly = np.array(seg).reshape((len(seg) / 2, 2)) + polygons.append(Polygon(poly, True, alpha=0.4)) + color.append(c) + p = PatchCollection( + polygons, + facecolors=color, + edgecolors=(1, 1, 0, 0), + linewidths=3, + alpha=1, + ) + ax.add_collection(p) # thick yellow polygon + p = PatchCollection( + polygons, + facecolors=color, + edgecolors=(1, 0, 0, 0), + linewidths=1, + alpha=1, + ) + ax.add_collection(p) # thin red polygon + else: + # mask used for refclef + rle = ann["segmentation"] + m = mask.decode(rle) + img = np.ones((m.shape[0], m.shape[1], 3)) + color_mask = np.array([2.0, 166.0, 101.0]) / 255 + for i in range(3): + img[:, :, i] = color_mask[i] + ax.imshow(np.dstack((img, m * 0.5))) + # show bounding-box + elif seg_box == "box": + ann_id = ref["ann_id"] + ann = self.Anns[ann_id] + bbox = self.getRefBox(ref["ref_id"]) + box_plot = Rectangle( + (bbox[0], bbox[1]), + bbox[2], + bbox[3], + fill=False, + edgecolor="green", + linewidth=3, + ) + ax.add_patch(box_plot) + + def getMask(self, ref): + # return mask, area and mask-center + ann = self.refToAnn[ref["ref_id"]] + image = self.Imgs[ref["image_id"]] + if type(ann["segmentation"][0]) == list: # polygon + rle = mask.frPyObjects(ann["segmentation"], image["height"], image["width"]) + else: + rle = ann["segmentation"] + m = mask.decode(rle) + m = np.sum( + m, axis=2 + ) # sometimes there are multiple binary map (corresponding to multiple segs) + m = m.astype(np.uint8) # convert to np.uint8 + # compute area + area = sum(mask.area(rle)) # should be close to ann['area'] + return {"mask": m, "area": area} + # # position + # position_x = np.mean(np.where(m==1)[1]) # [1] means columns (matlab style) -> x (c style) + # position_y = np.mean(np.where(m==1)[0]) # [0] means rows (matlab style) -> y (c style) + # # mass position (if there were multiple regions, we use the largest one.) + # label_m = label(m, connectivity=m.ndim) + # regions = regionprops(label_m) + # if len(regions) > 0: + # largest_id = np.argmax(np.array([props.filled_area for props in regions])) + # largest_props = regions[largest_id] + # mass_y, mass_x = largest_props.centroid + # else: + # mass_x, mass_y = position_x, position_y + # # if centroid is not in mask, we find the closest point to it from mask + # if m[mass_y, mass_x] != 1: + # print('Finding closes mask point ...') + # kernel = np.ones((10, 10),np.uint8) + # me = cv2.erode(m, kernel, iterations = 1) + # points = zip(np.where(me == 1)[0].tolist(), np.where(me == 1)[1].tolist()) # row, col style + # points = np.array(points) + # dist = np.sum((points - (mass_y, mass_x))**2, axis=1) + # id = np.argsort(dist)[0] + # mass_y, mass_x = points[id] + # # return + # return {'mask': m, 'area': area, 'position_x': position_x, 'position_y': position_y, 'mass_x': mass_x, 'mass_y': mass_y} + # # show image and mask + # I = io.imread(osp.join(self.IMAGE_DIR, image['file_name'])) + # plt.figure() + # plt.imshow(I) + # ax = plt.gca() + # img = np.ones( (m.shape[0], m.shape[1], 3) ) + # color_mask = np.array([2.0,166.0,101.0])/255 + # for i in range(3): + # img[:,:,i] = color_mask[i] + # ax.imshow(np.dstack( (img, m*0.5) )) + # plt.show() + + def showMask(self, ref): + M = self.getMask(ref) + msk = M["mask"] + ax = plt.gca() + ax.imshow(msk) + + +if __name__ == "__main__": + refer = REFER(dataset="refcocog", splitBy="google") + ref_ids = refer.getRefIds() + print(len(ref_ids)) + + print(len(refer.Imgs)) + print(len(refer.imgToRefs)) + + ref_ids = refer.getRefIds(split="train") + print("There are %s training referred objects." % len(ref_ids)) + + for ref_id in ref_ids: + ref = refer.loadRefs(ref_id)[0] + if len(ref["sentences"]) < 2: + continue + + pprint(ref) + print("The label is %s." % refer.Cats[ref["category_id"]]) + plt.figure() + refer.showRef(ref, seg_box="box") + plt.show() + + # plt.figure() + # refer.showMask(ref) + # plt.show() diff --git a/omg_llava/dataset/utils/utils.py b/omg_llava/dataset/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7b4a6cdd89e09c4f3a78fc2270b6c12cc5b6d77d --- /dev/null +++ b/omg_llava/dataset/utils/utils.py @@ -0,0 +1,71 @@ +import numpy as np +from PIL import Image + +def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + +def expand2square_mask(mask): + # mask (n, h, w) + n_mask, width, height = mask.shape + if width == height: + return mask + elif width > height: + n_pad = width - height + n_pad_1 = n_pad // 2 + n_pad_2 = n_pad - n_pad_1 + pad_mask_1 = np.zeros((n_mask, width, n_pad_1), dtype=np.uint8) + pad_mask_2 = np.zeros((n_mask, width, n_pad_2), dtype=np.uint8) + result = np.concatenate([pad_mask_1, mask, pad_mask_2], axis=2) + return result + else: + n_pad = height - width + n_pad_1 = n_pad // 2 + n_pad_2 = n_pad - n_pad_1 + pad_mask_1 = np.zeros((n_mask, n_pad_1, height), dtype=np.uint8) + pad_mask_2 = np.zeros((n_mask, n_pad_2, height), dtype=np.uint8) + result = np.concatenate([pad_mask_1, mask, pad_mask_2], axis=1) + return result + +def expand2square_bbox(bboxes, width, height): + bboxes = np.array(bboxes) + if width == height: + return bboxes + elif width > height: + n_pad = width - height + n_pad_1 = n_pad // 2 + n_pad_2 = n_pad - n_pad_1 + bboxes[:, 1] += n_pad_1 + return bboxes + else: + n_pad = height - width + n_pad_1 = n_pad // 2 + n_pad_2 = n_pad - n_pad_1 + bboxes[:, 0] += n_pad_1 + return bboxes + +def expand2square_points(points, width, height): + if width == height: + return points + elif width > height: + n_pad = width - height + n_pad_1 = n_pad // 2 + n_pad_2 = n_pad - n_pad_1 + points[:, 1] += n_pad_1 + return points + else: + n_pad = height - width + n_pad_1 = n_pad // 2 + n_pad_2 = n_pad - n_pad_1 + points[:, 0] += n_pad_1 + return points + diff --git a/omg_llava/engine/__init__.py b/omg_llava/engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f7d1e8d0a1ceb1bd7a89f55661902b5f3c895191 --- /dev/null +++ b/omg_llava/engine/__init__.py @@ -0,0 +1,2 @@ +from .dataset_info_hook import DatasetInfoHook_withSpecoalTokens +from .evaluate_chat_hook import EvaluateChatHook_withSpecialTokens \ No newline at end of file diff --git a/omg_llava/engine/dataset_info_hook.py b/omg_llava/engine/dataset_info_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..3d7aa45dafbf97a231c139c8a19caaed32232db5 --- /dev/null +++ b/omg_llava/engine/dataset_info_hook.py @@ -0,0 +1,17 @@ +from xtuner.registry import BUILDER +from xtuner.engine.hooks import DatasetInfoHook + +class DatasetInfoHook_withSpecoalTokens(DatasetInfoHook): + def __init__(self, tokenizer, is_intern_repo_dataset=False): + self.tokenizer = BUILDER.build(tokenizer) + self.is_intern_repo_dataset = is_intern_repo_dataset + # add special tokens + # Adding special tokens for pixel grounding + segmentation_tokens = ['[SEG]'] + # Adding tokens for GCG + phrase_tokens = ['

', '

'] + # add for visual prompt + region_tokens = [''] + point_tokens = [''] + special_tokens = segmentation_tokens + phrase_tokens + region_tokens + point_tokens + self.tokenizer.add_tokens(special_tokens, special_tokens=True) \ No newline at end of file diff --git a/omg_llava/engine/evaluate_chat_hook.py b/omg_llava/engine/evaluate_chat_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..8c8449f75d058415ed8099c9465987bf1aeeb4f4 --- /dev/null +++ b/omg_llava/engine/evaluate_chat_hook.py @@ -0,0 +1,164 @@ +import torch +from xtuner.dataset.utils import expand2square +from xtuner.model.utils import prepare_inputs_labels_for_multimodal +from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX) + +import warnings +from mmengine.utils.misc import get_object_from_string +from transformers import GenerationConfig, StoppingCriteriaList +from xtuner.dataset.utils import load_image +from xtuner.registry import BUILDER +from xtuner.utils import StopWordStoppingCriteria +from xtuner.engine.hooks import EvaluateChatHook + + + +class EvaluateChatHook_withSpecialTokens(EvaluateChatHook): + priority = 'LOW' + def __init__(self, + tokenizer, + evaluation_inputs, + evaluation_images=None, + image_processor=None, + system='', + prompt_template=None, + every_n_iters=None, + max_new_tokens=600, + stop_word=None, + stop_words=[]): + self.evaluation_inputs = evaluation_inputs + if isinstance(self.evaluation_inputs, str): + self.evaluation_inputs = [self.evaluation_inputs] + self.evaluation_images = evaluation_images + if isinstance(self.evaluation_images, str): + self.evaluation_images = [self.evaluation_images] + if self.evaluation_images is not None: + assert len( + self.evaluation_images) in [1, len(self.evaluation_inputs)] + if len(self.evaluation_images) == 1: + self.evaluation_images = [self.evaluation_images[0]] * len( + self.evaluation_inputs) + self.evaluation_images = [ + load_image(img) for img in self.evaluation_images + ] + if prompt_template is None: + instruction = '{input}' + else: + if isinstance(prompt_template, str): # for resume + prompt_template = get_object_from_string(prompt_template) + instruction = prompt_template.get('INSTRUCTION', '{input}') + if system != '': + system = prompt_template.get( + 'SYSTEM', '{system}\n').format(system=system) + stop_words += prompt_template.get('STOP_WORDS', []) + if stop_word is not None: + # TODO: deprecation, v0.3.0 + warnings.warn( + ('The `stop_word` argument is deprecated and will be removed ' + 'in v0.3.0, use `stop_words` instead.'), DeprecationWarning) + stop_words.append(stop_word) + self.instruction = instruction + self.system = system + self.every_n_iters = every_n_iters + self.max_new_tokens = max_new_tokens + self.tokenizer = BUILDER.build(tokenizer) + self._add_special_tokens() + if image_processor is not None: + self.image_processor = BUILDER.build(image_processor) + self.stop_criteria = StoppingCriteriaList() + # default generation config + self.gen_config = GenerationConfig( + max_new_tokens=max_new_tokens, + do_sample=True, + temperature=0.1, + top_p=0.75, + top_k=40, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.pad_token_id + if self.tokenizer.pad_token_id is not None else + self.tokenizer.eos_token_id, + ) + self.stop_criteria = StoppingCriteriaList() + for word in stop_words: + self.stop_criteria.append( + StopWordStoppingCriteria(self.tokenizer, word)) + + self.is_first_run = True + + def _add_special_tokens(self): + assert hasattr(self, "tokenizer") + # Adding special tokens for pixel grounding + segmentation_tokens = ['[SEG]'] + # Adding tokens for GCG + phrase_tokens = ['

', '

'] + # add for visual prompt + region_tokens = [''] + point_tokens = [''] + special_tokens = segmentation_tokens + phrase_tokens + region_tokens + point_tokens + self.tokenizer.add_tokens(special_tokens, special_tokens=True) + return + + def _eval_images(self, + runner, + model, + device, + max_new_tokens=None, + save_eval_output=False): + if save_eval_output: + eval_outputs = [] + + for sample_image, sample_input in zip(self.evaluation_images, + self.evaluation_inputs): + image = expand2square( + sample_image, + tuple(int(x * 255) for x in self.image_processor.image_mean)) + image = self.image_processor.preprocess( + image, return_tensors='pt')['pixel_values'][0] + image = image.to(device) + sample_input = DEFAULT_IMAGE_TOKEN + '\n' + sample_input + inputs = (self.system + self.instruction).format( + input=sample_input, round=1, **runner.cfg) + chunk_encode = [] + for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)): + if idx == 0: + cur_encode = self.tokenizer.encode(chunk) + else: + cur_encode = self.tokenizer.encode( + chunk, add_special_tokens=False) + chunk_encode.append(cur_encode) + assert len(chunk_encode) == 2 + input_ids = [] + for idx, cur_chunk_encode in enumerate(chunk_encode): + input_ids.extend(cur_chunk_encode) + if idx != len(chunk_encode) - 1: + input_ids.append(IMAGE_TOKEN_INDEX) + input_ids = torch.tensor(input_ids).to(device) + visual_outputs = model.visual_encoder( + image.unsqueeze(0).to(model.visual_encoder.dtype), + output_hidden_states=True) + if isinstance(visual_outputs, list) or isinstance(visual_outputs, tuple)\ + or isinstance(visual_outputs, torch.Tensor): + pixel_values = model.projector(visual_outputs) + else: + pixel_values = model.projector( + visual_outputs.hidden_states[model.visual_select_layer][:, 1:]) + + mm_inputs = prepare_inputs_labels_for_multimodal( + llm=model.llm, + input_ids=input_ids.unsqueeze(0), + pixel_values=pixel_values) + + generation_output = model.generate( + **mm_inputs, + max_new_tokens=max_new_tokens, + generation_config=self.gen_config, + bos_token_id=self.tokenizer.bos_token_id, + stopping_criteria=self.stop_criteria) + generation_output = self.tokenizer.decode(generation_output[0]) + runner.logger.info(f'Sample output:\n' + f'{inputs + generation_output}\n') + if save_eval_output: + eval_outputs.append(f'{inputs + generation_output}\n') + + if save_eval_output: + self._save_eval_output(runner, eval_outputs) diff --git a/omg_llava/model/__init__.py b/omg_llava/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3c02ac89dc9b664083e5616216ee72aed95316b5 --- /dev/null +++ b/omg_llava/model/__init__.py @@ -0,0 +1,4 @@ +from .convnext_clip import OpenCLIPBackbone, OpenCLIPBackbone_omgseg +from .modules import ProjectorConfig_OMG_LLaVA, ProjectorModel_OMG_LLaVA +from .omg_seg import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead +from .omg_llava import OMG_LLaVA \ No newline at end of file diff --git a/omg_llava/model/__pycache__/__init__.cpython-310.pyc b/omg_llava/model/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0df5368bd73c65d0fb454858a8b6d36bd2715419 Binary files /dev/null and b/omg_llava/model/__pycache__/__init__.cpython-310.pyc differ diff --git a/omg_llava/model/__pycache__/omg_llava.cpython-310.pyc b/omg_llava/model/__pycache__/omg_llava.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe862ee01d0e419fc09b6d1844729c5b21b55e55 Binary files /dev/null and b/omg_llava/model/__pycache__/omg_llava.cpython-310.pyc differ diff --git a/omg_llava/model/__pycache__/utils.cpython-310.pyc b/omg_llava/model/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2338c0e095dce9c84b5ba92ea5bb68935285810b Binary files /dev/null and b/omg_llava/model/__pycache__/utils.cpython-310.pyc differ diff --git a/omg_llava/model/convnext_clip/__init__.py b/omg_llava/model/convnext_clip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..44800fd1af53693e198bb5db5d8ccae990fd5aee --- /dev/null +++ b/omg_llava/model/convnext_clip/__init__.py @@ -0,0 +1 @@ +from .openclip_backbone import OpenCLIPBackbone, OpenCLIPBackbone_omgseg \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/__pycache__/__init__.cpython-310.pyc b/omg_llava/model/convnext_clip/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0e420d620b09fbeed06420533afcad978463368 Binary files /dev/null and b/omg_llava/model/convnext_clip/__pycache__/__init__.cpython-310.pyc differ diff --git a/omg_llava/model/convnext_clip/__pycache__/openclip_backbone.cpython-310.pyc b/omg_llava/model/convnext_clip/__pycache__/openclip_backbone.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6227ba58d22c8ac087dbb8eaf6d881be11a265a4 Binary files /dev/null and b/omg_llava/model/convnext_clip/__pycache__/openclip_backbone.cpython-310.pyc differ diff --git a/omg_llava/model/convnext_clip/open_clip/__init__.py b/omg_llava/model/convnext_clip/open_clip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fdb1199b8aa87a919abff1bd0020c6624757ac62 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/__init__.py @@ -0,0 +1,15 @@ +from .coca_model import CoCa +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss +from .factory import list_models, add_model_config, get_model_config, load_checkpoint +from .loss import ClipLoss, DistillClipLoss, CoCaLoss +from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \ + convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype +from .openai import load_openai_model, list_openai_models +from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \ + get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained +from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub +from .tokenizer import SimpleTokenizer, tokenize, decode +from .transform import image_transform, AugmentationCfg +from .zero_shot_classifier import build_zero_shot_classifier, build_zero_shot_classifier_legacy +from .zero_shot_metadata import OPENAI_IMAGENET_TEMPLATES, SIMPLE_IMAGENET_TEMPLATES, IMAGENET_CLASSNAMES diff --git a/omg_llava/model/convnext_clip/open_clip/__pycache__/__init__.cpython-310.pyc b/omg_llava/model/convnext_clip/open_clip/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b08da39546905af3bdc4b4e8a0bc06fd4e825d1 Binary files /dev/null and b/omg_llava/model/convnext_clip/open_clip/__pycache__/__init__.cpython-310.pyc differ diff --git a/omg_llava/model/convnext_clip/open_clip/__pycache__/coca_model.cpython-310.pyc b/omg_llava/model/convnext_clip/open_clip/__pycache__/coca_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75f8ec317b20789f9e0c3866d5b50743dbb637c4 Binary files /dev/null and b/omg_llava/model/convnext_clip/open_clip/__pycache__/coca_model.cpython-310.pyc differ diff --git a/omg_llava/model/convnext_clip/open_clip/__pycache__/constants.cpython-310.pyc b/omg_llava/model/convnext_clip/open_clip/__pycache__/constants.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1529a6521fa5287ba038893aef2726a3487baea5 Binary files /dev/null and b/omg_llava/model/convnext_clip/open_clip/__pycache__/constants.cpython-310.pyc differ diff --git a/omg_llava/model/convnext_clip/open_clip/__pycache__/factory.cpython-310.pyc b/omg_llava/model/convnext_clip/open_clip/__pycache__/factory.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a9957d381b4645d31ebfa1cfa4d48fb4fe5fe73 Binary files /dev/null and b/omg_llava/model/convnext_clip/open_clip/__pycache__/factory.cpython-310.pyc differ diff --git a/omg_llava/model/convnext_clip/open_clip/__pycache__/hf_configs.cpython-310.pyc b/omg_llava/model/convnext_clip/open_clip/__pycache__/hf_configs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..885a5e9ba32a45eccfe073e6f2b1f3da99d77c61 Binary files /dev/null and b/omg_llava/model/convnext_clip/open_clip/__pycache__/hf_configs.cpython-310.pyc differ diff --git a/omg_llava/model/convnext_clip/open_clip/__pycache__/hf_model.cpython-310.pyc b/omg_llava/model/convnext_clip/open_clip/__pycache__/hf_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e949207ce672517a59ea45f07ac1193f5e8a5e0 Binary files /dev/null and b/omg_llava/model/convnext_clip/open_clip/__pycache__/hf_model.cpython-310.pyc differ diff --git a/omg_llava/model/convnext_clip/open_clip/__pycache__/loss.cpython-310.pyc b/omg_llava/model/convnext_clip/open_clip/__pycache__/loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69a4a00367104825c4c38705a6c7577a49d0aed8 Binary files /dev/null and b/omg_llava/model/convnext_clip/open_clip/__pycache__/loss.cpython-310.pyc differ diff --git a/omg_llava/model/convnext_clip/open_clip/__pycache__/model.cpython-310.pyc b/omg_llava/model/convnext_clip/open_clip/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8eb050f1da42f1ec9c541424de2a52cd13b75d25 Binary files /dev/null and b/omg_llava/model/convnext_clip/open_clip/__pycache__/model.cpython-310.pyc differ diff --git a/omg_llava/model/convnext_clip/open_clip/__pycache__/modified_resnet.cpython-310.pyc b/omg_llava/model/convnext_clip/open_clip/__pycache__/modified_resnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..694a475cf7946c1c2d8c08e524e57c3c086aa010 Binary files /dev/null and b/omg_llava/model/convnext_clip/open_clip/__pycache__/modified_resnet.cpython-310.pyc differ diff --git a/omg_llava/model/convnext_clip/open_clip/__pycache__/openai.cpython-310.pyc b/omg_llava/model/convnext_clip/open_clip/__pycache__/openai.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad9229a7aa902831b25041a427121a8b77c1e6c6 Binary files /dev/null and b/omg_llava/model/convnext_clip/open_clip/__pycache__/openai.cpython-310.pyc differ diff --git a/omg_llava/model/convnext_clip/open_clip/__pycache__/pretrained.cpython-310.pyc b/omg_llava/model/convnext_clip/open_clip/__pycache__/pretrained.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65c92e566a4ba0063f97b109946832e5969217c5 Binary files /dev/null and b/omg_llava/model/convnext_clip/open_clip/__pycache__/pretrained.cpython-310.pyc differ diff --git a/omg_llava/model/convnext_clip/open_clip/__pycache__/push_to_hf_hub.cpython-310.pyc b/omg_llava/model/convnext_clip/open_clip/__pycache__/push_to_hf_hub.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..498040231065f9a9696c8c18a3c79257e3b65d49 Binary files /dev/null and b/omg_llava/model/convnext_clip/open_clip/__pycache__/push_to_hf_hub.cpython-310.pyc differ diff --git a/omg_llava/model/convnext_clip/open_clip/__pycache__/timm_model.cpython-310.pyc b/omg_llava/model/convnext_clip/open_clip/__pycache__/timm_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..913676ef89fd9ff5523c47811f443f3f13d80a57 Binary files /dev/null and b/omg_llava/model/convnext_clip/open_clip/__pycache__/timm_model.cpython-310.pyc differ diff --git a/omg_llava/model/convnext_clip/open_clip/__pycache__/tokenizer.cpython-310.pyc b/omg_llava/model/convnext_clip/open_clip/__pycache__/tokenizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59f275f926b23342314780f37c5aa2dbf69f95cf Binary files /dev/null and b/omg_llava/model/convnext_clip/open_clip/__pycache__/tokenizer.cpython-310.pyc differ diff --git a/omg_llava/model/convnext_clip/open_clip/__pycache__/transform.cpython-310.pyc b/omg_llava/model/convnext_clip/open_clip/__pycache__/transform.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7bf2b6b4813d1f2678d7e59292609591be9eb603 Binary files /dev/null and b/omg_llava/model/convnext_clip/open_clip/__pycache__/transform.cpython-310.pyc differ diff --git a/omg_llava/model/convnext_clip/open_clip/__pycache__/transformer.cpython-310.pyc b/omg_llava/model/convnext_clip/open_clip/__pycache__/transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f1e8ecfa096392ff201f3070c3c1b87d62f3e7a Binary files /dev/null and b/omg_llava/model/convnext_clip/open_clip/__pycache__/transformer.cpython-310.pyc differ diff --git a/omg_llava/model/convnext_clip/open_clip/__pycache__/utils.cpython-310.pyc b/omg_llava/model/convnext_clip/open_clip/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f8d422b71625f2a8aef3a521b4c266eaa02caee Binary files /dev/null and b/omg_llava/model/convnext_clip/open_clip/__pycache__/utils.cpython-310.pyc differ diff --git a/omg_llava/model/convnext_clip/open_clip/__pycache__/version.cpython-310.pyc b/omg_llava/model/convnext_clip/open_clip/__pycache__/version.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e3a911efa331065050aea154fb1ef896cecff16 Binary files /dev/null and b/omg_llava/model/convnext_clip/open_clip/__pycache__/version.cpython-310.pyc differ diff --git a/omg_llava/model/convnext_clip/open_clip/__pycache__/zero_shot_classifier.cpython-310.pyc b/omg_llava/model/convnext_clip/open_clip/__pycache__/zero_shot_classifier.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6945171910cf23419943144ab651c5d4cd1b790c Binary files /dev/null and b/omg_llava/model/convnext_clip/open_clip/__pycache__/zero_shot_classifier.cpython-310.pyc differ diff --git a/omg_llava/model/convnext_clip/open_clip/__pycache__/zero_shot_metadata.cpython-310.pyc b/omg_llava/model/convnext_clip/open_clip/__pycache__/zero_shot_metadata.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6bf5ade35ef928bfffe3e2032b6d8ebfaac666c Binary files /dev/null and b/omg_llava/model/convnext_clip/open_clip/__pycache__/zero_shot_metadata.cpython-310.pyc differ diff --git a/omg_llava/model/convnext_clip/open_clip/bpe_simple_vocab_16e6.txt.gz b/omg_llava/model/convnext_clip/open_clip/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/omg_llava/model/convnext_clip/open_clip/coca_model.py b/omg_llava/model/convnext_clip/open_clip/coca_model.py new file mode 100644 index 0000000000000000000000000000000000000000..039453af70d1c865dd7cc6016f732aff2f7dc3d2 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/coca_model.py @@ -0,0 +1,458 @@ +from typing import Optional + +import torch +from torch import nn +from torch.nn import functional as F +import numpy as np +from dataclasses import dataclass + +from .transformer import ( + LayerNormFp32, + LayerNorm, + QuickGELU, + MultimodalTransformer, +) +from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower + +try: + from transformers import ( + BeamSearchScorer, + LogitsProcessorList, + TopPLogitsWarper, + TopKLogitsWarper, + RepetitionPenaltyLogitsProcessor, + MinLengthLogitsProcessor, + MaxLengthCriteria, + StoppingCriteriaList + ) + + GENERATION_TYPES = { + "top_k": TopKLogitsWarper, + "top_p": TopPLogitsWarper, + "beam_search": "beam_search" + } + _has_transformers = True +except ImportError as e: + GENERATION_TYPES = { + "top_k": None, + "top_p": None, + "beam_search": "beam_search" + } + _has_transformers = False + + +@dataclass +class MultimodalCfg(CLIPTextCfg): + mlp_ratio: int = 4 + dim_head: int = 64 + heads: int = 8 + n_queries: int = 256 + attn_pooler_heads: int = 8 + + +def _build_text_decoder_tower( + embed_dim, + multimodal_cfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, +): + multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg + act_layer = QuickGELU if quick_gelu else nn.GELU + norm_layer = ( + LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + ) + + decoder = MultimodalTransformer( + context_length=multimodal_cfg.context_length, + width=multimodal_cfg.width, + heads=multimodal_cfg.heads, + layers=multimodal_cfg.layers, + ls_init_value=multimodal_cfg.ls_init_value, + output_dim=embed_dim, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + return decoder + + +class CoCa(nn.Module): + def __init__( + self, + embed_dim, + multimodal_cfg: MultimodalCfg, + text_cfg: CLIPTextCfg, + vision_cfg: CLIPVisionCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + pad_id: int = 0, + ): + super().__init__() + multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg + text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg + vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg + + self.text = _build_text_tower( + embed_dim=embed_dim, + text_cfg=text_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + + vocab_size = ( + text_cfg.vocab_size # for hf models + if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None + else text_cfg.vocab_size + ) + + self.visual = _build_vision_tower( + embed_dim=embed_dim, + vision_cfg=vision_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + + self.text_decoder = _build_text_decoder_tower( + vocab_size, + multimodal_cfg=multimodal_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.pad_id = pad_id + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.text.set_grad_checkpointing(enable) + self.text_decoder.set_grad_checkpointing(enable) + + def _encode_image(self, images, normalize=True): + image_latent, tokens_embs = self.visual(images) + image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent + return image_latent, tokens_embs + + def _encode_text(self, text, normalize=True, embed_cls=True): + text = text[:, :-1] if embed_cls else text # make space for CLS token + text_latent, token_emb = self.text(text) + text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent + return text_latent, token_emb + + def encode_image(self, images, normalize=True): + image_latent, _ = self._encode_image(images, normalize=normalize) + return image_latent + + def encode_text(self, text, normalize=True, embed_cls=True): + text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls) + return text_latent + + def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None): + text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls) + if image_latent is None or image_embs is None: + image_latent, image_embs = self._encode_image(image) + + # TODO: add assertion to avoid bugs? + labels = text[:, -token_embs.shape[1]:] + + logits = self.text_decoder(image_embs, token_embs) + return { + "image_features": image_latent, + "text_features": text_latent, + "logits": logits, + "labels": labels, + "logit_scale": self.logit_scale.exp() + } + + def generate( + self, + image, + text=None, + seq_len=30, + max_seq_len=77, + temperature=1., + generation_type="beam_search", + top_p=0.1, # keep tokens in the 1 - top_p quantile + top_k=1, # keeps the top_k most probable tokens + pad_token_id=None, + eos_token_id=None, + sot_token_id=None, + num_beams=6, + num_beam_groups=3, + min_seq_len=5, + stopping_criteria=None, + repetition_penalty=1.0, + fixed_output_length=False # if True output.shape == (batch_size, seq_len) + ): + # taking many ideas and components from HuggingFace GenerationMixin + # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation + assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`." + assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len" + + with torch.no_grad(): + sot_token_id = 49406 if sot_token_id is None else sot_token_id + eos_token_id = 49407 if eos_token_id is None else eos_token_id + pad_token_id = self.pad_id if pad_token_id is None else pad_token_id + logit_processor = LogitsProcessorList( + [ + MinLengthLogitsProcessor(min_seq_len, eos_token_id), + RepetitionPenaltyLogitsProcessor(repetition_penalty), + ] + ) + + if stopping_criteria is None: + stopping_criteria = [MaxLengthCriteria(max_length=seq_len)] + + stopping_criteria = StoppingCriteriaList( + stopping_criteria + ) + + device = image.device + + if generation_type == "beam_search": + output = self._generate_beamsearch( + image_inputs = image, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + sot_token_id=sot_token_id, + num_beams=num_beams, + num_beam_groups=num_beam_groups, + min_seq_len=min_seq_len, + stopping_criteria=stopping_criteria, + logit_processor=logit_processor, + ) + if fixed_output_length and output.shape[1] < seq_len: + return torch.cat( + (output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id), + dim=1 + ) + return output + + elif generation_type == "top_p": + logit_warper = GENERATION_TYPES[generation_type](top_p) + elif generation_type == "top_k": + logit_warper = GENERATION_TYPES[generation_type](top_k) + else: + raise ValueError( + f"generation_type has to be one of " + f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}." + ) + + image_latent, image_embs = self._encode_image(image) + + if text is None: + text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id + + was_training = self.training + num_dims = len(text.shape) + + if num_dims == 1: + text = text[None, :] + + cur_len = text.shape[1] + self.eval() + out = text + + while True: + x = out[:, -max_seq_len:] + cur_len = x.shape[1] + logits = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False)["logits"][:, -1] + mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id) + sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id + + if mask.all(): + if not fixed_output_length: + break + else: + logits = logits[~mask, :] + filtered_logits = logit_processor(x[~mask, :], logits) + filtered_logits = logit_warper(x[~mask, :], filtered_logits) + probs = F.softmax(filtered_logits / temperature, dim=-1) + + if (cur_len + 1 == seq_len): + sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id + else: + sample[~mask, :] = torch.multinomial(probs, 1) + + out = torch.cat((out, sample), dim=-1) + + cur_len += 1 + + if stopping_criteria(out, None): + break + + if num_dims == 1: + out = out.squeeze(0) + + self.train(was_training) + return out + + def _generate_beamsearch( + self, + image_inputs, + pad_token_id=None, + eos_token_id=None, + sot_token_id=None, + num_beams=6, + num_beam_groups=3, + min_seq_len=5, + stopping_criteria=None, + logit_processor=None, + logit_warper=None, + ): + device = image_inputs.device + batch_size = image_inputs.shape[0] + image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0) + image_latent, image_embs = self._encode_image(image_inputs) + + input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long) + input_ids = input_ids * sot_token_id + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=num_beams, + device=device, + num_beam_groups=num_beam_groups, + ) + # instantiate logits processors + logits_processor = ( + LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)]) + if logit_processor is None + else logit_processor + ) + + batch_size = len(beam_scorer._beam_hyps) + num_beams = beam_scorer.num_beams + num_beam_groups = beam_scorer.num_beam_groups + num_sub_beams = num_beams // num_beam_groups + batch_beam_size, cur_len = input_ids.shape + beam_indices = None + + if num_beams * batch_size != batch_beam_size: + raise ValueError( + f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." + ) + + beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) + # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in + # the same group don't produce same tokens everytime. + beam_scores[:, ::num_sub_beams] = 0 + beam_scores = beam_scores.view((batch_size * num_beams,)) + + while True: + + # predicted tokens in cur_len step + current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) + + # indices which will form the beams in the next time step + reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) + + # do one decoder step on all beams of all sentences in batch + model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs) + outputs = self( + model_inputs['images'], + model_inputs['text'], + embed_cls=False, + image_latent=image_latent, + image_embs=image_embs + ) + + for beam_group_idx in range(num_beam_groups): + group_start_idx = beam_group_idx * num_sub_beams + group_end_idx = min(group_start_idx + num_sub_beams, num_beams) + group_size = group_end_idx - group_start_idx + + # indices of beams of current group among all sentences in batch + batch_group_indices = [] + + for batch_idx in range(batch_size): + batch_group_indices.extend( + [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] + ) + group_input_ids = input_ids[batch_group_indices] + + # select outputs of beams of currentg group only + next_token_logits = outputs['logits'][batch_group_indices, -1, :] + vocab_size = next_token_logits.shape[-1] + + next_token_scores_processed = logits_processor( + group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx + ) + next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) + next_token_scores = next_token_scores.expand_as(next_token_scores_processed) + + # reshape for beam search + next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) + + next_token_scores, next_tokens = torch.topk( + next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True + ) + + next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") + next_tokens = next_tokens % vocab_size + + # stateless + process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None + beam_outputs = beam_scorer.process( + group_input_ids, + next_token_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + beam_indices=process_beam_indices, + ) + beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] + + input_ids[batch_group_indices] = group_input_ids[beam_idx] + group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + current_tokens[batch_group_indices] = group_input_ids[:, -1] + + # (beam_idx // group_size) -> batch_idx + # (beam_idx % group_size) -> offset of idx inside the group + reordering_indices[batch_group_indices] = ( + num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size) + ) + + input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) + + # increase cur_len + cur_len = cur_len + 1 + if beam_scorer.is_done or stopping_criteria(input_ids, None): + break + + final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None + sequence_outputs = beam_scorer.finalize( + input_ids, + beam_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + max_length=stopping_criteria.max_length, + beam_indices=final_beam_indices, + ) + return sequence_outputs['sequences'] + + +def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs): + if past: + input_ids = input_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + else: + position_ids = None + return { + "text": input_ids, + "images": image_inputs, + "past_key_values": past, + "position_ids": position_ids, + "attention_mask": attention_mask, + } diff --git a/omg_llava/model/convnext_clip/open_clip/constants.py b/omg_llava/model/convnext_clip/open_clip/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..a670bb3fab442baeb9af53b91c312e6982af57ee --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/constants.py @@ -0,0 +1,2 @@ +OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) +OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) diff --git a/omg_llava/model/convnext_clip/open_clip/factory.py b/omg_llava/model/convnext_clip/open_clip/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..a01afac050fc8812c58c039ba608268703cd2a0c --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/factory.py @@ -0,0 +1,388 @@ +import json +import logging +import os +import pathlib +import re +from copy import deepcopy +from pathlib import Path +from typing import Any, Dict, Optional, Tuple, Union + +import torch + +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\ + resize_pos_embed, get_cast_dtype +from .coca_model import CoCa +from .loss import ClipLoss, DistillClipLoss, CoCaLoss +from .openai import load_openai_model +from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\ + list_pretrained_tags_by_model, download_pretrained_from_hf +from .transform import image_transform, AugmentationCfg +from .tokenizer import HFTokenizer, tokenize + + +HF_HUB_PREFIX = 'hf-hub:' +_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] +_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs + + +def _natural_key(string_): + return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] + + +def _rescan_model_configs(): + global _MODEL_CONFIGS + + config_ext = ('.json',) + config_files = [] + for config_path in _MODEL_CONFIG_PATHS: + if config_path.is_file() and config_path.suffix in config_ext: + config_files.append(config_path) + elif config_path.is_dir(): + for ext in config_ext: + config_files.extend(config_path.glob(f'*{ext}')) + + for cf in config_files: + with open(cf, 'r') as f: + model_cfg = json.load(f) + if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')): + _MODEL_CONFIGS[cf.stem] = model_cfg + + _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))} + + +_rescan_model_configs() # initial populate of model config registry + + +def list_models(): + """ enumerate available model architectures based on config files """ + return list(_MODEL_CONFIGS.keys()) + + +def add_model_config(path): + """ add model config path or file and update registry """ + if not isinstance(path, Path): + path = Path(path) + _MODEL_CONFIG_PATHS.append(path) + _rescan_model_configs() + + +def get_model_config(model_name): + if model_name in _MODEL_CONFIGS: + return deepcopy(_MODEL_CONFIGS[model_name]) + else: + return None + + +def get_tokenizer(model_name): + if model_name.startswith(HF_HUB_PREFIX): + tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):]) + else: + config = get_model_config(model_name) + tokenizer = HFTokenizer( + config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize + return tokenizer + + +def load_state_dict(checkpoint_path: str, map_location='cpu'): + checkpoint = torch.load(checkpoint_path, map_location=map_location) + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + if next(iter(state_dict.items()))[0].startswith('module'): + state_dict = {k[7:]: v for k, v in state_dict.items()} + return state_dict + + +def load_checkpoint(model, checkpoint_path, strict=True): + print("Loadding CLIP from {} !!!".format(checkpoint_path)) + state_dict = load_state_dict(checkpoint_path) + # detect old format and make compatible with new format + if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): + state_dict = convert_to_custom_text_state_dict(state_dict) + resize_pos_embed(state_dict, model) + incompatible_keys = model.load_state_dict(state_dict, strict=strict) + return incompatible_keys + + +def create_model( + model_name: str, + pretrained: Optional[str] = None, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_text: bool = False, + force_patch_dropout: Optional[float] = None, + force_image_size: Optional[Union[int, Tuple[int, int]]] = None, + pretrained_image: bool = False, + pretrained_hf: bool = True, + cache_dir: Optional[str] = None, + output_dict: Optional[bool] = None, + require_pretrained: bool = False, + logger: logging.Logger = logging, +): + has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX) + if has_hf_hub_prefix: + model_id = model_name[len(HF_HUB_PREFIX):] + checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir) + config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir) + + with open(config_path, 'r', encoding='utf-8') as f: + config = json.load(f) + pretrained_cfg = config['preprocess_cfg'] + model_cfg = config['model_cfg'] + else: + model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names + checkpoint_path = None + pretrained_cfg = {} + model_cfg = None + + if isinstance(device, str): + device = torch.device(device) + + if pretrained and pretrained.lower() == 'openai': + logger.info(f'Loading pretrained {model_name} from OpenAI.') + model = load_openai_model( + model_name, + precision=precision, + device=device, + cache_dir=cache_dir, + ) + else: + model_cfg = model_cfg or get_model_config(model_name) + if model_cfg is not None: + logger.info(f'Loaded {model_name} model config.') + else: + logger.error(f'Model config for {model_name} not found; available models {list_models()}.') + raise RuntimeError(f'Model config for {model_name} not found.') + + if force_quick_gelu: + # override for use of QuickGELU on non-OpenAI transformer models + model_cfg["quick_gelu"] = True + + if force_patch_dropout is not None: + # override the default patch dropout value + model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout + + if force_image_size is not None: + # override model config's image size + model_cfg["vision_cfg"]["image_size"] = force_image_size + + is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {}) + if pretrained_image: + if is_timm_model: + # pretrained weight loading for timm models set via vision_cfg + model_cfg['vision_cfg']['timm_model_pretrained'] = True + else: + assert False, 'pretrained image towers currently only supported for timm models' + + # cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes + cast_dtype = get_cast_dtype(precision) + is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {}) + custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model + + if custom_text: + if is_hf_model: + model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf + if "coca" in model_name: + model = CoCa(**model_cfg, cast_dtype=cast_dtype) + else: + model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) + else: + model = CLIP(**model_cfg, cast_dtype=cast_dtype) + + if precision in ("fp16", "bf16"): + dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 + # manual mixed precision that matches original OpenAI behaviour + if is_timm_model: + # FIXME this is a bit janky, create timm based model in low-precision and + # then cast only LayerNormFp32 instances back to float32 so they don't break. + # Why? The convert_weights_to_lp fn only works with native models. + model.to(device=device, dtype=dtype) + from .transformer import LayerNormFp32 + def _convert_ln(m): + if isinstance(m, LayerNormFp32): + m.weight.data = m.weight.data.to(torch.float32) + m.bias.data = m.bias.data.to(torch.float32) + model.apply(_convert_ln) + else: + model.to(device=device) + convert_weights_to_lp(model, dtype=dtype) + elif precision in ("pure_fp16", "pure_bf16"): + dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 + model.to(device=device, dtype=dtype) + else: + model.to(device=device) + + pretrained_loaded = False + if pretrained: + checkpoint_path = '' + pretrained_cfg = get_pretrained_cfg(model_name, pretrained) + if pretrained_cfg: + checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir) + elif os.path.exists(pretrained): + checkpoint_path = pretrained + + if checkpoint_path: + logger.info(f'Loading pretrained {model_name} weights ({pretrained}).') + load_checkpoint(model, checkpoint_path) + else: + error_str = ( + f'Pretrained weights ({pretrained}) not found for model {model_name}.' + f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.') + logger.warning(error_str) + raise RuntimeError(error_str) + pretrained_loaded = True + elif has_hf_hub_prefix: + logger.info(f'Loading pretrained {model_name} weights ({pretrained}).') + load_checkpoint(model, checkpoint_path) + pretrained_loaded = True + + if require_pretrained and not pretrained_loaded: + # callers of create_model_from_pretrained always expect pretrained weights + raise RuntimeError( + f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.') + + # set image / mean metadata from pretrained_cfg if available, or use default + model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN + model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD + + if output_dict and hasattr(model, "output_dict"): + model.output_dict = True + + if jit: + model = torch.jit.script(model) + + return model + + +def create_loss(args): + if args.distill: + return DistillClipLoss( + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + ) + elif "coca" in args.model.lower(): + return CoCaLoss( + caption_loss_weight=args.coca_caption_loss_weight, + clip_loss_weight=args.coca_contrastive_loss_weight, + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + ) + return ClipLoss( + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + ) + + +def create_model_and_transforms( + model_name: str, + pretrained: Optional[str] = None, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_text: bool = False, + force_patch_dropout: Optional[float] = None, + force_image_size: Optional[Union[int, Tuple[int, int]]] = None, + pretrained_image: bool = False, + pretrained_hf: bool = True, + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, + cache_dir: Optional[str] = None, + output_dict: Optional[bool] = None, + logger: logging.Logger = logging, +): + model = create_model( + model_name, + pretrained, + precision=precision, + device=device, + jit=jit, + force_quick_gelu=force_quick_gelu, + force_custom_text=force_custom_text, + force_patch_dropout=force_patch_dropout, + force_image_size=force_image_size, + pretrained_image=pretrained_image, + pretrained_hf=pretrained_hf, + cache_dir=cache_dir, + output_dict=output_dict, + logger=logger, + ) + + image_mean = image_mean or getattr(model.visual, 'image_mean', None) + image_std = image_std or getattr(model.visual, 'image_std', None) + preprocess_train = image_transform( + model.visual.image_size, + is_train=True, + mean=image_mean, + std=image_std, + aug_cfg=aug_cfg, + ) + preprocess_val = image_transform( + model.visual.image_size, + is_train=False, + mean=image_mean, + std=image_std, + ) + + return model, preprocess_train, preprocess_val + + +def create_model_from_pretrained( + model_name: str, + pretrained: Optional[str] = None, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_text: bool = False, + force_image_size: Optional[Union[int, Tuple[int, int]]] = None, + return_transform: bool = True, + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + cache_dir: Optional[str] = None, + logger: logging.Logger = logging, +): + model = create_model( + model_name, + pretrained, + precision=precision, + device=device, + jit=jit, + force_quick_gelu=force_quick_gelu, + force_custom_text=force_custom_text, + force_image_size=force_image_size, + cache_dir=cache_dir, + require_pretrained=True, + logger=logger, + ) + + if not return_transform: + return model + + image_mean = image_mean or getattr(model.visual, 'image_mean', None) + image_std = image_std or getattr(model.visual, 'image_std', None) + preprocess = image_transform( + model.visual.image_size, + is_train=False, + mean=image_mean, + std=image_std, + ) + + return model, preprocess diff --git a/omg_llava/model/convnext_clip/open_clip/generation_utils.py b/omg_llava/model/convnext_clip/open_clip/generation_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/omg_llava/model/convnext_clip/open_clip/hf_configs.py b/omg_llava/model/convnext_clip/open_clip/hf_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..13c9bfd8c660eac59f1fbc1912b9fccc9c0c625a --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/hf_configs.py @@ -0,0 +1,56 @@ +# HF architecture dict: +arch_dict = { + # https://huggingface.co/docs/transformers/model_doc/roberta#roberta + "roberta": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "hidden_size", + "heads": "num_attention_heads", + "layers": "num_hidden_layers", + "layer_attr": "layer", + "token_embeddings_attr": "embeddings" + }, + "pooler": "mean_pooler", + }, + # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig + "xlm-roberta": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "hidden_size", + "heads": "num_attention_heads", + "layers": "num_hidden_layers", + "layer_attr": "layer", + "token_embeddings_attr": "embeddings" + }, + "pooler": "mean_pooler", + }, + # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 + "mt5": { + "config_names": { + # unlimited seqlen + # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 + # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 + "context_length": "", + "vocab_size": "vocab_size", + "width": "d_model", + "heads": "num_heads", + "layers": "num_layers", + "layer_attr": "block", + "token_embeddings_attr": "embed_tokens" + }, + "pooler": "mean_pooler", + }, + # https://huggingface.co/docs/transformers/model_doc/bert + "bert": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "hidden_size", + "heads": "num_attention_heads", + "layers": "num_hidden_layers", + }, + "pooler": "cls_pooler", + }, +} diff --git a/omg_llava/model/convnext_clip/open_clip/hf_model.py b/omg_llava/model/convnext_clip/open_clip/hf_model.py new file mode 100644 index 0000000000000000000000000000000000000000..08dbdbcde02b550ca765ca9bcb0b667be2c0443d --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/hf_model.py @@ -0,0 +1,193 @@ +""" huggingface model adapter + +Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model. +""" +import re + +import torch +import torch.nn as nn +from torch import TensorType + +try: + import transformers + from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig + from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \ + BaseModelOutputWithPoolingAndCrossAttentions +except ImportError as e: + transformers = None + + + class BaseModelOutput: + pass + + + class PretrainedConfig: + pass + +from .hf_configs import arch_dict + + +# utils +def _camel2snake(s): + return re.sub(r'(? torch.Tensor: + # calculated ground-truth and cache if enabled + if self.prev_num_logits != num_logits or device not in self.labels: + labels = torch.arange(num_logits, device=device, dtype=torch.long) + if self.world_size > 1 and self.local_loss: + labels = labels + num_logits * self.rank + if self.cache_labels: + self.labels[device] = labels + self.prev_num_logits = num_logits + else: + labels = self.labels[device] + return labels + + def get_logits(self, image_features, text_features, logit_scale): + if self.world_size > 1: + all_image_features, all_text_features = gather_features( + image_features, text_features, + self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) + + if self.local_loss: + logits_per_image = logit_scale * image_features @ all_text_features.T + logits_per_text = logit_scale * text_features @ all_image_features.T + else: + logits_per_image = logit_scale * all_image_features @ all_text_features.T + logits_per_text = logits_per_image.T + else: + logits_per_image = logit_scale * image_features @ text_features.T + logits_per_text = logit_scale * text_features @ image_features.T + + return logits_per_image, logits_per_text + + def forward(self, image_features, text_features, logit_scale, output_dict=False): + device = image_features.device + logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale) + + labels = self.get_ground_truth(device, logits_per_image.shape[0]) + + total_loss = ( + F.cross_entropy(logits_per_image, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + + return {"contrastive_loss": total_loss} if output_dict else total_loss + + +class CoCaLoss(ClipLoss): + def __init__( + self, + caption_loss_weight, + clip_loss_weight, + pad_id=0, # pad_token for open_clip custom tokenizer + local_loss=False, + gather_with_grad=False, + cache_labels=False, + rank=0, + world_size=1, + use_horovod=False, + ): + super().__init__( + local_loss=local_loss, + gather_with_grad=gather_with_grad, + cache_labels=cache_labels, + rank=rank, + world_size=world_size, + use_horovod=use_horovod + ) + + self.clip_loss_weight = clip_loss_weight + self.caption_loss_weight = caption_loss_weight + self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id) + + def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False): + + clip_loss = torch.tensor(0) + + if self.clip_loss_weight: + clip_loss = super().forward(image_features, text_features, logit_scale) + clip_loss = self.clip_loss_weight * clip_loss + + caption_loss = self.caption_loss( + logits.permute(0, 2, 1), + labels, + ) + caption_loss = caption_loss * self.caption_loss_weight + + if output_dict: + return {"contrastive_loss": clip_loss, "caption_loss": caption_loss} + + return clip_loss, caption_loss + + +class DistillClipLoss(ClipLoss): + + def dist_loss(self, teacher_logits, student_logits): + return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0) + + def forward( + self, + image_features, + text_features, + logit_scale, + dist_image_features, + dist_text_features, + dist_logit_scale, + output_dict=False, + ): + logits_per_image, logits_per_text = \ + self.get_logits(image_features, text_features, logit_scale) + + dist_logits_per_image, dist_logits_per_text = \ + self.get_logits(dist_image_features, dist_text_features, dist_logit_scale) + + labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0]) + + contrastive_loss = ( + F.cross_entropy(logits_per_image, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + + distill_loss = ( + self.dist_loss(dist_logits_per_image, logits_per_image) + + self.dist_loss(dist_logits_per_text, logits_per_text) + ) / 2 + + if output_dict: + return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss} + + return contrastive_loss, distill_loss diff --git a/omg_llava/model/convnext_clip/open_clip/model.py b/omg_llava/model/convnext_clip/open_clip/model.py new file mode 100644 index 0000000000000000000000000000000000000000..f85b68ba23117cb65d082cf5cd4cf7528bab4619 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model.py @@ -0,0 +1,473 @@ +""" CLIP Model + +Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" +from dataclasses import dataclass +import logging +import math +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.utils.checkpoint import checkpoint + +from .hf_model import HFTextEncoder +from .modified_resnet import ModifiedResNet +from .timm_model import TimmModel +from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer +from .utils import to_2tuple + + +@dataclass +class CLIPVisionCfg: + layers: Union[Tuple[int, int, int, int], int] = 12 + width: int = 768 + head_width: int = 64 + mlp_ratio: float = 4.0 + patch_size: int = 16 + image_size: Union[Tuple[int, int], int] = 224 + + ls_init_value: Optional[float] = None # layer scale initial value + patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results + input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design + global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580) + attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer + n_queries: int = 256 # n_queries for attentional pooler + attn_pooler_heads: int = 8 # n heads for attentional_pooling + output_tokens: bool = False + + timm_model_name: str = None # a valid model name overrides layers, width, patch_size + timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model + timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') + timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') + timm_proj_bias: bool = False # enable bias final projection + timm_drop: float = 0. # head dropout + timm_drop_path: Optional[float] = None # backbone stochastic depth + + +@dataclass +class CLIPTextCfg: + context_length: int = 77 + vocab_size: int = 49408 + width: int = 512 + heads: int = 8 + layers: int = 12 + ls_init_value: Optional[float] = None # layer scale initial value + hf_model_name: str = None + hf_tokenizer_name: str = None + hf_model_pretrained: bool = True + proj: str = 'mlp' + pooler_type: str = 'mean_pooler' + embed_cls: bool = False + pad_id: int = 0 + output_tokens: bool = False + + +def get_cast_dtype(precision: str): + cast_dtype = None + if precision == 'bf16': + cast_dtype = torch.bfloat16 + elif precision == 'fp16': + cast_dtype = torch.float16 + return cast_dtype + + +def get_input_dtype(precision: str): + input_dtype = None + if precision in ('bf16', 'pure_bf16'): + input_dtype = torch.bfloat16 + elif precision in ('fp16', 'pure_fp16'): + input_dtype = torch.float16 + return input_dtype + + +def _build_vision_tower( + embed_dim: int, + vision_cfg: CLIPVisionCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None +): + if isinstance(vision_cfg, dict): + vision_cfg = CLIPVisionCfg(**vision_cfg) + + # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more + # memory efficient in recent PyTorch releases (>= 1.10). + # NOTE: timm models always use native GELU regardless of quick_gelu flag. + act_layer = QuickGELU if quick_gelu else nn.GELU + + if vision_cfg.timm_model_name: + visual = TimmModel( + vision_cfg.timm_model_name, + pretrained=vision_cfg.timm_model_pretrained, + pool=vision_cfg.timm_pool, + proj=vision_cfg.timm_proj, + proj_bias=vision_cfg.timm_proj_bias, + drop=vision_cfg.timm_drop, + drop_path=vision_cfg.timm_drop_path, + patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None, + embed_dim=embed_dim, + image_size=vision_cfg.image_size, + ) + elif isinstance(vision_cfg.layers, (tuple, list)): + vision_heads = vision_cfg.width * 32 // vision_cfg.head_width + visual = ModifiedResNet( + layers=vision_cfg.layers, + output_dim=embed_dim, + heads=vision_heads, + image_size=vision_cfg.image_size, + width=vision_cfg.width, + ) + else: + vision_heads = vision_cfg.width // vision_cfg.head_width + norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + visual = VisionTransformer( + image_size=vision_cfg.image_size, + patch_size=vision_cfg.patch_size, + width=vision_cfg.width, + layers=vision_cfg.layers, + heads=vision_heads, + mlp_ratio=vision_cfg.mlp_ratio, + ls_init_value=vision_cfg.ls_init_value, + patch_dropout=vision_cfg.patch_dropout, + input_patchnorm=vision_cfg.input_patchnorm, + global_average_pool=vision_cfg.global_average_pool, + attentional_pool=vision_cfg.attentional_pool, + n_queries=vision_cfg.n_queries, + attn_pooler_heads=vision_cfg.attn_pooler_heads, + output_tokens=vision_cfg.output_tokens, + output_dim=embed_dim, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + return visual + + +def _build_text_tower( + embed_dim: int, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, +): + if isinstance(text_cfg, dict): + text_cfg = CLIPTextCfg(**text_cfg) + + if text_cfg.hf_model_name: + text = HFTextEncoder( + text_cfg.hf_model_name, + output_dim=embed_dim, + proj=text_cfg.proj, + pooler_type=text_cfg.pooler_type, + pretrained=text_cfg.hf_model_pretrained, + output_tokens=text_cfg.output_tokens, + ) + else: + act_layer = QuickGELU if quick_gelu else nn.GELU + norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + + text = TextTransformer( + context_length=text_cfg.context_length, + vocab_size=text_cfg.vocab_size, + width=text_cfg.width, + heads=text_cfg.heads, + layers=text_cfg.layers, + ls_init_value=text_cfg.ls_init_value, + output_dim=embed_dim, + embed_cls=text_cfg.embed_cls, + output_tokens=text_cfg.output_tokens, + pad_id=text_cfg.pad_id, + act_layer=act_layer, + norm_layer=norm_layer, + ) + return text + + +class CLIP(nn.Module): + output_dict: torch.jit.Final[bool] + + def __init__( + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + output_dict: bool = False, + ): + super().__init__() + self.output_dict = output_dict + self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) + + text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) + self.transformer = text.transformer + self.context_length = text.context_length + self.vocab_size = text.vocab_size + self.token_embedding = text.token_embedding + self.positional_embedding = text.positional_embedding + self.ln_final = text.ln_final + self.text_projection = text.text_projection + self.register_buffer('attn_mask', text.attn_mask, persistent=False) + + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): + # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 + self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.transformer.grad_checkpointing = enable + + def encode_image(self, image, normalize: bool = False): + features = self.visual(image) + return F.normalize(features, dim=-1) if normalize else features + + def encode_text(self, text, normalize: bool = False): + cast_dtype = self.transformer.get_cast_dtype() + + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x, attn_mask=self.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + return F.normalize(x, dim=-1) if normalize else x + + def forward( + self, + image: Optional[torch.Tensor] = None, + text: Optional[torch.Tensor] = None, + ): + image_features = self.encode_image(image, normalize=True) if image is not None else None + text_features = self.encode_text(text, normalize=True) if text is not None else None + if self.output_dict: + return { + "image_features": image_features, + "text_features": text_features, + "logit_scale": self.logit_scale.exp() + } + return image_features, text_features, self.logit_scale.exp() + + +class CustomTextCLIP(nn.Module): + output_dict: torch.jit.Final[bool] + + def __init__( + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + output_dict: bool = False, + ): + super().__init__() + self.output_dict = output_dict + self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) + self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) + self.context_length = self.text.context_length + self.vocab_size = self.text.vocab_size + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): + # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 + self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) + + def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): + self.text.lock(unlocked_layers, freeze_layer_norm) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.text.set_grad_checkpointing(enable) + + def encode_image(self, image, normalize: bool = False): + features = self.visual(image) + return F.normalize(features, dim=-1) if normalize else features + + def encode_text(self, text, normalize: bool = False): + features = self.text(text) + return F.normalize(features, dim=-1) if normalize else features + + def forward( + self, + image: Optional[torch.Tensor] = None, + text: Optional[torch.Tensor] = None, + ): + image_features = self.encode_image(image, normalize=True) if image is not None else None + text_features = self.encode_text(text, normalize=True) if text is not None else None + if self.output_dict: + return { + "image_features": image_features, + "text_features": text_features, + "logit_scale": self.logit_scale.exp() + } + return image_features, text_features, self.logit_scale.exp() + + +def convert_weights_to_lp(model: nn.Module, dtype=torch.float16): + """Convert applicable model parameters to low-precision (bf16 or fp16)""" + + def _convert_weights(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.to(dtype) + if l.bias is not None: + l.bias.data = l.bias.data.to(dtype) + + if isinstance(l, (nn.MultiheadAttention, Attention)): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.to(dtype) + + if isinstance(l, (CLIP, TextTransformer)): + # convert text nn.Parameter projections + attr = getattr(l, "text_projection", None) + if attr is not None: + attr.data = attr.data.to(dtype) + + if isinstance(l, VisionTransformer): + # convert vision nn.Parameter projections + attr = getattr(l, "proj", None) + if attr is not None: + attr.data = attr.data.to(dtype) + + model.apply(_convert_weights) + + +convert_weights_to_fp16 = convert_weights_to_lp # backwards compat + + +# used to maintain checkpoint compatibility +def convert_to_custom_text_state_dict(state_dict: dict): + if 'text_projection' in state_dict: + # old format state_dict, move text tower -> .text + new_state_dict = {} + for k, v in state_dict.items(): + if any(k.startswith(p) for p in ( + 'text_projection', + 'positional_embedding', + 'token_embedding', + 'transformer', + 'ln_final', + )): + k = 'text.' + k + new_state_dict[k] = v + return new_state_dict + return state_dict + + +def build_model_from_openai_state_dict( + state_dict: dict, + quick_gelu=True, + cast_dtype=torch.float16, +): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len( + [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_size = vision_patch_size * grid_size + else: + counts: list = [ + len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_size = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) + + vision_cfg = CLIPVisionCfg( + layers=vision_layers, + width=vision_width, + patch_size=vision_patch_size, + image_size=image_size, + ) + text_cfg = CLIPTextCfg( + context_length=context_length, + vocab_size=vocab_size, + width=transformer_width, + heads=transformer_heads, + layers=transformer_layers, + ) + model = CLIP( + embed_dim, + vision_cfg=vision_cfg, + text_cfg=text_cfg, + quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU + cast_dtype=cast_dtype, + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + state_dict.pop(key, None) + + convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16 + model.load_state_dict(state_dict) + return model.eval() + + +def trace_model(model, batch_size=256, device=torch.device('cpu')): + model.eval() + image_size = model.visual.image_size + example_images = torch.ones((batch_size, 3, image_size, image_size), device=device) + example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device) + model = torch.jit.trace_module( + model, + inputs=dict( + forward=(example_images, example_text), + encode_text=(example_text,), + encode_image=(example_images,) + )) + model.visual.image_size = image_size + return model + + +def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True): + # Rescale the grid of position embeddings when loading from state_dict + old_pos_embed = state_dict.get('visual.positional_embedding', None) + if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): + return + grid_size = to_2tuple(model.visual.grid_size) + extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) + new_seq_len = grid_size[0] * grid_size[1] + extra_tokens + if new_seq_len == old_pos_embed.shape[0]: + return + + if extra_tokens: + pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] + else: + pos_emb_tok, pos_emb_img = None, old_pos_embed + old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) + + logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) + pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) + pos_emb_img = F.interpolate( + pos_emb_img, + size=grid_size, + mode=interpolation, + antialias=antialias, + align_corners=False, + ) + pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] + if pos_emb_tok is not None: + new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) + else: + new_pos_embed = pos_emb_img + state_dict['visual.positional_embedding'] = new_pos_embed diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/EVA01-g-14-plus.json b/omg_llava/model/convnext_clip/open_clip/model_configs/EVA01-g-14-plus.json new file mode 100644 index 0000000000000000000000000000000000000000..73f46a71e664fce987218b8eb48903e7bd895f41 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/EVA01-g-14-plus.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "eva_giant_patch14_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + }, + "custom_text": true +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/EVA01-g-14.json b/omg_llava/model/convnext_clip/open_clip/model_configs/EVA01-g-14.json new file mode 100644 index 0000000000000000000000000000000000000000..9d0e80f290d9491b7c46fafd576201b1258165aa --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/EVA01-g-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "eva_giant_patch14_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/EVA02-B-16.json b/omg_llava/model/convnext_clip/open_clip/model_configs/EVA02-B-16.json new file mode 100644 index 0000000000000000000000000000000000000000..3f92357287e1f6600da1e7f391cb6370d7f66de4 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/EVA02-B-16.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "eva02_base_patch16_clip_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/EVA02-E-14-plus.json b/omg_llava/model/convnext_clip/open_clip/model_configs/EVA02-E-14-plus.json new file mode 100644 index 0000000000000000000000000000000000000000..e250c2a404c86ff168c54cfcf71bc2492be1b74c --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/EVA02-E-14-plus.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "eva02_enormous_patch14_clip_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 32 + }, + "custom_text": true +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/EVA02-E-14.json b/omg_llava/model/convnext_clip/open_clip/model_configs/EVA02-E-14.json new file mode 100644 index 0000000000000000000000000000000000000000..4b6648e25092b151a9095e0a66956c7ebf835b16 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/EVA02-E-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "eva02_enormous_patch14_clip_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + }, + "custom_text": true +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/EVA02-L-14-336.json b/omg_llava/model/convnext_clip/open_clip/model_configs/EVA02-L-14-336.json new file mode 100644 index 0000000000000000000000000000000000000000..2bb07f3c082fd88c4e86131b272163aaacfaef9e --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/EVA02-L-14-336.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 336, + "timm_model_name": "eva02_large_patch14_clip_336", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/EVA02-L-14.json b/omg_llava/model/convnext_clip/open_clip/model_configs/EVA02-L-14.json new file mode 100644 index 0000000000000000000000000000000000000000..b4c7f377bc543aa92a145358f2630a58ae9be989 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/EVA02-L-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "eva02_large_patch14_clip_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/RN101-quickgelu.json b/omg_llava/model/convnext_clip/open_clip/model_configs/RN101-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..d0db2c161d13138788c4609d373b023b8454d624 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/RN101-quickgelu.json @@ -0,0 +1,22 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 23, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/RN101.json b/omg_llava/model/convnext_clip/open_clip/model_configs/RN101.json new file mode 100644 index 0000000000000000000000000000000000000000..b88b4d3acbaa701c614ab0ea65fc88fcfe289c32 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/RN101.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 23, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/RN50-quickgelu.json b/omg_llava/model/convnext_clip/open_clip/model_configs/RN50-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..8c2f91260cdeb043434dc1e893cce81d4ce7f0d1 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/RN50-quickgelu.json @@ -0,0 +1,22 @@ +{ + "embed_dim": 1024, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 6, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/RN50.json b/omg_llava/model/convnext_clip/open_clip/model_configs/RN50.json new file mode 100644 index 0000000000000000000000000000000000000000..33aa884d54fee0076c33676831e49d5e1ffcb8f2 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/RN50.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 6, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/RN50x16.json b/omg_llava/model/convnext_clip/open_clip/model_configs/RN50x16.json new file mode 100644 index 0000000000000000000000000000000000000000..3161e1a2c9a839161e652a4d729c2cdc971161db --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/RN50x16.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 384, + "layers": [ + 6, + 8, + 18, + 8 + ], + "width": 96, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/RN50x4.json b/omg_llava/model/convnext_clip/open_clip/model_configs/RN50x4.json new file mode 100644 index 0000000000000000000000000000000000000000..e155237f8ce1026aaaeecc80751eabe6f329f0bb --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/RN50x4.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 288, + "layers": [ + 4, + 6, + 10, + 6 + ], + "width": 80, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/RN50x64.json b/omg_llava/model/convnext_clip/open_clip/model_configs/RN50x64.json new file mode 100644 index 0000000000000000000000000000000000000000..f5aaa2ee3de21ddb03cbd12766a3419bf34898c7 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/RN50x64.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 448, + "layers": [ + 3, + 15, + 36, + 10 + ], + "width": 128, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 12 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-B-16-plus-240.json b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-B-16-plus-240.json new file mode 100644 index 0000000000000000000000000000000000000000..5bbd12bcd01f64d6d0a0aa8316b129327a0d169a --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-B-16-plus-240.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 240, + "layers": 12, + "width": 896, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-B-16-plus.json b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-B-16-plus.json new file mode 100644 index 0000000000000000000000000000000000000000..5dc1e09baccef2b15055c1bffeb9903e760101c6 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-B-16-plus.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 896, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-B-16.json b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-B-16.json new file mode 100644 index 0000000000000000000000000000000000000000..395eea77ec3907c0611531aba63459b193e67b9c --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-B-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-B-32-plus-256.json b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-B-32-plus-256.json new file mode 100644 index 0000000000000000000000000000000000000000..2f09c857de9a4c01ae51297a7e2451984879f9de --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-B-32-plus-256.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 256, + "layers": 12, + "width": 896, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-B-32-quickgelu.json b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-B-32-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..ce6bd923593293ed50dfcfb28b73ca7403bcf3c5 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-B-32-quickgelu.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-B-32.json b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..07c8e28eb06fa1813ba932fe4eec668262d1c47f --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-B-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-H-14.json b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-H-14.json new file mode 100644 index 0000000000000000000000000000000000000000..3e3a7e934e7f02e41f4829996c4950e05f015a74 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-H-14.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-H-16.json b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-H-16.json new file mode 100644 index 0000000000000000000000000000000000000000..588485455fdf8193ec16474450b94e31c91ea93c --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-H-16.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-L-14-280.json b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-L-14-280.json new file mode 100644 index 0000000000000000000000000000000000000000..2262deaefa82792d35d73c0d7c8e620525092581 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-L-14-280.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 280, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-L-14-336.json b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-L-14-336.json new file mode 100644 index 0000000000000000000000000000000000000000..8d1f74c2639c3a3705df9865b9c08215675ddc97 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-L-14-336.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 336, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-L-14.json b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-L-14.json new file mode 100644 index 0000000000000000000000000000000000000000..d4a4bbb1dd4ed4edb317d3ace4f3ad13b211c241 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-L-14.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-L-16-320.json b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-L-16-320.json new file mode 100644 index 0000000000000000000000000000000000000000..fc2d13ca9ec7f0b56a886ddaf66c4a7ba7a442ba --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-L-16-320.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 320, + "layers": 24, + "width": 1024, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-L-16.json b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-L-16.json new file mode 100644 index 0000000000000000000000000000000000000000..82a1cedfa290adacbbdc02bc5d589734c22d41d3 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-L-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-M-16-alt.json b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-M-16-alt.json new file mode 100644 index 0000000000000000000000000000000000000000..1a317aad8e02d9c26d2decc7cc49a18dfdf9e0d8 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-M-16-alt.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 512, + "patch_size": 16, + "ls_init_value": 1e-4 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-M-16.json b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-M-16.json new file mode 100644 index 0000000000000000000000000000000000000000..f2f3225a46e09237730a151d161f70c86b985172 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-M-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 512, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-M-32-alt.json b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-M-32-alt.json new file mode 100644 index 0000000000000000000000000000000000000000..fd222aeac0f582ef6a1a33f1b3fec70a5b386ac0 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-M-32-alt.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 512, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-M-32.json b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-M-32.json new file mode 100644 index 0000000000000000000000000000000000000000..4f718642821035d9776d1e006817d65ede074366 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-M-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 512, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-S-16-alt.json b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-S-16-alt.json new file mode 100644 index 0000000000000000000000000000000000000000..a8c056555e4da3ba0d1475a61fc316362ecce76f --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-S-16-alt.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 256, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 384, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 256, + "heads": 4, + "layers": 10 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-S-16.json b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-S-16.json new file mode 100644 index 0000000000000000000000000000000000000000..1d8504e59658803f3093e5b05de45f30a09b8185 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-S-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 384, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-S-32-alt.json b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-S-32-alt.json new file mode 100644 index 0000000000000000000000000000000000000000..e1dfdec9824df09a2010e991ccfa1d9ee2f45807 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-S-32-alt.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 256, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 384, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 256, + "heads": 4, + "layers": 10 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-S-32.json b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-S-32.json new file mode 100644 index 0000000000000000000000000000000000000000..9b8b4191b268de267268cfcb90fc01c6b9df07d8 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-S-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 384, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-bigG-14.json b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-bigG-14.json new file mode 100644 index 0000000000000000000000000000000000000000..2cfba479a2e8f3737e71ce240732bf3bc743d8b7 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-bigG-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1280, + "vision_cfg": { + "image_size": 224, + "layers": 48, + "width": 1664, + "head_width": 104, + "mlp_ratio": 4.9231, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 32 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-e-14.json b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-e-14.json new file mode 100644 index 0000000000000000000000000000000000000000..91a0fe14d25a107fb8ec48dd7faae313fd26ed7b --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-e-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1280, + "vision_cfg": { + "image_size": 224, + "layers": 56, + "width": 1792, + "head_width": 112, + "mlp_ratio": 8.5715, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 36 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-g-14.json b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-g-14.json new file mode 100644 index 0000000000000000000000000000000000000000..8c4b7325cc75b6112be7107d36ae2cb5762d9091 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/ViT-g-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 40, + "width": 1408, + "head_width": 88, + "mlp_ratio": 4.3637, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/coca_ViT-B-32.json b/omg_llava/model/convnext_clip/open_clip/model_configs/coca_ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..7e7eb520a6a0096e5602d509ecd6186e278f4725 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/coca_ViT-B-32.json @@ -0,0 +1,30 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32, + "attentional_pool": true, + "attn_pooler_heads": 8, + "output_tokens": true + }, + "text_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12, + "embed_cls": true, + "output_tokens": true + }, + "multimodal_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12, + "attn_pooler_heads": 8 + }, + "custom_text": true +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/coca_ViT-L-14.json b/omg_llava/model/convnext_clip/open_clip/model_configs/coca_ViT-L-14.json new file mode 100644 index 0000000000000000000000000000000000000000..3d5ca4ca2338540f06852df5ff35ea6277e64555 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/coca_ViT-L-14.json @@ -0,0 +1,30 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 14, + "attentional_pool": true, + "attn_pooler_heads": 8, + "output_tokens": true + }, + "text_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12, + "embed_cls": true, + "output_tokens": true + }, + "multimodal_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12, + "attn_pooler_heads": 12 + }, + "custom_text": true +} diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/coca_base.json b/omg_llava/model/convnext_clip/open_clip/model_configs/coca_base.json new file mode 100644 index 0000000000000000000000000000000000000000..cf8c6cecb78a49d7e7140145a0307cbd561077c2 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/coca_base.json @@ -0,0 +1,31 @@ +{ + "embed_dim": 512, + "multimodal_cfg": { + "width": 768, + "context_length": 76, + "vocab_size": 64000, + "mlp_ratio": 4, + "layers": 12, + "dim_head": 64, + "heads": 12, + "n_queries": 256, + "attn_pooler_heads": 8 + }, + "vision_cfg": { + "image_size": 288, + "layers": 12, + "width": 768, + "patch_size": 18, + "output_tokens": true + }, + "text_cfg": { + "context_length": 76, + "vocab_size": 64000, + "layers": 12, + "heads": 12, + "width": 768, + "embed_cls": true, + "output_tokens": true + }, + "custom_text": true +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/coca_roberta-ViT-B-32.json b/omg_llava/model/convnext_clip/open_clip/model_configs/coca_roberta-ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..fb46354b95a17a46d7fcfd9d504e917ee6c1608c --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/coca_roberta-ViT-B-32.json @@ -0,0 +1,24 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32, + "output_tokens": true + }, + "text_cfg": { + "hf_model_name": "roberta-base", + "hf_tokenizer_name": "roberta-base", + "proj": "linear", + "width": 768, + "output_tokens": true + }, + "multimodal_cfg": { + "context_length": 76, + "width": 768, + "heads": 8, + "layers": 12 + }, + "custom_text": true +} diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/convnext_base.json b/omg_llava/model/convnext_clip/open_clip/model_configs/convnext_base.json new file mode 100644 index 0000000000000000000000000000000000000000..bb6dba181d950ea5081155c90d47e72c94816b80 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/convnext_base.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "convnext_base", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/convnext_base_w.json b/omg_llava/model/convnext_clip/open_clip/model_configs/convnext_base_w.json new file mode 100644 index 0000000000000000000000000000000000000000..82ea7ae3659e5514f37ff982f0ab1141dff4bd18 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/convnext_base_w.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "timm_model_name": "convnext_base", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/convnext_base_w_320.json b/omg_llava/model/convnext_clip/open_clip/model_configs/convnext_base_w_320.json new file mode 100644 index 0000000000000000000000000000000000000000..0a07c4e16abaa4015ecc5f82ec845de16e1f9d88 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/convnext_base_w_320.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "timm_model_name": "convnext_base", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 320 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/convnext_large.json b/omg_llava/model/convnext_clip/open_clip/model_configs/convnext_large.json new file mode 100644 index 0000000000000000000000000000000000000000..c4a1fea73dbead71c218a0e74b9b15f9b252e3ef --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/convnext_large.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "timm_model_name": "convnext_large", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/convnext_large_d.json b/omg_llava/model/convnext_clip/open_clip/model_configs/convnext_large_d.json new file mode 100644 index 0000000000000000000000000000000000000000..ae8fed21b58e1a6a411daf8b792ee50f0ab42346 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/convnext_large_d.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "timm_model_name": "convnext_large", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "mlp", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 16 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/convnext_large_d_320.json b/omg_llava/model/convnext_clip/open_clip/model_configs/convnext_large_d_320.json new file mode 100644 index 0000000000000000000000000000000000000000..54c3df36a6f56ace0b12ada24c13058de96feed8 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/convnext_large_d_320.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "timm_model_name": "convnext_large", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "mlp", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 320 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 16 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/convnext_small.json b/omg_llava/model/convnext_clip/open_clip/model_configs/convnext_small.json new file mode 100644 index 0000000000000000000000000000000000000000..3592c2a5cd21aae8d2544931773cf7603f67ea28 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/convnext_small.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "convnext_small", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/convnext_tiny.json b/omg_llava/model/convnext_clip/open_clip/model_configs/convnext_tiny.json new file mode 100644 index 0000000000000000000000000000000000000000..ad11470f5ec40ffec771096971ce58d3d5b9249b --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/convnext_tiny.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "convnext_tiny", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/convnext_xlarge.json b/omg_llava/model/convnext_clip/open_clip/model_configs/convnext_xlarge.json new file mode 100644 index 0000000000000000000000000000000000000000..2a909965932eef994177c829fefc2bdc1c219b3f --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/convnext_xlarge.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "convnext_xlarge", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 20 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/convnext_xxlarge.json b/omg_llava/model/convnext_clip/open_clip/model_configs/convnext_xxlarge.json new file mode 100644 index 0000000000000000000000000000000000000000..23a55a681c346d1a315d8a163c1cb6ad495e6a91 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/convnext_xxlarge.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "convnext_xxlarge", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/convnext_xxlarge_320.json b/omg_llava/model/convnext_clip/open_clip/model_configs/convnext_xxlarge_320.json new file mode 100644 index 0000000000000000000000000000000000000000..ac5134ca12cbaa97772cde059270d345386a74c7 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/convnext_xxlarge_320.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "convnext_xxlarge", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 320 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/mt5-base-ViT-B-32.json b/omg_llava/model/convnext_clip/open_clip/model_configs/mt5-base-ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..58cad89cf0f446bbe15e4e25b1ac43424a828017 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/mt5-base-ViT-B-32.json @@ -0,0 +1,15 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "hf_model_name": "google/mt5-base", + "hf_tokenizer_name": "google/mt5-base", + "proj": "mlp", + "pooler_type": "mean_pooler" + } +} diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/mt5-xl-ViT-H-14.json b/omg_llava/model/convnext_clip/open_clip/model_configs/mt5-xl-ViT-H-14.json new file mode 100644 index 0000000000000000000000000000000000000000..b432810777ba7269dbb0e89edfe65cdd27e7d255 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/mt5-xl-ViT-H-14.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14 + }, + "text_cfg": { + "hf_model_name": "google/mt5-xl", + "hf_tokenizer_name": "google/mt5-xl", + "proj": "mlp", + "pooler_type": "mean_pooler" + } +} diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/roberta-ViT-B-32.json b/omg_llava/model/convnext_clip/open_clip/model_configs/roberta-ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..ed687d472a73bb2ac96025f355f80437ab14c260 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/roberta-ViT-B-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "hf_model_name": "roberta-base", + "hf_tokenizer_name": "roberta-base", + "proj": "mlp", + "pooler_type": "mean_pooler" + } +} diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/swin_base_patch4_window7_224.json b/omg_llava/model/convnext_clip/open_clip/model_configs/swin_base_patch4_window7_224.json new file mode 100644 index 0000000000000000000000000000000000000000..bd6820f0cf2aa655e0a2723287f4b78895a58e6a --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/swin_base_patch4_window7_224.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "timm_model_name": "swin_base_patch4_window7_224", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/vit_medium_patch16_gap_256.json b/omg_llava/model/convnext_clip/open_clip/model_configs/vit_medium_patch16_gap_256.json new file mode 100644 index 0000000000000000000000000000000000000000..8843eaf08cad16c3e7b5f496fd650715c9573f65 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/vit_medium_patch16_gap_256.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "vit_medium_patch16_gap_256", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json b/omg_llava/model/convnext_clip/open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json new file mode 100644 index 0000000000000000000000000000000000000000..ed217b202d5e6071c5307f4547c97ff4cfe2abd1 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "vit_relpos_medium_patch16_cls_224", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json b/omg_llava/model/convnext_clip/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..751bccc2c6fc41bc4ff20182de88d86739d518d9 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json @@ -0,0 +1,15 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "hf_model_name": "xlm-roberta-base", + "hf_tokenizer_name": "xlm-roberta-base", + "proj": "mlp", + "pooler_type": "mean_pooler" + } +} diff --git a/omg_llava/model/convnext_clip/open_clip/model_configs/xlm-roberta-large-ViT-H-14.json b/omg_llava/model/convnext_clip/open_clip/model_configs/xlm-roberta-large-ViT-H-14.json new file mode 100644 index 0000000000000000000000000000000000000000..31f271faa9bbb7a9da53900b483a4c00a16f3c4a --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/model_configs/xlm-roberta-large-ViT-H-14.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14 + }, + "text_cfg": { + "hf_model_name": "xlm-roberta-large", + "hf_tokenizer_name": "xlm-roberta-large", + "proj": "mlp", + "pooler_type": "mean_pooler" + } +} diff --git a/omg_llava/model/convnext_clip/open_clip/modified_resnet.py b/omg_llava/model/convnext_clip/open_clip/modified_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..6a8d3aeda91ecb394303becbbfccc8acd8cddcd9 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/modified_resnet.py @@ -0,0 +1,181 @@ +from collections import OrderedDict + +import torch +from torch import nn +from torch.nn import functional as F + +from .utils import freeze_batch_norm_2d + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.act1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.act2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.act3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.act1(self.bn1(self.conv1(x))) + out = self.act2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.act3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0., + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + + return x[0] + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, image_size=224, width=64): + super().__init__() + self.output_dim = output_dim + self.image_size = image_size + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.act1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.act2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.act3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) + + self.init_parameters() + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def init_parameters(self): + if self.attnpool is not None: + std = self.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert unlocked_groups == 0, 'partial locking not currently supported for this model' + for param in self.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + # FIXME support for non-transformer + pass + + def stem(self, x): + x = self.act1(self.bn1(self.conv1(x))) + x = self.act2(self.bn2(self.conv2(x))) + x = self.act3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + def forward(self, x): + x = self.stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x diff --git a/omg_llava/model/convnext_clip/open_clip/openai.py b/omg_llava/model/convnext_clip/open_clip/openai.py new file mode 100644 index 0000000000000000000000000000000000000000..6c2c0235245c2e4f1217b3b2bfaf2acf78e74981 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/openai.py @@ -0,0 +1,90 @@ +""" OpenAI pretrained model functions + +Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" + +import os +import warnings +from typing import List, Optional, Union + +import torch + +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype +from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url + +__all__ = ["list_openai_models", "load_openai_model"] + + +def list_openai_models() -> List[str]: + """Returns the names of available CLIP models""" + return list_pretrained_models_by_tag('openai') + + +def load_openai_model( + name: str, + precision: Optional[str] = None, + device: Optional[Union[str, torch.device]] = None, + cache_dir: Optional[str] = None, +): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + precision: str + Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. + device : Union[str, torch.device] + The device to put the loaded model + cache_dir : Optional[str] + The directory to cache the downloaded model weights + + Returns + ------- + model : torch.nn.Module + The CLIP model + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + if precision is None: + precision = 'fp32' if device == 'cpu' else 'fp16' + + if get_pretrained_url(name, 'openai'): + model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") + + try: + # loading JIT archive + model = torch.jit.load(model_path, map_location="cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + state_dict = torch.load(model_path, map_location="cpu") + + # Build a non-jit model from the OpenAI jitted model state dict + cast_dtype = get_cast_dtype(precision) + try: + model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) + except KeyError: + sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} + model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) + + # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use + model = model.to(device) + # FIXME support pure fp16/bf16 precision modes + if precision != 'fp16': + model.float() + if precision == 'bf16': + # for bf16, convert back to low-precision + convert_weights_to_lp(model, dtype=torch.bfloat16) + + # add mean / std attributes for consistency with OpenCLIP models + model.visual.image_mean = OPENAI_DATASET_MEAN + model.visual.image_std = OPENAI_DATASET_STD + return model diff --git a/omg_llava/model/convnext_clip/open_clip/pretrained.py b/omg_llava/model/convnext_clip/open_clip/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..1465a2325652be7e7a1d7563698e38b9ec408cc6 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/pretrained.py @@ -0,0 +1,427 @@ +import hashlib +import os +import urllib +import warnings +from functools import partial +from typing import Dict, Union + +from tqdm import tqdm + +from .version import __version__ + +try: + from huggingface_hub import hf_hub_download + hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version=__version__) + _has_hf_hub = True +except ImportError: + hf_hub_download = None + _has_hf_hub = False + + +def _pcfg(url='', hf_hub='', mean=None, std=None): + return dict( + url=url, + hf_hub=hf_hub, + mean=mean, + std=std, + ) + + +_RN50 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"), + yfcc15m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"), + cc12m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"), +) + +_RN50_quickgelu = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"), + yfcc15m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"), + cc12m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"), +) + +_RN101 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"), + yfcc15m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"), +) + +_RN101_quickgelu = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"), + yfcc15m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"), +) + +_RN50x4 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"), +) + +_RN50x16 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"), +) + +_RN50x64 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"), +) + +_VITB32 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), + laion2b_e16=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"), + laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/'), + # DataComp-M models + datacomp_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.M-s128M-b4K/'), + commonpool_m_clip_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.clip-s128M-b4K/'), + commonpool_m_laion_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.laion-s128M-b4K/'), + commonpool_m_image_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.image-s128M-b4K/'), + commonpool_m_text_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.text-s128M-b4K/'), + commonpool_m_basic_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.basic-s128M-b4K/'), + commonpool_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M-s128M-b4K/'), + # DataComp-S models + datacomp_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.S-s13M-b4K/'), + commonpool_s_clip_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.clip-s13M-b4K/'), + commonpool_s_laion_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.laion-s13M-b4K/'), + commonpool_s_image_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.image-s13M-b4K/'), + commonpool_s_text_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.text-s13M-b4K/'), + commonpool_s_basic_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.basic-s13M-b4K/'), + commonpool_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S-s13M-b4K/'), +) + +_VITB32_quickgelu = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), +) + +_VITB16 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"), + laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'), + # DataComp-L models + datacomp_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.L-s1B-b8K/'), + commonpool_l_clip_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.clip-s1B-b8K/'), + commonpool_l_laion_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.laion-s1B-b8K/'), + commonpool_l_image_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.image-s1B-b8K/'), + commonpool_l_text_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.text-s1B-b8K/'), + commonpool_l_basic_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.basic-s1B-b8K/'), + commonpool_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L-s1B-b8K/'), +) + +_VITB16_PLUS_240 = dict( + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"), +) + +_VITL14 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"), + laion2b_s32b_b82k=_pcfg( + hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + # DataComp-XL models + datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K/'), + commonpool_xl_clip_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.clip-s13B-b90K/'), + commonpool_xl_laion_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.laion-s13B-b90K/'), + commonpool_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL-s13B-b90K/'), +) + +_VITL14_336 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"), +) + +_VITH14 = dict( + laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'), +) + +_VITg14 = dict( + laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'), + laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'), +) + +_VITbigG14 = dict( + laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'), +) + +_robertaViTB32 = dict( + laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'), +) + +_xlmRobertaBaseViTB32 = dict( + laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'), +) + +_xlmRobertaLargeFrozenViTH14 = dict( + frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'), +) + +_convnext_base = dict( + laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'), +) + +_convnext_base_w = dict( + laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'), + laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'), + laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'), +) + +_convnext_base_w_320 = dict( + laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'), + laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'), +) + +_convnext_large_d = dict( + laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'), +) + +_convnext_large_d_320 = dict( + laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'), + laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'), +) + +_convnext_xxlarge = dict( + laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'), + laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'), + laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'), +) + +_coca_VITB32 = dict( + laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'), + mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/') +) + +_coca_VITL14 = dict( + laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'), + mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/') +) + + +_PRETRAINED = { + "RN50": _RN50, + "RN50-quickgelu": _RN50_quickgelu, + "RN101": _RN101, + "RN101-quickgelu": _RN101_quickgelu, + "RN50x4": _RN50x4, + "RN50x16": _RN50x16, + "RN50x64": _RN50x64, + "ViT-B-32": _VITB32, + "ViT-B-32-quickgelu": _VITB32_quickgelu, + "ViT-B-16": _VITB16, + "ViT-B-16-plus-240": _VITB16_PLUS_240, + "ViT-L-14": _VITL14, + "ViT-L-14-336": _VITL14_336, + "ViT-H-14": _VITH14, + "ViT-g-14": _VITg14, + "ViT-bigG-14": _VITbigG14, + "roberta-ViT-B-32": _robertaViTB32, + "xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32, + "xlm-roberta-large-ViT-H-14": _xlmRobertaLargeFrozenViTH14, + "convnext_base": _convnext_base, + "convnext_base_w": _convnext_base_w, + "convnext_base_w_320": _convnext_base_w_320, + "convnext_large_d": _convnext_large_d, + "convnext_large_d_320": _convnext_large_d_320, + "convnext_xxlarge": _convnext_xxlarge, + "coca_ViT-B-32": _coca_VITB32, + "coca_ViT-L-14": _coca_VITL14, + "EVA01-g-14": dict( + # from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt + laion400m_s11b_b41k=_pcfg(hf_hub='timm/eva_giant_patch14_clip_224.laion400m_s11b_b41k/'), + ), + "EVA01-g-14-plus": dict( + # from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt + merged2b_s11b_b114k=_pcfg(hf_hub='timm/eva_giant_patch14_plus_clip_224.merged2b_s11b_b114k/'), + ), + "EVA02-B-16": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt + merged2b_s8b_b131k=_pcfg(hf_hub='timm/eva02_base_patch16_clip_224.merged2b_s8b_b131k/'), + ), + "EVA02-L-14": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt + merged2b_s4b_b131k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_224.merged2b_s4b_b131k/'), + ), + "EVA02-L-14-336": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt + merged2b_s6b_b61k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_336.merged2b_s6b_b61k/'), + ), + "EVA02-E-14": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt + laion2b_s4b_b115k=_pcfg(hf_hub='timm/eva02_enormous_patch14_clip_224.laion2b_s4b_b115k/'), + ), + "EVA02-E-14-plus": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt + laion2b_s9b_b144k=_pcfg(hf_hub='timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k/'), + ) +} + + +def _clean_tag(tag: str): + # normalize pretrained tags + return tag.lower().replace('-', '_') + + +def list_pretrained(as_str: bool = False): + """ returns list of pretrained models + Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True + """ + return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] + + +def list_pretrained_models_by_tag(tag: str): + """ return all models having the specified pretrain tag """ + models = [] + tag = _clean_tag(tag) + for k in _PRETRAINED.keys(): + if tag in _PRETRAINED[k]: + models.append(k) + return models + + +def list_pretrained_tags_by_model(model: str): + """ return all pretrain tags for the specified model architecture """ + tags = [] + if model in _PRETRAINED: + tags.extend(_PRETRAINED[model].keys()) + return tags + + +def is_pretrained_cfg(model: str, tag: str): + if model not in _PRETRAINED: + return False + return _clean_tag(tag) in _PRETRAINED[model] + + +def get_pretrained_cfg(model: str, tag: str): + if model not in _PRETRAINED: + return {} + model_pretrained = _PRETRAINED[model] + return model_pretrained.get(_clean_tag(tag), {}) + + +def get_pretrained_url(model: str, tag: str): + cfg = get_pretrained_cfg(model, _clean_tag(tag)) + return cfg.get('url', '') + + +def download_pretrained_from_url( + url: str, + cache_dir: Union[str, None] = None, +): + if not cache_dir: + cache_dir = os.path.expanduser("~/.cache/clip") + os.makedirs(cache_dir, exist_ok=True) + filename = os.path.basename(url) + + if 'openaipublic' in url: + expected_sha256 = url.split("/")[-2] + elif 'mlfoundations' in url: + expected_sha256 = os.path.splitext(filename)[0].split("-")[-1] + else: + expected_sha256 = '' + + download_target = os.path.join(cache_dir, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if expected_sha256: + if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + else: + return download_target + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): + raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def has_hf_hub(necessary=False): + if not _has_hf_hub and necessary: + # if no HF Hub module installed, and it is necessary to continue, raise error + raise RuntimeError( + 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') + return _has_hf_hub + + +def download_pretrained_from_hf( + model_id: str, + filename: str = 'open_clip_pytorch_model.bin', + revision=None, + cache_dir: Union[str, None] = None, +): + has_hf_hub(True) + cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir) + return cached_file + + +def download_pretrained( + cfg: Dict, + force_hf_hub: bool = False, + cache_dir: Union[str, None] = None, +): + target = '' + if not cfg: + return target + + download_url = cfg.get('url', '') + download_hf_hub = cfg.get('hf_hub', '') + if download_hf_hub and force_hf_hub: + # use HF hub even if url exists + download_url = '' + + if download_url: + target = download_pretrained_from_url(download_url, cache_dir=cache_dir) + elif download_hf_hub: + has_hf_hub(True) + # we assume the hf_hub entries in pretrained config combine model_id + filename in + # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and + # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'. + model_id, filename = os.path.split(download_hf_hub) + if filename: + target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir) + else: + target = download_pretrained_from_hf(model_id, cache_dir=cache_dir) + + return target diff --git a/omg_llava/model/convnext_clip/open_clip/push_to_hf_hub.py b/omg_llava/model/convnext_clip/open_clip/push_to_hf_hub.py new file mode 100644 index 0000000000000000000000000000000000000000..6e6271da1d35e36ea22e92d339dc9465d0793249 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/push_to_hf_hub.py @@ -0,0 +1,280 @@ +import argparse +import json +import os +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Optional, Tuple, Union + +import torch + +try: + from huggingface_hub import ( + create_repo, + get_hf_file_metadata, + hf_hub_download, + hf_hub_url, + repo_type_and_id_from_hf_id, + upload_folder, + list_repo_files, + ) + from huggingface_hub.utils import EntryNotFoundError + _has_hf_hub = True +except ImportError: + _has_hf_hub = False + +try: + import safetensors.torch + _has_safetensors = True +except ImportError: + _has_safetensors = False + +from .factory import create_model_from_pretrained, get_model_config, get_tokenizer +from .tokenizer import HFTokenizer + +# Default name for a weights file hosted on the Huggingface Hub. +HF_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl +HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version +HF_CONFIG_NAME = 'open_clip_config.json' + +def save_config_for_hf( + model, + config_path: str, + model_config: Optional[dict] +): + preprocess_cfg = { + 'mean': model.visual.image_mean, + 'std': model.visual.image_std, + } + hf_config = { + 'model_cfg': model_config, + 'preprocess_cfg': preprocess_cfg, + } + + with config_path.open('w') as f: + json.dump(hf_config, f, indent=2) + + +def save_for_hf( + model, + tokenizer: HFTokenizer, + model_config: dict, + save_directory: str, + safe_serialization: Union[bool, str] = False, + skip_weights : bool = False, +): + config_filename = HF_CONFIG_NAME + + save_directory = Path(save_directory) + save_directory.mkdir(exist_ok=True, parents=True) + + if not skip_weights: + tensors = model.state_dict() + if safe_serialization is True or safe_serialization == "both": + assert _has_safetensors, "`pip install safetensors` to use .safetensors" + safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME) + if safe_serialization is False or safe_serialization == "both": + torch.save(tensors, save_directory / HF_WEIGHTS_NAME) + + tokenizer.save_pretrained(save_directory) + + config_path = save_directory / config_filename + save_config_for_hf(model, config_path, model_config=model_config) + + +def push_to_hf_hub( + model, + tokenizer, + model_config: Optional[dict], + repo_id: str, + commit_message: str = 'Add model', + token: Optional[str] = None, + revision: Optional[str] = None, + private: bool = False, + create_pr: bool = False, + model_card: Optional[dict] = None, + safe_serialization: Union[bool, str] = False, +): + if not isinstance(tokenizer, HFTokenizer): + # default CLIP tokenizers use https://huggingface.co/openai/clip-vit-large-patch14 + tokenizer = HFTokenizer('openai/clip-vit-large-patch14') + + # Create repo if it doesn't exist yet + repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True) + + # Infer complete repo_id from repo_url + # Can be different from the input `repo_id` if repo_owner was implicit + _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url) + repo_id = f"{repo_owner}/{repo_name}" + + # Check if repo already exists and determine what needs updating + repo_exists = False + repo_files = {} + try: + repo_files = set(list_repo_files(repo_id)) + repo_exists = True + except Exception as e: + print('Repo does not exist', e) + + try: + get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision)) + has_readme = True + except EntryNotFoundError: + has_readme = False + + # Dump model and push to Hub + with TemporaryDirectory() as tmpdir: + # Save model weights and config. + save_for_hf( + model, + tokenizer=tokenizer, + model_config=model_config, + save_directory=tmpdir, + safe_serialization=safe_serialization, + ) + + # Add readme if it does not exist + if not has_readme: + model_card = model_card or {} + model_name = repo_id.split('/')[-1] + readme_path = Path(tmpdir) / "README.md" + readme_text = generate_readme(model_card, model_name) + readme_path.write_text(readme_text) + + # Upload model and return + return upload_folder( + repo_id=repo_id, + folder_path=tmpdir, + revision=revision, + create_pr=create_pr, + commit_message=commit_message, + ) + + +def push_pretrained_to_hf_hub( + model_name, + pretrained: str, + repo_id: str, + precision: str = 'fp32', + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + commit_message: str = 'Add model', + token: Optional[str] = None, + revision: Optional[str] = None, + private: bool = False, + create_pr: bool = False, + model_card: Optional[dict] = None, +): + model, preprocess_eval = create_model_from_pretrained( + model_name, + pretrained=pretrained, + precision=precision, + image_mean=image_mean, + image_std=image_std, + ) + + model_config = get_model_config(model_name) + assert model_config + + tokenizer = get_tokenizer(model_name) + + push_to_hf_hub( + model=model, + tokenizer=tokenizer, + model_config=model_config, + repo_id=repo_id, + commit_message=commit_message, + token=token, + revision=revision, + private=private, + create_pr=create_pr, + model_card=model_card, + safe_serialization='both', + ) + + +def generate_readme(model_card: dict, model_name: str): + readme_text = "---\n" + readme_text += "tags:\n- clip\n" + readme_text += "library_name: open_clip\n" + readme_text += "pipeline_tag: zero-shot-image-classification\n" + readme_text += f"license: {model_card.get('license', 'mit')}\n" + if 'details' in model_card and 'Dataset' in model_card['details']: + readme_text += 'datasets:\n' + readme_text += f"- {model_card['details']['Dataset'].lower()}\n" + readme_text += "---\n" + readme_text += f"# Model card for {model_name}\n" + if 'description' in model_card: + readme_text += f"\n{model_card['description']}\n" + if 'details' in model_card: + readme_text += f"\n## Model Details\n" + for k, v in model_card['details'].items(): + if isinstance(v, (list, tuple)): + readme_text += f"- **{k}:**\n" + for vi in v: + readme_text += f" - {vi}\n" + elif isinstance(v, dict): + readme_text += f"- **{k}:**\n" + for ki, vi in v.items(): + readme_text += f" - {ki}: {vi}\n" + else: + readme_text += f"- **{k}:** {v}\n" + if 'usage' in model_card: + readme_text += f"\n## Model Usage\n" + readme_text += model_card['usage'] + readme_text += '\n' + + if 'comparison' in model_card: + readme_text += f"\n## Model Comparison\n" + readme_text += model_card['comparison'] + readme_text += '\n' + + if 'citation' in model_card: + readme_text += f"\n## Citation\n" + if not isinstance(model_card['citation'], (list, tuple)): + citations = [model_card['citation']] + else: + citations = model_card['citation'] + for c in citations: + readme_text += f"```bibtex\n{c}\n```\n" + + return readme_text + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Push to Hugging Face Hub") + parser.add_argument( + "--model", type=str, help="Name of the model to use.", + ) + parser.add_argument( + "--pretrained", type=str, + help="Use a pretrained CLIP model weights with the specified tag or file path.", + ) + parser.add_argument( + "--repo-id", type=str, + help="Destination HF Hub repo-id ie 'organization/model_id'.", + ) + parser.add_argument( + "--precision", type=str, default='fp32', + ) + parser.add_argument( + '--image-mean', type=float, nargs='+', default=None, metavar='MEAN', + help='Override default image mean value of dataset') + parser.add_argument( + '--image-std', type=float, nargs='+', default=None, metavar='STD', + help='Override default image std deviation of of dataset') + args = parser.parse_args() + + print(f'Saving model {args.model} with pretrained weights {args.pretrained} to Hugging Face Hub at {args.repo_id}') + + # FIXME add support to pass model_card json / template from file via cmd line + + push_pretrained_to_hf_hub( + args.model, + args.pretrained, + args.repo_id, + precision=args.precision, + image_mean=args.image_mean, # override image mean/std if trained w/ non defaults + image_std=args.image_std, + ) + + print(f'{args.model} saved.') diff --git a/omg_llava/model/convnext_clip/open_clip/timm_model.py b/omg_llava/model/convnext_clip/open_clip/timm_model.py new file mode 100644 index 0000000000000000000000000000000000000000..3d3f595d67cdedd142b6312d26924e8e58c67086 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/timm_model.py @@ -0,0 +1,149 @@ +""" timm model adapter + +Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. +""" +import logging +from collections import OrderedDict + +import torch +import torch.nn as nn + +try: + import timm + from timm.models.layers import Mlp, to_2tuple + try: + # old timm imports < 0.8.1 + from timm.models.layers.attention_pool2d import RotAttentionPool2d + from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d + except ImportError: + # new timm imports >= 0.8.1 + from timm.layers import RotAttentionPool2d + from timm.layers import AttentionPool2d as AbsAttentionPool2d +except ImportError: + timm = None + +from .utils import freeze_batch_norm_2d + + +class TimmModel(nn.Module): + """ timm model adapter + """ + + def __init__( + self, + model_name, + embed_dim, + image_size=224, + pool='avg', + proj='linear', + proj_bias=False, + drop=0., + drop_path=None, + patch_drop=None, + pretrained=False, + ): + super().__init__() + if timm is None: + raise RuntimeError("Please `pip install timm` to use timm models.") + self.image_size = to_2tuple(image_size) + + # setup kwargs that may not be common across all models + timm_kwargs = {} + if drop_path is not None: + timm_kwargs['drop_path_rate'] = drop_path + if patch_drop is not None: + timm_kwargs['patch_drop_rate'] = patch_drop + + custom_pool = pool in ('abs_attn', 'rot_attn') + if not proj and not custom_pool: + # use network classifier head as projection if no proj specified and no custom pooling used + self.trunk = timm.create_model( + model_name, + num_classes=embed_dim, + global_pool=pool, + pretrained=pretrained, + **timm_kwargs, + ) + prev_chs = embed_dim + else: + self.trunk = timm.create_model( + model_name, + pretrained=pretrained, + **timm_kwargs, + ) + feat_size = self.trunk.default_cfg.get('pool_size', None) + feature_ndim = 1 if not feat_size else 2 + if custom_pool: + assert feature_ndim == 2 + # if attn pooling used, remove both classifier and default pool + self.trunk.reset_classifier(0, global_pool='') + else: + # reset global pool if pool config set, otherwise leave as network default + reset_kwargs = dict(global_pool=pool) if pool else {} + self.trunk.reset_classifier(0, **reset_kwargs) + prev_chs = self.trunk.num_features + + head_layers = OrderedDict() + + # Add custom pooling to head + if pool == 'abs_attn': + head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) + prev_chs = embed_dim + elif pool == 'rot_attn': + head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) + prev_chs = embed_dim + + # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used + if proj == 'linear': + head_layers['drop'] = nn.Dropout(drop) + head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) + elif proj == 'mlp': + head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias)) + else: + assert not proj, f'Unknown projection type {proj}.' + + self.head = nn.Sequential(head_layers) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + """ lock modules + Args: + unlocked_groups (int): leave last n layer groups unlocked (default: 0) + """ + if not unlocked_groups: + # lock full model + for param in self.trunk.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self.trunk) + else: + # NOTE: partial freeze requires latest timm (master) branch and is subject to change + try: + # FIXME import here until API stable and in an official release + from timm.models.helpers import group_parameters, group_modules + except ImportError: + raise RuntimeError( + 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') + matcher = self.trunk.group_matcher() + gparams = group_parameters(self.trunk, matcher) + max_layer_id = max(gparams.keys()) + max_layer_id = max_layer_id - unlocked_groups + for group_idx in range(max_layer_id + 1): + group = gparams[group_idx] + for param in group: + self.trunk.get_parameter(param).requires_grad = False + if freeze_bn_stats: + gmodules = group_modules(self.trunk, matcher, reverse=True) + gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} + freeze_batch_norm_2d(self.trunk, gmodules) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + try: + self.trunk.set_grad_checkpointing(enable) + except Exception as e: + logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') + + def forward(self, x): + x = self.trunk(x) + x = self.head(x) + return x diff --git a/omg_llava/model/convnext_clip/open_clip/tokenizer.py b/omg_llava/model/convnext_clip/open_clip/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..23fcfcbcb4ca051ba5bba7520918693001999282 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/tokenizer.py @@ -0,0 +1,214 @@ +""" CLIP tokenizer + +Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" +import gzip +import html +import os +from functools import lru_cache +from typing import Union, List + +import ftfy +import regex as re +import torch + +# https://stackoverflow.com/q/62691279 +import os +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a significant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + if not special_tokens: + special_tokens = ['', ''] + else: + special_tokens = ['', ''] + special_tokens + vocab.extend(special_tokens) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {t:t for t in special_tokens} + special = "|".join(special_tokens) + self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + self.vocab_size = len(self.encoder) + self.all_special_ids = [self.encoder[t] for t in special_tokens] + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text + + +_tokenizer = SimpleTokenizer() + +def decode(output_ids: torch.Tensor): + output_ids = output_ids.cpu().numpy() + return _tokenizer.decode(output_ids) + +def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + context_length : int + The context length to use; all CLIP models use 77 as the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder[""] + eot_token = _tokenizer.encoder[""] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + tokens = tokens[:context_length] # Truncate + tokens[-1] = eot_token + result[i, :len(tokens)] = torch.tensor(tokens) + + return result + + +class HFTokenizer: + """HuggingFace tokenizer wrapper""" + + def __init__(self, tokenizer_name: str): + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + + def save_pretrained(self, dest): + self.tokenizer.save_pretrained(dest) + + def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor: + # same cleaning as for default tokenizer, except lowercasing + # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance + if isinstance(texts, str): + texts = [texts] + texts = [whitespace_clean(basic_clean(text)) for text in texts] + input_ids = self.tokenizer( + texts, + return_tensors='pt', + max_length=context_length, + padding='max_length', + truncation=True, + ).input_ids + return input_ids diff --git a/omg_llava/model/convnext_clip/open_clip/transform.py b/omg_llava/model/convnext_clip/open_clip/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..748884a3c7cb7ece1ca521ca1dbf40bb74855007 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/transform.py @@ -0,0 +1,133 @@ +import warnings +from dataclasses import dataclass, asdict +from typing import Any, Dict, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +import torchvision.transforms.functional as F + +from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ + CenterCrop + +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD + + +@dataclass +class AugmentationCfg: + scale: Tuple[float, float] = (0.9, 1.0) + ratio: Optional[Tuple[float, float]] = None + color_jitter: Optional[Union[float, Tuple[float, float, float]]] = None + interpolation: Optional[str] = None + re_prob: Optional[float] = None + re_count: Optional[int] = None + use_timm: bool = False + + +class ResizeMaxSize(nn.Module): + + def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0): + super().__init__() + if not isinstance(max_size, int): + raise TypeError(f"Size should be int. Got {type(max_size)}") + self.max_size = max_size + self.interpolation = interpolation + self.fn = min if fn == 'min' else min + self.fill = fill + + def forward(self, img): + if isinstance(img, torch.Tensor): + height, width = img.shape[:2] + else: + width, height = img.size + scale = self.max_size / float(max(height, width)) + if scale != 1.0: + new_size = tuple(round(dim * scale) for dim in (height, width)) + img = F.resize(img, new_size, self.interpolation) + pad_h = self.max_size - new_size[0] + pad_w = self.max_size - new_size[1] + img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill) + return img + + +def _convert_to_rgb(image): + return image.convert('RGB') + + +def image_transform( + image_size: int, + is_train: bool, + mean: Optional[Tuple[float, ...]] = None, + std: Optional[Tuple[float, ...]] = None, + resize_longest_max: bool = False, + fill_color: int = 0, + aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, +): + mean = mean or OPENAI_DATASET_MEAN + if not isinstance(mean, (list, tuple)): + mean = (mean,) * 3 + + std = std or OPENAI_DATASET_STD + if not isinstance(std, (list, tuple)): + std = (std,) * 3 + + if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: + # for square size, pass size as int so that Resize() uses aspect preserving shortest edge + image_size = image_size[0] + + if isinstance(aug_cfg, dict): + aug_cfg = AugmentationCfg(**aug_cfg) + else: + aug_cfg = aug_cfg or AugmentationCfg() + normalize = Normalize(mean=mean, std=std) + if is_train: + aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None} + use_timm = aug_cfg_dict.pop('use_timm', False) + if use_timm: + from timm.data import create_transform # timm can still be optional + if isinstance(image_size, (tuple, list)): + assert len(image_size) >= 2 + input_size = (3,) + image_size[-2:] + else: + input_size = (3, image_size, image_size) + # by default, timm aug randomly alternates bicubic & bilinear for better robustness at inference time + aug_cfg_dict.setdefault('interpolation', 'random') + aug_cfg_dict.setdefault('color_jitter', None) # disable by default + train_transform = create_transform( + input_size=input_size, + is_training=True, + hflip=0., + mean=mean, + std=std, + re_mode='pixel', + **aug_cfg_dict, + ) + else: + train_transform = Compose([ + RandomResizedCrop( + image_size, + scale=aug_cfg_dict.pop('scale'), + interpolation=InterpolationMode.BICUBIC, + ), + _convert_to_rgb, + ToTensor(), + normalize, + ]) + if aug_cfg_dict: + warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).') + return train_transform + else: + if resize_longest_max: + transforms = [ + ResizeMaxSize(image_size, fill=fill_color) + ] + else: + transforms = [ + Resize(image_size, interpolation=InterpolationMode.BICUBIC), + CenterCrop(image_size), + ] + transforms.extend([ + _convert_to_rgb, + ToTensor(), + normalize, + ]) + return Compose(transforms) diff --git a/omg_llava/model/convnext_clip/open_clip/transformer.py b/omg_llava/model/convnext_clip/open_clip/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..0a30e94664a2dd890a373eb0a0f640818836baaa --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/transformer.py @@ -0,0 +1,726 @@ +from collections import OrderedDict +import math +from typing import Callable, Optional, Sequence, Tuple + +import torch +from torch import nn +from torch.nn import functional as F +from torch.utils.checkpoint import checkpoint + +from .utils import to_2tuple + + +class LayerNormFp32(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm (with cast back to input dtype).""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + + +class QuickGELU(nn.Module): + # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class PatchDropout(nn.Module): + """ + https://arxiv.org/abs/2212.00794 + """ + + def __init__(self, prob, exclude_first_token=True): + super().__init__() + assert 0 <= prob < 1. + self.prob = prob + self.exclude_first_token = exclude_first_token # exclude CLS token + + def forward(self, x): + if not self.training or self.prob == 0.: + return x + + if self.exclude_first_token: + cls_tokens, x = x[:, :1], x[:, 1:] + else: + cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) + + batch = x.size()[0] + num_tokens = x.size()[1] + + batch_indices = torch.arange(batch) + batch_indices = batch_indices[..., None] + + keep_prob = 1 - self.prob + num_patches_keep = max(1, int(num_tokens * keep_prob)) + + rand = torch.randn(batch, num_tokens) + patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices + + x = x[batch_indices, patch_indices_keep] + + if self.exclude_first_token: + x = torch.cat((cls_tokens, x), dim=1) + + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=True, + scaled_cosine=False, + scale_heads=False, + logit_scale_max=math.log(1. / 0.01), + attn_drop=0., + proj_drop=0. + ): + super().__init__() + self.scaled_cosine = scaled_cosine + self.scale_heads = scale_heads + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.logit_scale_max = logit_scale_max + + # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original + self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) + if qkv_bias: + self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) + else: + self.in_proj_bias = None + + if self.scaled_cosine: + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) + else: + self.logit_scale = None + self.attn_drop = nn.Dropout(attn_drop) + if self.scale_heads: + self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) + else: + self.head_scale = None + self.out_proj = nn.Linear(dim, dim) + self.out_drop = nn.Dropout(proj_drop) + + def forward(self, x, attn_mask: Optional[torch.Tensor] = None): + L, N, C = x.shape + q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) + q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + + if self.logit_scale is not None: + attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) + logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() + attn = attn.view(N, self.num_heads, L, L) * logit_scale + attn = attn.view(-1, L, L) + else: + q = q * self.scale + attn = torch.bmm(q, k.transpose(-1, -2)) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) + new_attn_mask.masked_fill_(attn_mask, float("-inf")) + attn_mask = new_attn_mask + attn += attn_mask + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = torch.bmm(attn, v) + if self.head_scale is not None: + x = x.view(N, self.num_heads, L, C) * self.head_scale + x = x.view(-1, L, C) + x = x.transpose(0, 1).reshape(L, N, C) + x = self.out_proj(x) + x = self.out_drop(x) + return x + + +class AttentionalPooler(nn.Module): + def __init__( + self, + d_model: int, + context_dim: int, + n_head: int = 8, + n_queries: int = 256, + norm_layer: Callable = LayerNorm + ): + super().__init__() + self.query = nn.Parameter(torch.randn(n_queries, d_model)) + self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim) + self.ln_q = norm_layer(d_model) + self.ln_k = norm_layer(context_dim) + + def forward(self, x: torch.Tensor): + x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND + N = x.shape[1] + q = self.ln_q(self.query) + out = self.attn(self._repeat(q, N), x, x, need_weights=False)[0] + return out.permute(1, 0, 2) # LND -> NLD + + def _repeat(self, query, N: int): + return query.unsqueeze(1).repeat(1, N, 1) + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + is_cross_attention: bool = False, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + if is_cross_attention: + self.ln_1_kv = norm_layer(d_model) + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, mlp_width)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)) + ])) + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + def attention( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ): + k_x = k_x if k_x is not None else q_x + v_x = v_x if v_x is not None else q_x + + attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None + return self.attn( + q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask + )[0] + + def forward( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ): + k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None + v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None + + x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)) + x = x + self.ls_2(self.mlp(self.ln_2(x))) + return x + + +class CustomResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + scale_cosine_attn: bool = False, + scale_heads: bool = False, + scale_attn: bool = False, + scale_fc: bool = False, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + self.attn = Attention( + d_model, n_head, + scaled_cosine=scale_cosine_attn, + scale_heads=scale_heads, + ) + self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity() + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, mlp_width)), + ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)) + ])) + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask))) + x = x + self.ls_2(self.mlp(self.ln_2(x))) + return x + + +class Transformer(nn.Module): + def __init__( + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + ): + super().__init__() + self.width = width + self.layers = layers + self.grad_checkpointing = False + + self.resblocks = nn.ModuleList([ + ResidualAttentionBlock( + width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer) + for _ in range(layers) + ]) + + def get_cast_dtype(self) -> torch.dtype: + if hasattr(self.resblocks[0].mlp.c_fc, 'int8_original_dtype'): + return self.resblocks[0].mlp.c_fc.int8_original_dtype + return self.resblocks[0].mlp.c_fc.weight.dtype + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + for r in self.resblocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 + x = checkpoint(r, x, None, None, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + +class VisionTransformer(nn.Module): + output_tokens: torch.jit.Final[bool] + + def __init__( + self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + mlp_ratio: float, + ls_init_value: float = None, + global_average_pool: bool = False, + attentional_pool: bool = False, + n_queries: int = 256, + attn_pooler_heads: int = 8, + output_dim: int = 512, + patch_dropout: float = 0., + input_patchnorm: bool = False, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + output_tokens: bool = False + ): + super().__init__() + self.output_tokens = output_tokens + image_height, image_width = self.image_size = to_2tuple(image_size) + patch_height, patch_width = self.patch_size = to_2tuple(patch_size) + self.grid_size = (image_height // patch_height, image_width // patch_width) + self.output_dim = output_dim + + # whether to layernorm each patch, as done in dual patchnorm paper - https://arxiv.org/abs/2302.01327v1 + self.input_patchnorm = input_patchnorm + + if input_patchnorm: + patch_input_dim = patch_height * patch_width * 3 + self.patchnorm_pre_ln = LayerNorm(patch_input_dim) + self.conv1 = nn.Linear(patch_input_dim, width) + else: + self.patchnorm_pre_ln = nn.Identity() + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + # class embeddings and positional embeddings + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) + + # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn + self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity() + + self.ln_pre = norm_layer(width) + self.transformer = Transformer( + width, + layers, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + self.global_average_pool = global_average_pool + if attentional_pool: + self.attn_pool = AttentionalPooler(output_dim, width, n_head=attn_pooler_heads, n_queries=n_queries) + self.ln_post = norm_layer(output_dim) + self.proj = nn.Parameter(scale * torch.randn(output_dim, output_dim)) + else: + self.attn_pool = None + self.ln_post = norm_layer(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + self.init_parameters() + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + for param in self.parameters(): + param.requires_grad = False + + if unlocked_groups != 0: + groups = [ + [ + self.conv1, + self.class_embedding, + self.positional_embedding, + self.ln_pre, + ], + *self.transformer.resblocks[:-1], + [ + self.transformer.resblocks[-1], + self.ln_post, + ], + self.proj, + ] + + def _unlock(x): + if isinstance(x, Sequence): + for g in x: + _unlock(g) + else: + if isinstance(x, torch.nn.Parameter): + x.requires_grad = True + else: + for p in x.parameters(): + p.requires_grad = True + + _unlock(groups[-unlocked_groups:]) + + def init_parameters(self): + # FIXME OpenAI CLIP did not define an init for the VisualTransformer + # TODO experiment if default PyTorch init, below, or alternate init is best. + + # nn.init.normal_(self.class_embedding, std=self.scale) + # nn.init.normal_(self.positional_embedding, std=self.scale) + # + # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + # attn_std = self.transformer.width ** -0.5 + # fc_std = (2 * self.transformer.width) ** -0.5 + # for block in self.transformer.resblocks: + # nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + # nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + # + # if self.text_projection is not None: + # nn.init.normal_(self.text_projection, std=self.scale) + pass + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if self.global_average_pool: + return x.mean(dim=1), x + else: + return x[:, 0], x[:, 1:] + + def forward(self, x: torch.Tensor): + + # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1 + if self.input_patchnorm: + # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)') + x = x.reshape(x.shape[0], x.shape[1], self.grid_size[0], self.patch_size[0], self.grid_size[1], self.patch_size[1]) + x = x.permute(0, 2, 4, 1, 3, 5) + x = x.reshape(x.shape[0], self.grid_size[0] * self.grid_size[1], -1) + x = self.patchnorm_pre_ln(x) + x = self.conv1(x) + else: + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + # class embeddings and positional embeddings + x = torch.cat( + [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), + x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + + # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in + x = self.patch_dropout(x) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + if self.attn_pool is not None: + x = self.attn_pool(x) + x = self.ln_post(x) + pooled, tokens = self._global_pool(x) + else: + pooled, tokens = self._global_pool(x) + pooled = self.ln_post(pooled) + + if self.proj is not None: + pooled = pooled @ self.proj + + if self.output_tokens: + return pooled, tokens + + return pooled + + +class TextTransformer(nn.Module): + output_tokens: torch.jit.Final[bool] + + def __init__( + self, + context_length: int = 77, + vocab_size: int = 49408, + width: int = 512, + heads: int = 8, + layers: int = 12, + ls_init_value: float = None, + output_dim: int = 512, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + embed_cls: bool = False, + pad_id: int = 0, + output_tokens: bool = False, + ): + super().__init__() + self.output_tokens = output_tokens + self.num_pos = self.context_length = context_length + self.vocab_size = vocab_size + self.width = width + self.output_dim = output_dim + self.heads = heads + self.pad_id = pad_id + + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + + if embed_cls: + self.cls_emb = nn.Parameter(torch.empty(width)) + self.num_pos += 1 + else: + self.cls_emb = None + + self.token_embedding = nn.Embedding(vocab_size, width) + self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width)) + self.transformer = Transformer( + width=width, + layers=layers, + heads=heads, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + self.ln_final = norm_layer(width) + + self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) + + self.init_parameters() + + def init_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + if self.cls_emb is not None: + nn.init.normal_(self.cls_emb, std=0.01) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.num_pos, self.num_pos) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def build_cls_mask(self, text, cast_dtype: torch.dtype): + cls_mask = (text != self.pad_id).unsqueeze(1) + cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0) + additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device) + additive_mask.fill_(0) + additive_mask.masked_fill_(~cls_mask, float("-inf")) + additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0) + return additive_mask + + def _repeat(self, t, N: int): + return t.reshape(1, 1, -1).repeat(N, 1, 1) + + def forward(self, text): + cast_dtype = self.transformer.get_cast_dtype() + seq_len = text.shape[1] + + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + attn_mask = self.attn_mask + if self.cls_emb is not None: + seq_len += 1 + x = torch.cat([x, self._repeat(self.cls_emb, x.shape[0])], dim=1) + cls_mask = self.build_cls_mask(text, cast_dtype) + attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len] + + x = x + self.positional_embedding[:seq_len].to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x, attn_mask=attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + if self.cls_emb is not None: + pooled, tokens = x[:, -1], x[:, :-1] + pooled = self.ln_final(pooled) + else: + x = self.ln_final(x) + pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x + + if self.text_projection is not None: + pooled = pooled @ self.text_projection + + if self.output_tokens: + return pooled, tokens + + return pooled + + +class MultimodalTransformer(Transformer): + def __init__( + self, + width: int, + layers: int, + heads: int, + context_length: int = 77, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + output_dim: int = 512, + ): + + super().__init__( + width=width, + layers=layers, + heads=heads, + mlp_ratio=mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + self.context_length = context_length + self.cross_attn = nn.ModuleList([ + ResidualAttentionBlock( + width, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + is_cross_attention=True, + ) + for _ in range(layers) + ]) + + self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) + + self.ln_final = norm_layer(width) + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + + def init_parameters(self): + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + for block in self.transformer.cross_attn: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def forward(self, image_embs, text_embs): + text_embs = text_embs.permute(1, 0, 2) # NLD -> LNDsq + image_embs = image_embs.permute(1, 0, 2) # NLD -> LND + seq_len = text_embs.shape[0] + + for resblock, cross_attn in zip(self.resblocks, self.cross_attn): + if self.grad_checkpointing and not torch.jit.is_scripting(): + # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 + text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len]) + text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None) + else: + text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len]) + text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs) + + x = text_embs.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + if self.text_projection is not None: + x = x @ self.text_projection + + return x + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable diff --git a/omg_llava/model/convnext_clip/open_clip/utils.py b/omg_llava/model/convnext_clip/open_clip/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bb0bb8868ae1f2d31493ca32b73accd6bf1d3cdb --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/utils.py @@ -0,0 +1,89 @@ +from itertools import repeat +import collections.abc + +import torch +from torch import nn as nn +from torchvision.ops.misc import FrozenBatchNorm2d + + +def freeze_batch_norm_2d(module, module_match={}, name=''): + """ + Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is + itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and + returned. Otherwise, the module is walked recursively and submodules are converted in place. + + Args: + module (torch.nn.Module): Any PyTorch module. + module_match (dict): Dictionary of full module names to freeze (all if empty) + name (str): Full module name (prefix) + + Returns: + torch.nn.Module: Resulting module + + Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 + """ + res = module + is_match = True + if module_match: + is_match = name in module_match + if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): + res = FrozenBatchNorm2d(module.num_features) + res.num_features = module.num_features + res.affine = module.affine + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for child_name, child in module.named_children(): + full_child_name = '.'.join([name, child_name]) if name else child_name + new_child = freeze_batch_norm_2d(child, module_match, full_child_name) + if new_child is not child: + res.add_module(child_name, new_child) + return res + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = lambda n, x: _ntuple(n)(x) + +# Replaces all linear layers with linear_replacement +# TODO: add int8 support for other linear layers including attn and convnets +def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True): + for name, module in model.named_children(): + if len(list(module.children())) > 0: + replace_linear(module, linear_replacement, include_modules, copy_weights) + + if isinstance(module, torch.nn.Linear) and name in include_modules: + old_module = model._modules[name] + model._modules[name] = linear_replacement( + module.in_features, + module.out_features, + module.bias is not None, + ) + if copy_weights: + model._modules[name].weight.data.copy_(old_module.weight.data) + if model._modules[name].bias is not None: + model._modules[name].bias.data.copy_(old_module.bias) + + return model + +def convert_int8_model_to_inference_mode(model): + for m in model.modules(): + if hasattr(m, 'prepare_for_eval'): + int8_original_dtype = m.weight.dtype + m.prepare_for_eval() + m.int8_original_dtype = int8_original_dtype \ No newline at end of file diff --git a/omg_llava/model/convnext_clip/open_clip/version.py b/omg_llava/model/convnext_clip/open_clip/version.py new file mode 100644 index 0000000000000000000000000000000000000000..a910817da22d06aa0244c6d488b40d30da2bfb7e --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/version.py @@ -0,0 +1 @@ +__version__ = '2.20.0' diff --git a/omg_llava/model/convnext_clip/open_clip/zero_shot_classifier.py b/omg_llava/model/convnext_clip/open_clip/zero_shot_classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..12b58f65bb0875b164946a9ee73e938aef255382 --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/zero_shot_classifier.py @@ -0,0 +1,110 @@ +from functools import partial +from itertools import islice +from typing import Callable, List, Optional, Sequence, Union + +import torch +import torch.nn.functional as F + + +def batched(iterable, n): + """Batch data into lists of length *n*. The last batch may be shorter. + NOTE based on more-itertools impl, to be replaced by python 3.12 itertools.batched impl + """ + it = iter(iterable) + while True: + batch = list(islice(it, n)) + if not batch: + break + yield batch + + +def build_zero_shot_classifier( + model, + tokenizer, + classnames: Sequence[str], + templates: Sequence[Union[Callable, str]], + num_classes_per_batch: Optional[int] = 10, + device: Union[str, torch.device] = 'cpu', + use_tqdm: bool = False, +): + """ Build zero-shot classifier weights by iterating over class names in batches + Args: + model: CLIP model instance + tokenizer: CLIP tokenizer instance + classnames: A sequence of class (label) names + templates: A sequence of callables or format() friendly strings to produce templates per class name + num_classes_per_batch: The number of classes to batch together in each forward, all if None + device: Device to use. + use_tqdm: Enable TQDM progress bar. + """ + assert isinstance(templates, Sequence) and len(templates) > 0 + assert isinstance(classnames, Sequence) and len(classnames) > 0 + use_format = isinstance(templates[0], str) + num_templates = len(templates) + num_classes = len(classnames) + if use_tqdm: + import tqdm + num_iter = 1 if num_classes_per_batch is None else ((num_classes - 1) // num_classes_per_batch + 1) + iter_wrap = partial(tqdm.tqdm, total=num_iter, unit_scale=num_classes_per_batch) + else: + iter_wrap = iter + + def _process_batch(batch_classnames): + num_batch_classes = len(batch_classnames) + texts = [template.format(c) if use_format else template(c) for c in batch_classnames for template in templates] + texts = tokenizer(texts).to(device) + class_embeddings = F.normalize(model.encode_text(texts), dim=-1) + class_embeddings = class_embeddings.reshape(num_batch_classes, num_templates, -1).mean(dim=1) + class_embeddings = class_embeddings / class_embeddings.norm(dim=1, keepdim=True) + class_embeddings = class_embeddings.T + return class_embeddings + + with torch.no_grad(): + if num_classes_per_batch: + batched_embeds = [_process_batch(batch) for batch in iter_wrap(batched(classnames, num_classes_per_batch))] + zeroshot_weights = torch.cat(batched_embeds, dim=1) + else: + zeroshot_weights = _process_batch(classnames) + return zeroshot_weights + + +def build_zero_shot_classifier_legacy( + model, + tokenizer, + classnames: Sequence[str], + templates: Sequence[Union[Callable, str]], + device: Union[str, torch.device] = 'cpu', + use_tqdm: bool = False, +): + """ Build zero-shot classifier weights by iterating over class names 1 by 1 + Args: + model: CLIP model instance + tokenizer: CLIP tokenizer instance + classnames: A sequence of class (label) names + templates: A sequence of callables or format() friendly strings to produce templates per class name + device: Device to use. + use_tqdm: Enable TQDM progress bar. + """ + assert isinstance(templates, Sequence) and len(templates) > 0 + assert isinstance(classnames, Sequence) and len(classnames) > 0 + if use_tqdm: + import tqdm + iter_wrap = tqdm.tqdm + else: + iter_wrap = iter + + use_format = isinstance(templates[0], str) + + with torch.no_grad(): + zeroshot_weights = [] + for classname in iter_wrap(classnames): + texts = [template.format(classname) if use_format else template(classname) for template in templates] + texts = tokenizer(texts).to(device) # tokenize + class_embeddings = model.encode_text(texts) + class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) + class_embedding /= class_embedding.norm() + zeroshot_weights.append(class_embedding) + zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device) + + return zeroshot_weights + diff --git a/omg_llava/model/convnext_clip/open_clip/zero_shot_metadata.py b/omg_llava/model/convnext_clip/open_clip/zero_shot_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..ccb452bbb6e27b71cff1dd27e2bb263259b9363f --- /dev/null +++ b/omg_llava/model/convnext_clip/open_clip/zero_shot_metadata.py @@ -0,0 +1,266 @@ + +OPENAI_IMAGENET_TEMPLATES = ( + lambda c: f'a bad photo of a {c}.', + lambda c: f'a photo of many {c}.', + lambda c: f'a sculpture of a {c}.', + lambda c: f'a photo of the hard to see {c}.', + lambda c: f'a low resolution photo of the {c}.', + lambda c: f'a rendering of a {c}.', + lambda c: f'graffiti of a {c}.', + lambda c: f'a bad photo of the {c}.', + lambda c: f'a cropped photo of the {c}.', + lambda c: f'a tattoo of a {c}.', + lambda c: f'the embroidered {c}.', + lambda c: f'a photo of a hard to see {c}.', + lambda c: f'a bright photo of a {c}.', + lambda c: f'a photo of a clean {c}.', + lambda c: f'a photo of a dirty {c}.', + lambda c: f'a dark photo of the {c}.', + lambda c: f'a drawing of a {c}.', + lambda c: f'a photo of my {c}.', + lambda c: f'the plastic {c}.', + lambda c: f'a photo of the cool {c}.', + lambda c: f'a close-up photo of a {c}.', + lambda c: f'a black and white photo of the {c}.', + lambda c: f'a painting of the {c}.', + lambda c: f'a painting of a {c}.', + lambda c: f'a pixelated photo of the {c}.', + lambda c: f'a sculpture of the {c}.', + lambda c: f'a bright photo of the {c}.', + lambda c: f'a cropped photo of a {c}.', + lambda c: f'a plastic {c}.', + lambda c: f'a photo of the dirty {c}.', + lambda c: f'a jpeg corrupted photo of a {c}.', + lambda c: f'a blurry photo of the {c}.', + lambda c: f'a photo of the {c}.', + lambda c: f'a good photo of the {c}.', + lambda c: f'a rendering of the {c}.', + lambda c: f'a {c} in a video game.', + lambda c: f'a photo of one {c}.', + lambda c: f'a doodle of a {c}.', + lambda c: f'a close-up photo of the {c}.', + lambda c: f'a photo of a {c}.', + lambda c: f'the origami {c}.', + lambda c: f'the {c} in a video game.', + lambda c: f'a sketch of a {c}.', + lambda c: f'a doodle of the {c}.', + lambda c: f'a origami {c}.', + lambda c: f'a low resolution photo of a {c}.', + lambda c: f'the toy {c}.', + lambda c: f'a rendition of the {c}.', + lambda c: f'a photo of the clean {c}.', + lambda c: f'a photo of a large {c}.', + lambda c: f'a rendition of a {c}.', + lambda c: f'a photo of a nice {c}.', + lambda c: f'a photo of a weird {c}.', + lambda c: f'a blurry photo of a {c}.', + lambda c: f'a cartoon {c}.', + lambda c: f'art of a {c}.', + lambda c: f'a sketch of the {c}.', + lambda c: f'a embroidered {c}.', + lambda c: f'a pixelated photo of a {c}.', + lambda c: f'itap of the {c}.', + lambda c: f'a jpeg corrupted photo of the {c}.', + lambda c: f'a good photo of a {c}.', + lambda c: f'a plushie {c}.', + lambda c: f'a photo of the nice {c}.', + lambda c: f'a photo of the small {c}.', + lambda c: f'a photo of the weird {c}.', + lambda c: f'the cartoon {c}.', + lambda c: f'art of the {c}.', + lambda c: f'a drawing of the {c}.', + lambda c: f'a photo of the large {c}.', + lambda c: f'a black and white photo of a {c}.', + lambda c: f'the plushie {c}.', + lambda c: f'a dark photo of a {c}.', + lambda c: f'itap of a {c}.', + lambda c: f'graffiti of the {c}.', + lambda c: f'a toy {c}.', + lambda c: f'itap of my {c}.', + lambda c: f'a photo of a cool {c}.', + lambda c: f'a photo of a small {c}.', + lambda c: f'a tattoo of the {c}.', +) + + +# a much smaller subset of above prompts +# from https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb +SIMPLE_IMAGENET_TEMPLATES = ( + lambda c: f'itap of a {c}.', + lambda c: f'a bad photo of the {c}.', + lambda c: f'a origami {c}.', + lambda c: f'a photo of the large {c}.', + lambda c: f'a {c} in a video game.', + lambda c: f'art of the {c}.', + lambda c: f'a photo of the small {c}.', +) + + +IMAGENET_CLASSNAMES = ( + "tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", + "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", + "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", + "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", + "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", + "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", + "box turtle", "banded gecko", "green iguana", "Carolina anole", + "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", + "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", + "American alligator", "triceratops", "worm snake", "ring-necked snake", + "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", + "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", + "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", + "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", + "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", + "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", + "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", + "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", + "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", + "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", + "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", + "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", + "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", + "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", + "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", + "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", + "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", + "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", + "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", + "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", + "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", + "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", + "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", + "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", + "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", + "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", + "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", + "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", + "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", + "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", + "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", + "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", + "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", + "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", + "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", + "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", + "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", + "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", + "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", + "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", + "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", + "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", + "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", + "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", + "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", + "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", + "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", + "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", + "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", + "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", + "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", + "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", + "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", + "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", + "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", + "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", + "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", + "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", + "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", + "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", + "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", + "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", + "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", + "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", + "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", + "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", + "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", + "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", + "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", + "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", + "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", + "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", + "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", + "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", + "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", + "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", + "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", + "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", + "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", + "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", + "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", + "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", + "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", + "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", + "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", + "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", + "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", + "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", + "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", + "freight car", "French horn", "frying pan", "fur coat", "garbage truck", + "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", + "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", + "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", + "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", + "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", + "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", + "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", + "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", + "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", + "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", + "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", + "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", + "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", + "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", + "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", + "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", + "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", + "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", + "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", + "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", + "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", + "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", + "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", + "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", + "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", + "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", + "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", + "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", + "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", + "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", + "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", + "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", + "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", + "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", + "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", + "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", + "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", + "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", + "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", + "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", + "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", + "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", + "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", + "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", + "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", + "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", + "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", + "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", + "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", + "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", + "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", + "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", + "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", + "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", + "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", + "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", + "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", + "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", + "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", + "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", + "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", + "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", + "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", + "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", + "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper" +) + diff --git a/omg_llava/model/convnext_clip/openclip_backbone.py b/omg_llava/model/convnext_clip/openclip_backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..3a753bb9d50adf00ab57edb2a8608dcb63559677 --- /dev/null +++ b/omg_llava/model/convnext_clip/openclip_backbone.py @@ -0,0 +1,769 @@ +from typing import Optional, List + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from mmdet.registry import MODELS + +from mmengine.model import BaseModule +from mmengine.dist import get_dist_info +from mmengine.logging import MMLogger +from mmengine.runner.checkpoint import CheckpointLoader +from timm.layers import resample_abs_pos_embed + +from . import open_clip +class Data: + hidden_size = 1024 + +class Output: + def __init__(self, hidden_states): + self.hidden_states = hidden_states + +def load_checkpoint_with_prefix(filename, prefix=None, map_location='cpu', logger='current'): + """Load partial pretrained model with specific prefix. + + Args: + prefix (str): The prefix of sub-module. + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str | None): Same as :func:`torch.load`. + Defaults to None. + logger: logger + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + + checkpoint = CheckpointLoader.load_checkpoint(filename, map_location=map_location, logger=logger) + + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + if not prefix: + return state_dict + if not prefix.endswith('.'): + prefix += '.' + prefix_len = len(prefix) + + state_dict = { + k[prefix_len:]: v + for k, v in state_dict.items() if k.startswith(prefix) + } + + assert state_dict, f'{prefix} is not in the pretrained model' + return state_dict + +def flatten_permute(x): + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + return x + + +@MODELS.register_module() +class OpenCLIPBackbone(BaseModule): + """OpenCLIPBackbone, + Please refer to: + https://github.com/mlfoundations/open_clip/tree/5f7892b672b21e6853d0f6c11b18dda9bcf36c8d#pretrained-model-interface + for the supported models and checkpoints. + """ + STAGES = 4 + + def __init__( + self, + img_size: int = 1024, + model_name: str = '', + fix: bool = True, + fix_layers: Optional[List] = None, + init_cfg=None, + dtype=torch.float16, + **kwargs, + ): + assert init_cfg is not None and init_cfg['type'] in ['clip_pretrain', 'image_pretrain', 'Pretrained'], \ + f"{init_cfg['type']} is not supported." + pretrained = init_cfg['checkpoint'] + super().__init__(init_cfg=None) + self.init_cfg = init_cfg + self.logger = MMLogger.get_current_instance() + rank, world_size = get_dist_info() + + if world_size > 1: + if rank == 0: + if init_cfg['type'] == 'clip_pretrain': + _ = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained, + return_transform=False, logger=self.logger) + elif init_cfg['type'] == 'image_pretrain': + _ = open_clip.create_model(model_name, pretrained_image=True, logger=self.logger) + + else: + pass + dist.barrier() + + # Get the clip model + if init_cfg['type'] == 'clip_pretrain': + clip_model = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained, + return_transform=False, logger=self.logger) + elif init_cfg['type'] == 'image_pretrain': + clip_model = open_clip.create_model(model_name, pretrained_image=True, logger=self.logger) + elif init_cfg['type'] == 'Pretrained': + clip_model = open_clip.create_model(model_name, pretrained_image=False, logger=self.logger) + else: + raise NotImplementedError + + self.out_indices = (0, 1, 2, 3) + model_name_lower = model_name.lower() + if 'convnext_' in model_name_lower: + model_type = 'convnext' + if '_base' in model_name_lower: + output_channels = [128, 256, 512, 1024] + feat_size = 0 + elif '_large' in model_name_lower: + output_channels = [192, 384, 768, 1536] + feat_size = 0 + elif '_xxlarge' in model_name_lower: + output_channels = [384, 768, 1536, 3072] + feat_size = 0 + else: + raise NotImplementedError(f"{model_name} not supported yet.") + elif 'rn' in model_name_lower: + model_type = 'resnet' + if model_name_lower.replace('-quickgelu', '') in ['rn50', 'rn101']: + output_channels = [256, 512, 1024, 2048] + feat_size = 7 + elif model_name_lower == 'rn50x4': + output_channels = [320, 640, 1280, 2560] + feat_size = 9 + elif model_name_lower == 'rn50x16': + output_channels = [384, 768, 1536, 3072] + feat_size = 12 + elif model_name_lower == 'rn50x64': + output_channels = [512, 1024, 2048, 4096] + feat_size = 14 + else: + raise NotImplementedError(f"{model_name} not supported yet.") + elif "vit" in model_name_lower: + model_type = 'vit' + if model_name_lower == 'vit-l-14': + output_channels = [1024, 1024, 1024, 1024] + feat_size = 0 + assert not clip_model.visual.input_patchnorm + assert clip_model.visual.attn_pool is None + elif model_name_lower == 'vit-b-32': + output_channels = [768, 768, 768, 768] + feat_size = 0 + assert not clip_model.visual.input_patchnorm + assert clip_model.visual.attn_pool is None + else: + raise NotImplementedError(f"{model_name} not supported yet.") + else: + raise NotImplementedError(f"{model_name} not supported yet.") + + self.model_name = model_name + self.fix = fix + self.model_type = model_type + self.output_channels = output_channels + self.feat_size = feat_size + + self.config = Data + # self.config.hidden_size = output_channels[-2] + self.config.hidden_size = output_channels[-2] + output_channels[-1] + + # Get the visual model + if self.model_type == 'resnet': + self.stem = nn.Sequential(*[ + clip_model.visual.conv1, clip_model.visual.bn1, clip_model.visual.act1, + clip_model.visual.conv2, clip_model.visual.bn2, clip_model.visual.act2, + clip_model.visual.conv3, clip_model.visual.bn3, clip_model.visual.act3, + ]) + elif self.model_type == 'convnext': + self.stem = clip_model.visual.trunk.stem + elif self.model_type == 'vit': + self.stem = clip_model.visual.conv1 + else: + raise ValueError + + if self.model_type == 'resnet': + self.avgpool = clip_model.visual.avgpool + elif self.model_type == 'convnext': + self.avgpool = nn.Identity() + elif self.model_type == 'vit': + self.avgpool = flatten_permute + else: + raise ValueError + + self.res_layers = [] + if self.model_type in ['vit']: + self.t_class_embedding = clip_model.visual.class_embedding + self.t_positional_embedding = clip_model.visual.positional_embedding + self.t_ln_pre_trans = clip_model.visual.ln_pre + self.t_transformer = clip_model.visual.transformer + else: + for i in range(self.STAGES): + if self.model_type == 'resnet': + layer_name = f'layer{i + 1}' + layer = getattr(clip_model.visual, layer_name) + elif self.model_type == 'convnext': + layer_name = f'layer{i + 1}' + layer = clip_model.visual.trunk.stages[i] + else: + raise ValueError + self.add_module(layer_name, layer) + self.res_layers.append(layer_name) + + if self.model_type == 'resnet': + self.norm_pre = nn.Identity() + elif self.model_type == 'convnext': + self.norm_pre = clip_model.visual.trunk.norm_pre + elif self.model_type == 'vit': + self.norm_pre = nn.Identity() + + if self.model_type == 'resnet': + self.head = clip_model.visual.attnpool + elif self.model_type == 'convnext': + self.head = nn.Sequential(*[ + clip_model.visual.trunk.head, + clip_model.visual.head, + ]) + elif self.model_type == 'vit': + self.head = clip_model.visual.ln_post + self.proj = clip_model.visual.proj + + if self.init_cfg['type'] == 'Pretrained': + checkpoint_path = pretrained + state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix=self.init_cfg['prefix']) + self.load_state_dict(state_dict, strict=True) + + self.fix_layers = fix_layers + + if not self.fix: + self.train() + for name, param in self.norm_pre.named_parameters(): + param.requires_grad = False + for name, param in self.head.named_parameters(): + param.requires_grad = False + if self.fix_layers is not None: + for i, layer_name in enumerate(self.res_layers): + if i in self.fix_layers: + res_layer = getattr(self, layer_name) + for name, param in res_layer.named_parameters(): + param.requires_grad = False + if i == 0: + for name, param in self.stem.named_parameters(): + param.requires_grad = False + + if self.fix: + self.train(mode=False) + for name, param in self.named_parameters(): + param.requires_grad = False + + self.dtype = dtype + self.backbone_type = None + + self.enable_output_gradient = False + + def init_weights(self): + self.logger.info(f"Init Config for {self.model_name}") + self.logger.info(self.init_cfg) + + def train(self: torch.nn.Module, mode: bool = True) -> torch.nn.Module: + if not isinstance(mode, bool): + raise ValueError("training mode is expected to be boolean") + if self.fix: + super().train(mode=False) + else: + super().train(mode=mode) + if self.fix_layers is not None: + for i, layer_name in enumerate(self.res_layers): + if i in self.fix_layers: + res_layer = getattr(self, layer_name) + res_layer.train(mode=False) + if i == 0: + self.stem.train(mode=False) + return self + + def forward_func(self, x): + x = self.stem(x) + h, w = x.shape[-2:] + x = self.avgpool(x) + outs = [] + if self.model_type == 'vit': + x = torch.cat( + [self.t_class_embedding.to(x.dtype) + + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), + x], dim=1 + ) # shape = [*, grid ** 2 + 1, width] + new_pos_embed = resample_abs_pos_embed( + self.t_positional_embedding[None], + [h, w], + num_prefix_tokens=1 + ) + x = x + new_pos_embed.to(x.dtype) + x = self.t_ln_pre_trans(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.t_transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = x[:, 1:] + x = x.permute(0, 2, 1).unflatten(2, (h, w)) # BCHW + for i in range(self.STAGES): + outs.append( + F.interpolate( + x, + scale_factor=2 ** (2 - i), + mode='bilinear', + align_corners=False + ) + ) + else: + for i, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + x = res_layer(x).contiguous() + if i in self.out_indices: + outs.append(x) + return tuple(outs) + + def get_clip_feature(self, backbone_feat): + if self.model_type == 'resnet': + return backbone_feat + elif self.model_type == 'convnext': + return self.norm_pre(backbone_feat) + elif self.model_type == 'vit': + return backbone_feat + raise NotImplementedError + + def forward_feat(self, features): + if self.model_type == 'convnext': + batch, num_query, channel = features.shape + features = features.reshape(batch * num_query, channel, 1, 1) + features = self.head(features) + return features.view(batch, num_query, features.shape[-1]) + elif self.model_type == 'resnet': + num_query, channel, seven, seven = features.shape + features = self.head(features) + return features + elif self.model_type == 'vit': + return (self.head(features) @ self.proj)[:, 0] # should return n x c + + def forward(self, x, output_hidden_states=True): + if self.backbone_type is None: + self.backbone_type = [p.dtype for p in self.parameters()][0] + x = x.to(self.backbone_type) + if self.fix: + with torch.no_grad(): + outs = self.forward_func(x) + else: + outs = self.forward_func(x) + + # print([item.shape for item in outs]) + # outs = outs[-2].flatten(2).permute(0, 2, 1) + + # ms + second_outs = outs[-2] + second_shape = second_outs.shape[2:] + last_outs = outs[-1] + last_outs = F.interpolate( + last_outs, + size=second_shape, + mode='bilinear', + align_corners=False + ) + outs = torch.cat([second_outs, last_outs], dim=1).flatten(2).permute(0, 2, 1) + + outs = self.set_output_gradient(outs) + images_feat = torch.cat([outs[:, :1, :], outs], dim=1) + hidden_states = [images_feat, images_feat] + output = Output(hidden_states=hidden_states) + return output + + def enable_input_require_grads(self): + self.enable_output_gradient = True + return + + def set_output_gradient(self, output): + output.requires_grad_(self.enable_output_gradient) + return output + + def requires_grad_(self, state): + if state: + print("Not Frozen the Visual Encoder !") + else: + print("Frozen the Visual Encoder !") + for p in self.parameters(): + p.requires_grad_(state) + return + + # def forward(self, x): + # if self.fix: + # with torch.no_grad(): + # outs = self.forward_func(x) + # else: + # outs = self.forward_func(x) + # return outs + + def get_text_model(self): + return OpenCLIPBackboneText( + self.model_name, + init_cfg=self.init_cfg + ) + +@MODELS.register_module() +class OpenCLIPBackbone_omgseg(BaseModule): + """OpenCLIPBackbone, + Please refer to: + https://github.com/mlfoundations/open_clip/tree/5f7892b672b21e6853d0f6c11b18dda9bcf36c8d#pretrained-model-interface + for the supported models and checkpoints. + """ + STAGES = 4 + + def __init__( + self, + img_size: int = 1024, + model_name: str = '', + fix: bool = True, + fix_layers: Optional[List] = None, + init_cfg=None, + dtype=torch.float16, + **kwargs, + ): + assert init_cfg is not None and init_cfg['type'] in ['clip_pretrain', 'image_pretrain', 'Pretrained'], \ + f"{init_cfg['type']} is not supported." + pretrained = init_cfg['checkpoint'] + super().__init__(init_cfg=None) + self.init_cfg = init_cfg + self.logger = MMLogger.get_current_instance() + rank, world_size = get_dist_info() + + if world_size > 1: + if rank == 0: + if init_cfg['type'] == 'clip_pretrain': + _ = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained, + return_transform=False, logger=self.logger) + elif init_cfg['type'] == 'image_pretrain': + _ = open_clip.create_model(model_name, pretrained_image=True, logger=self.logger) + + else: + pass + dist.barrier() + + # Get the clip model + if init_cfg['type'] == 'clip_pretrain': + clip_model = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained, + return_transform=False, logger=self.logger) + elif init_cfg['type'] == 'image_pretrain': + clip_model = open_clip.create_model(model_name, pretrained_image=True, logger=self.logger) + elif init_cfg['type'] == 'Pretrained': + clip_model = open_clip.create_model(model_name, pretrained_image=False, logger=self.logger) + else: + raise NotImplementedError + + self.out_indices = (0, 1, 2, 3) + model_name_lower = model_name.lower() + if 'convnext_' in model_name_lower: + model_type = 'convnext' + if '_base' in model_name_lower: + output_channels = [128, 256, 512, 1024] + feat_size = 0 + elif '_large' in model_name_lower: + output_channels = [192, 384, 768, 1536] + feat_size = 0 + elif '_xxlarge' in model_name_lower: + output_channels = [384, 768, 1536, 3072] + feat_size = 0 + else: + raise NotImplementedError(f"{model_name} not supported yet.") + elif 'rn' in model_name_lower: + model_type = 'resnet' + if model_name_lower.replace('-quickgelu', '') in ['rn50', 'rn101']: + output_channels = [256, 512, 1024, 2048] + feat_size = 7 + elif model_name_lower == 'rn50x4': + output_channels = [320, 640, 1280, 2560] + feat_size = 9 + elif model_name_lower == 'rn50x16': + output_channels = [384, 768, 1536, 3072] + feat_size = 12 + elif model_name_lower == 'rn50x64': + output_channels = [512, 1024, 2048, 4096] + feat_size = 14 + else: + raise NotImplementedError(f"{model_name} not supported yet.") + elif "vit" in model_name_lower: + model_type = 'vit' + if model_name_lower == 'vit-l-14': + output_channels = [1024, 1024, 1024, 1024] + feat_size = 0 + assert not clip_model.visual.input_patchnorm + assert clip_model.visual.attn_pool is None + elif model_name_lower == 'vit-b-32': + output_channels = [768, 768, 768, 768] + feat_size = 0 + assert not clip_model.visual.input_patchnorm + assert clip_model.visual.attn_pool is None + else: + raise NotImplementedError(f"{model_name} not supported yet.") + else: + raise NotImplementedError(f"{model_name} not supported yet.") + + self.model_name = model_name + self.fix = fix + self.model_type = model_type + self.output_channels = output_channels + self.feat_size = feat_size + + self.config = Data + # self.config.hidden_size = output_channels[-2] + self.config.hidden_size = output_channels[-1] + output_channels[-2] + + # Get the visual model + if self.model_type == 'resnet': + self.stem = nn.Sequential(*[ + clip_model.visual.conv1, clip_model.visual.bn1, clip_model.visual.act1, + clip_model.visual.conv2, clip_model.visual.bn2, clip_model.visual.act2, + clip_model.visual.conv3, clip_model.visual.bn3, clip_model.visual.act3, + ]) + elif self.model_type == 'convnext': + self.stem = clip_model.visual.trunk.stem + elif self.model_type == 'vit': + self.stem = clip_model.visual.conv1 + else: + raise ValueError + + if self.model_type == 'resnet': + self.avgpool = clip_model.visual.avgpool + elif self.model_type == 'convnext': + self.avgpool = nn.Identity() + elif self.model_type == 'vit': + self.avgpool = flatten_permute + else: + raise ValueError + + self.res_layers = [] + if self.model_type in ['vit']: + self.t_class_embedding = clip_model.visual.class_embedding + self.t_positional_embedding = clip_model.visual.positional_embedding + self.t_ln_pre_trans = clip_model.visual.ln_pre + self.t_transformer = clip_model.visual.transformer + else: + for i in range(self.STAGES): + if self.model_type == 'resnet': + layer_name = f'layer{i + 1}' + layer = getattr(clip_model.visual, layer_name) + elif self.model_type == 'convnext': + layer_name = f'layer{i + 1}' + layer = clip_model.visual.trunk.stages[i] + else: + raise ValueError + self.add_module(layer_name, layer) + self.res_layers.append(layer_name) + + if self.model_type == 'resnet': + self.norm_pre = nn.Identity() + elif self.model_type == 'convnext': + self.norm_pre = clip_model.visual.trunk.norm_pre + elif self.model_type == 'vit': + self.norm_pre = nn.Identity() + + if self.model_type == 'resnet': + self.head = clip_model.visual.attnpool + elif self.model_type == 'convnext': + self.head = nn.Sequential(*[ + clip_model.visual.trunk.head, + clip_model.visual.head, + ]) + elif self.model_type == 'vit': + self.head = clip_model.visual.ln_post + self.proj = clip_model.visual.proj + + if self.init_cfg['type'] == 'Pretrained': + checkpoint_path = pretrained + state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix=self.init_cfg['prefix']) + self.load_state_dict(state_dict, strict=True) + + self.fix_layers = fix_layers + + if not self.fix: + self.train() + for name, param in self.norm_pre.named_parameters(): + param.requires_grad = False + for name, param in self.head.named_parameters(): + param.requires_grad = False + if self.fix_layers is not None: + for i, layer_name in enumerate(self.res_layers): + if i in self.fix_layers: + res_layer = getattr(self, layer_name) + for name, param in res_layer.named_parameters(): + param.requires_grad = False + if i == 0: + for name, param in self.stem.named_parameters(): + param.requires_grad = False + + if self.fix: + self.train(mode=False) + for name, param in self.named_parameters(): + param.requires_grad = False + + self.dtype = dtype + self.backbone_type = None + + self.enable_output_gradient = False + + def init_weights(self): + self.logger.info(f"Init Config for {self.model_name}") + self.logger.info(self.init_cfg) + + def train(self: torch.nn.Module, mode: bool = True) -> torch.nn.Module: + if not isinstance(mode, bool): + raise ValueError("training mode is expected to be boolean") + if self.fix: + super().train(mode=False) + else: + super().train(mode=mode) + if self.fix_layers is not None: + for i, layer_name in enumerate(self.res_layers): + if i in self.fix_layers: + res_layer = getattr(self, layer_name) + res_layer.train(mode=False) + if i == 0: + self.stem.train(mode=False) + return self + + def forward_func(self, x): + x = self.stem(x) + h, w = x.shape[-2:] + x = self.avgpool(x) + outs = [] + if self.model_type == 'vit': + x = torch.cat( + [self.t_class_embedding.to(x.dtype) + + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), + x], dim=1 + ) # shape = [*, grid ** 2 + 1, width] + new_pos_embed = resample_abs_pos_embed( + self.t_positional_embedding[None], + [h, w], + num_prefix_tokens=1 + ) + x = x + new_pos_embed.to(x.dtype) + x = self.t_ln_pre_trans(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.t_transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = x[:, 1:] + x = x.permute(0, 2, 1).unflatten(2, (h, w)) # BCHW + for i in range(self.STAGES): + outs.append( + F.interpolate( + x, + scale_factor=2 ** (2 - i), + mode='bilinear', + align_corners=False + ) + ) + else: + for i, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + x = res_layer(x).contiguous() + if i in self.out_indices: + outs.append(x) + return tuple(outs) + + def get_clip_feature(self, backbone_feat): + if self.model_type == 'resnet': + return backbone_feat + elif self.model_type == 'convnext': + return self.norm_pre(backbone_feat) + elif self.model_type == 'vit': + return backbone_feat + raise NotImplementedError + + def forward_feat(self, features): + if self.model_type == 'convnext': + batch, num_query, channel = features.shape + features = features.reshape(batch * num_query, channel, 1, 1) + features = self.head(features) + return features.view(batch, num_query, features.shape[-1]) + elif self.model_type == 'resnet': + num_query, channel, seven, seven = features.shape + features = self.head(features) + return features + elif self.model_type == 'vit': + return (self.head(features) @ self.proj)[:, 0] # should return n x c + + def forward(self, x, output_hidden_states=True): + if self.backbone_type is None: + self.backbone_type = [p.dtype for p in self.parameters()][0] + x = x.to(self.backbone_type) + if self.fix: + with torch.no_grad(): + outs = self.forward_func(x) + else: + outs = self.forward_func(x) + + return outs + + def get_text_model(self): + return OpenCLIPBackboneText( + self.model_name, + init_cfg=self.init_cfg + ) + +@MODELS.register_module() +class OpenCLIPBackboneText(BaseModule): + def __init__( + self, + model_name: str = '', + init_cfg=None, + ): + assert init_cfg is not None and init_cfg['type'] == 'clip_pretrain', f"{init_cfg['type']} is not supported." + pretrained = init_cfg['checkpoint'] + super().__init__(init_cfg=None) + self.init_cfg = init_cfg + self.logger = MMLogger.get_current_instance() + rank, world_size = get_dist_info() + + if world_size > 1: + if rank == 0: + _ = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained, return_transform=False, + logger=self.logger) + else: + pass + dist.barrier() + + # Get the clip model + clip_model = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained, return_transform=False, + logger=self.logger) + + # Get the textual model + self.text_tokenizer = open_clip.get_tokenizer(model_name) + self.text_transformer = clip_model.transformer + self.text_token_embedding = clip_model.token_embedding + self.text_pe = clip_model.positional_embedding + self.text_ln_final = clip_model.ln_final + self.text_proj = clip_model.text_projection + + self.register_buffer('text_attn_mask', clip_model.attn_mask) + + self.param_dtype = torch.float32 + self.model_name = model_name + + def init_weights(self): + self.logger.info(f"Init Config for {self.model_name}") + self.logger.info(self.init_cfg) + + # Copied from + # https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L343 + @torch.no_grad() + def forward(self, text): + text_tokens = self.text_tokenizer(text).to(device=self.text_proj.device) + x = self.text_token_embedding(text_tokens).to(self.param_dtype) + x = x + self.text_pe.to(self.param_dtype) + x = x.permute(1, 0, 2) + x = self.text_transformer(x, attn_mask=self.text_attn_mask) + x = x.permute(1, 0, 2) + x = self.text_ln_final(x) # [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text_tokens.argmax(dim=-1)] @ self.text_proj + return x diff --git a/omg_llava/model/modules/__init__.py b/omg_llava/model/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e16337dd947c9cab1d884b5785f220a1a65e6055 --- /dev/null +++ b/omg_llava/model/modules/__init__.py @@ -0,0 +1,4 @@ +from xtuner.model import * +from .projector import ProjectorModel_OMG_LLaVA, ProjectorConfig_OMG_LLaVA + +__all__ = ['ProjectorConfig_OMG_LLaVA', 'ProjectorModel_OMG_LLaVA', ] diff --git a/omg_llava/model/modules/__pycache__/__init__.cpython-310.pyc b/omg_llava/model/modules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab29ae30c474212cffb9946bc02cb1e359c3dcb9 Binary files /dev/null and b/omg_llava/model/modules/__pycache__/__init__.cpython-310.pyc differ diff --git a/omg_llava/model/modules/projector/__init__.py b/omg_llava/model/modules/projector/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..86b24864d2c6c4dc48b5a9319417f091d648b9e4 --- /dev/null +++ b/omg_llava/model/modules/projector/__init__.py @@ -0,0 +1,9 @@ +from xtuner.model.modules.projector import * +from transformers import AutoConfig, AutoModel +from .configuration_projector import ProjectorConfig_OMG_LLaVA +from .modeling_projector import ProjectorModel_OMG_LLaVA + +AutoConfig.register('projector', ProjectorConfig_OMG_LLaVA) +AutoModel.register(ProjectorConfig_OMG_LLaVA, ProjectorModel_OMG_LLaVA) + +__all__ = ['ProjectorConfig_OMG_LLaVA', 'ProjectorModel_OMG_LLaVA'] diff --git a/omg_llava/model/modules/projector/__pycache__/__init__.cpython-310.pyc b/omg_llava/model/modules/projector/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1da382abcb0ec0277db2476d902fd74e5455552c Binary files /dev/null and b/omg_llava/model/modules/projector/__pycache__/__init__.cpython-310.pyc differ diff --git a/omg_llava/model/modules/projector/__pycache__/configuration_projector.cpython-310.pyc b/omg_llava/model/modules/projector/__pycache__/configuration_projector.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f8fe2a061299b4990e4b08271712e89a4758f1f Binary files /dev/null and b/omg_llava/model/modules/projector/__pycache__/configuration_projector.cpython-310.pyc differ diff --git a/omg_llava/model/modules/projector/__pycache__/modeling_projector.cpython-310.pyc b/omg_llava/model/modules/projector/__pycache__/modeling_projector.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe3afb207597fbe2546cfb2f4247d2272f34b721 Binary files /dev/null and b/omg_llava/model/modules/projector/__pycache__/modeling_projector.cpython-310.pyc differ diff --git a/omg_llava/model/modules/projector/__pycache__/modeling_projector_seperate.cpython-310.pyc b/omg_llava/model/modules/projector/__pycache__/modeling_projector_seperate.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a35b9b9d479d1a0deef78e2fceb5405a5716f148 Binary files /dev/null and b/omg_llava/model/modules/projector/__pycache__/modeling_projector_seperate.cpython-310.pyc differ diff --git a/omg_llava/model/modules/projector/configuration_projector.py b/omg_llava/model/modules/projector/configuration_projector.py new file mode 100644 index 0000000000000000000000000000000000000000..f8f13aee6ba2b5649899782bbb6f335be75d91a0 --- /dev/null +++ b/omg_llava/model/modules/projector/configuration_projector.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from transformers import PretrainedConfig + +class ProjectorConfig_OMG_LLaVA(PretrainedConfig): + model_type = 'projector' + _auto_class = 'AutoConfig' + + def __init__( + self, + visual_hidden_size=4096, + llm_hidden_size=4096, + depth=2, + hidden_act='gelu', + bias=True, + query_channels=256, + feat_channels=1536, + pixel_shuffle_ratio=None, + additional_bg_tokens=10, + visual_prompt_proj=False, + add_cross_attn_layer=False, + **kwargs, + ): + self.visual_hidden_size = visual_hidden_size + self.llm_hidden_size = llm_hidden_size + self.depth = depth + self.hidden_act = hidden_act + self.bias = bias + self.query_channels=query_channels + self.feat_channels=feat_channels + if pixel_shuffle_ratio is not None: + self.feat_channels = self.feat_channels * pixel_shuffle_ratio * pixel_shuffle_ratio + self.additional_bg_tokens = additional_bg_tokens + self.visual_prompt_proj = visual_prompt_proj + self.add_cross_attn_layer = add_cross_attn_layer + super().__init__(**kwargs) \ No newline at end of file diff --git a/omg_llava/model/modules/projector/modeling_projector.py b/omg_llava/model/modules/projector/modeling_projector.py new file mode 100644 index 0000000000000000000000000000000000000000..980d64129290e4ea4d0adf3cd3ebf8cab341f40b --- /dev/null +++ b/omg_llava/model/modules/projector/modeling_projector.py @@ -0,0 +1,309 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from transformers import PreTrainedModel +from transformers.activations import ACT2FN +from typing import Optional +from torch import Tensor +from torch.nn import functional as F + +from .configuration_projector import ProjectorConfig_OMG_LLaVA + +class Naive_Proj(nn.Module): + def __init__(self, config, rm_prior_embedding=False, rm_query=False): + super().__init__() + query_channels = config.query_channels + self.query_channels = query_channels + feat_channels = config.feat_channels + if isinstance(query_channels, tuple): + query_channels = query_channels[0] + if isinstance(feat_channels, tuple): + feat_channels = feat_channels[0] + + query_channels = query_channels * 2 # feat + embed + + self.query_proj = nn.Linear(query_channels, feat_channels) + + modules = [ + nn.Linear( + feat_channels, + config.llm_hidden_size, + bias=config.bias) + ] + for _ in range(1, config.depth): + modules.append(ACT2FN[config.hidden_act]) + modules.append( + nn.Linear( + config.llm_hidden_size, + config.llm_hidden_size, + bias=config.bias)) + self.model = nn.Sequential(*modules) + + modules = [ + nn.Linear( + feat_channels + query_channels, + config.llm_hidden_size, + bias=config.bias) + ] + for _ in range(1, config.depth): + modules.append(ACT2FN[config.hidden_act]) + modules.append( + nn.Linear( + config.llm_hidden_size, + config.llm_hidden_size, + bias=config.bias)) + self.model_feat = nn.Sequential(*modules) + + self.seperate_embed = nn.Embedding(1, config.llm_hidden_size) + + self.rm_prior_embedding = rm_prior_embedding + self.rm_query = rm_query + + def forward(self, x): + clip_feature, query_feat, attention_mask = x + # clip feature (bs, hw, c + 2 * q_c) + # query_feat (bs, q, c) + # attention_mask (bs, q, hw) + + if self.rm_prior_embedding: + clip_feature_feat = clip_feature[:, :, :-512] + clip_feature_query = clip_feature[:, :, -512:] * 0.0 + clip_feature = torch.cat([clip_feature_feat, clip_feature_query], dim=-1) + + query_feat = self.query_proj(query_feat) + + valid_mask = attention_mask.sum(dim=-1) < attention_mask.shape[-1] # (bs, q) + # valid_mask # (bs, q) + # query_feat (bs, q, c) + # clip_feature (bs, hw, c) + # attn_map (bs, q, hw) + bs, n_q = query_feat.shape[:2] + + layer_outputs = self.model(query_feat) + + # filter + clip_feature_out = clip_feature + clip_feature_out = self.model_feat(clip_feature_out) + ret = [] + + valid_queries_embeddings = [] + for layer_output, keep in zip(layer_outputs, valid_mask): + valid_queries_embeddings.append(layer_output[keep]) + self.valid_queries_embeddings = valid_queries_embeddings + + for clip_feat, layer_output, keep in zip(clip_feature_out, layer_outputs, valid_mask): + if self.rm_query: + ret.append(clip_feat + torch.mean(self.seperate_embed.weight) * 0.0 + torch.mean(layer_output[keep]) * 0.0) + else: + ret.append(torch.cat([clip_feat, self.seperate_embed.weight, layer_output[keep]], dim=0)) + return ret + +class ProjectorModel_OMG_LLaVA(PreTrainedModel): + _auto_class = 'AutoModel' + config_class = ProjectorConfig_OMG_LLaVA + base_model_prefix = 'model' + supports_gradient_checkpointing = True + + def __init__(self, config: ProjectorConfig_OMG_LLaVA) -> None: + super().__init__(config) + self.gradient_checkpointing = False + + self.rm_prior_embedding = False + self.rm_query = False + self.model = Naive_Proj(config, ) + + def enable_input_require_grads(self): + + def make_inputs_require_grad(module, input, output): + if isinstance(output, torch.Tensor): + output.requires_grad_(True) + else: + for item in output: + item.requires_grad_(True) + + self.model.register_forward_hook(make_inputs_require_grad) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, ProjectorConfig_OMG_LLaVA): + module.gradient_checkpointing = value + + def forward(self, x): + if self.gradient_checkpointing and self.training: + layer_outputs = torch.utils.checkpoint.checkpoint(self.model, x) + else: + layer_outputs = self.model(x) + return layer_outputs + + +class SelfAttentionLayer(nn.Module): + + def __init__(self, d_model, nhead, dropout=0.0, + activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + + self.norm = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, tgt, + tgt_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout(tgt2) + tgt = self.norm(tgt) + + return tgt + + def forward_pre(self, tgt, + tgt_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + tgt2 = self.norm(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout(tgt2) + + return tgt + + def forward(self, tgt, + tgt_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(tgt, tgt_mask, + tgt_key_padding_mask, query_pos) + return self.forward_post(tgt, tgt_mask, + tgt_key_padding_mask, query_pos) + + +class CrossAttentionLayer(nn.Module): + + def __init__(self, d_model, nhead, dropout=0.0, + activation="relu", normalize_before=False): + super().__init__() + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + + self.norm = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, tgt, memory, + memory_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout(tgt2) + tgt = self.norm(tgt) + + return tgt + + def forward_pre(self, tgt, memory, + memory_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + tgt2 = self.norm(tgt) + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout(tgt2) + + return tgt + + def forward(self, tgt, memory, + memory_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(tgt, memory, memory_mask, + memory_key_padding_mask, pos, query_pos) + return self.forward_post(tgt, memory, memory_mask, + memory_key_padding_mask, pos, query_pos) + + +class FFNLayer(nn.Module): + + def __init__(self, d_model, dim_feedforward=2048, dropout=0.0, + activation="relu", normalize_before=False): + super().__init__() + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm = nn.LayerNorm(d_model) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, tgt): + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout(tgt2) + tgt = self.norm(tgt) + return tgt + + def forward_pre(self, tgt): + tgt2 = self.norm(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout(tgt2) + return tgt + + def forward(self, tgt): + if self.normalize_before: + return self.forward_pre(tgt) + return self.forward_post(tgt) + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") \ No newline at end of file diff --git a/omg_llava/model/modules/projector/modeling_projector_seperate.py b/omg_llava/model/modules/projector/modeling_projector_seperate.py new file mode 100644 index 0000000000000000000000000000000000000000..b20f0f47689687c8299aee8bd454ba971f31a1db --- /dev/null +++ b/omg_llava/model/modules/projector/modeling_projector_seperate.py @@ -0,0 +1,427 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from transformers import PreTrainedModel +from transformers.activations import ACT2FN +from typing import Optional +from torch import Tensor +from torch.nn import functional as F + +from .configuration_projector import ProjectorConfig_OMG_LLaVA + +class Naive_Proj(nn.Module): + def __init__(self, config, rm_prior_embedding=False, + rm_query=False): + super().__init__() + query_channels = config.query_channels + self.query_channels = query_channels + feat_channels = config.feat_channels + if isinstance(query_channels, tuple): + query_channels = query_channels[0] + if isinstance(feat_channels, tuple): + feat_channels = feat_channels[0] + + add_cross_attn_layer = config.add_cross_attn_layer + self.add_cross_attn_layer = config.add_cross_attn_layer + + query_channels = query_channels * 2 # feat + embed + + self.query_proj = nn.Linear(query_channels, feat_channels) + + modules = [ + nn.Linear( + feat_channels, + config.llm_hidden_size, + bias=config.bias) + ] + for _ in range(1, config.depth): + modules.append(ACT2FN[config.hidden_act]) + modules.append( + nn.Linear( + config.llm_hidden_size, + config.llm_hidden_size, + bias=config.bias)) + self.model = nn.Sequential(*modules) + + if add_cross_attn_layer: + print("Using Cross Attention Layer at Projector !!!") + self.query_cross_attn = CrossAttentionLayer( + d_model=config.llm_hidden_size, + nhead=32, + ) + self.query_ffn = FFNLayer( + d_model=config.llm_hidden_size, + dim_feedforward=4096, + ) + else: + self.query_cross_attn = None + self.query_ffn = None + + modules = [ + nn.Linear( + feat_channels + query_channels, + config.llm_hidden_size, + bias=config.bias) + ] + for _ in range(1, config.depth): + modules.append(ACT2FN[config.hidden_act]) + modules.append( + nn.Linear( + config.llm_hidden_size, + config.llm_hidden_size, + bias=config.bias)) + self.model_feat = nn.Sequential(*modules) + + self.seperate_embed = nn.Embedding(1, config.llm_hidden_size) + + self.rm_prior_embedding = rm_prior_embedding + self.rm_query = rm_query + + visual_prompt_proj = config.visual_prompt_proj + self.visual_prompt_proj = visual_prompt_proj + if not visual_prompt_proj: + self.visual_prompt_query_proj = None + self.visual_prompt_query_model = None + self.visual_prompt_query_cross_attn = None + self.visual_prompt_query_ffn = None + else: + print("Initialized all Layers for Visual Prompt in Projector !!!") + self.visual_prompt_query_proj = nn.Linear(query_channels, feat_channels) + modules = [ + nn.Linear( + feat_channels, + config.llm_hidden_size, + bias=config.bias) + ] + for _ in range(1, config.depth): + modules.append(ACT2FN[config.hidden_act]) + modules.append( + nn.Linear( + config.llm_hidden_size, + config.llm_hidden_size, + bias=config.bias)) + self.visual_prompt_query_model = nn.Sequential(*modules) + + if add_cross_attn_layer: + self.visual_prompt_query_cross_attn = CrossAttentionLayer( + d_model=config.llm_hidden_size, + nhead=32, + ) + self.visual_prompt_query_ffn = FFNLayer( + d_model=config.llm_hidden_size, + dim_feedforward=4096, + ) + else: + self.visual_prompt_query_cross_attn = None + self.visual_prompt_query_ffn = None + + def forward(self, x): + clip_feature, query_feat, attention_mask = x + query_feat_copy = query_feat[0, :1] # (1, 1, c) + # clip feature (bs, hw, c + 2 * q_c) + # query_feat (bs, q, c) + # attention_mask (bs, q, hw) + + if self.rm_prior_embedding: + clip_feature_feat = clip_feature[:, :, :-512] + clip_feature_query = clip_feature[:, :, -512:] * 0.0 + clip_feature = torch.cat([clip_feature_feat, clip_feature_query], dim=-1) + + query_feat = self.query_proj(query_feat) + + valid_mask = attention_mask.sum(dim=-1) < attention_mask.shape[-1] # (bs, q) + # valid_mask # (bs, q) + # query_feat (bs, q, c) + # clip_feature (bs, hw, c) + # attn_map (bs, q, hw) + bs, n_q = query_feat.shape[:2] + + layer_outputs = self.model(query_feat) + + # filter + clip_feature_out = clip_feature + clip_feature_out = self.model_feat(clip_feature_out) + ret = [] + + valid_queries_embeddings = [] + for layer_output, keep in zip(layer_outputs, valid_mask): + valid_queries_embeddings.append(layer_output[keep]) + self.valid_queries_embeddings = valid_queries_embeddings + + self.last_clip_feature = clip_feature_out + + for clip_feat, layer_output, keep in zip(clip_feature_out, layer_outputs, valid_mask): + valid_layer_output = layer_output[keep] + if self.add_cross_attn_layer: + valid_layer_output = self.query_cross_attn( + valid_layer_output.unsqueeze(1), clip_feat.unsqueeze(1), + )[:, 0] + valid_layer_output = self.query_ffn(valid_layer_output) + if self.rm_query: + ret.append(clip_feat + torch.mean(self.seperate_embed.weight) * 0.0 + torch.mean(valid_layer_output) * 0.0) + else: + ret.append(torch.cat([clip_feat, self.seperate_embed.weight, valid_layer_output], dim=0)) + + # generate zero using visual prompt projector if valid + if self.visual_prompt_proj: + visual_prompt_embeddings = query_feat_copy.to(self.visual_prompt_query_proj.weight.dtype) + visual_prompt_embeddings = self.visual_prompt_query_proj(visual_prompt_embeddings) + visual_prompt_embeddings = self.visual_prompt_query_model(visual_prompt_embeddings) # (B, C) + if self.add_cross_attn_layer: + clip_feat = self.last_clip_feature[0] # (B, HW, C) + visual_prompt_embeddings = self.visual_prompt_query_cross_attn( + visual_prompt_embeddings.unsqueeze(1), clip_feat.unsqueeze(1), + )[:, 0] + visual_prompt_embeddings = self.visual_prompt_query_ffn(visual_prompt_embeddings) + self.visual_prompt_zero = visual_prompt_embeddings.sum() * 0.0 + else: + self.visual_prompt_zero = 0.0 + return ret + + def forward_visual_prompts_embeddings(self, visual_prompt_embeddings, batch_idxs): + if self.visual_prompt_proj: + visual_prompt_embeddings = visual_prompt_embeddings.to(self.visual_prompt_query_proj.weight.dtype) + visual_prompt_embeddings = self.visual_prompt_query_proj(visual_prompt_embeddings) + visual_prompt_embeddings = self.visual_prompt_query_model(visual_prompt_embeddings) # (B, C) + if self.add_cross_attn_layer: + clip_feat = self.last_clip_feature[batch_idxs].permute(1, 0, 2) # (B, HW, C) + visual_prompt_embeddings = self.visual_prompt_query_cross_attn( + visual_prompt_embeddings.unsqueeze(0), clip_feat, + )[0, :] + visual_prompt_embeddings = self.visual_prompt_query_ffn(visual_prompt_embeddings) + else: + visual_prompt_embeddings = visual_prompt_embeddings.to(self.query_proj.weight.dtype) + visual_prompt_embeddings = self.query_proj(visual_prompt_embeddings) + visual_prompt_embeddings = self.model(visual_prompt_embeddings) # (B, C) + if self.add_cross_attn_layer: + clip_feat = self.last_clip_feature[batch_idxs].permute(1, 0, 2) # (B, HW, C) + visual_prompt_embeddings = self.query_cross_attn( + visual_prompt_embeddings.unsqueeze(0), clip_feat, + )[0, :] + visual_prompt_embeddings = self.query_ffn(visual_prompt_embeddings) + return visual_prompt_embeddings + + def init_visual_prompt_weights(self): + if self.visual_prompt_query_proj is not None: + self.visual_prompt_query_proj.load_state_dict(self.query_proj.state_dict()) + if self.visual_prompt_query_model is not None: + self.visual_prompt_query_model.load_state_dict(self.model.state_dict()) + if self.visual_prompt_query_cross_attn is not None: + self.visual_prompt_query_cross_attn.load_state_dict( + self.query_cross_attn.state_dict() + ) + if self.visual_prompt_query_ffn is not None: + self.visual_prompt_query_ffn.load_state_dict( + self.query_ffn.state_dict() + ) + return + +class ProjectorModel_OMG_LLaVA(PreTrainedModel): + _auto_class = 'AutoModel' + config_class = ProjectorConfig_OMG_LLaVA + base_model_prefix = 'model' + supports_gradient_checkpointing = True + + def __init__(self, config: ProjectorConfig_OMG_LLaVA) -> None: + super().__init__(config) + self.gradient_checkpointing = False + + self.rm_prior_embedding = False + self.rm_query = False + self.model = Naive_Proj(config, ) + + def enable_input_require_grads(self): + + def make_inputs_require_grad(module, input, output): + if isinstance(output, torch.Tensor): + output.requires_grad_(True) + else: + for item in output: + item.requires_grad_(True) + + self.model.register_forward_hook(make_inputs_require_grad) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, ProjectorConfig_OMG_LLaVA): + module.gradient_checkpointing = value + + def forward(self, x): + if self.gradient_checkpointing and self.training: + layer_outputs = torch.utils.checkpoint.checkpoint(self.model, x) + else: + layer_outputs = self.model(x) + return layer_outputs + + +class SelfAttentionLayer(nn.Module): + + def __init__(self, d_model, nhead, dropout=0.0, + activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + + self.norm = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, tgt, + tgt_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout(tgt2) + tgt = self.norm(tgt) + + return tgt + + def forward_pre(self, tgt, + tgt_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + tgt2 = self.norm(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout(tgt2) + + return tgt + + def forward(self, tgt, + tgt_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(tgt, tgt_mask, + tgt_key_padding_mask, query_pos) + return self.forward_post(tgt, tgt_mask, + tgt_key_padding_mask, query_pos) + + +class CrossAttentionLayer(nn.Module): + + def __init__(self, d_model, nhead, dropout=0.0, + activation="relu", normalize_before=False): + super().__init__() + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + + self.norm = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, tgt, memory, + memory_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout(tgt2) + tgt = self.norm(tgt) + + return tgt + + def forward_pre(self, tgt, memory, + memory_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + tgt2 = self.norm(tgt) + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout(tgt2) + + return tgt + + def forward(self, tgt, memory, + memory_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(tgt, memory, memory_mask, + memory_key_padding_mask, pos, query_pos) + return self.forward_post(tgt, memory, memory_mask, + memory_key_padding_mask, pos, query_pos) + + +class FFNLayer(nn.Module): + + def __init__(self, d_model, dim_feedforward=2048, dropout=0.0, + activation="relu", normalize_before=False): + super().__init__() + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm = nn.LayerNorm(d_model) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, tgt): + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout(tgt2) + tgt = self.norm(tgt) + return tgt + + def forward_pre(self, tgt): + tgt2 = self.norm(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout(tgt2) + return tgt + + def forward(self, tgt): + if self.normalize_before: + return self.forward_pre(tgt) + return self.forward_post(tgt) + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") \ No newline at end of file diff --git a/omg_llava/model/omg_llava.py b/omg_llava/model/omg_llava.py new file mode 100644 index 0000000000000000000000000000000000000000..121f327ba175944b47453f69985d1abb4fd59861 --- /dev/null +++ b/omg_llava/model/omg_llava.py @@ -0,0 +1,783 @@ +from collections import OrderedDict +import torch +import torch.nn as nn +from mmengine.config import Config, ConfigDict +from mmengine.model import BaseModel +from peft import get_peft_model, prepare_model_for_kbit_training + +from xtuner.registry import BUILDER +# from .modules import ProjectorConfig_OMG_LLaVA, ProjectorModel_OMG_LLaVA +from .modules.projector.modeling_projector_seperate import ProjectorConfig_OMG_LLaVA, ProjectorModel_OMG_LLaVA +from xtuner.model.modules import ProjectorModel, ProjectorConfig +from xtuner.model.modules import dispatch_modules +from .utils import (LoadWoInit, find_all_linear_names, + get_peft_model_state_dict, guess_load_checkpoint, + make_inputs_require_grad, + traverse_dict, + prepare_inputs_labels_for_multimodal_with_visual_prompts) +from .convnext_clip import OpenCLIPBackbone +from .omg_seg import OMGSegVisualEncoder + +class OMG_LLaVA(BaseModel): + def __init__(self, + llm, + visual_encoder, + visual_select_layer=-2, + freeze_llm=False, + freeze_visual_encoder=False, + require_omg_decoder=False, + pretrained_pth=None, + llm_lora=None, + visual_encoder_lora=None, + use_activation_checkpointing=True, + projector_depth=2, + text2vision_projector=False, + tokenizer=None, + keep_omg_decoder_frozen=False, + add_seg_pretrain=False, + additional_cross_attn_layers=False, + pixel_shuffle_ratio=None, + train_vocabulary=False, + freeze_llm_with_lora=False, + freeze_visual_projector=False, + rm_prior_embedding=False, + rm_query=False, + clip_feat_channel=1536, + # for [SEG] + using_multilayer_states=False, + seg_token_merge_type='mean', + selected_layers=32, + # for proj ablation + visual_prompt_proj=False, + add_cross_attn_layer=False, + ): + super().__init__() + + self.freeze_llm_with_lora = freeze_llm_with_lora + self.freeze_visual_projector = freeze_visual_projector + + self.freeze_llm = freeze_llm + self.freeze_visual_encoder = freeze_visual_encoder + with LoadWoInit(): + self.llm = self._build_from_cfg_or_module(llm) + if visual_encoder.type == OpenCLIPBackbone or visual_encoder.type == OMGSegVisualEncoder: + self.visual_encoder = visual_encoder.type(**visual_encoder) + else: + self.visual_encoder = self._build_from_cfg_or_module( + visual_encoder) + self.llm.config.use_cache = False + dispatch_modules(self.llm) + + projector_config = ProjectorConfig_OMG_LLaVA( + query_channels=256, + feat_channels=clip_feat_channel, + llm_hidden_size=self.llm.config.hidden_size, + depth=projector_depth, + pixel_shuffle_ratio=pixel_shuffle_ratio, + visual_prompt_proj=visual_prompt_proj, + add_cross_attn_layer=add_cross_attn_layer, + ) + self.projector = ProjectorModel_OMG_LLaVA(projector_config).to( + self.visual_encoder.dtype) + + self.text2vision_projector = text2vision_projector + if text2vision_projector: + projector_config = ProjectorConfig( + visual_hidden_size=self.llm.config.hidden_size, + llm_hidden_size=256 * 2, + depth=projector_depth) + self.projector_text2vision = ProjectorModel(projector_config).to( + self.visual_encoder.dtype) + + + if rm_query: + self.projector.model.rm_query = rm_query + if rm_prior_embedding: + self.projector.model.rm_prior_embedding = rm_prior_embedding + + if self.freeze_llm: + self.llm.requires_grad_(False) + if self.freeze_visual_encoder: + self.visual_encoder.requires_grad_(False) + + self.use_activation_checkpointing = use_activation_checkpointing + if use_activation_checkpointing: + # For backward compatibility + if hasattr(self.llm, 'enable_input_require_grads'): + self.llm.enable_input_require_grads() + else: + self.llm.get_input_embeddings().register_forward_hook( + make_inputs_require_grad) + if hasattr(self.visual_encoder, 'enable_input_require_grads'): + self.visual_encoder.enable_input_require_grads() + else: + self.visual_encoder.get_input_embeddings( + ).register_forward_hook(make_inputs_require_grad) + self.projector.enable_input_require_grads() + if text2vision_projector: + self.projector_text2vision.enable_input_require_grads() + + # enable gradient (activation) checkpointing for memory efficiency + self.gradient_checkpointing_enable() + + # resize input embed before add llm lora + self.added_special_token = False + if tokenizer is not None: + self.tokenizer = tokenizer + tokenizer_type = self.tokenizer['type'] + del self.tokenizer['type'] + self.tokenizer = tokenizer_type(**self.tokenizer) + self._add_special_tokens() + + self.use_llm_lora = llm_lora is not None + self.use_visual_encoder_lora = visual_encoder_lora is not None + + if self.use_llm_lora: + self._prepare_llm_for_lora(llm_lora, use_activation_checkpointing) + if self.freeze_llm_with_lora: + for name, param in self.llm.named_parameters(): + param.requires_grad_(False) + else: + if train_vocabulary: + # train vocabulary embedding and logit head when pretrain + for name, param in self.named_parameters(): + if 'tok_' in name or 'lm_head' in name: + print("Unfrozen {} !!!".format(name)) + param.requires_grad_(True) + if 'output.' in name and 'llm' in name and 'lora' not in name: + print("Unfrozen {} !!!".format(name)) + param.requires_grad_(True) + + if self.use_visual_encoder_lora: + self._prepare_visual_encoder_for_lora( + visual_encoder_lora, use_activation_checkpointing) + + if pretrained_pth is not None: + pretrained_state_dict = guess_load_checkpoint(pretrained_pth) + self.load_state_dict(pretrained_state_dict, strict=False) + print(f'Load pretrained weight from {pretrained_pth}') + + if visual_prompt_proj: + print("Initialize the visual prompt projection weights with query projection weights !!! ") + self.projector.model.init_visual_prompt_weights() + + self.visual_select_layer = visual_select_layer + + self._is_init = True + + self.require_omg_decoder = require_omg_decoder + if require_omg_decoder: + self.visual_encoder.init_new_decoder() + if keep_omg_decoder_frozen: + for name, param in self.visual_encoder.panoptic_head.transformer_decoder_llm.named_parameters(): + param.requires_grad_(False) + print("Frozen all the omg seg decoder !!!") + + self.additional_cross_attn_layers = additional_cross_attn_layers + if self.additional_cross_attn_layers: + self.visual_encoder.init_cross_attn_layer() + + if self.freeze_visual_projector: + for name, param in self.projector.named_parameters(): + param.requires_grad_(False) + + self.add_seg_pretrain = add_seg_pretrain + + if text2vision_projector is False: + using_multilayer_states = False + self.using_multilayer_states = using_multilayer_states + self.seg_token_merge_type = seg_token_merge_type + self.selected_layers = selected_layers + if self.using_multilayer_states: + assert self.seg_token_merge_type in ['mean', 'cat', 'linear_cat'] + if self.seg_token_merge_type == 'cat': + self.seg_token_proj_cat = nn.Linear( + self.llm.config.hidden_size * self.selected_layers, + self.llm.config.hidden_size + ) + elif self.seg_token_merge_type == 'linear_cat': + self.seg_token_proj_linear_cat = nn.ModuleList() + self.seg_token_proj_linear_cat.append( + nn.Linear( + self.llm.config.hidden_size, + 196, + ) + ) + self.seg_token_proj_linear_cat.append( + nn.Linear( + 196 * self.selected_layers, + self.llm.config.hidden_size, + ) + ) + + + def _add_special_tokens(self): + assert hasattr(self, "tokenizer") + + segmentation_tokens = ['[SEG]'] + # Adding tokens for GCG + phrase_tokens = ['

', '

'] + # add for visual prompt + region_tokens = [''] + point_tokens = [''] + special_tokens = segmentation_tokens + phrase_tokens + region_tokens + self.tokenizer.add_tokens(special_tokens, special_tokens=True) + + self.seg_token_idx = self.tokenizer("[SEG]", add_special_tokens=False).input_ids[0] + self.bop_token_idx = self.tokenizer("

", add_special_tokens=False).input_ids[0] + self.eop_token_idx = self.tokenizer("

", add_special_tokens=False).input_ids[0] + self.region_token_idx = self.tokenizer("", add_special_tokens=False).input_ids[0] + # self.mark_token_idx = self.tokenizer("", add_special_tokens=False).input_ids[0] + + self.llm.resize_token_embeddings(len(self.tokenizer)) + + self.tokenizer.add_tokens(point_tokens, special_tokens=True) + self.mark_token_idx = self.tokenizer("", add_special_tokens=False).input_ids[0] + if self.use_activation_checkpointing or self.use_llm_lora or not self.freeze_llm: + self.llm.enable_input_require_grads() + self.added_special_token = True + print("[SEG]: {},

: {},

: {}, : {}, : {}" \ + .format(self.seg_token_idx, self.bop_token_idx, + self.eop_token_idx, self.region_token_idx, self.mark_token_idx)) + print('****************************Add special tokens ********************************************') + return + + def _parse_lora_config(self, lora_config): + if isinstance(lora_config, dict) or isinstance( + lora_config, Config) or isinstance(lora_config, ConfigDict): + lora_config = BUILDER.build(lora_config) + return lora_config + + def _prepare_llm_for_lora(self, + lora_config, + use_activation_checkpointing=True): + lora_config = self._parse_lora_config(lora_config) + self.llm = prepare_model_for_kbit_training( + self.llm, use_activation_checkpointing) + if lora_config.target_modules is None: + modules = find_all_linear_names(self.llm) + lora_config.target_modules = modules + self.llm = get_peft_model(self.llm, lora_config) + for name, param in self.named_parameters(): + if 'tok_' in name or 'lm_head' in name: + print("Unfrozen {} !!!".format(name)) + param.requires_grad_(True) + if 'output.' in name and 'llm' in name and 'lora' not in name: + print("Unfrozen {} !!!".format(name)) + param.requires_grad_(True) + + def _prepare_visual_encoder_for_lora(self, + lora_config, + use_activation_checkpointing=True): + lora_config = self._parse_lora_config(lora_config) + if lora_config.target_modules is None: + modules = find_all_linear_names(self.visual_encoder) + lora_config.target_modules = modules + self.visual_encoder = get_peft_model(self.visual_encoder, lora_config) + + def gradient_checkpointing_enable(self): + self.activation_checkpointing_enable() + + def activation_checkpointing_enable(self): + self.llm.gradient_checkpointing_enable() + if hasattr(self.visual_encoder, 'gradient_checkpointing_enable'): + self.visual_encoder.gradient_checkpointing_enable() + elif hasattr(self.visual_encoder, 'clip_model'): + if self.visual_encoder.clip_model is not None: + self.visual_encoder.clip_model.gradient_checkpointing_enable() + if hasattr(self.projector, 'gradient_checkpointing_enable'): + self.projector.gradient_checkpointing_enable() + if self.text2vision_projector and hasattr(self.projector_text2vision, 'gradient_checkpointing_enable'): + self.projector_text2vision.gradient_checkpointing_enable() + + def gradient_checkpointing_disable(self): + self.activation_checkpointing_disable() + + def activation_checkpointing_disable(self): + self.llm.gradient_checkpointing_disable() + if hasattr(self.visual_encoder, 'gradient_checkpointing_disable'): + self.visual_encoder.gradient_checkpointing_disable() + if hasattr(self.projector, 'gradient_checkpointing_disable'): + self.projector.gradient_checkpointing_disable() + if self.text2vision_projector and hasattr(self.projector_text2vision, 'gradient_checkpointing_disable'): + self.projector_text2vision.gradient_checkpointing_disable() + + def init_weights(self): + pass + + def state_dict(self, *args, **kwargs): + state_dict = super().state_dict(*args, **kwargs) + + to_return = OrderedDict() + + # # vocabulary embedding + # to_return.update( + # {'special_' + k: v for k, v in state_dict.items() if 'tok_embeddings' in k} + # ) + # # logit head + # to_return.update( + # {'special_' + k: v for k, v in state_dict.items() if 'output.' in k and 'llm' in k and 'lora' not in k} + # ) + + # vocabulary embedding + to_return.update( + {k: v for k, v in state_dict.items() if 'tok_embeddings' in k} + ) + # logit head + to_return.update( + {k: v for k, v in state_dict.items() if 'output.' in k and 'llm' in k and 'lora' not in k} + ) + + # Step 1. visual_encoder + if self.use_visual_encoder_lora: + to_return.update( + get_peft_model_state_dict( + self.visual_encoder, state_dict=state_dict)) + elif not self.freeze_visual_encoder: + to_return.update({ + k: v + for k, v in state_dict.items() if 'visual_encoder.' in k + }) + # Step 2. LLM + if self.use_llm_lora: + to_return.update( + get_peft_model_state_dict(self.llm, state_dict=state_dict)) + elif not self.freeze_llm: + to_return.update( + {k: v + for k, v in state_dict.items() if 'llm.' in k}) + # Step 3. Projector + to_return.update( + {k: v + for k, v in state_dict.items() if 'projector.' in k}) + # projector text2vision + to_return.update( + {k: v + for k, v in state_dict.items() if 'projector_text2vision' in k}) + + # visual_encoder.adapter_proj + if self.freeze_visual_encoder: + to_return.update( + {k: v + for k, v in state_dict.items() if 'visual_encoder.adapter_proj' in k}) + + # git_clip lora + if hasattr(self.visual_encoder, 'clip_model'): + if self.visual_encoder.clip_lora is not None: + to_return.update( + get_peft_model_state_dict(self.visual_encoder.clip_model, + state_dict=state_dict)) + # omg decoder for llm + if self.require_omg_decoder: + to_return.update( + {k: v + for k, v in state_dict.items() + if 'visual_encoder.panoptic_head.transformer_decoder_llm' in k or + 'visual_encoder.panoptic_head.mask_embed_llm' in k or + 'visual_encoder.panoptic_head.pixel_decoder_llm' in k or + 'visual_encoder.panoptic_head.additional_cross_attn_layers' in k or + 'visual_encoder.panoptic_head.additional_ffn' in k or + 'visual_encoder.downsample_layer' in k + }) + + # seg tokens hidden states merge proj + if self.require_omg_decoder: + to_return.update( + {k: v + for k, v in state_dict.items() + if 'seg_token_proj' in k + }) + return to_return + + def _build_from_cfg_or_module(self, cfg_or_mod): + if isinstance(cfg_or_mod, nn.Module): + return cfg_or_mod + elif isinstance(cfg_or_mod, dict): + traverse_dict(cfg_or_mod) + return BUILDER.build(cfg_or_mod) + else: + raise NotImplementedError + + def forward(self, data, data_samples=None, mode='loss'): + + if 'pixel_values' in data: + if 'masks' in data: + masks = data['masks'] + del data['masks'] + else: + masks = None + if 'regions' in data: + regions = data['regions'] + del data['regions'] + else: + regions = None + if 'points' in data: + points = data['points'] + del data['points'] + else: + points = None + + visual_outputs = self.visual_encoder( + data['pixel_values'].to(self.visual_encoder.dtype), + output_hidden_states=True) + + if self.add_seg_pretrain: + pred_obj_query, gt_obj_query = prepare_seg_pretrain_data( + visual_outputs, + [self.projector.model.query_proj, self.projector.model.model], + self.projector_text2vision.model + ) + + if isinstance(visual_outputs, list) or isinstance(visual_outputs, tuple)\ + or isinstance(visual_outputs, torch.Tensor): + pixel_values = self.projector(visual_outputs) + else: + pixel_values = self.projector( + visual_outputs.hidden_states[self.visual_select_layer][:, 1:]) + + if regions is not None: + region_embeddings, region_success = self.get_region_embeddings( + regions, data['input_ids'], + ) + del regions + else: + region_success = True + region_embeddings = [] + + if points is not None: + points_mark_embedding, mark_success = self.get_points_embeddings( + points, data['input_ids'], + width=data['pixel_values'].shape[-1], + height=data['pixel_values'].shape[-2], + ) + else: + points_mark_embedding = [] + mark_success = True + + data['pixel_values'] = pixel_values + data = prepare_inputs_labels_for_multimodal_with_visual_prompts( + llm=self.llm, region_id=self.region_token_idx, + regions_feats=region_embeddings, + mark_id=self.mark_token_idx, + mark_feats=points_mark_embedding, + **data) + else: + masks = None + + if mode == 'loss': + if self.add_seg_pretrain: + return self.compute_loss(data, data_samples, masks=masks, region_success=region_success, + pred_gt_obj_query=(pred_obj_query, gt_obj_query), + mark_success=mark_success) + else: + return self.compute_loss(data, data_samples, masks=masks, + pred_gt_obj_query=None, + region_success=region_success, + mark_success=mark_success) + elif mode == 'predict': + return self.predict(data, data_samples) + elif mode == 'tensor': + return self._forward(data, data_samples) + else: + raise NotImplementedError + + def _forward(self, data, data_samples=None): + + outputs = self.llm(**data) + + return outputs + + def predict(self, data, data_samples=None): + outputs = self.llm(**data) + logits_dict = [{'logits': logits} for logits in outputs.logits] + return logits_dict + + def compute_loss(self, data, data_samples=None, masks=None, pred_gt_obj_query=None, + region_success=True, mark_success=True): + if 'original_labels' in data.keys(): + input_ids = data['original_labels'] + del data['original_labels'] + else: + input_ids = data['labels'] + outputs = self.llm(**data, output_hidden_states=True) + + if self.using_multilayer_states: + loss_dice, loss_mask = self.compute_seg_loss_multiple_states( + input_ids, outputs.hidden_states, masks, merge_type=self.seg_token_merge_type) + else: + loss_dice, loss_mask = self.compute_seg_loss( + input_ids, outputs.hidden_states[-1], masks) + + if pred_gt_obj_query is not None: + pred_obj_query, gt_obj_query = pred_gt_obj_query + proj_loss = torch.mean((pred_obj_query - gt_obj_query) ** 2) * 10 + else: + proj_loss = 0 + + if not region_success: + loss = outputs.loss * 0 + else: + loss = outputs.loss + + if not mark_success: + loss = outputs.loss * 0 + + loss = loss + self.get_visual_prompts_projector_zero() + + loss_dict = {'loss': loss, 'loss_dice': outputs.loss* 0 + loss_dice * 0.1, + 'loss_mask': outputs.loss * 0 + loss_mask * 0.4, + 'loss_proj': outputs.loss * 0 + proj_loss} + return loss_dict + + def __getattr__(self, name: str): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.llm, name) + + def get_region_embeddings(self, regions, input_ids): + success = True + if regions is None or len(regions) == 0: + return [], success + else: + region_token_mask = input_ids == self.region_token_idx + batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to( + input_ids.device) + batch_idxs = batch_idxs[region_token_mask] # (N, ) batch_size number + if len(regions) != len(batch_idxs): + # There is a bug !!! skip it. + success = False + if len(regions) > len(batch_idxs): + regions = regions[:len(batch_idxs)] + else: + n_pad = len(batch_idxs) - len(regions) + pad_region = regions[:1].repeat(n_pad, 1, 1) + regions = torch.cat([pad_region, regions]) + + regions_embeddings = self.visual_encoder.forward_region_sam( + regions, batch_idxs + )[:, 0] # (N, C) + + # regions_embeddings = regions_embeddings.to(self.projector.model.query_proj.weight.dtype) + # regions_embeddings = self.projector.model.query_proj(regions_embeddings) + # regions_embeddings = self.projector.model.model(regions_embeddings) + regions_embeddings = self.projector.model.forward_visual_prompts_embeddings( + regions_embeddings, batch_idxs) + return regions_embeddings, success # (N, C) + + def get_points_embeddings(self, points, input_ids, width, height): + success = True + if points is None or len(points) == 0: + return [] + + mark_token_mask = input_ids == self.mark_token_idx + batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to( + input_ids.device) + batch_idxs = batch_idxs[mark_token_mask] # (N, ) batch_size number + + if len(points) != len(batch_idxs): + # There is a bug !!! skip it. + success = False + if len(points) > len(batch_idxs): + points = points[:len(batch_idxs)] + else: + n_pad = len(batch_idxs) - len(points) + pad_region = points[:1].repeat(n_pad, 1, 1) + points = torch.cat([pad_region, points]) + + marks_embeddings = self.visual_encoder.forward_point_sam( + points, batch_idxs, width=width, height=height + )[:, 0] # (N, C) + + # marks_embeddings = marks_embeddings.to(self.projector.model.query_proj.weight.dtype) + # marks_embeddings = self.projector.model.query_proj(marks_embeddings) + # marks_embeddings = self.projector.model.model(marks_embeddings) + + marks_embeddings = self.projector.model.forward_visual_prompts_embeddings( + marks_embeddings, batch_idxs) + return marks_embeddings, success # (N, C) + + def get_visual_prompts_projector_zero(self): + return self.projector.model.visual_prompt_zero + + def compute_seg_loss(self, input_ids, hidden_states, gt_masks): + if not self.text2vision_projector or self.add_seg_pretrain: + return 0.0, 0.0 + success = True + if gt_masks is None or len(gt_masks) == 0: + batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to( + input_ids.device) + batch_idxs = batch_idxs[0, :1] # (N, ) batch_size number + gt_masks = [None] + hidden_states = hidden_states[0, :1] + hidden_states = self.projector_text2vision(hidden_states) # (N, C) + + pred_masks_list = self.visual_encoder.forward_llm_seg(hidden_states, batch_idxs) + dice_loss, mask_loss = self.visual_encoder.loss_llm_seg(pred_masks_list, gt_masks) + + return dice_loss * 0.0, mask_loss * 0.0 + + + seg_tokens_mask = input_ids == self.seg_token_idx + batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to(seg_tokens_mask.device) + + ori_hidden_states = hidden_states + hidden_states = hidden_states[seg_tokens_mask] + batch_idxs = batch_idxs[seg_tokens_mask] # (N, ) batch_size number + + if len(hidden_states) != len(gt_masks) or len(hidden_states) == 0: + # drop this batch + print("Drop the batch because the number of [SEG] and masks not equal !!!") + hidden_states = ori_hidden_states + batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to( + input_ids.device) + batch_idxs = batch_idxs[0, :1] # (N, ) batch_size number + gt_masks = [None] + hidden_states = hidden_states[0, :1] + hidden_states = self.projector_text2vision(hidden_states) # (N, C) + + pred_masks_list = self.visual_encoder.forward_llm_seg(hidden_states, batch_idxs) + dice_loss, mask_loss = self.visual_encoder.loss_llm_seg(pred_masks_list, gt_masks) + + return dice_loss * 0.0, mask_loss * 0.0 + + assert len(hidden_states) == len(gt_masks), "expect [seg] number equal to mask number, but get {} [seg] and {} masks".format(len(hidden_states), len(gt_masks)) + hidden_states = self.projector_text2vision(hidden_states) # (N, C) + + pred_masks_list = self.visual_encoder.forward_llm_seg(hidden_states, batch_idxs) + dice_loss, mask_loss = self.visual_encoder.loss_llm_seg(pred_masks_list, gt_masks) + + if not success: + return dice_loss * 0.0, mask_loss * 0.0 + + return dice_loss, mask_loss + + def process_seg_tokens(self, multi_layers_hidden_states, seg_tokens_mask, merge_type): + multi_layers_hidden_states = [single_layer_hidden_states[seg_tokens_mask] \ + for single_layer_hidden_states in + multi_layers_hidden_states] + if merge_type == 'mean': + hidden_states = torch.stack(multi_layers_hidden_states, dim=0) + hidden_states = torch.mean(hidden_states, dim=0) + elif merge_type == 'cat': + hidden_states = multi_layers_hidden_states[-self.selected_layers:] + hidden_states = torch.cat(hidden_states, dim=-1) + hidden_states = self.seg_token_proj_cat(hidden_states / self.selected_layers) + elif merge_type == 'linear_cat': + hidden_states = multi_layers_hidden_states[-self.selected_layers:] + hidden_states = torch.stack(hidden_states, dim=1) + hidden_states = self.seg_token_proj_linear_cat[0](hidden_states) + hidden_states = hidden_states.flatten(1) + hidden_states = self.seg_token_proj_linear_cat[1](hidden_states) + else: + raise NotImplementedError + # hidden states (N, C) + return hidden_states + + def process_unvalid_tokens(self, multi_layers_hidden_states, merge_type): + multi_layers_hidden_states = [item[0, :1] for item in multi_layers_hidden_states] + if merge_type == 'mean': + hidden_states = torch.stack(multi_layers_hidden_states, dim=0) + hidden_states = torch.mean(hidden_states, dim=0) + elif merge_type == 'cat': + hidden_states = multi_layers_hidden_states[-self.selected_layers:] + hidden_states = torch.cat(hidden_states, dim=-1) + hidden_states = self.seg_token_proj_cat(hidden_states / self.selected_layers) + elif merge_type == 'linear_cat': + hidden_states = multi_layers_hidden_states[-self.selected_layers:] + hidden_states = torch.stack(hidden_states, dim=1) + hidden_states = self.seg_token_proj_linear_cat[0](hidden_states) + hidden_states = hidden_states.flatten(1) + hidden_states = self.seg_token_proj_linear_cat[1](hidden_states) + else: + raise NotImplementedError + # hidden states (1, C) + return hidden_states + + def compute_seg_loss_multiple_states(self, input_ids, multi_layers_hidden_states, gt_masks, merge_type='mean'): + if not self.text2vision_projector or self.add_seg_pretrain: + return 0.0, 0.0 + success = True + if gt_masks is None or len(gt_masks) == 0: + batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat( + 1, input_ids.shape[1]).to( + input_ids.device) + batch_idxs = batch_idxs[0, :1] # (N, ) batch_size number + gt_masks = [None] + hidden_states = self.process_unvalid_tokens(multi_layers_hidden_states, + merge_type=merge_type) + hidden_states = self.projector_text2vision(hidden_states) # (N, C) + + pred_masks_list = self.visual_encoder.forward_llm_seg(hidden_states, batch_idxs) + dice_loss, mask_loss = self.visual_encoder.loss_llm_seg(pred_masks_list, gt_masks) + + return dice_loss * 0.0, mask_loss * 0.0 + + + seg_tokens_mask = input_ids == self.seg_token_idx + batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to(seg_tokens_mask.device) + + ori_multi_layers_hidden_states = multi_layers_hidden_states + + hidden_states = self.process_seg_tokens( + multi_layers_hidden_states, + seg_tokens_mask, merge_type=merge_type) + + batch_idxs = batch_idxs[seg_tokens_mask] # (N, ) batch_size number + + if len(hidden_states) != len(gt_masks) or len(hidden_states) == 0: + # drop this batch + print("Drop the batch because the number of [SEG] and masks not equal !!!") + hidden_states = self.process_unvalid_tokens( + ori_multi_layers_hidden_states, + merge_type=merge_type) + batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat( + 1, input_ids.shape[1]).to( + input_ids.device) + batch_idxs = batch_idxs[0, :1] # (N, ) batch_size number + gt_masks = [None] + hidden_states = self.projector_text2vision(hidden_states) # (N, C) + + pred_masks_list = self.visual_encoder.forward_llm_seg(hidden_states, batch_idxs) + dice_loss, mask_loss = self.visual_encoder.loss_llm_seg(pred_masks_list, gt_masks) + + return dice_loss * 0.0, mask_loss * 0.0 + + assert len(hidden_states) == len(gt_masks), "expect [seg] number equal to mask number, but get {} [seg] and {} masks".format(len(hidden_states), len(gt_masks)) + hidden_states = self.projector_text2vision(hidden_states) # (N, C) + + pred_masks_list = self.visual_encoder.forward_llm_seg(hidden_states, batch_idxs) + dice_loss, mask_loss = self.visual_encoder.loss_llm_seg(pred_masks_list, gt_masks) + + if not success: + return dice_loss * 0.0, mask_loss * 0.0 + + return dice_loss, mask_loss + +def prepare_seg_pretrain_data(visual_outputs, + query_in_proj, query_out_proj): + clip_feature, query_feat, attention_mask = visual_outputs + # clip feature (bs, hw, c + 2 * q_c) + # query_feat (bs, q, 2c) + # attention_mask (bs, q, hw) + bs, q, _ = query_feat.shape + pred_query_embed = [] + gt_query_embed = [] + for i in range(bs): + valid = attention_mask[i].sum(-1) > 0 + valid_query_feat = query_feat[i][valid] # (n, 2c) + gt_query_embed.append(valid_query_feat) + + if isinstance(query_in_proj, list): + llm_query = valid_query_feat + for proj in query_in_proj: + llm_query = proj(llm_query) + else: + llm_query = query_in_proj(valid_query_feat) + + pred_query_embed.append(query_out_proj(llm_query)) + + pred_query_embed = torch.cat(pred_query_embed, dim=0) + gt_query_embed = torch.cat(gt_query_embed, dim=0) + return pred_query_embed, gt_query_embed + diff --git a/omg_llava/model/omg_seg/__init__.py b/omg_llava/model/omg_seg/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e8e5bdf507012da9a8c94ea1cfb4f81d0214eb53 --- /dev/null +++ b/omg_llava/model/omg_seg/__init__.py @@ -0,0 +1,2 @@ +from .omg_seg_visual_encoder import OMGSegVisualEncoder +from .mask2former_vid_semanticsam import Mask2FormerVideoSemSamHead \ No newline at end of file diff --git a/omg_llava/model/omg_seg/__pycache__/__init__.cpython-310.pyc b/omg_llava/model/omg_seg/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbbd36e059c802c1ca543f3bb84f78a2e7030489 Binary files /dev/null and b/omg_llava/model/omg_seg/__pycache__/__init__.cpython-310.pyc differ diff --git a/omg_llava/model/omg_seg/__pycache__/mask2former_vid.cpython-310.pyc b/omg_llava/model/omg_seg/__pycache__/mask2former_vid.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1759447f77216595149a1975865c53d3ef0aae51 Binary files /dev/null and b/omg_llava/model/omg_seg/__pycache__/mask2former_vid.cpython-310.pyc differ diff --git a/omg_llava/model/omg_seg/__pycache__/mask2former_vid_semanticsam.cpython-310.pyc b/omg_llava/model/omg_seg/__pycache__/mask2former_vid_semanticsam.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ab3a063732e6472def32c75d4c6b3a9656b8619 Binary files /dev/null and b/omg_llava/model/omg_seg/__pycache__/mask2former_vid_semanticsam.cpython-310.pyc differ diff --git a/omg_llava/model/omg_seg/__pycache__/omg_seg_visual_encoder.cpython-310.pyc b/omg_llava/model/omg_seg/__pycache__/omg_seg_visual_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e305d23dcf40eeb6058be3025ca7853772caba3 Binary files /dev/null and b/omg_llava/model/omg_seg/__pycache__/omg_seg_visual_encoder.cpython-310.pyc differ diff --git a/omg_llava/model/omg_seg/__pycache__/utils.cpython-310.pyc b/omg_llava/model/omg_seg/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2cd154c92c469d2cadcc7b883d9c127c0e2215e6 Binary files /dev/null and b/omg_llava/model/omg_seg/__pycache__/utils.cpython-310.pyc differ diff --git a/omg_llava/model/omg_seg/mask2former_vid.py b/omg_llava/model/omg_seg/mask2former_vid.py new file mode 100644 index 0000000000000000000000000000000000000000..ab5564209c0b76b36f69d8a348a16bff9bc3a545 --- /dev/null +++ b/omg_llava/model/omg_seg/mask2former_vid.py @@ -0,0 +1,319 @@ +# Copied from OMG-Seg +from typing import Dict, List, Tuple + +import torch +from mmengine.structures import InstanceData +from torch import Tensor +import torch.nn.functional as F + +from mmdet.registry import MODELS +from mmdet.structures import SampleList, OptSampleList, TrackDataSample +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from mmdet.models.detectors.single_stage import SingleStageDetector + +from .utils import mask_pool + + +@MODELS.register_module() +class Mask2formerVideo(SingleStageDetector): + r"""Implementation of `Per-Pixel Classification is + NOT All You Need for Semantic Segmentation + `_.""" + OVERLAPPING = None + + def __init__(self, + backbone: ConfigType, + neck: OptConfigType = None, + panoptic_head: OptConfigType = None, + panoptic_fusion_head: OptConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + inference_sam: bool = False, + init_cfg: OptMultiConfig = None + ): + super(SingleStageDetector, self).__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + self.backbone = MODELS.build(backbone) + if neck is not None: + self.neck = MODELS.build(neck) + + panoptic_head_ = panoptic_head.deepcopy() + panoptic_head_.update(train_cfg=train_cfg) + panoptic_head_.update(test_cfg=test_cfg) + self.panoptic_head_cfg = panoptic_head_ + self.panoptic_head = MODELS.build(panoptic_head_) + + panoptic_fusion_head_ = panoptic_fusion_head.deepcopy() + panoptic_fusion_head_.update(test_cfg=test_cfg) + self.panoptic_fusion_head = MODELS.build(panoptic_fusion_head_) + + self.num_things_classes = self.panoptic_head.num_things_classes + self.num_stuff_classes = self.panoptic_head.num_stuff_classes + self.num_classes = self.panoptic_head.num_classes + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + self.alpha = 0.4 + self.beta = 0.8 + + self.inference_sam = inference_sam + + def loss(self, batch_inputs: Tensor, + batch_data_samples: SampleList) -> Dict[str, Tensor]: + """ + Args: + batch_inputs (Tensor): Input images of shape (N, C, H, W). + These should usually be mean centered and std scaled. + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + if isinstance(batch_data_samples[0], TrackDataSample): + bs, num_frames, three, h, w = batch_inputs.shape + assert three == 3, "Only supporting images with 3 channels." + + x = batch_inputs.reshape((bs * num_frames, three, h, w)) + x = self.extract_feat(x) + else: + x = self.extract_feat(batch_inputs) + losses = self.panoptic_head.loss(x, batch_data_samples) + return losses + + def predict(self, + batch_inputs: Tensor, + batch_data_samples: SampleList, + rescale: bool = True) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing. + + Args: + batch_inputs (Tensor): Inputs with shape (N, C, H, W). + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + rescale (bool): Whether to rescale the results. + Defaults to True. + + Returns: + list[:obj:`DetDataSample`]: Detection results of the + input images. Each DetDataSample usually contain + 'pred_instances' and `pred_panoptic_seg`. And the + ``pred_instances`` usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + + And the ``pred_panoptic_seg`` contains the following key + + - sem_seg (Tensor): panoptic segmentation mask, has a + shape (1, h, w). + """ + if isinstance(batch_data_samples[0], TrackDataSample): + bs, num_frames, three, h, w = batch_inputs.shape + assert three == 3, "Only supporting images with 3 channels." + x = batch_inputs.reshape((bs * num_frames, three, h, w)) + feats = self.extract_feat(x) + else: + num_frames = 0 + bs = batch_inputs.shape[0] + feats = self.extract_feat(batch_inputs) + + # in case no queries are provided for prompt. + if self.inference_sam and len(batch_data_samples[0].gt_instances) == 0: + for idx, data_sample in enumerate(batch_data_samples): + results = InstanceData() + data_sample.pred_instances = results + return batch_data_samples + + mask_cls_results, mask_pred_results, iou_results = self.panoptic_head.predict(feats, batch_data_samples) + + if self.OVERLAPPING is not None: + assert len(self.OVERLAPPING) == self.num_classes + mask_cls_results = self.open_voc_inference(feats, mask_cls_results, mask_pred_results) + + if self.inference_sam: + for idx, data_sample in enumerate(batch_data_samples): + results = InstanceData() + mask = mask_pred_results[idx] + img_height, img_width = data_sample.metainfo['img_shape'][:2] + mask = mask[:, :img_height, :img_width] + ori_height, ori_width = data_sample.metainfo['ori_shape'][:2] + mask = F.interpolate( + mask[:, None], + size=(ori_height, ori_width), + mode='bilinear', + align_corners=False)[:, 0] + results.masks = mask.sigmoid() > 0.5 + data_sample.pred_instances = results + return batch_data_samples + + if num_frames > 0: + for frame_id in range(num_frames): + results_list_img = self.panoptic_fusion_head.predict( + mask_cls_results, + mask_pred_results[:, :, frame_id], + [batch_data_samples[idx][frame_id] for idx in range(bs)], + rescale=rescale + ) + _ = self.add_track_pred_to_datasample( + [batch_data_samples[idx][frame_id] for idx in range(bs)], results_list_img + ) + results = batch_data_samples + else: + results_list = self.panoptic_fusion_head.predict( + mask_cls_results, + mask_pred_results, + batch_data_samples, + iou_results=iou_results, + rescale=rescale + ) + results = self.add_pred_to_datasample(batch_data_samples, results_list) + + return results + + def add_pred_to_datasample(self, data_samples: SampleList, + results_list: List[dict]) -> SampleList: + """Add predictions to `DetDataSample`. + + Args: + data_samples (list[:obj:`DetDataSample`], optional): A batch of + data samples that contain annotations and predictions. + results_list (List[dict]): Instance segmentation, segmantic + segmentation and panoptic segmentation results. + + Returns: + list[:obj:`DetDataSample`]: Detection results of the + input images. Each DetDataSample usually contain + 'pred_instances' and `pred_panoptic_seg`. And the + ``pred_instances`` usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + + And the ``pred_panoptic_seg`` contains the following key + + - sem_seg (Tensor): panoptic segmentation mask, has a + shape (1, h, w). + """ + for data_sample, pred_results in zip(data_samples, results_list): + if 'pan_results' in pred_results: + data_sample.pred_panoptic_seg = pred_results['pan_results'] + + if 'ins_results' in pred_results: + data_sample.pred_instances = pred_results['ins_results'] + + assert 'sem_results' not in pred_results + + return data_samples + + def add_track_pred_to_datasample(self, data_samples: SampleList, results_list: List[dict]) -> SampleList: + for data_sample, pred_results in zip(data_samples, results_list): + if 'pan_results' in pred_results: + assert self.num_stuff_classes > 0 + pred_results['pan_results'].sem_seg = pred_results['pan_results'].sem_seg.cpu() + data_sample.pred_track_panoptic_seg = pred_results['pan_results'] + + if 'ins_results' in pred_results: + bboxes = pred_results['ins_results']['bboxes'] + labels = pred_results['ins_results']['labels'] + track_ids = torch.arange(len(bboxes), dtype=labels.dtype, device=bboxes.device) + 1 + pred_results['ins_results']['instances_id'] = track_ids + data_sample.pred_track_instances = pred_results['ins_results'] + + if 'pro_results' in pred_results: + data_sample.pred_track_proposal = pred_results['pro_results'] + + assert 'sem_results' not in pred_results + + return data_samples + + def _forward( + self, + batch_inputs: Tensor, + batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]: + """Network forward process. Usually includes backbone, neck and head + forward without any post-processing. + + Args: + batch_inputs (Tensor): Inputs with shape (N, C, H, W). + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + tuple[List[Tensor]]: A tuple of features from ``panoptic_head`` + forward. + """ + if isinstance(batch_data_samples[0], TrackDataSample): + bs, num_frames, three, h, w = batch_inputs.shape + assert three == 3, "Only supporting images with 3 channels." + + x = batch_inputs.reshape((bs * num_frames, three, h, w)) + feats = self.extract_feat(x) + else: + feats = self.extract_feat(batch_inputs) + results = self.panoptic_head.forward(feats, batch_data_samples) + return results + + def open_voc_inference(self, feats, mask_cls_results, mask_pred_results): + if len(mask_pred_results.shape) == 5: + batch_size = mask_cls_results.shape[0] + num_frames = mask_pred_results.shape[2] + mask_pred_results = mask_pred_results.permute(0, 2, 1, 3, 4).flatten(0, 1) + else: + batch_size = mask_cls_results.shape[0] + num_frames = 0 + clip_feat = self.backbone.get_clip_feature(feats[-1]) + clip_feat_mask = F.interpolate( + mask_pred_results, + size=clip_feat.shape[-2:], + mode='bilinear', + align_corners=False + ) + if num_frames > 0: + clip_feat_mask = clip_feat_mask.unflatten(0, (batch_size, num_frames)).permute(0, 2, 1, 3, 4).flatten(2, 3) + clip_feat = clip_feat.unflatten(0, (batch_size, num_frames)).permute(0, 2, 1, 3, 4).flatten(2, 3) + instance_feat = mask_pool(clip_feat, clip_feat_mask) + instance_feat = self.backbone.forward_feat(instance_feat) + clip_logit = self.panoptic_head.forward_logit(instance_feat) + clip_logit = clip_logit[..., :-1] + query_logit = mask_cls_results[..., :-1] + + clip_logit = clip_logit.softmax(-1) + query_logit = query_logit.softmax(-1) + overlapping_mask = torch.tensor(self.OVERLAPPING, dtype=torch.float32, device=clip_logit.device) + + valid_masking = ((clip_feat_mask > 0).to(dtype=torch.float32).flatten(-2).sum(-1) > 0).to( + torch.float32)[..., None] + alpha = torch.ones_like(clip_logit) * self.alpha * valid_masking + beta = torch.ones_like(clip_logit) * self.beta * valid_masking + + cls_logits_seen = ( + (query_logit ** (1 - alpha) * clip_logit ** alpha).log() + * overlapping_mask + ) + cls_logits_unseen = ( + (query_logit ** (1 - beta) * clip_logit ** beta).log() + * (1 - overlapping_mask) + ) + cls_results = cls_logits_seen + cls_logits_unseen + is_void_prob = F.softmax(mask_cls_results, dim=-1)[..., -1:] + mask_cls_results = torch.cat([ + cls_results.softmax(-1) * (1.0 - is_void_prob), is_void_prob], dim=-1) + mask_cls_results = torch.log(mask_cls_results + 1e-8) + return mask_cls_results diff --git a/omg_llava/model/omg_seg/mask2former_vid_semanticsam.py b/omg_llava/model/omg_seg/mask2former_vid_semanticsam.py new file mode 100644 index 0000000000000000000000000000000000000000..6eb9ddc7404d21673da1652bfa7c4da51385192c --- /dev/null +++ b/omg_llava/model/omg_seg/mask2former_vid_semanticsam.py @@ -0,0 +1,1876 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +from mmcv.cnn import Conv2d +from mmcv.ops import point_sample +from mmdet.models import Mask2FormerTransformerDecoder, inverse_sigmoid, coordinate_to_encoding +from mmdet.structures.bbox import bbox_xyxy_to_cxcywh +from mmengine.dist import get_dist_info +from mmengine.model import caffe2_xavier_init, ModuleList +from mmengine.structures import InstanceData, PixelData +from torch import Tensor + +from mmdet.registry import MODELS, TASK_UTILS +from mmdet.structures import SampleList, TrackDataSample +from mmdet.utils import (ConfigType, InstanceList, OptConfigType, + OptMultiConfig, reduce_mean) +from mmdet.models.layers import SinePositionalEncoding3D +from mmdet.models.utils import multi_apply, preprocess_panoptic_gt, get_uncertain_point_coords_with_randomness +from mmdet.models.dense_heads.anchor_free_head import AnchorFreeHead +from types import MethodType +from .utils import mask_pool +preprocess_video_panoptic_gt = None +from kornia.contrib import distance_transform +from omg_llava.model.modules.projector.modeling_projector import CrossAttentionLayer, FFNLayer + +def forward_cache(self, + query: Tensor, + key: Tensor = None, + value: Tensor = None, + query_pos: Tensor = None, + key_pos: Tensor = None, + self_attn_mask: Tensor = None, + cross_attn_mask: Tensor = None, + key_padding_mask: Tensor = None, + return_query_cache=False, + query_cache=None, + query_cache_pos=None, + **kwargs) -> Tensor: + """ + Args: + query (Tensor): The input query, has shape (bs, num_queries, dim). + key (Tensor, optional): The input key, has shape (bs, num_keys, + dim). If `None`, the `query` will be used. Defaults to `None`. + value (Tensor, optional): The input value, has the same shape as + `key`, as in `nn.MultiheadAttention.forward`. If `None`, the + `key` will be used. Defaults to `None`. + query_pos (Tensor, optional): The positional encoding for `query`, + has the same shape as `query`. If not `None`, it will be added + to `query` before forward function. Defaults to `None`. + key_pos (Tensor, optional): The positional encoding for `key`, has + the same shape as `key`. If not `None`, it will be added to + `key` before forward function. If None, and `query_pos` has the + same shape as `key`, then `query_pos` will be used for + `key_pos`. Defaults to None. + self_attn_mask (Tensor, optional): ByteTensor mask, has shape + (num_queries, num_keys), as in `nn.MultiheadAttention.forward`. + Defaults to None. + cross_attn_mask (Tensor, optional): ByteTensor mask, has shape + (num_queries, num_keys), as in `nn.MultiheadAttention.forward`. + Defaults to None. + key_padding_mask (Tensor, optional): The `key_padding_mask` of + `self_attn` input. ByteTensor, has shape (bs, num_value). + Defaults to None. + + Returns: + Tensor: forwarded results, has shape (bs, num_queries, dim). + """ + if return_query_cache: + assert query_cache is None + + query = self.cross_attn( + query=query, + key=key, + value=value, + query_pos=query_pos, + key_pos=key_pos, + attn_mask=cross_attn_mask, + key_padding_mask=key_padding_mask, + **kwargs) + query = self.norms[0](query) + + if return_query_cache: + ret_query_cache = query + ret_query_cache_pos = query_pos + if query_cache is not None: + query = torch.cat([query, query_cache], dim=1) + query_pos = torch.cat([query_pos, query_cache_pos], dim=1) + # print(query.shape, query_pos.shape) + query = self.self_attn( + query=query, + key=query, + value=query, + query_pos=query_pos, + key_pos=query_pos, + attn_mask=self_attn_mask, + **kwargs) + if query_cache is not None: + query = query[:, :-query_cache.shape[1], :] + query_pos = query_pos[:, :-query_cache.shape[1], :] + query = self.norms[1](query) + query = self.ffn(query) + query = self.norms[2](query) + if return_query_cache: + return query, ret_query_cache, ret_query_cache_pos + return query + + +@MODELS.register_module() +class Mask2FormerVideoSemSamHead(AnchorFreeHead): + """Implements the Mask2Former head. + + See `Masked-attention Mask Transformer for Universal Image + Segmentation `_ for details. + + Args: + in_channels (list[int]): Number of channels in the input feature map. + feat_channels (int): Number of channels for features. + out_channels (int): Number of channels for output. + num_things_classes (int): Number of things. + num_stuff_classes (int): Number of stuff. + num_queries (int): Number of query in Transformer decoder. + pixel_decoder (:obj:`ConfigDict` or dict): Config for pixel + decoder. Defaults to None. + enforce_decoder_input_project (bool, optional): Whether to add + a layer to change the embed_dim of tranformer encoder in + pixel decoder to the embed_dim of transformer decoder. + Defaults to False. + transformer_decoder (:obj:`ConfigDict` or dict): Config for + transformer decoder. Defaults to None. + positional_encoding (:obj:`ConfigDict` or dict): Config for + transformer decoder position encoding. Defaults to + dict(num_feats=128, normalize=True). + loss_cls (:obj:`ConfigDict` or dict): Config of the classification + loss. Defaults to None. + loss_mask (:obj:`ConfigDict` or dict): Config of the mask loss. + Defaults to None. + loss_dice (:obj:`ConfigDict` or dict): Config of the dice loss. + Defaults to None. + train_cfg (:obj:`ConfigDict` or dict, optional): Training config of + Mask2Former head. + test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of + Mask2Former head. + init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ + dict], optional): Initialization config dict. Defaults to None. + """ + + def __init__(self, + in_channels: List[int], + feat_channels: int, + out_channels: int, + num_things_classes: int = 80, + num_stuff_classes: int = 53, + num_queries: int = 100, + num_transformer_feat_level: int = 3, + pixel_decoder: ConfigType = ..., + enforce_decoder_input_project: bool = False, + transformer_decoder: ConfigType = ..., + positional_encoding: ConfigType = None, + loss_cls: ConfigType = None, + loss_mask: ConfigType = None, + loss_dice: ConfigType = None, + loss_iou: ConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + init_cfg: OptMultiConfig = None, + # ov configs + sphere_cls: bool = False, + ov_classifier_name: Optional[str] = None, + logit: Optional[int] = None, + # box sup + matching_whole_map: bool = False, + ov_path=None, + **kwargs) -> None: + super(AnchorFreeHead, self).__init__(init_cfg=init_cfg) + enable_box_query = True + self.feat_channels = feat_channels + self.out_mask_channel = out_channels + self.num_things_classes = num_things_classes + self.num_stuff_classes = num_stuff_classes + self.num_classes = self.num_things_classes + self.num_stuff_classes + self.num_queries = num_queries + self.num_transformer_feat_level = num_transformer_feat_level + self.num_heads = transformer_decoder.layer_cfg.cross_attn_cfg.num_heads + self.num_transformer_decoder_layers = transformer_decoder.num_layers + assert pixel_decoder.encoder.layer_cfg.\ + self_attn_cfg.num_levels == num_transformer_feat_level + pixel_decoder_ = copy.deepcopy(pixel_decoder) + pixel_decoder_.update( + in_channels=in_channels, + feat_channels=feat_channels, + out_channels=out_channels) + + self.pixel_decoder_cfg = pixel_decoder_ + + self.pixel_decoder = MODELS.build(pixel_decoder_) + self.transformer_decoder_cfg = transformer_decoder + self.transformer_decoder = Mask2FormerTransformerDecoder( + **transformer_decoder) + self.decoder_embed_dims = self.transformer_decoder.embed_dims + + self.decoder_input_projs = ModuleList() + # from low resolution to high resolution + for _ in range(num_transformer_feat_level): + if (self.decoder_embed_dims != feat_channels + or enforce_decoder_input_project): + self.decoder_input_projs.append( + Conv2d( + feat_channels, self.decoder_embed_dims, kernel_size=1)) + else: + self.decoder_input_projs.append(nn.Identity()) + self.decoder_positional_encoding = SinePositionalEncoding3D( + **positional_encoding) + self.query_embed = nn.Embedding(self.num_queries, feat_channels) + self.query_feat = nn.Embedding(self.num_queries, feat_channels) + # from low resolution to high resolution + self.level_embed = nn.Embedding(self.num_transformer_feat_level, + feat_channels) + + if not sphere_cls: + self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1) + self.mask_embed = nn.Sequential( + nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), + nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), + nn.Linear(feat_channels, out_channels)) + + if loss_iou is not None: + self.iou_embed = nn.Sequential( + nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), + nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), + nn.Linear(feat_channels, 1)) + else: + self.iou_embed = None + + self.test_cfg = test_cfg + self.train_cfg = train_cfg + if train_cfg: + self.assigner = TASK_UTILS.build(self.train_cfg['assigner']) + self.sampler = TASK_UTILS.build( + self.train_cfg['sampler'], default_args=dict(context=self)) + self.num_points = self.train_cfg.get('num_points', 12544) + self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0) + self.importance_sample_ratio = self.train_cfg.get( + 'importance_sample_ratio', 0.75) + + self.class_weight = loss_cls.class_weight + self.loss_cls = MODELS.build(loss_cls) + self.loss_mask = MODELS.build(loss_mask) + self.loss_dice = MODELS.build(loss_dice) + if loss_iou is not None: + self.loss_iou = MODELS.build(loss_iou) + else: + self.loss_iou = None + + # prepare OV things + # OV cls embed + if sphere_cls: + rank, world_size = get_dist_info() + if ov_classifier_name is None: + _dim = 1024 # temporally hard code + cls_embed = torch.empty(self.num_classes, _dim) + torch.nn.init.orthogonal_(cls_embed) + cls_embed = cls_embed[:, None] + else: + if ov_path is None: + ov_path = os.path.join(os.path.expanduser('~/.cache/embd'), f"{ov_classifier_name}.pth") + else: + ov_path = ov_path + cls_embed = torch.load(ov_path) + cls_embed_norm = cls_embed.norm(p=2, dim=-1) + assert torch.allclose(cls_embed_norm, torch.ones_like(cls_embed_norm)) + if self.loss_cls and self.loss_cls.use_sigmoid: + pass + else: + _dim = cls_embed.size(2) + _prototypes = cls_embed.size(1) + + if rank == 0: + back_token = torch.zeros(1, _dim, dtype=torch.float32, device='cuda') + else: + back_token = torch.empty(1, _dim, dtype=torch.float32, device='cuda') + if world_size > 1: + dist.broadcast(back_token, src=0) + back_token = back_token.to(device='cpu') + cls_embed = torch.cat([ + cls_embed, back_token.repeat(_prototypes, 1)[None] + ], dim=0) + self.register_buffer('cls_embed', cls_embed.permute(2, 0, 1).contiguous(), persistent=False) + + # cls embd proj + cls_embed_dim = self.cls_embed.size(0) + self.cls_proj = nn.Sequential( + nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), + nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), + nn.Linear(feat_channels, cls_embed_dim) + ) + + # Haobo Yuan: + # For the logit_scale, I refer to this issue. + # https://github.com/openai/CLIP/issues/46#issuecomment-945062212 + # https://github.com/openai/CLIP/issues/46#issuecomment-782558799 + # Based on my understanding, it is a mistake of CLIP. + # Because they mention that they refer to InstDisc (Wu, 2018) paper. + # InstDisc set a non-learnable temperature to np.log(1 / 0.07). + # 4.6052 is np.log(1 / 0.01) + # np.log(1 / 0.07) will be fast converged to np.log(1 / 0.01) + if logit is None: + logit_scale = torch.tensor(4.6052, dtype=torch.float32) + else: + logit_scale = torch.tensor(logit, dtype=torch.float32) + self.register_buffer('logit_scale', logit_scale, persistent=False) + + # Mask Pooling + self.mask_pooling = mask_pool + self.mask_pooling_proj = nn.Sequential( + nn.LayerNorm(feat_channels), + nn.Linear(feat_channels, feat_channels) + ) + + # box inst + self.matching_whole_map = matching_whole_map + + # enable box query + self.enable_box_query = enable_box_query + if self.enable_box_query: + self.num_mask_tokens = 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, feat_channels) + self.pb_embedding = nn.Embedding(2, feat_channels) + self.pos_linear = nn.Linear(2 * feat_channels, feat_channels) + + self.transformer_decoder_llm = None + self.mask_embed_llm = None + self.pixel_decoder_llm = None + + self.additional_cross_attn_layers = None + + def init_new_decoder(self): + if self.transformer_decoder_llm is not None: + return + dtype = self.query_embed.weight.dtype + device = self.query_embed.weight.device + self.transformer_decoder_llm_dtype = dtype + self.transformer_decoder_llm = Mask2FormerTransformerDecoder( + **self.transformer_decoder_cfg).to(dtype).to(device) + self.transformer_decoder_llm.load_state_dict(self.transformer_decoder.state_dict(), strict=True) + for name, param in self.transformer_decoder_llm.named_parameters(): + param.requires_grad_(True) + print("Init transformer_decoder_llm and resume omg seg decoder weight and not frozen !!!") + + self.mask_embed_llm =\ + nn.Sequential( + nn.Linear(self.feat_channels, self.feat_channels), nn.ReLU(inplace=True), + nn.Linear(self.feat_channels, self.feat_channels), nn.ReLU(inplace=True), + nn.Linear(self.feat_channels, self.out_mask_channel)).to(dtype).to(device) + self.mask_embed_llm.load_state_dict(self.mask_embed.state_dict(), strict=True) + for name, param in self.mask_embed_llm.named_parameters(): + param.requires_grad_(True) + print("Init mask_embed_llm and resume omg seg weight and not frozen !!!") + return + + def init_cross_attn_layer(self): + if self.additional_cross_attn_layers is not None: + return + dtype = self.query_embed.weight.dtype + device = self.query_embed.weight.device + self.additional_cross_attn_layers = CrossAttentionLayer( + self.decoder_embed_dims, self.num_heads, dropout=0.0, + activation="relu", normalize_before=False + ).to(dtype).to(device) + self.additional_ffn = FFNLayer(self.decoder_embed_dims) + + for name, param in self.additional_cross_attn_layers.named_parameters(): + param.requires_grad_(True) + for name, param in self.additional_ffn.named_parameters(): + param.requires_grad_(True) + print("Init additional cross attn layer and ffn layer !!!") + return + + def init_weights(self) -> None: + for m in self.decoder_input_projs: + if isinstance(m, Conv2d): + caffe2_xavier_init(m, bias=0) + + self.pixel_decoder.init_weights() + + for p in self.transformer_decoder.parameters(): + if p.dim() > 1: + nn.init.xavier_normal_(p) + + def preprocess_gt( + self, batch_gt_instances: InstanceList, + batch_gt_semantic_segs: List[Optional[PixelData]]) -> InstanceList: + """Preprocess the ground truth for all images. + + Args: + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``labels``, each is + ground truth labels of each bbox, with shape (num_gts, ) + and ``masks``, each is ground truth masks of each instances + of a image, shape (num_gts, h, w). + batch_gt_semantic_segs (list[Optional[PixelData]]): Ground truth of + semantic segmentation, each with the shape (1, h, w). + [0, num_thing_class - 1] means things, + [num_thing_class, num_class-1] means stuff, + 255 means VOID. It's None when training instance segmentation. + + Returns: + list[obj:`InstanceData`]: each contains the following keys + + - labels (Tensor): Ground truth class indices\ + for a image, with shape (n, ), n is the sum of\ + number of stuff type and number of instance in a image. + - masks (Tensor): Ground truth mask for a\ + image, with shape (n, h, w). + """ + num_things_list = [self.num_things_classes] * len(batch_gt_instances) + num_stuff_list = [self.num_stuff_classes] * len(batch_gt_instances) + if isinstance(batch_gt_instances[0], List): + gt_labels_list = [ + [torch.stack([torch.ones_like(gt_instances['labels']) * frame_id, gt_instances['labels']], dim=1) + for frame_id, gt_instances in enumerate(gt_vid_instances)] + for gt_vid_instances in batch_gt_instances + ] + gt_labels_list = [torch.cat(gt_labels, dim=0) for gt_labels in gt_labels_list] + gt_masks_list = [ + [gt_instances['masks'] for gt_instances in gt_vid_instances] + for gt_vid_instances in batch_gt_instances + ] + gt_semantic_segs = [ + [None if gt_semantic_seg is None else gt_semantic_seg.sem_seg + for gt_semantic_seg in gt_vid_semantic_segs] + for gt_vid_semantic_segs in batch_gt_semantic_segs + ] + if gt_semantic_segs[0][0] is None: + gt_semantic_segs = [None] * len(batch_gt_instances) + else: + gt_semantic_segs = [torch.stack(gt_sem_seg, dim=0) for gt_sem_seg in gt_semantic_segs] + gt_instance_ids_list = [ + [torch.stack([torch.ones_like(gt_instances['instances_ids']) * frame_id, gt_instances['instances_ids']], + dim=1) + for frame_id, gt_instances in enumerate(gt_vid_instances)] + for gt_vid_instances in batch_gt_instances + ] + gt_instance_ids_list = [torch.cat(gt_instance_ids, dim=0) for gt_instance_ids in gt_instance_ids_list] + targets = multi_apply(preprocess_video_panoptic_gt, gt_labels_list, + gt_masks_list, gt_semantic_segs, gt_instance_ids_list, + num_things_list, num_stuff_list) + else: + gt_labels_list = [ + gt_instances['labels'] for gt_instances in batch_gt_instances + ] + gt_masks_list = [ + gt_instances['masks'] for gt_instances in batch_gt_instances + ] + gt_semantic_segs = [ + None if gt_semantic_seg is None else gt_semantic_seg.sem_seg + for gt_semantic_seg in batch_gt_semantic_segs + ] + targets = multi_apply(preprocess_panoptic_gt, gt_labels_list, + gt_masks_list, gt_semantic_segs, num_things_list, + num_stuff_list) + labels, masks = targets + batch_gt_instances = [ + InstanceData(labels=label, masks=mask) + for label, mask in zip(labels, masks) + ] + return batch_gt_instances + + def get_queries(self, batch_data_samples): + img_size = batch_data_samples[0].batch_input_shape + query_feat_list = [] + bp_list = [] + for idx, data_sample in enumerate(batch_data_samples): + is_box = data_sample.gt_instances.bp.eq(0) + is_point = data_sample.gt_instances.bp.eq(1) + assert is_box.any() + sparse_embed, _ = self.pe( + data_sample.gt_instances[is_box], + image_size=img_size, + with_bboxes=True, + with_points=False, + ) + sparse_embed = [sparse_embed] + if is_point.any(): + _sparse_embed, _ = self.pe( + data_sample.gt_instances[is_point], + image_size=img_size, + with_bboxes=False, + with_points=True, + ) + sparse_embed.append(_sparse_embed) + sparse_embed = torch.cat(sparse_embed) + assert len(sparse_embed) == len(data_sample.gt_instances) + + query_feat_list.append(self.query_proj(sparse_embed.flatten(1, 2))) + bp_list.append(data_sample.gt_instances.bp) + + query_feat = torch.stack(query_feat_list) + bp_labels = torch.stack(bp_list).to(dtype=torch.long) + bp_embed = self.bp_embedding.weight[bp_labels] + bp_embed = bp_embed.repeat_interleave(self.num_mask_tokens, dim=1) + + query_feat = query_feat + bp_embed + return query_feat, None + + def get_targets( + self, + cls_scores_list: List[Tensor], + mask_preds_list: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + return_sampling_results: bool = False + ) -> Tuple[List[Union[Tensor, int]]]: + """Compute classification and mask targets for all images for a decoder + layer. + + Args: + cls_scores_list (list[Tensor]): Mask score logits from a single + decoder layer for all images. Each with shape (num_queries, + cls_out_channels). + mask_preds_list (list[Tensor]): Mask logits from a single decoder + layer for all images. Each with shape (num_queries, h, w). + batch_gt_instances (list[obj:`InstanceData`]): each contains + ``labels`` and ``masks``. + batch_img_metas (list[dict]): List of image meta information. + return_sampling_results (bool): Whether to return the sampling + results. Defaults to False. + + Returns: + tuple: a tuple containing the following targets. + + - labels_list (list[Tensor]): Labels of all images.\ + Each with shape (num_queries, ). + - label_weights_list (list[Tensor]): Label weights\ + of all images. Each with shape (num_queries, ). + - mask_targets_list (list[Tensor]): Mask targets of\ + all images. Each with shape (num_queries, h, w). + - mask_weights_list (list[Tensor]): Mask weights of\ + all images. Each with shape (num_queries, ). + - avg_factor (int): Average factor that is used to average\ + the loss. When using sampling method, avg_factor is + usually the sum of positive and negative priors. When + using `MaskPseudoSampler`, `avg_factor` is usually equal + to the number of positive priors. + + additional_returns: This function enables user-defined returns from + `self._get_targets_single`. These returns are currently refined + to properties at each feature map (i.e. having HxW dimension). + The results will be concatenated after the end. + """ + results = multi_apply( + self._get_targets_single, cls_scores_list, mask_preds_list, batch_gt_instances, batch_img_metas + ) + labels_list, label_weights_list, mask_targets_list, mask_weights_list, \ + pos_inds_list, neg_inds_list, sampling_results_list = results[:7] + rest_results = list(results[7:]) + + avg_factor = sum([results.avg_factor for results in sampling_results_list]) + res = (labels_list, label_weights_list, mask_targets_list, mask_weights_list, avg_factor) + + if return_sampling_results: + res = res + sampling_results_list + + return res + tuple(rest_results) + + def _get_targets_single(self, cls_score: Tensor, mask_pred: Tensor, + gt_instances: InstanceData, + img_meta: dict) -> Tuple[Tensor]: + """Compute classification and mask targets for one image. + + Args: + cls_score (Tensor): Mask score logits from a single decoder layer + for one image. Shape (num_queries, cls_out_channels). + mask_pred (Tensor): Mask logits for a single decoder layer for one + image. Shape (num_queries, h, w). + gt_instances (:obj:`InstanceData`): It contains ``labels`` and + ``masks``. + img_meta (dict): Image informtation. + + Returns: + tuple[Tensor]: A tuple containing the following for one image. + + - labels (Tensor): Labels of each image. \ + shape (num_queries, ). + - label_weights (Tensor): Label weights of each image. \ + shape (num_queries, ). + - mask_targets (Tensor): Mask targets of each image. \ + shape (num_queries, h, w). + - mask_weights (Tensor): Mask weights of each image. \ + shape (num_queries, ). + - pos_inds (Tensor): Sampled positive indices for each \ + image. + - neg_inds (Tensor): Sampled negative indices for each \ + image. + - sampling_result (:obj:`SamplingResult`): Sampling results. + """ + gt_labels = gt_instances.labels + gt_masks = gt_instances.masks + + # sample points + num_queries = cls_score.shape[0] + num_gts = gt_labels.shape[0] + + if not self.matching_whole_map: + point_coords = torch.rand((1, self.num_points, 2), device=cls_score.device) + # shape (num_queries, num_points) + mask_points_pred = point_sample(mask_pred.unsqueeze(1), + point_coords.repeat(num_queries, 1, 1)).squeeze(1) + # shape (num_gts, num_points) + gt_points_masks = point_sample(gt_masks.unsqueeze(1).float(), + point_coords.repeat(num_gts, 1, 1)).squeeze(1) + else: + mask_points_pred = mask_pred + gt_points_masks = gt_masks + + sampled_gt_instances = InstanceData( + labels=gt_labels, masks=gt_points_masks) + sampled_pred_instances = InstanceData( + scores=cls_score, masks=mask_points_pred) + # assign and sample + assign_result = self.assigner.assign( + pred_instances=sampled_pred_instances, + gt_instances=sampled_gt_instances, + img_meta=img_meta + ) + pred_instances = InstanceData(scores=cls_score, masks=mask_pred) + sampling_result = self.sampler.sample( + assign_result=assign_result, + pred_instances=pred_instances, + gt_instances=gt_instances) + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + + # label target + labels = gt_labels.new_full((num_queries,), self.num_classes, dtype=torch.long) + labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] + label_weights = gt_labels.new_ones((num_queries,)) + + # mask target + mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds] + mask_weights = mask_pred.new_zeros((num_queries,)) + mask_weights[pos_inds] = 1.0 + + return labels, label_weights, mask_targets, mask_weights, pos_inds, neg_inds, sampling_result + + def loss_by_feat(self, all_cls_scores: Tensor, all_mask_preds: Tensor, + batch_gt_instances: List[InstanceData], + batch_img_metas: List[dict]) -> Dict[str, Tensor]: + """Loss function. + + Args: + all_cls_scores (Tensor): Classification scores for all decoder + layers with shape (num_decoder, batch_size, num_queries, + cls_out_channels). Note `cls_out_channels` should includes + background. + all_mask_preds (Tensor): Mask scores for all decoder layers with + shape (num_decoder, batch_size, num_queries, h, w). + batch_gt_instances (list[obj:`InstanceData`]): each contains + ``labels`` and ``masks``. + batch_img_metas (list[dict]): List of image meta information. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + num_dec_layers = len(all_cls_scores) + batch_gt_instances_list = [ + batch_gt_instances for _ in range(num_dec_layers) + ] + img_metas_list = [batch_img_metas for _ in range(num_dec_layers)] + losses_cls, losses_mask, losses_dice = multi_apply( + self._loss_by_feat_single, all_cls_scores, all_mask_preds, batch_gt_instances_list, img_metas_list + ) + + loss_dict = dict() + # loss from other decoder layers + num_dec_layer = 0 + for loss_cls_i, loss_mask_i, loss_dice_i in zip(losses_cls[:-1], losses_mask[:-1], losses_dice[:-1]): + loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i + loss_dict[f'd{num_dec_layer}.loss_mask'] = loss_mask_i + loss_dict[f'd{num_dec_layer}.loss_dice'] = loss_dice_i + num_dec_layer += 1 + # loss from the last decoder layer + loss_dict['loss_cls'] = losses_cls[-1] + loss_dict['loss_mask'] = losses_mask[-1] + loss_dict['loss_dice'] = losses_dice[-1] + return loss_dict + + def _loss_by_feat_single(self, cls_scores: Tensor, mask_preds: Tensor, + batch_gt_instances: List[InstanceData], + batch_img_metas: List[dict]) -> Tuple[Tensor]: + """Loss function for outputs from a single decoder layer. + + Args: + cls_scores (Tensor): Mask score logits from a single decoder layer + for all images. Shape (batch_size, num_queries, + cls_out_channels). Note `cls_out_channels` should includes + background. + mask_preds (Tensor): Mask logits for a pixel decoder for all + images. Shape (batch_size, num_queries, h, w). + batch_gt_instances (list[obj:`InstanceData`]): each contains + ``labels`` and ``masks``. + batch_img_metas (list[dict]): List of image meta information. + + Returns: + tuple[Tensor]: Loss components for outputs from a single \ + decoder layer. + """ + batch_size, num_ins = cls_scores.size(0), cls_scores.size(1) + # hack here: + is_sam = num_ins != self.num_queries + + if not is_sam: + cls_scores_list = [cls_scores[i] for i in range(batch_size)] + mask_preds_list = [mask_preds[i] for i in range(batch_size)] + labels_list, label_weights_list, mask_targets_list, mask_weights_list, avg_factor = \ + self.get_targets(cls_scores_list, mask_preds_list, batch_gt_instances, batch_img_metas) + labels = torch.stack(labels_list, dim=0) + label_weights = torch.stack(label_weights_list, dim=0) + mask_targets = torch.cat(mask_targets_list, dim=0) + mask_weights = torch.stack(mask_weights_list, dim=0) + else: + labels = torch.stack([item.labels for item in batch_gt_instances]) + label_weights = labels.new_ones((batch_size, num_ins), dtype=torch.float) + mask_targets = torch.cat([item.masks for item in batch_gt_instances]) + mask_weights = mask_targets.new_ones((batch_size, num_ins), dtype=torch.float) + avg_factor = cls_scores.size(1) + + # classification loss + # shape (batch_size * num_queries, ) + cls_scores = cls_scores.flatten(0, 1) + labels = labels.flatten(0, 1) + label_weights = label_weights.flatten(0, 1) + class_weight = cls_scores.new_tensor(self.class_weight) + ignore_inds = labels.eq(-1.) + # zero will not be involved in the loss cal + labels[ignore_inds] = 0 + label_weights[ignore_inds] = 0. + obj_inds = labels.eq(self.num_classes) + if is_sam: + cls_avg_factor = cls_scores.new_tensor([0]) + else: + cls_avg_factor = class_weight[labels].sum() + cls_avg_factor = reduce_mean(cls_avg_factor) + cls_avg_factor = max(cls_avg_factor, 1) + if self.loss_iou is not None: + loss_cls = self.loss_cls( + cls_scores[..., :-1], + labels, + label_weights, + avg_factor=cls_avg_factor + ) + loss_iou = self.loss_iou( + cls_scores[..., -1:], + obj_inds.to(dtype=torch.long), + avg_factor=cls_avg_factor + ) + if is_sam: + loss_iou = loss_iou * 0 + loss_cls = loss_cls + loss_iou + else: + loss_cls = self.loss_cls( + cls_scores, + labels, + label_weights, + avg_factor=cls_avg_factor + ) + + # loss_mask + num_total_masks = reduce_mean(cls_scores.new_tensor([avg_factor])) + num_total_masks = max(num_total_masks, 1) + # extract positive ones + # shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w) + mask_preds = mask_preds[mask_weights > 0] + + if mask_targets.shape[0] == 0: + # zero match + loss_dice = mask_preds.sum() + loss_mask = mask_preds.sum() + return loss_cls, loss_mask, loss_dice + + if not self.matching_whole_map: + with torch.no_grad(): + points_coords = get_uncertain_point_coords_with_randomness( + mask_preds.unsqueeze(1), None, self.num_points, + self.oversample_ratio, self.importance_sample_ratio) + # shape (num_total_gts, h, w) -> (num_total_gts, num_points) + mask_point_targets = point_sample( + mask_targets.unsqueeze(1).float(), points_coords).squeeze(1) + # shape (num_queries, h, w) -> (num_queries, num_points) + mask_point_preds = point_sample( + mask_preds.unsqueeze(1), points_coords).squeeze(1) + else: + mask_point_targets = mask_targets + mask_point_preds = mask_preds + + # dice loss + loss_dice = self.loss_dice( + mask_point_preds, mask_point_targets, avg_factor=num_total_masks) + + # mask loss + # shape (num_queries, num_points) -> (num_queries * num_points, ) + mask_point_preds = mask_point_preds.reshape(-1) + # shape (num_total_gts, num_points) -> (num_total_gts * num_points, ) + mask_point_targets = mask_point_targets.reshape(-1) + loss_mask = self.loss_mask( + mask_point_preds, + mask_point_targets, + avg_factor=num_total_masks * self.num_points + ) + + return loss_cls, loss_mask, loss_dice + + def forward_logit(self, cls_embd): + cls_pred = torch.einsum('bnc,ckp->bnkp', F.normalize(cls_embd, dim=-1), self.cls_embed) + cls_pred = cls_pred.max(-1).values + cls_pred = self.logit_scale.exp() * cls_pred + return cls_pred + + def _forward_head_llm(self, decoder_out: Tensor, mask_feature: Tensor, + attn_mask_target_size: Tuple[int, int], + num_frames: int = 0) -> Tuple[Tensor]: + """Forward for head part which is called after every decoder layer. + + Args: + decoder_out (Tensor): in shape (batch_size, num_queries, c). + mask_feature (Tensor): in shape (batch_size, c, h, w). + attn_mask_target_size (tuple[int, int]): target attention + mask size. + + Returns: + tuple: A tuple contain three elements. + + - cls_pred (Tensor): Classification scores in shape \ + (batch_size, num_queries, cls_out_channels). \ + Note `cls_out_channels` should includes background. + - mask_pred (Tensor): Mask scores in shape \ + (batch_size, num_queries,h, w). + - attn_mask (Tensor): Attention mask in shape \ + (batch_size * num_heads, num_queries, h, w). + - num_frames: How many frames are there in video. + """ + if self.transformer_decoder_llm is None: + decoder_out = self.transformer_decoder.post_norm(decoder_out) + else: + self.transformer_decoder_llm.post_norm = self.transformer_decoder_llm.post_norm.to(decoder_out.dtype) + decoder_out = self.transformer_decoder_llm.post_norm(decoder_out) + # shape (num_queries, batch_size, c) + + if self.mask_embed_llm is None: + mask_embed = self.mask_embed(decoder_out) + else: + self.mask_embed_llm = self.mask_embed_llm.to(decoder_out.dtype) + mask_embed = self.mask_embed_llm(decoder_out) + # shape (num_queries, batch_size, h, w) + mask_pred = torch.einsum('bqc,bchw->bqhw', mask_embed, mask_feature) + + if num_frames > 0: + assert len(mask_pred.shape) == 4 + assert mask_pred.shape[2] % num_frames == 0 + frame_h = mask_pred.shape[2] // num_frames + num_q = mask_pred.shape[1] + _mask_pred = mask_pred.unflatten(-2, (num_frames, frame_h)).flatten(1, 2) + attn_mask = F.interpolate( + _mask_pred, + attn_mask_target_size, + mode='bilinear', + align_corners=False) + attn_mask = attn_mask.unflatten(1, (num_q, num_frames)).flatten(2, 3) + else: + attn_mask = F.interpolate( + mask_pred, + attn_mask_target_size, + mode='bilinear', + align_corners=False) + # shape (num_queries, batch_size, h, w) -> + # (batch_size * num_head, num_queries, h, w) + attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat( + (1, self.num_heads, 1, 1)).flatten(0, 1) + attn_mask = attn_mask.sigmoid() < 0.5 + attn_mask = attn_mask.detach() + + return mask_pred, attn_mask + + def _forward_head(self, decoder_out: Tensor, mask_feature: Tensor, + attn_mask_target_size: Tuple[int, int], + num_frames: int = 0) -> Tuple[Tensor]: + """Forward for head part which is called after every decoder layer. + + Args: + decoder_out (Tensor): in shape (batch_size, num_queries, c). + mask_feature (Tensor): in shape (batch_size, c, h, w). + attn_mask_target_size (tuple[int, int]): target attention + mask size. + + Returns: + tuple: A tuple contain three elements. + + - cls_pred (Tensor): Classification scores in shape \ + (batch_size, num_queries, cls_out_channels). \ + Note `cls_out_channels` should includes background. + - mask_pred (Tensor): Mask scores in shape \ + (batch_size, num_queries,h, w). + - attn_mask (Tensor): Attention mask in shape \ + (batch_size * num_heads, num_queries, h, w). + - num_frames: How many frames are there in video. + """ + decoder_out = self.transformer_decoder.post_norm(decoder_out) + # shape (num_queries, batch_size, c) + if isinstance(self.cls_embed, nn.Module): + cls_pred = self.cls_embed(decoder_out) + # shape (num_queries, batch_size, c) + mask_embed = self.mask_embed(decoder_out) + # shape (num_queries, batch_size, h, w) + mask_pred = torch.einsum('bqc,bchw->bqhw', mask_embed, mask_feature) + + if not isinstance(self.cls_embed, nn.Module): + maskpool_embd = self.mask_pooling(x=mask_feature, mask=mask_pred.detach()) + maskpool_embd = self.mask_pooling_proj(maskpool_embd) + cls_embd = self.cls_proj(maskpool_embd + decoder_out) + cls_pred = self.forward_logit(cls_embd) + + if self.iou_embed is not None: + iou_pred = self.iou_embed(decoder_out) + cls_pred = torch.cat([cls_pred, iou_pred], dim=-1) + + if num_frames > 0: + assert len(mask_pred.shape) == 4 + assert mask_pred.shape[2] % num_frames == 0 + frame_h = mask_pred.shape[2] // num_frames + num_q = mask_pred.shape[1] + _mask_pred = mask_pred.unflatten(-2, (num_frames, frame_h)).flatten(1, 2) + attn_mask = F.interpolate( + _mask_pred, + attn_mask_target_size, + mode='bilinear', + align_corners=False) + attn_mask = attn_mask.unflatten(1, (num_q, num_frames)).flatten(2, 3) + else: + attn_mask = F.interpolate( + mask_pred, + attn_mask_target_size, + mode='bilinear', + align_corners=False) + # shape (num_queries, batch_size, h, w) -> + # (batch_size * num_head, num_queries, h, w) + attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat( + (1, self.num_heads, 1, 1)).flatten(0, 1) + attn_mask = attn_mask.sigmoid() < 0.5 + attn_mask = attn_mask.detach() + + return cls_pred, mask_pred, attn_mask + + def forward_llm_seg(self, hidden_states, batch_idxs): + # hidden_states (N, C) -> (N, 1, C) + hidden_states = hidden_states.to(self.query_feat.weight.dtype) + hidden_states = hidden_states.unsqueeze(1) + C = hidden_states.shape[-1] + num_frames = 0 + + if self.pixel_decoder_llm is not None: + self.pixel_decoder_llm = self.pixel_decoder_llm.to(hidden_states.dtype) + mask_features, multi_scale_memorys = self.pixel_decoder(self.image_feat) + mask_features = mask_features[batch_idxs] + else: + mask_features = self.cur_batch_mask_features[batch_idxs] + multi_scale_memorys = self.cur_batch_multi_scale_memorys + + decoder_inputs = [] + decoder_positional_encodings = [] + for i in range(self.num_transformer_feat_level): + decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i]) + # shape (batch_size, c, h, w) -> (batch_size, h*w, c) + decoder_input = decoder_input.flatten(2).permute(0, 2, 1) + level_embed = self.level_embed.weight[i].view(1, 1, -1) + decoder_input = decoder_input + level_embed + + decoder_input = decoder_input[batch_idxs] # (N, hw, c) + batch_size = len(decoder_input) + + # shape (batch_size, c, h, w) -> (batch_size, h*w, c) + num_frames_real = 1 + mask = decoder_input.new_zeros( + (batch_size, num_frames_real) + multi_scale_memorys[i].shape[-2:], + dtype=torch.bool) + decoder_positional_encoding = self.decoder_positional_encoding( + mask) + decoder_positional_encoding = decoder_positional_encoding.transpose( + 1, 2).flatten(2).permute(0, 2, 1) + decoder_inputs.append(decoder_input) + decoder_positional_encodings.append(decoder_positional_encoding) + + query_feat = hidden_states[:, :, :C//2] + query_embed = hidden_states[:, :, C//2:] + self_attn_mask = None + + mask_pred_list = [] + mask_pred, attn_mask = self._forward_head_llm( + query_feat, mask_features, multi_scale_memorys[0].shape[-2:], + num_frames=num_frames + ) + if num_frames > 0: + mask_pred = mask_pred.unflatten(2, (num_frames, -1)) + mask_pred_list.append(mask_pred) + + for i in range(self.num_transformer_decoder_layers): + level_idx = i % self.num_transformer_feat_level + # if a mask is all True(all background), then set it all False. + attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False + + # cross_attn + self_attn + if self.transformer_decoder_llm is not None: + layer = self.transformer_decoder_llm.layers[i] + else: + layer = self.transformer_decoder.layers[i] + layer = layer.to(query_feat.dtype) + query_feat = layer( + query=query_feat, + key=decoder_inputs[level_idx], + value=decoder_inputs[level_idx], + query_pos=query_embed, + key_pos=decoder_positional_encodings[level_idx], + cross_attn_mask=attn_mask, + self_attn_mask=self_attn_mask, + query_key_padding_mask=None, + # here we do not apply masking on padded region + key_padding_mask=None, + query_cache=None, + query_cache_pos=None, + ) + mask_pred, attn_mask = self._forward_head_llm( + query_feat, mask_features, multi_scale_memorys[(i + 1) % self.num_transformer_feat_level].shape[-2:], + num_frames=num_frames + ) + + if num_frames > 0: + mask_pred = mask_pred.unflatten(2, (num_frames, -1)) + mask_pred_list.append(mask_pred) + + # mask_pred_list [(b, 1, h, w), ...] + return mask_pred_list + + def sample_points(self, mask_pred, gt_masks): + gt_masks = gt_masks.unsqueeze(1) + gt_masks = gt_masks.to(mask_pred) + # (N, 1, h, w) + + with torch.no_grad(): + points_coords = get_uncertain_point_coords_with_randomness( + mask_pred, None, self.num_points, + self.oversample_ratio, self.importance_sample_ratio) + # shape (num_total_gts, h, w) -> (num_total_gts, num_points) + mask_point_targets = point_sample( + gt_masks.float(), points_coords).squeeze(1) + # shape (num_queries, h, w) -> (num_queries, num_points) + mask_point_preds = point_sample( + mask_pred, points_coords).squeeze(1) + return mask_point_preds, mask_point_targets + + def llm_seg_loss(self, mask_pred_list, gt_masks): + + if gt_masks is None or gt_masks[0] is None: + ret_loss = 0 + for mask_pred in mask_pred_list: + ret_loss = ret_loss + mask_pred.sum() * 0.0 + return [ret_loss], [ret_loss] + + # dice loss and ce loss + all_loss_dice = [] + all_loss_mask = [] + + for mask_pred in mask_pred_list: + + sampled_mask_pred, sampled_mask_gt = self.sample_points(mask_pred, gt_masks) + loss_dice = self.loss_dice( + sampled_mask_pred, + sampled_mask_gt, avg_factor=(len(gt_masks) + 1e-4)) + loss_mask = self.loss_mask( + sampled_mask_pred.reshape(-1), + sampled_mask_gt.reshape(-1), + avg_factor=(sampled_mask_pred.shape[0] * sampled_mask_pred.shape[1] + 1e-4)) + all_loss_dice.append(loss_dice) + all_loss_mask.append(loss_mask) + return all_loss_dice, all_loss_mask + + def forward(self, x: List[Tensor], batch_data_samples: SampleList, + return_mask_features=False, save_feat=False, + return_query_pos=False) -> Tuple[List[Tensor]]: + """Forward function. + + Args: + x (list[Tensor]): Multi scale Features from the + upstream network, each is a 4D-tensor. + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + + Returns: + tuple[list[Tensor]]: A tuple contains two elements. + + - cls_pred_list (list[Tensor)]: Classification logits \ + for each decoder layer. Each is a 3D-tensor with shape \ + (batch_size, num_queries, cls_out_channels). \ + Note `cls_out_channels` should includes background. + - mask_pred_list (list[Tensor]): Mask logits for each \ + decoder layer. Each with shape (batch_size, num_queries, \ + h, w). + """ + batch_img_metas = [] + if isinstance(batch_data_samples[0], TrackDataSample): + for track_sample in batch_data_samples: + cur_list = [] + for det_sample in track_sample: + cur_list.append(det_sample.metainfo) + batch_img_metas.append(cur_list) + num_frames = len(batch_img_metas[0]) + else: + for data_sample in batch_data_samples: + batch_img_metas.append(data_sample.metainfo) + num_frames = 0 + batch_size = len(batch_img_metas) + + mask_features, multi_scale_memorys = self.pixel_decoder(x) + + if num_frames > 0: + mask_features = mask_features.unflatten(0, (batch_size, num_frames)) + mask_features = mask_features.transpose(1, 2).flatten(2, 3) + + # save for decode the llm's [SEG] tokens + if save_feat: + self.cur_batch_mask_features = mask_features + self.cur_batch_multi_scale_memorys = multi_scale_memorys + self.image_feat = x + + decoder_inputs = [] + decoder_positional_encodings = [] + for i in range(self.num_transformer_feat_level): + decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i]) + # shape (batch_size, c, h, w) -> (batch_size, h*w, c) + decoder_input = decoder_input.flatten(2).permute(0, 2, 1) + if num_frames > 0: + decoder_input = decoder_input.unflatten(0, (batch_size, num_frames)) + decoder_input = decoder_input.flatten(1, 2) + level_embed = self.level_embed.weight[i].view(1, 1, -1) + decoder_input = decoder_input + level_embed + + # shape (batch_size, c, h, w) -> (batch_size, h*w, c) + num_frames_real = 1 if num_frames == 0 else num_frames + mask = decoder_input.new_zeros( + (batch_size, num_frames_real) + multi_scale_memorys[i].shape[-2:], + dtype=torch.bool) + decoder_positional_encoding = self.decoder_positional_encoding( + mask) + decoder_positional_encoding = decoder_positional_encoding.transpose( + 1, 2).flatten(2).permute(0, 2, 1) + decoder_inputs.append(decoder_input) + decoder_positional_encodings.append(decoder_positional_encoding) + + # only for encode the image for llm, not support sam mode in this process + if False and batch_data_samples[0].data_tag in ['sam_mul', 'sam']: + query_feat, input_query_bbox, self_attn_mask, _ = self.prepare_for_dn_mo(batch_data_samples) + query_embed = coordinate_to_encoding(input_query_bbox.sigmoid()) + query_embed = self.pos_linear(query_embed) + else: + # coco style query generation + # shape (num_queries, c) -> (batch_size, num_queries, c) + query_feat = self.query_feat.weight.unsqueeze(0).repeat((batch_size, 1, 1)) + query_embed = self.query_embed.weight.unsqueeze(0).repeat((batch_size, 1, 1)) + self_attn_mask = None + + cls_pred_list = [] + mask_pred_list = [] + cls_pred, mask_pred, attn_mask = self._forward_head( + query_feat, mask_features, multi_scale_memorys[0].shape[-2:], + num_frames=num_frames + ) + cls_pred_list.append(cls_pred) + if num_frames > 0: + mask_pred = mask_pred.unflatten(2, (num_frames, -1)) + mask_pred_list.append(mask_pred) + + for i in range(self.num_transformer_decoder_layers): + level_idx = i % self.num_transformer_feat_level + # if a mask is all True(all background), then set it all False. + attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False + + # cross_attn + self_attn + layer = self.transformer_decoder.layers[i] + query_feat = layer( + query=query_feat, + key=decoder_inputs[level_idx], + value=decoder_inputs[level_idx], + query_pos=query_embed, + key_pos=decoder_positional_encodings[level_idx], + cross_attn_mask=attn_mask, + self_attn_mask=self_attn_mask, + query_key_padding_mask=None, + # here we do not apply masking on padded region + key_padding_mask=None, + ) + cls_pred, mask_pred, attn_mask = self._forward_head( + query_feat, mask_features, multi_scale_memorys[(i + 1) % self.num_transformer_feat_level].shape[-2:], + num_frames=num_frames + ) + + cls_pred_list.append(cls_pred) + if num_frames > 0: + mask_pred = mask_pred.unflatten(2, (num_frames, -1)) + mask_pred_list.append(mask_pred) + if return_mask_features: + if return_query_pos: + return cls_pred_list, mask_pred_list, query_feat, query_embed, mask_features + return cls_pred_list, mask_pred_list, query_feat, mask_features + if return_query_pos: + return cls_pred_list, mask_pred_list, query_feat, query_embed + return cls_pred_list, mask_pred_list, query_feat + + def loss( + self, + x: Tuple[Tensor], + batch_data_samples: SampleList, + ) -> Dict[str, Tensor]: + """Perform forward propagation and loss calculation of the panoptic + head on the features of the upstream network. + + Args: + x (tuple[Tensor]): Multi-level features from the upstream + network, each is a 4D-tensor. + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + batch_img_metas = [] + batch_gt_instances = [] + batch_gt_semantic_segs = [] + for data_sample in batch_data_samples: + if isinstance(data_sample, TrackDataSample): + clip_meta = [] + clip_instances = [] + clip_sem_seg = [] + for det_sample in data_sample: + clip_meta.append(det_sample.metainfo) + clip_instances.append(det_sample.gt_instances) + if 'gt_sem_seg' in det_sample: + clip_sem_seg.append(det_sample.gt_sem_seg) + else: + clip_sem_seg.append(None) + batch_img_metas.append(clip_meta) + batch_gt_instances.append(clip_instances) + batch_gt_semantic_segs.append(clip_sem_seg) + else: + batch_img_metas.append(data_sample.metainfo) + batch_gt_instances.append(data_sample.gt_instances) + if 'gt_sem_seg' in data_sample: + batch_gt_semantic_segs.append(data_sample.gt_sem_seg) + else: + batch_gt_semantic_segs.append(None) + + # forward + all_cls_scores, all_mask_preds, _ = self(x, batch_data_samples) + + # preprocess ground truth + if not self.enable_box_query or batch_data_samples[0].data_tag in ['coco', 'sam']: + batch_gt_instances = self.preprocess_gt(batch_gt_instances, batch_gt_semantic_segs) + + # loss + if isinstance(batch_data_samples[0], TrackDataSample): + num_frames = len(batch_img_metas[0]) + all_mask_preds = [mask.flatten(2, 3) for mask in all_mask_preds] + for instance in batch_gt_instances: + instance['masks'] = instance['masks'].flatten(1, 2) + film_metas = [ + { + 'img_shape': (meta[0]['img_shape'][0] * num_frames, + meta[0]['img_shape'][1]) + } for meta in batch_img_metas + ] + batch_img_metas = film_metas + + losses = self.loss_by_feat(all_cls_scores, all_mask_preds, batch_gt_instances, batch_img_metas) + + if self.enable_box_query: + losses['loss_zero'] = 0 * self.query_feat.weight.sum() + 0 * self.query_embed.weight.sum() + losses['loss_zero'] += 0 * self.pb_embedding.weight.sum() + losses['loss_zero'] += 0 * self.mask_tokens.weight.sum() + for name, param in self.pos_linear.named_parameters(): + losses['loss_zero'] += 0 * param.sum() + return losses + + def predict(self, x: Tuple[Tensor], + batch_data_samples: SampleList, + return_query=False, + return_mask_features=False, + save_feat=False, + return_query_pos=False, + ) -> Tuple[Tensor, ...]: + """Test without augmentaton. + + Args: + return_query: + x (tuple[Tensor]): Multi-level features from the + upstream network, each is a 4D-tensor. + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + + Returns: + tuple[Tensor]: A tuple contains two tensors. + + - mask_cls_results (Tensor): Mask classification logits,\ + shape (batch_size, num_queries, cls_out_channels). + Note `cls_out_channels` should includes background. + - mask_pred_results (Tensor): Mask logits, shape \ + (batch_size, num_queries, h, w). + """ + data_sample = batch_data_samples[0] + if isinstance(data_sample, TrackDataSample): + img_shape = data_sample[0].metainfo['batch_input_shape'] + num_frames = len(data_sample) + else: + img_shape = data_sample.metainfo['batch_input_shape'] + num_frames = 0 + if return_mask_features: + all_cls_scores, all_mask_preds, query_feat, query_pos, mask_features =\ + self(x, batch_data_samples, return_mask_features, + save_feat=save_feat, return_query_pos=True) + else: + all_cls_scores, all_mask_preds, query_feat, query_pos =\ + self(x, batch_data_samples, save_feat=save_feat, return_query_pos=True) + if self.iou_embed is not None: + _all_cls_scores = [cls_score[..., :-1] for cls_score in all_cls_scores] + iou_results = [cls_score[..., -1:] for cls_score in all_cls_scores] + all_cls_scores = _all_cls_scores + else: + iou_results = None + mask_cls_results = all_cls_scores[-1] + mask_pred_results = all_mask_preds[-1] + if iou_results is not None: + iou_results = iou_results[-1] + + if num_frames > 0: + mask_pred_results = mask_pred_results.flatten(1, 2) + mask_pred_results = F.interpolate( + mask_pred_results, + size=(img_shape[0], img_shape[1]), + mode='bilinear', + align_corners=False) + if num_frames > 0: + num_queries = mask_cls_results.shape[1] + mask_pred_results = mask_pred_results.unflatten(1, (num_queries, num_frames)) + + if iou_results is None: + return mask_cls_results, mask_pred_results + + if return_query: + if return_mask_features: + if return_query_pos: + return mask_cls_results, mask_pred_results, query_feat, query_pos, iou_results, mask_features + return mask_cls_results, mask_pred_results, query_feat, iou_results, mask_features + else: + if return_query_pos: + return mask_cls_results, mask_pred_results, query_feat, query_pos, iou_results + return mask_cls_results, mask_pred_results, query_feat, iou_results + else: + if return_mask_features: + return mask_cls_results, mask_pred_results, iou_results, mask_features + else: + return mask_cls_results, mask_pred_results, iou_results + + def prepare_for_dn_mo(self, batch_data_samples): + scalar, noise_scale = 100, 0.4 + gt_instances = [t.gt_instances for t in batch_data_samples] + + point_coords = torch.stack([inst.point_coords for inst in gt_instances]) + pb_labels = torch.stack([inst['bp'] for inst in gt_instances]) + labels = torch.zeros_like(pb_labels).long() + + boxes = point_coords # + boxes + + factors = [] + for i, data_sample in enumerate(batch_data_samples): + h, w, = data_sample.metainfo['img_shape'] + factor = boxes[i].new_tensor([w, h, w, h]).unsqueeze(0).repeat(boxes[i].size(0), 1) + factors.append(factor) + factors = torch.stack(factors, 0) + + boxes = bbox_xyxy_to_cxcywh(boxes / factors) + box_start = [len(point) for point in point_coords] + + known_labels = labels + known_pb_labels = pb_labels + known_bboxs = boxes + + known_labels_expaned = known_labels.clone() + known_pb_labels_expaned = known_pb_labels.clone() + known_bbox_expand = known_bboxs.clone() + + if noise_scale > 0 and self.training: + diff = torch.zeros_like(known_bbox_expand) + diff[:, :, :2] = known_bbox_expand[:, :, 2:] / 2 + diff[:, :, 2:] = known_bbox_expand[:, :, 2:] + # add very small noise to input points; no box + sc = 0.01 + for i, st in enumerate(box_start): + diff[i, :st] = diff[i, :st] * sc + known_bbox_expand += torch.mul( + (torch.rand_like(known_bbox_expand) * 2 - 1.0), + diff) * noise_scale + + known_bbox_expand = known_bbox_expand.clamp(min=0.0, max=1.0) + + input_label_embed = self.pb_embedding(known_pb_labels_expaned) + + input_bbox_embed = inverse_sigmoid(known_bbox_expand) + + input_label_embed = input_label_embed.repeat_interleave( + self.num_mask_tokens, + 1) + self.mask_tokens.weight.unsqueeze(0).repeat( + input_label_embed.shape[0], input_label_embed.shape[1], 1) + input_bbox_embed = input_bbox_embed.repeat_interleave( + self.num_mask_tokens, 1) + + single_pad = self.num_mask_tokens + + # NOTE scalar is modified to 100, each click cannot see each other + scalar = int(input_label_embed.shape[1] / self.num_mask_tokens) + + pad_size = input_label_embed.shape[1] + + if input_label_embed.shape[1] > 0: + input_query_label = input_label_embed + input_query_bbox = input_bbox_embed + + tgt_size = pad_size + attn_mask = torch.ones(tgt_size, tgt_size).to('cuda') < 0 + # match query cannot see the reconstruct + attn_mask[pad_size:, :pad_size] = True + # reconstruct cannot see each other + for i in range(scalar): + if i == 0: + attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True + if i == scalar - 1: + attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True + else: + attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True + attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True + mask_dict = { + 'known_lbs_bboxes': (known_labels, known_bboxs), + 'pad_size': pad_size, + 'scalar': scalar, + } + return input_query_label, input_query_bbox, attn_mask, mask_dict + + def prepare_sam_query(self, points, w, h): + # points N, 1, 2 + tl_points = points - 3 + br_points = points + 3 + boxes = torch.cat([tl_points, br_points], dim=-1) + + labels = torch.zeros((points.shape[0], 1, ), dtype=torch.int64).to(points.device) + + factors = torch.Tensor([[[w, h, w, h]]]).to(boxes) + + boxes = bbox_xyxy_to_cxcywh(boxes / factors) # xyxy / factor or xywh / factor ???? + print('rela_coords:', boxes) + known_bboxs = boxes + + known_bbox_expand = known_bboxs.clone() + + input_label_embed = self.pb_embedding(labels) + input_bbox_embed = inverse_sigmoid(known_bbox_expand) + + input_label_embed = input_label_embed.repeat_interleave( + self.num_mask_tokens, + 1) + self.mask_tokens.weight.unsqueeze(0).repeat( + input_label_embed.shape[0], input_label_embed.shape[1], 1) + input_bbox_embed = input_bbox_embed.repeat_interleave( + self.num_mask_tokens, 1) + + if input_label_embed.shape[1] > 0: + input_query_label = input_label_embed + input_query_bbox = input_bbox_embed + + query_embed = coordinate_to_encoding(input_query_bbox.sigmoid()) + query_embed = self.pos_linear(query_embed) + return input_query_label, query_embed # (N, 1, C) + + def forward_visual_prompt(self, regions, batch_idxs): + # points (N, 2) + points = get_center_coords(regions) + points = points.to(torch.float32).to(regions.device) + + query_feat, query_embed = self.prepare_sam_query( + points.unsqueeze(1), w=regions.shape[-1], h=regions.shape[-2], + ) # (N, 1, c) + + # hidden_states (N, C) -> (N, 1, C) + num_frames = 0 + + mask_features = self.cur_batch_mask_features[batch_idxs] + multi_scale_memorys = self.cur_batch_multi_scale_memorys + + query_feat = query_feat.to(mask_features) + query_embed = query_embed.to(mask_features) + + decoder_inputs = [] + decoder_positional_encodings = [] + for i in range(self.num_transformer_feat_level): + decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i]) + # shape (batch_size, c, h, w) -> (batch_size, h*w, c) + decoder_input = decoder_input.flatten(2).permute(0, 2, 1) + level_embed = self.level_embed.weight[i].view(1, 1, -1) + decoder_input = decoder_input + level_embed + + decoder_input = decoder_input[batch_idxs] # (N, hw, c) + batch_size = len(decoder_input) + + # shape (batch_size, c, h, w) -> (batch_size, h*w, c) + num_frames_real = 1 + mask = decoder_input.new_zeros( + (batch_size, num_frames_real) + multi_scale_memorys[i].shape[-2:], + dtype=torch.bool) + decoder_positional_encoding = self.decoder_positional_encoding( + mask) + decoder_positional_encoding = decoder_positional_encoding.transpose( + 1, 2).flatten(2).permute(0, 2, 1) + decoder_inputs.append(decoder_input) + decoder_positional_encodings.append(decoder_positional_encoding) + + self_attn_mask = None + + mask_pred_list = [] + mask_pred, attn_mask = self._forward_head_sam( + query_feat, mask_features, multi_scale_memorys[0].shape[-2:], + num_frames=num_frames, regions=regions, + ) + if num_frames > 0: + mask_pred = mask_pred.unflatten(2, (num_frames, -1)) + mask_pred_list.append(mask_pred) + + for i in range(self.num_transformer_decoder_layers): + level_idx = i % self.num_transformer_feat_level + # if a mask is all True(all background), then set it all False. + attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False + + # cross_attn + self_attn + + layer = self.transformer_decoder.layers[i] + layer = layer.to(query_feat.dtype) + layer.forward = MethodType(forward_cache, layer) + query_feat = layer( + query=query_feat, + key=decoder_inputs[level_idx], + value=decoder_inputs[level_idx], + query_pos=query_embed, + key_pos=decoder_positional_encodings[level_idx], + cross_attn_mask=attn_mask, + self_attn_mask=self_attn_mask, + query_key_padding_mask=None, + # here we do not apply masking on padded region + key_padding_mask=None, + query_cache=None, + query_cache_pos=None, + ) + mask_pred, attn_mask = self._forward_head_sam( + query_feat, mask_features, multi_scale_memorys[(i + 1) % self.num_transformer_feat_level].shape[-2:], + num_frames=num_frames, regions=regions, + ) + + if num_frames > 0: + mask_pred = mask_pred.unflatten(2, (num_frames, -1)) + mask_pred_list.append(mask_pred) + return query_feat, query_embed + + + def _forward_head_sam(self, decoder_out: Tensor, mask_feature: Tensor, + attn_mask_target_size: Tuple[int, int], + num_frames: int = 0, regions=None, bboxes=None) -> Tuple[Tensor]: + """Forward for head part which is called after every decoder layer. + + Args: + decoder_out (Tensor): in shape (batch_size, num_queries, c). + mask_feature (Tensor): in shape (batch_size, c, h, w). + attn_mask_target_size (tuple[int, int]): target attention + mask size. + + Returns: + tuple: A tuple contain three elements. + + - cls_pred (Tensor): Classification scores in shape \ + (batch_size, num_queries, cls_out_channels). \ + Note `cls_out_channels` should includes background. + - mask_pred (Tensor): Mask scores in shape \ + (batch_size, num_queries,h, w). + - attn_mask (Tensor): Attention mask in shape \ + (batch_size * num_heads, num_queries, h, w). + - num_frames: How many frames are there in video. + """ + decoder_out = self.transformer_decoder.post_norm(decoder_out) + mask_embed = self.mask_embed(decoder_out) + + # shape (num_queries, batch_size, h, w) + mask_pred = torch.einsum('bqc,bchw->bqhw', mask_embed, mask_feature) + + if regions is not None: + attn_mask = self._get_attn_mask_from_gt_mask(regions, + attn_mask_target_size) + else: + if num_frames > 0: + assert len(mask_pred.shape) == 4 + assert mask_pred.shape[2] % num_frames == 0 + frame_h = mask_pred.shape[2] // num_frames + num_q = mask_pred.shape[1] + _mask_pred = mask_pred.unflatten(-2, (num_frames, frame_h)).flatten(1, 2) + attn_mask = F.interpolate( + _mask_pred, + attn_mask_target_size, + mode='bilinear', + align_corners=False) + attn_mask = attn_mask.unflatten(1, (num_q, num_frames)).flatten(2, 3) + else: + attn_mask = F.interpolate( + mask_pred, + attn_mask_target_size, + mode='bilinear', + align_corners=False) + + # set attn maps + if bboxes is not None: + cur_scale_bboxes = copy.deepcopy(bboxes) + bs, _, h, w = attn_mask.shape + assert len(bboxes) == bs + cur_scale_bboxes = np.clip(cur_scale_bboxes, a_min=0, a_max=1) + cur_scale_bboxes[:, [0, 2]] *= w + cur_scale_bboxes[:, [1, 3]] *= h + cur_scale_bboxes[:, 2:] += 1 + cur_scale_bboxes = torch.Tensor(np.floor(cur_scale_bboxes)) + cur_scale_bboxes = cur_scale_bboxes.to(torch.int64) + for i in range(bs): + sx, sy = cur_scale_bboxes[i][0], cur_scale_bboxes[i][1] + ex, ey = cur_scale_bboxes[i][2], cur_scale_bboxes[i][3] + attn_mask[i, :, :sy, :] = True + attn_mask[i, :, ey:, :] = True + attn_mask[i, :, :, :sx] = True + attn_mask[i, :, :, ex:] = True + + # shape (num_queries, batch_size, h, w) -> + # (batch_size * num_head, num_queries, h, w) + attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat( + (1, self.num_heads, 1, 1)).flatten(0, 1) + attn_mask = attn_mask.sigmoid() < 0.5 + attn_mask = attn_mask.detach() + + return mask_pred, attn_mask + + def forward_point_prompt(self, points, batch_idxs, width, height): + # regions (N, H, W) + + query_feat, query_embed = self.prepare_sam_query( + points.unsqueeze(1), w=width, h=height, + ) # (N, 1, c) + + # hidden_states (N, C) -> (N, 1, C) + num_frames = 0 + + mask_features = self.cur_batch_mask_features[batch_idxs] + multi_scale_memorys = self.cur_batch_multi_scale_memorys + + query_feat = query_feat.to(mask_features) + query_embed = query_embed.to(mask_features) + + decoder_inputs = [] + decoder_positional_encodings = [] + for i in range(self.num_transformer_feat_level): + decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i]) + # shape (batch_size, c, h, w) -> (batch_size, h*w, c) + decoder_input = decoder_input.flatten(2).permute(0, 2, 1) + level_embed = self.level_embed.weight[i].view(1, 1, -1) + decoder_input = decoder_input + level_embed + + decoder_input = decoder_input[batch_idxs] # (N, hw, c) + batch_size = len(decoder_input) + + # shape (batch_size, c, h, w) -> (batch_size, h*w, c) + num_frames_real = 1 + mask = decoder_input.new_zeros( + (batch_size, num_frames_real) + multi_scale_memorys[i].shape[-2:], + dtype=torch.bool) + decoder_positional_encoding = self.decoder_positional_encoding( + mask) + decoder_positional_encoding = decoder_positional_encoding.transpose( + 1, 2).flatten(2).permute(0, 2, 1) + decoder_inputs.append(decoder_input) + decoder_positional_encodings.append(decoder_positional_encoding) + + self_attn_mask = None + + mask_pred_list = [] + mask_pred, attn_mask = self._forward_head_sam( + query_feat, mask_features, multi_scale_memorys[0].shape[-2:], + num_frames=num_frames + ) + # attn_mask = self._get_attn_mask_from_gt_mask(regions, multi_scale_memorys[0].shape[-2:]) + if num_frames > 0: + mask_pred = mask_pred.unflatten(2, (num_frames, -1)) + mask_pred_list.append(mask_pred) + + for i in range(self.num_transformer_decoder_layers): + level_idx = i % self.num_transformer_feat_level + # if a mask is all True(all background), then set it all False. + attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False + + # cross_attn + self_attn + + layer = self.transformer_decoder.layers[i] + layer = layer.to(query_feat.dtype) + layer.forward = MethodType(forward_cache, layer) + query_feat = layer( + query=query_feat, + key=decoder_inputs[level_idx], + value=decoder_inputs[level_idx], + query_pos=query_embed, + key_pos=decoder_positional_encodings[level_idx], + cross_attn_mask=attn_mask, + self_attn_mask=self_attn_mask, + query_key_padding_mask=None, + # here we do not apply masking on padded region + key_padding_mask=None, + query_cache=None, + query_cache_pos=None, + ) + mask_pred, attn_mask = self._forward_head_sam( + query_feat, mask_features, multi_scale_memorys[(i + 1) % self.num_transformer_feat_level].shape[-2:], + num_frames=num_frames + ) + + if num_frames > 0: + mask_pred = mask_pred.unflatten(2, (num_frames, -1)) + mask_pred_list.append(mask_pred) + + return query_feat, query_embed + + def forward_box_prompt(self, boxes, batch_idxs, width, height): + + # regions (N, H, W) + points = (boxes[:, :2] + boxes[:, 2:]) / 2.0 + boxes_rela_coords = copy.deepcopy(boxes) + boxes_rela_coords[:, [0, 2]] /= width + boxes_rela_coords[:, [1, 3]] /= height + + query_feat, query_embed = self.prepare_sam_query( + points.unsqueeze(1), w=width, h=height, + ) # (N, 1, c) + + # hidden_states (N, C) -> (N, 1, C) + num_frames = 0 + + mask_features = self.cur_batch_mask_features[batch_idxs] + multi_scale_memorys = self.cur_batch_multi_scale_memorys + + query_feat = query_feat.to(mask_features) + query_embed = query_embed.to(mask_features) + + decoder_inputs = [] + decoder_positional_encodings = [] + for i in range(self.num_transformer_feat_level): + decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i]) + # shape (batch_size, c, h, w) -> (batch_size, h*w, c) + decoder_input = decoder_input.flatten(2).permute(0, 2, 1) + level_embed = self.level_embed.weight[i].view(1, 1, -1) + decoder_input = decoder_input + level_embed + + decoder_input = decoder_input[batch_idxs] # (N, hw, c) + batch_size = len(decoder_input) + + # shape (batch_size, c, h, w) -> (batch_size, h*w, c) + num_frames_real = 1 + mask = decoder_input.new_zeros( + (batch_size, num_frames_real) + multi_scale_memorys[i].shape[-2:], + dtype=torch.bool) + decoder_positional_encoding = self.decoder_positional_encoding( + mask) + decoder_positional_encoding = decoder_positional_encoding.transpose( + 1, 2).flatten(2).permute(0, 2, 1) + decoder_inputs.append(decoder_input) + decoder_positional_encodings.append(decoder_positional_encoding) + + self_attn_mask = None + + mask_pred_list = [] + mask_pred, attn_mask = self._forward_head_sam( + query_feat, mask_features, multi_scale_memorys[0].shape[-2:], + num_frames=num_frames, bboxes=boxes_rela_coords, + ) + # attn_mask = self._get_attn_mask_from_gt_mask(regions, multi_scale_memorys[0].shape[-2:]) + if num_frames > 0: + mask_pred = mask_pred.unflatten(2, (num_frames, -1)) + mask_pred_list.append(mask_pred) + + for i in range(self.num_transformer_decoder_layers): + level_idx = i % self.num_transformer_feat_level + # if a mask is all True(all background), then set it all False. + attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False + + # cross_attn + self_attn + + layer = self.transformer_decoder.layers[i] + layer = layer.to(query_feat.dtype) + layer.forward = MethodType(forward_cache, layer) + query_feat = layer( + query=query_feat, + key=decoder_inputs[level_idx], + value=decoder_inputs[level_idx], + query_pos=query_embed, + key_pos=decoder_positional_encodings[level_idx], + cross_attn_mask=attn_mask, + self_attn_mask=self_attn_mask, + query_key_padding_mask=None, + # here we do not apply masking on padded region + key_padding_mask=None, + query_cache=None, + query_cache_pos=None, + ) + mask_pred, attn_mask = self._forward_head_sam( + query_feat, mask_features, multi_scale_memorys[(i + 1) % self.num_transformer_feat_level].shape[-2:], + num_frames=num_frames, bboxes=boxes_rela_coords, + ) + + if num_frames > 0: + mask_pred = mask_pred.unflatten(2, (num_frames, -1)) + mask_pred_list.append(mask_pred) + + return query_feat, query_embed + + def _get_attn_mask_from_gt_mask(self, regions, attn_mask_target_size): + regions = regions.unsqueeze(1) # (N, 1, H, W) + attn_mask = F.interpolate( + regions, + attn_mask_target_size, + mode='nearest') + # shape (num_queries, batch_size, h, w) -> + # (batch_size * num_head, num_queries, h, w) + attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat( + (1, self.num_heads, 1, 1)).flatten(0, 1) + attn_mask = attn_mask.to(torch.bool) + attn_mask = attn_mask.detach() + return ~attn_mask + + +def get_center_coords(masks): + point_coords = [] + for mask in masks: + mask = mask[None, None] + mask = mask.to(torch.bool) + n, _, h, w = mask.shape + mask_dt = ( + distance_transform( + (~F.pad(mask, pad=(1, 1, 1, 1), mode='constant', value=0)).float() + )[:, :, 1:-1, 1:-1] + ) + selected_point = torch.tensor([mask_dt.argmax() / w, mask_dt.argmax() % w]).long().flip(0).to( + mask.device) + point_coords.append(selected_point) + if len(point_coords) > 0: + point_coords = torch.stack(point_coords)[:, None] + else: + point_coords = torch.empty((0, 1, 2), dtype=torch.int32).to(device=mask.device) + return point_coords[:, 0] diff --git a/omg_llava/model/omg_seg/omg_seg_visual_encoder.py b/omg_llava/model/omg_seg/omg_seg_visual_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..1df7cc9eb785a61b006416999b941224ba3cd627 --- /dev/null +++ b/omg_llava/model/omg_seg/omg_seg_visual_encoder.py @@ -0,0 +1,256 @@ +from .mask2former_vid import Mask2formerVideo +from mmdet.structures import DetDataSample +import torch +import torch.nn.functional as F +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig + +class OMGSegVisualEncoder(Mask2formerVideo): + def __init__(self, + backbone: ConfigType, + neck: OptConfigType = None, + panoptic_head: OptConfigType = None, + panoptic_fusion_head: OptConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + inference_sam: bool = False, + init_cfg: OptMultiConfig = None, + dtype=torch.float32, + pixel_shuffle_down_ratio=2, + **kwargs, + ): + super().__init__(backbone=backbone, neck=neck, panoptic_head=panoptic_head, + panoptic_fusion_head=panoptic_fusion_head, train_cfg=train_cfg, + test_cfg=test_cfg, data_preprocessor=data_preprocessor, + inference_sam=inference_sam, init_cfg=init_cfg, ) + self.dtype = dtype + self.enable_output_gradient = False + self.backbone_type = None + self.pixel_shuffle_down_ratio = pixel_shuffle_down_ratio + + weight_path = init_cfg['checkpoint'] + state_dict = torch.load(weight_path)["state_dict"] + self.load_state_dict(state_dict, strict=False) + print("Loaded omg weight from {} !!!".format(weight_path)) + + def init_new_decoder(self): + self.panoptic_head.init_new_decoder() + return + + def init_cross_attn_layer(self): + self.panoptic_head.init_cross_attn_layer() + return + + def prepare_input(self, image): + # image (b, 3, h, w) + h, w = image.shape[-2:] + datasamples = DetDataSample() + metainfo = {'batch_input_shape': (h, w), + 'ori_shape': (h, w), + 'img_shape': (h, w)} + datasamples.set_metainfo(metainfo) + return image, [datasamples for i in range(image.shape[0])] + + def pixel_shuffle_feat(self, feat): + # feat (b, c, h, w) + if self.pixel_shuffle_down_ratio is None: + return feat + # pixel shuffle + b, c, h, w = feat.shape + feat = feat.reshape( + b, c, h // self.pixel_shuffle_down_ratio, self.pixel_shuffle_down_ratio, + w // self.pixel_shuffle_down_ratio, self.pixel_shuffle_down_ratio, + ) + feat = feat.permute(0, 3, 5, 1, 2, 4) # (bs, rh, rw, c, h_down, w_down) + feat = feat.flatten(1, 3) # (bs, rh * rw * c, h_down, w_down) + return feat + + def llava_visual_feat(self, backbone_feat): + # get clip feature + ret = [] + last_outs = backbone_feat[-1] + # more downsample ratio by pixel shuffle + last_outs = self.pixel_shuffle_feat(last_outs) + ret.append(last_outs.flatten(2).permute(0, 2, 1)) + return ret + + def forward_llm_seg(self, hidden_states, batch_idxs): + # hidden_states (N, 256) batch_idxs (N, ) + hidden_states = hidden_states.to(self.dtype) + + mask_pred_results = self.panoptic_head.forward_llm_seg( + hidden_states, batch_idxs, + ) + return mask_pred_results + + def forward_region_sam(self, regions, batch_idxs): + query_feat, query_embed = self.panoptic_head.forward_visual_prompt( + regions, batch_idxs + ) + return torch.cat([query_feat, query_embed], dim=-1) + + def forward_point_sam(self, points, batch_idxs, width, height): + query_feat, query_embed = self.panoptic_head.forward_point_prompt( + points, batch_idxs, width=width, height=height + ) + return torch.cat([query_feat, query_embed], dim=-1) + + def forward_box_sam(self, bboxes, batch_idxs, width, height): + query_feat, query_embed = self.panoptic_head.forward_box_prompt( + bboxes, batch_idxs, width=width, height=height + ) + return torch.cat([query_feat, query_embed], dim=-1) + + def loss_llm_seg(self, mask_pred_results, gt_masks): + all_loss_dice, all_loss_mask = self.panoptic_head.llm_seg_loss( + mask_pred_results, gt_masks, + ) + return sum(all_loss_dice), sum(all_loss_mask) + + def forward(self, images, output_hidden_states=True): + if self.backbone_type is None: + self.backbone_type = [p.dtype for p in self.parameters()][0] + self.to(self.dtype) + images = images.to(self.dtype) + + img_shape = images.shape[-2:] + # last scale for ConvNext-L + feat_shape = [item // 32 for item in img_shape] + if self.pixel_shuffle_down_ratio is not None: + feat_shape = [item // self.pixel_shuffle_down_ratio for item in feat_shape] + batch_inputs, batch_data_samples = self.prepare_input(images) + + # directly for image perception + num_frames = 0 # only consider image + bs = batch_inputs.shape[0] + feats = self.extract_feat(batch_inputs) + llava_clip_feat = self.llava_visual_feat(feats) + + # directly do panoptic segmentation + mask_cls_results, mask_pred_results, query_feat, query_pos, iou_results, mask_features =\ + self.panoptic_head.predict( + feats, batch_data_samples, return_query=True, return_mask_features=True, save_feat=True, + return_query_pos=True,) + + if self.OVERLAPPING is not None: + assert len(self.OVERLAPPING) == self.num_classes + mask_cls_results = self.open_voc_inference(feats, mask_cls_results, mask_pred_results) + + # llava_clip_feat [(bs, hw, c), ] + # query_feat (bs, q, c), query_pos (bs, q, c) + # mask_pred (b, q, h, w) + query_pos_feat = torch.cat([query_feat, query_pos], dim=-1) # (bs, q, 2c) + + ret_pixel_query = [] + ret_attn_mask = [] + for i in range(bs): + pixel_query, attn_mask = self.panoptic_postprocess( + mask_cls_results[i], mask_pred_results[i], query_pos_feat[i], + feat_size=feat_shape, + ) + ret_pixel_query.append(pixel_query) # (h, w, c) + ret_attn_mask.append(attn_mask) # (q, hw) + ret_pixel_query = torch.stack(ret_pixel_query, dim=0) # (bs, h, w, c) + ret_attn_mask = torch.stack(ret_attn_mask, dim=0) # (bs, q, hw) + + ret_pixel_query = ret_pixel_query.flatten(1, 2) # (bs, hw, c) + llava_clip_feat[0] = torch.cat([llava_clip_feat[0], ret_pixel_query], dim=-1) + ret = llava_clip_feat + [query_pos_feat, ret_attn_mask] + + for i in range(len(ret) - 1): + ret[i] = ret[i].to(self.backbone_type) + for i in range(len(ret) - 1): + ret[i] = self.set_output_gradient(ret[i]) + return ret + + def panoptic_postprocess(self, mask_cls, mask_pred, query_feat, feat_size=[320, 320]): + """assign queries for per pixel. + mask_cls (q, c) + mask_pred (q, h, w) + query_feat (q, c) + """ + + scores_foreground, _ = F.softmax(mask_cls, dim=-1)[..., :-1].max(-1) + + mask_pred = mask_pred.sigmoid() + cur_scores = scores_foreground + cur_masks = mask_pred + + cur_query_feat = query_feat + # smooth_ratio = 0.05 + smooth_ratio = 0.5 + # use 0.5 as the smooth score + cur_prob_masks = (cur_scores.view(-1, 1, 1) * (1 - smooth_ratio) + smooth_ratio) * cur_masks + + # for visualization + ori_prob_masks = cur_prob_masks + + cur_prob_masks = F.interpolate(cur_prob_masks.unsqueeze(0), + size=feat_size, + mode='bilinear', + align_corners=False + )[0] + + cur_mask_ids = cur_prob_masks.argmax(0).unsqueeze(0) # (1, h, w) + + # maybe need add the low threshold and re softmax, i.e. 0.1 + # pixel_query = cur_prob_masks.softmax(dim=0) + # pixel_query = (pixel_query > 0.1).to(pixel_query.dtype) * pixel_query + # pixel_query = pixel_query.softmax(dim=0).permute(1, 2, 0) @ \ + # cur_query_feat # (h, w, c) + + # pixel_query = cur_query_feat[cur_mask_ids[0]] + + pixel_query = cur_prob_masks.softmax(dim=0).permute(1, 2, 0) @ \ + cur_query_feat # (h, w, c) + + # need attn mask to filter none mask + attn_mask = cur_mask_ids != torch.arange(0, cur_query_feat.shape[0]).unsqueeze(1).unsqueeze(2).to(cur_mask_ids.device) + attn_mask = attn_mask.flatten(1) + + # for visualization + if not self.training: + valid_mask = attn_mask.sum(dim=-1) < attn_mask.shape[-1] # (bs, q) + keep = valid_mask + ori_prob_masks = ori_prob_masks[keep] + # self.vis_binary_masks = ori_prob_masks.unsqueeze(1) + + vis_mask_ids = ori_prob_masks.argmax(0).unsqueeze(0) + self.vis_binary_masks = (vis_mask_ids == torch.arange(0, ori_prob_masks.shape[0]).unsqueeze(1).unsqueeze(2).to( + cur_mask_ids.device)).unsqueeze(1).to(torch.float32) + # vis_mask_ids = ori_prob_masks.softmax(0) - 0.1 + # self.vis_binary_masks = vis_mask_ids.unsqueeze(1) + + # vis_prob_masks = F.interpolate(ori_prob_masks.unsqueeze(0), + # size=(1024, 1024), + # mode='bilinear', + # align_corners=False + # )[0] + # vis_mask_ids = vis_prob_masks.argmax(0).unsqueeze(0) + # self.vis_binary_masks = vis_mask_ids == torch.arange(0, ori_prob_masks.shape[0]).unsqueeze(1).unsqueeze(2).to( + # cur_mask_ids.device) + # self.vis_binary_masks = vis_prob_masks > 0.5 + + return pixel_query, attn_mask + + def enable_input_require_grads(self): + self.enable_output_gradient = True + return + + def set_output_gradient(self, output): + if not self.training: + return output + output.requires_grad_(self.enable_output_gradient) + return output + + def requires_grad_(self, state): + if not self.training: + return + if state: + print("Not Frozen the Visual Encoder !") + else: + print("Frozen the Visual Encoder !") + for p in self.parameters(): + p.requires_grad_(state) + return + diff --git a/omg_llava/model/omg_seg/utils.py b/omg_llava/model/omg_seg/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..66dcfc1b9ee47461689f4b2f938c86aa5c328abd --- /dev/null +++ b/omg_llava/model/omg_seg/utils.py @@ -0,0 +1,27 @@ +import torch +import torch.nn.functional as F + +def mask_pool(x, mask): + """ + Args: + x: [B, C, H, W] + mask: [B, Q, H, W] + """ + if not x.shape[-2:] == mask.shape[-2:]: + # reshape mask to x + mask = F.interpolate(mask, size=x.shape[-2:], mode='bilinear', align_corners=False) + with torch.no_grad(): + mask = mask.detach() + mask = (mask > 0).to(mask.dtype) + denorm = mask.sum(dim=(-1, -2), keepdim=True) + 1e-8 + + mask_pooled_x = torch.einsum( + "bchw,bqhw->bqc", + x, + mask / denorm, + ) + + return mask_pooled_x + + + diff --git a/omg_llava/model/utils.py b/omg_llava/model/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..61e0f4445577c183580d8ed73b56d5c75ee76fcd --- /dev/null +++ b/omg_llava/model/utils.py @@ -0,0 +1,195 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from xtuner.model.utils import * +from typing import List, Optional +import torch +from transformers import PreTrainedModel +from xtuner.utils import IGNORE_INDEX, IMAGE_TOKEN_INDEX + +def prepare_inputs_labels_for_multimodal_with_visual_prompts( + llm: PreTrainedModel, + input_ids: torch.LongTensor = None, + position_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + labels: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + region_id=None, + regions_feats=None, + mark_id=None, + mark_feats=None, + **kwargs, +): + if pixel_values is None: + return { + 'input_ids': input_ids, + 'position_ids': position_ids, + 'attention_mask': attention_mask, + 'past_key_values': past_key_values, + 'inputs_embeds': None, + 'labels': labels + } + + _labels = labels + _position_ids = position_ids + _attention_mask = attention_mask + if attention_mask is None: + attention_mask = torch.ones_like(input_ids, dtype=torch.bool) + else: + attention_mask = attention_mask.bool() + if position_ids is None: + position_ids = torch.arange( + 0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) + if labels is None: + labels = torch.full_like(input_ids, IGNORE_INDEX) + + # remove the padding using attention_mask -- TODO: double check + input_ids = [ + cur_input_ids[cur_attention_mask] + for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask) + ] + labels = [ + cur_labels[cur_attention_mask] + for cur_labels, cur_attention_mask in zip(labels, attention_mask) + ] + + new_inputs_embeds = [] + new_labels = [] + cur_image_idx = 0 + cur_region_idx = 0 + cur_mark_id = 0 + for batch_idx, cur_input_ids in enumerate(input_ids): + num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() + if num_images == 0: + cur_pixel_values = pixel_values[cur_image_idx] + cur_inputs_embeds_1 = llm.get_input_embeddings()(cur_input_ids) + cur_inputs_embeds = torch.cat( + [cur_inputs_embeds_1, cur_pixel_values[0:0]], dim=0) + new_inputs_embeds.append(cur_inputs_embeds) + new_labels.append(labels[batch_idx]) + cur_image_idx += 1 + continue + + need_replace = cur_input_ids == IMAGE_TOKEN_INDEX + need_replace = torch.logical_or(need_replace, cur_input_ids == region_id) + need_replace = torch.logical_or(need_replace, cur_input_ids == mark_id) + num_replace = need_replace.sum() + replace_type = cur_input_ids[need_replace] + + image_token_indices = [-1] + torch.where( + need_replace)[0].tolist() + [ + cur_input_ids.shape[0] + ] + cur_input_ids_noim = [] + cur_labels = labels[batch_idx] + cur_labels_noim = [] + for i in range(len(image_token_indices) - 1): + + cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + + 1:image_token_indices[i + + 1]]) + cur_labels_noim.append(cur_labels[image_token_indices[i] + + 1:image_token_indices[i + 1]]) + split_sizes = [x.shape[0] for x in cur_labels_noim] + cur_inputs_embeds = llm.get_input_embeddings()( + torch.cat(cur_input_ids_noim)) + cur_inputs_embeds_no_im = torch.split( + cur_inputs_embeds, split_sizes, dim=0) + cur_new_inputs_embeds = [] + cur_new_labels = [] + + for i in range(num_replace + 1): + cur_new_inputs_embeds.append(cur_inputs_embeds_no_im[i]) + cur_new_labels.append(cur_labels_noim[i]) + if i < num_replace: + # image + if replace_type[i] == IMAGE_TOKEN_INDEX: + cur_pixel_values = pixel_values[cur_image_idx] + cur_image_idx += 1 + cur_new_inputs_embeds.append(cur_pixel_values) + cur_new_labels.append( + torch.full((cur_pixel_values.shape[0], ), + IGNORE_INDEX, + device=cur_labels.device, + dtype=cur_labels.dtype)) + elif replace_type[i] == region_id: + cur_pixel_values = regions_feats[cur_region_idx:cur_region_idx+1] + cur_region_idx += 1 + cur_new_inputs_embeds.append(cur_pixel_values) + cur_new_labels.append( + torch.full((cur_pixel_values.shape[0],), + IGNORE_INDEX, + device=cur_labels.device, + dtype=cur_labels.dtype)) + elif replace_type[i] == mark_id: + cur_pixel_values = mark_feats[cur_mark_id:cur_mark_id + 1] + cur_mark_id += 1 + cur_new_inputs_embeds.append(cur_pixel_values) + cur_new_labels.append( + torch.full((cur_pixel_values.shape[0],), + IGNORE_INDEX, + device=cur_labels.device, + dtype=cur_labels.dtype)) + + cur_new_inputs_embeds = torch.cat(cur_new_inputs_embeds) + cur_new_labels = torch.cat(cur_new_labels) + + new_inputs_embeds.append(cur_new_inputs_embeds) + new_labels.append(cur_new_labels) + + # Combine them + max_len = max(x.shape[0] for x in new_inputs_embeds) + batch_size = len(new_inputs_embeds) + + new_inputs_embeds_padded = [] + new_labels_padded = torch.full((batch_size, max_len), + IGNORE_INDEX, + dtype=new_labels[0].dtype, + device=new_labels[0].device) + attention_mask = torch.zeros((batch_size, max_len), + dtype=attention_mask.dtype, + device=attention_mask.device) + position_ids = torch.zeros((batch_size, max_len), + dtype=position_ids.dtype, + device=position_ids.device) + + for i, (cur_new_embed, + cur_new_labels) in enumerate(zip(new_inputs_embeds, new_labels)): + cur_len = cur_new_embed.shape[0] + new_inputs_embeds_padded.append( + torch.cat((cur_new_embed, + torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), + dtype=cur_new_embed.dtype, + device=cur_new_embed.device)), + dim=0)) + if cur_len > 0: + new_labels_padded[i, :cur_len] = cur_new_labels + attention_mask[i, :cur_len] = True + position_ids[i, :cur_len] = torch.arange( + 0, + cur_len, + dtype=position_ids.dtype, + device=position_ids.device) + + new_inputs_embeds = torch.stack(new_inputs_embeds_padded, dim=0) + + if _labels is None: + new_labels = None + else: + new_labels = new_labels_padded + + if _attention_mask is None: + attention_mask = None + else: + attention_mask = attention_mask.to(dtype=_attention_mask.dtype) + + if _position_ids is None: + position_ids = None + + return { + 'input_ids': None, + 'position_ids': position_ids, + 'attention_mask': attention_mask, + 'past_key_values': past_key_values, + 'inputs_embeds': new_inputs_embeds, + 'labels': new_labels, + } diff --git a/omg_llava/tools/__init__.py b/omg_llava/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/omg_llava/tools/__pycache__/__init__.cpython-310.pyc b/omg_llava/tools/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e3a9d4f4c84e245ac64a5d89f4fad6bd6be805e Binary files /dev/null and b/omg_llava/tools/__pycache__/__init__.cpython-310.pyc differ diff --git a/omg_llava/tools/__pycache__/app_utils.cpython-310.pyc b/omg_llava/tools/__pycache__/app_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9eec6e36d867da86f92d62209f654a5a2cc81307 Binary files /dev/null and b/omg_llava/tools/__pycache__/app_utils.cpython-310.pyc differ diff --git a/omg_llava/tools/app.py b/omg_llava/tools/app.py new file mode 100644 index 0000000000000000000000000000000000000000..465e052d208d663d78e271867125e3e2d7e47a79 --- /dev/null +++ b/omg_llava/tools/app.py @@ -0,0 +1,530 @@ +import cv2 +import random +import gradio as gr +import numpy as np +from PIL import Image +import torch.nn.functional as F +import sys +from omg_llava.tools.app_utils import process_markdown, show_mask_pred, parse_visual_prompts + +import torch +from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel, GenerationConfig) +from transformers.generation.streamers import TextStreamer + +from xtuner.dataset.utils import expand2square, load_image +from omg_llava.dataset.utils import expand2square_bbox, expand2square_mask, expand2square_points +from xtuner.model.utils import prepare_inputs_labels_for_multimodal +from omg_llava.model.utils import prepare_inputs_labels_for_multimodal_with_visual_prompts +from xtuner.tools.utils import get_stop_criteria +from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, + PROMPT_TEMPLATE, SYSTEM_TEMPLATE) + +import argparse +import os.path as osp + +from mmengine.config import Config, DictAction +from mmengine.fileio import PetrelBackend, get_file_backend + +from xtuner.configs import cfgs_name_path +from xtuner.model.utils import guess_load_checkpoint +from xtuner.registry import BUILDER + +from gradio_image_prompter import ImagePrompter + +TORCH_DTYPE_MAP = dict( + fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto') + +def parse_args(args): + parser = argparse.ArgumentParser(description="OMG-LLaVA Demo") + parser.add_argument('config', help='config file name or path.') + parser.add_argument('pth_model', help='pth model file') + + parser.add_argument('--image', default=None, help='image') + parser.add_argument( + '--torch-dtype', + default='fp16', + choices=TORCH_DTYPE_MAP.keys(), + help='Override the default `torch.dtype` and load the model under ' + 'a specific `dtype`.') + parser.add_argument( + '--prompt-template', + choices=PROMPT_TEMPLATE.keys(), + default="internlm2_chat", + help='Specify a prompt template') + system_group = parser.add_mutually_exclusive_group() + system_group.add_argument( + '--system', default=None, help='Specify the system text') + system_group.add_argument( + '--system-template', + choices=SYSTEM_TEMPLATE.keys(), + default=None, + help='Specify a system template') + parser.add_argument( + '--bits', + type=int, + choices=[4, 8, None], + default=None, + help='LLM bits') + parser.add_argument( + '--bot-name', type=str, default='BOT', help='Name for Bot') + parser.add_argument( + '--with-plugins', + nargs='+', + choices=['calculate', 'solve', 'search'], + help='Specify plugins to use') + parser.add_argument( + '--no-streamer', action='store_true', help='Whether to with streamer') + parser.add_argument( + '--lagent', action='store_true', help='Whether to use lagent') + parser.add_argument( + '--stop-words', nargs='+', type=str, default=[], help='Stop words') + parser.add_argument( + '--offload-folder', + default=None, + help='The folder in which to offload the model weights (or where the ' + 'model weights are already offloaded).') + parser.add_argument( + '--max-new-tokens', + type=int, + default=2048, + help='Maximum number of new tokens allowed in generated text') + parser.add_argument( + '--temperature', + type=float, + default=0.1, + help='The value used to modulate the next token probabilities.') + parser.add_argument( + '--top-k', + type=int, + default=40, + help='The number of highest probability vocabulary tokens to ' + 'keep for top-k-filtering.') + parser.add_argument( + '--top-p', + type=float, + default=0.75, + help='If set to float < 1, only the smallest set of most probable ' + 'tokens with probabilities that add up to top_p or higher are ' + 'kept for generation.') + parser.add_argument( + '--repetition-penalty', + type=float, + default=1.0, + help='The parameter for repetition penalty. 1.0 means no penalty.') + parser.add_argument( + '--seed', + type=int, + default=0, + help='Random seed for reproducible text generation') + return parser.parse_args(args) + +def get_points_embeddings(points, input_ids, width, height, + mark_token_idx, mode='point'): + if points is None or len(points) == 0: + return [] + + mark_token_mask = input_ids == mark_token_idx + batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to( + input_ids.device) + batch_idxs = batch_idxs[mark_token_mask] # (N, ) batch_size number + + points = points.to(torch.float32) + # print(points.dtype, batch_idxs.dtype) + + if mode == 'point': + marks_embeddings = visual_encoder.forward_point_sam( + points, batch_idxs, width=width, height=height + )[:, 0] # (N, C) + elif mode == 'box': + marks_embeddings = visual_encoder.forward_box_sam( + points, batch_idxs, width=width, height=height + )[:, 0] # (N, C) + else: + raise NotImplementedError + + marks_embeddings = marks_embeddings.to(projector.model.query_proj.weight.dtype) + marks_embeddings = projector.model.query_proj(marks_embeddings) + marks_embeddings = projector.model.model(marks_embeddings) + print('marks_embeddings shape: ', marks_embeddings.shape) + return marks_embeddings # (N, C) + +def get_visual_prompts_embeddings( + height, width, input_ids, +): + points_prompts = global_infos.point_prompts + boxes_prompts = global_infos.box_prompts + + if len(points_prompts) == 0: + points_mark_embedding = [] + else: + points = np.array(points_prompts) + points = expand2square_points(points, height=height, width=width) + points[:, 0] = points[:, 0] / max(height, width) * 1024 + points[:, 1] = points[:, 1] / max(height, width) * 1024 + points = torch.from_numpy(points) + points = points.cuda() + mark_token_id = omg_llava.mark_token_idx + + points_mark_embedding = get_points_embeddings( + points, input_ids, + 1024, 1024, + mark_token_id) + + + if len(boxes_prompts) == 0: + boxes_mark_embedding = [] + else: + boxes_prompts = np.array(boxes_prompts) + + boxes_prompts = expand2square_bbox(boxes_prompts, height=height, width=width) + boxes_prompts[:, [0, 2]] = boxes_prompts[:, [0, 2]] / max(height, width) * 1024 + boxes_prompts[:, [1, 3]] = boxes_prompts[:, [1, 3]] / max(height, width) * 1024 + boxes_prompts = torch.from_numpy(boxes_prompts) + boxes_prompts = torch.from_numpy(boxes_prompts) + boxes_prompts = boxes_prompts.cuda() + # using token + region_token_id = omg_llava.region_token_idx + + boxes_mark_embedding = get_points_embeddings( + boxes_prompts, input_ids, + 1024, 1024, + region_token_id) + return points_mark_embedding, boxes_mark_embedding + +def inference(input_str, all_inputs, follow_up): + input_str = input_str.replace('', '')\ + .replace('', '') + print("Get Recieved Infos !!!") + prompts = all_inputs['points'] + visual_prompts = parse_visual_prompts(prompts) + input_image = all_inputs['image'] + + print("follow_up: ", follow_up) + print(prompts) + print("input_str: ", input_str, "input_image: ", input_image) + + # + if not follow_up: + # reset + print('Log: History responses have been removed!') + global_infos.n_turn = 0 + global_infos.inputs = '' + # reset prompts + global_infos.point_prompts = [] + global_infos.box_prompts = [] + global_infos.mask_prompts = [] + + # first conversation, add image tokens + text = DEFAULT_IMAGE_TOKEN + '\n' + input_str + + # prepare image + image = load_image(input_image) + width, height = image.size + global_infos.image_width = width + global_infos.image_height = height + image = expand2square( + image, tuple(int(x * 255) for x in image_processor.image_mean)) + global_infos.image_for_show = image + image = image_processor.preprocess( + image, return_tensors='pt')['pixel_values'][0] + image = image.cuda().unsqueeze(0).to(visual_encoder.dtype) + visual_outputs = visual_encoder(image, output_hidden_states=True) + pixel_values = projector(visual_outputs) + global_infos.panoptic_masks = omg_llava.visual_encoder.vis_binary_masks + global_infos.pixel_values = pixel_values + + # for remove padding + if width == height: + sx, ex, sy, ey = 0, width, 0, height + elif width > height: + sy = int((width - height) / 2.0) + ey = width - sy + sx, ex = 0, width + else: + sx = int((height - width) / 2.0) + ex = height - sx + sy, ey = 0, height + + global_infos.sx = sx + global_infos.sy = sy + global_infos.ex = ex + global_infos.ey = ey + + else: + text = input_str + pixel_values = global_infos.pixel_values + + # add cur prompts into global prompts + global_infos.point_prompts += visual_prompts['points'] + global_infos.box_prompts += visual_prompts['boxes'] + + if args.prompt_template: + prompt_text = '' + template = PROMPT_TEMPLATE[args.prompt_template] + if 'SYSTEM' in template and global_infos.n_turn == 0: + system_text = None + if args.system_template is not None: + system_text = SYSTEM_TEMPLATE[ + args.system_template].format( + round=global_infos.n_turn + 1, bot_name=args.bot_name) + elif args.system is not None: + system_text = args.system + if system_text is not None: + prompt_text += template['SYSTEM'].format( + system=system_text, + round=global_infos.n_turn + 1, + bot_name=args.bot_name) + prompt_text += template['INSTRUCTION'].format( + input=text, round=global_infos.n_turn + 1, bot_name=args.bot_name) + else: + prompt_text = text + + print("prompt_text: ", prompt_text) + global_infos.inputs += prompt_text + + # encode prompt text + chunk_encode = [] + for idx, chunk in enumerate(global_infos.inputs.split(DEFAULT_IMAGE_TOKEN)): + if idx == 0 and global_infos.n_turn == 0: + cur_encode = tokenizer.encode(chunk) + else: + cur_encode = tokenizer.encode( + chunk, add_special_tokens=False) + chunk_encode.append(cur_encode) + assert len(chunk_encode) == 2 + ids = [] + for idx, cur_chunk_encode in enumerate(chunk_encode): + ids.extend(cur_chunk_encode) + if idx != len(chunk_encode) - 1: + ids.append(IMAGE_TOKEN_INDEX) + ids = torch.tensor(ids).cuda().unsqueeze(0) + + points_mark_embeddings, boxes_mark_embeddings = get_visual_prompts_embeddings( + height=global_infos.image_height, + width=global_infos.image_width, input_ids=ids + ) + + mark_embeddings = points_mark_embeddings + + mark_token_id = omg_llava.mark_token_idx + mm_inputs = prepare_inputs_labels_for_multimodal_with_visual_prompts( + llm=llm, input_ids=ids, pixel_values=pixel_values, + mark_id=mark_token_id, + mark_feats=mark_embeddings, region_id=-9999) + + # mm_inputs['inputs_embeds'] = mm_inputs['inputs_embeds'].to(torch.float16) + + generate_output = llm.generate( + **mm_inputs, + generation_config=gen_config, + streamer=streamer, + bos_token_id=tokenizer.bos_token_id, + stopping_criteria=stop_criteria, + output_hidden_states=True, + return_dict_in_generate=True + ) + + predict = tokenizer.decode( + generate_output.sequences[0]) + + global_infos.inputs += predict + predict = predict.strip() + global_infos.n_turn += 1 + global_infos.inputs += sep + if len(generate_output.sequences[0]) >= args.max_new_tokens: + print( + 'Remove the memory of history responses, since ' + f'it exceeds the length limitation {args.max_new_tokens}.') + global_infos.n_turn = 0 + global_infos.inputs = '' + + hidden_states = generate_output.hidden_states + last_hidden_states = [item[-1][0] for item in hidden_states] + last_hidden_states = torch.cat(last_hidden_states, dim=0) + seg_hidden_states = get_seg_hidden_states( + last_hidden_states, generate_output.sequences[0][:-1], + seg_id=omg_llava.seg_token_idx + ) + # seg_hidden_states = seg_hidden_states.to(torch.float32) + if len(seg_hidden_states) != 0: + seg_hidden_states = projector_text2vision(seg_hidden_states) + batch_idxs = torch.zeros((seg_hidden_states.shape[0],), + dtype=torch.int64).to(seg_hidden_states.device) + pred_masks_list = omg_llava.visual_encoder.forward_llm_seg(seg_hidden_states, batch_idxs) + print((pred_masks_list[-1].flatten(2) > 0).sum(-1)) + print(pred_masks_list[-1].shape) + image_mask_show, selected_colors = show_mask_pred( + global_infos.image_for_show, pred_masks_list[-1], + crop_range = (global_infos.sx, global_infos.ex, global_infos.sy, global_infos.ey) + ) + else: + image_mask_show = global_infos.image_for_show.crop( + (global_infos.sx, global_infos.sy, global_infos.ex, global_infos.ey)) + selected_colors = [] + + panoptic_show, _ = show_mask_pred( + global_infos.image_for_show, global_infos.panoptic_masks, + crop_range=(global_infos.sx, global_infos.ex, global_infos.sy, global_infos.ey) + ) + + predict = process_markdown(predict, selected_colors) + # return panoptic_show, image_mask_show, predict + return image_mask_show, predict + +def init_models(args): + torch.manual_seed(args.seed) + + # parse config + if not osp.isfile(args.config): + try: + args.config = cfgs_name_path[args.config] + except KeyError: + raise FileNotFoundError(f'Cannot find {args.config}') + + # load config + cfg = Config.fromfile(args.config) + + model_name = cfg.model.type if isinstance(cfg.model.type, + str) else cfg.model.type.__name__ + if 'LLaVAModel' or 'OMG' in model_name: + cfg.model.pretrained_pth = None + + model = BUILDER.build(cfg.model) + + backend = get_file_backend(args.pth_model) + if isinstance(backend, PetrelBackend): + from xtuner.utils.fileio import patch_fileio + with patch_fileio(): + state_dict = guess_load_checkpoint(args.pth_model) + else: + state_dict = guess_load_checkpoint(args.pth_model) + + model.load_state_dict(state_dict, strict=False) + print(f'Load PTH model from {args.pth_model}') + + image_processor = cfg.image_processor + image_processor_type = image_processor['type'] + del image_processor['type'] + image_processor = image_processor_type(**image_processor) + + # build llm + quantization_config = None + load_in_8bit = False + if args.bits == 4: + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4') + elif args.bits == 8: + load_in_8bit = True + model_kwargs = { + 'quantization_config': quantization_config, + 'load_in_8bit': load_in_8bit, + 'device_map': 'auto', + 'offload_folder': args.offload_folder, + 'trust_remote_code': True, + 'torch_dtype': TORCH_DTYPE_MAP[args.torch_dtype] + } + + inner_thoughts_open = False + calculate_open = False + solve_open = False + search_open = False + + # build llm + llm = model.llm + tokenizer = model.tokenizer + + model.cuda() + model.eval() + llm.eval() + visual_encoder = model.visual_encoder + projector = model.projector + projector_text2vision = model.projector_text2vision + + visual_encoder.eval() + projector.eval() + projector_text2vision.eval() + + return model, llm, tokenizer, image_processor, visual_encoder, projector, projector_text2vision + +def get_seg_hidden_states(hidden_states, output_ids, seg_id): + seg_mask = output_ids == seg_id + n_out = len(seg_mask) + print(output_ids) + return hidden_states[-n_out:][seg_mask] + +class global_infos: + inputs = '' + n_turn = 0 + image_width = 0 + image_height = 0 + + image_for_show = None + pixel_values = None + panoptic_masks = None + + sx, sy, ex, ey = 0, 0 ,1024, 1024 + + point_prompts = [] + box_prompts = [] + mask_prompts = [] + +if __name__ == "__main__": + # get parse args and set models + args = parse_args(sys.argv[1:]) + + omg_llava, llm, tokenizer, image_processor, visual_encoder, projector, projector_text2vision = \ + init_models(args) + + stop_words = args.stop_words + sep = '' + if args.prompt_template: + template = PROMPT_TEMPLATE[args.prompt_template] + stop_words += template.get('STOP_WORDS', []) + sep = template.get('SEP', '') + stop_criteria = get_stop_criteria( + tokenizer=tokenizer, stop_words=stop_words) + + if args.no_streamer: + streamer = None + else: + streamer = TextStreamer(tokenizer, skip_prompt=True) + + gen_config = GenerationConfig( + max_new_tokens=args.max_new_tokens, + do_sample=args.temperature > 0, + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + repetition_penalty=args.repetition_penalty, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id + if tokenizer.pad_token_id is not None else tokenizer.eos_token_id, + ) + + demo = gr.Interface( + inference, inputs=[gr.Textbox(lines=1, placeholder=None, label="Text Instruction"), ImagePrompter( + type='filepath', label='Input Image (Please click points or draw bboxes)', interactive=True, + elem_id='image_upload', height=360, visible=True, render=True + ), + gr.Checkbox(label="Follow up Question")], + outputs=[ + # gr.Image(type="pil", label="Panoptic Segmentation", height=360), + gr.Image(type="pil", label="Output Image"), + gr.Markdown()], + theme=gr.themes.Soft(), allow_flagging="auto", ) + + demo.queue() + demo.launch(share=True) + + # gr.Image( + # type='filepath', label='Input Image (Please draw bounding boxes)', interactive=True, + # elem_id='image_upload', height=360, + # ) \ No newline at end of file diff --git a/omg_llava/tools/app_utils.py b/omg_llava/tools/app_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..80750479423469b0a2c0a53cd2ac3a258ebe3753 --- /dev/null +++ b/omg_llava/tools/app_utils.py @@ -0,0 +1,170 @@ +import torch.nn.functional as F +import numpy as np +import torch + +markdown_default = """ + + +OMG-LLaVA +""" + +ONE_THIRD = 1.0/3.0 +ONE_SIXTH = 1.0/6.0 +TWO_THIRD = 2.0/3.0 + +def desaturate(rgb, factor=0.65): + """ + Desaturate an RGB color by a given factor. + + :param rgb: A tuple of (r, g, b) where each value is in [0, 255]. + :param factor: The factor by which to reduce the saturation. + 0 means completely desaturated, 1 means original color. + :return: A tuple of desaturated (r, g, b) values in [0, 255]. + """ + r, g, b = [x / 255.0 for x in rgb] + h, l, s = rgb_to_hls(r, g, b) + l = factor + new_r, new_g, new_b = hls_to_rgb(h, l, s) + return (int(new_r * 255), int(new_g * 255), int(new_b * 255)) + +def rgb_to_hls(r, g, b): + maxc = max(r, g, b) + minc = min(r, g, b) + sumc = (maxc+minc) + rangec = (maxc-minc) + l = sumc/2.0 + if minc == maxc: + return 0.0, l, 0.0 + if l <= 0.5: + s = rangec / sumc + else: + s = rangec / (2.0-sumc) + rc = (maxc-r) / rangec + gc = (maxc-g) / rangec + bc = (maxc-b) / rangec + if r == maxc: + h = bc-gc + elif g == maxc: + h = 2.0+rc-bc + else: + h = 4.0+gc-rc + h = (h/6.0) % 1.0 + return h, l, s + +def hls_to_rgb(h, l, s): + if s == 0.0: + return l, l, l + if l <= 0.5: + m2 = l * (1.0+s) + else: + m2 = l+s-(l*s) + m1 = 2.0*l - m2 + return (_v(m1, m2, h+ONE_THIRD), _v(m1, m2, h), _v(m1, m2, h-ONE_THIRD)) + +def _v(m1, m2, hue): + hue = hue % 1.0 + if hue < ONE_SIXTH: + return m1 + (m2-m1)*hue*6.0 + if hue < 0.5: + return m2 + if hue < TWO_THIRD: + return m1 + (m2-m1)*(TWO_THIRD-hue)*6.0 + return m1 + +def process_markdown(output_str, colors): + output_str = output_str.replace("\n", "").replace(" ", " ").replace("", "")\ + .replace("<|im_end|>", '') + output_str = output_str.split("ASSISTANT: ")[-1] + + markdown_out = output_str.replace('[SEG]', '') + markdown_out = markdown_out.replace( + "

", "" + ) + markdown_out = markdown_out.replace("

", "") + + for color in colors: + markdown_out = markdown_out.replace("[COLOR]", str(desaturate(tuple(color))), 1) + + markdown_out = f""" + {markdown_out} + """ + markdown_out = markdown_default + "

" + markdown_out + return markdown_out + +def show_mask_pred(image, masks, crop_range=(0, 1024, 0, 1024)): + print(crop_range) + + selected_colors = [] + + colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), + (255, 255, 0), (255, 0, 255), (0, 255, 255), + (128, 128, 255), [255, 192, 203], # Pink + [165, 42, 42], # Brown + [255, 165, 0], # Orange + [128, 0, 128], # Purple + [0, 0, 128], # Navy + [128, 0, 0], # Maroon + [128, 128, 0], # Olive + [70, 130, 180], # Steel Blue + [173, 216, 230], # Light Blue + [255, 192, 0], # Gold + [255, 165, 165], # Light Salmon + [255, 20, 147], # Deep Pink + ] + + masks = F.interpolate(masks, size=image.size, mode='bilinear', align_corners=False) + masks = masks.sigmoid() > 0.5 + masks = masks.to(torch.uint8).cpu().numpy()[:, 0] + + _mask_image = np.zeros((masks.shape[1], masks.shape[2], 3), dtype=np.uint8) + + for i, mask in enumerate(masks): + color = colors[i % len(colors)] + selected_colors.append(color) + _mask_image[:, :, 0] = _mask_image[:, :, 0] + mask.astype(np.uint8) * color[0] + _mask_image[:, :, 1] = _mask_image[:, :, 1] + mask.astype(np.uint8) * color[1] + _mask_image[:, :, 2] = _mask_image[:, :, 2] + mask.astype(np.uint8) * color[2] + + + image = np.array(image) + image = image * 0.5 + _mask_image * 0.5 + image = image.astype(np.uint8) + image = image[crop_range[2]: crop_range[3], crop_range[0]: crop_range[1], :] + # image = Image.fromarray(image) + # image.save(save_dir) + return image, selected_colors + +def parse_visual_prompts(points): + ret = {'points': [], 'boxes': []} + for item in points: + if item[2] == 1.0: + ret['points'].append([item[0], item[1]]) + elif item[2] == 2.0 or item[2] == 3.0: + ret['boxes'].append(item[[0, 1, 3, 4]]) + else: + raise NotImplementedError + return ret \ No newline at end of file diff --git a/omg_llava/tools/chat_omg_llava.py b/omg_llava/tools/chat_omg_llava.py new file mode 100644 index 0000000000000000000000000000000000000000..9f0cde989e8ca42055bd9e90b38b8a46ab25fc65 --- /dev/null +++ b/omg_llava/tools/chat_omg_llava.py @@ -0,0 +1,518 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import copy +import os +import os.path as osp +import re +import sys + +import torch +from huggingface_hub import snapshot_download +from peft import PeftModel +from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel, GenerationConfig) +from transformers.generation.streamers import TextStreamer + +from xtuner.dataset.utils import expand2square, load_image +from xtuner.model.utils import prepare_inputs_labels_for_multimodal +from xtuner.tools.utils import get_stop_criteria +from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, + PROMPT_TEMPLATE, SYSTEM_TEMPLATE) + +import argparse +import os.path as osp + +from mmengine.config import Config, DictAction +from mmengine.fileio import PetrelBackend, get_file_backend + +from xtuner.configs import cfgs_name_path +from xtuner.model.utils import guess_load_checkpoint +from xtuner.registry import BUILDER + +TORCH_DTYPE_MAP = dict( + fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto') + +from xtuner.engine.hooks.evaluate_chat_hook import EvaluateChatHook + +def remove_prefix(state_dict, prefix): + new_state_dict = {} + for key, value in state_dict.items(): + if key.startswith(prefix): + new_key = key[len(prefix):] + new_state_dict[new_key] = value + else: + new_state_dict[key] = value + return new_state_dict + + +def parse_args(): + parser = argparse.ArgumentParser(description='Chat with a HF model') + parser.add_argument('config', help='config file name or path.') + parser.add_argument('pth_model', help='pth model file') + + parser.add_argument('--image', default=None, help='image') + parser.add_argument( + '--torch-dtype', + default='fp16', + choices=TORCH_DTYPE_MAP.keys(), + help='Override the default `torch.dtype` and load the model under ' + 'a specific `dtype`.') + parser.add_argument( + '--prompt-template', + choices=PROMPT_TEMPLATE.keys(), + default="internlm2_chat", + help='Specify a prompt template') + system_group = parser.add_mutually_exclusive_group() + system_group.add_argument( + '--system', default=None, help='Specify the system text') + system_group.add_argument( + '--system-template', + choices=SYSTEM_TEMPLATE.keys(), + default=None, + help='Specify a system template') + parser.add_argument( + '--bits', + type=int, + choices=[4, 8, None], + default=None, + help='LLM bits') + parser.add_argument( + '--bot-name', type=str, default='BOT', help='Name for Bot') + parser.add_argument( + '--with-plugins', + nargs='+', + choices=['calculate', 'solve', 'search'], + help='Specify plugins to use') + parser.add_argument( + '--no-streamer', action='store_true', help='Whether to with streamer') + parser.add_argument( + '--lagent', action='store_true', help='Whether to use lagent') + parser.add_argument( + '--stop-words', nargs='+', type=str, default=[], help='Stop words') + parser.add_argument( + '--offload-folder', + default=None, + help='The folder in which to offload the model weights (or where the ' + 'model weights are already offloaded).') + parser.add_argument( + '--max-new-tokens', + type=int, + default=2048, + help='Maximum number of new tokens allowed in generated text') + parser.add_argument( + '--temperature', + type=float, + default=0.1, + help='The value used to modulate the next token probabilities.') + parser.add_argument( + '--top-k', + type=int, + default=40, + help='The number of highest probability vocabulary tokens to ' + 'keep for top-k-filtering.') + parser.add_argument( + '--top-p', + type=float, + default=0.75, + help='If set to float < 1, only the smallest set of most probable ' + 'tokens with probabilities that add up to top_p or higher are ' + 'kept for generation.') + parser.add_argument( + '--repetition-penalty', + type=float, + default=1.0, + help='The parameter for repetition penalty. 1.0 means no penalty.') + parser.add_argument( + '--seed', + type=int, + default=0, + help='Random seed for reproducible text generation') + args = parser.parse_args() + return args + + +def get_input(): + """Helper function for getting input from users.""" + sentinel = '' # ends when this string is seen + result = None + while result is None: + print(('\ndouble enter to end input (EXIT: exit chat, ' + 'RESET: reset history) >>> '), + end='') + try: + result = '\n'.join(iter(input, sentinel)) + except UnicodeDecodeError: + print('Invalid characters detected. Please enter again.') + return result + + +def main(): + args = parse_args() + torch.manual_seed(args.seed) + + # parse config + if not osp.isfile(args.config): + try: + args.config = cfgs_name_path[args.config] + except KeyError: + raise FileNotFoundError(f'Cannot find {args.config}') + + # load config + cfg = Config.fromfile(args.config) + # if args.cfg_options is not None: + # cfg.merge_from_dict(args.cfg_options) + + model_name = cfg.model.type if isinstance(cfg.model.type, + str) else cfg.model.type.__name__ + if 'LLaVAModel' or 'OMG' in model_name: + cfg.model.pretrained_pth = None + + model = BUILDER.build(cfg.model) + print(model.state_dict().keys()) + + # pre_state_dict = torch.load("/root/omg-llava.pth") + # model.load_state_dict(pre_state_dict) + + backend = get_file_backend(args.pth_model) + if isinstance(backend, PetrelBackend): + from xtuner.utils.fileio import patch_fileio + with patch_fileio(): + state_dict = guess_load_checkpoint(args.pth_model) + else: + state_dict = guess_load_checkpoint(args.pth_model) + + print(state_dict.keys()) + # del state_dict['llm.base_model.model.model.tok_embeddings.weight'] + model.load_state_dict(state_dict, strict=False) + print(f'Load PTH model from {args.pth_model}') + + # image_processor_cfg = copy.deepcopy(cfg.image_processor) + image_processor = cfg.image_processor + image_processor_type = image_processor['type'] + del image_processor['type'] + image_processor = image_processor_type(**image_processor) + + # chat_hook = EvaluateChatHook( + # tokenizer=cfg.tokenizer, + # image_processor=image_processor_cfg, + # every_n_iters=100, + # evaluation_inputs=cfg.evaluation_inputs, + # evaluation_images=cfg.evaluation_images, + # system='', + # prompt_template=PROMPT_TEMPLATE.internlm2_chat + # ) + # model.cuda() + # model.eval() + # chat_hook._eval_images_(model, model.device, max_new_tokens=200) + + # build llm + quantization_config = None + load_in_8bit = False + if args.bits == 4: + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4') + elif args.bits == 8: + load_in_8bit = True + model_kwargs = { + 'quantization_config': quantization_config, + 'load_in_8bit': load_in_8bit, + 'device_map': 'auto', + 'offload_folder': args.offload_folder, + 'trust_remote_code': True, + 'torch_dtype': TORCH_DTYPE_MAP[args.torch_dtype] + } + if False: + pass + else: + if args.with_plugins is None: + inner_thoughts_open = False + calculate_open = False + solve_open = False + search_open = False + else: + assert args.prompt_template == args.system_template == 'moss_sft' + from plugins import plugins_api + inner_thoughts_open = True + calculate_open = 'calculate' in args.with_plugins + solve_open = 'solve' in args.with_plugins + search_open = 'search' in args.with_plugins + # pre-import for api and model preparation + if calculate_open: + from plugins import calculate # noqa: F401 + if solve_open: + from plugins import solve # noqa: F401 + if search_open: + from plugins import search # noqa: F401 + # build llm + llm = model.llm + tokenizer = model.tokenizer + + model.cuda() + model.eval() + llm.eval() + visual_encoder = model.visual_encoder + projector = model.projector + projector_text2vision = model.projector_text2vision + + if args.image is not None: + image = load_image(args.image) + image = expand2square( + image, tuple(int(x * 255) for x in image_processor.image_mean)) + image_for_show = image + image = image_processor.preprocess( + image, return_tensors='pt')['pixel_values'][0] + image = image.cuda().unsqueeze(0).to(visual_encoder.dtype) + visual_outputs = visual_encoder(image, output_hidden_states=True) + print([item.shape for item in visual_outputs]) + pixel_values = projector(visual_outputs) + + stop_words = args.stop_words + sep = '' + if args.prompt_template: + template = PROMPT_TEMPLATE[args.prompt_template] + stop_words += template.get('STOP_WORDS', []) + sep = template.get('SEP', '') + stop_criteria = get_stop_criteria( + tokenizer=tokenizer, stop_words=stop_words) + + if args.no_streamer: + streamer = None + else: + streamer = TextStreamer(tokenizer, skip_prompt=True) + + gen_config = GenerationConfig( + max_new_tokens=args.max_new_tokens, + do_sample=args.temperature > 0, + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + repetition_penalty=args.repetition_penalty, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id + if tokenizer.pad_token_id is not None else tokenizer.eos_token_id, + ) + + n_turn = 0 + inputs = '' + while True: + text = get_input() + while text.strip() == 'RESET': + print('Log: History responses have been removed!') + n_turn = 0 + inputs = '' + text = get_input() + if text.strip() == 'EXIT': + print('Log: Exit!') + exit(0) + + if args.image is not None and n_turn == 0: + text = DEFAULT_IMAGE_TOKEN + '\n' + text + + if args.prompt_template: + prompt_text = '' + template = PROMPT_TEMPLATE[args.prompt_template] + if 'SYSTEM' in template and n_turn == 0: + system_text = None + if args.system_template is not None: + system_text = SYSTEM_TEMPLATE[ + args.system_template].format( + round=n_turn + 1, bot_name=args.bot_name) + elif args.system is not None: + system_text = args.system + if system_text is not None: + prompt_text += template['SYSTEM'].format( + system=system_text, + round=n_turn + 1, + bot_name=args.bot_name) + prompt_text += template['INSTRUCTION'].format( + input=text, round=n_turn + 1, bot_name=args.bot_name) + if args.prompt_template == args.system_template == 'moss_sft': + if not inner_thoughts_open: + prompt_text.replace('- Inner thoughts: enabled.', + '- Inner thoughts: disabled.') + if not calculate_open: + prompt_text.replace(('- Calculator: enabled. API: ' + 'Calculate(expression)'), + '- Calculator: disabled.') + if not solve_open: + prompt_text.replace( + '- Equation solver: enabled. API: Solve(equation)', + '- Equation solver: disabled.') + if not search_open: + prompt_text.replace( + '- Web search: enabled. API: Search(query)', + '- Web search: disabled.') + else: + prompt_text = text + print("prompt_text: ", prompt_text) + inputs += prompt_text + if args.image is None: + if n_turn == 0: + ids = tokenizer.encode(inputs, return_tensors='pt') + else: + ids = tokenizer.encode( + inputs, return_tensors='pt', add_special_tokens=False) + + if args.with_plugins is not None: + generate_output = llm.generate( + inputs=ids.cuda(), + generation_config=gen_config, + streamer=streamer, + stopping_criteria=stop_criteria).cpu() + generate_output_text = tokenizer.decode( + generate_output[0][len(ids[0]):]) + if streamer is None: + end = '' if generate_output_text[-1] == '\n' else '\n' + print(generate_output_text, end=end) + pattern = r'<\|Commands\|>:(.*?)' + command_text = ', '.join( + re.findall(pattern, generate_output_text)) + extent_text = plugins_api( + command_text, + calculate_open=calculate_open, + solve_open=solve_open, + search_open=search_open) + end = '' if extent_text[-1] == '\n' else '\n' + print(extent_text, end=end) + extent_text_ids = tokenizer.encode( + extent_text, + return_tensors='pt', + add_special_tokens=False) + new_ids = torch.cat((generate_output, extent_text_ids), + dim=1) + + generate_output = llm.generate( + inputs=new_ids.cuda(), + generation_config=gen_config, + streamer=streamer, + stopping_criteria=stop_criteria) + if streamer is None: + output_text = tokenizer.decode( + generate_output[0][len(new_ids[0]):]) + end = '' if output_text[-1] == '\n' else '\n' + print(output_text, end=end) + else: + generate_output = llm.generate( + inputs=ids.cuda(), + generation_config=gen_config, + streamer=streamer, + stopping_criteria=stop_criteria) + if streamer is None: + output_text = tokenizer.decode( + generate_output[0][len(ids[0]):]) + end = '' if output_text[-1] == '\n' else '\n' + print(output_text, end=end) + inputs = tokenizer.decode(generate_output[0]) + else: + chunk_encode = [] + for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)): + if idx == 0 and n_turn == 0: + cur_encode = tokenizer.encode(chunk) + else: + cur_encode = tokenizer.encode( + chunk, add_special_tokens=False) + chunk_encode.append(cur_encode) + assert len(chunk_encode) == 2 + ids = [] + for idx, cur_chunk_encode in enumerate(chunk_encode): + ids.extend(cur_chunk_encode) + if idx != len(chunk_encode) - 1: + ids.append(IMAGE_TOKEN_INDEX) + ids = torch.tensor(ids).cuda().unsqueeze(0) + mm_inputs = prepare_inputs_labels_for_multimodal( + llm=llm, input_ids=ids, pixel_values=pixel_values) + print(mm_inputs['inputs_embeds'].shape) + # mm_inputs['inputs_embeds'] = mm_inputs['inputs_embeds'].to(torch.float16) + + generate_output = llm.generate( + **mm_inputs, + generation_config=gen_config, + streamer=streamer, + bos_token_id=tokenizer.bos_token_id, + stopping_criteria=stop_criteria, + output_hidden_states=True, + return_dict_in_generate=True + ) + + hidden_states = generate_output.hidden_states + last_hidden_states = [item[-1][0] for item in hidden_states] + last_hidden_states = torch.cat(last_hidden_states, dim=0) + seg_hidden_states = get_seg_hidden_states( + last_hidden_states, generate_output.sequences[0][:-1], + # last_hidden_states, generate_output.sequences[0], + seg_id=model.seg_token_idx + ) + # seg_hidden_states = seg_hidden_states.to(torch.float32) + if len(seg_hidden_states) != 0: + seg_hidden_states = projector_text2vision(seg_hidden_states) + batch_idxs = torch.zeros((seg_hidden_states.shape[0], ), + dtype=torch.int64).to(seg_hidden_states.device) + pred_masks_list = model.visual_encoder.forward_llm_seg(seg_hidden_states, batch_idxs) + print((pred_masks_list[-1].flatten(2) > 0).sum(-1)) + print(pred_masks_list[-1].shape) + show_mask_pred(image_for_show, pred_masks_list[-1], save_dir='./output.png') + + + if streamer is None: + # output_text = tokenizer.decode(generate_output[0]) + output_text = tokenizer.decode(generate_output.sequences[0]) + end = '' if output_text[-1] == '\n' else '\n' + print(output_text, end=end) + # inputs += tokenizer.decode(generate_output[0]) + inputs += tokenizer.decode(generate_output.sequences[0]) + n_turn += 1 + inputs += sep + # if len(generate_output[0]) >= args.max_new_tokens: + if len(generate_output.sequences[0]) >= args.max_new_tokens: + print( + 'Remove the memory of history responses, since ' + f'it exceeds the length limitation {args.max_new_tokens}.') + n_turn = 0 + inputs = '' + + +def get_seg_hidden_states(hidden_states, output_ids, seg_id): + seg_mask = output_ids == seg_id + n_out = len(seg_mask) + print(output_ids) + return hidden_states[-n_out:][seg_mask] + +def show_mask_pred(image, masks, save_dir='./output.png'): + import torch.nn.functional as F + from PIL import Image + import numpy as np + + colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), + (255, 255, 0), (255, 0, 255), (0, 255, 255), + (128, 128, 255)] + + masks = F.interpolate(masks, size=image.size, mode='bilinear', align_corners=False) + masks = masks.sigmoid() > 0.5 + masks = masks.to(torch.uint8).cpu().numpy()[:, 0] + + _mask_image = np.zeros((masks.shape[1], masks.shape[2], 3), dtype=np.uint8) + + for i, mask in enumerate(masks): + color = colors[i % len(colors)] + _mask_image[:, :, 0] = _mask_image[:, :, 0] + mask.astype(np.uint8) * color[0] + _mask_image[:, :, 1] = _mask_image[:, :, 1] + mask.astype(np.uint8) * color[1] + _mask_image[:, :, 2] = _mask_image[:, :, 2] + mask.astype(np.uint8) * color[2] + + + image = np.array(image) + image = image * 0.5 + _mask_image * 0.5 + image = image.astype(np.uint8) + image = Image.fromarray(image) + image.save(save_dir) + + return + +if __name__ == '__main__': + main() diff --git a/omg_llava/tools/chat_omg_llava_msseg.py b/omg_llava/tools/chat_omg_llava_msseg.py new file mode 100644 index 0000000000000000000000000000000000000000..6e968e188476c82ec061bf509430fa4198618947 --- /dev/null +++ b/omg_llava/tools/chat_omg_llava_msseg.py @@ -0,0 +1,549 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import copy +import os +import os.path as osp +import re +import sys + +import torch +from huggingface_hub import snapshot_download +from peft import PeftModel +from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel, GenerationConfig) +from transformers.generation.streamers import TextStreamer + +from xtuner.dataset.utils import expand2square, load_image +from xtuner.model.utils import prepare_inputs_labels_for_multimodal +from xtuner.tools.utils import get_stop_criteria +from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, + PROMPT_TEMPLATE, SYSTEM_TEMPLATE) + +import argparse +import os.path as osp + +from mmengine.config import Config, DictAction +from mmengine.fileio import PetrelBackend, get_file_backend + +from xtuner.configs import cfgs_name_path +from xtuner.model.utils import guess_load_checkpoint +from xtuner.registry import BUILDER + +TORCH_DTYPE_MAP = dict( + fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto') + +from xtuner.engine.hooks.evaluate_chat_hook import EvaluateChatHook + +def remove_prefix(state_dict, prefix): + new_state_dict = {} + for key, value in state_dict.items(): + if key.startswith(prefix): + new_key = key[len(prefix):] + new_state_dict[new_key] = value + else: + new_state_dict[key] = value + return new_state_dict + + +def parse_args(): + parser = argparse.ArgumentParser(description='Chat with a HF model') + parser.add_argument('config', help='config file name or path.') + parser.add_argument('pth_model', help='pth model file') + + parser.add_argument( + '--mode', + default='baseline', # baseline, mean, linear_cat + help='Specify a mode') + + parser.add_argument('--image', default=None, help='image') + parser.add_argument( + '--torch-dtype', + default='fp16', + choices=TORCH_DTYPE_MAP.keys(), + help='Override the default `torch.dtype` and load the model under ' + 'a specific `dtype`.') + parser.add_argument( + '--prompt-template', + choices=PROMPT_TEMPLATE.keys(), + default="internlm2_chat", + help='Specify a prompt template') + system_group = parser.add_mutually_exclusive_group() + system_group.add_argument( + '--system', default=None, help='Specify the system text') + system_group.add_argument( + '--system-template', + choices=SYSTEM_TEMPLATE.keys(), + default=None, + help='Specify a system template') + parser.add_argument( + '--bits', + type=int, + choices=[4, 8, None], + default=None, + help='LLM bits') + parser.add_argument( + '--bot-name', type=str, default='BOT', help='Name for Bot') + parser.add_argument( + '--with-plugins', + nargs='+', + choices=['calculate', 'solve', 'search'], + help='Specify plugins to use') + parser.add_argument( + '--no-streamer', action='store_true', help='Whether to with streamer') + parser.add_argument( + '--lagent', action='store_true', help='Whether to use lagent') + parser.add_argument( + '--stop-words', nargs='+', type=str, default=[], help='Stop words') + parser.add_argument( + '--offload-folder', + default=None, + help='The folder in which to offload the model weights (or where the ' + 'model weights are already offloaded).') + parser.add_argument( + '--max-new-tokens', + type=int, + default=2048, + help='Maximum number of new tokens allowed in generated text') + parser.add_argument( + '--temperature', + type=float, + default=0.1, + help='The value used to modulate the next token probabilities.') + parser.add_argument( + '--top-k', + type=int, + default=40, + help='The number of highest probability vocabulary tokens to ' + 'keep for top-k-filtering.') + parser.add_argument( + '--top-p', + type=float, + default=0.75, + help='If set to float < 1, only the smallest set of most probable ' + 'tokens with probabilities that add up to top_p or higher are ' + 'kept for generation.') + parser.add_argument( + '--repetition-penalty', + type=float, + default=1.0, + help='The parameter for repetition penalty. 1.0 means no penalty.') + parser.add_argument( + '--seed', + type=int, + default=0, + help='Random seed for reproducible text generation') + args = parser.parse_args() + return args + + +def get_input(): + """Helper function for getting input from users.""" + sentinel = '' # ends when this string is seen + result = None + while result is None: + print(('\ndouble enter to end input (EXIT: exit chat, ' + 'RESET: reset history) >>> '), + end='') + try: + result = '\n'.join(iter(input, sentinel)) + except UnicodeDecodeError: + print('Invalid characters detected. Please enter again.') + return result + + +def main(): + args = parse_args() + torch.manual_seed(args.seed) + + # parse config + if not osp.isfile(args.config): + try: + args.config = cfgs_name_path[args.config] + except KeyError: + raise FileNotFoundError(f'Cannot find {args.config}') + + # load config + cfg = Config.fromfile(args.config) + # if args.cfg_options is not None: + # cfg.merge_from_dict(args.cfg_options) + + model_name = cfg.model.type if isinstance(cfg.model.type, + str) else cfg.model.type.__name__ + if 'LLaVAModel' or 'OMG' in model_name: + cfg.model.pretrained_pth = None + + model = BUILDER.build(cfg.model) + print(model.state_dict().keys()) + + # pre_state_dict = torch.load("/root/omg-llava.pth") + # model.load_state_dict(pre_state_dict) + + backend = get_file_backend(args.pth_model) + if isinstance(backend, PetrelBackend): + from xtuner.utils.fileio import patch_fileio + with patch_fileio(): + state_dict = guess_load_checkpoint(args.pth_model) + else: + state_dict = guess_load_checkpoint(args.pth_model) + + print(state_dict.keys()) + # del state_dict['llm.base_model.model.model.tok_embeddings.weight'] + model.load_state_dict(state_dict, strict=False) + print(f'Load PTH model from {args.pth_model}') + + # image_processor_cfg = copy.deepcopy(cfg.image_processor) + image_processor = cfg.image_processor + image_processor_type = image_processor['type'] + del image_processor['type'] + image_processor = image_processor_type(**image_processor) + + # chat_hook = EvaluateChatHook( + # tokenizer=cfg.tokenizer, + # image_processor=image_processor_cfg, + # every_n_iters=100, + # evaluation_inputs=cfg.evaluation_inputs, + # evaluation_images=cfg.evaluation_images, + # system='', + # prompt_template=PROMPT_TEMPLATE.internlm2_chat + # ) + # model.cuda() + # model.eval() + # chat_hook._eval_images_(model, model.device, max_new_tokens=200) + + # build llm + quantization_config = None + load_in_8bit = False + if args.bits == 4: + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4') + elif args.bits == 8: + load_in_8bit = True + model_kwargs = { + 'quantization_config': quantization_config, + 'load_in_8bit': load_in_8bit, + 'device_map': 'auto', + 'offload_folder': args.offload_folder, + 'trust_remote_code': True, + 'torch_dtype': TORCH_DTYPE_MAP[args.torch_dtype] + } + if False: + pass + else: + if args.with_plugins is None: + inner_thoughts_open = False + calculate_open = False + solve_open = False + search_open = False + else: + assert args.prompt_template == args.system_template == 'moss_sft' + from plugins import plugins_api + inner_thoughts_open = True + calculate_open = 'calculate' in args.with_plugins + solve_open = 'solve' in args.with_plugins + search_open = 'search' in args.with_plugins + # pre-import for api and model preparation + if calculate_open: + from plugins import calculate # noqa: F401 + if solve_open: + from plugins import solve # noqa: F401 + if search_open: + from plugins import search # noqa: F401 + # build llm + llm = model.llm + tokenizer = model.tokenizer + + model.cuda() + model.eval() + llm.eval() + visual_encoder = model.visual_encoder + projector = model.projector + projector_text2vision = model.projector_text2vision + + if args.image is not None: + image = load_image(args.image) + image = expand2square( + image, tuple(int(x * 255) for x in image_processor.image_mean)) + image_for_show = image + image = image_processor.preprocess( + image, return_tensors='pt')['pixel_values'][0] + image = image.cuda().unsqueeze(0).to(visual_encoder.dtype) + visual_outputs = visual_encoder(image, output_hidden_states=True) + print([item.shape for item in visual_outputs]) + pixel_values = projector(visual_outputs) + + stop_words = args.stop_words + sep = '' + if args.prompt_template: + template = PROMPT_TEMPLATE[args.prompt_template] + stop_words += template.get('STOP_WORDS', []) + sep = template.get('SEP', '') + stop_criteria = get_stop_criteria( + tokenizer=tokenizer, stop_words=stop_words) + + if args.no_streamer: + streamer = None + else: + streamer = TextStreamer(tokenizer, skip_prompt=True) + + gen_config = GenerationConfig( + max_new_tokens=args.max_new_tokens, + do_sample=args.temperature > 0, + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + repetition_penalty=args.repetition_penalty, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id + if tokenizer.pad_token_id is not None else tokenizer.eos_token_id, + ) + + n_turn = 0 + inputs = '' + while True: + text = get_input() + while text.strip() == 'RESET': + print('Log: History responses have been removed!') + n_turn = 0 + inputs = '' + text = get_input() + if text.strip() == 'EXIT': + print('Log: Exit!') + exit(0) + + if args.image is not None and n_turn == 0: + text = DEFAULT_IMAGE_TOKEN + '\n' + text + + if args.prompt_template: + prompt_text = '' + template = PROMPT_TEMPLATE[args.prompt_template] + if 'SYSTEM' in template and n_turn == 0: + system_text = None + if args.system_template is not None: + system_text = SYSTEM_TEMPLATE[ + args.system_template].format( + round=n_turn + 1, bot_name=args.bot_name) + elif args.system is not None: + system_text = args.system + if system_text is not None: + prompt_text += template['SYSTEM'].format( + system=system_text, + round=n_turn + 1, + bot_name=args.bot_name) + prompt_text += template['INSTRUCTION'].format( + input=text, round=n_turn + 1, bot_name=args.bot_name) + if args.prompt_template == args.system_template == 'moss_sft': + if not inner_thoughts_open: + prompt_text.replace('- Inner thoughts: enabled.', + '- Inner thoughts: disabled.') + if not calculate_open: + prompt_text.replace(('- Calculator: enabled. API: ' + 'Calculate(expression)'), + '- Calculator: disabled.') + if not solve_open: + prompt_text.replace( + '- Equation solver: enabled. API: Solve(equation)', + '- Equation solver: disabled.') + if not search_open: + prompt_text.replace( + '- Web search: enabled. API: Search(query)', + '- Web search: disabled.') + else: + prompt_text = text + print("prompt_text: ", prompt_text) + inputs += prompt_text + if args.image is None: + if n_turn == 0: + ids = tokenizer.encode(inputs, return_tensors='pt') + else: + ids = tokenizer.encode( + inputs, return_tensors='pt', add_special_tokens=False) + + if args.with_plugins is not None: + generate_output = llm.generate( + inputs=ids.cuda(), + generation_config=gen_config, + streamer=streamer, + stopping_criteria=stop_criteria).cpu() + generate_output_text = tokenizer.decode( + generate_output[0][len(ids[0]):]) + if streamer is None: + end = '' if generate_output_text[-1] == '\n' else '\n' + print(generate_output_text, end=end) + pattern = r'<\|Commands\|>:(.*?)' + command_text = ', '.join( + re.findall(pattern, generate_output_text)) + extent_text = plugins_api( + command_text, + calculate_open=calculate_open, + solve_open=solve_open, + search_open=search_open) + end = '' if extent_text[-1] == '\n' else '\n' + print(extent_text, end=end) + extent_text_ids = tokenizer.encode( + extent_text, + return_tensors='pt', + add_special_tokens=False) + new_ids = torch.cat((generate_output, extent_text_ids), + dim=1) + + generate_output = llm.generate( + inputs=new_ids.cuda(), + generation_config=gen_config, + streamer=streamer, + stopping_criteria=stop_criteria) + if streamer is None: + output_text = tokenizer.decode( + generate_output[0][len(new_ids[0]):]) + end = '' if output_text[-1] == '\n' else '\n' + print(output_text, end=end) + else: + generate_output = llm.generate( + inputs=ids.cuda(), + generation_config=gen_config, + streamer=streamer, + stopping_criteria=stop_criteria) + if streamer is None: + output_text = tokenizer.decode( + generate_output[0][len(ids[0]):]) + end = '' if output_text[-1] == '\n' else '\n' + print(output_text, end=end) + inputs = tokenizer.decode(generate_output[0]) + else: + chunk_encode = [] + for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)): + if idx == 0 and n_turn == 0: + cur_encode = tokenizer.encode(chunk) + else: + cur_encode = tokenizer.encode( + chunk, add_special_tokens=False) + chunk_encode.append(cur_encode) + assert len(chunk_encode) == 2 + ids = [] + for idx, cur_chunk_encode in enumerate(chunk_encode): + ids.extend(cur_chunk_encode) + if idx != len(chunk_encode) - 1: + ids.append(IMAGE_TOKEN_INDEX) + ids = torch.tensor(ids).cuda().unsqueeze(0) + mm_inputs = prepare_inputs_labels_for_multimodal( + llm=llm, input_ids=ids, pixel_values=pixel_values) + print(mm_inputs['inputs_embeds'].shape) + # mm_inputs['inputs_embeds'] = mm_inputs['inputs_embeds'].to(torch.float16) + + generate_output = llm.generate( + **mm_inputs, + generation_config=gen_config, + streamer=streamer, + bos_token_id=tokenizer.bos_token_id, + stopping_criteria=stop_criteria, + output_hidden_states=True, + return_dict_in_generate=True + ) + + hidden_states = generate_output.hidden_states + if args.mode == 'baseline': + last_hidden_states = [item[-1][-1] for item in hidden_states] + last_hidden_states = torch.cat(last_hidden_states, dim=0) + seg_hidden_states = get_seg_hidden_states( + # last_hidden_states, generate_output.sequences[0], + last_hidden_states, generate_output.sequences[0][:-1], + seg_id=model.seg_token_idx + ) + else: + hidden_states = [torch.cat(item, dim=0)[:, -1] for item in hidden_states] + last_hidden_states = torch.stack(hidden_states, dim=0) # (N, n_layer, c) + seg_hidden_states = get_seg_hidden_states_multistates( + last_hidden_states, generate_output.sequences[0][:-1], + seg_id=model.seg_token_idx, mode=args.mode, + model=model, + ) + + # seg_hidden_states = seg_hidden_states.to(torch.float32) + if len(seg_hidden_states) != 0: + seg_hidden_states = projector_text2vision(seg_hidden_states) + batch_idxs = torch.zeros((seg_hidden_states.shape[0], ), + dtype=torch.int64).to(seg_hidden_states.device) + pred_masks_list = model.visual_encoder.forward_llm_seg(seg_hidden_states, batch_idxs) + print((pred_masks_list[-1].flatten(2) > 0).sum(-1)) + print(pred_masks_list[-1].shape) + show_mask_pred(image_for_show, pred_masks_list[-1], save_dir='./output.png') + + + if streamer is None: + # output_text = tokenizer.decode(generate_output[0]) + output_text = tokenizer.decode(generate_output.sequences[0]) + end = '' if output_text[-1] == '\n' else '\n' + print(output_text, end=end) + # inputs += tokenizer.decode(generate_output[0]) + inputs += tokenizer.decode(generate_output.sequences[0]) + n_turn += 1 + inputs += sep + # if len(generate_output[0]) >= args.max_new_tokens: + if len(generate_output.sequences[0]) >= args.max_new_tokens: + print( + 'Remove the memory of history responses, since ' + f'it exceeds the length limitation {args.max_new_tokens}.') + n_turn = 0 + inputs = '' + + +def get_seg_hidden_states(hidden_states, output_ids, seg_id): + seg_mask = output_ids == seg_id + n_out = len(seg_mask) + print(output_ids) + return hidden_states[-n_out:][seg_mask] + +def get_seg_hidden_states_multistates(hidden_states, output_ids, seg_id, mode, model): + # (N, n_layer, c) + seg_mask = output_ids == seg_id + n_out = len(seg_mask) + hidden_states = hidden_states[-n_out:][seg_mask] # (N, n_layers, c) + if mode == 'mean': + hidden_states = torch.mean(hidden_states, dim=1) + elif mode == 'linear_cat': + hidden_states = hidden_states[:, -model.selected_layers:] + hidden_states = model.seg_token_proj_linear_cat[0](hidden_states) + hidden_states = hidden_states.flatten(1) + hidden_states = model.seg_token_proj_linear_cat[1](hidden_states) + else: + raise NotImplementedError + return hidden_states + +def show_mask_pred(image, masks, save_dir='./output.png'): + import torch.nn.functional as F + from PIL import Image + import numpy as np + + colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), + (255, 255, 0), (255, 0, 255), (0, 255, 255), + (128, 128, 255)] + + masks = F.interpolate(masks, size=image.size, mode='bilinear', align_corners=False) + masks = masks.sigmoid() > 0.5 + masks = masks.to(torch.uint8).cpu().numpy()[:, 0] + + _mask_image = np.zeros((masks.shape[1], masks.shape[2], 3), dtype=np.uint8) + + for i, mask in enumerate(masks): + color = colors[i % len(colors)] + _mask_image[:, :, 0] = _mask_image[:, :, 0] + mask.astype(np.uint8) * color[0] + _mask_image[:, :, 1] = _mask_image[:, :, 1] + mask.astype(np.uint8) * color[1] + _mask_image[:, :, 2] = _mask_image[:, :, 2] + mask.astype(np.uint8) * color[2] + + + image = np.array(image) + image = image * 0.5 + _mask_image * 0.5 + image = image.astype(np.uint8) + image = Image.fromarray(image) + image.save(save_dir) + + return + +if __name__ == '__main__': + main() diff --git a/omg_llava/tools/chat_omg_llava_visual_prompts.py b/omg_llava/tools/chat_omg_llava_visual_prompts.py new file mode 100644 index 0000000000000000000000000000000000000000..1e036143cef1d6d9c09af0045289733dca832d83 --- /dev/null +++ b/omg_llava/tools/chat_omg_llava_visual_prompts.py @@ -0,0 +1,567 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import copy +import os +import os.path as osp +import re +import sys + +import torch +from huggingface_hub import snapshot_download +from peft import PeftModel +from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel, GenerationConfig) +from transformers.generation.streamers import TextStreamer + +from xtuner.dataset.utils import expand2square, load_image +from omg_llava.dataset.utils import expand2square_points +from omg_llava.model.utils import prepare_inputs_labels_for_multimodal_with_visual_prompts +from xtuner.tools.utils import get_stop_criteria +from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, + PROMPT_TEMPLATE, SYSTEM_TEMPLATE) + +import argparse +import os.path as osp + +from mmengine.config import Config, DictAction +from mmengine.fileio import PetrelBackend, get_file_backend + +from xtuner.configs import cfgs_name_path +from xtuner.model.utils import guess_load_checkpoint +from xtuner.registry import BUILDER +import numpy as np +from PIL import Image + +TORCH_DTYPE_MAP = dict( + fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto') + +from xtuner.engine.hooks.evaluate_chat_hook import EvaluateChatHook + +prompts_points = [[2527, 3215], ] + +def get_points_embeddings(points, input_ids, width, height, + mark_token_idx, visual_encoder, + projector): + if points is None or len(points) == 0: + return [] + + mark_token_mask = input_ids == mark_token_idx + batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to( + input_ids.device) + batch_idxs = batch_idxs[mark_token_mask] # (N, ) batch_size number + + points = points.to(torch.float32) + # print(points.dtype, batch_idxs.dtype) + marks_embeddings = visual_encoder.forward_point_sam( + points, batch_idxs, width=width, height=height + )[:, 0] # (N, C) + + marks_embeddings = marks_embeddings.to(projector.model.query_proj.weight.dtype) + marks_embeddings = projector.model.query_proj(marks_embeddings) + marks_embeddings = projector.model.model(marks_embeddings) + return marks_embeddings # (N, C) + +def remove_prefix(state_dict, prefix): + new_state_dict = {} + for key, value in state_dict.items(): + if key.startswith(prefix): + new_key = key[len(prefix):] + new_state_dict[new_key] = value + else: + new_state_dict[key] = value + return new_state_dict + + +def parse_args(): + parser = argparse.ArgumentParser(description='Chat with a HF model') + parser.add_argument('config', help='config file name or path.') + parser.add_argument('pth_model', help='pth model file') + + parser.add_argument('--image', default=None, help='image') + parser.add_argument( + '--torch-dtype', + default='fp16', + choices=TORCH_DTYPE_MAP.keys(), + help='Override the default `torch.dtype` and load the model under ' + 'a specific `dtype`.') + parser.add_argument( + '--prompt-template', + choices=PROMPT_TEMPLATE.keys(), + default="internlm2_chat", + help='Specify a prompt template') + system_group = parser.add_mutually_exclusive_group() + system_group.add_argument( + '--system', default=None, help='Specify the system text') + system_group.add_argument( + '--system-template', + choices=SYSTEM_TEMPLATE.keys(), + default=None, + help='Specify a system template') + parser.add_argument( + '--bits', + type=int, + choices=[4, 8, None], + default=None, + help='LLM bits') + parser.add_argument( + '--bot-name', type=str, default='BOT', help='Name for Bot') + parser.add_argument( + '--with-plugins', + nargs='+', + choices=['calculate', 'solve', 'search'], + help='Specify plugins to use') + parser.add_argument( + '--no-streamer', action='store_true', help='Whether to with streamer') + parser.add_argument( + '--lagent', action='store_true', help='Whether to use lagent') + parser.add_argument( + '--stop-words', nargs='+', type=str, default=[], help='Stop words') + parser.add_argument( + '--offload-folder', + default=None, + help='The folder in which to offload the model weights (or where the ' + 'model weights are already offloaded).') + parser.add_argument( + '--max-new-tokens', + type=int, + default=2048, + help='Maximum number of new tokens allowed in generated text') + parser.add_argument( + '--temperature', + type=float, + default=0.1, + help='The value used to modulate the next token probabilities.') + parser.add_argument( + '--top-k', + type=int, + default=40, + help='The number of highest probability vocabulary tokens to ' + 'keep for top-k-filtering.') + parser.add_argument( + '--top-p', + type=float, + default=0.75, + help='If set to float < 1, only the smallest set of most probable ' + 'tokens with probabilities that add up to top_p or higher are ' + 'kept for generation.') + parser.add_argument( + '--repetition-penalty', + type=float, + default=1.0, + help='The parameter for repetition penalty. 1.0 means no penalty.') + parser.add_argument( + '--seed', + type=int, + default=0, + help='Random seed for reproducible text generation') + args = parser.parse_args() + return args + + +def get_input(): + """Helper function for getting input from users.""" + sentinel = '' # ends when this string is seen + result = None + while result is None: + print(('\ndouble enter to end input (EXIT: exit chat, ' + 'RESET: reset history) >>> '), + end='') + try: + result = '\n'.join(iter(input, sentinel)) + except UnicodeDecodeError: + print('Invalid characters detected. Please enter again.') + return result + + +def main(): + args = parse_args() + torch.manual_seed(args.seed) + + # parse config + if not osp.isfile(args.config): + try: + args.config = cfgs_name_path[args.config] + except KeyError: + raise FileNotFoundError(f'Cannot find {args.config}') + + # load config + cfg = Config.fromfile(args.config) + # if args.cfg_options is not None: + # cfg.merge_from_dict(args.cfg_options) + + model_name = cfg.model.type if isinstance(cfg.model.type, + str) else cfg.model.type.__name__ + if 'LLaVAModel' in model_name: + cfg.model.pretrained_pth = None + + model = BUILDER.build(cfg.model) + print(model.state_dict().keys()) + + # pre_state_dict = torch.load("/root/omg-llava.pth") + # model.load_state_dict(pre_state_dict) + + backend = get_file_backend(args.pth_model) + if isinstance(backend, PetrelBackend): + from xtuner.utils.fileio import patch_fileio + with patch_fileio(): + state_dict = guess_load_checkpoint(args.pth_model) + else: + state_dict = guess_load_checkpoint(args.pth_model) + + print(state_dict.keys()) + # del state_dict['llm.base_model.model.model.tok_embeddings.weight'] + model.load_state_dict(state_dict, strict=False) + print(f'Load PTH model from {args.pth_model}') + + # image_processor_cfg = copy.deepcopy(cfg.image_processor) + image_processor = cfg.image_processor + image_processor_type = image_processor['type'] + del image_processor['type'] + image_processor = image_processor_type(**image_processor) + + # chat_hook = EvaluateChatHook( + # tokenizer=cfg.tokenizer, + # image_processor=image_processor_cfg, + # every_n_iters=100, + # evaluation_inputs=cfg.evaluation_inputs, + # evaluation_images=cfg.evaluation_images, + # system='', + # prompt_template=PROMPT_TEMPLATE.internlm2_chat + # ) + # model.cuda() + # model.eval() + # chat_hook._eval_images_(model, model.device, max_new_tokens=200) + + # build llm + quantization_config = None + load_in_8bit = False + if args.bits == 4: + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4') + elif args.bits == 8: + load_in_8bit = True + model_kwargs = { + 'quantization_config': quantization_config, + 'load_in_8bit': load_in_8bit, + 'device_map': 'auto', + 'offload_folder': args.offload_folder, + 'trust_remote_code': True, + 'torch_dtype': TORCH_DTYPE_MAP[args.torch_dtype] + } + if False: + pass + else: + if args.with_plugins is None: + inner_thoughts_open = False + calculate_open = False + solve_open = False + search_open = False + else: + assert args.prompt_template == args.system_template == 'moss_sft' + from plugins import plugins_api + inner_thoughts_open = True + calculate_open = 'calculate' in args.with_plugins + solve_open = 'solve' in args.with_plugins + search_open = 'search' in args.with_plugins + # pre-import for api and model preparation + if calculate_open: + from plugins import calculate # noqa: F401 + if solve_open: + from plugins import solve # noqa: F401 + if search_open: + from plugins import search # noqa: F401 + # build llm + llm = model.llm + tokenizer = model.tokenizer + + model.cuda() + model.eval() + llm.eval() + visual_encoder = model.visual_encoder + projector = model.projector + projector_text2vision = model.projector_text2vision + + if args.image is not None: + image = load_image(args.image) + ori_width, ori_height = image.size + image = expand2square( + image, tuple(int(x * 255) for x in image_processor.image_mean)) + image_for_show = image + image = image_processor.preprocess( + image, return_tensors='pt')['pixel_values'][0] + image = image.cuda().unsqueeze(0).to(visual_encoder.dtype) + visual_outputs = visual_encoder(image, output_hidden_states=True) + print([item.shape for item in visual_outputs]) + pixel_values = projector(visual_outputs) + + stop_words = args.stop_words + sep = '' + if args.prompt_template: + template = PROMPT_TEMPLATE[args.prompt_template] + stop_words += template.get('STOP_WORDS', []) + sep = template.get('SEP', '') + stop_criteria = get_stop_criteria( + tokenizer=tokenizer, stop_words=stop_words) + + if args.no_streamer: + streamer = None + else: + streamer = TextStreamer(tokenizer, skip_prompt=True) + + gen_config = GenerationConfig( + max_new_tokens=args.max_new_tokens, + do_sample=args.temperature > 0, + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + repetition_penalty=args.repetition_penalty, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id + if tokenizer.pad_token_id is not None else tokenizer.eos_token_id, + ) + + n_turn = 0 + inputs = '' + while True: + text = get_input() + while text.strip() == 'RESET': + print('Log: History responses have been removed!') + n_turn = 0 + inputs = '' + text = get_input() + if text.strip() == 'EXIT': + print('Log: Exit!') + exit(0) + + if args.image is not None and n_turn == 0: + text = DEFAULT_IMAGE_TOKEN + '\n' + text + + if args.prompt_template: + prompt_text = '' + template = PROMPT_TEMPLATE[args.prompt_template] + if 'SYSTEM' in template and n_turn == 0: + system_text = None + if args.system_template is not None: + system_text = SYSTEM_TEMPLATE[ + args.system_template].format( + round=n_turn + 1, bot_name=args.bot_name) + elif args.system is not None: + system_text = args.system + if system_text is not None: + prompt_text += template['SYSTEM'].format( + system=system_text, + round=n_turn + 1, + bot_name=args.bot_name) + prompt_text += template['INSTRUCTION'].format( + input=text, round=n_turn + 1, bot_name=args.bot_name) + if args.prompt_template == args.system_template == 'moss_sft': + if not inner_thoughts_open: + prompt_text.replace('- Inner thoughts: enabled.', + '- Inner thoughts: disabled.') + if not calculate_open: + prompt_text.replace(('- Calculator: enabled. API: ' + 'Calculate(expression)'), + '- Calculator: disabled.') + if not solve_open: + prompt_text.replace( + '- Equation solver: enabled. API: Solve(equation)', + '- Equation solver: disabled.') + if not search_open: + prompt_text.replace( + '- Web search: enabled. API: Search(query)', + '- Web search: disabled.') + else: + prompt_text = text + print("prompt_text: ", prompt_text) + inputs += prompt_text + if args.image is None: + if n_turn == 0: + ids = tokenizer.encode(inputs, return_tensors='pt') + else: + ids = tokenizer.encode( + inputs, return_tensors='pt', add_special_tokens=False) + + if args.with_plugins is not None: + generate_output = llm.generate( + inputs=ids.cuda(), + generation_config=gen_config, + streamer=streamer, + stopping_criteria=stop_criteria).cpu() + generate_output_text = tokenizer.decode( + generate_output[0][len(ids[0]):]) + if streamer is None: + end = '' if generate_output_text[-1] == '\n' else '\n' + print(generate_output_text, end=end) + pattern = r'<\|Commands\|>:(.*?)' + command_text = ', '.join( + re.findall(pattern, generate_output_text)) + extent_text = plugins_api( + command_text, + calculate_open=calculate_open, + solve_open=solve_open, + search_open=search_open) + end = '' if extent_text[-1] == '\n' else '\n' + print(extent_text, end=end) + extent_text_ids = tokenizer.encode( + extent_text, + return_tensors='pt', + add_special_tokens=False) + new_ids = torch.cat((generate_output, extent_text_ids), + dim=1) + + generate_output = llm.generate( + inputs=new_ids.cuda(), + generation_config=gen_config, + streamer=streamer, + stopping_criteria=stop_criteria) + if streamer is None: + output_text = tokenizer.decode( + generate_output[0][len(new_ids[0]):]) + end = '' if output_text[-1] == '\n' else '\n' + print(output_text, end=end) + else: + generate_output = llm.generate( + inputs=ids.cuda(), + generation_config=gen_config, + streamer=streamer, + stopping_criteria=stop_criteria) + if streamer is None: + output_text = tokenizer.decode( + generate_output[0][len(ids[0]):]) + end = '' if output_text[-1] == '\n' else '\n' + print(output_text, end=end) + inputs = tokenizer.decode(generate_output[0]) + else: + chunk_encode = [] + for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)): + if idx == 0 and n_turn == 0: + cur_encode = tokenizer.encode(chunk) + else: + cur_encode = tokenizer.encode( + chunk, add_special_tokens=False) + chunk_encode.append(cur_encode) + assert len(chunk_encode) == 2 + ids = [] + for idx, cur_chunk_encode in enumerate(chunk_encode): + ids.extend(cur_chunk_encode) + if idx != len(chunk_encode) - 1: + ids.append(IMAGE_TOKEN_INDEX) + ids = torch.tensor(ids).cuda().unsqueeze(0) + + # add region token on firset conversation + if n_turn == 0: + points = np.array(prompts_points) + points = expand2square_points(points, height=ori_height, width=ori_width) + points[:, 0] = points[:, 0] / max(ori_height, ori_width) * 1024 + points[:, 1] = points[:, 1] / max(ori_height, ori_width) * 1024 + points = torch.from_numpy(points) + points = points.cuda() + mark_token_id = model.mark_token_idx + + points_mark_embedding = get_points_embeddings( + points, ids, 1024, 1024, + mark_token_id, visual_encoder, + projector) + + # mm_inputs = prepare_inputs_labels_for_multimodal( + # llm=llm, input_ids=ids, pixel_values=pixel_values) + + mm_inputs = prepare_inputs_labels_for_multimodal_with_visual_prompts( + llm=llm, input_ids=ids, pixel_values=pixel_values, + mark_id=mark_token_id, + mark_feats=points_mark_embedding, region_id=-9999) + + print(mm_inputs['inputs_embeds'].shape) + # mm_inputs['inputs_embeds'] = mm_inputs['inputs_embeds'].to(torch.float16) + + generate_output = llm.generate( + **mm_inputs, + generation_config=gen_config, + streamer=streamer, + bos_token_id=tokenizer.bos_token_id, + stopping_criteria=stop_criteria, + output_hidden_states=True, + return_dict_in_generate=True + ) + + hidden_states = generate_output.hidden_states + last_hidden_states = [item[-1][0] for item in hidden_states] + last_hidden_states = torch.cat(last_hidden_states, dim=0) + seg_hidden_states = get_seg_hidden_states( + last_hidden_states, generate_output.sequences[0], + seg_id=model.seg_token_idx + ) + # seg_hidden_states = seg_hidden_states.to(torch.float32) + if len(seg_hidden_states) != 0: + seg_hidden_states = projector_text2vision(seg_hidden_states) + batch_idxs = torch.zeros((seg_hidden_states.shape[0], ), + dtype=torch.int64).to(seg_hidden_states.device) + pred_masks_list = model.visual_encoder.forward_llm_seg(seg_hidden_states, batch_idxs) + print((pred_masks_list[-1].flatten(2) > 0).sum(-1)) + print(pred_masks_list[-1].shape) + show_mask_pred(image_for_show, pred_masks_list[-1], save_dir='./output.png') + + + if streamer is None: + # output_text = tokenizer.decode(generate_output[0]) + output_text = tokenizer.decode(generate_output.sequences[0]) + end = '' if output_text[-1] == '\n' else '\n' + print(output_text, end=end) + # inputs += tokenizer.decode(generate_output[0]) + inputs += tokenizer.decode(generate_output.sequences[0]) + n_turn += 1 + inputs += sep + # if len(generate_output[0]) >= args.max_new_tokens: + if len(generate_output.sequences[0]) >= args.max_new_tokens: + print( + 'Remove the memory of history responses, since ' + f'it exceeds the length limitation {args.max_new_tokens}.') + n_turn = 0 + inputs = '' + + +def get_seg_hidden_states(hidden_states, output_ids, seg_id): + seg_mask = output_ids == seg_id + n_out = len(seg_mask) + print(output_ids) + return hidden_states[-n_out:][seg_mask] + +def show_mask_pred(image, masks, save_dir='./output.png'): + import torch.nn.functional as F + from PIL import Image + import numpy as np + + colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), + (255, 255, 0), (255, 0, 255), (0, 255, 255), + (128, 128, 255)] + + masks = F.interpolate(masks, size=image.size, mode='bilinear', align_corners=False) + masks = masks.sigmoid() > 0.5 + masks = masks.to(torch.uint8).cpu().numpy()[:, 0] + + _mask_image = np.zeros((masks.shape[1], masks.shape[2], 3), dtype=np.uint8) + + for i, mask in enumerate(masks): + color = colors[i % len(colors)] + _mask_image[:, :, 0] = _mask_image[:, :, 0] + mask.astype(np.uint8) * color[0] + _mask_image[:, :, 1] = _mask_image[:, :, 1] + mask.astype(np.uint8) * color[1] + _mask_image[:, :, 2] = _mask_image[:, :, 2] + mask.astype(np.uint8) * color[2] + + + image = np.array(image) + image = image * 0.5 + _mask_image * 0.5 + image = image.astype(np.uint8) + image = Image.fromarray(image) + image.save(save_dir) + + return + +if __name__ == '__main__': + main() diff --git a/omg_llava/tools/convert_deepspeed2pth.py b/omg_llava/tools/convert_deepspeed2pth.py new file mode 100644 index 0000000000000000000000000000000000000000..e6bc4d6c5cc1fa695aa278862445ea98dcbce12c --- /dev/null +++ b/omg_llava/tools/convert_deepspeed2pth.py @@ -0,0 +1,154 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import re + +import torch +from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel, GenerationConfig) +from transformers.generation.streamers import TextStreamer + +from xtuner.dataset.utils import expand2square, load_image +from xtuner.model.utils import prepare_inputs_labels_for_multimodal +from xtuner.tools.utils import get_stop_criteria +from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, + PROMPT_TEMPLATE, SYSTEM_TEMPLATE) + +import argparse +import os.path as osp + +from mmengine.config import Config, DictAction +from mmengine.fileio import PetrelBackend, get_file_backend + +from xtuner.configs import cfgs_name_path +from xtuner.model.utils import guess_load_checkpoint +from xtuner.registry import BUILDER + +TORCH_DTYPE_MAP = dict( + fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto') + +def parse_args(): + parser = argparse.ArgumentParser(description='Chat with a HF model') + parser.add_argument('config', help='config file name or path.') + parser.add_argument('pth_model', help='pth model file') + + parser.add_argument('--save-path', default='./work_dirs/converted.pth', help='save path of converted pth') + parser.add_argument( + '--torch-dtype', + default='fp16', + choices=TORCH_DTYPE_MAP.keys(), + help='Override the default `torch.dtype` and load the model under ' + 'a specific `dtype`.') + parser.add_argument( + '--prompt-template', + choices=PROMPT_TEMPLATE.keys(), + default="internlm2_chat", + help='Specify a prompt template') + system_group = parser.add_mutually_exclusive_group() + system_group.add_argument( + '--system', default=None, help='Specify the system text') + system_group.add_argument( + '--system-template', + choices=SYSTEM_TEMPLATE.keys(), + default=None, + help='Specify a system template') + parser.add_argument( + '--bits', + type=int, + choices=[4, 8, None], + default=None, + help='LLM bits') + parser.add_argument( + '--bot-name', type=str, default='BOT', help='Name for Bot') + parser.add_argument( + '--with-plugins', + nargs='+', + choices=['calculate', 'solve', 'search'], + help='Specify plugins to use') + parser.add_argument( + '--no-streamer', action='store_true', help='Whether to with streamer') + parser.add_argument( + '--lagent', action='store_true', help='Whether to use lagent') + parser.add_argument( + '--stop-words', nargs='+', type=str, default=[], help='Stop words') + parser.add_argument( + '--offload-folder', + default=None, + help='The folder in which to offload the model weights (or where the ' + 'model weights are already offloaded).') + parser.add_argument( + '--max-new-tokens', + type=int, + default=2048, + help='Maximum number of new tokens allowed in generated text') + parser.add_argument( + '--temperature', + type=float, + default=0.1, + help='The value used to modulate the next token probabilities.') + parser.add_argument( + '--top-k', + type=int, + default=40, + help='The number of highest probability vocabulary tokens to ' + 'keep for top-k-filtering.') + parser.add_argument( + '--top-p', + type=float, + default=0.75, + help='If set to float < 1, only the smallest set of most probable ' + 'tokens with probabilities that add up to top_p or higher are ' + 'kept for generation.') + parser.add_argument( + '--repetition-penalty', + type=float, + default=1.0, + help='The parameter for repetition penalty. 1.0 means no penalty.') + parser.add_argument( + '--seed', + type=int, + default=0, + help='Random seed for reproducible text generation') + args = parser.parse_args() + return args + +def main(): + args = parse_args() + torch.manual_seed(args.seed) + + # parse config + if not osp.isfile(args.config): + try: + args.config = cfgs_name_path[args.config] + except KeyError: + raise FileNotFoundError(f'Cannot find {args.config}') + + # load config + cfg = Config.fromfile(args.config) + # if args.cfg_options is not None: + # cfg.merge_from_dict(args.cfg_options) + + model_name = cfg.model.type if isinstance(cfg.model.type, + str) else cfg.model.type.__name__ + if 'LLaVAModel' or 'OMG' in model_name: + cfg.model.pretrained_pth = None + + model = BUILDER.build(cfg.model) + + backend = get_file_backend(args.pth_model) + if isinstance(backend, PetrelBackend): + from xtuner.utils.fileio import patch_fileio + with patch_fileio(): + state_dict = guess_load_checkpoint(args.pth_model) + else: + state_dict = guess_load_checkpoint(args.pth_model) + + model.load_state_dict(state_dict, strict=False) + print(f'Load PTH model from {args.pth_model}') + + state_dict = model.state_dict() + torch.save(state_dict, args.save_path) + print('Save the converted pth to {}'.format(args.save_path)) + return + +if __name__ == '__main__': + main() diff --git a/omg_llava/tools/evaluate_gcd.py b/omg_llava/tools/evaluate_gcd.py new file mode 100644 index 0000000000000000000000000000000000000000..a255e07446615fcfba990eb7a63c2dfef119bba2 --- /dev/null +++ b/omg_llava/tools/evaluate_gcd.py @@ -0,0 +1,291 @@ +import os +import json +import argparse +from tqdm import tqdm +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval +from pycocotools import mask as maskUtils +from pycocoevalcap.eval import COCOEvalCap +from transformers import AutoTokenizer, AutoModel +from sklearn.metrics.pairwise import cosine_similarity +import torch +import numpy as np + + + + +def parse_args(): + parser = argparse.ArgumentParser(description="Training") + + parser.add_argument("--split", required=True, help="Evaluation split, options are 'val', 'test'") + parser.add_argument("--prediction_dir_path", required=True, help="The path where the inference results are stored.") + parser.add_argument("--gt_dir_path", required=False, default="./data/glamm_data/annotations/gcg_val_test/", + help="The path containing GranD-f evaluation annotations.") + + args = parser.parse_args() + + return args + + +# Load pre-trained model tokenizer and model for evaluation +tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") +model = AutoModel.from_pretrained("bert-base-uncased") + + +def get_bert_embedding(text): + inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True) + outputs = model(**inputs) + # Use the mean of the last hidden states as sentence embedding + sentence_embedding = torch.mean(outputs.last_hidden_state[0], dim=0).detach().numpy() + + return sentence_embedding + +def compute_iou(mask1, mask2): + intersection = np.logical_and(mask1, mask2) + union = np.logical_or(mask1, mask2) + iou = np.sum(intersection) / np.sum(union) + + return iou + +def bbox_to_x1y1x2y2(bbox): + x1, y1, w, h = bbox + bbox = [x1, y1, x1 + w, y1 + h] + + return bbox + +def compute_miou(pred_masks, gt_masks): + # Computing mIoU between predicted masks and ground truth masks + iou_matrix = np.zeros((len(pred_masks), len(gt_masks))) + for i, pred_mask in enumerate(pred_masks): + for j, gt_mask in enumerate(gt_masks): + iou_matrix[i, j] = compute_iou(pred_mask, gt_mask) + + # One-to-one pairing and mean IoU calculation + paired_iou = [] + while iou_matrix.size > 0 and np.max(iou_matrix) > 0: + max_iou_idx = np.unravel_index(np.argmax(iou_matrix, axis=None), iou_matrix.shape) + paired_iou.append(iou_matrix[max_iou_idx]) + iou_matrix = np.delete(iou_matrix, max_iou_idx[0], axis=0) + iou_matrix = np.delete(iou_matrix, max_iou_idx[1], axis=1) + + return np.mean(paired_iou) if paired_iou else 0.0 + + +def evaluate_mask_miou(coco_gt, image_ids, pred_save_path): + # Load predictions + coco_dt = coco_gt.loadRes(pred_save_path) + + mious = [] + for image_id in tqdm(image_ids): + # Getting ground truth masks + matching_anns = [ann for ann in coco_gt.anns.values() if ann['image_id'] == image_id] + ann_ids = [ann['id'] for ann in matching_anns] + + gt_anns = coco_gt.loadAnns(ann_ids) + gt_masks = [maskUtils.decode(ann['segmentation']) for ann in gt_anns if 'segmentation' in ann] + + # Getting predicted masks + matching_anns = [ann for ann in coco_dt.anns.values() if ann['image_id'] == image_id] + dt_ann_ids = [ann['id'] for ann in matching_anns] + pred_anns = coco_dt.loadAnns(dt_ann_ids) + pred_masks = [maskUtils.decode(ann['segmentation']) for ann in pred_anns if 'segmentation' in ann] + + # Compute and save the mIoU for the current image + mious.append(compute_miou(pred_masks, gt_masks)) + + # Report mean IoU across all images + mean_miou = np.mean(mious) if mious else 0.0 # If list is empty, return 0.0 + + print(f"Mean IoU (mIoU) across all images: {mean_miou:.3f}") + + +def compute_iou_matrix(pred_masks, gt_masks): + iou_matrix = np.zeros((len(pred_masks), len(gt_masks))) + for i, pred_mask in enumerate(pred_masks): + for j, gt_mask in enumerate(gt_masks): + iou_matrix[i, j] = compute_iou(pred_mask, gt_mask) + + return iou_matrix + + +def text_similarity_bert(str1, str2): + emb1 = get_bert_embedding(str1) + emb2 = get_bert_embedding(str2) + + return cosine_similarity([emb1], [emb2])[0, 0] + + +def find_best_matches(gt_anns, gt_labels, dt_anns, dt_labels, iou_threshold, text_sim_threshold, vectorizer=None): + best_matches = [] + + # Compute pair - wise IoU + pred_masks = [maskUtils.decode(ann['segmentation']) for ann in dt_anns] + gt_masks = [maskUtils.decode(ann['segmentation']) for ann in gt_anns] + ious = compute_iou_matrix(gt_masks, pred_masks) + + text_sims = np.zeros((len(gt_labels), len(dt_labels))) + + for i, gt_label in enumerate(gt_labels): + for j, dt_label in enumerate(dt_labels): + text_sims[i, j] = text_similarity_bert(gt_label, dt_label) + + # Find one-to-one matches satisfying both IoU and text similarity thresholds + while ious.size > 0: + max_iou_idx = np.unravel_index(np.argmax(ious), ious.shape) + if ious[max_iou_idx] < iou_threshold or text_sims[max_iou_idx] < text_sim_threshold: + break # No admissible pair found + + best_matches.append(max_iou_idx) + + # Remove selected annotations from consideration + ious[max_iou_idx[0], :] = 0 + ious[:, max_iou_idx[1]] = 0 + text_sims[max_iou_idx[0], :] = 0 + text_sims[:, max_iou_idx[1]] = 0 + + return best_matches # List of index pairs [(gt_idx, dt_idx), ...] + + +def evaluate_recall_with_mapping(coco_gt, coco_cap_gt, image_ids, pred_save_path, cap_pred_save_path, iou_threshold=0.5, + text_sim_threshold=0.5): + coco_dt = coco_gt.loadRes(pred_save_path) + coco_cap_dt = coco_cap_gt.loadRes(cap_pred_save_path) + + true_positives = 0 + actual_positives = 0 + + for image_id in tqdm(image_ids): + try: + # gt_ann_ids = coco_gt.getAnnIds(imgIds=image_id, iscrowd=None) + matching_anns = [ann for ann in coco_gt.anns.values() if ann['image_id'] == image_id] + gt_ann_ids = [ann['id'] for ann in matching_anns] + gt_anns = coco_gt.loadAnns(gt_ann_ids) + + # dt_ann_ids = coco_dt.getAnnIds(imgIds=image_id, iscrowd=None) + matching_anns = [ann for ann in coco_dt.anns.values() if ann['image_id'] == image_id] + dt_ann_ids = [ann['id'] for ann in matching_anns] + dt_anns = coco_dt.loadAnns(dt_ann_ids) + + # gt_cap_ann_ids = coco_cap_gt.getAnnIds(imgIds=image_id) + matching_anns = [ann for ann in coco_cap_gt.anns.values() if ann['image_id'] == image_id] + gt_cap_ann_ids = [ann['id'] for ann in matching_anns] + gt_cap_ann = coco_cap_gt.loadAnns(gt_cap_ann_ids)[0] + + # dt_cap_ann_ids = coco_cap_dt.getAnnIds(imgIds=image_id) + matching_anns = [ann for ann in coco_cap_dt.anns.values() if ann['image_id'] == image_id] + dt_cap_ann_ids = [ann['id'] for ann in matching_anns] + dt_cap_ann = coco_cap_dt.loadAnns(dt_cap_ann_ids)[0] + + gt_labels = gt_cap_ann['labels'] + dt_labels = dt_cap_ann['labels'] + + actual_positives += len(gt_labels) + + # Find best matching pairs + best_matches = find_best_matches(gt_anns, gt_labels, dt_anns, dt_labels, iou_threshold, text_sim_threshold) + + true_positives += len(best_matches) + except Exception as e: + print(e) + + recall = true_positives / actual_positives if actual_positives > 0 else 0 + + print(f"Recall: {recall:.3f}") + + +def main(): + args = parse_args() + + # Set the correct split + split = args.split + assert split == "val" or split == "test" # GCG Evaluation has only val and test splits + gt_mask_path = f"{args.gt_dir_path}/{split}_gcg_coco_mask_gt.json" + gt_cap_path = f"{args.gt_dir_path}/{split}_gcg_coco_caption_gt.json" + + print(f"Starting evalution on {split} split.") + + # Get the image names of the split + all_images_ids = [] + with open(gt_cap_path, 'r') as f: + contents = json.load(f) + for image in contents['images']: + all_images_ids.append(image['id']) + + # The directory is used to store intermediate files + tmp_dir_path = f"tmp/{os.path.basename(args.prediction_dir_path)}_{split}" + os.makedirs(tmp_dir_path, exist_ok=True) # Create directory if not exists already + + # Create predictions + pred_save_path = f"{tmp_dir_path}/mask_pred_tmp_save.json" + cap_pred_save_path = f"{tmp_dir_path}/cap_pred_tmp_save.json" + coco_pred_file = [] + caption_pred_dict = {} + for image_id in all_images_ids: + prediction_path = f"{args.prediction_dir_path}/{image_id}.json" + with open(prediction_path, 'r') as f: + pred = json.load(f) + bu = pred + key = list(pred.keys())[0] + pred = pred[key] + try: + caption_pred_dict[image_id] = {'caption': pred['caption'], 'labels': pred['phrases']} + except Exception as e: + pred = bu + caption_pred_dict[image_id] = {'caption': pred['caption'], 'labels': pred['phrases']} + for rle_mask in pred['pred_masks']: + coco_pred_file.append({"image_id": image_id, "category_id": 1, "segmentation": rle_mask, "score": 1.0}) + + # Save gcg_coco_predictions + with open(pred_save_path, 'w') as f: + json.dump(coco_pred_file, f) + + # Prepare the CAPTION predictions in COCO format + cap_image_ids = [] + coco_cap_pred_file = [] + for image_id, values in caption_pred_dict.items(): + cap_image_ids.append(image_id) + coco_cap_pred_file.append({"image_id": image_id, "caption": values['caption'], "labels": values['labels']}) + + # Save gcg_caption_coco_predictions + with open(cap_pred_save_path, 'w') as f: + json.dump(coco_cap_pred_file, f) + + # # -------------------------------# + # 1. Evaluate AP + # Calculate mask mAP + # Load the ground truth and predictions in COCO format + coco_gt = COCO(gt_mask_path) + coco_dt = coco_gt.loadRes(pred_save_path) # load predictions + # Initialize COCOEval and specify the metric you want to use + coco_eval = COCOeval(coco_gt, coco_dt, "segm") # "segm" for segmentation + # Evaluate on a specific category + coco_eval.params.catIds = [1] # your category ID + # Evaluate + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + + # # -------------------------------# + # # 2. Evaluate Caption Quality + coco_cap_gt = COCO(gt_cap_path) + coco_cap_result = coco_cap_gt.loadRes(cap_pred_save_path) + # create coco_eval object by taking coco and coco_result + coco_eval = COCOEvalCap(coco_cap_gt, coco_cap_result) + coco_eval.params['image_id'] = coco_cap_result.getImgIds() + coco_eval.evaluate() + for metric, score in coco_eval.eval.items(): + print(f'{metric}: {score:.3f}') + + # # -------------------------------# + # 3. Evaluate Mask Mean MIoU + coco_gt = COCO(gt_mask_path) # Load ground truth annotations + evaluate_mask_miou(coco_gt, all_images_ids, pred_save_path) + + # # -------------------------------# + # 4. Evaluate Recall + evaluate_recall_with_mapping(coco_gt, coco_cap_gt, all_images_ids, pred_save_path, cap_pred_save_path, + iou_threshold=0.5, text_sim_threshold=0.5) + + +if __name__ == "__main__": + main() diff --git a/omg_llava/tools/evaluate_region_cap.py b/omg_llava/tools/evaluate_region_cap.py new file mode 100644 index 0000000000000000000000000000000000000000..e3f401164efd282133fc4b42f5e19174f5d02e9a --- /dev/null +++ b/omg_llava/tools/evaluate_region_cap.py @@ -0,0 +1,52 @@ +import os +import json +import argparse +from pycocotools.coco import COCO +from pycocoevalcap.eval import COCOEvalCap + + +def parse_args(): + parser = argparse.ArgumentParser(description="GLaMM Inference - Region Captioning") + + parser.add_argument("--annotation_file", + default="data/RefCoco_Reg/mdetr_annotations/finetune_refcocog_val_captions.json", type=str, + help="Replace with 'data/visual_genome/test_caption.json' for VG.") + parser.add_argument("--results_dir", default="results", type=str, help="The path to save the results.") + + return parser.parse_args() + + +def main(): + args = parse_args() + + # Load the annotation file + coco = COCO(args.annotation_file) + + # Merge and load the results files + # all_results = [] + # for result_file in os.listdir(args.results_dir): + # all_results += json.load(open(f"{args.results_dir}/{result_file}", "r")) + # merged_file_path = f"{args.results_dir}/merged.json" + # with open(merged_file_path, 'w') as f: + # json.dump(all_results, f) + # coco_result = coco.loadRes(merged_file_path) + coco_result = coco.loadRes(args.results_dir) + + # Create coco_eval object by taking coco and coco_result + coco_eval = COCOEvalCap(coco, coco_result) + + # Evaluate results + coco_eval.params['image_id'] = coco_result.getImgIds() + coco_eval.evaluate() + + # Print and save the output evaluation scores + output_file_path = f"./work_dirs/region_cap_metrics.txt" + f = open(output_file_path, 'w') + for metric, score in coco_eval.eval.items(): + print(f'{metric}: {score:.3f}') + f.write(f"{metric}: {score:.3f}\n") + f.close() + + +if __name__ == "__main__": + main() diff --git a/omg_llava/tools/gcd_omg_seg_llava.py b/omg_llava/tools/gcd_omg_seg_llava.py new file mode 100644 index 0000000000000000000000000000000000000000..86161659fe9ac45e5d11f979a2cc5ce4fc523b28 --- /dev/null +++ b/omg_llava/tools/gcd_omg_seg_llava.py @@ -0,0 +1,484 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import json +import math +import os +import os.path as osp +import re +import torch +import tqdm + +from mmengine.dist import (collect_results, get_dist_info, get_rank, init_dist, + master_only) +from mmengine.utils.dl_utils import set_multi_processing +from torch.utils.data import Dataset +from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel, GenerationConfig) + +from xtuner.model.utils import LoadWoInit, prepare_inputs_labels_for_multimodal +from xtuner.tools.utils import get_stop_criteria, is_cn_string +from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, + PROMPT_TEMPLATE) + +from xtuner.registry import BUILDER +from xtuner.configs import cfgs_name_path +from xtuner.model.utils import guess_load_checkpoint +from mmengine.config import Config +from mmengine.fileio import PetrelBackend, get_file_backend +from mmengine.config import ConfigDict + +from PIL import Image +import torch.nn.functional as F +from xtuner.dataset.utils import expand2square +from pycocotools import mask as mask_utils + + +def convert_dict2config_dict(input): + input = ConfigDict(**input) + for key in input.keys(): + if isinstance(input[key], dict): + input[key] = convert_dict2config_dict(input[key]) + return input + +TORCH_DTYPE_MAP = dict( + fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto') + + +GCG_QUESTIONS = [ + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.', + 'Can you provide a thorough description of the this image? Please output with interleaved segmentation masks for the corresponding phrases.', + 'Please describe in detail the contents of the image. Please respond with interleaved segmentation masks for the corresponding parts of the answer.', + 'Could you give a comprehensive explanation of what can be found within this picture? Please output with interleaved segmentation masks for the corresponding phrases.', + 'Could you give me an elaborate explanation of this picture? Please respond with interleaved segmentation masks for the corresponding phrases.', + 'Could you provide me with a detailed analysis of this photo? Please output with interleaved segmentation masks for the corresponding parts of the answer.', +] + +def parse_args(): + parser = argparse.ArgumentParser(description='RefCocoSeg') + parser.add_argument('config', help='config file name or path.') + parser.add_argument('pth_model', help='pth model file') + parser.add_argument( + '--output-name', type=str, default='gcg', help='save folder name') + parser.add_argument( + '--prompt-template', + choices=PROMPT_TEMPLATE.keys(), + default='internlm2_chat', + help='Specify a prompt template') + parser.add_argument( + '--stop-words', nargs='+', type=str, default=[], help='Stop words') + parser.add_argument( + '--torch-dtype', + default='fp16', + choices=TORCH_DTYPE_MAP.keys(), + help='Override the default `torch.dtype` and load the model under ' + 'a specific `dtype`.') + parser.add_argument( + '--bits', + type=int, + choices=[4, 8, None], + default=None, + help='LLM bits') + parser.add_argument( + '--bot-name', type=str, default='BOT', help='Name for Bot') + parser.add_argument( + '--offload-folder', + default=None, + help='The folder in which to offload the model weights (or where the ' + 'model weights are already offloaded).') + parser.add_argument( + '--max-new-tokens', + type=int, + default=100, + help='Maximum number of new tokens allowed in generated text') + parser.add_argument( + '--seed', + type=int, + default=0, + help='Random seed for reproducible text generation') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + args = parser.parse_args() + return args + + +@master_only +def master_print(msg): + print(msg) + +class GCD_Inference_Dataset(Dataset): + def __init__(self, + image_folder, + image_processor, + debug=False, + pad_image_to_square=True, + ): + self.debug = debug + self.image_folder = image_folder + size = image_processor.crop_size + if isinstance(size, int): + self.image_h, self.image_w = size, size + else: + self.image_w, self.image_h = size + + if isinstance(image_processor, dict) or isinstance( + image_processor, Config) or isinstance(image_processor, + ConfigDict): + self.image_processor = BUILDER.build(image_processor) + else: + self.image_processor = image_processor + self.pad_image_to_square = pad_image_to_square + self.down_ratio = 1 + + self.images = os.listdir(image_folder) + if debug: + self.images = self.images[:20] + + def __len__(self): + return len(self.images) + + def get_questions(self): + question = "Could you please give me a detailed description of the image? Please respond with interleaved \ + segmentation masks for the corresponding parts of the answer." + return question + + def __getitem__(self, index): + + data_dict = {} + + questions = self.get_questions() + image_file = self.images[index] + data_dict['image_file'] = image_file + image_file = os.path.join(self.image_folder, image_file) + print(image_file) + image = Image.open(image_file).convert('RGB') + ori_width, ori_height = image.size + if self.pad_image_to_square: + image = expand2square( + image, + tuple( + int(x * 255) for x in self.image_processor.image_mean)) + image = self.image_processor.preprocess( + image, return_tensors='pt')['pixel_values'][0] + data_dict['pixel_values'] = image + data_dict['ori_size'] = (ori_width, ori_height) + data_dict['questions'] = questions + return data_dict + +def main(): + args = parse_args() + + torch.manual_seed(args.seed) + + if args.launcher != 'none': + set_multi_processing(distributed=True) + init_dist(args.launcher) + + rank, world_size = get_dist_info() + torch.cuda.set_device(rank) + else: + rank = 0 + world_size = 1 + + # build model + if not osp.isfile(args.config): + try: + args.config = cfgs_name_path[args.config] + except KeyError: + raise FileNotFoundError(f'Cannot find {args.config}') + + # load config + cfg = Config.fromfile(args.config) + # if args.cfg_options is not None: + # cfg.merge_from_dict(args.cfg_options) + + model_name = cfg.model.type if isinstance(cfg.model.type, + str) else cfg.model.type.__name__ + if 'LLaVAModel' in model_name: + cfg.model.pretrained_pth = None + + model = BUILDER.build(cfg.model) + backend = get_file_backend(args.pth_model) + if isinstance(backend, PetrelBackend): + from xtuner.utils.fileio import patch_fileio + with patch_fileio(): + state_dict = guess_load_checkpoint(args.pth_model) + else: + state_dict = guess_load_checkpoint(args.pth_model) + + # del state_dict['llm.base_model.model.model.tok_embeddings.weight'] + model.load_state_dict(state_dict, strict=False) + print(f'Load PTH model from {args.pth_model}') + # image_processor_cfg = copy.deepcopy(cfg.image_processor) + image_processor = cfg.image_processor + image_processor_type = image_processor['type'] + del image_processor['type'] + image_processor = image_processor_type(**image_processor) + + llm = model.llm + tokenizer = model.tokenizer + + model.cuda() + model.eval() + llm.eval() + visual_encoder = model.visual_encoder + projector = model.projector + projector_text2vision = model.projector_text2vision + + projector.cuda() + projector.eval() + + visual_encoder.cuda() + visual_encoder.eval() + + stop_words = args.stop_words + if args.prompt_template: + template = PROMPT_TEMPLATE[args.prompt_template] + stop_words += template.get('STOP_WORDS', []) + stop_criteria = get_stop_criteria( + tokenizer=tokenizer, stop_words=stop_words) + + gen_config = GenerationConfig( + max_new_tokens=args.max_new_tokens, + do_sample=False, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id + if tokenizer.pad_token_id is not None else tokenizer.eos_token_id, + ) + + dataset = GCD_Inference_Dataset( + image_folder='./data/glamm_data/images/grandf/val_test/', + image_processor=image_processor, + pad_image_to_square=True, + debug=False, + # debug=True, + ) + n_samples = len(dataset) + per_rank_samples = math.ceil(n_samples / world_size) + + per_rank_ids = range(per_rank_samples * rank, + min(n_samples, per_rank_samples * (rank + 1))) + + for i in tqdm.tqdm(per_rank_ids, desc=f'Rank {rank}'): + # pixel feature + data_sample = dataset[i] + image = data_sample['pixel_values'] # () + image = image.cuda().unsqueeze(0).to(visual_encoder.dtype) + visual_outputs = visual_encoder(image, output_hidden_states=True) + if isinstance(visual_outputs, list) or isinstance(visual_outputs, tuple) \ + or isinstance(visual_outputs, torch.Tensor): + pixel_values = projector(visual_outputs) + else: + pixel_values = projector( + visual_outputs.hidden_states[args.visual_select_layer][:, 1:]) + + + # questions = data_sample['questions'] + questions = GCG_QUESTIONS + for question in questions: + # print(question) + texts = DEFAULT_IMAGE_TOKEN + '\n' + question + + if args.prompt_template: + prompt_text = '' + template = PROMPT_TEMPLATE[args.prompt_template] + prompt_text += template['INSTRUCTION'].format( + input=texts, round=1, bot_name=args.bot_name) + else: + prompt_text = texts + + batch_inputs = prompt_text + + predict, seg_hidden_states = forward_model( + batch_inputs, pixel_values, + tokenizer, model, llm, + projector_text2vision, + gen_config, stop_criteria) + if len(seg_hidden_states) != 0: + break + + + ori_size = data_sample['ori_size'] + # print("Answer:", predict) + # print("Mask num: ", len(seg_hidden_states)) + + if len(seg_hidden_states) == 0: + print("Warnning !!! No mask Pred !!!") + w, h = ori_size + masks = torch.zeros((0, h, w), dtype=torch.bool) + else: + seg_hidden_states = projector_text2vision(seg_hidden_states) + batch_idxs = torch.zeros((seg_hidden_states.shape[0],), + dtype=torch.int64).to(seg_hidden_states.device) + pred_masks_list = model.visual_encoder.forward_llm_seg(seg_hidden_states, batch_idxs) + pred_masks = pred_masks_list[-1] + w, h = ori_size + masks = F.interpolate(pred_masks, size=(max(w, h), max(w, h)), + mode='bilinear', align_corners=False) + masks = masks[:, 0] + # remove padding + if w == h: + pass + elif w > h: + n_pad = w - h + n_pad_1 = n_pad // 2 + n_pad_2 = n_pad - n_pad_1 + masks = masks[:, n_pad_1: w - n_pad_2] + else: + n_pad = h - w + n_pad_1 = n_pad // 2 + n_pad_2 = n_pad - n_pad_1 + masks = masks[:, :, n_pad_1: h - n_pad_2] + # binary + masks = masks.sigmoid() > 0.5 + process_and_save_output( + "./work_dirs/{}/".format(args.output_name), + data_sample['image_file'], + predict, + masks + ) + +def forward_model(question, pixel_values, + tokenizer, model, llm, + projector_text2vision, + gen_config, stop_criteria): + # pixel_values = projector( + # visual_outputs.hidden_states[args.visual_select_layer][:, 1:]) + + inputs = question + # print("Question: ", inputs) + chunk_encode = [] + for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)): + if idx == 0: + cur_encode = tokenizer.encode(chunk) + else: + cur_encode = tokenizer.encode(chunk, add_special_tokens=False) + chunk_encode.append(cur_encode) + assert len(chunk_encode) == 2 + ids = [] + for idx, cur_chunk_encode in enumerate(chunk_encode): + ids.extend(cur_chunk_encode) + if idx != len(chunk_encode) - 1: + ids.append(IMAGE_TOKEN_INDEX) + ids = torch.tensor(ids).cuda().unsqueeze(0) + mm_inputs = prepare_inputs_labels_for_multimodal( + llm=llm, input_ids=ids, pixel_values=pixel_values) + # mm_inputs['inputs_embeds'] = mm_inputs['inputs_embeds'].to(torch.float16) + generate_output = llm.generate( + **mm_inputs, + generation_config=gen_config, + streamer=None, + bos_token_id=tokenizer.bos_token_id, + stopping_criteria=stop_criteria, + output_hidden_states=True, + return_dict_in_generate=True + ) + predict = tokenizer.decode( + # generate_output.sequences[0], skip_special_tokens=True).strip() + generate_output.sequences[0]).strip() + print("Answer:", predict) + + hidden_states = generate_output.hidden_states + last_hidden_states = [item[-1][0] for item in hidden_states] + last_hidden_states = torch.cat(last_hidden_states, dim=0) + seg_hidden_states = get_seg_hidden_states( + last_hidden_states, generate_output.sequences[0][:-1], + seg_id=model.seg_token_idx + ) + # seg_hidden_states = seg_hidden_states.to(torch.float32) + # print("Mask num: ", len(seg_hidden_states)) + + # seg_hidden_states = projector_text2vision(seg_hidden_states) + return predict, seg_hidden_states + +def get_seg_hidden_states(hidden_states, output_ids, seg_id): + seg_mask = output_ids == seg_id + n_out = len(seg_mask) + return hidden_states[-n_out:][seg_mask] + +def process_and_save_output(output_dir, image_name, text_output, pred_masks): + if not os.path.exists(output_dir): + os.mkdir(output_dir) + + text_output = text_output.replace("", "").replace("\n", "").replace(" ", " ") + text_output = text_output.split("ASSISTANT: ")[-1] + + cleaned_str = re.sub(r'<.*?>', '', text_output) + + pattern = re.compile(r'

(.*?)<\/p>') + phrases = pattern.findall(text_output) + phrases = [p.strip() for p in phrases] + + # Remove the [SEG] token + cleaned_str = cleaned_str.replace('[SEG]', '') + + # Strip unnecessary spaces + cleaned_str = ' '.join(cleaned_str.split()).strip("'") + cleaned_str = cleaned_str.strip() + + # Convert the predicted masks into RLE format + pred_masks_tensor = pred_masks.cpu() + uncompressed_mask_rles = mask_to_rle_pytorch(pred_masks_tensor) + rle_masks = [] + for m in uncompressed_mask_rles: + rle_masks.append(coco_encode_rle(m)) + + # Create results dictionary + result_dict = { + "image_id": image_name[:-4], + "caption": cleaned_str, + "phrases": phrases, + "pred_masks": rle_masks + } + + # print(cleaned_str) + # print(phrases) + + output_path = f"{output_dir}/{image_name[:-4]}.json" + + with open(output_path, 'w') as f: + json.dump(result_dict, f) + + return + +def mask_to_rle_pytorch(tensor: torch.Tensor): + """ + Encodes masks to an uncompressed RLE, in the format expected by + pycoco tools. + """ + # Put in fortran order and flatten h,w + b, h, w = tensor.shape + tensor = tensor.permute(0, 2, 1).flatten(1) + + # Compute change indices + diff = tensor[:, 1:] ^ tensor[:, :-1] + change_indices = diff.nonzero() + + # Encode run length + out = [] + for i in range(b): + cur_idxs = change_indices[change_indices[:, 0] == i, 1] + cur_idxs = torch.cat( + [torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), cur_idxs + 1, + torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), ] + ) + btw_idxs = cur_idxs[1:] - cur_idxs[:-1] + counts = [] if tensor[i, 0] == 0 else [0] + counts.extend(btw_idxs.detach().cpu().tolist()) + out.append({"size": [h, w], "counts": counts}) + + return out + +def coco_encode_rle(uncompressed_rle): + h, w = uncompressed_rle["size"] + rle = mask_utils.frPyObjects(uncompressed_rle, h, w) + rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json + + return rle + +if __name__ == '__main__': + + main() diff --git a/omg_llava/tools/mmbench_omg_seg_llava.py b/omg_llava/tools/mmbench_omg_seg_llava.py new file mode 100644 index 0000000000000000000000000000000000000000..23de5842ec68e2bc38e835f4a782047cc523cb87 --- /dev/null +++ b/omg_llava/tools/mmbench_omg_seg_llava.py @@ -0,0 +1,498 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import json +import math +import os.path as osp +import re +import string +import time + +import numpy as np +import pandas as pd +import torch +import tqdm +from huggingface_hub import snapshot_download +from mmengine import mkdir_or_exist +from mmengine.dist import (collect_results, get_dist_info, get_rank, init_dist, + master_only) +from mmengine.utils.dl_utils import set_multi_processing +from peft import PeftModel +from rich.console import Console +from rich.table import Table +from torch.utils.data import Dataset +from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel, GenerationConfig) + +from xtuner.dataset.utils import decode_base64_to_image, expand2square +from xtuner.model.utils import LoadWoInit, prepare_inputs_labels_for_multimodal +from xtuner.tools.utils import get_stop_criteria, is_cn_string +from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, + PROMPT_TEMPLATE) +from importlib import import_module +from xtuner.registry import BUILDER +from xtuner.configs import cfgs_name_path +from xtuner.model.utils import guess_load_checkpoint +from mmengine.config import Config +from mmengine.fileio import PetrelBackend, get_file_backend +from mmengine.config import ConfigDict + +def convert_dict2config_dict(input): + input = ConfigDict(**input) + for key in input.keys(): + if isinstance(input[key], dict): + input[key] = convert_dict2config_dict(input[key]) + return input + +TORCH_DTYPE_MAP = dict( + fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto') + + +def parse_args(): + parser = argparse.ArgumentParser(description='MMBench') + parser.add_argument('config', help='config file name or path.') + parser.add_argument('pth_model', help='pth model file') + parser.add_argument('--data-path', default=None, help='data path') + parser.add_argument('--work-dir', help='the dir to save results') + parser.add_argument('--llava', default=None, help='llava name or path') + parser.add_argument( + '--visual-encoder', default=None, help='visual encoder name or path') + parser.add_argument( + '--visual-select-layer', default=-2, help='visual select layer') + parser.add_argument( + '--prompt-template', + choices=PROMPT_TEMPLATE.keys(), + default='internlm2_chat', + help='Specify a prompt template') + parser.add_argument( + '--stop-words', nargs='+', type=str, default=[], help='Stop words') + parser.add_argument( + '--torch-dtype', + default='fp16', + choices=TORCH_DTYPE_MAP.keys(), + help='Override the default `torch.dtype` and load the model under ' + 'a specific `dtype`.') + parser.add_argument( + '--bits', + type=int, + choices=[4, 8, None], + default=None, + help='LLM bits') + parser.add_argument( + '--bot-name', type=str, default='BOT', help='Name for Bot') + parser.add_argument( + '--offload-folder', + default=None, + help='The folder in which to offload the model weights (or where the ' + 'model weights are already offloaded).') + parser.add_argument( + '--max-new-tokens', + type=int, + default=100, + help='Maximum number of new tokens allowed in generated text') + parser.add_argument( + '--seed', + type=int, + default=0, + help='Random seed for reproducible text generation') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + args = parser.parse_args() + return args + + +@master_only +def master_print(msg): + print(msg) + + +class MMBenchDataset(Dataset): + ABBRS = { + 'coarse_perception': 'CP', + 'finegrained_perception (instance-level)': 'FP-S', + 'finegrained_perception (cross-instance)': 'FP-C', + 'logic_reasoning': 'LR', + 'relation_reasoning': 'RR', + 'attribute_reasoning': 'AR', + 'sketch_reasoning': 'Sketch Reasoning', + 'scenery_building': 'Scenery & Building', + 'food_clothes': 'Food & Clothes', + 'historical_figure': 'Historical Figure', + 'traditional_show': 'Traditional Show', + 'calligraphy_painting': 'Calligraphy Painting', + 'cultural_relic': 'Cultural Relic' + } + + def __init__(self, data_file): + self.data_file = data_file + self.df = pd.read_csv(data_file, sep='\t') + self.split = 'dev' if 'answer' in self.df.iloc[0].keys() else 'test' + self.has_l2_category = 'l2-category' in self.df.columns.to_list() + + def get_image(self, image): + while len(image) < 16: + image = self.df[self.df['index'] == int(image)]['image'].values + assert len(image) == 1 + image = image[0] + image = decode_base64_to_image(image) + return image + + def __len__(self): + return len(self.df) + + def __getitem__(self, idx): + index = self.df.iloc[idx]['index'] + image = self.df.iloc[idx]['image'] + image = self.get_image(image) + question = self.df.iloc[idx]['question'] + answer = self.df.iloc[idx]['answer'] if 'answer' in self.df.iloc[ + 0].keys() else None + category = self.df.iloc[idx]['category'] + + options = { + cand: self.load_from_df(idx, cand) + for cand in string.ascii_uppercase + if self.load_from_df(idx, cand) is not None + } + options_prompt = '' + for key, item in options.items(): + options_prompt += f'{key}. {item}\n' + + hint = self.load_from_df(idx, 'hint') + data = { + 'img': image, + 'question': question, + 'answer': answer, + 'options': options_prompt, + 'category': category, + 'options_dict': options, + 'index': index, + 'context': hint, + } + if self.has_l2_category: + data.update({'l2-category': self.df.iloc[idx]['l2-category']}) + return data + + def load_from_df(self, idx, key): + if key in self.df.iloc[idx] and not pd.isna(self.df.iloc[idx][key]): + return self.df.iloc[idx][key] + else: + return None + + @master_only + def eval_result(self, result_df, show=True): + + def calc_acc(df, group='category'): + assert group in ['overall', 'category', 'l2-category'] + if group == 'overall': + res = {'Average': np.mean(df['hit'])} + else: + res = {} + abilities = list(set(df[group])) + abilities.sort() + for ab in abilities: + sub_df = df[df[group] == ab] + ab = self.ABBRS[ab] if ab in self.ABBRS else ab + res[ab] = np.mean(sub_df['hit']) + return res + + def eval_sub_data(sub_data, answer_map): + lt = len(sub_data) + for i in range(lt): + item = sub_data.iloc[i] + match = re.search(r'([A-D]+)', item['prediction']) + pred = match.group(1) if match else '' + gt = answer_map[item['index']] + if gt != pred: + return 0 + return 1 + + def show_result(ret_json): + show_dict = ret_json.copy() + table = Table(title=f' MMBench ({self.data_file}) ') + console = Console() + table.add_column('Category', justify='left') + table.add_column('Accuracy (%)', justify='right') + average = show_dict.pop('Average') * 100 + table.add_row('Average', f'{average:.1f}') + table.add_section() + for cat_name, cat_acc in show_dict.items(): + table.add_row(cat_name, f'{cat_acc * 100:.1f}') + with console.capture() as capture: + console.print(table, end='') + print('\n' + capture.get()) + print('Note: Please be cautious if you use the results in papers, ' + "since we don't use ChatGPT as a helper for choice " + 'extraction') + + data = result_df.sort_values(by='index') + data['prediction'] = [str(x) for x in data['prediction']] + for k in data.keys(): + data[k.lower() if k not in 'ABCD' else k] = data.pop(k) + + data_main = data[data['index'] < int(1e6)] + cate_map = { + i: c + for i, c in zip(self.df['index'], self.df['category']) + } + if self.has_l2_category: + l2_cate_map = { + i: c + for i, c in zip(self.df['index'], self.df['l2-category']) + } + answer_map = { + i: c + for i, c in zip(self.df['index'], self.df['answer']) + } + + lt = len(data_main) + hit, tot = 0, 0 + result = {} + for i in range(lt): + item_main = data_main.iloc[i] + idx = item_main['index'] + assert idx not in result + sub_data = data[data['index'] % int(1e6) == idx] + ret = eval_sub_data(sub_data, answer_map) + result[idx] = ret + hit += ret + tot += 1 + + indices = data_main['index'] + data_main = data_main.copy() + data_main['hit'] = [result[i] for i in indices] + main_idx = data_main['index'] + data_main['category'] = [cate_map[i] for i in main_idx] + + ret_json = calc_acc(data_main, 'overall') + + if self.has_l2_category: + data_main['l2-category'] = [l2_cate_map[i] for i in main_idx] + l2 = calc_acc(data_main, 'l2-category') + ret_json.update(l2) + else: + leaf = calc_acc(data_main, 'category') + ret_json.update(leaf) + if show: + show_result(ret_json) + return ret_json + + +def main(): + args = parse_args() + + torch.manual_seed(args.seed) + + if args.launcher != 'none': + set_multi_processing(distributed=True) + init_dist(args.launcher) + + rank, world_size = get_dist_info() + torch.cuda.set_device(rank) + else: + rank = 0 + world_size = 1 + + # build model + if not osp.isfile(args.config): + try: + args.config = cfgs_name_path[args.config] + except KeyError: + raise FileNotFoundError(f'Cannot find {args.config}') + + # load config + cfg = Config.fromfile(args.config) + # if args.cfg_options is not None: + # cfg.merge_from_dict(args.cfg_options) + + model_name = cfg.model.type if isinstance(cfg.model.type, + str) else cfg.model.type.__name__ + if 'LLaVAModel' or 'OMG' in model_name: + cfg.model.pretrained_pth = None + + model = BUILDER.build(cfg.model) + backend = get_file_backend(args.pth_model) + if isinstance(backend, PetrelBackend): + from xtuner.utils.fileio import patch_fileio + with patch_fileio(): + state_dict = guess_load_checkpoint(args.pth_model) + else: + state_dict = guess_load_checkpoint(args.pth_model) + + print(state_dict.keys()) + # del state_dict['llm.base_model.model.model.tok_embeddings.weight'] + model.load_state_dict(state_dict, strict=False) + print(f'Load PTH model from {args.pth_model}') + # image_processor_cfg = copy.deepcopy(cfg.image_processor) + image_processor = cfg.image_processor + image_processor_type = image_processor['type'] + del image_processor['type'] + image_processor = image_processor_type(**image_processor) + + llm = model.llm + tokenizer = model.tokenizer + + model.cuda() + model.eval() + llm.eval() + visual_encoder = model.visual_encoder + projector = model.projector + projector_text2vision = model.projector_text2vision + + projector.cuda() + projector.eval() + + visual_encoder.cuda() + visual_encoder.eval() + + stop_words = args.stop_words + if args.prompt_template: + template = PROMPT_TEMPLATE[args.prompt_template] + stop_words += template.get('STOP_WORDS', []) + stop_criteria = get_stop_criteria( + tokenizer=tokenizer, stop_words=stop_words) + + gen_config = GenerationConfig( + max_new_tokens=args.max_new_tokens, + do_sample=False, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id + if tokenizer.pad_token_id is not None else tokenizer.eos_token_id, + ) + + # work_dir + if args.work_dir is not None: + # update configs according to CLI args if args.work_dir is not None + save_dir = args.work_dir + else: + # use config filename as default work_dir + save_dir = osp.join('./work_dirs', + osp.splitext(osp.basename(args.data_path))[0]) + timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time())) + save_dir = osp.join(save_dir, timestamp) + + if rank == 0: + mkdir_or_exist(osp.abspath(save_dir)) + print('=======================================================') + print(f'Dataset path: {osp.abspath(args.data_path)}\n' + f'Results will be saved to {osp.abspath(save_dir)}') + print('=======================================================') + + args_path = osp.join(save_dir, 'args.json') + with open(args_path, 'w', encoding='utf-8') as f: + json.dump(args.__dict__, f, indent=2) + + results_xlsx_path = osp.join(save_dir, 'mmbench_result.xlsx') + results_json_path = osp.join(save_dir, 'mmbench_result.json') + + dataset = MMBenchDataset(args.data_path) + + results = [] + n_samples = len(dataset) + per_rank_samples = math.ceil(n_samples / world_size) + + per_rank_ids = range(per_rank_samples * rank, + min(n_samples, per_rank_samples * (rank + 1))) + for i in tqdm.tqdm(per_rank_ids, desc=f'Rank {rank}'): + data_sample = dataset[i] + if data_sample['context'] is not None: + text = data_sample['context'] + '\n' + data_sample[ + 'question'] + '\n' + data_sample['options'] + else: + text = data_sample['question'] + '\n' + data_sample['options'] + + text = DEFAULT_IMAGE_TOKEN + '\n' + text + + if is_cn_string(text): + text = text + '请直接回答选项字母。' + else: + text = text + ("Answer with the option's letter from the " + 'given choices directly.') + + if args.prompt_template: + prompt_text = '' + template = PROMPT_TEMPLATE[args.prompt_template] + prompt_text += template['INSTRUCTION'].format( + input=text, round=1, bot_name=args.bot_name) + else: + prompt_text = text + inputs = prompt_text + + image = data_sample['img'].convert('RGB') + image = expand2square( + image, tuple(int(x * 255) for x in image_processor.image_mean)) + image = image_processor.preprocess( + image, return_tensors='pt')['pixel_values'][0] + image = image.cuda().unsqueeze(0).to(visual_encoder.dtype) + visual_outputs = visual_encoder(image, output_hidden_states=True) + if isinstance(visual_outputs, list) or isinstance(visual_outputs, tuple)\ + or isinstance(visual_outputs, torch.Tensor): + pixel_values = projector(visual_outputs) + else: + pixel_values = projector( + visual_outputs.hidden_states[args.visual_select_layer][:, 1:]) + # pixel_values = projector( + # visual_outputs.hidden_states[args.visual_select_layer][:, 1:]) + + chunk_encode = [] + for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)): + if idx == 0: + cur_encode = tokenizer.encode(chunk) + else: + cur_encode = tokenizer.encode(chunk, add_special_tokens=False) + chunk_encode.append(cur_encode) + assert len(chunk_encode) == 2 + ids = [] + for idx, cur_chunk_encode in enumerate(chunk_encode): + ids.extend(cur_chunk_encode) + if idx != len(chunk_encode) - 1: + ids.append(IMAGE_TOKEN_INDEX) + ids = torch.tensor(ids).cuda().unsqueeze(0) + mm_inputs = prepare_inputs_labels_for_multimodal( + llm=llm, input_ids=ids, pixel_values=pixel_values) + + generate_output = llm.generate( + **mm_inputs, + generation_config=gen_config, + streamer=None, + bos_token_id=tokenizer.bos_token_id, + stopping_criteria=stop_criteria) + + predict = tokenizer.decode( + generate_output[0], skip_special_tokens=True).strip() + cur_result = {} + cur_result['question'] = data_sample.get('question') + cur_result.update(data_sample.get('options_dict')) + cur_result['prediction'] = predict + if data_sample.get('category') is not None: + cur_result['category'] = data_sample.get('category') + if data_sample.get('l2-category') is not None: + cur_result['l2-category'] = data_sample.get('l2-category') + cur_result['index'] = data_sample.get('index') + cur_result['split'] = data_sample.get('split') + cur_result['answer'] = data_sample.get('answer') + results.append(cur_result) + + results = collect_results(results, n_samples) + + if get_rank() == 0: + + results_df = pd.DataFrame(results) + # with pd.ExcelWriter(results_xlsx_path, engine='openpyxl') as writer: + with pd.ExcelWriter(results_xlsx_path, engine='xlsxwriter') as writer: + results_df.to_excel(writer, index=False) + + if dataset.split == 'dev': + results_dict = dataset.eval_result(results_df, show=True) + with open(results_json_path, 'w', encoding='utf-8') as f: + json.dump(results_dict, f, indent=2) + else: + print('All done!') + + +if __name__ == '__main__': + + main() diff --git a/omg_llava/tools/refcoco_omg_seg_llava.py b/omg_llava/tools/refcoco_omg_seg_llava.py new file mode 100644 index 0000000000000000000000000000000000000000..c3f445ef3e6a6d7733a39c1f317113af67ff1822 --- /dev/null +++ b/omg_llava/tools/refcoco_omg_seg_llava.py @@ -0,0 +1,727 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import math +import os +import os.path as osp +import numpy as np +import torch +import tqdm +from mmengine.dist import (collect_results, get_dist_info, get_rank, init_dist, + master_only) +from mmengine.utils.dl_utils import set_multi_processing +from torch.utils.data import Dataset +from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel, GenerationConfig) + +from xtuner.model.utils import prepare_inputs_labels_for_multimodal +from xtuner.tools.utils import get_stop_criteria +from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, + PROMPT_TEMPLATE) +from xtuner.registry import BUILDER +from xtuner.configs import cfgs_name_path +from xtuner.model.utils import guess_load_checkpoint +from mmengine.config import Config +from mmengine.fileio import PetrelBackend, get_file_backend +from mmengine.config import ConfigDict + +import logging +from mmengine import print_log +from PIL import Image +from pycocotools import mask +import torch.nn.functional as F +from omg_llava.dataset.utils import expand2square +from omg_llava.dataset.utils.refcoco_refer import REFER +from omg_llava.tools.utils_refcoco import AverageMeter, Summary, intersectionAndUnionGPU + + +def convert_dict2config_dict(input): + input = ConfigDict(**input) + for key in input.keys(): + if isinstance(input[key], dict): + input[key] = convert_dict2config_dict(input[key]) + return input + +TORCH_DTYPE_MAP = dict( + fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto') + + +def parse_args(): + parser = argparse.ArgumentParser(description='RefCocoSeg') + parser.add_argument('config', help='config file name or path.') + parser.add_argument('pth_model', help='pth model file') + parser.add_argument( + '--dataset', + choices=DATASETS_ATTRIBUTES.keys(), + default='refcoco', + help='Specify a ref dataset') + parser.add_argument( + '--split', + default='val', + help='Specify a split') + parser.add_argument( + '--prompt-template', + choices=PROMPT_TEMPLATE.keys(), + default='internlm2_chat', + help='Specify a prompt template') + parser.add_argument( + '--stop-words', nargs='+', type=str, default=[], help='Stop words') + parser.add_argument( + '--torch-dtype', + default='fp16', + choices=TORCH_DTYPE_MAP.keys(), + help='Override the default `torch.dtype` and load the model under ' + 'a specific `dtype`.') + parser.add_argument( + '--bits', + type=int, + choices=[4, 8, None], + default=None, + help='LLM bits') + parser.add_argument( + '--bot-name', type=str, default='BOT', help='Name for Bot') + parser.add_argument( + '--offload-folder', + default=None, + help='The folder in which to offload the model weights (or where the ' + 'model weights are already offloaded).') + parser.add_argument( + '--max-new-tokens', + type=int, + default=100, + help='Maximum number of new tokens allowed in generated text') + parser.add_argument( + '--seed', + type=int, + default=0, + help='Random seed for reproducible text generation') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + args = parser.parse_args() + return args + +DATASETS_ATTRIBUTES = { + 'refcoco': {'splitBy': "unc", 'dataset_name': 'refcoco'}, + 'refcoco_plus': {'splitBy': "unc", 'dataset_name': 'refcoco+'}, + 'refcocog': {'splitBy': "umd", 'dataset_name': 'refcocog'}, +} + +@master_only +def master_print(msg): + print(msg) + +class RefcocoReferringSegDataset(Dataset): + def __init__(self, + image_folder, + image_processor, + dataset_name, + data_path=None, + tokenizer=None, + offline_processed_text_folder=None, + pad_image_to_square=False, + debug=False, + repeats=1, + split='val', + ): + self.split = split + self._set_attribute(dataset_name) + self.debug = debug + if offline_processed_text_folder and data_path: + print_log( + 'Both `offline_processed_text_folder` and ' + '`data_path` are set, and we load dataset from' + '`offline_processed_text_folder` ' + f'({offline_processed_text_folder})', + logger='current', + level=logging.WARNING) + + if offline_processed_text_folder is not None: + raise NotImplementedError + else: + json_datas = self.json_file_preprocess(data_path) + self.json_datas = json_datas + # json_datas = self.only_get_hf_map_infos() + # json_data = DatasetDict({'train': HFDataset.from_list(json_datas)}) + # self.text_data = process_hf_dataset( + # dataset=json_data, + # tokenizer=tokenizer, + # max_length=max_length, + # dataset_map_fn=dataset_map_fn, + # template_map_fn=template_map_fn, + # split='train', + # max_dataset_length=max_dataset_length, + # remove_unused_columns=False, + # pack_to_max_length=False, + # with_image_token=True, + # map_num_proc=num_proc, # because limited mem + # ) + + self.image_folder = image_folder + size = image_processor.crop_size + if isinstance(size, int): + self.image_h, self.image_w = size, size + else: + self.image_w, self.image_h = size + + if isinstance(image_processor, dict) or isinstance( + image_processor, Config) or isinstance(image_processor, + ConfigDict): + self.image_processor = BUILDER.build(image_processor) + else: + self.image_processor = image_processor + self.pad_image_to_square = pad_image_to_square + self.down_ratio = 1 + self.repeats = repeats + + def _set_attribute(self, dataset_name): + attr_dict = DATASETS_ATTRIBUTES[dataset_name] + + self.splitBy = attr_dict['splitBy'] + self.dataset_name = attr_dict['dataset_name'] + + def __len__(self): + return len(self.json_datas) * self.repeats + + def real_len(self): + return len(self.json_datas) + + def json_file_preprocess(self, data_path): + splitBy = self.splitBy + dataset_name = self.dataset_name + refer_api = REFER(data_path, dataset_name, splitBy) + ref_ids_train = refer_api.getRefIds(split=self.split) + images_ids_train = refer_api.getImgIds(ref_ids=ref_ids_train) + refs_train = refer_api.loadRefs(ref_ids=ref_ids_train) + self.img2refs = self.create_img_to_refs_mapping(refs_train) + + image_infos = [] + loaded_images = refer_api.loadImgs(image_ids=images_ids_train) + for item in loaded_images: + item = item.copy() + image_infos.append(item) + + self.annotations = refer_api.Anns + # self.img2refs = self.create_img_to_refs_mapping(refs_train) + + refs = [self.img2refs[image_info['id']] for image_info in image_infos] + + ret = [] + for image_info, ref in zip(image_infos, refs): + if len(ref) == 0: + continue + + sents = [] + ann_ids = [] + for _ref in ref: + for sent in _ref["sentences"]: + text = sent["sent"] + sents.append(text) + ann_ids.append(_ref["ann_id"]) + + # if len(sents) >= 3: + # sampled_inds = np.random.choice( + # list(range(len(sents))), 3, replace=False + # ) + # else: + # sampled_inds = list(range(len(sents))) + + sampled_inds = list(range(len(sents))) + sampled_sents = np.vectorize(sents.__getitem__)(sampled_inds).tolist() + # sampled_ann_ids = np.vectorize(ann_ids.__getitem__)(sampled_inds).tolist() + sampled_ann_ids = [ann_ids[ind] for ind in sampled_inds] + selected_labels = sampled_sents + ret.append( + {'image_info': image_info, + 'sampled_ann_id': sampled_ann_ids, + 'selected_labels': selected_labels, + 'image': image_info['file_name'] + } + ) + if self.debug: + return ret[:10] + return ret + + def load_images(self, refer_api, images_ids_train, dataset_dir, dataset_name, inference=False): + images = [] + loaded_images = refer_api.loadImgs(image_ids=images_ids_train) + for item in loaded_images: + item = item.copy() + images.append(item) + return images + + def create_img_to_refs_mapping(self, refs_train): + img2refs = {} + for ref in refs_train: + img2refs[ref["image_id"]] = img2refs.get(ref["image_id"], []) + [ref, ] + return img2refs + + def decode_mask(self, annotations_ids, image_info): + flag = False + masks = [] + + for ann_id in annotations_ids: + if isinstance(ann_id, list): + flag = True + if -1 in ann_id: + assert len(ann_id) == 1 + m = np.zeros((image_info["height"], image_info["width"])).astype( + np.uint8 + ) + else: + m_final = np.zeros( + (image_info["height"], image_info["width"]) + ).astype(np.uint8) + for ann_id_i in ann_id: + ann = self.annotations[ann_id_i] + + if len(ann["segmentation"]) == 0: + m = np.zeros( + (image_info["height"], image_info["width"]) + ).astype(np.uint8) + else: + if type(ann["segmentation"][0]) == list: # polygon + rle = mask.frPyObjects( + ann["segmentation"], image_info["height"], image_info["width"], ) + else: + rle = ann["segmentation"] + for i in range(len(rle)): + if not isinstance(rle[i]["counts"], bytes): + rle[i]["counts"] = rle[i]["counts"].encode() + m = mask.decode(rle) + m = np.sum( + m, axis=2 + ) # sometimes there are multiple binary map (corresponding to multiple segs) + m = m.astype(np.uint8) # convert to np.uint8 + m_final = m_final | m + m = m_final + masks.append(m) + continue + + ann = self.annotations[ann_id] + + if len(ann["segmentation"]) == 0: + m = np.zeros((image_info["height"], image_info["width"])).astype( + np.uint8 + ) + masks.append(m) + continue + + if type(ann["segmentation"][0]) == list: # polygon + rle = mask.frPyObjects( + ann["segmentation"], image_info["height"], image_info["width"] + ) + else: + rle = ann["segmentation"] + for i in range(len(rle)): + if not isinstance(rle[i]["counts"], bytes): + rle[i]["counts"] = rle[i]["counts"].encode() + m = mask.decode(rle) + m = np.sum(m, axis=2) # sometimes there are multiple binary map (corresponding to multiple segs) + m = m.astype(np.uint8) # convert to np.uint8 + masks.append(m) + masks = np.stack(masks, axis=0) + + # if self.pad_image_to_square: + # masks = expand2square_mask(masks) + masks = torch.from_numpy(masks) + + # masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio, + # self.image_w // self.down_ratio), mode='nearest').squeeze(0) + + # print(image_info['file_name']) + # print(masks.shape) + # save_masks = torch.stack([masks[0], masks[0], masks[0]], dim=-1) + # save_masks = save_masks.numpy() * 255 + # save_masks = Image.fromarray(save_masks.astype(np.uint8)) + # save_masks.save("/root/mask.png") + # print(kkk) + return masks + + def only_get_text_infos(self, json_data): + return {'sampled_sents': json_data['selected_labels']} + + def get_questions(self, text_require_infos): + sampled_sents = text_require_infos['sampled_sents'] + ret = [] + for sent in sampled_sents: + ret.append("Please segment {} in this image.".format(sent)) + return ret + + def filter_data_dict(self, data_dict): + names = ['pixel_values', 'masks', 'ori_size', 'questions'] + ret = {name: data_dict[name] for name in names} + return ret + + def __getitem__(self, index): + index = index % self.real_len() + data_dict = self.json_datas[index] + text_require_infos = self.only_get_text_infos(data_dict) + questions = self.get_questions(text_require_infos) + + assert data_dict.get('image', None) is not None + if data_dict.get('image', None) is not None: + image_file = data_dict['image'] + image_file = os.path.join(self.image_folder, image_file) + # print(image_file) + image = Image.open(image_file).convert('RGB') + ori_width, ori_height = image.size + if self.pad_image_to_square: + image = expand2square( + image, + tuple( + int(x * 255) for x in self.image_processor.image_mean)) + image = self.image_processor.preprocess( + image, return_tensors='pt')['pixel_values'][0] + data_dict['pixel_values'] = image + + # process and get masks + masks = self.decode_mask(data_dict['sampled_ann_id'], data_dict['image_info']) + data_dict['masks'] = masks + data_dict['ori_size'] = (ori_width, ori_height) + data_dict['questions'] = questions + else: + if hasattr(self.image_processor, 'crop_size'): + crop_size = self.image_processor.crop_size + else: + crop_size = self.image_processor.size + data_dict['pixel_values'] = torch.zeros(3, crop_size['height'], + crop_size['width']) + data_dict['masks'] = None + # pixel_values, binary masks, conversation/input ids + return self.filter_data_dict(data_dict) + +def main(): + args = parse_args() + + torch.manual_seed(args.seed) + + if args.launcher != 'none': + set_multi_processing(distributed=True) + init_dist(args.launcher) + + rank, world_size = get_dist_info() + torch.cuda.set_device(rank) + else: + rank = 0 + world_size = 1 + + # build model + if not osp.isfile(args.config): + try: + args.config = cfgs_name_path[args.config] + except KeyError: + raise FileNotFoundError(f'Cannot find {args.config}') + + # load config + cfg = Config.fromfile(args.config) + # if args.cfg_options is not None: + # cfg.merge_from_dict(args.cfg_options) + + model_name = cfg.model.type if isinstance(cfg.model.type, + str) else cfg.model.type.__name__ + if 'LLaVAModel' or 'OMG' in model_name: + cfg.model.pretrained_pth = None + + model = BUILDER.build(cfg.model) + backend = get_file_backend(args.pth_model) + + if os.path.exists(cfg.pretrained_pth): + if isinstance(backend, PetrelBackend): + from xtuner.utils.fileio import patch_fileio + with patch_fileio(): + state_dict = guess_load_checkpoint(cfg.pretrained_pth) + else: + state_dict = guess_load_checkpoint(cfg.pretrained_pth) + + # del state_dict['llm.base_model.model.model.tok_embeddings.weight'] + model.load_state_dict(state_dict, strict=False) + print(f'Load pre PTH model from {cfg.pretrained_pth}') + + if isinstance(backend, PetrelBackend): + from xtuner.utils.fileio import patch_fileio + with patch_fileio(): + state_dict = guess_load_checkpoint(args.pth_model) + else: + state_dict = guess_load_checkpoint(args.pth_model) + + # del state_dict['llm.base_model.model.model.tok_embeddings.weight'] + # print(state_dict.keys()) + model.load_state_dict(state_dict, strict=False) + print(f'Load PTH model from {args.pth_model}') + # image_processor_cfg = copy.deepcopy(cfg.image_processor) + image_processor = cfg.image_processor + image_processor_type = image_processor['type'] + del image_processor['type'] + image_processor = image_processor_type(**image_processor) + + llm = model.llm + tokenizer = model.tokenizer + + model.cuda() + model.eval() + llm.eval() + visual_encoder = model.visual_encoder + projector = model.projector + projector_text2vision = model.projector_text2vision + + projector.cuda() + projector.eval() + + visual_encoder.cuda() + visual_encoder.eval() + + stop_words = args.stop_words + if args.prompt_template: + template = PROMPT_TEMPLATE[args.prompt_template] + stop_words += template.get('STOP_WORDS', []) + stop_criteria = get_stop_criteria( + tokenizer=tokenizer, stop_words=stop_words) + + gen_config = GenerationConfig( + max_new_tokens=args.max_new_tokens, + do_sample=False, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id + if tokenizer.pad_token_id is not None else tokenizer.eos_token_id, + ) + + # # work_dir + # if args.work_dir is not None: + # # update configs according to CLI args if args.work_dir is not None + # save_dir = args.work_dir + # else: + # # use config filename as default work_dir + # save_dir = osp.join('./work_dirs', + # osp.splitext(osp.basename(args.data_path))[0]) + # timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time())) + # save_dir = osp.join(save_dir, timestamp) + + # if rank == 0: + # mkdir_or_exist(osp.abspath(save_dir)) + # print('=======================================================') + # print(f'Dataset path: {osp.abspath(args.data_path)}\n' + # f'Results will be saved to {osp.abspath(save_dir)}') + # print('=======================================================') + + # args_path = osp.join(save_dir, 'args.json') + # with open(args_path, 'w', encoding='utf-8') as f: + # json.dump(args.__dict__, f, indent=2) + + # results_xlsx_path = osp.join(save_dir, 'mmbench_result.xlsx') + # results_json_path = osp.join(save_dir, 'mmbench_result.json') + + dataset = RefcocoReferringSegDataset( + dataset_name=args.dataset, + image_folder='./data/glamm_data/' + 'images/coco2014/train2014/', + image_processor=image_processor, + data_path="./data/ref_seg/", + tokenizer=tokenizer, + pad_image_to_square=True, + debug=False, + split=args.split, + # debug=True, + ) + + results = [] + n_samples = len(dataset) + per_rank_samples = math.ceil(n_samples / world_size) + + per_rank_ids = range(per_rank_samples * rank, + min(n_samples, per_rank_samples * (rank + 1))) + + trackers = { + "intersection": AverageMeter("Intersec", ":6.3f", Summary.SUM), + "union": AverageMeter("Union", ":6.3f", Summary.SUM), + "gIoU": AverageMeter("gIoU", ":6.3f", Summary.SUM) + } + + for i in tqdm.tqdm(per_rank_ids, desc=f'Rank {rank}'): + data_sample = dataset[i] + questions = data_sample['questions'] + texts = [] + for question in questions: + texts.append(DEFAULT_IMAGE_TOKEN + '\n' + question) + + # if data_sample['context'] is not None: + # text = data_sample['context'] + '\n' + data_sample[ + # 'question'] + '\n' + data_sample['options'] + # else: + # text = data_sample['question'] + '\n' + data_sample['options'] + # + # text = DEFAULT_IMAGE_TOKEN + '\n' + text + # + # if is_cn_string(text): + # text = text + '请直接回答选项字母。' + # else: + # text = text + ("Answer with the option's letter from the " + # 'given choices directly.') + prompt_texts = [] + + if args.prompt_template: + for text in texts: + prompt_text = '' + template = PROMPT_TEMPLATE[args.prompt_template] + prompt_text += template['INSTRUCTION'].format( + input=text, round=1, bot_name=args.bot_name) + prompt_texts.append(prompt_text) + else: + prompt_texts = texts + + batch_inputs = prompt_texts + + image = data_sample['pixel_values'] # () + image = image.cuda().unsqueeze(0).to(visual_encoder.dtype) + visual_outputs = visual_encoder(image, output_hidden_states=True) + if isinstance(visual_outputs, list) or isinstance(visual_outputs, tuple)\ + or isinstance(visual_outputs, torch.Tensor): + pixel_values = projector(visual_outputs) + else: + pixel_values = projector( + visual_outputs.hidden_states[args.visual_select_layer][:, 1:]) + # pixel_values = projector( + # visual_outputs.hidden_states[args.visual_select_layer][:, 1:]) + + ori_size = data_sample['ori_size'] + target_masks = data_sample['masks'].cuda().to(torch.uint8) + + intersection, union, accuracy_iou = 0.0, 0.0, 0.0 + + for idx_inp, inputs in enumerate(batch_inputs): + # print("Question: ", inputs) + chunk_encode = [] + for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)): + if idx == 0: + cur_encode = tokenizer.encode(chunk) + else: + cur_encode = tokenizer.encode(chunk, add_special_tokens=False) + chunk_encode.append(cur_encode) + assert len(chunk_encode) == 2 + ids = [] + for idx, cur_chunk_encode in enumerate(chunk_encode): + ids.extend(cur_chunk_encode) + if idx != len(chunk_encode) - 1: + ids.append(IMAGE_TOKEN_INDEX) + ids = torch.tensor(ids).cuda().unsqueeze(0) + mm_inputs = prepare_inputs_labels_for_multimodal( + llm=llm, input_ids=ids, pixel_values=pixel_values) + + # mm_inputs['inputs_embeds'] = mm_inputs['inputs_embeds'].to(torch.float16) + + generate_output = llm.generate( + **mm_inputs, + generation_config=gen_config, + streamer=None, + bos_token_id=tokenizer.bos_token_id, + stopping_criteria=stop_criteria, + output_hidden_states=True, + return_dict_in_generate=True + ) + predict = tokenizer.decode( + # generate_output.sequences[0], skip_special_tokens=True).strip() + generate_output.sequences[0]).strip() + # print("Answer:", predict) + + hidden_states = generate_output.hidden_states + last_hidden_states = [item[-1][-1] for item in hidden_states] + last_hidden_states = torch.cat(last_hidden_states, dim=0) + seg_hidden_states = get_seg_hidden_states( + # last_hidden_states, generate_output.sequences[0], + last_hidden_states, generate_output.sequences[0][:-1], + seg_id=model.seg_token_idx + ) + # seg_hidden_states = seg_hidden_states.to(torch.float32) + # print("Mask num: ", len(seg_hidden_states)) + if len(seg_hidden_states) == 0: + print("Warning, no [SEG] tokens !!!") + continue + elif len(seg_hidden_states) > 1: + print("Warning, {} [SEG] tokens !!!".format(len(seg_hidden_states))) + seg_hidden_states = seg_hidden_states[:1] + + seg_hidden_states = projector_text2vision(seg_hidden_states) + batch_idxs = torch.zeros((seg_hidden_states.shape[0],), + dtype=torch.int64).to(seg_hidden_states.device) + pred_masks_list = model.visual_encoder.forward_llm_seg(seg_hidden_states, batch_idxs) + pred_masks = pred_masks_list[-1] + w, h = ori_size + masks = F.interpolate(pred_masks, size=(max(w, h), max(w, h)), + mode='bilinear', align_corners=False) + masks = masks[:, 0] + # remove padding + if w == h: + pass + elif w > h: + n_pad = w - h + n_pad_1 = n_pad // 2 + n_pad_2 = n_pad - n_pad_1 + masks = masks[:, n_pad_1: w - n_pad_2] + else: + n_pad = h - w + n_pad_1 = n_pad // 2 + n_pad_2 = n_pad - n_pad_1 + masks = masks[:, :, n_pad_1: h - n_pad_2] + # binary + masks = masks.sigmoid() > 0.5 + masks = masks.int() + _target = target_masks[idx_inp:idx_inp+1].int() + + # intersection, union, accuracy_iou = 0.0, 0.0, 0.0 + for target, prediction in zip(masks, _target): + intersect, union_, _ = intersectionAndUnionGPU( + prediction.contiguous().clone(), target.contiguous(), 2, ignore_index=255 + ) + intersection += intersect + union += union_ + accuracy_iou += intersect / (union_ + 1e-5) + # print(intersect / (union_ + 1e-5)) + # handles no-object targets + accuracy_iou[union_ == 0] += 1.0 + + intersection, union = intersection.cpu().numpy(), union.cpu().numpy() + accuracy_iou = accuracy_iou.cpu().numpy() / target_masks.shape[0] + trackers["intersection"].update(intersection) + trackers["union"].update(union) + trackers["gIoU"].update(accuracy_iou, n=target_masks.shape[0]) + + # for meter in trackers.values(): + # meter.all_reduce() + # print(trackers["intersection"].sum, ' ', trackers["union"].sum, ' ', + # trackers["gIoU"].avg, ' ', trackers["gIoU"].count) + cur_results = {'pixel_intersection': trackers["intersection"].sum[1], + 'pixel_union': trackers["union"].sum[1], + 'gIoU': trackers["gIoU"].avg[1], + 'mask_counts': trackers["gIoU"].count, + } + results.append(cur_results) + # iou_per_class = trackers["intersection"].sum / (trackers["union"].sum + 1e-10) + # class_iou = iou_per_class[1] + # global_iou = trackers["gIoU"].avg[1] + # + # print("ciou: ", class_iou) + # print("giou: ", global_iou) + + results = collect_results(results, n_samples) + + if get_rank() == 0: + pixel_intersection = [cur_result['pixel_intersection'] for cur_result in results] + pixel_union = [cur_result['pixel_union'] for cur_result in results] + gIoUs = [cur_result['gIoU'] for cur_result in results] + mask_counts = [cur_result['mask_counts'] for cur_result in results] + + class_iou = sum(pixel_intersection) / (sum(pixel_union) + 1e-10) + global_iou = sum([giou * n_masks for giou, n_masks in zip(gIoUs, mask_counts)]) / sum(mask_counts) + print("ciou: ", class_iou) + print("giou: ", global_iou) + +def get_seg_hidden_states(hidden_states, output_ids, seg_id): + seg_mask = output_ids == seg_id + n_out = len(seg_mask) + return hidden_states[-n_out:][seg_mask] + + +if __name__ == '__main__': + + main() diff --git a/omg_llava/tools/refcoco_omg_seg_llava_msseg.py b/omg_llava/tools/refcoco_omg_seg_llava_msseg.py new file mode 100644 index 0000000000000000000000000000000000000000..761ed557e7a9a3236dd9391e3f242dfcf0922695 --- /dev/null +++ b/omg_llava/tools/refcoco_omg_seg_llava_msseg.py @@ -0,0 +1,756 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import math +import os +import os.path as osp +import numpy as np +import torch +import tqdm +from mmengine.dist import (collect_results, get_dist_info, get_rank, init_dist, + master_only) +from mmengine.utils.dl_utils import set_multi_processing +from torch.utils.data import Dataset +from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel, GenerationConfig) + +from xtuner.model.utils import prepare_inputs_labels_for_multimodal +from xtuner.tools.utils import get_stop_criteria +from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, + PROMPT_TEMPLATE) +from xtuner.registry import BUILDER +from xtuner.configs import cfgs_name_path +from xtuner.model.utils import guess_load_checkpoint +from mmengine.config import Config +from mmengine.fileio import PetrelBackend, get_file_backend +from mmengine.config import ConfigDict + +import logging +from mmengine import print_log +from PIL import Image +from pycocotools import mask +import torch.nn.functional as F +from omg_llava.dataset.utils import expand2square +from omg_llava.dataset.utils.refcoco_refer import REFER +from omg_llava.tools.utils_refcoco import AverageMeter, Summary, intersectionAndUnionGPU + + +def convert_dict2config_dict(input): + input = ConfigDict(**input) + for key in input.keys(): + if isinstance(input[key], dict): + input[key] = convert_dict2config_dict(input[key]) + return input + +TORCH_DTYPE_MAP = dict( + fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto') + + +def parse_args(): + parser = argparse.ArgumentParser(description='RefCocoSeg') + parser.add_argument('config', help='config file name or path.') + parser.add_argument('pth_model', help='pth model file') + parser.add_argument( + '--dataset', + choices=DATASETS_ATTRIBUTES.keys(), + default='refcoco', + help='Specify a ref dataset') + parser.add_argument( + '--split', + default='val', + help='Specify a split') + parser.add_argument( + '--mode', + default='baseline', # baseline, mean, linear_cat + help='Specify a mode') + parser.add_argument( + '--prompt-template', + choices=PROMPT_TEMPLATE.keys(), + default='internlm2_chat', + help='Specify a prompt template') + parser.add_argument( + '--stop-words', nargs='+', type=str, default=[], help='Stop words') + parser.add_argument( + '--torch-dtype', + default='fp16', + choices=TORCH_DTYPE_MAP.keys(), + help='Override the default `torch.dtype` and load the model under ' + 'a specific `dtype`.') + parser.add_argument( + '--bits', + type=int, + choices=[4, 8, None], + default=None, + help='LLM bits') + parser.add_argument( + '--bot-name', type=str, default='BOT', help='Name for Bot') + parser.add_argument( + '--offload-folder', + default=None, + help='The folder in which to offload the model weights (or where the ' + 'model weights are already offloaded).') + parser.add_argument( + '--max-new-tokens', + type=int, + default=100, + help='Maximum number of new tokens allowed in generated text') + parser.add_argument( + '--seed', + type=int, + default=0, + help='Random seed for reproducible text generation') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + args = parser.parse_args() + return args + +DATASETS_ATTRIBUTES = { + 'refcoco': {'splitBy': "unc", 'dataset_name': 'refcoco'}, + 'refcoco_plus': {'splitBy': "unc", 'dataset_name': 'refcoco+'}, + 'refcocog': {'splitBy': "umd", 'dataset_name': 'refcocog'}, +} + +@master_only +def master_print(msg): + print(msg) + +class RefcocoReferringSegDataset(Dataset): + def __init__(self, + image_folder, + image_processor, + dataset_name, + data_path=None, + tokenizer=None, + offline_processed_text_folder=None, + pad_image_to_square=False, + debug=False, + repeats=1, + split='val', + ): + self.split = split + self._set_attribute(dataset_name) + self.debug = debug + if offline_processed_text_folder and data_path: + print_log( + 'Both `offline_processed_text_folder` and ' + '`data_path` are set, and we load dataset from' + '`offline_processed_text_folder` ' + f'({offline_processed_text_folder})', + logger='current', + level=logging.WARNING) + + if offline_processed_text_folder is not None: + raise NotImplementedError + else: + json_datas = self.json_file_preprocess(data_path) + self.json_datas = json_datas + # json_datas = self.only_get_hf_map_infos() + # json_data = DatasetDict({'train': HFDataset.from_list(json_datas)}) + # self.text_data = process_hf_dataset( + # dataset=json_data, + # tokenizer=tokenizer, + # max_length=max_length, + # dataset_map_fn=dataset_map_fn, + # template_map_fn=template_map_fn, + # split='train', + # max_dataset_length=max_dataset_length, + # remove_unused_columns=False, + # pack_to_max_length=False, + # with_image_token=True, + # map_num_proc=num_proc, # because limited mem + # ) + + self.image_folder = image_folder + size = image_processor.crop_size + if isinstance(size, int): + self.image_h, self.image_w = size, size + else: + self.image_w, self.image_h = size + + if isinstance(image_processor, dict) or isinstance( + image_processor, Config) or isinstance(image_processor, + ConfigDict): + self.image_processor = BUILDER.build(image_processor) + else: + self.image_processor = image_processor + self.pad_image_to_square = pad_image_to_square + self.down_ratio = 1 + self.repeats = repeats + + def _set_attribute(self, dataset_name): + attr_dict = DATASETS_ATTRIBUTES[dataset_name] + + self.splitBy = attr_dict['splitBy'] + self.dataset_name = attr_dict['dataset_name'] + + def __len__(self): + return len(self.json_datas) * self.repeats + + def real_len(self): + return len(self.json_datas) + + def json_file_preprocess(self, data_path): + splitBy = self.splitBy + dataset_name = self.dataset_name + refer_api = REFER(data_path, dataset_name, splitBy) + ref_ids_train = refer_api.getRefIds(split=self.split) + images_ids_train = refer_api.getImgIds(ref_ids=ref_ids_train) + refs_train = refer_api.loadRefs(ref_ids=ref_ids_train) + self.img2refs = self.create_img_to_refs_mapping(refs_train) + + image_infos = [] + loaded_images = refer_api.loadImgs(image_ids=images_ids_train) + for item in loaded_images: + item = item.copy() + image_infos.append(item) + + self.annotations = refer_api.Anns + # self.img2refs = self.create_img_to_refs_mapping(refs_train) + + refs = [self.img2refs[image_info['id']] for image_info in image_infos] + + ret = [] + for image_info, ref in zip(image_infos, refs): + if len(ref) == 0: + continue + + sents = [] + ann_ids = [] + for _ref in ref: + for sent in _ref["sentences"]: + text = sent["sent"] + sents.append(text) + ann_ids.append(_ref["ann_id"]) + + # if len(sents) >= 3: + # sampled_inds = np.random.choice( + # list(range(len(sents))), 3, replace=False + # ) + # else: + # sampled_inds = list(range(len(sents))) + + sampled_inds = list(range(len(sents))) + sampled_sents = np.vectorize(sents.__getitem__)(sampled_inds).tolist() + # sampled_ann_ids = np.vectorize(ann_ids.__getitem__)(sampled_inds).tolist() + sampled_ann_ids = [ann_ids[ind] for ind in sampled_inds] + selected_labels = sampled_sents + ret.append( + {'image_info': image_info, + 'sampled_ann_id': sampled_ann_ids, + 'selected_labels': selected_labels, + 'image': image_info['file_name'] + } + ) + if self.debug: + return ret[:10] + return ret + + def load_images(self, refer_api, images_ids_train, dataset_dir, dataset_name, inference=False): + images = [] + loaded_images = refer_api.loadImgs(image_ids=images_ids_train) + for item in loaded_images: + item = item.copy() + images.append(item) + return images + + def create_img_to_refs_mapping(self, refs_train): + img2refs = {} + for ref in refs_train: + img2refs[ref["image_id"]] = img2refs.get(ref["image_id"], []) + [ref, ] + return img2refs + + def decode_mask(self, annotations_ids, image_info): + flag = False + masks = [] + + for ann_id in annotations_ids: + if isinstance(ann_id, list): + flag = True + if -1 in ann_id: + assert len(ann_id) == 1 + m = np.zeros((image_info["height"], image_info["width"])).astype( + np.uint8 + ) + else: + m_final = np.zeros( + (image_info["height"], image_info["width"]) + ).astype(np.uint8) + for ann_id_i in ann_id: + ann = self.annotations[ann_id_i] + + if len(ann["segmentation"]) == 0: + m = np.zeros( + (image_info["height"], image_info["width"]) + ).astype(np.uint8) + else: + if type(ann["segmentation"][0]) == list: # polygon + rle = mask.frPyObjects( + ann["segmentation"], image_info["height"], image_info["width"], ) + else: + rle = ann["segmentation"] + for i in range(len(rle)): + if not isinstance(rle[i]["counts"], bytes): + rle[i]["counts"] = rle[i]["counts"].encode() + m = mask.decode(rle) + m = np.sum( + m, axis=2 + ) # sometimes there are multiple binary map (corresponding to multiple segs) + m = m.astype(np.uint8) # convert to np.uint8 + m_final = m_final | m + m = m_final + masks.append(m) + continue + + ann = self.annotations[ann_id] + + if len(ann["segmentation"]) == 0: + m = np.zeros((image_info["height"], image_info["width"])).astype( + np.uint8 + ) + masks.append(m) + continue + + if type(ann["segmentation"][0]) == list: # polygon + rle = mask.frPyObjects( + ann["segmentation"], image_info["height"], image_info["width"] + ) + else: + rle = ann["segmentation"] + for i in range(len(rle)): + if not isinstance(rle[i]["counts"], bytes): + rle[i]["counts"] = rle[i]["counts"].encode() + m = mask.decode(rle) + m = np.sum(m, axis=2) # sometimes there are multiple binary map (corresponding to multiple segs) + m = m.astype(np.uint8) # convert to np.uint8 + masks.append(m) + masks = np.stack(masks, axis=0) + + # if self.pad_image_to_square: + # masks = expand2square_mask(masks) + masks = torch.from_numpy(masks) + + # masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio, + # self.image_w // self.down_ratio), mode='nearest').squeeze(0) + + # print(image_info['file_name']) + # print(masks.shape) + # save_masks = torch.stack([masks[0], masks[0], masks[0]], dim=-1) + # save_masks = save_masks.numpy() * 255 + # save_masks = Image.fromarray(save_masks.astype(np.uint8)) + # save_masks.save("/root/mask.png") + # print(kkk) + return masks + + def only_get_text_infos(self, json_data): + return {'sampled_sents': json_data['selected_labels']} + + def get_questions(self, text_require_infos): + sampled_sents = text_require_infos['sampled_sents'] + ret = [] + for sent in sampled_sents: + ret.append("Please segment {} in this image.".format(sent)) + return ret + + def filter_data_dict(self, data_dict): + names = ['pixel_values', 'masks', 'ori_size', 'questions'] + ret = {name: data_dict[name] for name in names} + return ret + + def __getitem__(self, index): + index = index % self.real_len() + data_dict = self.json_datas[index] + text_require_infos = self.only_get_text_infos(data_dict) + questions = self.get_questions(text_require_infos) + + assert data_dict.get('image', None) is not None + if data_dict.get('image', None) is not None: + image_file = data_dict['image'] + image_file = os.path.join(self.image_folder, image_file) + # print(image_file) + image = Image.open(image_file).convert('RGB') + ori_width, ori_height = image.size + if self.pad_image_to_square: + image = expand2square( + image, + tuple( + int(x * 255) for x in self.image_processor.image_mean)) + image = self.image_processor.preprocess( + image, return_tensors='pt')['pixel_values'][0] + data_dict['pixel_values'] = image + + # process and get masks + masks = self.decode_mask(data_dict['sampled_ann_id'], data_dict['image_info']) + data_dict['masks'] = masks + data_dict['ori_size'] = (ori_width, ori_height) + data_dict['questions'] = questions + else: + if hasattr(self.image_processor, 'crop_size'): + crop_size = self.image_processor.crop_size + else: + crop_size = self.image_processor.size + data_dict['pixel_values'] = torch.zeros(3, crop_size['height'], + crop_size['width']) + data_dict['masks'] = None + # pixel_values, binary masks, conversation/input ids + return self.filter_data_dict(data_dict) + +def main(): + args = parse_args() + + torch.manual_seed(args.seed) + + if args.launcher != 'none': + set_multi_processing(distributed=True) + init_dist(args.launcher) + + rank, world_size = get_dist_info() + torch.cuda.set_device(rank) + else: + rank = 0 + world_size = 1 + + # build model + if not osp.isfile(args.config): + try: + args.config = cfgs_name_path[args.config] + except KeyError: + raise FileNotFoundError(f'Cannot find {args.config}') + + # load config + cfg = Config.fromfile(args.config) + # if args.cfg_options is not None: + # cfg.merge_from_dict(args.cfg_options) + + model_name = cfg.model.type if isinstance(cfg.model.type, + str) else cfg.model.type.__name__ + if 'LLaVAModel' or 'OMG' in model_name: + cfg.model.pretrained_pth = None + + model = BUILDER.build(cfg.model) + backend = get_file_backend(args.pth_model) + + if os.path.exists(cfg.pretrained_pth): + if isinstance(backend, PetrelBackend): + from xtuner.utils.fileio import patch_fileio + with patch_fileio(): + state_dict = guess_load_checkpoint(cfg.pretrained_pth) + else: + state_dict = guess_load_checkpoint(cfg.pretrained_pth) + + # del state_dict['llm.base_model.model.model.tok_embeddings.weight'] + model.load_state_dict(state_dict, strict=False) + print(f'Load pre PTH model from {cfg.pretrained_pth}') + + if isinstance(backend, PetrelBackend): + from xtuner.utils.fileio import patch_fileio + with patch_fileio(): + state_dict = guess_load_checkpoint(args.pth_model) + else: + state_dict = guess_load_checkpoint(args.pth_model) + + # del state_dict['llm.base_model.model.model.tok_embeddings.weight'] + # print(state_dict.keys()) + model.load_state_dict(state_dict, strict=False) + print(f'Load PTH model from {args.pth_model}') + # image_processor_cfg = copy.deepcopy(cfg.image_processor) + image_processor = cfg.image_processor + image_processor_type = image_processor['type'] + del image_processor['type'] + image_processor = image_processor_type(**image_processor) + + llm = model.llm + tokenizer = model.tokenizer + + model.cuda() + model.eval() + llm.eval() + visual_encoder = model.visual_encoder + projector = model.projector + projector_text2vision = model.projector_text2vision + + projector.cuda() + projector.eval() + + visual_encoder.cuda() + visual_encoder.eval() + + stop_words = args.stop_words + if args.prompt_template: + template = PROMPT_TEMPLATE[args.prompt_template] + stop_words += template.get('STOP_WORDS', []) + stop_criteria = get_stop_criteria( + tokenizer=tokenizer, stop_words=stop_words) + + gen_config = GenerationConfig( + max_new_tokens=args.max_new_tokens, + do_sample=False, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id + if tokenizer.pad_token_id is not None else tokenizer.eos_token_id, + ) + + # # work_dir + # if args.work_dir is not None: + # # update configs according to CLI args if args.work_dir is not None + # save_dir = args.work_dir + # else: + # # use config filename as default work_dir + # save_dir = osp.join('./work_dirs', + # osp.splitext(osp.basename(args.data_path))[0]) + # timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time())) + # save_dir = osp.join(save_dir, timestamp) + + # if rank == 0: + # mkdir_or_exist(osp.abspath(save_dir)) + # print('=======================================================') + # print(f'Dataset path: {osp.abspath(args.data_path)}\n' + # f'Results will be saved to {osp.abspath(save_dir)}') + # print('=======================================================') + + # args_path = osp.join(save_dir, 'args.json') + # with open(args_path, 'w', encoding='utf-8') as f: + # json.dump(args.__dict__, f, indent=2) + + # results_xlsx_path = osp.join(save_dir, 'mmbench_result.xlsx') + # results_json_path = osp.join(save_dir, 'mmbench_result.json') + + dataset = RefcocoReferringSegDataset( + dataset_name=args.dataset, + image_folder='./data/glamm_data/' + 'images/coco2014/train2014/', + image_processor=image_processor, + data_path="./data/ref_seg/", + tokenizer=tokenizer, + pad_image_to_square=True, + debug=False, + split=args.split, + # debug=True, + ) + + results = [] + n_samples = len(dataset) + per_rank_samples = math.ceil(n_samples / world_size) + + per_rank_ids = range(per_rank_samples * rank, + min(n_samples, per_rank_samples * (rank + 1))) + + trackers = { + "intersection": AverageMeter("Intersec", ":6.3f", Summary.SUM), + "union": AverageMeter("Union", ":6.3f", Summary.SUM), + "gIoU": AverageMeter("gIoU", ":6.3f", Summary.SUM) + } + + for i in tqdm.tqdm(per_rank_ids, desc=f'Rank {rank}'): + data_sample = dataset[i] + questions = data_sample['questions'] + texts = [] + for question in questions: + texts.append(DEFAULT_IMAGE_TOKEN + '\n' + question) + + # if data_sample['context'] is not None: + # text = data_sample['context'] + '\n' + data_sample[ + # 'question'] + '\n' + data_sample['options'] + # else: + # text = data_sample['question'] + '\n' + data_sample['options'] + # + # text = DEFAULT_IMAGE_TOKEN + '\n' + text + # + # if is_cn_string(text): + # text = text + '请直接回答选项字母。' + # else: + # text = text + ("Answer with the option's letter from the " + # 'given choices directly.') + prompt_texts = [] + + if args.prompt_template: + for text in texts: + prompt_text = '' + template = PROMPT_TEMPLATE[args.prompt_template] + prompt_text += template['INSTRUCTION'].format( + input=text, round=1, bot_name=args.bot_name) + prompt_texts.append(prompt_text) + else: + prompt_texts = texts + + batch_inputs = prompt_texts + + image = data_sample['pixel_values'] # () + image = image.cuda().unsqueeze(0).to(visual_encoder.dtype) + visual_outputs = visual_encoder(image, output_hidden_states=True) + if isinstance(visual_outputs, list) or isinstance(visual_outputs, tuple)\ + or isinstance(visual_outputs, torch.Tensor): + pixel_values = projector(visual_outputs) + else: + pixel_values = projector( + visual_outputs.hidden_states[args.visual_select_layer][:, 1:]) + # pixel_values = projector( + # visual_outputs.hidden_states[args.visual_select_layer][:, 1:]) + + ori_size = data_sample['ori_size'] + target_masks = data_sample['masks'].cuda().to(torch.uint8) + + intersection, union, accuracy_iou = 0.0, 0.0, 0.0 + + for idx_inp, inputs in enumerate(batch_inputs): + # print("Question: ", inputs) + chunk_encode = [] + for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)): + if idx == 0: + cur_encode = tokenizer.encode(chunk) + else: + cur_encode = tokenizer.encode(chunk, add_special_tokens=False) + chunk_encode.append(cur_encode) + assert len(chunk_encode) == 2 + ids = [] + for idx, cur_chunk_encode in enumerate(chunk_encode): + ids.extend(cur_chunk_encode) + if idx != len(chunk_encode) - 1: + ids.append(IMAGE_TOKEN_INDEX) + ids = torch.tensor(ids).cuda().unsqueeze(0) + mm_inputs = prepare_inputs_labels_for_multimodal( + llm=llm, input_ids=ids, pixel_values=pixel_values) + + # mm_inputs['inputs_embeds'] = mm_inputs['inputs_embeds'].to(torch.float16) + + generate_output = llm.generate( + **mm_inputs, + generation_config=gen_config, + streamer=None, + bos_token_id=tokenizer.bos_token_id, + stopping_criteria=stop_criteria, + output_hidden_states=True, + return_dict_in_generate=True + ) + predict = tokenizer.decode( + # generate_output.sequences[0], skip_special_tokens=True).strip() + generate_output.sequences[0]).strip() + print("Answer:", predict) + + hidden_states = generate_output.hidden_states + if args.mode == 'baseline': + last_hidden_states = [item[-1][-1] for item in hidden_states] + last_hidden_states = torch.cat(last_hidden_states, dim=0) + seg_hidden_states = get_seg_hidden_states( + # last_hidden_states, generate_output.sequences[0], + last_hidden_states, generate_output.sequences[0][:-1], + seg_id=model.seg_token_idx + ) + else: + hidden_states = [torch.cat(item, dim=0)[:, -1] for item in hidden_states] + last_hidden_states = torch.stack(hidden_states, dim=0) # (N, n_layer, c) + seg_hidden_states = get_seg_hidden_states_multistates( + last_hidden_states, generate_output.sequences[0][:-1], + seg_id=model.seg_token_idx, mode=args.mode, + model=model, + ) + # seg_hidden_states = seg_hidden_states.to(torch.float32) + # print("Mask num: ", len(seg_hidden_states)) + if len(seg_hidden_states) == 0: + print("Warning, no [SEG] tokens !!!") + continue + elif len(seg_hidden_states) > 1: + print("Warning, {} [SEG] tokens !!!".format(len(seg_hidden_states))) + seg_hidden_states = seg_hidden_states[:1] + + seg_hidden_states = projector_text2vision(seg_hidden_states) + batch_idxs = torch.zeros((seg_hidden_states.shape[0],), + dtype=torch.int64).to(seg_hidden_states.device) + pred_masks_list = model.visual_encoder.forward_llm_seg(seg_hidden_states, batch_idxs) + pred_masks = pred_masks_list[-1] + w, h = ori_size + masks = F.interpolate(pred_masks, size=(max(w, h), max(w, h)), + mode='bilinear', align_corners=False) + masks = masks[:, 0] + # remove padding + if w == h: + pass + elif w > h: + n_pad = w - h + n_pad_1 = n_pad // 2 + n_pad_2 = n_pad - n_pad_1 + masks = masks[:, n_pad_1: w - n_pad_2] + else: + n_pad = h - w + n_pad_1 = n_pad // 2 + n_pad_2 = n_pad - n_pad_1 + masks = masks[:, :, n_pad_1: h - n_pad_2] + # binary + masks = masks.sigmoid() > 0.5 + masks = masks.int() + _target = target_masks[idx_inp:idx_inp+1].int() + + # intersection, union, accuracy_iou = 0.0, 0.0, 0.0 + for target, prediction in zip(masks, _target): + intersect, union_, _ = intersectionAndUnionGPU( + prediction.contiguous().clone(), target.contiguous(), 2, ignore_index=255 + ) + intersection += intersect + union += union_ + accuracy_iou += intersect / (union_ + 1e-5) + # print(intersect / (union_ + 1e-5)) + # handles no-object targets + accuracy_iou[union_ == 0] += 1.0 + + intersection, union = intersection.cpu().numpy(), union.cpu().numpy() + accuracy_iou = accuracy_iou.cpu().numpy() / target_masks.shape[0] + trackers["intersection"].update(intersection) + trackers["union"].update(union) + trackers["gIoU"].update(accuracy_iou, n=target_masks.shape[0]) + + # for meter in trackers.values(): + # meter.all_reduce() + # print(trackers["intersection"].sum, ' ', trackers["union"].sum, ' ', + # trackers["gIoU"].avg, ' ', trackers["gIoU"].count) + cur_results = {'pixel_intersection': trackers["intersection"].sum[1], + 'pixel_union': trackers["union"].sum[1], + 'gIoU': trackers["gIoU"].avg[1], + 'mask_counts': trackers["gIoU"].count, + } + results.append(cur_results) + # iou_per_class = trackers["intersection"].sum / (trackers["union"].sum + 1e-10) + # class_iou = iou_per_class[1] + # global_iou = trackers["gIoU"].avg[1] + # + # print("ciou: ", class_iou) + # print("giou: ", global_iou) + + results = collect_results(results, n_samples) + + if get_rank() == 0: + pixel_intersection = [cur_result['pixel_intersection'] for cur_result in results] + pixel_union = [cur_result['pixel_union'] for cur_result in results] + gIoUs = [cur_result['gIoU'] for cur_result in results] + mask_counts = [cur_result['mask_counts'] for cur_result in results] + + class_iou = sum(pixel_intersection) / (sum(pixel_union) + 1e-10) + global_iou = sum([giou * n_masks for giou, n_masks in zip(gIoUs, mask_counts)]) / sum(mask_counts) + print("ciou: ", class_iou) + print("giou: ", global_iou) + +def get_seg_hidden_states(hidden_states, output_ids, seg_id): + seg_mask = output_ids == seg_id + n_out = len(seg_mask) + return hidden_states[-n_out:][seg_mask] + + +def get_seg_hidden_states_multistates(hidden_states, output_ids, seg_id, mode, model): + # (N, n_layer, c) + seg_mask = output_ids == seg_id + n_out = len(seg_mask) + hidden_states = hidden_states[-n_out:][seg_mask] # (N, n_layers, c) + if mode == 'mean': + hidden_states = torch.mean(hidden_states, dim=1) + elif mode == 'linear_cat': + hidden_states = hidden_states[:, -model.selected_layers:] + hidden_states = model.seg_token_proj_linear_cat[0](hidden_states) + hidden_states = hidden_states.flatten(1) + hidden_states = model.seg_token_proj_linear_cat[1](hidden_states) + else: + raise NotImplementedError + return hidden_states + +if __name__ == '__main__': + + main() diff --git a/omg_llava/tools/refvos_omg_seg_llava.py b/omg_llava/tools/refvos_omg_seg_llava.py new file mode 100644 index 0000000000000000000000000000000000000000..c3123a2c1aec3d30e6ccc65ad149280e4a0bbef3 --- /dev/null +++ b/omg_llava/tools/refvos_omg_seg_llava.py @@ -0,0 +1,527 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import json +import math +import os +import os.path as osp +import re +import string +import time + +import numpy as np +import pandas as pd +import torch +import tqdm +from huggingface_hub import snapshot_download +from mmengine import mkdir_or_exist +from mmengine.dist import (collect_results, get_dist_info, get_rank, init_dist, + master_only) +from mmengine.utils.dl_utils import set_multi_processing +from peft import PeftModel +from rich.console import Console +from rich.table import Table +from torch.utils.data import Dataset +from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel, GenerationConfig) + +from xtuner.dataset.utils import decode_base64_to_image, expand2square +from xtuner.model.utils import LoadWoInit, prepare_inputs_labels_for_multimodal +from xtuner.tools.utils import get_stop_criteria, is_cn_string +from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, + PROMPT_TEMPLATE) +from xtuner.model.git import GitPerceptionEncoder, GitPerceptionEncoder_Clip +from importlib import import_module +from xtuner.registry import BUILDER +from xtuner.configs import cfgs_name_path +from xtuner.model.utils import guess_load_checkpoint +from mmengine.config import Config +from mmengine.fileio import PetrelBackend, get_file_backend +from mmengine.config import ConfigDict + +import logging +from datasets import Dataset as HFDataset +from datasets import DatasetDict, load_from_disk +from mmengine import print_log +from PIL import Image +from pycocotools import mask +import torch.nn.functional as F +from xtuner.dataset.utils import expand2square, expand2square_mask +from xtuner.dataset.huggingface import process_hf_dataset +from xtuner.dataset.GLAMM_dataset.utils.refcoco_refer import REFER +from xtuner.tools.utils_refcoco import AverageMeter, Summary, intersectionAndUnionGPU + + +def convert_dict2config_dict(input): + input = ConfigDict(**input) + for key in input.keys(): + if isinstance(input[key], dict): + input[key] = convert_dict2config_dict(input[key]) + return input + +TORCH_DTYPE_MAP = dict( + fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto') + + +def parse_args(): + parser = argparse.ArgumentParser(description='RefCocoSeg') + parser.add_argument('config', help='config file name or path.') + parser.add_argument('pth_model', help='pth model file') + parser.add_argument( + '--dataset', + choices=DATASETS_ATTRIBUTES.keys(), + default='refcoco', + help='Specify a ref dataset') + parser.add_argument( + '--split', + default='val', + help='Specify a split') + parser.add_argument( + '--prompt-template', + choices=PROMPT_TEMPLATE.keys(), + default='internlm2_chat', + help='Specify a prompt template') + parser.add_argument( + '--stop-words', nargs='+', type=str, default=[], help='Stop words') + parser.add_argument( + '--torch-dtype', + default='fp16', + choices=TORCH_DTYPE_MAP.keys(), + help='Override the default `torch.dtype` and load the model under ' + 'a specific `dtype`.') + parser.add_argument( + '--bits', + type=int, + choices=[4, 8, None], + default=None, + help='LLM bits') + parser.add_argument( + '--bot-name', type=str, default='BOT', help='Name for Bot') + parser.add_argument( + '--offload-folder', + default=None, + help='The folder in which to offload the model weights (or where the ' + 'model weights are already offloaded).') + parser.add_argument( + '--max-new-tokens', + type=int, + default=100, + help='Maximum number of new tokens allowed in generated text') + parser.add_argument( + '--seed', + type=int, + default=0, + help='Random seed for reproducible text generation') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + args = parser.parse_args() + return args + +DATASETS_ATTRIBUTES = { + 'refcoco': {'splitBy': "unc", 'dataset_name': 'refcoco'}, + 'refcoco_plus': {'splitBy': "unc", 'dataset_name': 'refcoco+'}, + 'refcocog': {'splitBy': "umd", 'dataset_name': 'refcocog'}, +} + +@master_only +def master_print(msg): + print(msg) + +class RefVOSDataset(Dataset): + def __init__(self, + image_folder, + image_processor, + expressions_file=None, + tokenizer=None, + offline_processed_text_folder=None, + pad_image_to_square=False, + debug=False, + repeats=1, + ): + self.debug = debug + + if offline_processed_text_folder is not None: + raise NotImplementedError + else: + json_datas = self.json_file_preprocess(expressions_file, image_folder) + self.json_datas = json_datas + + self.image_folder = image_folder + size = image_processor.crop_size + if isinstance(size, int): + self.image_h, self.image_w = size, size + else: + self.image_w, self.image_h = size + + if isinstance(image_processor, dict) or isinstance( + image_processor, Config) or isinstance(image_processor, + ConfigDict): + self.image_processor = BUILDER.build(image_processor) + else: + self.image_processor = image_processor + self.pad_image_to_square = pad_image_to_square + self.down_ratio = 1 + self.repeats = repeats + + def __len__(self): + return len(self.json_datas) * self.repeats + + def real_len(self): + return len(self.json_datas) + + def json_file_preprocess(self, expression_json_file, image_folder): + + video_files = os.listdir(image_folder) + with open(expression_json_file, 'r') as f: + expression_data = json.load(f) + expression_data = expression_data["videos"] + expression_data_ = {} + for video_file in video_files: + expression_data_.update({video_file: expression_data[video_file]}) + + ret_items = [] + for video_file in expression_data_.keys(): + video_expression_data = expression_data_[video_file] + require_infos = {} + require_infos['video_file'] = video_file + require_infos['frames'] = video_expression_data["frames"] + require_infos['expressions_id'] = list(video_expression_data["expressions"].keys()) + require_infos['expressions'] = [video_expression_data["expressions"][id]["exp"] for id in require_infos['expressions_id']] + ret_items.append(require_infos) + + return ret_items + + def get_questions(self, expressions): + ret = [] + for sent in expressions: + ret.append("Please segment {} in this image.".format(sent)) + return ret + + def __getitem__(self, index): + index = index % self.real_len() + data_dict = self.json_datas[index] + questions = self.get_questions(data_dict['expressions']) + + data_dict['image'] = data_dict['frames'][0] + '.jpg' + assert data_dict.get('image', None) is not None + if data_dict.get('image', None) is not None: + image_file = data_dict['image'] + image_file = os.path.join(self.image_folder, data_dict['video_file'], image_file) + # print(image_file) + image = Image.open(image_file).convert('RGB') + ori_width, ori_height = image.size + if self.pad_image_to_square: + image = expand2square( + image, + tuple( + int(x * 255) for x in self.image_processor.image_mean)) + image = self.image_processor.preprocess( + image, return_tensors='pt')['pixel_values'][0] + data_dict['pixel_values'] = image + + data_dict['ori_size'] = (ori_width, ori_height) + data_dict['questions'] = questions + else: + if hasattr(self.image_processor, 'crop_size'): + crop_size = self.image_processor.crop_size + else: + crop_size = self.image_processor.size + data_dict['pixel_values'] = torch.zeros(3, crop_size['height'], + crop_size['width']) + data_dict['masks'] = None + # pixel_values, binary masks, conversation/input ids + return data_dict + +def main(): + args = parse_args() + + torch.manual_seed(args.seed) + + if args.launcher != 'none': + set_multi_processing(distributed=True) + init_dist(args.launcher) + + rank, world_size = get_dist_info() + torch.cuda.set_device(rank) + else: + rank = 0 + world_size = 1 + + # build model + if not osp.isfile(args.config): + try: + args.config = cfgs_name_path[args.config] + except KeyError: + raise FileNotFoundError(f'Cannot find {args.config}') + + # load config + cfg = Config.fromfile(args.config) + # if args.cfg_options is not None: + # cfg.merge_from_dict(args.cfg_options) + + model_name = cfg.model.type if isinstance(cfg.model.type, + str) else cfg.model.type.__name__ + if 'LLaVAModel' in model_name: + cfg.model.pretrained_pth = None + + model = BUILDER.build(cfg.model) + backend = get_file_backend(args.pth_model) + + if os.path.exists(cfg.pretrained_pth): + if isinstance(backend, PetrelBackend): + from xtuner.utils.fileio import patch_fileio + with patch_fileio(): + state_dict = guess_load_checkpoint(cfg.pretrained_pth) + else: + state_dict = guess_load_checkpoint(cfg.pretrained_pth) + + # del state_dict['llm.base_model.model.model.tok_embeddings.weight'] + model.load_state_dict(state_dict, strict=False) + print(f'Load pre PTH model from {cfg.pretrained_pth}') + + if isinstance(backend, PetrelBackend): + from xtuner.utils.fileio import patch_fileio + with patch_fileio(): + state_dict = guess_load_checkpoint(args.pth_model) + else: + state_dict = guess_load_checkpoint(args.pth_model) + + # del state_dict['llm.base_model.model.model.tok_embeddings.weight'] + # print(state_dict.keys()) + model.load_state_dict(state_dict, strict=False) + print(f'Load PTH model from {args.pth_model}') + # image_processor_cfg = copy.deepcopy(cfg.image_processor) + image_processor = cfg.image_processor + image_processor_type = image_processor['type'] + del image_processor['type'] + image_processor = image_processor_type(**image_processor) + + llm = model.llm + tokenizer = model.tokenizer + + model.cuda() + model.eval() + llm.eval() + visual_encoder = model.visual_encoder + projector = model.projector + projector_text2vision = model.projector_text2vision + + projector.cuda() + projector.eval() + + visual_encoder.cuda() + visual_encoder.eval() + + stop_words = args.stop_words + if args.prompt_template: + template = PROMPT_TEMPLATE[args.prompt_template] + stop_words += template.get('STOP_WORDS', []) + stop_criteria = get_stop_criteria( + tokenizer=tokenizer, stop_words=stop_words) + + gen_config = GenerationConfig( + max_new_tokens=args.max_new_tokens, + do_sample=False, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id + if tokenizer.pad_token_id is not None else tokenizer.eos_token_id, + ) + + # # work_dir + # if args.work_dir is not None: + # # update configs according to CLI args if args.work_dir is not None + # save_dir = args.work_dir + # else: + # # use config filename as default work_dir + # save_dir = osp.join('./work_dirs', + # osp.splitext(osp.basename(args.data_path))[0]) + # timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time())) + # save_dir = osp.join(save_dir, timestamp) + + # if rank == 0: + # mkdir_or_exist(osp.abspath(save_dir)) + # print('=======================================================') + # print(f'Dataset path: {osp.abspath(args.data_path)}\n' + # f'Results will be saved to {osp.abspath(save_dir)}') + # print('=======================================================') + + # args_path = osp.join(save_dir, 'args.json') + # with open(args_path, 'w', encoding='utf-8') as f: + # json.dump(args.__dict__, f, indent=2) + + # results_xlsx_path = osp.join(save_dir, 'mmbench_result.xlsx') + # results_json_path = osp.join(save_dir, 'mmbench_result.json') + + dataset = RefVOSDataset( + image_folder='./data/rvos/valid/JPRGImages/', + image_processor=image_processor, + expressions_file="./data/rvos/meta_expressions/valid/meta_expressions.json", + tokenizer=tokenizer, + pad_image_to_square=True, + debug=False, + # debug=True, + ) + + results = [] + n_samples = len(dataset) + per_rank_samples = math.ceil(n_samples / world_size) + + per_rank_ids = range(per_rank_samples * rank, + min(n_samples, per_rank_samples * (rank + 1))) + + trackers = { + "intersection": AverageMeter("Intersec", ":6.3f", Summary.SUM), + "union": AverageMeter("Union", ":6.3f", Summary.SUM), + "gIoU": AverageMeter("gIoU", ":6.3f", Summary.SUM) + } + + for i in tqdm.tqdm(per_rank_ids, desc=f'Rank {rank}'): + data_sample = dataset[i] + questions = data_sample['questions'] + texts = [] + for question in questions: + texts.append(DEFAULT_IMAGE_TOKEN + '\n' + question) + + # if data_sample['context'] is not None: + # text = data_sample['context'] + '\n' + data_sample[ + # 'question'] + '\n' + data_sample['options'] + # else: + # text = data_sample['question'] + '\n' + data_sample['options'] + # + # text = DEFAULT_IMAGE_TOKEN + '\n' + text + # + # if is_cn_string(text): + # text = text + '请直接回答选项字母。' + # else: + # text = text + ("Answer with the option's letter from the " + # 'given choices directly.') + prompt_texts = [] + + if args.prompt_template: + for text in texts: + prompt_text = '' + template = PROMPT_TEMPLATE[args.prompt_template] + prompt_text += template['INSTRUCTION'].format( + input=text, round=1, bot_name=args.bot_name) + prompt_texts.append(prompt_text) + else: + prompt_texts = texts + + batch_inputs = prompt_texts + + image = data_sample['pixel_values'] # () + image = image.cuda().unsqueeze(0).to(visual_encoder.dtype) + visual_outputs = visual_encoder(image, output_hidden_states=True) + if isinstance(visual_outputs, list) or isinstance(visual_outputs, tuple)\ + or isinstance(visual_outputs, torch.Tensor): + pixel_values = projector(visual_outputs) + else: + pixel_values = projector( + visual_outputs.hidden_states[args.visual_select_layer][:, 1:]) + # pixel_values = projector( + # visual_outputs.hidden_states[args.visual_select_layer][:, 1:]) + + ori_size = data_sample['ori_size'] + target_masks = data_sample['masks'].cuda().to(torch.uint8) + + intersection, union, accuracy_iou = 0.0, 0.0, 0.0 + + for idx_inp, inputs in enumerate(batch_inputs): + # print("Question: ", inputs) + chunk_encode = [] + for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)): + if idx == 0: + cur_encode = tokenizer.encode(chunk) + else: + cur_encode = tokenizer.encode(chunk, add_special_tokens=False) + chunk_encode.append(cur_encode) + assert len(chunk_encode) == 2 + ids = [] + for idx, cur_chunk_encode in enumerate(chunk_encode): + ids.extend(cur_chunk_encode) + if idx != len(chunk_encode) - 1: + ids.append(IMAGE_TOKEN_INDEX) + ids = torch.tensor(ids).cuda().unsqueeze(0) + mm_inputs = prepare_inputs_labels_for_multimodal( + llm=llm, input_ids=ids, pixel_values=pixel_values) + generate_output = llm.generate( + **mm_inputs, + generation_config=gen_config, + streamer=None, + bos_token_id=tokenizer.bos_token_id, + stopping_criteria=stop_criteria, + output_hidden_states=True, + return_dict_in_generate=True + ) + predict = tokenizer.decode( + # generate_output.sequences[0], skip_special_tokens=True).strip() + generate_output.sequences[0]).strip() + # print("Answer:", predict) + + hidden_states = generate_output.hidden_states + last_hidden_states = [item[-1][0] for item in hidden_states] + last_hidden_states = torch.cat(last_hidden_states, dim=0) + seg_hidden_states = get_seg_hidden_states( + last_hidden_states, generate_output.sequences[0], + seg_id=model.seg_token_idx + ) + # print("Mask num: ", len(seg_hidden_states)) + if len(seg_hidden_states) == 0: + print("Warning, no [SEG] tokens !!!") + continue + elif len(seg_hidden_states) > 1: + print("Warning, {} [SEG] tokens !!!".format(len(seg_hidden_states))) + seg_hidden_states = seg_hidden_states[:1] + + seg_hidden_states = projector_text2vision(seg_hidden_states) + batch_idxs = torch.zeros((seg_hidden_states.shape[0],), + dtype=torch.int64).to(seg_hidden_states.device) + pred_masks_list = model.visual_encoder.forward_llm_seg(seg_hidden_states, batch_idxs) + pred_masks = pred_masks_list[-1] + w, h = ori_size + masks = F.interpolate(pred_masks, size=(max(w, h), max(w, h)), + mode='bilinear', align_corners=False) + masks = masks[:, 0] + # remove padding + if w == h: + pass + elif w > h: + n_pad = w - h + n_pad_1 = n_pad // 2 + n_pad_2 = n_pad - n_pad_1 + masks = masks[:, n_pad_1: w - n_pad_2] + else: + n_pad = h - w + n_pad_1 = n_pad // 2 + n_pad_2 = n_pad - n_pad_1 + masks = masks[:, :, n_pad_1: h - n_pad_2] + # binary + masks = masks.sigmoid() > 0.5 + masks = masks.int() * 255 + + video_name = data_sample['video_file'] + expression_id = data_sample['expressions_id'][idx_inp] + if not os.path.exists('./work_dirs/rvos/'): + os.mkdir('./work_dirs/rvos/') + if not os.path.exists(os.path.join('./work_dirs/rvos/', video_name)): + os.mkdir(os.path.join('./work_dirs/rvos/', video_name)) + if not os.path.exists(os.path.join('./work_dirs/rvos/', video_name, expression_id)): + os.mkdir(os.path.join('./work_dirs/rvos/', video_name, expression_id)) + + masks = torch.cat([masks, masks, masks], dim=0).permute(1, 2, 0) + masks = masks.cpu().numpy() + image = Image.fromarray(masks) + image.save(os.path.join('./work_dirs/rvos/', video_name, expression_id, data_sample['frames'][0] + '.png')) + +def get_seg_hidden_states(hidden_states, output_ids, seg_id): + seg_mask = output_ids == seg_id + n_out = len(seg_mask) + return hidden_states[-n_out:][seg_mask] + + +if __name__ == '__main__': + + main() diff --git a/omg_llava/tools/region_cap_mask_omg_seg_llava.py b/omg_llava/tools/region_cap_mask_omg_seg_llava.py new file mode 100644 index 0000000000000000000000000000000000000000..e5b3b43f73912001aad2a3d045b44f3aaaec223d --- /dev/null +++ b/omg_llava/tools/region_cap_mask_omg_seg_llava.py @@ -0,0 +1,508 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import json +import math +import os +import os.path as osp +import re +import torch +import tqdm + +from mmengine.dist import (collect_results, get_dist_info, get_rank, init_dist, + master_only) +from mmengine.utils.dl_utils import set_multi_processing +from torch.utils.data import Dataset +from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel, GenerationConfig) + +from xtuner.model.utils import LoadWoInit +from omg_llava.model.utils import prepare_inputs_labels_for_multimodal_with_visual_prompts +from xtuner.tools.utils import get_stop_criteria, is_cn_string +from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, + PROMPT_TEMPLATE) + +from xtuner.registry import BUILDER +from xtuner.configs import cfgs_name_path +from xtuner.model.utils import guess_load_checkpoint +from mmengine.config import Config +from mmengine.fileio import PetrelBackend, get_file_backend +from mmengine.config import ConfigDict + +from PIL import Image +import torch.nn.functional as F +from omg_llava.dataset.utils import expand2square, expand2square_mask +from pycocotools import mask + +from pycocotools.coco import COCO +import numpy as np + +def bbox_to_x1y1x2y2(bbox): + x1, y1, w, h = bbox + bbox = [x1, y1, x1 + w, y1 + h] + + return bbox + +def convert_dict2config_dict(input): + input = ConfigDict(**input) + for key in input.keys(): + if isinstance(input[key], dict): + input[key] = convert_dict2config_dict(input[key]) + return input + +TORCH_DTYPE_MAP = dict( + fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto') + +def parse_args(): + parser = argparse.ArgumentParser(description='RefCocoSeg') + parser.add_argument('config', help='config file name or path.') + parser.add_argument('pth_model', help='pth model file') + parser.add_argument( + '--output-path', type=str, default='./work_dirs/region_cap_pred.json', help='Name for Bot') + parser.add_argument( + '--prompt-template', + choices=PROMPT_TEMPLATE.keys(), + default='internlm2_chat', + help='Specify a prompt template') + parser.add_argument( + '--stop-words', nargs='+', type=str, default=[], help='Stop words') + parser.add_argument( + '--torch-dtype', + default='fp16', + choices=TORCH_DTYPE_MAP.keys(), + help='Override the default `torch.dtype` and load the model under ' + 'a specific `dtype`.') + parser.add_argument( + '--bits', + type=int, + choices=[4, 8, None], + default=None, + help='LLM bits') + parser.add_argument( + '--bot-name', type=str, default='BOT', help='Name for Bot') + parser.add_argument( + '--offload-folder', + default=None, + help='The folder in which to offload the model weights (or where the ' + 'model weights are already offloaded).') + parser.add_argument( + '--max-new-tokens', + type=int, + default=300, + help='Maximum number of new tokens allowed in generated text') + parser.add_argument( + '--seed', + type=int, + default=0, + help='Random seed for reproducible text generation') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + args = parser.parse_args() + return args + + +@master_only +def master_print(msg): + print(msg) + +class RegionCap_Inference_Dataset(Dataset): + def __init__(self, + image_folder, + image_processor, + pad_image_to_square=True, + annotation_file=None, + debug=False, + ): + self.debug = debug + self.image_folder = image_folder + size = image_processor.crop_size + # if isinstance(size, int): + # self.image_h, self.image_w = size, size + # else: + # self.image_w, self.image_h = size + self.image_h, self.image_w = 1024, 1024 + + if isinstance(image_processor, dict) or isinstance( + image_processor, Config) or isinstance(image_processor, + ConfigDict): + self.image_processor = BUILDER.build(image_processor) + else: + self.image_processor = image_processor + self.pad_image_to_square = pad_image_to_square + self.down_ratio = 1 + + self.coco = COCO(annotation_file) + self.image_dict = self.coco.imgs + self.ann_dict = self.coco.anns + self.image_dict_keys = list(self.image_dict.keys()) + + def __len__(self): + return len(self.image_dict_keys) + + def decode_mask(self, annotation, image_info): + flag = False + masks = [] + + for ann_id in range(1): + # if isinstance(ann_id, list): + # flag = True + # if -1 in ann_id: + # assert len(ann_id) == 1 + # m = np.zeros((image_info["height"], image_info["width"])).astype( + # np.uint8 + # ) + # else: + # m_final = np.zeros( + # (image_info["height"], image_info["width"]) + # ).astype(np.uint8) + # for ann_id_i in ann_id: + # ann = self.annotations[ann_id_i] + # + # if len(ann["segmentation"]) == 0: + # m = np.zeros( + # (image_info["height"], image_info["width"]) + # ).astype(np.uint8) + # else: + # if type(ann["segmentation"][0]) == list: # polygon + # rle = mask.frPyObjects( + # ann["segmentation"], image_info["height"], image_info["width"], ) + # else: + # rle = ann["segmentation"] + # for i in range(len(rle)): + # if not isinstance(rle[i]["counts"], bytes): + # rle[i]["counts"] = rle[i]["counts"].encode() + # m = mask.decode(rle) + # m = np.sum( + # m, axis=2 + # ) # sometimes there are multiple binary map (corresponding to multiple segs) + # m = m.astype(np.uint8) # convert to np.uint8 + # m_final = m_final | m + # m = m_final + # masks.append(m) + # continue + + ann = {"segmentation": annotation} + + if len(ann["segmentation"]) == 0: + m = np.zeros((image_info["height"], image_info["width"])).astype( + np.uint8 + ) + masks.append(m) + continue + + if type(ann["segmentation"][0]) == list: # polygon + rle = mask.frPyObjects( + ann["segmentation"], image_info["height"], image_info["width"] + ) + else: + rle = ann["segmentation"] + for i in range(len(rle)): + if not isinstance(rle[i]["counts"], bytes): + rle[i]["counts"] = rle[i]["counts"].encode() + m = mask.decode(rle) + m = np.sum(m, axis=2) # sometimes there are multiple binary map (corresponding to multiple segs) + m = m.astype(np.uint8) # convert to np.uint8 + masks.append(m) + masks = np.stack(masks, axis=0) + + if self.pad_image_to_square: + masks = expand2square_mask(masks) + masks = torch.from_numpy(masks) + masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio, + self.image_w // self.down_ratio), mode='nearest').squeeze(0) + + # print(image_info['file_name']) + # print(masks.shape) + # save_masks = torch.stack([masks[0], masks[0], masks[0]], dim=-1) + # save_masks = save_masks.numpy() * 255 + # save_masks = Image.fromarray(save_masks.astype(np.uint8)) + # save_masks.save("/root/mask.png") + # print(kkk) + return masks + + def get_questions(self): + question = "Can you provide me with a detailed description of the region in the picture marked by region1 ?" + return question + + def __getitem__(self, index): + + data_dict = {} + + image_id = self.image_dict_keys[index] + image_file = self.image_dict[image_id]['file_name'] + + questions = self.get_questions() + + data_dict['image_file'] = image_file + image_file = os.path.join(self.image_folder, image_file) + image = Image.open(image_file).convert('RGB') + ori_width, ori_height = image.size + if self.pad_image_to_square: + image = expand2square( + image, + tuple( + int(x * 255) for x in self.image_processor.image_mean)) + image = self.image_processor.preprocess( + image, return_tensors='pt')['pixel_values'][0] + data_dict['pixel_values'] = image + data_dict['ori_size'] = (ori_width, ori_height) + data_dict['questions'] = questions + + masks = self.ann_dict[image_id]['segmentation'] + image_info = self.image_dict[image_id] + masks = self.decode_mask(masks, image_info) + + data_dict['regions'] = masks + data_dict['image_id'] = image_id + + return data_dict + +def main(): + args = parse_args() + + torch.manual_seed(args.seed) + + if args.launcher != 'none': + set_multi_processing(distributed=True) + init_dist(args.launcher) + + rank, world_size = get_dist_info() + torch.cuda.set_device(rank) + else: + rank = 0 + world_size = 1 + + # build model + if not osp.isfile(args.config): + try: + args.config = cfgs_name_path[args.config] + except KeyError: + raise FileNotFoundError(f'Cannot find {args.config}') + + # load config + cfg = Config.fromfile(args.config) + # if args.cfg_options is not None: + # cfg.merge_from_dict(args.cfg_options) + + model_name = cfg.model.type if isinstance(cfg.model.type, + str) else cfg.model.type.__name__ + if 'LLaVAModel' or 'OMG' in model_name: + cfg.model.pretrained_pth = None + + model = BUILDER.build(cfg.model) + backend = get_file_backend(args.pth_model) + if isinstance(backend, PetrelBackend): + from xtuner.utils.fileio import patch_fileio + with patch_fileio(): + state_dict = guess_load_checkpoint(args.pth_model) + else: + state_dict = guess_load_checkpoint(args.pth_model) + + # del state_dict['llm.base_model.model.model.tok_embeddings.weight'] + model.load_state_dict(state_dict, strict=False) + print(f'Load PTH model from {args.pth_model}') + # image_processor_cfg = copy.deepcopy(cfg.image_processor) + image_processor = cfg.image_processor + image_processor_type = image_processor['type'] + del image_processor['type'] + image_processor = image_processor_type(**image_processor) + + llm = model.llm + tokenizer = model.tokenizer + + model.cuda() + model.eval() + llm.eval() + visual_encoder = model.visual_encoder + projector = model.projector + projector_text2vision = model.projector_text2vision + + projector.cuda() + projector.eval() + + visual_encoder.cuda() + visual_encoder.eval() + + stop_words = args.stop_words + if args.prompt_template: + template = PROMPT_TEMPLATE[args.prompt_template] + stop_words += template.get('STOP_WORDS', []) + stop_criteria = get_stop_criteria( + tokenizer=tokenizer, stop_words=stop_words) + + gen_config = GenerationConfig( + max_new_tokens=args.max_new_tokens, + do_sample=False, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id + if tokenizer.pad_token_id is not None else tokenizer.eos_token_id, + ) + + dataset = RegionCap_Inference_Dataset( + annotation_file='./data/region_caption/refcocog/finetune_refcocog_val_with_mask.json', + image_folder='./data/glamm_data/images/coco2014/train2014/', + image_processor=image_processor, + pad_image_to_square=True, + debug=False, + # debug=True, + ) + n_samples = len(dataset) + per_rank_samples = math.ceil(n_samples / world_size) + + per_rank_ids = range(per_rank_samples * rank, + min(n_samples, per_rank_samples * (rank + 1))) + results = [] + for i in tqdm.tqdm(per_rank_ids, desc=f'Rank {rank}'): + # pixel feature + data_sample = dataset[i] + image = data_sample['pixel_values'] # () + image = image.cuda().unsqueeze(0).to(visual_encoder.dtype) + visual_outputs = visual_encoder(image, output_hidden_states=True) + if isinstance(visual_outputs, list) or isinstance(visual_outputs, tuple) \ + or isinstance(visual_outputs, torch.Tensor): + pixel_values = projector(visual_outputs) + else: + pixel_values = projector( + visual_outputs.hidden_states[args.visual_select_layer][:, 1:]) + + + questions = data_sample['questions'] + regions = data_sample['regions'] + texts = DEFAULT_IMAGE_TOKEN + '\n' + questions + + if args.prompt_template: + prompt_text = '' + template = PROMPT_TEMPLATE[args.prompt_template] + prompt_text += template['INSTRUCTION'].format( + input=texts, round=1, bot_name=args.bot_name) + else: + prompt_text = texts + + batch_inputs = prompt_text + + predict = forward_model( + batch_inputs, pixel_values, + tokenizer, model, llm, + projector_text2vision, + gen_config, stop_criteria, points=regions, + mark_token_id=model.mark_token_idx, + width=image.shape[-1], height=image.shape[-2], + visual_encoder=visual_encoder, projector=projector + ) + + text_output = predict.replace("", "").replace("\n", "")\ + .replace("region1", '').replace("Region1", '')\ + .replace(':', '').replace(" ", " ").replace(" ", " ") + text_output = text_output.split("ASSISTANT: ")[-1] + + cleaned_str = re.sub(r'<.*?>', '', text_output) + + # Remove the [SEG] token + cleaned_str = cleaned_str.replace('[SEG]', '') + + # only select 1 setence for eval + # cleaned_str = cleaned_str.split('.')[0] + + # Strip unnecessary spaces + cleaned_str = ' '.join(cleaned_str.split()).strip("'") + cleaned_str = cleaned_str.strip() + + result_dict = {} + result_dict["image_id"] = data_sample['image_id'] + result_dict["caption"] = cleaned_str + result_dict["image_file"] = data_sample['image_file'] + result_dict["prediction"] = cleaned_str + results.append(result_dict) + print(cleaned_str) + + results = collect_results(results, n_samples) + + if get_rank() == 0: + with open(args.output_path, 'w') as json_file: + json.dump(results, json_file, indent=2) + +def forward_model(question, pixel_values, + tokenizer, model, llm, + projector_text2vision, + gen_config, stop_criteria, + mark_token_id=None, + points=None, width=None, height=None, + visual_encoder=None, projector=None): + # pixel_values = projector( + # visual_outputs.hidden_states[args.visual_select_layer][:, 1:]) + + inputs = question + # print("Question: ", inputs) + chunk_encode = [] + for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)): + if idx == 0: + cur_encode = tokenizer.encode(chunk) + else: + cur_encode = tokenizer.encode(chunk, add_special_tokens=False) + chunk_encode.append(cur_encode) + assert len(chunk_encode) == 2 + ids = [] + for idx, cur_chunk_encode in enumerate(chunk_encode): + ids.extend(cur_chunk_encode) + if idx != len(chunk_encode) - 1: + ids.append(IMAGE_TOKEN_INDEX) + ids = torch.tensor(ids).cuda().unsqueeze(0) + points = points.cuda() + + points_mark_embedding = get_points_embeddings( + points, ids, width, height, + mark_token_id, visual_encoder, + projector) + + mm_inputs = prepare_inputs_labels_for_multimodal_with_visual_prompts( + llm=llm, input_ids=ids, pixel_values=pixel_values, + mark_id=mark_token_id, + mark_feats=points_mark_embedding, region_id=-9999) + + generate_output = llm.generate( + **mm_inputs, + generation_config=gen_config, + streamer=None, + bos_token_id=tokenizer.bos_token_id, + stopping_criteria=stop_criteria, + output_hidden_states=True, + return_dict_in_generate=True + ) + predict = tokenizer.decode( + # generate_output.sequences[0], skip_special_tokens=True).strip() + generate_output.sequences[0], skip_special_tokens=True).strip() + return predict + + +def get_points_embeddings(points, input_ids, width, height, + mark_token_idx, visual_encoder, + projector): + if points is None or len(points) == 0: + return [] + + mark_token_mask = input_ids == mark_token_idx + batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to( + input_ids.device) + batch_idxs = batch_idxs[mark_token_mask] # (N, ) batch_size number + + # points = points.to(torch.float32) + # print(points.dtype, batch_idxs.dtype) + # marks_embeddings = visual_encoder.forward_point_sam( + # points, batch_idxs, width=width, height=height + # )[:, 0] # (N, C) + + marks_embeddings = visual_encoder.forward_region_sam( + points, batch_idxs + )[:, 0] # (N, C) + + marks_embeddings = marks_embeddings.to(projector.model.query_proj.weight.dtype) + marks_embeddings = projector.model.query_proj(marks_embeddings) + marks_embeddings = projector.model.model(marks_embeddings) + return marks_embeddings # (N, C) + +if __name__ == '__main__': + + main() diff --git a/omg_llava/tools/region_cap_omg_seg_llava.py b/omg_llava/tools/region_cap_omg_seg_llava.py new file mode 100644 index 0000000000000000000000000000000000000000..51315e43b99b53d7191f1f5072992f001907da48 --- /dev/null +++ b/omg_llava/tools/region_cap_omg_seg_llava.py @@ -0,0 +1,426 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import json +import math +import os +import os.path as osp +import re +import torch +import tqdm + +from mmengine.dist import (collect_results, get_dist_info, get_rank, init_dist, + master_only) +from mmengine.utils.dl_utils import set_multi_processing +from torch.utils.data import Dataset +from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel, GenerationConfig) + +from xtuner.model.utils import LoadWoInit, prepare_inputs_labels_for_multimodal, prepare_inputs_labels_for_multimodal_with_region +from xtuner.tools.utils import get_stop_criteria, is_cn_string +from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, + PROMPT_TEMPLATE) + +from xtuner.registry import BUILDER +from xtuner.configs import cfgs_name_path +from xtuner.model.utils import guess_load_checkpoint +from mmengine.config import Config +from mmengine.fileio import PetrelBackend, get_file_backend +from mmengine.config import ConfigDict + +from PIL import Image +import torch.nn.functional as F +from xtuner.dataset.utils import expand2square, expand2square_points +from pycocotools import mask as mask_utils + +from pycocotools.coco import COCO +import numpy as np + +def bbox_to_x1y1x2y2(bbox): + x1, y1, w, h = bbox + bbox = [x1, y1, x1 + w, y1 + h] + + return bbox + +def convert_dict2config_dict(input): + input = ConfigDict(**input) + for key in input.keys(): + if isinstance(input[key], dict): + input[key] = convert_dict2config_dict(input[key]) + return input + +TORCH_DTYPE_MAP = dict( + fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto') + +def parse_args(): + parser = argparse.ArgumentParser(description='RefCocoSeg') + parser.add_argument('config', help='config file name or path.') + parser.add_argument('pth_model', help='pth model file') + parser.add_argument( + '--prompt-template', + choices=PROMPT_TEMPLATE.keys(), + default='internlm2_chat', + help='Specify a prompt template') + parser.add_argument( + '--stop-words', nargs='+', type=str, default=[], help='Stop words') + parser.add_argument( + '--torch-dtype', + default='fp16', + choices=TORCH_DTYPE_MAP.keys(), + help='Override the default `torch.dtype` and load the model under ' + 'a specific `dtype`.') + parser.add_argument( + '--bits', + type=int, + choices=[4, 8, None], + default=None, + help='LLM bits') + parser.add_argument( + '--bot-name', type=str, default='BOT', help='Name for Bot') + parser.add_argument( + '--offload-folder', + default=None, + help='The folder in which to offload the model weights (or where the ' + 'model weights are already offloaded).') + parser.add_argument( + '--max-new-tokens', + type=int, + default=100, + help='Maximum number of new tokens allowed in generated text') + parser.add_argument( + '--seed', + type=int, + default=0, + help='Random seed for reproducible text generation') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + args = parser.parse_args() + return args + + +@master_only +def master_print(msg): + print(msg) + +class RegionCap_Inference_Dataset(Dataset): + def __init__(self, + image_folder, + image_processor, + pad_image_to_square=True, + annotation_file=None, + debug=False, + ): + self.debug = debug + self.image_folder = image_folder + size = image_processor.crop_size + # if isinstance(size, int): + # self.image_h, self.image_w = size, size + # else: + # self.image_w, self.image_h = size + self.image_h, self.image_w = 1024, 1024 + + if isinstance(image_processor, dict) or isinstance( + image_processor, Config) or isinstance(image_processor, + ConfigDict): + self.image_processor = BUILDER.build(image_processor) + else: + self.image_processor = image_processor + self.pad_image_to_square = pad_image_to_square + self.down_ratio = 1 + + self.coco = COCO(annotation_file) + self.image_dict = self.coco.imgs + self.ann_dict = self.coco.anns + self.image_dict_keys = list(self.image_dict.keys()) + + def __len__(self): + return len(self.image_dict_keys) + + def get_questions(self): + question = "Can you provide me with a detailed description of the region in the picture marked by region1 ?" + return question + + def __getitem__(self, index): + + data_dict = {} + + image_id = self.image_dict_keys[index] + image_file = self.image_dict[image_id]['file_name'] + gt = self.ann_dict[image_id]['caption'] + + questions = self.get_questions() + + data_dict['image_file'] = image_file + image_file = os.path.join(self.image_folder, image_file) + print(image_file) + image = Image.open(image_file).convert('RGB') + ori_width, ori_height = image.size + if self.pad_image_to_square: + image = expand2square( + image, + tuple( + int(x * 255) for x in self.image_processor.image_mean)) + image = self.image_processor.preprocess( + image, return_tensors='pt')['pixel_values'][0] + data_dict['pixel_values'] = image + data_dict['ori_size'] = (ori_width, ori_height) + data_dict['questions'] = questions + + bbox = bbox_to_x1y1x2y2(self.ann_dict[image_id]['bbox']) + bbox = np.array([bbox]) + points = (bbox[:, :2] + bbox[:, 2:]) / 2. + if self.pad_image_to_square: + points = expand2square_points(points, height=ori_height, width=ori_width) + points[:, 0] = points[:, 0] / max(ori_height, ori_width) * self.image_w + points[:, 1] = points[:, 1] / max(ori_height, ori_width) * self.image_h + else: + points[:, 0] = points[:, 0] / ori_width * self.image_w + points[:, 1] = points[:, 1] / ori_height * self.image_h + data_dict['points'] = torch.from_numpy(points) + print(data_dict['points']) + data_dict['image_id'] = image_id + + return data_dict + +def main(): + args = parse_args() + + torch.manual_seed(args.seed) + + if args.launcher != 'none': + set_multi_processing(distributed=True) + init_dist(args.launcher) + + rank, world_size = get_dist_info() + torch.cuda.set_device(rank) + else: + rank = 0 + world_size = 1 + + # build model + if not osp.isfile(args.config): + try: + args.config = cfgs_name_path[args.config] + except KeyError: + raise FileNotFoundError(f'Cannot find {args.config}') + + # load config + cfg = Config.fromfile(args.config) + # if args.cfg_options is not None: + # cfg.merge_from_dict(args.cfg_options) + + model_name = cfg.model.type if isinstance(cfg.model.type, + str) else cfg.model.type.__name__ + if 'LLaVAModel' in model_name: + cfg.model.pretrained_pth = None + + model = BUILDER.build(cfg.model) + backend = get_file_backend(args.pth_model) + if isinstance(backend, PetrelBackend): + from xtuner.utils.fileio import patch_fileio + with patch_fileio(): + state_dict = guess_load_checkpoint(args.pth_model) + else: + state_dict = guess_load_checkpoint(args.pth_model) + + # del state_dict['llm.base_model.model.model.tok_embeddings.weight'] + model.load_state_dict(state_dict, strict=False) + print(f'Load PTH model from {args.pth_model}') + # image_processor_cfg = copy.deepcopy(cfg.image_processor) + image_processor = cfg.image_processor + image_processor_type = image_processor['type'] + del image_processor['type'] + image_processor = image_processor_type(**image_processor) + + llm = model.llm + tokenizer = model.tokenizer + + model.cuda() + model.eval() + llm.eval() + visual_encoder = model.visual_encoder + projector = model.projector + projector_text2vision = model.projector_text2vision + + projector.cuda() + projector.eval() + + visual_encoder.cuda() + visual_encoder.eval() + + stop_words = args.stop_words + if args.prompt_template: + template = PROMPT_TEMPLATE[args.prompt_template] + stop_words += template.get('STOP_WORDS', []) + stop_criteria = get_stop_criteria( + tokenizer=tokenizer, stop_words=stop_words) + + gen_config = GenerationConfig( + max_new_tokens=args.max_new_tokens, + do_sample=False, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id + if tokenizer.pad_token_id is not None else tokenizer.eos_token_id, + ) + + dataset = RegionCap_Inference_Dataset( + annotation_file='./data/region_caption/mdetr_annotations/finetune_refcocog_val_captions.json', + image_folder='./data/glamm_data/images/coco2014/train2014/', + image_processor=image_processor, + pad_image_to_square=True, + debug=False, + # debug=True, + ) + n_samples = len(dataset) + per_rank_samples = math.ceil(n_samples / world_size) + + per_rank_ids = range(per_rank_samples * rank, + min(n_samples, per_rank_samples * (rank + 1))) + results = [] + for i in tqdm.tqdm(per_rank_ids, desc=f'Rank {rank}'): + # pixel feature + data_sample = dataset[i] + image = data_sample['pixel_values'] # () + image = image.cuda().unsqueeze(0).to(visual_encoder.dtype) + visual_outputs = visual_encoder(image, output_hidden_states=True) + if isinstance(visual_outputs, list) or isinstance(visual_outputs, tuple) \ + or isinstance(visual_outputs, torch.Tensor): + pixel_values = projector(visual_outputs) + else: + pixel_values = projector( + visual_outputs.hidden_states[args.visual_select_layer][:, 1:]) + + + questions = data_sample['questions'] + points = data_sample['points'] + texts = DEFAULT_IMAGE_TOKEN + '\n' + questions + + if args.prompt_template: + prompt_text = '' + template = PROMPT_TEMPLATE[args.prompt_template] + prompt_text += template['INSTRUCTION'].format( + input=texts, round=1, bot_name=args.bot_name) + else: + prompt_text = texts + + batch_inputs = prompt_text + + predict = forward_model( + batch_inputs, pixel_values, + tokenizer, model, llm, + projector_text2vision, + gen_config, stop_criteria, points=points, + mark_token_id=model.mark_token_idx, + width=image.shape[-1], height=image.shape[-2], + visual_encoder=visual_encoder, projector=projector + ) + + text_output = predict.replace("", "").replace("\n", "")\ + .replace("region1", '').replace(':', '').replace(" ", " ") + text_output = text_output.split("ASSISTANT: ")[-1] + + cleaned_str = re.sub(r'<.*?>', '', text_output) + + # Remove the [SEG] token + cleaned_str = cleaned_str.replace('[SEG]', '') + + # only select 1 setence for eval + # cleaned_str = cleaned_str.split('.')[0] + + # Strip unnecessary spaces + cleaned_str = ' '.join(cleaned_str.split()).strip("'") + cleaned_str = cleaned_str.strip() + + result_dict = {} + result_dict["image_id"] = data_sample['image_id'] + result_dict["caption"] = cleaned_str + results.append(result_dict) + print(cleaned_str) + + results = collect_results(results, n_samples) + + if get_rank() == 0: + with open('./work_dirs/region_cap_pred.json', 'w') as json_file: + json.dump(results, json_file, indent=2) + +def forward_model(question, pixel_values, + tokenizer, model, llm, + projector_text2vision, + gen_config, stop_criteria, + mark_token_id=None, + points=None, width=None, height=None, + visual_encoder=None, projector=None): + # pixel_values = projector( + # visual_outputs.hidden_states[args.visual_select_layer][:, 1:]) + + inputs = question + # print("Question: ", inputs) + chunk_encode = [] + for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)): + if idx == 0: + cur_encode = tokenizer.encode(chunk) + else: + cur_encode = tokenizer.encode(chunk, add_special_tokens=False) + chunk_encode.append(cur_encode) + assert len(chunk_encode) == 2 + ids = [] + for idx, cur_chunk_encode in enumerate(chunk_encode): + ids.extend(cur_chunk_encode) + if idx != len(chunk_encode) - 1: + ids.append(IMAGE_TOKEN_INDEX) + ids = torch.tensor(ids).cuda().unsqueeze(0) + points = points.cuda() + + points_mark_embedding = get_points_embeddings( + points, ids, width, height, + mark_token_id, visual_encoder, + projector) + + mm_inputs = prepare_inputs_labels_for_multimodal_with_region( + llm=llm, input_ids=ids, pixel_values=pixel_values, + mark_id=mark_token_id, + mark_feats=points_mark_embedding, region_id=-9999) + + generate_output = llm.generate( + **mm_inputs, + generation_config=gen_config, + streamer=None, + bos_token_id=tokenizer.bos_token_id, + stopping_criteria=stop_criteria, + output_hidden_states=True, + return_dict_in_generate=True + ) + predict = tokenizer.decode( + # generate_output.sequences[0], skip_special_tokens=True).strip() + generate_output.sequences[0]).strip() + return predict + + +def get_points_embeddings(points, input_ids, width, height, + mark_token_idx, visual_encoder, + projector): + if points is None or len(points) == 0: + return [] + + mark_token_mask = input_ids == mark_token_idx + batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to( + input_ids.device) + batch_idxs = batch_idxs[mark_token_mask] # (N, ) batch_size number + + points = points.to(torch.float32) + # print(points.dtype, batch_idxs.dtype) + marks_embeddings = visual_encoder.forward_point_sam( + points, batch_idxs, width=width, height=height + )[:, 0] # (N, C) + + marks_embeddings = marks_embeddings.to(projector.model.query_proj.weight.dtype) + marks_embeddings = projector.model.query_proj(marks_embeddings) + marks_embeddings = projector.model.model(marks_embeddings) + return marks_embeddings # (N, C) + +if __name__ == '__main__': + + main() diff --git a/omg_llava/tools/seg_cap_omg_llava.py b/omg_llava/tools/seg_cap_omg_llava.py new file mode 100644 index 0000000000000000000000000000000000000000..951f7048fe51820cc7cdfe39b34c46b569fda1f3 --- /dev/null +++ b/omg_llava/tools/seg_cap_omg_llava.py @@ -0,0 +1,654 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import copy +import os +import os.path as osp +import re +import sys + +import torch +from huggingface_hub import snapshot_download +from peft import PeftModel +from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel, GenerationConfig) +from transformers.generation.streamers import TextStreamer + +from xtuner.dataset.utils import expand2square, load_image, expand2square_points +from xtuner.model.utils import prepare_inputs_labels_for_multimodal, prepare_inputs_labels_for_multimodal_with_region + +from xtuner.tools.utils import get_stop_criteria +from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, + PROMPT_TEMPLATE, SYSTEM_TEMPLATE) + +import argparse +import os.path as osp + +from mmengine.config import Config, DictAction +from mmengine.fileio import PetrelBackend, get_file_backend + +from xtuner.configs import cfgs_name_path +from xtuner.model.utils import guess_load_checkpoint +from xtuner.registry import BUILDER +import numpy as np + +TORCH_DTYPE_MAP = dict( + fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto') + +from xtuner.engine.hooks.evaluate_chat_hook import EvaluateChatHook + +# def get_points_embeddings(points, input_ids, width, height, +# mark_token_idx, visual_encoder, +# projector): +# if points is None or len(points) == 0: +# return [] +# +# mark_token_mask = input_ids == mark_token_idx +# batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to( +# input_ids.device) +# batch_idxs = batch_idxs[mark_token_mask] # (N, ) batch_size number +# +# points = points.to(torch.float32) +# # print(points.dtype, batch_idxs.dtype) +# marks_embeddings = visual_encoder.forward_point_sam( +# points, batch_idxs, width=width, height=height +# )[:, 0] # (N, C) +# +# marks_embeddings = marks_embeddings.to(projector.model.query_proj.weight.dtype) +# marks_embeddings = projector.model.query_proj(marks_embeddings) +# marks_embeddings = projector.model.model(marks_embeddings) +# return marks_embeddings # (N, C) + +def get_points_embeddings(points, input_ids, width, height, + mark_token_idx, visual_encoder, + projector): + if points is None or len(points) == 0: + return [] + + mark_token_mask = input_ids == mark_token_idx + batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to( + input_ids.device) + batch_idxs = batch_idxs[mark_token_mask] # (N, ) batch_size number + + # points = points.to(torch.float32) + # print(points.dtype, batch_idxs.dtype) + # marks_embeddings = visual_encoder.forward_point_sam( + # points, batch_idxs, width=width, height=height + # )[:, 0] # (N, C) + + marks_embeddings = visual_encoder.forward_region_sam( + points, batch_idxs + )[:, 0] # (N, C) + + marks_embeddings = marks_embeddings.to(projector.model.query_proj.weight.dtype) + marks_embeddings = projector.model.query_proj(marks_embeddings) + marks_embeddings = projector.model.model(marks_embeddings) + return marks_embeddings # (N, C) + +def remove_prefix(state_dict, prefix): + new_state_dict = {} + for key, value in state_dict.items(): + if key.startswith(prefix): + new_key = key[len(prefix):] + new_state_dict[new_key] = value + else: + new_state_dict[key] = value + return new_state_dict + + +def parse_args(): + parser = argparse.ArgumentParser(description='Chat with a HF model') + parser.add_argument('config', help='config file name or path.') + parser.add_argument('pth_model', help='pth model file') + + parser.add_argument('--image', default=None, help='image') + parser.add_argument( + '--torch-dtype', + default='fp16', + choices=TORCH_DTYPE_MAP.keys(), + help='Override the default `torch.dtype` and load the model under ' + 'a specific `dtype`.') + parser.add_argument( + '--prompt-template', + choices=PROMPT_TEMPLATE.keys(), + default="internlm2_chat", + help='Specify a prompt template') + system_group = parser.add_mutually_exclusive_group() + system_group.add_argument( + '--system', default=None, help='Specify the system text') + system_group.add_argument( + '--system-template', + choices=SYSTEM_TEMPLATE.keys(), + default=None, + help='Specify a system template') + parser.add_argument( + '--bits', + type=int, + choices=[4, 8, None], + default=None, + help='LLM bits') + parser.add_argument( + '--bot-name', type=str, default='BOT', help='Name for Bot') + parser.add_argument( + '--with-plugins', + nargs='+', + choices=['calculate', 'solve', 'search'], + help='Specify plugins to use') + parser.add_argument( + '--no-streamer', action='store_true', help='Whether to with streamer') + parser.add_argument( + '--lagent', action='store_true', help='Whether to use lagent') + parser.add_argument( + '--stop-words', nargs='+', type=str, default=[], help='Stop words') + parser.add_argument( + '--offload-folder', + default=None, + help='The folder in which to offload the model weights (or where the ' + 'model weights are already offloaded).') + parser.add_argument( + '--max-new-tokens', + type=int, + default=2048, + help='Maximum number of new tokens allowed in generated text') + parser.add_argument( + '--temperature', + type=float, + default=0.1, + help='The value used to modulate the next token probabilities.') + parser.add_argument( + '--top-k', + type=int, + default=40, + help='The number of highest probability vocabulary tokens to ' + 'keep for top-k-filtering.') + parser.add_argument( + '--top-p', + type=float, + default=0.75, + help='If set to float < 1, only the smallest set of most probable ' + 'tokens with probabilities that add up to top_p or higher are ' + 'kept for generation.') + parser.add_argument( + '--repetition-penalty', + type=float, + default=1.0, + help='The parameter for repetition penalty. 1.0 means no penalty.') + parser.add_argument( + '--seed', + type=int, + default=0, + help='Random seed for reproducible text generation') + args = parser.parse_args() + return args + + +def get_input(): + """Helper function for getting input from users.""" + sentinel = '' # ends when this string is seen + result = None + while result is None: + print(('\ndouble enter to end input (EXIT: exit chat, ' + 'RESET: reset history) >>> '), + end='') + try: + result = '\n'.join(iter(input, sentinel)) + except UnicodeDecodeError: + print('Invalid characters detected. Please enter again.') + return result + + +def main(): + args = parse_args() + torch.manual_seed(args.seed) + + # parse config + if not osp.isfile(args.config): + try: + args.config = cfgs_name_path[args.config] + except KeyError: + raise FileNotFoundError(f'Cannot find {args.config}') + + # load config + cfg = Config.fromfile(args.config) + # if args.cfg_options is not None: + # cfg.merge_from_dict(args.cfg_options) + + model_name = cfg.model.type if isinstance(cfg.model.type, + str) else cfg.model.type.__name__ + if 'LLaVAModel' in model_name: + cfg.model.pretrained_pth = None + + model = BUILDER.build(cfg.model) + print(model.state_dict().keys()) + + # pre_state_dict = torch.load("/root/omg-llava.pth") + # model.load_state_dict(pre_state_dict) + + backend = get_file_backend(args.pth_model) + if isinstance(backend, PetrelBackend): + from xtuner.utils.fileio import patch_fileio + with patch_fileio(): + state_dict = guess_load_checkpoint(args.pth_model) + else: + state_dict = guess_load_checkpoint(args.pth_model) + + print(state_dict.keys()) + # del state_dict['llm.base_model.model.model.tok_embeddings.weight'] + model.load_state_dict(state_dict, strict=False) + print(f'Load PTH model from {args.pth_model}') + + # image_processor_cfg = copy.deepcopy(cfg.image_processor) + image_processor = cfg.image_processor + image_processor_type = image_processor['type'] + del image_processor['type'] + image_processor = image_processor_type(**image_processor) + + # chat_hook = EvaluateChatHook( + # tokenizer=cfg.tokenizer, + # image_processor=image_processor_cfg, + # every_n_iters=100, + # evaluation_inputs=cfg.evaluation_inputs, + # evaluation_images=cfg.evaluation_images, + # system='', + # prompt_template=PROMPT_TEMPLATE.internlm2_chat + # ) + # model.cuda() + # model.eval() + # chat_hook._eval_images_(model, model.device, max_new_tokens=200) + + # build llm + quantization_config = None + load_in_8bit = False + if args.bits == 4: + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4') + elif args.bits == 8: + load_in_8bit = True + model_kwargs = { + 'quantization_config': quantization_config, + 'load_in_8bit': load_in_8bit, + 'device_map': 'auto', + 'offload_folder': args.offload_folder, + 'trust_remote_code': True, + 'torch_dtype': TORCH_DTYPE_MAP[args.torch_dtype] + } + if False: + pass + else: + if args.with_plugins is None: + inner_thoughts_open = False + calculate_open = False + solve_open = False + search_open = False + else: + assert args.prompt_template == args.system_template == 'moss_sft' + from plugins import plugins_api + inner_thoughts_open = True + calculate_open = 'calculate' in args.with_plugins + solve_open = 'solve' in args.with_plugins + search_open = 'search' in args.with_plugins + # pre-import for api and model preparation + if calculate_open: + from plugins import calculate # noqa: F401 + if solve_open: + from plugins import solve # noqa: F401 + if search_open: + from plugins import search # noqa: F401 + # build llm + llm = model.llm + tokenizer = model.tokenizer + + model.cuda() + model.eval() + llm.eval() + visual_encoder = model.visual_encoder + projector = model.projector + projector_text2vision = model.projector_text2vision + + if args.image is not None: + image = load_image(args.image) + ori_width, ori_height = image.size + image = expand2square( + image, tuple(int(x * 255) for x in image_processor.image_mean)) + image_for_show = image + image = image_processor.preprocess( + image, return_tensors='pt')['pixel_values'][0] + image = image.cuda().unsqueeze(0).to(visual_encoder.dtype) + visual_outputs = visual_encoder(image, output_hidden_states=True) + print([item.shape for item in visual_outputs]) + pixel_values = projector(visual_outputs) + + n_obj = len(projector.model.valid_queries_embeddings[0]) + stop_words = args.stop_words + sep = '' + if args.prompt_template: + template = PROMPT_TEMPLATE[args.prompt_template] + stop_words += template.get('STOP_WORDS', []) + sep = template.get('SEP', '') + stop_criteria = get_stop_criteria( + tokenizer=tokenizer, stop_words=stop_words) + + if args.no_streamer: + streamer = None + else: + streamer = TextStreamer(tokenizer, skip_prompt=True) + + gen_config = GenerationConfig( + max_new_tokens=args.max_new_tokens, + do_sample=args.temperature > 0, + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + repetition_penalty=args.repetition_penalty, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id + if tokenizer.pad_token_id is not None else tokenizer.eos_token_id, + ) + + n_turn = 0 + inputs = '' + obj_index = 0 + while True: + # text = get_input() + if n_turn == 0: + text = 'There are some Regions:' + text = text + ' region{} '.format(1) + text = text + '.\n' + text = text + 'Please detailed describe the regions.' + else: + text = 'RESET' + obj_index += 1 + if obj_index == n_obj: + text = 'EXIT' + print('Log: Exit!') + exit(0) + n_turn = 0 + inputs = '' + continue + + while text.strip() == 'RESET': + print('Log: History responses have been removed!') + n_turn = 0 + inputs = '' + text = get_input() + if text.strip() == 'EXIT': + print('Log: Exit!') + exit(0) + + if args.image is not None and n_turn == 0: + text = DEFAULT_IMAGE_TOKEN + '\n' + text + + if args.prompt_template: + prompt_text = '' + template = PROMPT_TEMPLATE[args.prompt_template] + if 'SYSTEM' in template and n_turn == 0: + system_text = None + if args.system_template is not None: + system_text = SYSTEM_TEMPLATE[ + args.system_template].format( + round=n_turn + 1, bot_name=args.bot_name) + elif args.system is not None: + system_text = args.system + if system_text is not None: + prompt_text += template['SYSTEM'].format( + system=system_text, + round=n_turn + 1, + bot_name=args.bot_name) + prompt_text += template['INSTRUCTION'].format( + input=text, round=n_turn + 1, bot_name=args.bot_name) + if args.prompt_template == args.system_template == 'moss_sft': + if not inner_thoughts_open: + prompt_text.replace('- Inner thoughts: enabled.', + '- Inner thoughts: disabled.') + if not calculate_open: + prompt_text.replace(('- Calculator: enabled. API: ' + 'Calculate(expression)'), + '- Calculator: disabled.') + if not solve_open: + prompt_text.replace( + '- Equation solver: enabled. API: Solve(equation)', + '- Equation solver: disabled.') + if not search_open: + prompt_text.replace( + '- Web search: enabled. API: Search(query)', + '- Web search: disabled.') + else: + prompt_text = text + print("prompt_text: ", prompt_text) + inputs += prompt_text + if args.image is None: + if n_turn == 0: + ids = tokenizer.encode(inputs, return_tensors='pt') + else: + ids = tokenizer.encode( + inputs, return_tensors='pt', add_special_tokens=False) + + if args.with_plugins is not None: + generate_output = llm.generate( + inputs=ids.cuda(), + generation_config=gen_config, + streamer=streamer, + stopping_criteria=stop_criteria).cpu() + generate_output_text = tokenizer.decode( + generate_output[0][len(ids[0]):]) + if streamer is None: + end = '' if generate_output_text[-1] == '\n' else '\n' + print(generate_output_text, end=end) + pattern = r'<\|Commands\|>:(.*?)' + command_text = ', '.join( + re.findall(pattern, generate_output_text)) + extent_text = plugins_api( + command_text, + calculate_open=calculate_open, + solve_open=solve_open, + search_open=search_open) + end = '' if extent_text[-1] == '\n' else '\n' + print(extent_text, end=end) + extent_text_ids = tokenizer.encode( + extent_text, + return_tensors='pt', + add_special_tokens=False) + new_ids = torch.cat((generate_output, extent_text_ids), + dim=1) + + generate_output = llm.generate( + inputs=new_ids.cuda(), + generation_config=gen_config, + streamer=streamer, + stopping_criteria=stop_criteria) + if streamer is None: + output_text = tokenizer.decode( + generate_output[0][len(new_ids[0]):]) + end = '' if output_text[-1] == '\n' else '\n' + print(output_text, end=end) + else: + generate_output = llm.generate( + inputs=ids.cuda(), + generation_config=gen_config, + streamer=streamer, + stopping_criteria=stop_criteria) + if streamer is None: + output_text = tokenizer.decode( + generate_output[0][len(ids[0]):]) + end = '' if output_text[-1] == '\n' else '\n' + print(output_text, end=end) + inputs = tokenizer.decode(generate_output[0]) + else: + chunk_encode = [] + for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)): + if idx == 0 and n_turn == 0: + cur_encode = tokenizer.encode(chunk) + else: + cur_encode = tokenizer.encode( + chunk, add_special_tokens=False) + chunk_encode.append(cur_encode) + assert len(chunk_encode) == 2 + ids = [] + for idx, cur_chunk_encode in enumerate(chunk_encode): + ids.extend(cur_chunk_encode) + if idx != len(chunk_encode) - 1: + ids.append(IMAGE_TOKEN_INDEX) + ids = torch.tensor(ids).cuda().unsqueeze(0) + + # add region token on firset conversation + if n_turn == 0: + print(obj_index) + masks = visual_encoder.vis_binary_masks[obj_index:obj_index+1].cuda().to(pixel_values[0].dtype) + mark_token_id = model.mark_token_idx + # + points_mark_embedding = get_points_embeddings( + masks, ids, 1024, 1024, + mark_token_id, visual_encoder, + projector) + # points_mark_embedding = projector.model.valid_queries_embeddings[0][obj_index:obj_index+1] + + # mm_inputs = prepare_inputs_labels_for_multimodal( + # llm=llm, input_ids=ids, pixel_values=pixel_values) + mm_inputs = prepare_inputs_labels_for_multimodal_with_region( + llm=llm, input_ids=ids, pixel_values=pixel_values, + mark_id=mark_token_id, + mark_feats=points_mark_embedding, region_id=-9999) + + print(mm_inputs['inputs_embeds'].shape) + # mm_inputs['inputs_embeds'] = mm_inputs['inputs_embeds'].to(torch.float16) + + generate_output = llm.generate( + **mm_inputs, + generation_config=gen_config, + streamer=streamer, + bos_token_id=tokenizer.bos_token_id, + stopping_criteria=stop_criteria, + output_hidden_states=True, + return_dict_in_generate=True + ) + + hidden_states = generate_output.hidden_states + last_hidden_states = [item[-1][0] for item in hidden_states] + last_hidden_states = torch.cat(last_hidden_states, dim=0) + seg_hidden_states = get_seg_hidden_states( + last_hidden_states, generate_output.sequences[0], + seg_id=model.seg_token_idx + ) + # seg_hidden_states = seg_hidden_states.to(torch.float32) + masks = visual_encoder.vis_binary_masks + if len(masks) != 0: + print((masks.flatten(1) > 0).sum(-1)) + print(masks.shape) + output_text = tokenizer.decode(generate_output.sequences[0]) + show_mask_pred_binary(image_for_show, masks, save_dir='./output.png', text=output_text) + + + # if len(seg_hidden_states) != 0: + # seg_hidden_states = projector_text2vision(seg_hidden_states) + # batch_idxs = torch.zeros((seg_hidden_states.shape[0], ), + # dtype=torch.int64).to(seg_hidden_states.device) + # pred_masks_list = model.visual_encoder.forward_llm_seg(seg_hidden_states, batch_idxs) + # print((pred_masks_list[-1].flatten(2) > 0).sum(-1)) + # print(pred_masks_list[-1].shape) + # show_mask_pred(image_for_show, pred_masks_list[-1], save_dir='./output.png') + + + if streamer is None: + # output_text = tokenizer.decode(generate_output[0]) + output_text = tokenizer.decode(generate_output.sequences[0]) + end = '' if output_text[-1] == '\n' else '\n' + print(output_text, end=end) + # inputs += tokenizer.decode(generate_output[0]) + inputs += tokenizer.decode(generate_output.sequences[0]) + n_turn += 1 + inputs += sep + # if len(generate_output[0]) >= args.max_new_tokens: + if len(generate_output.sequences[0]) >= args.max_new_tokens: + print( + 'Remove the memory of history responses, since ' + f'it exceeds the length limitation {args.max_new_tokens}.') + n_turn = 0 + inputs = '' + + +def get_seg_hidden_states(hidden_states, output_ids, seg_id): + seg_mask = output_ids == seg_id + n_out = len(seg_mask) + print(output_ids) + return hidden_states[-n_out:][seg_mask] + +def show_mask_pred(image, masks, save_dir='./output.png'): + import torch.nn.functional as F + from PIL import Image + import numpy as np + + colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), + (255, 255, 0), (255, 0, 255), (0, 255, 255), + (128, 128, 255)] + + masks = F.interpolate(masks, size=image.size, mode='bilinear', align_corners=False) + masks = masks.sigmoid() > 0.5 + masks = masks.to(torch.uint8).cpu().numpy()[:, 0] + + _mask_image = np.zeros((masks.shape[1], masks.shape[2], 3), dtype=np.uint8) + + for i, mask in enumerate(masks): + color = colors[i % len(colors)] + print(color) + _mask_image[:, :, 0] = _mask_image[:, :, 0] + mask.astype(np.uint8) * color[0] + _mask_image[:, :, 1] = _mask_image[:, :, 1] + mask.astype(np.uint8) * color[1] + _mask_image[:, :, 2] = _mask_image[:, :, 2] + mask.astype(np.uint8) * color[2] + + + image = np.array(image) + image = image * 0.5 + _mask_image * 0.5 + image = image.astype(np.uint8) + image = Image.fromarray(image) + image.save(save_dir) + + return + +def show_mask_pred_binary(image, masks, save_dir='./output.png', filter_ids=[], text=''): + import torch.nn.functional as F + from PIL import Image + import numpy as np + + print(text) + filter_ids = [] + for i in range(len(masks)): + if 'Mark {}'.format(i+1) in text: + filter_ids.append(i) + print(filter_ids) + + colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), + (255, 255, 0), (255, 0, 255), (0, 255, 255), + (128, 128, 255)] + + masks = masks.to(torch.float32) + masks = F.interpolate(masks.unsqueeze(0), size=image.size, mode='bilinear', align_corners=False)[0] + masks = masks > 0.5 + + masks = masks.to(torch.uint8).cpu().numpy() + + _mask_image = np.zeros((masks.shape[1], masks.shape[2], 3), dtype=np.uint8) + + for i, mask in enumerate(masks): + # if i not in filter_ids:continue + color = colors[i % len(colors)] + print(color) + _mask_image[:, :, 0] = _mask_image[:, :, 0] + mask.astype(np.uint8) * color[0] + _mask_image[:, :, 1] = _mask_image[:, :, 1] + mask.astype(np.uint8) * color[1] + _mask_image[:, :, 2] = _mask_image[:, :, 2] + mask.astype(np.uint8) * color[2] + + + image = np.array(image) + image = image * 0.5 + _mask_image * 0.5 + image = image.astype(np.uint8) + image = Image.fromarray(image) + image.save(save_dir) + + return + +if __name__ == '__main__': + main() diff --git a/omg_llava/tools/seg_condition_cap_omg_llava.py b/omg_llava/tools/seg_condition_cap_omg_llava.py new file mode 100644 index 0000000000000000000000000000000000000000..84f69ef0bc053abdee34d1ae065ace4edcea78d8 --- /dev/null +++ b/omg_llava/tools/seg_condition_cap_omg_llava.py @@ -0,0 +1,628 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import copy +import os +import os.path as osp +import re +import sys + +import torch +from huggingface_hub import snapshot_download +from peft import PeftModel +from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel, GenerationConfig) +from transformers.generation.streamers import TextStreamer + +from xtuner.dataset.utils import expand2square, load_image, expand2square_points +from xtuner.model.utils import prepare_inputs_labels_for_multimodal, prepare_inputs_labels_for_multimodal_with_region + +from xtuner.tools.utils import get_stop_criteria +from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, + PROMPT_TEMPLATE, SYSTEM_TEMPLATE) + +import argparse +import os.path as osp + +from mmengine.config import Config, DictAction +from mmengine.fileio import PetrelBackend, get_file_backend + +from xtuner.configs import cfgs_name_path +from xtuner.model.utils import guess_load_checkpoint +from xtuner.registry import BUILDER +import numpy as np + +TORCH_DTYPE_MAP = dict( + fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto') + +from xtuner.engine.hooks.evaluate_chat_hook import EvaluateChatHook + +def get_points_embeddings(points, input_ids, width, height, + mark_token_idx, visual_encoder, + projector): + if points is None or len(points) == 0: + return [] + + mark_token_mask = input_ids == mark_token_idx + batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to( + input_ids.device) + batch_idxs = batch_idxs[mark_token_mask] # (N, ) batch_size number + + points = points.to(torch.float32) + # print(points.dtype, batch_idxs.dtype) + marks_embeddings = visual_encoder.forward_point_sam( + points, batch_idxs, width=width, height=height + )[:, 0] # (N, C) + + marks_embeddings = marks_embeddings.to(projector.model.query_proj.weight.dtype) + marks_embeddings = projector.model.query_proj(marks_embeddings) + marks_embeddings = projector.model.model(marks_embeddings) + return marks_embeddings # (N, C) + +def remove_prefix(state_dict, prefix): + new_state_dict = {} + for key, value in state_dict.items(): + if key.startswith(prefix): + new_key = key[len(prefix):] + new_state_dict[new_key] = value + else: + new_state_dict[key] = value + return new_state_dict + + +def parse_args(): + parser = argparse.ArgumentParser(description='Chat with a HF model') + parser.add_argument('config', help='config file name or path.') + parser.add_argument('pth_model', help='pth model file') + + parser.add_argument('--image', default=None, help='image') + parser.add_argument( + '--torch-dtype', + default='fp16', + choices=TORCH_DTYPE_MAP.keys(), + help='Override the default `torch.dtype` and load the model under ' + 'a specific `dtype`.') + parser.add_argument( + '--prompt-template', + choices=PROMPT_TEMPLATE.keys(), + default="internlm2_chat", + help='Specify a prompt template') + system_group = parser.add_mutually_exclusive_group() + system_group.add_argument( + '--system', default=None, help='Specify the system text') + system_group.add_argument( + '--system-template', + choices=SYSTEM_TEMPLATE.keys(), + default=None, + help='Specify a system template') + parser.add_argument( + '--bits', + type=int, + choices=[4, 8, None], + default=None, + help='LLM bits') + parser.add_argument( + '--bot-name', type=str, default='BOT', help='Name for Bot') + parser.add_argument( + '--with-plugins', + nargs='+', + choices=['calculate', 'solve', 'search'], + help='Specify plugins to use') + parser.add_argument( + '--no-streamer', action='store_true', help='Whether to with streamer') + parser.add_argument( + '--lagent', action='store_true', help='Whether to use lagent') + parser.add_argument( + '--stop-words', nargs='+', type=str, default=[], help='Stop words') + parser.add_argument( + '--offload-folder', + default=None, + help='The folder in which to offload the model weights (or where the ' + 'model weights are already offloaded).') + parser.add_argument( + '--max-new-tokens', + type=int, + default=2048, + help='Maximum number of new tokens allowed in generated text') + parser.add_argument( + '--temperature', + type=float, + default=0.1, + help='The value used to modulate the next token probabilities.') + parser.add_argument( + '--top-k', + type=int, + default=40, + help='The number of highest probability vocabulary tokens to ' + 'keep for top-k-filtering.') + parser.add_argument( + '--top-p', + type=float, + default=0.75, + help='If set to float < 1, only the smallest set of most probable ' + 'tokens with probabilities that add up to top_p or higher are ' + 'kept for generation.') + parser.add_argument( + '--repetition-penalty', + type=float, + default=1.0, + help='The parameter for repetition penalty. 1.0 means no penalty.') + parser.add_argument( + '--seed', + type=int, + default=0, + help='Random seed for reproducible text generation') + args = parser.parse_args() + return args + + +def get_input(): + """Helper function for getting input from users.""" + sentinel = '' # ends when this string is seen + result = None + while result is None: + print(('\ndouble enter to end input (EXIT: exit chat, ' + 'RESET: reset history) >>> '), + end='') + try: + result = '\n'.join(iter(input, sentinel)) + except UnicodeDecodeError: + print('Invalid characters detected. Please enter again.') + return result + + +def main(): + args = parse_args() + torch.manual_seed(args.seed) + + # parse config + if not osp.isfile(args.config): + try: + args.config = cfgs_name_path[args.config] + except KeyError: + raise FileNotFoundError(f'Cannot find {args.config}') + + # load config + cfg = Config.fromfile(args.config) + # if args.cfg_options is not None: + # cfg.merge_from_dict(args.cfg_options) + + model_name = cfg.model.type if isinstance(cfg.model.type, + str) else cfg.model.type.__name__ + if 'LLaVAModel' in model_name: + cfg.model.pretrained_pth = None + + model = BUILDER.build(cfg.model) + print(model.state_dict().keys()) + + # pre_state_dict = torch.load("/root/omg-llava.pth") + # model.load_state_dict(pre_state_dict) + + backend = get_file_backend(args.pth_model) + if isinstance(backend, PetrelBackend): + from xtuner.utils.fileio import patch_fileio + with patch_fileio(): + state_dict = guess_load_checkpoint(args.pth_model) + else: + state_dict = guess_load_checkpoint(args.pth_model) + + print(state_dict.keys()) + # del state_dict['llm.base_model.model.model.tok_embeddings.weight'] + model.load_state_dict(state_dict, strict=False) + print(f'Load PTH model from {args.pth_model}') + + # image_processor_cfg = copy.deepcopy(cfg.image_processor) + image_processor = cfg.image_processor + image_processor_type = image_processor['type'] + del image_processor['type'] + image_processor = image_processor_type(**image_processor) + + # chat_hook = EvaluateChatHook( + # tokenizer=cfg.tokenizer, + # image_processor=image_processor_cfg, + # every_n_iters=100, + # evaluation_inputs=cfg.evaluation_inputs, + # evaluation_images=cfg.evaluation_images, + # system='', + # prompt_template=PROMPT_TEMPLATE.internlm2_chat + # ) + # model.cuda() + # model.eval() + # chat_hook._eval_images_(model, model.device, max_new_tokens=200) + + # build llm + quantization_config = None + load_in_8bit = False + if args.bits == 4: + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4') + elif args.bits == 8: + load_in_8bit = True + model_kwargs = { + 'quantization_config': quantization_config, + 'load_in_8bit': load_in_8bit, + 'device_map': 'auto', + 'offload_folder': args.offload_folder, + 'trust_remote_code': True, + 'torch_dtype': TORCH_DTYPE_MAP[args.torch_dtype] + } + if False: + pass + else: + if args.with_plugins is None: + inner_thoughts_open = False + calculate_open = False + solve_open = False + search_open = False + else: + assert args.prompt_template == args.system_template == 'moss_sft' + from plugins import plugins_api + inner_thoughts_open = True + calculate_open = 'calculate' in args.with_plugins + solve_open = 'solve' in args.with_plugins + search_open = 'search' in args.with_plugins + # pre-import for api and model preparation + if calculate_open: + from plugins import calculate # noqa: F401 + if solve_open: + from plugins import solve # noqa: F401 + if search_open: + from plugins import search # noqa: F401 + # build llm + llm = model.llm + tokenizer = model.tokenizer + + model.cuda() + model.eval() + llm.eval() + visual_encoder = model.visual_encoder + projector = model.projector + projector_text2vision = model.projector_text2vision + + if args.image is not None: + image = load_image(args.image) + ori_width, ori_height = image.size + image = expand2square( + image, tuple(int(x * 255) for x in image_processor.image_mean)) + image_for_show = image + image = image_processor.preprocess( + image, return_tensors='pt')['pixel_values'][0] + image = image.cuda().unsqueeze(0).to(visual_encoder.dtype) + visual_outputs = visual_encoder(image, output_hidden_states=True) + print([item.shape for item in visual_outputs]) + pixel_values = projector(visual_outputs) + + n_obj = len(projector.model.valid_queries_embeddings[0]) + stop_words = args.stop_words + sep = '' + if args.prompt_template: + template = PROMPT_TEMPLATE[args.prompt_template] + stop_words += template.get('STOP_WORDS', []) + sep = template.get('SEP', '') + stop_criteria = get_stop_criteria( + tokenizer=tokenizer, stop_words=stop_words) + + if args.no_streamer: + streamer = None + else: + streamer = TextStreamer(tokenizer, skip_prompt=True) + + gen_config = GenerationConfig( + max_new_tokens=args.max_new_tokens, + do_sample=args.temperature > 0, + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + repetition_penalty=args.repetition_penalty, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id + if tokenizer.pad_token_id is not None else tokenizer.eos_token_id, + ) + + n_turn = 0 + inputs = '' + while True: + # text = get_input() + if n_turn == 0: + text = 'There are some Marks:' + for i in range(n_obj): + text = text + ' mark{} '.format(i + 1) + if i + 1 == n_obj: + text = text + '.\n' + else: + text = text + ',' + text = text + 'Please detailed describe these Marks.' + else: + text = 'EXIT' + + while text.strip() == 'RESET': + print('Log: History responses have been removed!') + n_turn = 0 + inputs = '' + text = get_input() + if text.strip() == 'EXIT': + print('Log: Exit!') + exit(0) + + if args.image is not None and n_turn == 0: + text = DEFAULT_IMAGE_TOKEN + '\n' + text + + if args.prompt_template: + prompt_text = '' + template = PROMPT_TEMPLATE[args.prompt_template] + if 'SYSTEM' in template and n_turn == 0: + system_text = None + if args.system_template is not None: + system_text = SYSTEM_TEMPLATE[ + args.system_template].format( + round=n_turn + 1, bot_name=args.bot_name) + elif args.system is not None: + system_text = args.system + if system_text is not None: + prompt_text += template['SYSTEM'].format( + system=system_text, + round=n_turn + 1, + bot_name=args.bot_name) + prompt_text += template['INSTRUCTION'].format( + input=text, round=n_turn + 1, bot_name=args.bot_name) + if args.prompt_template == args.system_template == 'moss_sft': + if not inner_thoughts_open: + prompt_text.replace('- Inner thoughts: enabled.', + '- Inner thoughts: disabled.') + if not calculate_open: + prompt_text.replace(('- Calculator: enabled. API: ' + 'Calculate(expression)'), + '- Calculator: disabled.') + if not solve_open: + prompt_text.replace( + '- Equation solver: enabled. API: Solve(equation)', + '- Equation solver: disabled.') + if not search_open: + prompt_text.replace( + '- Web search: enabled. API: Search(query)', + '- Web search: disabled.') + else: + prompt_text = text + print("prompt_text: ", prompt_text) + inputs += prompt_text + if args.image is None: + if n_turn == 0: + ids = tokenizer.encode(inputs, return_tensors='pt') + else: + ids = tokenizer.encode( + inputs, return_tensors='pt', add_special_tokens=False) + + if args.with_plugins is not None: + generate_output = llm.generate( + inputs=ids.cuda(), + generation_config=gen_config, + streamer=streamer, + stopping_criteria=stop_criteria).cpu() + generate_output_text = tokenizer.decode( + generate_output[0][len(ids[0]):]) + if streamer is None: + end = '' if generate_output_text[-1] == '\n' else '\n' + print(generate_output_text, end=end) + pattern = r'<\|Commands\|>:(.*?)' + command_text = ', '.join( + re.findall(pattern, generate_output_text)) + extent_text = plugins_api( + command_text, + calculate_open=calculate_open, + solve_open=solve_open, + search_open=search_open) + end = '' if extent_text[-1] == '\n' else '\n' + print(extent_text, end=end) + extent_text_ids = tokenizer.encode( + extent_text, + return_tensors='pt', + add_special_tokens=False) + new_ids = torch.cat((generate_output, extent_text_ids), + dim=1) + + generate_output = llm.generate( + inputs=new_ids.cuda(), + generation_config=gen_config, + streamer=streamer, + stopping_criteria=stop_criteria) + if streamer is None: + output_text = tokenizer.decode( + generate_output[0][len(new_ids[0]):]) + end = '' if output_text[-1] == '\n' else '\n' + print(output_text, end=end) + else: + generate_output = llm.generate( + inputs=ids.cuda(), + generation_config=gen_config, + streamer=streamer, + stopping_criteria=stop_criteria) + if streamer is None: + output_text = tokenizer.decode( + generate_output[0][len(ids[0]):]) + end = '' if output_text[-1] == '\n' else '\n' + print(output_text, end=end) + inputs = tokenizer.decode(generate_output[0]) + else: + chunk_encode = [] + for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)): + if idx == 0 and n_turn == 0: + cur_encode = tokenizer.encode(chunk) + else: + cur_encode = tokenizer.encode( + chunk, add_special_tokens=False) + chunk_encode.append(cur_encode) + assert len(chunk_encode) == 2 + ids = [] + for idx, cur_chunk_encode in enumerate(chunk_encode): + ids.extend(cur_chunk_encode) + if idx != len(chunk_encode) - 1: + ids.append(IMAGE_TOKEN_INDEX) + ids = torch.tensor(ids).cuda().unsqueeze(0) + + # add region token on firset conversation + if n_turn == 0: + # points = np.array(prompts_points) + # points = expand2square_points(points, height=ori_height, width=ori_width) + # points[:, 0] = points[:, 0] / max(ori_height, ori_width) * 1024 + # points[:, 1] = points[:, 1] / max(ori_height, ori_width) * 1024 + # points = torch.from_numpy(points) + # points = points.cuda() + mark_token_id = model.mark_token_idx + # + # points_mark_embedding = get_points_embeddings( + # points, ids, 1024, 1024, + # mark_token_id, visual_encoder, + # projector) + points_mark_embedding = projector.model.valid_queries_embeddings[0] + + # mm_inputs = prepare_inputs_labels_for_multimodal( + # llm=llm, input_ids=ids, pixel_values=pixel_values) + + mm_inputs = prepare_inputs_labels_for_multimodal_with_region( + llm=llm, input_ids=ids, pixel_values=pixel_values, + mark_id=mark_token_id, + mark_feats=points_mark_embedding, region_id=-9999) + + print(mm_inputs['inputs_embeds'].shape) + # mm_inputs['inputs_embeds'] = mm_inputs['inputs_embeds'].to(torch.float16) + + generate_output = llm.generate( + **mm_inputs, + generation_config=gen_config, + streamer=streamer, + bos_token_id=tokenizer.bos_token_id, + stopping_criteria=stop_criteria, + output_hidden_states=True, + return_dict_in_generate=True + ) + + hidden_states = generate_output.hidden_states + last_hidden_states = [item[-1][0] for item in hidden_states] + last_hidden_states = torch.cat(last_hidden_states, dim=0) + seg_hidden_states = get_seg_hidden_states( + last_hidden_states, generate_output.sequences[0], + seg_id=model.seg_token_idx + ) + # seg_hidden_states = seg_hidden_states.to(torch.float32) + masks = visual_encoder.vis_binary_masks + if len(masks) != 0: + print((masks.flatten(2) > 0).sum(-1)) + print(masks.shape) + output_text = tokenizer.decode(generate_output.sequences[0]) + show_mask_pred_binary(image_for_show, masks, save_dir='./output.png', text=output_text) + + + # if len(seg_hidden_states) != 0: + # seg_hidden_states = projector_text2vision(seg_hidden_states) + # batch_idxs = torch.zeros((seg_hidden_states.shape[0], ), + # dtype=torch.int64).to(seg_hidden_states.device) + # pred_masks_list = model.visual_encoder.forward_llm_seg(seg_hidden_states, batch_idxs) + # print((pred_masks_list[-1].flatten(2) > 0).sum(-1)) + # print(pred_masks_list[-1].shape) + # show_mask_pred(image_for_show, pred_masks_list[-1], save_dir='./output.png') + + + if streamer is None: + # output_text = tokenizer.decode(generate_output[0]) + output_text = tokenizer.decode(generate_output.sequences[0]) + end = '' if output_text[-1] == '\n' else '\n' + print(output_text, end=end) + # inputs += tokenizer.decode(generate_output[0]) + inputs += tokenizer.decode(generate_output.sequences[0]) + n_turn += 1 + inputs += sep + # if len(generate_output[0]) >= args.max_new_tokens: + if len(generate_output.sequences[0]) >= args.max_new_tokens: + print( + 'Remove the memory of history responses, since ' + f'it exceeds the length limitation {args.max_new_tokens}.') + n_turn = 0 + inputs = '' + + +def get_seg_hidden_states(hidden_states, output_ids, seg_id): + seg_mask = output_ids == seg_id + n_out = len(seg_mask) + print(output_ids) + return hidden_states[-n_out:][seg_mask] + +def show_mask_pred(image, masks, save_dir='./output.png'): + import torch.nn.functional as F + from PIL import Image + import numpy as np + + colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), + (255, 255, 0), (255, 0, 255), (0, 255, 255), + (128, 128, 255)] + + masks = F.interpolate(masks, size=image.size, mode='bilinear', align_corners=False) + masks = masks.sigmoid() > 0.5 + masks = masks.to(torch.uint8).cpu().numpy()[:, 0] + + _mask_image = np.zeros((masks.shape[1], masks.shape[2], 3), dtype=np.uint8) + + for i, mask in enumerate(masks): + color = colors[i % len(colors)] + print(color) + _mask_image[:, :, 0] = _mask_image[:, :, 0] + mask.astype(np.uint8) * color[0] + _mask_image[:, :, 1] = _mask_image[:, :, 1] + mask.astype(np.uint8) * color[1] + _mask_image[:, :, 2] = _mask_image[:, :, 2] + mask.astype(np.uint8) * color[2] + + + image = np.array(image) + image = image * 0.5 + _mask_image * 0.5 + image = image.astype(np.uint8) + image = Image.fromarray(image) + image.save(save_dir) + + return + +def show_mask_pred_binary(image, masks, save_dir='./output.png', filter_ids=[], text=''): + import torch.nn.functional as F + from PIL import Image + import numpy as np + + print(text) + filter_ids = [] + for i in range(len(masks)): + if 'Mark {}'.format(i+1) in text: + filter_ids.append(i) + print(filter_ids) + + colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), + (255, 255, 0), (255, 0, 255), (0, 255, 255), + (128, 128, 255)] + + masks = masks.to(torch.float32) + masks = F.interpolate(masks.unsqueeze(0), size=image.size, mode='bilinear', align_corners=False)[0] + masks = masks > 0.5 + + masks = masks.to(torch.uint8).cpu().numpy() + + _mask_image = np.zeros((masks.shape[1], masks.shape[2], 3), dtype=np.uint8) + + for i, mask in enumerate(masks): + if i not in filter_ids:continue + color = colors[i % len(colors)] + print(color) + _mask_image[:, :, 0] = _mask_image[:, :, 0] + mask.astype(np.uint8) * color[0] + _mask_image[:, :, 1] = _mask_image[:, :, 1] + mask.astype(np.uint8) * color[1] + _mask_image[:, :, 2] = _mask_image[:, :, 2] + mask.astype(np.uint8) * color[2] + + + image = np.array(image) + image = image * 0.5 + _mask_image * 0.5 + image = image.astype(np.uint8) + image = Image.fromarray(image) + image.save(save_dir) + + return + +if __name__ == '__main__': + main() diff --git a/omg_llava/tools/utils_refcoco.py b/omg_llava/tools/utils_refcoco.py new file mode 100644 index 0000000000000000000000000000000000000000..5c23cb0a741df3d9d6b02b89b1de2e9437805e76 --- /dev/null +++ b/omg_llava/tools/utils_refcoco.py @@ -0,0 +1,122 @@ +from enum import Enum + +import numpy as np +import torch +import torch.distributed as dist + + +class Summary(Enum): + NONE = 0 + AVERAGE = 1 + SUM = 2 + COUNT = 3 + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE): + self.name = name + self.fmt = fmt + self.summary_type = summary_type + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def all_reduce(self): + device = "cuda" if torch.cuda.is_available() else "cpu" + if isinstance(self.sum, np.ndarray): + total = torch.tensor( + self.sum.tolist() + + [ + self.count, + ], + dtype=torch.float32, + device=device, + ) + else: + total = torch.tensor( + [self.sum, self.count], dtype=torch.float32, device=device + ) + + dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False) + if total.shape[0] > 2: + self.sum, self.count = total[:-1].cpu().numpy(), total[-1].cpu().item() + else: + self.sum, self.count = total.tolist() + self.avg = self.sum / (self.count + 1e-5) + + def __str__(self): + fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" + return fmtstr.format(**self.__dict__) + + def summary(self): + fmtstr = "" + if self.summary_type is Summary.NONE: + fmtstr = "" + elif self.summary_type is Summary.AVERAGE: + fmtstr = "{name} {avg:.3f}" + elif self.summary_type is Summary.SUM: + fmtstr = "{name} {sum:.3f}" + elif self.summary_type is Summary.COUNT: + fmtstr = "{name} {count:.3f}" + else: + raise ValueError("invalid summary type %r" % self.summary_type) + + return fmtstr.format(**self.__dict__) + + +def intersectionAndUnionGPU(output, target, K, ignore_index=255): + # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1. + assert output.dim() in [1, 2, 3] + assert output.shape == target.shape + output = output.view(-1) + target = target.view(-1) + output[target == ignore_index] = ignore_index + intersection = output[output == target] + area_intersection = torch.histc(intersection, bins=K, min=0, max=K - 1) + area_output = torch.histc(output, bins=K, min=0, max=K - 1) + area_target = torch.histc(target, bins=K, min=0, max=K - 1) + area_union = area_output + area_target - area_intersection + return area_intersection, area_union, area_target + + +class ProgressMeter(object): + def __init__(self, num_batches, meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.prefix = prefix + + def display(self, batch): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + print("\t".join(entries)) + + def display_summary(self): + entries = [" *"] + entries += [meter.summary() for meter in self.meters] + print(" ".join(entries)) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = "{:" + str(num_digits) + "d}" + return "[" + fmt + "/" + fmt.format(num_batches) + "]" + + +def dict_to_cuda(input_dict): + for k, v in input_dict.items(): + if isinstance(input_dict[k], torch.Tensor): + input_dict[k] = v.cuda(non_blocking=True) + elif isinstance(v, list) and len(v) > 0: + input_dict[k] = [ele.cuda(non_blocking=True) if isinstance(ele, torch.Tensor) else ele for ele in v] + return input_dict