diff --git a/omg_llava/__init__.py b/omg_llava/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/omg_llava/__pycache__/__init__.cpython-310.pyc b/omg_llava/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba1dfd8482764fa057058c8ec6b172ab2ccdced4 Binary files /dev/null and b/omg_llava/__pycache__/__init__.cpython-310.pyc differ diff --git a/omg_llava/configs/__init__.py b/omg_llava/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_baseline.py b/omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_baseline.py new file mode 100644 index 0000000000000000000000000000000000000000..41cbb83cec11555506f2bed10386487af9009346 --- /dev/null +++ b/omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_baseline.py @@ -0,0 +1,951 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\ + semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\ + referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = '/mnt/workspace/taozhang/chekpoints/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './work_dirs/omg_llava_7b_pretrain_1024image_8gpus.pth' +omg_ov_class_embed_path='/mnt/workspace/taozhang/chekpoints/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = '/mnt/workspace/taozhang/chekpoints/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 0 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset, + glamm_grandf_dataset, glamm_psg_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + region_cap_osprey_dataset, region_conversation_osprey_dataset, + mdpv_detailed_description_ade20k_dataset, + mdpv_detailed_description_cocostuff_10k_dataset, + mdpv_detailed_description_cocostuff_164k_dataset, + mdpv_detailed_description_vg_dataset, + mdpv_brief_description_lvis_dataset, + mdpv_brief_description_vg_dataset, + mdpv_brief_description_ade20k_dataset, + mdpv_brief_description_cocostuff10k_dataset, + mdpv_brief_description_cocostuff164k_dataset, + mdpv_qa_vg_dataset, + mdpv_qa_lvis_dataset, + mdpv_qa_ade20k_dataset, + mdpv_qa_cocostuff10k_dataset, + mdpv_qa_cocostuff164k_dataset, + mdpv_multi_points_flicker30k_dataset, + mdpv_multi_points_openpsg_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_cat.py b/omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_cat.py new file mode 100644 index 0000000000000000000000000000000000000000..b6da12d81e9c0127601e731fcc4528b1a4a061a3 --- /dev/null +++ b/omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_cat.py @@ -0,0 +1,954 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\ + semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\ + referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = '/mnt/workspace/taozhang/chekpoints/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './work_dirs/omg_llava_7b_pretrain_1024image_8gpus.pth' +omg_ov_class_embed_path='/mnt/workspace/taozhang/chekpoints/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = '/mnt/workspace/taozhang/chekpoints/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 0 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + using_multilayer_states=True, + seg_token_merge_type='cat', + selected_layers=32, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset, + glamm_grandf_dataset, glamm_psg_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + region_cap_osprey_dataset, region_conversation_osprey_dataset, + mdpv_detailed_description_ade20k_dataset, + mdpv_detailed_description_cocostuff_10k_dataset, + mdpv_detailed_description_cocostuff_164k_dataset, + mdpv_detailed_description_vg_dataset, + mdpv_brief_description_lvis_dataset, + mdpv_brief_description_vg_dataset, + mdpv_brief_description_ade20k_dataset, + mdpv_brief_description_cocostuff10k_dataset, + mdpv_brief_description_cocostuff164k_dataset, + mdpv_qa_vg_dataset, + mdpv_qa_lvis_dataset, + mdpv_qa_ade20k_dataset, + mdpv_qa_cocostuff10k_dataset, + mdpv_qa_cocostuff164k_dataset, + mdpv_multi_points_flicker30k_dataset, + mdpv_multi_points_openpsg_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_cat_debug.py b/omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_cat_debug.py new file mode 100644 index 0000000000000000000000000000000000000000..ef61339363f4e7bac0a347b5ee86f24989fa5a05 --- /dev/null +++ b/omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_cat_debug.py @@ -0,0 +1,927 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\ + semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\ + referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = '/mnt/workspace/taozhang/chekpoints/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './work_dirs/omg_llava_7b_pretrain_1024image_8gpus.pth' +omg_ov_class_embed_path='/mnt/workspace/taozhang/chekpoints/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = '/mnt/workspace/taozhang/chekpoints/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 0 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + using_multilayer_states=True, + seg_token_merge_type='cat', + selected_layers=32, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[glamm_refcocog_dataset, ], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_linear_cat.py b/omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_linear_cat.py new file mode 100644 index 0000000000000000000000000000000000000000..415bc59f269dbddfb786ff3240e11e44c730d895 --- /dev/null +++ b/omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_linear_cat.py @@ -0,0 +1,954 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\ + semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\ + referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = '/mnt/workspace/taozhang/chekpoints/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './work_dirs/omg_llava_7b_pretrain_1024image_8gpus.pth' +omg_ov_class_embed_path='/mnt/workspace/taozhang/chekpoints/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = '/mnt/workspace/taozhang/chekpoints/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 0 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + using_multilayer_states=True, + seg_token_merge_type='linear_cat', + selected_layers=32, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset, + glamm_grandf_dataset, glamm_psg_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + region_cap_osprey_dataset, region_conversation_osprey_dataset, + mdpv_detailed_description_ade20k_dataset, + mdpv_detailed_description_cocostuff_10k_dataset, + mdpv_detailed_description_cocostuff_164k_dataset, + mdpv_detailed_description_vg_dataset, + mdpv_brief_description_lvis_dataset, + mdpv_brief_description_vg_dataset, + mdpv_brief_description_ade20k_dataset, + mdpv_brief_description_cocostuff10k_dataset, + mdpv_brief_description_cocostuff164k_dataset, + mdpv_qa_vg_dataset, + mdpv_qa_lvis_dataset, + mdpv_qa_ade20k_dataset, + mdpv_qa_cocostuff10k_dataset, + mdpv_qa_cocostuff164k_dataset, + mdpv_multi_points_flicker30k_dataset, + mdpv_multi_points_openpsg_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_linearcat_debug.py b/omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_linearcat_debug.py new file mode 100644 index 0000000000000000000000000000000000000000..782946579760f1aa0ac7d91993e0d9de8b54830d --- /dev/null +++ b/omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_linearcat_debug.py @@ -0,0 +1,927 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\ + semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\ + referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = '/mnt/workspace/taozhang/chekpoints/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './work_dirs/omg_llava_7b_pretrain_1024image_8gpus.pth' +omg_ov_class_embed_path='/mnt/workspace/taozhang/chekpoints/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = '/mnt/workspace/taozhang/chekpoints/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 0 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + using_multilayer_states=True, + seg_token_merge_type='linear_cat', + selected_layers=32, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[glamm_refcocog_dataset, ], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_mean.py b/omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_mean.py new file mode 100644 index 0000000000000000000000000000000000000000..7dcf07802c3a081dfc224af7696c33a8fce37190 --- /dev/null +++ b/omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_mean.py @@ -0,0 +1,954 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\ + semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\ + referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = '/mnt/workspace/taozhang/chekpoints/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './work_dirs/omg_llava_7b_pretrain_1024image_8gpus.pth' +omg_ov_class_embed_path='/mnt/workspace/taozhang/chekpoints/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = '/mnt/workspace/taozhang/chekpoints/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 0 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + using_multilayer_states=True, + seg_token_merge_type='mean', + selected_layers=32, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset, + glamm_grandf_dataset, glamm_psg_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + region_cap_osprey_dataset, region_conversation_osprey_dataset, + mdpv_detailed_description_ade20k_dataset, + mdpv_detailed_description_cocostuff_10k_dataset, + mdpv_detailed_description_cocostuff_164k_dataset, + mdpv_detailed_description_vg_dataset, + mdpv_brief_description_lvis_dataset, + mdpv_brief_description_vg_dataset, + mdpv_brief_description_ade20k_dataset, + mdpv_brief_description_cocostuff10k_dataset, + mdpv_brief_description_cocostuff164k_dataset, + mdpv_qa_vg_dataset, + mdpv_qa_lvis_dataset, + mdpv_qa_ade20k_dataset, + mdpv_qa_cocostuff10k_dataset, + mdpv_qa_cocostuff164k_dataset, + mdpv_multi_points_flicker30k_dataset, + mdpv_multi_points_openpsg_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/ablation_multi_seg_states/debug.py b/omg_llava/configs/finetune/ablation_multi_seg_states/debug.py new file mode 100644 index 0000000000000000000000000000000000000000..3c5782d906d9e7800a423a7e1b0ac5e254639cf2 --- /dev/null +++ b/omg_llava/configs/finetune/ablation_multi_seg_states/debug.py @@ -0,0 +1,924 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\ + semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\ + referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = '/mnt/workspace/taozhang/chekpoints/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './work_dirs/omg_llava_7b_pretrain_1024image_8gpus/iter_4361.pth' +omg_ov_class_embed_path='/mnt/workspace/taozhang/chekpoints/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = '/mnt/workspace/taozhang/chekpoints/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[glamm_refcocog_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/ablation_projector/finetune_ablation_projector_cross.py b/omg_llava/configs/finetune/ablation_projector/finetune_ablation_projector_cross.py new file mode 100644 index 0000000000000000000000000000000000000000..b98374afc1d1350cbe2b4e6a2d4d82e23e6c9092 --- /dev/null +++ b/omg_llava/configs/finetune/ablation_projector/finetune_ablation_projector_cross.py @@ -0,0 +1,953 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\ + semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\ + referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './pretrained/omg_llava/omg_llava_7b_pretrain_1024image_seperate_cross_8gpus.pth' +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + visual_prompt_proj=False, + add_cross_attn_layer=True, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset, + glamm_grandf_dataset, glamm_psg_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + region_cap_osprey_dataset, region_conversation_osprey_dataset, + mdpv_detailed_description_ade20k_dataset, + mdpv_detailed_description_cocostuff_10k_dataset, + mdpv_detailed_description_cocostuff_164k_dataset, + mdpv_detailed_description_vg_dataset, + mdpv_brief_description_lvis_dataset, + mdpv_brief_description_vg_dataset, + mdpv_brief_description_ade20k_dataset, + mdpv_brief_description_cocostuff10k_dataset, + mdpv_brief_description_cocostuff164k_dataset, + mdpv_qa_vg_dataset, + mdpv_qa_lvis_dataset, + mdpv_qa_ade20k_dataset, + mdpv_qa_cocostuff10k_dataset, + mdpv_qa_cocostuff164k_dataset, + mdpv_multi_points_flicker30k_dataset, + mdpv_multi_points_openpsg_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/ablation_projector/finetune_ablation_projector_seperate.py b/omg_llava/configs/finetune/ablation_projector/finetune_ablation_projector_seperate.py new file mode 100644 index 0000000000000000000000000000000000000000..2690aed1cbafbb4667910d37a20ba9610e010e7b --- /dev/null +++ b/omg_llava/configs/finetune/ablation_projector/finetune_ablation_projector_seperate.py @@ -0,0 +1,953 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\ + semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\ + referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './pretrained/omg_llava/omg_llava_7b_pretrain_1024image_8gpus.pth' +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + visual_prompt_proj=True, + add_cross_attn_layer=False, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset, + glamm_grandf_dataset, glamm_psg_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + region_cap_osprey_dataset, region_conversation_osprey_dataset, + mdpv_detailed_description_ade20k_dataset, + mdpv_detailed_description_cocostuff_10k_dataset, + mdpv_detailed_description_cocostuff_164k_dataset, + mdpv_detailed_description_vg_dataset, + mdpv_brief_description_lvis_dataset, + mdpv_brief_description_vg_dataset, + mdpv_brief_description_ade20k_dataset, + mdpv_brief_description_cocostuff10k_dataset, + mdpv_brief_description_cocostuff164k_dataset, + mdpv_qa_vg_dataset, + mdpv_qa_lvis_dataset, + mdpv_qa_ade20k_dataset, + mdpv_qa_cocostuff10k_dataset, + mdpv_qa_cocostuff164k_dataset, + mdpv_multi_points_flicker30k_dataset, + mdpv_multi_points_openpsg_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/ablation_projector/finetune_ablation_projector_seperate_cross.py b/omg_llava/configs/finetune/ablation_projector/finetune_ablation_projector_seperate_cross.py new file mode 100644 index 0000000000000000000000000000000000000000..eb575feac689040f4b98979e56fca6c1f72f3f35 --- /dev/null +++ b/omg_llava/configs/finetune/ablation_projector/finetune_ablation_projector_seperate_cross.py @@ -0,0 +1,953 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\ + semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\ + referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './pretrained/omg_llava/omg_llava_7b_pretrain_1024image_seperate_cross_8gpus.pth' +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + visual_prompt_proj=True, + add_cross_attn_layer=True, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset, + glamm_grandf_dataset, glamm_psg_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + region_cap_osprey_dataset, region_conversation_osprey_dataset, + mdpv_detailed_description_ade20k_dataset, + mdpv_detailed_description_cocostuff_10k_dataset, + mdpv_detailed_description_cocostuff_164k_dataset, + mdpv_detailed_description_vg_dataset, + mdpv_brief_description_lvis_dataset, + mdpv_brief_description_vg_dataset, + mdpv_brief_description_ade20k_dataset, + mdpv_brief_description_cocostuff10k_dataset, + mdpv_brief_description_cocostuff164k_dataset, + mdpv_qa_vg_dataset, + mdpv_qa_lvis_dataset, + mdpv_qa_ade20k_dataset, + mdpv_qa_cocostuff10k_dataset, + mdpv_qa_cocostuff164k_dataset, + mdpv_multi_points_flicker30k_dataset, + mdpv_multi_points_openpsg_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/ablation_projector/finetune_ablation_projector_seperate_cross_debug.py b/omg_llava/configs/finetune/ablation_projector/finetune_ablation_projector_seperate_cross_debug.py new file mode 100644 index 0000000000000000000000000000000000000000..383e5cf1e0bbe5e13c1e9cd4eb9994cdb7064f49 --- /dev/null +++ b/omg_llava/configs/finetune/ablation_projector/finetune_ablation_projector_seperate_cross_debug.py @@ -0,0 +1,926 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\ + semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\ + referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './pretrained/omg_llava/omg_llava_7b_pretrain_1024image_seperate_cross_8gpus.pth' +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 0 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + visual_prompt_proj=True, + add_cross_attn_layer=True, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[mdpv_brief_description_lvis_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/debug.py b/omg_llava/configs/finetune/debug.py new file mode 100644 index 0000000000000000000000000000000000000000..78d2b667c30cd0fada27a2b27475f01f89b7cefb --- /dev/null +++ b/omg_llava/configs/finetune/debug.py @@ -0,0 +1,967 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = './pretrained/omg_llava/internlm2-7b' # Please change to your own path +pretrained_pth = './pretrained/omg_llava/omg_llava_7b_pretrain_1024image_8gpus.pth' +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +semantic_seg_mapillary_dataset = dict( + type=MapillarySemanticSegDataset, + data_path=mapillary_class_file, + image_folder=mapillary_image_path, + label_path=mapillary_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +semantic_seg_pascal_part_dataset = dict( + type=PascalPartSemanticSegDataset, + data_path=pascal_file, + image_folder=pascal_part_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=pascal_part_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +semantic_seg_paco_dataset = dict( + type=PacoSemanticSegDataset, + data_path=paco_file, + image_folder=paco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=pascal_part_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[glamm_refcocog_dataset, + ], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/fix_unfrozen_bug_omg_llava_7b_finetune_stage1_1024image_uniSegFormat_8gpus.py b/omg_llava/configs/finetune/fix_unfrozen_bug_omg_llava_7b_finetune_stage1_1024image_uniSegFormat_8gpus.py new file mode 100644 index 0000000000000000000000000000000000000000..1e8f9f91932249edf521af375f7da95c10924c94 --- /dev/null +++ b/omg_llava/configs/finetune/fix_unfrozen_bug_omg_llava_7b_finetune_stage1_1024image_uniSegFormat_8gpus.py @@ -0,0 +1,951 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\ + semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\ + referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './pretrained/omg_llava/omg_llava_7b_pretrain_1024image_8gpus.pth' +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 0 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset, + glamm_grandf_dataset, glamm_psg_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + region_cap_osprey_dataset, region_conversation_osprey_dataset, + mdpv_detailed_description_ade20k_dataset, + mdpv_detailed_description_cocostuff_10k_dataset, + mdpv_detailed_description_cocostuff_164k_dataset, + mdpv_detailed_description_vg_dataset, + mdpv_brief_description_lvis_dataset, + mdpv_brief_description_vg_dataset, + mdpv_brief_description_ade20k_dataset, + mdpv_brief_description_cocostuff10k_dataset, + mdpv_brief_description_cocostuff164k_dataset, + mdpv_qa_vg_dataset, + mdpv_qa_lvis_dataset, + mdpv_qa_ade20k_dataset, + mdpv_qa_cocostuff10k_dataset, + mdpv_qa_cocostuff164k_dataset, + mdpv_multi_points_flicker30k_dataset, + mdpv_multi_points_openpsg_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/hf_app.py b/omg_llava/configs/finetune/hf_app.py new file mode 100644 index 0000000000000000000000000000000000000000..9e5ef9d1fcb16b9c3c96e5c0eb9b7f4992df7c19 --- /dev/null +++ b/omg_llava/configs/finetune/hf_app.py @@ -0,0 +1,951 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\ + semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\ + referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './pretrained/omg_llava/omg_llava_7b_pretrain_1024image_8gpus.pth' +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset, + glamm_grandf_dataset, glamm_psg_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + region_cap_osprey_dataset, region_conversation_osprey_dataset, + mdpv_detailed_description_ade20k_dataset, + mdpv_detailed_description_cocostuff_10k_dataset, + mdpv_detailed_description_cocostuff_164k_dataset, + mdpv_detailed_description_vg_dataset, + mdpv_brief_description_lvis_dataset, + mdpv_brief_description_vg_dataset, + mdpv_brief_description_ade20k_dataset, + mdpv_brief_description_cocostuff10k_dataset, + mdpv_brief_description_cocostuff164k_dataset, + mdpv_qa_vg_dataset, + mdpv_qa_lvis_dataset, + mdpv_qa_ade20k_dataset, + mdpv_qa_cocostuff10k_dataset, + mdpv_qa_cocostuff164k_dataset, + mdpv_multi_points_flicker30k_dataset, + mdpv_multi_points_openpsg_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/omg_llava_20b_finetune_stage1_1024image_8gpus.py b/omg_llava/configs/finetune/omg_llava_20b_finetune_stage1_1024image_8gpus.py new file mode 100644 index 0000000000000000000000000000000000000000..3c38e99ae60a4ae2da5901ef3c6a1066507a555d --- /dev/null +++ b/omg_llava/configs/finetune/omg_llava_20b_finetune_stage1_1024image_8gpus.py @@ -0,0 +1,993 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = './pretrained/omg_llava/internlm2-chat-20b' # Please change to your own path +pretrained_pth = './pretrained/omg_llava/omg_llava_20b_pretrain_1024image_8gpus.pth' +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 4 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +semantic_seg_mapillary_dataset = dict( + type=MapillarySemanticSegDataset, + data_path=mapillary_class_file, + image_folder=mapillary_image_path, + label_path=mapillary_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +semantic_seg_pascal_part_dataset = dict( + type=PascalPartSemanticSegDataset, + data_path=pascal_file, + image_folder=pascal_part_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=pascal_part_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +semantic_seg_paco_dataset = dict( + type=PacoSemanticSegDataset, + data_path=paco_file, + image_folder=paco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=pascal_part_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset, + glamm_grandf_dataset, glamm_psg_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + region_cap_osprey_dataset, region_conversation_osprey_dataset, + mdpv_detailed_description_ade20k_dataset, + mdpv_detailed_description_cocostuff_10k_dataset, + mdpv_detailed_description_cocostuff_164k_dataset, + mdpv_detailed_description_vg_dataset, + mdpv_brief_description_lvis_dataset, + mdpv_brief_description_vg_dataset, + mdpv_brief_description_ade20k_dataset, + mdpv_brief_description_cocostuff10k_dataset, + mdpv_brief_description_cocostuff164k_dataset, + mdpv_qa_vg_dataset, + mdpv_qa_lvis_dataset, + mdpv_qa_ade20k_dataset, + mdpv_qa_cocostuff10k_dataset, + mdpv_qa_cocostuff164k_dataset, + mdpv_multi_points_flicker30k_dataset, + mdpv_multi_points_openpsg_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/omg_llava_7b_convnextXXL_finetune_stage1_1024image_uniSegFormat_8gpus.py b/omg_llava/configs/finetune/omg_llava_7b_convnextXXL_finetune_stage1_1024image_uniSegFormat_8gpus.py new file mode 100644 index 0000000000000000000000000000000000000000..ce00b4aee77b8abf7b90bf120026394778880a79 --- /dev/null +++ b/omg_llava/configs/finetune/omg_llava_7b_convnextXXL_finetune_stage1_1024image_uniSegFormat_8gpus.py @@ -0,0 +1,952 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\ + semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\ + referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './pretrained/omg_llava/omg_llava_7b_pretrain_1024image_convnextXXL.pth' +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_xxlarge_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convxxl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_xxlarge', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s34b_b82k_augreg_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[384, 768, 1536, 3072], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + clip_feat_channel=3072, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset, + glamm_grandf_dataset, glamm_psg_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + region_cap_osprey_dataset, region_conversation_osprey_dataset, + mdpv_detailed_description_ade20k_dataset, + mdpv_detailed_description_cocostuff_10k_dataset, + mdpv_detailed_description_cocostuff_164k_dataset, + mdpv_detailed_description_vg_dataset, + mdpv_brief_description_lvis_dataset, + mdpv_brief_description_vg_dataset, + mdpv_brief_description_ade20k_dataset, + mdpv_brief_description_cocostuff10k_dataset, + mdpv_brief_description_cocostuff164k_dataset, + mdpv_qa_vg_dataset, + mdpv_qa_lvis_dataset, + mdpv_qa_ade20k_dataset, + mdpv_qa_cocostuff10k_dataset, + mdpv_qa_cocostuff164k_dataset, + mdpv_multi_points_flicker30k_dataset, + mdpv_multi_points_openpsg_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_8gpus.py b/omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_8gpus.py new file mode 100644 index 0000000000000000000000000000000000000000..e3c9a42e5bf34fa3ffc9464a6c708e63c2bba10e --- /dev/null +++ b/omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_8gpus.py @@ -0,0 +1,993 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './pretrained/omg_llava/omg_llava_7b_pretrain_1024image_8gpus.pth' +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +semantic_seg_mapillary_dataset = dict( + type=MapillarySemanticSegDataset, + data_path=mapillary_class_file, + image_folder=mapillary_image_path, + label_path=mapillary_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +semantic_seg_pascal_part_dataset = dict( + type=PascalPartSemanticSegDataset, + data_path=pascal_file, + image_folder=pascal_part_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=pascal_part_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +semantic_seg_paco_dataset = dict( + type=PacoSemanticSegDataset, + data_path=paco_file, + image_folder=paco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=pascal_part_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset, + glamm_grandf_dataset, glamm_psg_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + region_cap_osprey_dataset, region_conversation_osprey_dataset, + mdpv_detailed_description_ade20k_dataset, + mdpv_detailed_description_cocostuff_10k_dataset, + mdpv_detailed_description_cocostuff_164k_dataset, + mdpv_detailed_description_vg_dataset, + mdpv_brief_description_lvis_dataset, + mdpv_brief_description_vg_dataset, + mdpv_brief_description_ade20k_dataset, + mdpv_brief_description_cocostuff10k_dataset, + mdpv_brief_description_cocostuff164k_dataset, + mdpv_qa_vg_dataset, + mdpv_qa_lvis_dataset, + mdpv_qa_ade20k_dataset, + mdpv_qa_cocostuff10k_dataset, + mdpv_qa_cocostuff164k_dataset, + mdpv_multi_points_flicker30k_dataset, + mdpv_multi_points_openpsg_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_8gpus_01.py b/omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_8gpus_01.py new file mode 100644 index 0000000000000000000000000000000000000000..a46b0008af47d266d60e2066255a381e918fbd35 --- /dev/null +++ b/omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_8gpus_01.py @@ -0,0 +1,1007 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = './pretrained/omg_llava/internlm2-7b' # Please change to your own path +pretrained_pth = './pretrained/omg_llava/omg_llava_7b_pretrain_1024image_8gpus.pth' +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, + num_proc=32 +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, + num_proc=32 +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, + num_proc=32 +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + num_proc=32 +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + num_proc=32 +) + +semantic_seg_mapillary_dataset = dict( + type=MapillarySemanticSegDataset, + data_path=mapillary_class_file, + image_folder=mapillary_image_path, + label_path=mapillary_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + num_proc=32 +) + +semantic_seg_pascal_part_dataset = dict( + type=PascalPartSemanticSegDataset, + data_path=pascal_file, + image_folder=pascal_part_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=pascal_part_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + num_proc=32 +) + +semantic_seg_paco_dataset = dict( + type=PacoSemanticSegDataset, + data_path=paco_file, + image_folder=paco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=pascal_part_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + num_proc=32 +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + num_proc=32 +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + num_proc=32 +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + num_proc=32 +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + num_proc=32 +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + num_proc=32 +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + num_proc=32 +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset, + glamm_grandf_dataset, glamm_psg_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + region_cap_osprey_dataset, region_conversation_osprey_dataset, + mdpv_detailed_description_ade20k_dataset, + mdpv_detailed_description_cocostuff_10k_dataset, + mdpv_detailed_description_cocostuff_164k_dataset, + mdpv_detailed_description_vg_dataset, + mdpv_brief_description_lvis_dataset, + mdpv_brief_description_vg_dataset, + mdpv_brief_description_ade20k_dataset, + mdpv_brief_description_cocostuff10k_dataset, + mdpv_brief_description_cocostuff164k_dataset, + mdpv_qa_vg_dataset, + mdpv_qa_lvis_dataset, + mdpv_qa_ade20k_dataset, + mdpv_qa_cocostuff10k_dataset, + mdpv_qa_cocostuff164k_dataset, + mdpv_multi_points_flicker30k_dataset, + mdpv_multi_points_openpsg_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_decoupleGCG_8gpus.py b/omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_decoupleGCG_8gpus.py new file mode 100644 index 0000000000000000000000000000000000000000..510b1d2e784b189086d6e2a1b61753350aea852a --- /dev/null +++ b/omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_decoupleGCG_8gpus.py @@ -0,0 +1,1028 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\ + semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\ + referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn,\ + DecoupledGranDfGCGDataset, DecoupledOpenPsgGCGDataset, DecoupledRefCOCOgGCGDataset, DecoupledFlickrGCGDataset,\ + glamm_openpsg_decoupled_given_description_map_fn, glamm_openpsg_decoupled_given_objects_map_fn,\ + glamm_flickr_decoupled_given_objects_map_fn, glamm_flickr_decoupled_given_description_map_fn,\ + glamm_granf_decoupled_given_objects_map_fn, glamm_granf_decoupled_given_description_map_fn,\ + glamm_refcocog_decoupled_given_objects_map_fn, glamm_refcocog_decoupled_given_description_map_fn + +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './pretrained/omg_llava/omg_llava_7b_pretrain_1024image_8gpus.pth' +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 0 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset_given_description = dict( + type=DecoupledRefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_decoupled_given_description_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + mode='given_description' +) + +glamm_refcocog_dataset_given_objects = dict( + type=DecoupledRefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_decoupled_given_objects_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + mode='given_objects' +) + +glamm_grandf_dataset_given_description = dict( + type=DecoupledGranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_decoupled_given_description_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, + mode='given_description' +) + +glamm_grandf_dataset_given_objects = dict( + type=DecoupledGranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_decoupled_given_objects_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, + mode='given_objects' +) + +glamm_psg_dataset_given_description = dict( + type=DecoupledOpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_decoupled_given_description_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, + mode='given_description' +) + +glamm_psg_dataset_given_objects = dict( + type=DecoupledOpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_decoupled_given_objects_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, + mode='given_objects' +) + +glamm_flickr_dataset_given_description = dict( + type=DecoupledFlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_decoupled_given_description_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, + mode='given_description' +) + +glamm_flickr_dataset_given_objects = dict( + type=DecoupledFlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_decoupled_given_objects_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, + mode='given_objects' +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[llava_dataset, + glamm_flickr_dataset_given_description, glamm_flickr_dataset_given_objects, + glamm_refcocog_dataset_given_objects, glamm_refcocog_dataset_given_description, + glamm_psg_dataset_given_description, glamm_psg_dataset_given_objects, + glamm_grandf_dataset_given_description, glamm_grandf_dataset_given_objects, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + region_cap_osprey_dataset, region_conversation_osprey_dataset, + mdpv_detailed_description_ade20k_dataset, + mdpv_detailed_description_cocostuff_10k_dataset, + mdpv_detailed_description_cocostuff_164k_dataset, + mdpv_detailed_description_vg_dataset, + mdpv_brief_description_lvis_dataset, + mdpv_brief_description_vg_dataset, + mdpv_brief_description_ade20k_dataset, + mdpv_brief_description_cocostuff10k_dataset, + mdpv_brief_description_cocostuff164k_dataset, + mdpv_qa_vg_dataset, + mdpv_qa_lvis_dataset, + mdpv_qa_ade20k_dataset, + mdpv_qa_cocostuff10k_dataset, + mdpv_qa_cocostuff164k_dataset, + mdpv_multi_points_flicker30k_dataset, + mdpv_multi_points_openpsg_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_decoupleGCG_8gpus_debug.py b/omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_decoupleGCG_8gpus_debug.py new file mode 100644 index 0000000000000000000000000000000000000000..38ac6d9d03f228c0189922fdfae7f4acbc7c6d04 --- /dev/null +++ b/omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_decoupleGCG_8gpus_debug.py @@ -0,0 +1,1000 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\ + semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\ + referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn,\ + DecoupledGranDfGCGDataset, DecoupledOpenPsgGCGDataset, DecoupledRefCOCOgGCGDataset, DecoupledFlickrGCGDataset,\ + glamm_openpsg_decoupled_given_description_map_fn, glamm_openpsg_decoupled_given_objects_map_fn,\ + glamm_flickr_decoupled_given_objects_map_fn, glamm_flickr_decoupled_given_description_map_fn,\ + glamm_granf_decoupled_given_objects_map_fn, glamm_granf_decoupled_given_description_map_fn,\ + glamm_refcocog_decoupled_given_objects_map_fn, glamm_refcocog_decoupled_given_description_map_fn + +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './pretrained/omg_llava/omg_llava_7b_pretrain_1024image_8gpus.pth' +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 0 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset_given_description = dict( + type=DecoupledRefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_decoupled_given_description_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + mode='given_description' +) + +glamm_refcocog_dataset_given_objects = dict( + type=DecoupledRefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_decoupled_given_objects_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + mode='given_objects' +) + +glamm_grandf_dataset_given_description = dict( + type=DecoupledGranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_decoupled_given_description_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, + mode='given_description' +) + +glamm_grandf_dataset_given_objects = dict( + type=DecoupledGranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_decoupled_given_objects_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, + mode='given_objects' +) + +glamm_psg_dataset_given_description = dict( + type=DecoupledOpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_decoupled_given_description_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, + mode='given_description' +) + +glamm_psg_dataset_given_objects = dict( + type=DecoupledOpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_decoupled_given_objects_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, + mode='given_objects' +) + +glamm_flickr_dataset_given_description = dict( + type=DecoupledFlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_decoupled_given_description_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, + mode='given_description' +) + +glamm_flickr_dataset_given_objects = dict( + type=DecoupledFlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_decoupled_given_objects_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, + mode='given_objects' +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[ + glamm_refcocog_dataset_given_objects, glamm_refcocog_dataset_given_description, + ], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_uniSegFormat_8gpus.py b/omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_uniSegFormat_8gpus.py new file mode 100644 index 0000000000000000000000000000000000000000..9e5ef9d1fcb16b9c3c96e5c0eb9b7f4992df7c19 --- /dev/null +++ b/omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_uniSegFormat_8gpus.py @@ -0,0 +1,951 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\ + semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\ + referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './pretrained/omg_llava/omg_llava_7b_pretrain_1024image_8gpus.pth' +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset, + glamm_grandf_dataset, glamm_psg_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + region_cap_osprey_dataset, region_conversation_osprey_dataset, + mdpv_detailed_description_ade20k_dataset, + mdpv_detailed_description_cocostuff_10k_dataset, + mdpv_detailed_description_cocostuff_164k_dataset, + mdpv_detailed_description_vg_dataset, + mdpv_brief_description_lvis_dataset, + mdpv_brief_description_vg_dataset, + mdpv_brief_description_ade20k_dataset, + mdpv_brief_description_cocostuff10k_dataset, + mdpv_brief_description_cocostuff164k_dataset, + mdpv_qa_vg_dataset, + mdpv_qa_lvis_dataset, + mdpv_qa_ade20k_dataset, + mdpv_qa_cocostuff10k_dataset, + mdpv_qa_cocostuff164k_dataset, + mdpv_multi_points_flicker30k_dataset, + mdpv_multi_points_openpsg_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/omg_llava_7b_finetune_stage2_1024image_8gpus.py b/omg_llava/configs/finetune/omg_llava_7b_finetune_stage2_1024image_8gpus.py new file mode 100644 index 0000000000000000000000000000000000000000..16ed4dc066a7efca8aef6f6dc665fd6cd459767b --- /dev/null +++ b/omg_llava/configs/finetune/omg_llava_7b_finetune_stage2_1024image_8gpus.py @@ -0,0 +1,994 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './work_dirs/omg_llava_7b_finetune_stage1_1024image_8gpus/iter_27600.pth' # noqa: E501 +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=True, + freeze_llm_with_lora=False, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +semantic_seg_mapillary_dataset = dict( + type=MapillarySemanticSegDataset, + data_path=mapillary_class_file, + image_folder=mapillary_image_path, + label_path=mapillary_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +semantic_seg_pascal_part_dataset = dict( + type=PascalPartSemanticSegDataset, + data_path=pascal_file, + image_folder=pascal_part_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=pascal_part_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +semantic_seg_paco_dataset = dict( + type=PacoSemanticSegDataset, + data_path=paco_file, + image_folder=paco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=pascal_part_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset, + glamm_grandf_dataset, glamm_psg_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + region_cap_osprey_dataset, region_conversation_osprey_dataset, + mdpv_detailed_description_ade20k_dataset, + mdpv_detailed_description_cocostuff_10k_dataset, + mdpv_detailed_description_cocostuff_164k_dataset, + mdpv_detailed_description_vg_dataset, + mdpv_brief_description_lvis_dataset, + mdpv_brief_description_vg_dataset, + mdpv_brief_description_ade20k_dataset, + mdpv_brief_description_cocostuff10k_dataset, + mdpv_brief_description_cocostuff164k_dataset, + mdpv_qa_vg_dataset, + mdpv_qa_lvis_dataset, + mdpv_qa_ade20k_dataset, + mdpv_qa_cocostuff10k_dataset, + mdpv_qa_cocostuff164k_dataset, + mdpv_multi_points_flicker30k_dataset, + mdpv_multi_points_openpsg_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/specific_tasks_finetune/finetune_gcg.py b/omg_llava/configs/finetune/specific_tasks_finetune/finetune_gcg.py new file mode 100644 index 0000000000000000000000000000000000000000..9e3dbef5b5583359795e05591509e9d79660455a --- /dev/null +++ b/omg_llava/configs/finetune/specific_tasks_finetune/finetune_gcg.py @@ -0,0 +1,925 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\ + semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\ + referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './work_dirs/omg_llava_1024x_2stage_finetune_1_clear_reratio_rmqcache_uniformSegFormat_8gpus.pth' +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=True, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[glamm_flickr_dataset, glamm_refcocog_dataset, + glamm_grandf_dataset, glamm_psg_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/finetune/specific_tasks_finetune/finetune_refseg.py b/omg_llava/configs/finetune/specific_tasks_finetune/finetune_refseg.py new file mode 100644 index 0000000000000000000000000000000000000000..a0126ce51774c06d72694f61fb67b2690d2ffa9c --- /dev/null +++ b/omg_llava/configs/finetune/specific_tasks_finetune/finetune_refseg.py @@ -0,0 +1,929 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\ + CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\ + ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\ + PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\ + RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\ + Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\ + OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\ + OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\ + MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\ + semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\ + referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn +from xtuner.dataset.samplers import LengthGroupedSampler +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model import OpenCLIPBackbone_omgseg +from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path +pretrained_pth = './work_dirs/omg_llava_1024x_2stage_finetune_1_clear_reratio_rmqcache_uniformSegFormat_8gpus.pth' +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path + +# Data +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' + +glamm_data_root = './data/glamm_data/' + +refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' +refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' + +grandf_image_path = glamm_data_root + 'images/grandf/train/' +grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' + +flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' +flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' + +psg_image_path = glamm_data_root + 'images/coco2017/' +psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' + +ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json' + +cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/' +cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt' +cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/' + +mapillary_image_path = './data/semantic_seg/mapillary/training/images/' +mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json' +mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/' + +pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/' +pascal_file = './data/semantic_seg/pascal_part/train.json' + +paco_image_path = './data/glamm_data/images/coco2017/' +paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json' + +referring_refcoco_image_path = refcocog_image_path +referring_refcoco_data_path = "./data/ref_seg/" + +referring_refcoco_plus_image_path = refcocog_image_path +referring_refcoco_plus_data_path = "./data/ref_seg/" + +referring_refcocog_image_path = refcocog_image_path +referring_refcocog_data_path = "./data/ref_seg/" + +referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/" +referring_refclef_data_path = "./data/ref_seg/" + +region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json" + +region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/' +region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json" + +mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json' + +mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json' + +mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json' + +mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json' + +mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json' + +mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json' + +mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json' + +mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json' + +mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K' +mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json' + +mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/' +mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json' + +mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json' + +mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017' +mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json' + +mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/' +mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json' + +mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/' +mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json' + +mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017' +mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json' + +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2 - 100) + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 2000 +save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 2000 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture', + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + require_omg_decoder=True, + pretrained_pth=pretrained_pth, + text2vision_projector=True, + pixel_shuffle_ratio=2, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +debug=False +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True) + +glamm_refcocog_dataset = dict( + type=RefCOCOgGCGDataset, + data_path=refcocog_ann_file, + image_folder=refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_refcocog_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +glamm_grandf_dataset = dict( + type=GranDfGCGDataset, + data_path=grandf_ann_file, + image_folder=grandf_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_granf_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=10, +) + +glamm_psg_dataset = dict( + type=OpenPsgGCGDataset, + data_path=psg_ann_file, + image_folder=psg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_openpsg_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +glamm_flickr_dataset = dict( + type=FlickrGCGDataset, + data_path=flickr_ann_file, + image_folder=flickr_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=glamm_flickr_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=debug, + repeats=1, +) + +semantic_seg_ade20k_dataset = dict( + type=ADE20kSemanticSegDataset, + data_path=ade20k_class_file, + image_folder=ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +semantic_seg_cocostuff_dataset = dict( + type=COCOStuffSemanticSegDataset, + data_path=cocostuff_class_file, + image_folder=cocostuff_image_path, + label_path=cocostuff_label_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=semantic_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, + gcg_format=True, +) + +referring_seg_refcoco_dataset = dict( + type=RefcocoReferringSegDataset, + data_path=referring_refcoco_data_path, + image_folder=referring_refcoco_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcoco_plus_dataset = dict( + type=Refcoco_plus_ReferringSegDataset, + data_path=referring_refcoco_plus_data_path, + image_folder=referring_refcoco_plus_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refcocog_dataset = dict( + type=Refcocog_ReferringSegDataset, + data_path=referring_refcocog_data_path, + image_folder=referring_refcocog_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +referring_seg_refclef_dataset = dict( + type=Refclef_ReferringSegDataset, + data_path=referring_refclef_data_path, + image_folder=referring_refclef_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=referring_seg_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_cap_osprey_dataset = dict( + type=OspreyRegionCaptionDataset, + data_path=region_cap_osprey_data_path, + image_folder=region_cap_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_caption_gcg_format_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +region_conversation_osprey_dataset = dict( + type=OspreyRegionConversationDataset, + data_path=region_conversation_osprey_data_path, + image_folder=region_conversation_osprey_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=osprey_region_conversation_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_ade20k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_ade20k_data_path, + image_folder=mdpv_detailed_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_10k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_10k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_cocostuff_164k_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_cocostuff_164k_data_path, + image_folder=mdpv_detailed_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_detailed_description_vg_dataset = dict( + type=MDPVPointDetailedCaptionDataset, + data_path=mdpv_detailed_caption_vg_data_path, + image_folder=mdpv_detailed_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_vg_data_path, + image_folder=mdpv_brief_caption_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_10k_data_path, + image_folder=mdpv_brief_caption_cocostuff_10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_cocostuff_164k_data_path, + image_folder=mdpv_brief_caption_cocostuff_164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_ade20k_data_path, + image_folder=mdpv_brief_caption_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_brief_description_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_brief_caption_lvis_data_path, + image_folder=mdpv_brief_caption_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_vg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_vg_data_path, + image_folder=mdpv_qa_vg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_ade20k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_ade20k_data_path, + image_folder=mdpv_qa_ade20k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_lvis_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_lvis_data_path, + image_folder=mdpv_qa_lvis_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff10k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff10k_data_path, + image_folder=mdpv_qa_cocostuff10k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_qa_cocostuff164k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_qa_cocostuff164k_data_path, + image_folder=mdpv_qa_cocostuff164k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_openpsg_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_openpsg_data_path, + image_folder=mdpv_multi_points_openpsg_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +mdpv_multi_points_flicker30k_dataset = dict( + type=MDPVPointBriefCaptionDataset, + data_path=mdpv_multi_points_flicker30k_data_path, + image_folder=mdpv_multi_points_flicker30k_image_path, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=mdpv_points_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, + repeats=1, +) + +train_dataset = dict( + type=CombineDataset, + datasets_cfgs=[referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset, + referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset, + referring_seg_refcocog_dataset, referring_seg_refclef_dataset,], +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/pretrain/ablation_projector/ablation_projector_baseline.py b/omg_llava/configs/pretrain/ablation_projector/ablation_projector_baseline.py new file mode 100644 index 0000000000000000000000000000000000000000..0cb68e3552b3094466bdab4f7bbc49981251d0c1 --- /dev/null +++ b/omg_llava/configs/pretrain/ablation_projector/ablation_projector_baseline.py @@ -0,0 +1,377 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.dataset import DefaultSampler +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model.convnext_clip import OpenCLIPBackbone_omgseg +from omg_llava.model.omg_seg import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model or model paths +llm_name_or_path = '/mnt/workspace/taozhang/chekpoints/internlm2-chat-7b' # Please change to your own path +omg_ov_class_embed_path='/mnt/workspace/taozhang/chekpoints/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = '/mnt/workspace/taozhang/chekpoints/omg_seg_convl.pth' # Please change to your own path + +# Data paths +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Pretrain/blip_laion_cc_sbu_558k.json' +image_folder = data_root + 'LLaVA-Pretrain/images' +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2) + +# Scheduler & Optimizer +batch_size = 16 # per_device +accumulative_counts = 2 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 1e-3 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + +# Save +save_steps = 500 +save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 200 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +# using coco class as the class classifier +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + text2vision_projector=True, + keep_omg_decoder_frozen=True, + add_seg_pretrain=True, + pixel_shuffle_ratio=2, + visual_prompt_proj=False, + add_cross_attn_layer=False, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=llava_dataset, + sampler=dict(type=DefaultSampler, shuffle=True), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/pretrain/ablation_projector/ablation_projector_seperate_cross.py b/omg_llava/configs/pretrain/ablation_projector/ablation_projector_seperate_cross.py new file mode 100644 index 0000000000000000000000000000000000000000..b49b3c19287bb6146b9940bd483c229eaf542e57 --- /dev/null +++ b/omg_llava/configs/pretrain/ablation_projector/ablation_projector_seperate_cross.py @@ -0,0 +1,377 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.dataset import DefaultSampler +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model.convnext_clip import OpenCLIPBackbone_omgseg +from omg_llava.model.omg_seg import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model or model paths +llm_name_or_path = '/mnt/workspace/taozhang/chekpoints/internlm2-chat-7b' # Please change to your own path +omg_ov_class_embed_path='/mnt/workspace/taozhang/chekpoints/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = '/mnt/workspace/taozhang/chekpoints/omg_seg_convl.pth' # Please change to your own path + +# Data paths +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Pretrain/blip_laion_cc_sbu_558k.json' +image_folder = data_root + 'LLaVA-Pretrain/images' +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2) + +# Scheduler & Optimizer +batch_size = 16 # per_device +accumulative_counts = 2 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 1e-3 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + +# Save +save_steps = 500 +save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 200 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +# using coco class as the class classifier +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + text2vision_projector=True, + keep_omg_decoder_frozen=True, + add_seg_pretrain=True, + pixel_shuffle_ratio=2, + visual_prompt_proj=False, + add_cross_attn_layer=True, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=llava_dataset, + sampler=dict(type=DefaultSampler, shuffle=True), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/pretrain/ablation_projector/ablation_projector_seperate_cross_rmProjloss.py b/omg_llava/configs/pretrain/ablation_projector/ablation_projector_seperate_cross_rmProjloss.py new file mode 100644 index 0000000000000000000000000000000000000000..3e709027e948ba534ba334ce087e3175a0ab1aa4 --- /dev/null +++ b/omg_llava/configs/pretrain/ablation_projector/ablation_projector_seperate_cross_rmProjloss.py @@ -0,0 +1,377 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.dataset import DefaultSampler +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model.convnext_clip import OpenCLIPBackbone_omgseg +from omg_llava.model.omg_seg import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model or model paths +llm_name_or_path = '/mnt/workspace/taozhang/chekpoints/internlm2-chat-7b' # Please change to your own path +omg_ov_class_embed_path='/mnt/workspace/taozhang/chekpoints/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = '/mnt/workspace/taozhang/chekpoints/omg_seg_convl.pth' # Please change to your own path + +# Data paths +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Pretrain/blip_laion_cc_sbu_558k.json' +image_folder = data_root + 'LLaVA-Pretrain/images' +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2) + +# Scheduler & Optimizer +batch_size = 16 # per_device +accumulative_counts = 4 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 1e-3 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + +# Save +save_steps = 500 +save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 200 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +# using coco class as the class classifier +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + text2vision_projector=True, + keep_omg_decoder_frozen=True, + add_seg_pretrain=True, + pixel_shuffle_ratio=2, + visual_prompt_proj=False, + add_cross_attn_layer=True, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=llava_dataset, + sampler=dict(type=DefaultSampler, shuffle=True), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/pretrain/omg_llava_20b_pretrain_1024image_8gpus.py b/omg_llava/configs/pretrain/omg_llava_20b_pretrain_1024image_8gpus.py new file mode 100644 index 0000000000000000000000000000000000000000..7e78c928accbb1cf0733fd978b108a437db7a1fd --- /dev/null +++ b/omg_llava/configs/pretrain/omg_llava_20b_pretrain_1024image_8gpus.py @@ -0,0 +1,379 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.dataset import DefaultSampler +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model.convnext_clip import OpenCLIPBackbone_omgseg +from omg_llava.model.omg_seg import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model or model paths +llm_name_or_path = './pretrained/omg_llava/internlm2-chat-20b' # Please change to your own path +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path + +# Data paths +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Pretrain/blip_laion_cc_sbu_558k.json' +image_folder = data_root + 'LLaVA-Pretrain/images' +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2) + +# Scheduler & Optimizer +batch_size = 16 # per_device +accumulative_counts = 2 +dataloader_num_workers = 0 +max_epochs = 1 +optim_type = AdamW +lr = 1e-3 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + +# Save +save_steps = 500 +save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 200 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +# using coco class as the class classifier +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + + + + + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + text2vision_projector=True, + keep_omg_decoder_frozen=True, + add_seg_pretrain=True, + pixel_shuffle_ratio=2, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=llava_dataset, + sampler=dict(type=DefaultSampler, shuffle=True), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/pretrain/omg_llava_7b_pretrain_1024image_8gpus.py b/omg_llava/configs/pretrain/omg_llava_7b_pretrain_1024image_8gpus.py new file mode 100644 index 0000000000000000000000000000000000000000..65321b73fe90a8c51ad3ff102a1be01d534a5e17 --- /dev/null +++ b/omg_llava/configs/pretrain/omg_llava_7b_pretrain_1024image_8gpus.py @@ -0,0 +1,375 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.dataset import DefaultSampler +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model.convnext_clip import OpenCLIPBackbone_omgseg +from omg_llava.model.omg_seg import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model or model paths +llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path + +# Data paths +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Pretrain/blip_laion_cc_sbu_558k.json' +image_folder = data_root + 'LLaVA-Pretrain/images' +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2) + +# Scheduler & Optimizer +batch_size = 16 # per_device +accumulative_counts = 2 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 1e-3 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + +# Save +save_steps = 500 +save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 200 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +# using coco class as the class classifier +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + text2vision_projector=True, + keep_omg_decoder_frozen=True, + add_seg_pretrain=True, + pixel_shuffle_ratio=2, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=llava_dataset, + sampler=dict(type=DefaultSampler, shuffle=True), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/configs/pretrain/omg_llava_7b_pretrain_1024image_convnextXXL_8gpus.py b/omg_llava/configs/pretrain/omg_llava_7b_pretrain_1024image_convnextXXL_8gpus.py new file mode 100644 index 0000000000000000000000000000000000000000..33f6780bb24f9396ec28597c91dc2dd445dd265e --- /dev/null +++ b/omg_llava/configs/pretrain/omg_llava_7b_pretrain_1024image_convnextXXL_8gpus.py @@ -0,0 +1,376 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.dataset import DefaultSampler +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from omg_llava.dataset import LLaVADataset +from omg_llava.dataset.collect_fns import omg_llava_collate_fn +from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory +from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens +from xtuner.engine.runner import TrainLoop +from omg_llava.model import OMG_LLaVA +from xtuner.utils import PROMPT_TEMPLATE +from omg_llava.model.convnext_clip import OpenCLIPBackbone_omgseg +from omg_llava.model.omg_seg import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead + +from torch.nn import GroupNorm, ReLU + +from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \ + DiceLoss, MaskFormerFusionHead, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model or model paths +llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path +omg_ov_class_embed_path='./pretrained/omg_llava/convnext_xxlarge_CocoPanopticOVDataset.pth' # Please change to your own path +omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convxxl.pth' # Please change to your own path + +# Data paths +data_root = './data/llava_data/' +data_path = data_root + 'LLaVA-Pretrain/blip_laion_cc_sbu_558k.json' +image_folder = data_root + 'LLaVA-Pretrain/images' +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = int(2048 - (1024 / 64)**2) + +# Scheduler & Optimizer +batch_size = 16 # per_device +accumulative_counts = 4 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 1e-3 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + +# Save +save_steps = 500 +save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 200 +SYSTEM = '' +evaluation_images = './work_dirs/test.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor, + do_resize=True, + size=1024, + resample=3, + do_center_crop=True, + crop_size=1024, + do_rescale=True, + do_normalize=True, + image_mean=[0.4814, 0.4578, 0.4082], + image_std=[0.2686, 0.2613, 0.2757], + do_convert_rgb=True +) + +# using coco class as the class classifier +class_embed = 'convnext_large_d_320_CocoPanopticOVDataset' +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +omgseg_model = dict( + type=OMGSegVisualEncoder, + data_preprocessor=None, + pixel_shuffle_down_ratio=2, + backbone=dict( + type=OpenCLIPBackbone_omgseg, + model_name='convnext_xxlarge', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s34b_b82k_augreg_soup' + ) + ), + panoptic_head=dict( + type=Mask2FormerVideoSemSamHead, + sphere_cls=True, + ov_path=omg_ov_class_embed_path, + enable_box_query=False, + ov_classifier_name=class_embed, + logit=None, + in_channels=[384, 768, 1536, 3072], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 240 + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean') + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + # dict(type=FlexibleClassificationCost, weight=2.0), + dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint=omg_head_pretrain_pth_path, + ) +) + +model = dict( + type=OMG_LLaVA, + freeze_llm=True, + freeze_visual_encoder=True, + text2vision_projector=True, + keep_omg_decoder_frozen=True, + add_seg_pretrain=True, + pixel_shuffle_ratio=2, + clip_feat_channel=3072, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + visual_encoder=omgseg_model, + tokenizer=tokenizer, +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +llava_dataset = dict( + type=LLaVADataset, + data_path=data_path, + image_folder=image_folder, + tokenizer=tokenizer, + image_processor=image_processor, + dataset_map_fn=llava_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + max_length=max_length, + pad_image_to_square=True, + debug=False, +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=llava_dataset, + sampler=dict(type=DefaultSampler, shuffle=True), + collate_fn=dict(type=omg_llava_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer), + dict( + type=EvaluateChatHook_withSpecialTokens, + tokenizer=tokenizer, + image_processor=image_processor, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + evaluation_images=evaluation_images, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/omg_llava/dataset/CombineDataset.py b/omg_llava/dataset/CombineDataset.py new file mode 100644 index 0000000000000000000000000000000000000000..db46456e27ffd52a66d36cb1ec0536af60b19657 --- /dev/null +++ b/omg_llava/dataset/CombineDataset.py @@ -0,0 +1,81 @@ +from torch.utils.data import Dataset +import numpy as np + +class CombineDataset(Dataset): + def __init__(self, + datasets_cfgs, + ): + super().__init__() + + self.datasets = [] + self.datasets_length = [] + + self.tokenizer = datasets_cfgs[0].tokenizer + tokenizer_type = self.tokenizer['type'] + del self.tokenizer['type'] + self.tokenizer = tokenizer_type(**self.tokenizer) + + self._add_special_tokens() + + for i in range(len(datasets_cfgs)): + datasets_cfgs[i].tokenizer = self.tokenizer + + for dataset_cfg in datasets_cfgs: + dataset = dataset_cfg['type'] + del dataset_cfg['type'] + dataset = dataset(**dataset_cfg) + self.datasets.append(dataset) + self.datasets_length.append(len(dataset)) + + self.dataset_threthold = [] + for i, length in enumerate(self.datasets_length): + if i == 0: + self.dataset_threthold.append(length) + else: + self.dataset_threthold.append(length + self.dataset_threthold[i - 1]) + + np.random.seed(42) + self.shuffled_index = np.arange(self.dataset_threthold[-1]) + np.random.shuffle(self.shuffled_index) + + @property + def modality_length(self): + length_list = [] + for dataset in self.datasets: + for data_dict in dataset.text_data: + cur_len = len(data_dict['input_ids']) + if data_dict.get('image', None) is None: + cur_len = -cur_len + length_list.append(cur_len) + return length_list + + def __len__(self): + return self.dataset_threthold[-1] + + def __getitem__(self, index): + index = int(self.shuffled_index[index]) + for i, thred in enumerate(self.dataset_threthold): + if index < thred: + break + + + if i == 0: + _index = index + else: + _index = index - self.dataset_threthold[i - 1] + + return self.datasets[i][_index] + + def _add_special_tokens(self): + assert hasattr(self, "tokenizer") + # Adding special tokens for pixel grounding + segmentation_tokens = ['[SEG]'] + # Adding tokens for GCG + phrase_tokens = ['
', '
'] + # add for visual prompt + region_tokens = ['', '
'] + # add for visual prompt + region_tokens = ['and [seg] to caption and select a question + question = random.choice(GCG_QUESTIONS).strip().format(caption) + + # Prepare caption with tags + def tag_caption(caption, tokens): + for start, end in sorted(tokens, key=lambda x: x[0], reverse=True): + caption = f"{caption[:start]}
{caption[start:end]}
[SEG]{caption[end:]}" + return caption + + detailed_answer = tag_caption(caption, tokens_positive) + + conversations = [{'from': 'human', 'value': question}, {'from': 'gpt', 'value': detailed_answer}] + return conversations + +def refcocog_preprocess_decoupled_given_description(example): + data_labels = example['labels'] + masks = example['masks'] + caption = example['caption'] + tokens_positive = example['tokens_positive'] + + # Function to sort elements based on the start index of each phrase + def sort_by_start_index(items, order): + return [items[i] for i in order] + + # Sort phrases based on their appearance in the sentence + phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0]) + masks = sort_by_start_index(masks, phrase_order) + data_labels = sort_by_start_index(data_labels, phrase_order) + tokens_positive = sort_by_start_index(tokens_positive, phrase_order) + + conversations = refcocog_conversation_decoupled_given_description(caption, tokens_positive) + example['conversations'] = conversations + example['labels'] = data_labels + example['masks'] = masks + example['tokens_positive'] = tokens_positive + + return example + +def refcocog_conversation_decoupled_given_objects(caption, tokens_positive): + # insertand [seg] to caption and select a question + object_tokens = '' + for i in range(len(tokens_positive)): + object_tokens = object_tokens + '
{caption[start:end]}
[SEG]{caption[end:]}" + return caption + + detailed_answer = tag_caption(caption, tokens_positive) + + conversations = [{'from': 'human', 'value': question}, {'from': 'gpt', 'value': detailed_answer}] + return conversations + +def refcocog_preprocess_decoupled_given_objects(example): + data_labels = example['labels'] + masks = example['masks'] + caption = example['caption'] + tokens_positive = example['tokens_positive'] + + # Function to sort elements based on the start index of each phrase + def sort_by_start_index(items, order): + return [items[i] for i in order] + + # Sort phrases based on their appearance in the sentence + phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0]) + masks = sort_by_start_index(masks, phrase_order) + data_labels = sort_by_start_index(data_labels, phrase_order) + tokens_positive = sort_by_start_index(tokens_positive, phrase_order) + + conversations = refcocog_conversation_decoupled_given_objects(caption, tokens_positive) + example['conversations'] = conversations + example['labels'] = data_labels + example['masks'] = masks + example['tokens_positive'] = tokens_positive + return example + +def glamm_refcocog_decoupled_given_description_map_fn(example): + # example {'id': str, 'refs': [{"setence", 'bbox', 'segmentation'},], 'img_file_name': str, 'caption': str} + + example = refcocog_parse_annotations(example) + # example 'labels': [], 'caption': str, 'masks': [], 'tokens_positive': [], 'file_name': image_file + + example = refcocog_preprocess_decoupled_given_description(example) + + # do llava preprocess + messages = example['conversations'] + input = '' + conversation = [] + while messages and messages[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + messages = messages[1:] + for msg in messages: + if msg['from'] == 'human': + if DEFAULT_IMAGE_TOKEN in msg['value']: + msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, + '').strip() + msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] + msg['value'] = msg['value'].strip() + input += msg['value'] + + elif msg['from'] == 'gpt': + conversation.append({'input': input, 'output': msg['value']}) + input = '' + else: + raise NotImplementedError + example.update({'conversation': conversation}) + return example + +def glamm_refcocog_decoupled_given_objects_map_fn(example): + # example {'id': str, 'refs': [{"setence", 'bbox', 'segmentation'},], 'img_file_name': str, 'caption': str} + + example = refcocog_parse_annotations(example) + # example 'labels': [], 'caption': str, 'masks': [], 'tokens_positive': [], 'file_name': image_file + + example = refcocog_preprocess_decoupled_given_objects(example) + + # do llava preprocess + messages = example['conversations'] + input = '' + conversation = [] + while messages and messages[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + messages = messages[1:] + for msg in messages: + if msg['from'] == 'human': + if DEFAULT_IMAGE_TOKEN in msg['value']: + msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, + '').strip() + msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] + msg['value'] = msg['value'].strip() + input += msg['value'] + + elif msg['from'] == 'gpt': + conversation.append({'input': input, 'output': msg['value']}) + input = '' + else: + raise NotImplementedError + example.update({'conversation': conversation}) + return example + +def grandf_parse_annotations(example): + image_path = example['file_name'] + annotations = { + 'labels': [], 'caption': [], 'masks': [], + 'tokens_positive': [], 'file_name': image_path, + 'image': image_path} + annotations['caption'] = example['caption'].strip('"').strip() + + for word, grounding in example["groundings"].items(): + if grounding is None: + continue + annotations['labels'].append(word) + annotations['tokens_positive'].append(grounding["token_positives"]) + annotations['masks'].append(grounding["rle_masks"]) + + return annotations + +def grandf_conversation_given_description(caption, tokens_positive): + question = random.choice(GCG_QUESTIONS).strip().format(caption) + + # Prepare caption with tags + def tag_caption(caption, tokens): + for start, end in sorted(tokens, key=lambda x: x[0], reverse=True): + caption = f"{caption[:start]}{caption[start:end]}
[SEG]{caption[end:]}" + return caption + + detailed_answer = tag_caption(caption, tokens_positive) + + conversations = [{'from': 'human', 'value': question}, {'from': 'gpt', 'value': detailed_answer}] + return conversations + +def grandf_conversation_given_objects(caption, tokens_positive): + object_tokens = '' + for i in range(len(tokens_positive)): + object_tokens = object_tokens + '{caption[start:end]}
[SEG]{caption[end:]}" + return caption + + detailed_answer = tag_caption(caption, tokens_positive) + + conversations = [{'from': 'human', 'value': question}, {'from': 'gpt', 'value': detailed_answer}] + return conversations + +def grandf_preprocess_given_description(example): + data_labels = example['labels'] + masks = example['masks'] + caption = example['caption'] + tokens_positive = example['tokens_positive'] + + # Function to sort elements based on the start index of each phrase + def sort_by_start_index(items, order): + return [items[i] for i in order] + + # Sort phrases based on their appearance in the sentence + phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0]) + masks = sort_by_start_index(masks, phrase_order) + data_labels = sort_by_start_index(data_labels, phrase_order) + tokens_positive = sort_by_start_index(tokens_positive, phrase_order) + + conversations = grandf_conversation_given_description(caption, tokens_positive) + example['conversations'] = conversations + example['labels'] = data_labels + example['masks'] = masks + example['tokens_positive'] = tokens_positive + return example + +def grandf_preprocess_given_objects(example): + data_labels = example['labels'] + masks = example['masks'] + caption = example['caption'] + tokens_positive = example['tokens_positive'] + + # Function to sort elements based on the start index of each phrase + def sort_by_start_index(items, order): + return [items[i] for i in order] + + # Sort phrases based on their appearance in the sentence + phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0]) + masks = sort_by_start_index(masks, phrase_order) + data_labels = sort_by_start_index(data_labels, phrase_order) + tokens_positive = sort_by_start_index(tokens_positive, phrase_order) + + conversations = grandf_conversation_given_objects(caption, tokens_positive) + example['conversations'] = conversations + example['labels'] = data_labels + example['masks'] = masks + example['tokens_positive'] = tokens_positive + return example + +def glamm_granf_decoupled_given_description_map_fn(example): + # example {'file_name': str, "height": int, "width": int, "image_id": str, caption: "str", + # "groundings": {ground_words: {'token_positives', 'rle_masks', }}} + example = grandf_parse_annotations(example) + # example 'labels': [], 'caption': str, 'masks': [], 'tokens_positive': [], 'file_name': image_file + + example = grandf_preprocess_given_description(example) + + # do llava preprocess + messages = example['conversations'] + input = '' + conversation = [] + while messages and messages[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + messages = messages[1:] + for msg in messages: + if msg['from'] == 'human': + if DEFAULT_IMAGE_TOKEN in msg['value']: + msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, + '').strip() + msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] + msg['value'] = msg['value'].strip() + input += msg['value'] + + elif msg['from'] == 'gpt': + conversation.append({'input': input, 'output': msg['value']}) + input = '' + else: + raise NotImplementedError + example.update({'conversation': conversation}) + return example + +def glamm_granf_decoupled_given_objects_map_fn(example): + # example {'file_name': str, "height": int, "width": int, "image_id": str, caption: "str", + # "groundings": {ground_words: {'token_positives', 'rle_masks', }}} + example = grandf_parse_annotations(example) + # example 'labels': [], 'caption': str, 'masks': [], 'tokens_positive': [], 'file_name': image_file + + example = grandf_preprocess_given_objects(example) + + # do llava preprocess + messages = example['conversations'] + input = '' + conversation = [] + while messages and messages[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + messages = messages[1:] + for msg in messages: + if msg['from'] == 'human': + if DEFAULT_IMAGE_TOKEN in msg['value']: + msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, + '').strip() + msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] + msg['value'] = msg['value'].strip() + input += msg['value'] + + elif msg['from'] == 'gpt': + conversation.append({'input': input, 'output': msg['value']}) + input = '' + else: + raise NotImplementedError + example.update({'conversation': conversation}) + return example + +glamm_openpsg_decoupled_given_objects_map_fn = glamm_granf_decoupled_given_objects_map_fn +glamm_openpsg_decoupled_given_description_map_fn = glamm_granf_decoupled_given_description_map_fn + +def flickr_parse_annotations(example): + annotations = {'bboxes': [], 'labels': [], 'bboxes_ignore': [], 'caption': example['caption'], 'masks': [], + 'tokens_positive': [], 'image': example['file_name']} + ann_info = example["ann_info"] + for ann in ann_info: + if ann.get('ignore', False): + continue + x1, y1, w, h = ann['bbox'] + inter_w = max(0, min(x1 + w, example['width']) - max(x1, 0)) + inter_h = max(0, min(y1 + h, example['height']) - max(y1, 0)) + if inter_w * inter_h == 0 or ann['area'] <= 0 or w < 1 or h < 1: + continue + bbox = [x1, y1, x1 + w, y1 + h] + annotations['bboxes'].append(bbox) + tokens_positive = ann['tokens_positive'] + gt_label = [example['caption'][span[0]:span[1]] for span in tokens_positive] + annotations['labels'].append(gt_label[0]) + annotations['tokens_positive'].append(tokens_positive[0]) + + rle = ann['sam_mask'] + annotations['masks'].append(rle) + + # Convert bounding boxes to numpy arrays + annotations['bboxes'] = np.array(annotations['bboxes'], dtype=np.float32) if annotations[ + 'bboxes'] else np.zeros((0, 4), dtype=np.float32) + annotations['bboxes_ignore'] = np.array(annotations['bboxes_ignore'], dtype=np.float32) if annotations[ + 'bboxes_ignore'] else np.zeros((0, 4), dtype=np.float32) + return annotations + +def flickr_preprocess_given_description(example): + data_labels = example['labels'] + masks = example['masks'] + caption = example['caption'] + tokens_positive = example['tokens_positive'] + + # Function to sort elements based on the start index of each phrase + def sort_by_start_index(items, order): + return [items[i] for i in order] + + # Sort phrases based on their appearance in the sentence + phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0]) + masks = sort_by_start_index(masks, phrase_order) + data_labels = sort_by_start_index(data_labels, phrase_order) + tokens_positive = sort_by_start_index(tokens_positive, phrase_order) + + conversations = grandf_conversation_given_description(caption, tokens_positive) + example['conversations'] = conversations + example['labels'] = data_labels + example['masks'] = masks + example['tokens_positive'] = tokens_positive + return example + +def flickr_preprocess_given_objects(example): + data_labels = example['labels'] + masks = example['masks'] + caption = example['caption'] + tokens_positive = example['tokens_positive'] + + # Function to sort elements based on the start index of each phrase + def sort_by_start_index(items, order): + return [items[i] for i in order] + + # Sort phrases based on their appearance in the sentence + phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0]) + masks = sort_by_start_index(masks, phrase_order) + data_labels = sort_by_start_index(data_labels, phrase_order) + tokens_positive = sort_by_start_index(tokens_positive, phrase_order) + + conversations = grandf_conversation_given_objects(caption, tokens_positive) + example['conversations'] = conversations + example['labels'] = data_labels + example['masks'] = masks + example['tokens_positive'] = tokens_positive + return example + +def glamm_flickr_decoupled_given_description_map_fn(example): + # example {'file_name': str, "height": int, "width": int, "image_id": str, caption: "str", + # "groundings": {ground_words: {'token_positives', 'rle_masks', }}} + + example = flickr_parse_annotations(example) + + example = flickr_preprocess_given_description(example) + + # do llava preprocess + messages = example['conversations'] + input = '' + conversation = [] + while messages and messages[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + messages = messages[1:] + for msg in messages: + if msg['from'] == 'human': + if DEFAULT_IMAGE_TOKEN in msg['value']: + msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, + '').strip() + msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] + msg['value'] = msg['value'].strip() + input += msg['value'] + + elif msg['from'] == 'gpt': + conversation.append({'input': input, 'output': msg['value']}) + input = '' + else: + raise NotImplementedError + example.update({'conversation': conversation}) + return example + +def glamm_flickr_decoupled_given_objects_map_fn(example): + # example {'file_name': str, "height": int, "width": int, "image_id": str, caption: "str", + # "groundings": {ground_words: {'token_positives', 'rle_masks', }}} + + example = flickr_parse_annotations(example) + + example = flickr_preprocess_given_objects(example) + + # do llava preprocess + messages = example['conversations'] + input = '' + conversation = [] + while messages and messages[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + messages = messages[1:] + for msg in messages: + if msg['from'] == 'human': + if DEFAULT_IMAGE_TOKEN in msg['value']: + msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, + '').strip() + msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] + msg['value'] = msg['value'].strip() + input += msg['value'] + + elif msg['from'] == 'gpt': + conversation.append({'input': input, 'output': msg['value']}) + input = '' + else: + raise NotImplementedError + example.update({'conversation': conversation}) + return example + + + + + diff --git a/omg_llava/dataset/process_functions/gcg_process.py b/omg_llava/dataset/process_functions/gcg_process.py new file mode 100644 index 0000000000000000000000000000000000000000..8befcbcc48375184f4d557b38791cce5b78360d8 --- /dev/null +++ b/omg_llava/dataset/process_functions/gcg_process.py @@ -0,0 +1,297 @@ +import numpy as np +import random +from xtuner.utils import DEFAULT_IMAGE_TOKEN + +GCG_QUESTIONS = [ + DEFAULT_IMAGE_TOKEN + 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.', + DEFAULT_IMAGE_TOKEN + 'Can you provide a thorough description of the this image? Please output with interleaved segmentation masks for the corresponding phrases.', + DEFAULT_IMAGE_TOKEN + 'Please describe in detail the contents of the image. Please respond with interleaved segmentation masks for the corresponding parts of the answer.', + DEFAULT_IMAGE_TOKEN + 'Could you give a comprehensive explanation of what can be found within this picture? Please output with interleaved segmentation masks for the corresponding phrases.', + DEFAULT_IMAGE_TOKEN + 'Could you give me an elaborate explanation of this picture? Please respond with interleaved segmentation masks for the corresponding phrases.', + DEFAULT_IMAGE_TOKEN + 'Could you provide me with a detailed analysis of this photo? Please output with interleaved segmentation masks for the corresponding parts of the answer.', +] + +def refcocog_parse_annotations(example): + # example {'id': str, 'refs': [{"setence", 'bbox', 'segmentation'},], 'img_file_name': str, 'caption': str} + annotations = {'labels': [], 'caption': [], 'masks': [], 'tokens_positive': [], + 'file_name': example['img_file_name'], 'image': example['img_file_name']} + + orig_caption = example['caption'].strip('"').strip() + annotations['caption'] = orig_caption.lower() + + for detail in example['refs']: + phrase = detail['sentence'] + if phrase.lower() in annotations['caption']: + annotations['labels'].append(phrase) + index = annotations['caption'].find(phrase) + end_index = index + len(phrase) if index != -1 else -1 + annotations['tokens_positive'].append([index, end_index]) + # still polygon or rle + annotations['masks'].append(detail["segmentation"]) + + # Sort tokens_positive and corresponding lists + tokens_positive = annotations['tokens_positive'] + sorted_indices = sorted(range(len(tokens_positive)), key=lambda i: tokens_positive[i][0]) + annotations['tokens_positive'] = [tokens_positive[i] for i in sorted_indices] + annotations['masks'] = [annotations['masks'][i] for i in sorted_indices] + annotations['labels'] = [annotations['labels'][i] for i in sorted_indices] + + # Trimming overlapping intervals + for i in range(len(tokens_positive)): + for j in range(i + 1, len(tokens_positive)): + # If there is overlap + if tokens_positive[i][1] >= tokens_positive[j][0]: + # Modify the end index of phrase i to be one less than the start index of phrase j + tokens_positive[i][1] = tokens_positive[j][0] - 1 + # Modify the phrases to reflect the change in indices + annotations['labels'][i] = orig_caption[tokens_positive[i][0]:tokens_positive[i][1] + 1] + break # Exit inner loop since i was modified + + return annotations + +def refcocog_conversation(caption, tokens_positive): + # insertand [seg] to caption and select a question + question = random.choice(GCG_QUESTIONS).strip() + + # Prepare caption with tags + def tag_caption(caption, tokens): + for start, end in sorted(tokens, key=lambda x: x[0], reverse=True): + caption = f"{caption[:start]}
{caption[start:end]}
[SEG]{caption[end:]}" + return caption + + detailed_answer = tag_caption(caption, tokens_positive) + + conversations = [{'from': 'human', 'value': question}, {'from': 'gpt', 'value': detailed_answer}] + return conversations + +def refcocog_preprocess(example): + data_labels = example['labels'] + masks = example['masks'] + caption = example['caption'] + tokens_positive = example['tokens_positive'] + + # Function to sort elements based on the start index of each phrase + def sort_by_start_index(items, order): + return [items[i] for i in order] + + # Sort phrases based on their appearance in the sentence + phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0]) + masks = sort_by_start_index(masks, phrase_order) + data_labels = sort_by_start_index(data_labels, phrase_order) + tokens_positive = sort_by_start_index(tokens_positive, phrase_order) + + conversations = refcocog_conversation(caption, tokens_positive) + example['conversations'] = conversations + example['labels'] = data_labels + example['masks'] = masks + example['tokens_positive'] = tokens_positive + + return example + +def glamm_refcocog_map_fn(example): + # example {'id': str, 'refs': [{"setence", 'bbox', 'segmentation'},], 'img_file_name': str, 'caption': str} + + example = refcocog_parse_annotations(example) + # example 'labels': [], 'caption': str, 'masks': [], 'tokens_positive': [], 'file_name': image_file + + example = refcocog_preprocess(example) + + # do llava preprocess + messages = example['conversations'] + input = '' + conversation = [] + while messages and messages[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + messages = messages[1:] + for msg in messages: + if msg['from'] == 'human': + if DEFAULT_IMAGE_TOKEN in msg['value']: + msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, + '').strip() + msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] + msg['value'] = msg['value'].strip() + input += msg['value'] + + elif msg['from'] == 'gpt': + conversation.append({'input': input, 'output': msg['value']}) + input = '' + else: + raise NotImplementedError + example.update({'conversation': conversation}) + return example + +def grandf_parse_annotations(example): + image_path = example['file_name'] + annotations = { + 'labels': [], 'caption': [], 'masks': [], + 'tokens_positive': [], 'file_name': image_path, + 'image': image_path} + annotations['caption'] = example['caption'].strip('"').strip() + + for word, grounding in example["groundings"].items(): + if grounding is None: + continue + annotations['labels'].append(word) + annotations['tokens_positive'].append(grounding["token_positives"]) + annotations['masks'].append(grounding["rle_masks"]) + + return annotations + +def grandf_conversation(caption, tokens_positive): + question = random.choice(GCG_QUESTIONS).strip() + + # Prepare caption with tags + def tag_caption(caption, tokens): + for start, end in sorted(tokens, key=lambda x: x[0], reverse=True): + caption = f"{caption[:start]}{caption[start:end]}
[SEG]{caption[end:]}" + return caption + + detailed_answer = tag_caption(caption, tokens_positive) + + conversations = [{'from': 'human', 'value': question}, {'from': 'gpt', 'value': detailed_answer}] + return conversations +def grandf_preprocess(example): + data_labels = example['labels'] + masks = example['masks'] + caption = example['caption'] + tokens_positive = example['tokens_positive'] + + # Function to sort elements based on the start index of each phrase + def sort_by_start_index(items, order): + return [items[i] for i in order] + + # Sort phrases based on their appearance in the sentence + phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0]) + masks = sort_by_start_index(masks, phrase_order) + data_labels = sort_by_start_index(data_labels, phrase_order) + tokens_positive = sort_by_start_index(tokens_positive, phrase_order) + + conversations = grandf_conversation(caption, tokens_positive) + example['conversations'] = conversations + example['labels'] = data_labels + example['masks'] = masks + example['tokens_positive'] = tokens_positive + return example + +def glamm_granf_map_fn(example): + # example {'file_name': str, "height": int, "width": int, "image_id": str, caption: "str", + # "groundings": {ground_words: {'token_positives', 'rle_masks', }}} + example = grandf_parse_annotations(example) + # example 'labels': [], 'caption': str, 'masks': [], 'tokens_positive': [], 'file_name': image_file + + example = grandf_preprocess(example) + + # do llava preprocess + messages = example['conversations'] + input = '' + conversation = [] + while messages and messages[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + messages = messages[1:] + for msg in messages: + if msg['from'] == 'human': + if DEFAULT_IMAGE_TOKEN in msg['value']: + msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, + '').strip() + msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] + msg['value'] = msg['value'].strip() + input += msg['value'] + + elif msg['from'] == 'gpt': + conversation.append({'input': input, 'output': msg['value']}) + input = '' + else: + raise NotImplementedError + example.update({'conversation': conversation}) + return example + +glamm_openpsg_map_fn = glamm_granf_map_fn + +def flickr_parse_annotations(example): + annotations = {'bboxes': [], 'labels': [], 'bboxes_ignore': [], 'caption': example['caption'], 'masks': [], + 'tokens_positive': [], 'image': example['file_name']} + ann_info = example["ann_info"] + for ann in ann_info: + if ann.get('ignore', False): + continue + x1, y1, w, h = ann['bbox'] + inter_w = max(0, min(x1 + w, example['width']) - max(x1, 0)) + inter_h = max(0, min(y1 + h, example['height']) - max(y1, 0)) + if inter_w * inter_h == 0 or ann['area'] <= 0 or w < 1 or h < 1: + continue + bbox = [x1, y1, x1 + w, y1 + h] + annotations['bboxes'].append(bbox) + tokens_positive = ann['tokens_positive'] + gt_label = [example['caption'][span[0]:span[1]] for span in tokens_positive] + annotations['labels'].append(gt_label[0]) + annotations['tokens_positive'].append(tokens_positive[0]) + + rle = ann['sam_mask'] + annotations['masks'].append(rle) + + # Convert bounding boxes to numpy arrays + annotations['bboxes'] = np.array(annotations['bboxes'], dtype=np.float32) if annotations[ + 'bboxes'] else np.zeros((0, 4), dtype=np.float32) + annotations['bboxes_ignore'] = np.array(annotations['bboxes_ignore'], dtype=np.float32) if annotations[ + 'bboxes_ignore'] else np.zeros((0, 4), dtype=np.float32) + return annotations + +def flickr_preprocess(example): + data_labels = example['labels'] + masks = example['masks'] + caption = example['caption'] + tokens_positive = example['tokens_positive'] + + # Function to sort elements based on the start index of each phrase + def sort_by_start_index(items, order): + return [items[i] for i in order] + + # Sort phrases based on their appearance in the sentence + phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0]) + masks = sort_by_start_index(masks, phrase_order) + data_labels = sort_by_start_index(data_labels, phrase_order) + tokens_positive = sort_by_start_index(tokens_positive, phrase_order) + + conversations = grandf_conversation(caption, tokens_positive) + example['conversations'] = conversations + example['labels'] = data_labels + example['masks'] = masks + example['tokens_positive'] = tokens_positive + return example + +def glamm_flickr_map_fn(example): + # example {'file_name': str, "height": int, "width": int, "image_id": str, caption: "str", + # "groundings": {ground_words: {'token_positives', 'rle_masks', }}} + + example = flickr_parse_annotations(example) + + example = flickr_preprocess(example) + + # do llava preprocess + messages = example['conversations'] + input = '' + conversation = [] + while messages and messages[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + messages = messages[1:] + for msg in messages: + if msg['from'] == 'human': + if DEFAULT_IMAGE_TOKEN in msg['value']: + msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, + '').strip() + msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] + msg['value'] = msg['value'].strip() + input += msg['value'] + + elif msg['from'] == 'gpt': + conversation.append({'input': input, 'output': msg['value']}) + input = '' + else: + raise NotImplementedError + example.update({'conversation': conversation}) + return example + + + + + diff --git a/omg_llava/dataset/process_functions/mdpv_points_process.py b/omg_llava/dataset/process_functions/mdpv_points_process.py new file mode 100644 index 0000000000000000000000000000000000000000..4f966703b53356e2da6e604441d931b175477619 --- /dev/null +++ b/omg_llava/dataset/process_functions/mdpv_points_process.py @@ -0,0 +1,52 @@ +from xtuner.utils import DEFAULT_IMAGE_TOKEN + +def mdpv_points_preprocess(example): + conversations = example['conversations'] + num_marks = example['num_marks'] + + for i, conversation in enumerate(conversations): + if i == 0: + role = conversation['from'] + assert role == 'human' + question = DEFAULT_IMAGE_TOKEN + 'There are some marks:' + for i in range(num_marks): + question = question + ' Mark {} '.format(i + 1) + if i + 1 == num_marks: + question = question + '.\n' + else: + question = question + ',' + question = question + conversation['value'].replace('<', '').replace('>', '') + conversation['value'] = question + else: + conversation['value'] = conversation['value'].replace('<', '').replace('>', '') + + example['conversations'] = conversations + return example + +def mdpv_points_map_fn(example): + # examples {'image', 'conversations'} + example = mdpv_points_preprocess(example) + + # do llava preprocess + messages = example['conversations'] + input = '' + conversation = [] + while messages and messages[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + messages = messages[1:] + for msg in messages: + if msg['from'] == 'human': + if DEFAULT_IMAGE_TOKEN in msg['value']: + msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, + '').strip() + msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] + msg['value'] = msg['value'].strip() + input += msg['value'] + + elif msg['from'] == 'gpt': + conversation.append({'input': input, 'output': msg['value']}) + input = '' + else: + raise NotImplementedError + example.update({'conversation': conversation}) + return example \ No newline at end of file diff --git a/omg_llava/dataset/process_functions/referring_seg_process.py b/omg_llava/dataset/process_functions/referring_seg_process.py new file mode 100644 index 0000000000000000000000000000000000000000..679f7a32ffba957a8f27ee8bf2e4b7715859dd0c --- /dev/null +++ b/omg_llava/dataset/process_functions/referring_seg_process.py @@ -0,0 +1,135 @@ +import random +from xtuner.utils import DEFAULT_IMAGE_TOKEN + +SEG_QUESTIONS = [ + "Can you segment the {class_name} in this image?", + "Please segment {class_name} in this image.", + "What is {class_name} in this image? Please respond with segmentation mask.", + "What is {class_name} in this image? Please output segmentation mask.", + + "Can you segment the {class_name} in this image", + "Please segment {class_name} in this image", + "What is {class_name} in this image? Please respond with segmentation mask", + "What is {class_name} in this image? Please output segmentation mask", + + "Could you provide a segmentation mask for the {class_name} in this image?", + "Please identify and segment the {class_name} in this image.", + "Where is the {class_name} in this picture? Please respond with a segmentation mask.", + "Can you highlight the {class_name} in this image with a segmentation mask?", + + "Could you provide a segmentation mask for the {class_name} in this image", + "Please identify and segment the {class_name} in this image", + "Where is the {class_name} in this picture? Please respond with a segmentation mask", + "Can you highlight the {class_name} in this image with a segmentation mask", +] + +ANSWER_LIST = [ + "It is [SEG].", + "Sure, [SEG].", + "Sure, it is [SEG].", + "Sure, the segmentation result is [SEG].", + "[SEG].", +] + +ANSWER_LIST_GCG_FORMAT = [ + "{}
[SEG].", +] + +def referring_seg_conversations(labels): + questions = [] + answers = [] + for i, label in enumerate(labels): + label = label.strip() + assert len(label.split("||")) == 1 + question_template = random.choice(SEG_QUESTIONS) + questions.append(question_template.format(class_name=label.lower())) + answers.append(random.choice(ANSWER_LIST)) + ret = [] + for i, (question, answer) in enumerate(zip(questions, answers)): + if i == 0: + ret.append( + {'from': 'human', 'value': DEFAULT_IMAGE_TOKEN+question} + ) + else: + ret.append( + {'from': 'human', 'value': question} + ) + ret.append( + {'from': 'gpt', 'value': answer} + ) + return ret + +def referring_seg_map_fn(example): + # example {'sampled_sents'} + messages = referring_seg_conversations(example['sampled_sents']) + input = '' + conversation = [] + while messages and messages[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + messages = messages[1:] + for msg in messages: + if msg['from'] == 'human': + if DEFAULT_IMAGE_TOKEN in msg['value']: + msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, + '').strip() + msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] + msg['value'] = msg['value'].strip() + input += msg['value'] + + elif msg['from'] == 'gpt': + conversation.append({'input': input, 'output': msg['value']}) + input = '' + else: + raise NotImplementedError + example.update({'conversation': conversation}) + return example + +def referring_seg_gcg_format_conversations(labels): + questions = [] + answers = [] + for i, label in enumerate(labels): + label = label.strip() + assert len(label.split("||")) == 1 + question_template = random.choice(SEG_QUESTIONS) + questions.append(question_template.format(class_name=label.lower())) + answers.append(random.choice(ANSWER_LIST_GCG_FORMAT).format(label.lower().capitalize())) + ret = [] + for i, (question, answer) in enumerate(zip(questions, answers)): + if i == 0: + ret.append( + {'from': 'human', 'value': DEFAULT_IMAGE_TOKEN+question} + ) + else: + ret.append( + {'from': 'human', 'value': question} + ) + ret.append( + {'from': 'gpt', 'value': answer} + ) + return ret + +def referring_seg_gcg_format_map_fn(example): + # example {'sampled_sents'} + + messages = referring_seg_gcg_format_conversations(example['sampled_sents']) + input = '' + conversation = [] + while messages and messages[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + messages = messages[1:] + for msg in messages: + if msg['from'] == 'human': + if DEFAULT_IMAGE_TOKEN in msg['value']: + msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, + '').strip() + msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] + msg['value'] = msg['value'].strip() + input += msg['value'] + + elif msg['from'] == 'gpt': + conversation.append({'input': input, 'output': msg['value']}) + input = '' + else: + raise NotImplementedError + example.update({'conversation': conversation}) + return example \ No newline at end of file diff --git a/omg_llava/dataset/process_functions/region_caption_process.py b/omg_llava/dataset/process_functions/region_caption_process.py new file mode 100644 index 0000000000000000000000000000000000000000..847abefeaac32351cd3e0849ed2f588f2901f7d3 --- /dev/null +++ b/omg_llava/dataset/process_functions/region_caption_process.py @@ -0,0 +1,223 @@ +import numpy as np +import random +from xtuner.utils import DEFAULT_IMAGE_TOKEN +import re + +REGION_QUESTIONS = [ + 'Can you provide me with a detailed description of the region in the picture marked byRegion{}
[SEG].".format(selected_seg_idx) + questions.append(question) + answers.append(answer) + + conversations = [] + for question, answer in zip(questions, answers): + conversations.append({'from': 'human', 'value': question}) + conversations.append({'from': 'gpt', 'value': answer}) + return conversations, [selected_seg_idx - 1] + +def region_caption_preprocess(example): + descriptions = example['description'] + + # random select some labels + if len(descriptions) >= 3: + sampled_inds = np.random.choice( + list(range(len(descriptions))), size=3, replace=False + ) + else: + sampled_inds = list(range(len(descriptions))) + + selected_descriptions = [descriptions[idx] for idx in sampled_inds] + selected_descriptions = [re.sub(r'<[^>]*>', '{}
[SEG].", +] + +def semantic_seg_conversations(labels): + ret = [] + for i, label in enumerate(labels): + label = label.strip() + assert len(label.split("||")) == 1 + for question_template in SEG_QUESTIONS: + for answer_template in ANSWER_LIST: + item = {} + item['conversations'] = [{'from': 'human', 'value': DEFAULT_IMAGE_TOKEN+question_template.format(class_name=label.lower())}, + {'from': 'gpt', 'value': answer_template}] + item['class_id'] = i + ret.append(item) + return ret + +def semantic_seg_map_fn(example): + # example {'conversations', 'class_id'} + messages = example['conversations'] + input = '' + conversation = [] + while messages and messages[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + messages = messages[1:] + for msg in messages: + if msg['from'] == 'human': + if DEFAULT_IMAGE_TOKEN in msg['value']: + msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, + '').strip() + msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] + msg['value'] = msg['value'].strip() + input += msg['value'] + + elif msg['from'] == 'gpt': + conversation.append({'input': input, 'output': msg['value']}) + input = '' + else: + raise NotImplementedError + example.update({'conversation': conversation}) + return example + +def pascal_part_conversation(selected_labels): + conversations = [] + for i, selected_label in enumerate(selected_labels): + question = random.choice(SEG_QUESTIONS).format(class_name=selected_label.lower()).strip() + answer = random.choice(ANSWER_LIST) + if i == 0: + question = DEFAULT_IMAGE_TOKEN + question + conversations.append({'from': 'human', 'value': question}) + conversations.append({'from': 'gpt', 'value': answer}) + return conversations + +def pascal_part_preprocess(example): + selected_labels = example["selected_labels"] + conversations = pascal_part_conversation(selected_labels) + example['conversations'] = conversations + return example + +def pascal_part_map_fn(example): + example = pascal_part_preprocess(example) + example['image'] = example["file_name"] + # do llava preprocess + messages = example['conversations'] + input = '' + conversation = [] + while messages and messages[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + messages = messages[1:] + for msg in messages: + if msg['from'] == 'human': + if DEFAULT_IMAGE_TOKEN in msg['value']: + msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, + '').strip() + msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] + msg['value'] = msg['value'].strip() + input += msg['value'] + + elif msg['from'] == 'gpt': + conversation.append({'input': input, 'output': msg['value']}) + input = '' + else: + raise NotImplementedError + example.update({'conversation': conversation}) + return example + + +def semantic_seg_gcg_format_conversations(labels): + ret = [] + for i, label in enumerate(labels): + label = label.strip() + assert len(label.split("||")) == 1 + for question_template in SEG_QUESTIONS: + for answer_template in ANSWER_LIST_GCG_FORMAT: + item = {} + item['conversations'] = [{'from': 'human', 'value': DEFAULT_IMAGE_TOKEN+question_template.format(class_name=label.lower())}, + {'from': 'gpt', 'value': answer_template.format(label.lower().capitalize())}] + item['class_id'] = i + ret.append(item) + return ret + +def semantic_seg_gcg_format_map_fn(example): + # example {'conversations', 'class_id'} + messages = example['conversations'] + input = '' + conversation = [] + while messages and messages[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + messages = messages[1:] + for msg in messages: + if msg['from'] == 'human': + if DEFAULT_IMAGE_TOKEN in msg['value']: + msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, + '').strip() + msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] + msg['value'] = msg['value'].strip() + input += msg['value'] + + elif msg['from'] == 'gpt': + conversation.append({'input': input, 'output': msg['value']}) + input = '' + else: + raise NotImplementedError + example.update({'conversation': conversation}) + return example + +def pascal_part_gcg_format_conversation(selected_labels): + conversations = [] + for i, selected_label in enumerate(selected_labels): + question = random.choice(SEG_QUESTIONS).format(class_name=selected_label.lower()).strip() + answer = random.choice(ANSWER_LIST).format(selected_label.lower().capitalize()) + if i == 0: + question = DEFAULT_IMAGE_TOKEN + question + conversations.append({'from': 'human', 'value': question}) + conversations.append({'from': 'gpt', 'value': answer}) + return conversations + +def pascal_part_gcg_format_preprocess(example): + selected_labels = example["selected_labels"] + conversations = pascal_part_gcg_format_conversation(selected_labels) + example['conversations'] = conversations + return example + +def pascal_part_gcg_format_map_fn(example): + example = pascal_part_gcg_format_preprocess(example) + example['image'] = example["file_name"] + # do llava preprocess + messages = example['conversations'] + input = '' + conversation = [] + while messages and messages[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + messages = messages[1:] + for msg in messages: + if msg['from'] == 'human': + if DEFAULT_IMAGE_TOKEN in msg['value']: + msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, + '').strip() + msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] + msg['value'] = msg['value'].strip() + input += msg['value'] + + elif msg['from'] == 'gpt': + conversation.append({'input': input, 'output': msg['value']}) + input = '' + else: + raise NotImplementedError + example.update({'conversation': conversation}) + return example + + diff --git a/omg_llava/dataset/utils/__init__.py b/omg_llava/dataset/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..085858dbf4fcac7d3432bd9d1e308e33b0cd3e23 --- /dev/null +++ b/omg_llava/dataset/utils/__init__.py @@ -0,0 +1 @@ +from .utils import expand2square, expand2square_mask, expand2square_points, expand2square_bbox \ No newline at end of file diff --git a/omg_llava/dataset/utils/__pycache__/__init__.cpython-310.pyc b/omg_llava/dataset/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..361a6b6b14324d36eae3cfd6d4a9fec5b3acdc1f Binary files /dev/null and b/omg_llava/dataset/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/omg_llava/dataset/utils/__pycache__/refcoco_refer.cpython-310.pyc b/omg_llava/dataset/utils/__pycache__/refcoco_refer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff42d24c2743cb09f65512ccea331bad4f88708e Binary files /dev/null and b/omg_llava/dataset/utils/__pycache__/refcoco_refer.cpython-310.pyc differ diff --git a/omg_llava/dataset/utils/__pycache__/utils.cpython-310.pyc b/omg_llava/dataset/utils/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94360a333f042ea28f815f77b477a83f2c688743 Binary files /dev/null and b/omg_llava/dataset/utils/__pycache__/utils.cpython-310.pyc differ diff --git a/omg_llava/dataset/utils/ade20k_classes.json b/omg_llava/dataset/utils/ade20k_classes.json new file mode 100644 index 0000000000000000000000000000000000000000..1f96e616bc3fd2f8c0ec4caea975d77c680f44bb --- /dev/null +++ b/omg_llava/dataset/utils/ade20k_classes.json @@ -0,0 +1,30 @@ +[ + "wall", "building", "sky", "floor", "tree", "ceiling", "road", + "bed", "windowpane", "grass", "cabinet", "sidewalk", + "person", "earth", "door", "table", "mountain", "plant", + "curtain", "chair", "car", "water", "painting", "sofa", + "shelf", "house", "sea", "mirror", "rug", "field", "armchair", + "seat", "fence", "desk", "rock", "wardrobe", "lamp", + "bathtub", "railing", "cushion", "base", "box", "column", + "signboard", "chest of drawers", "counter", "sand", "sink", + "skyscraper", "fireplace", "refrigerator", "grandstand", + "path", "stairs", "runway", "case", "pool table", "pillow", + "screen door", "stairway", "river", "bridge", "bookcase", + "blind", "coffee table", "toilet", "flower", "book", "hill", + "bench", "countertop", "stove", "palm", "kitchen island", + "computer", "swivel chair", "boat", "bar", "arcade machine", + "hovel", "bus", "towel", "light", "truck", "tower", + "chandelier", "awning", "streetlight", "booth", + "television receiver", "airplane", "dirt track", "apparel", + "pole", "land", "bannister", "escalator", "ottoman", "bottle", + "buffet", "poster", "stage", "van", "ship", "fountain", + "conveyer belt", "canopy", "washer", "plaything", + "swimming pool", "stool", "barrel", "basket", "waterfall", + "tent", "bag", "minibike", "cradle", "oven", "ball", "food", + "step", "tank", "trade name", "microwave", "pot", "animal", + "bicycle", "lake", "dishwasher", "screen", "blanket", + "sculpture", "hood", "sconce", "vase", "traffic light", + "tray", "ashcan", "fan", "pier", "crt screen", "plate", + "monitor", "bulletin board", "shower", "radiator", "glass", + "clock", "flag" +] \ No newline at end of file diff --git a/omg_llava/dataset/utils/cocostuff_classes.txt b/omg_llava/dataset/utils/cocostuff_classes.txt new file mode 100644 index 0000000000000000000000000000000000000000..1d5a692b83ac8eead2bfffa805e1115cef737bae --- /dev/null +++ b/omg_llava/dataset/utils/cocostuff_classes.txt @@ -0,0 +1,183 @@ +0: unlabeled +1: person +2: bicycle +3: car +4: motorcycle +5: airplane +6: bus +7: train +8: truck +9: boat +10: traffic light +11: fire hydrant +12: street sign +13: stop sign +14: parking meter +15: bench +16: bird +17: cat +18: dog +19: horse +20: sheep +21: cow +22: elephant +23: bear +24: zebra +25: giraffe +26: hat +27: backpack +28: umbrella +29: shoe +30: eye glasses +31: handbag +32: tie +33: suitcase +34: frisbee +35: skis +36: snowboard +37: sports ball +38: kite +39: baseball bat +40: baseball glove +41: skateboard +42: surfboard +43: tennis racket +44: bottle +45: plate +46: wine glass +47: cup +48: fork +49: knife +50: spoon +51: bowl +52: banana +53: apple +54: sandwich +55: orange +56: broccoli +57: carrot +58: hot dog +59: pizza +60: donut +61: cake +62: chair +63: couch +64: potted plant +65: bed +66: mirror +67: dining table +68: window +69: desk +70: toilet +71: door +72: tv +73: laptop +74: mouse +75: remote +76: keyboard +77: cell phone +78: microwave +79: oven +80: toaster +81: sink +82: refrigerator +83: blender +84: book +85: clock +86: vase +87: scissors +88: teddy bear +89: hair drier +90: toothbrush +91: hair brush +92: banner +93: blanket +94: branch +95: bridge +96: building-other +97: bush +98: cabinet +99: cage +100: cardboard +101: carpet +102: ceiling-other +103: ceiling-tile +104: cloth +105: clothes +106: clouds +107: counter +108: cupboard +109: curtain +110: desk-stuff +111: dirt +112: door-stuff +113: fence +114: floor-marble +115: floor-other +116: floor-stone +117: floor-tile +118: floor-wood +119: flower +120: fog +121: food-other +122: fruit +123: furniture-other +124: grass +125: gravel +126: ground-other +127: hill +128: house +129: leaves +130: light +131: mat +132: metal +133: mirror-stuff +134: moss +135: mountain +136: mud +137: napkin +138: net +139: paper +140: pavement +141: pillow +142: plant-other +143: plastic +144: platform +145: playingfield +146: railing +147: railroad +148: river +149: road +150: rock +151: roof +152: rug +153: salad +154: sand +155: sea +156: shelf +157: sky +158: skyscraper +159: snow +160: solid-other +161: stairs +162: stone +163: straw +164: structural-other +165: table +166: tent +167: textile-other +168: towel +169: tree +170: vegetable +171: wall-brick +172: wall-concrete +173: wall-other +174: wall-panel +175: wall-stone +176: wall-tile +177: wall-wood +178: water-other +179: waterdrops +180: window-blind +181: window-other +182: wood diff --git a/omg_llava/dataset/utils/grefer.py b/omg_llava/dataset/utils/grefer.py new file mode 100644 index 0000000000000000000000000000000000000000..3c881c5860a2bbfc89eb91b8fcf91cc32c27fbbf --- /dev/null +++ b/omg_llava/dataset/utils/grefer.py @@ -0,0 +1,352 @@ +""" +grefer v0.1 +This interface provides access to gRefCOCO. + +The following API functions are defined: +G_REFER - REFER api class +getRefIds - get ref ids that satisfy given filter conditions. +getAnnIds - get ann ids that satisfy given filter conditions. +getImgIds - get image ids that satisfy given filter conditions. +getCatIds - get category ids that satisfy given filter conditions. +loadRefs - load refs with the specified ref ids. +loadAnns - load anns with the specified ann ids. +loadImgs - load images with the specified image ids. +loadCats - load category names with the specified category ids. +getRefBox - get ref's bounding box [x, y, w, h] given the ref_id +showRef - show image, segmentation or box of the referred object with the ref +getMaskByRef - get mask and area of the referred object given ref or ref ids +getMask - get mask and area of the referred object given ref +showMask - show mask of the referred object given ref +""" + +import itertools +import json +import os.path as osp +import pickle +import time + +import matplotlib.pyplot as plt +import numpy as np +import skimage.io as io +from matplotlib.collections import PatchCollection +from matplotlib.patches import Polygon, Rectangle +from pycocotools import mask + + +class G_REFER: + def __init__(self, data_root, dataset="grefcoco", splitBy="unc"): + # provide data_root folder which contains grefcoco + print("loading dataset %s into memory..." % dataset) + self.ROOT_DIR = osp.abspath(osp.dirname(__file__)) + self.DATA_DIR = osp.join(data_root, dataset) + if dataset in ["grefcoco"]: + self.IMAGE_DIR = osp.join(data_root, "images/train2014") + else: + raise KeyError("No refer dataset is called [%s]" % dataset) + + tic = time.time() + + # load refs from data/dataset/refs(dataset).json + self.data = {} + self.data["dataset"] = dataset + + ref_file = osp.join(self.DATA_DIR, f"grefs({splitBy}).p") + if osp.exists(ref_file): + self.data["refs"] = pickle.load(open(ref_file, "rb"), fix_imports=True) + else: + ref_file = osp.join(self.DATA_DIR, f"grefs({splitBy}).json") + if osp.exists(ref_file): + self.data["refs"] = json.load(open(ref_file, "rb")) + else: + raise FileNotFoundError("JSON file not found") + + # load annotations from data/dataset/instances.json + instances_file = osp.join(self.DATA_DIR, "instances.json") + instances = json.load(open(instances_file, "r")) + self.data["images"] = instances["images"] + self.data["annotations"] = instances["annotations"] + self.data["categories"] = instances["categories"] + + # create index + self.createIndex() + print("DONE (t=%.2fs)" % (time.time() - tic)) + + @staticmethod + def _toList(x): + return x if isinstance(x, list) else [x] + + @staticmethod + def match_any(a, b): + a = a if isinstance(a, list) else [a] + b = b if isinstance(b, list) else [b] + return set(a) & set(b) + + def createIndex(self): + # create sets of mapping + # 1) Refs: {ref_id: ref} + # 2) Anns: {ann_id: ann} + # 3) Imgs: {image_id: image} + # 4) Cats: {category_id: category_name} + # 5) Sents: {sent_id: sent} + # 6) imgToRefs: {image_id: refs} + # 7) imgToAnns: {image_id: anns} + # 8) refToAnn: {ref_id: ann} + # 9) annToRef: {ann_id: ref} + # 10) catToRefs: {category_id: refs} + # 11) sentToRef: {sent_id: ref} + # 12) sentToTokens: {sent_id: tokens} + print("creating index...") + # fetch info from instances + Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {} + Anns[-1] = None + for ann in self.data["annotations"]: + Anns[ann["id"]] = ann + imgToAnns[ann["image_id"]] = imgToAnns.get(ann["image_id"], []) + [ann] + for img in self.data["images"]: + Imgs[img["id"]] = img + for cat in self.data["categories"]: + Cats[cat["id"]] = cat["name"] + + # fetch info from refs + Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {} + Sents, sentToRef, sentToTokens = {}, {}, {} + availableSplits = [] + for ref in self.data["refs"]: + # ids + ref_id = ref["ref_id"] + ann_id = ref["ann_id"] + category_id = ref["category_id"] + image_id = ref["image_id"] + + if ref["split"] not in availableSplits: + availableSplits.append(ref["split"]) + + # add mapping related to ref + if ref_id in Refs: + print("Duplicate ref id") + Refs[ref_id] = ref + imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref] + + category_id = self._toList(category_id) + added_cats = [] + for cat in category_id: + if cat not in added_cats: + added_cats.append(cat) + catToRefs[cat] = catToRefs.get(cat, []) + [ref] + + ann_id = self._toList(ann_id) + refToAnn[ref_id] = [Anns[ann] for ann in ann_id] + for ann_id_n in ann_id: + annToRef[ann_id_n] = annToRef.get(ann_id_n, []) + [ref] + + # add mapping of sent + for sent in ref["sentences"]: + Sents[sent["sent_id"]] = sent + sentToRef[sent["sent_id"]] = ref + sentToTokens[sent["sent_id"]] = sent["tokens"] + + # create class members + self.Refs = Refs + self.Anns = Anns + self.Imgs = Imgs + self.Cats = Cats + self.Sents = Sents + self.imgToRefs = imgToRefs + self.imgToAnns = imgToAnns + self.refToAnn = refToAnn + self.annToRef = annToRef + self.catToRefs = catToRefs + self.sentToRef = sentToRef + self.sentToTokens = sentToTokens + self.availableSplits = availableSplits + print("index created.") + + def getRefIds(self, image_ids=[], cat_ids=[], split=[]): + image_ids = self._toList(image_ids) + cat_ids = self._toList(cat_ids) + split = self._toList(split) + + for s in split: + if s not in self.availableSplits: + raise ValueError(f"Invalid split name: {s}") + + refs = self.data["refs"] + + if len(image_ids) > 0: + lists = [self.imgToRefs[image_id] for image_id in image_ids] + refs = list(itertools.chain.from_iterable(lists)) + if len(cat_ids) > 0: + refs = [ref for ref in refs if self.match_any(ref["category_id"], cat_ids)] + if len(split) > 0: + refs = [ref for ref in refs if ref["split"] in split] + + ref_ids = [ref["ref_id"] for ref in refs] + return ref_ids + + def getAnnIds(self, image_ids=[], ref_ids=[]): + image_ids = self._toList(image_ids) + ref_ids = self._toList(ref_ids) + + if any([len(image_ids), len(ref_ids)]): + if len(image_ids) > 0: + lists = [ + self.imgToAnns[image_id] + for image_id in image_ids + if image_id in self.imgToAnns + ] + anns = list(itertools.chain.from_iterable(lists)) + else: + anns = self.data["annotations"] + ann_ids = [ann["id"] for ann in anns] + if len(ref_ids) > 0: + lists = [self.Refs[ref_id]["ann_id"] for ref_id in ref_ids] + anns_by_ref_id = list(itertools.chain.from_iterable(lists)) + ann_ids = list(set(ann_ids).intersection(set(anns_by_ref_id))) + else: + ann_ids = [ann["id"] for ann in self.data["annotations"]] + + return ann_ids + + def getImgIds(self, ref_ids=[]): + ref_ids = self._toList(ref_ids) + + if len(ref_ids) > 0: + image_ids = list(set([self.Refs[ref_id]["image_id"] for ref_id in ref_ids])) + else: + image_ids = self.Imgs.keys() + return image_ids + + def getCatIds(self): + return self.Cats.keys() + + def loadRefs(self, ref_ids=[]): + return [self.Refs[ref_id] for ref_id in self._toList(ref_ids)] + + def loadAnns(self, ann_ids=[]): + if isinstance(ann_ids, str): + ann_ids = int(ann_ids) + return [self.Anns[ann_id] for ann_id in self._toList(ann_ids)] + + def loadImgs(self, image_ids=[]): + return [self.Imgs[image_id] for image_id in self._toList(image_ids)] + + def loadCats(self, cat_ids=[]): + return [self.Cats[cat_id] for cat_id in self._toList(cat_ids)] + + def getRefBox(self, ref_id): + anns = self.refToAnn[ref_id] + return [ann["bbox"] for ann in anns] # [x, y, w, h] + + def showRef(self, ref, seg_box="seg"): + ax = plt.gca() + # show image + image = self.Imgs[ref["image_id"]] + I = io.imread(osp.join(self.IMAGE_DIR, image["file_name"])) + ax.imshow(I) + # show refer expression + for sid, sent in enumerate(ref["sentences"]): + print("%s. %s" % (sid + 1, sent["sent"])) + # show segmentations + if seg_box == "seg": + ann_id = ref["ann_id"] + ann = self.Anns[ann_id] + polygons = [] + color = [] + c = "none" + if type(ann["segmentation"][0]) == list: + # polygon used for refcoco* + for seg in ann["segmentation"]: + poly = np.array(seg).reshape((len(seg) / 2, 2)) + polygons.append(Polygon(poly, True, alpha=0.4)) + color.append(c) + p = PatchCollection( + polygons, + facecolors=color, + edgecolors=(1, 1, 0, 0), + linewidths=3, + alpha=1, + ) + ax.add_collection(p) # thick yellow polygon + p = PatchCollection( + polygons, + facecolors=color, + edgecolors=(1, 0, 0, 0), + linewidths=1, + alpha=1, + ) + ax.add_collection(p) # thin red polygon + else: + # mask used for refclef + rle = ann["segmentation"] + m = mask.decode(rle) + img = np.ones((m.shape[0], m.shape[1], 3)) + color_mask = np.array([2.0, 166.0, 101.0]) / 255 + for i in range(3): + img[:, :, i] = color_mask[i] + ax.imshow(np.dstack((img, m * 0.5))) + # show bounding-box + elif seg_box == "box": + ann_id = ref["ann_id"] + ann = self.Anns[ann_id] + bbox = self.getRefBox(ref["ref_id"]) + box_plot = Rectangle( + (bbox[0], bbox[1]), + bbox[2], + bbox[3], + fill=False, + edgecolor="green", + linewidth=3, + ) + ax.add_patch(box_plot) + + def getMask(self, ann): + if not ann: + return None + if ann["iscrowd"]: + raise ValueError("Crowd object") + image = self.Imgs[ann["image_id"]] + if type(ann["segmentation"][0]) == list: # polygon + rle = mask.frPyObjects(ann["segmentation"], image["height"], image["width"]) + else: + rle = ann["segmentation"] + + m = mask.decode(rle) + m = np.sum( + m, axis=2 + ) # sometimes there are multiple binary map (corresponding to multiple segs) + m = m.astype(np.uint8) # convert to np.uint8 + # compute area + area = sum(mask.area(rle)) # should be close to ann['area'] + return {"mask": m, "area": area} + + def getMaskByRef(self, ref=None, ref_id=None, merge=False): + if not ref and not ref_id: + raise ValueError + if ref: + ann_ids = ref["ann_id"] + ref_id = ref["ref_id"] + else: + ann_ids = self.getAnnIds(ref_ids=ref_id) + + if ann_ids == [-1]: + img = self.Imgs[self.Refs[ref_id]["image_id"]] + return { + "mask": np.zeros([img["height"], img["width"]], dtype=np.uint8), + "empty": True, + } + + anns = self.loadAnns(ann_ids) + mask_list = [self.getMask(ann) for ann in anns if not ann["iscrowd"]] + + if merge: + merged_masks = sum([mask["mask"] for mask in mask_list]) + merged_masks[np.where(merged_masks > 1)] = 1 + return {"mask": merged_masks, "empty": False} + else: + return mask_list + + def showMask(self, ref): + M = self.getMask(ref) + msk = M["mask"] + ax = plt.gca() + ax.imshow(msk) diff --git a/omg_llava/dataset/utils/refcoco_refer.py b/omg_llava/dataset/utils/refcoco_refer.py new file mode 100644 index 0000000000000000000000000000000000000000..3b4cea716e40e73d0b5aa118143eb076392f5eb1 --- /dev/null +++ b/omg_llava/dataset/utils/refcoco_refer.py @@ -0,0 +1,391 @@ +__author__ = "licheng" + +""" +This interface provides access to four datasets: +1) refclef +2) refcoco +3) refcoco+ +4) refcocog +split by unc and google + +The following API functions are defined: +REFER - REFER api class +getRefIds - get ref ids that satisfy given filter conditions. +getAnnIds - get ann ids that satisfy given filter conditions. +getImgIds - get image ids that satisfy given filter conditions. +getCatIds - get category ids that satisfy given filter conditions. +loadRefs - load refs with the specified ref ids. +loadAnns - load anns with the specified ann ids. +loadImgs - load images with the specified image ids. +loadCats - load category names with the specified category ids. +getRefBox - get ref's bounding box [x, y, w, h] given the ref_id +showRef - show image, segmentation or box of the referred object with the ref +getMask - get mask and area of the referred object given ref +showMask - show mask of the referred object given ref +""" + +import itertools +import json +import os.path as osp +import pickle +import sys +import time +from pprint import pprint + +import matplotlib.pyplot as plt +import numpy as np +import skimage.io as io +from matplotlib.collections import PatchCollection +from matplotlib.patches import Polygon, Rectangle +from pycocotools import mask + + +class REFER: + def __init__(self, data_root, dataset="refcoco", splitBy="unc"): + # provide data_root folder which contains refclef, refcoco, refcoco+ and refcocog + # also provide dataset name and splitBy information + # e.g., dataset = 'refcoco', splitBy = 'unc' + print("loading dataset %s into memory..." % dataset) + self.ROOT_DIR = osp.abspath(osp.dirname(__file__)) + self.DATA_DIR = osp.join(data_root, dataset) + if dataset in ["refcoco", "refcoco+", "refcocog"]: + self.IMAGE_DIR = osp.join(data_root, "images/mscoco/images/train2014") + elif dataset == "refclef": + self.IMAGE_DIR = osp.join(data_root, "images/saiapr_tc-12") + else: + print("No refer dataset is called [%s]" % dataset) + sys.exit() + + self.dataset = dataset + + # load refs from data/dataset/refs(dataset).json + tic = time.time() + + ref_file = osp.join(self.DATA_DIR, "refs(" + splitBy + ").p") + print("ref_file: ", ref_file) + self.data = {} + self.data["dataset"] = dataset + self.data["refs"] = pickle.load(open(ref_file, "rb")) + + # load annotations from data/dataset/instances.json + instances_file = osp.join(self.DATA_DIR, "instances.json") + instances = json.load(open(instances_file, "rb")) + self.data["images"] = instances["images"] + self.data["annotations"] = instances["annotations"] + self.data["categories"] = instances["categories"] + + # create index + self.createIndex() + print("DONE (t=%.2fs)" % (time.time() - tic)) + + def createIndex(self): + # create sets of mapping + # 1) Refs: {ref_id: ref} + # 2) Anns: {ann_id: ann} + # 3) Imgs: {image_id: image} + # 4) Cats: {category_id: category_name} + # 5) Sents: {sent_id: sent} + # 6) imgToRefs: {image_id: refs} + # 7) imgToAnns: {image_id: anns} + # 8) refToAnn: {ref_id: ann} + # 9) annToRef: {ann_id: ref} + # 10) catToRefs: {category_id: refs} + # 11) sentToRef: {sent_id: ref} + # 12) sentToTokens: {sent_id: tokens} + print("creating index...") + # fetch info from instances + Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {} + for ann in self.data["annotations"]: + Anns[ann["id"]] = ann + imgToAnns[ann["image_id"]] = imgToAnns.get(ann["image_id"], []) + [ann] + for img in self.data["images"]: + Imgs[img["id"]] = img + for cat in self.data["categories"]: + Cats[cat["id"]] = cat["name"] + + # fetch info from refs + Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {} + Sents, sentToRef, sentToTokens = {}, {}, {} + for ref in self.data["refs"]: + # ids + ref_id = ref["ref_id"] + ann_id = ref["ann_id"] + category_id = ref["category_id"] + image_id = ref["image_id"] + + # add mapping related to ref + Refs[ref_id] = ref + imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref] + catToRefs[category_id] = catToRefs.get(category_id, []) + [ref] + refToAnn[ref_id] = Anns[ann_id] + annToRef[ann_id] = ref + + # add mapping of sent + for sent in ref["sentences"]: + Sents[sent["sent_id"]] = sent + sentToRef[sent["sent_id"]] = ref + sentToTokens[sent["sent_id"]] = sent["tokens"] + + # create class members + self.Refs = Refs + self.Anns = Anns + self.Imgs = Imgs + self.Cats = Cats + self.Sents = Sents + self.imgToRefs = imgToRefs + self.imgToAnns = imgToAnns + self.refToAnn = refToAnn + self.annToRef = annToRef + self.catToRefs = catToRefs + self.sentToRef = sentToRef + self.sentToTokens = sentToTokens + print("index created.") + + def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=""): + image_ids = image_ids if type(image_ids) == list else [image_ids] + cat_ids = cat_ids if type(cat_ids) == list else [cat_ids] + ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] + + if len(image_ids) == len(cat_ids) == len(ref_ids) == len(split) == 0: + refs = self.data["refs"] + else: + if not len(image_ids) == 0: + refs = [self.imgToRefs[image_id] for image_id in image_ids] + else: + refs = self.data["refs"] + if not len(cat_ids) == 0: + refs = [ref for ref in refs if ref["category_id"] in cat_ids] + if not len(ref_ids) == 0: + refs = [ref for ref in refs if ref["ref_id"] in ref_ids] + if not len(split) == 0: + if split in ["testA", "testB", "testC"]: + refs = [ + ref for ref in refs if split[-1] in ref["split"] + ] # we also consider testAB, testBC, ... + elif split in ["testAB", "testBC", "testAC"]: + refs = [ + ref for ref in refs if ref["split"] == split + ] # rarely used I guess... + elif split == "test": + refs = [ref for ref in refs if "test" in ref["split"]] + elif split == "train" or split == "val": + refs = [ref for ref in refs if ref["split"] == split] + else: + print("No such split [%s]" % split) + sys.exit() + ref_ids = [ref["ref_id"] for ref in refs] + return ref_ids + + def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]): + image_ids = image_ids if type(image_ids) == list else [image_ids] + cat_ids = cat_ids if type(cat_ids) == list else [cat_ids] + ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] + + if len(image_ids) == len(cat_ids) == len(ref_ids) == 0: + ann_ids = [ann["id"] for ann in self.data["annotations"]] + else: + if not len(image_ids) == 0: + lists = [ + self.imgToAnns[image_id] + for image_id in image_ids + if image_id in self.imgToAnns + ] # list of [anns] + anns = list(itertools.chain.from_iterable(lists)) + else: + anns = self.data["annotations"] + if not len(cat_ids) == 0: + anns = [ann for ann in anns if ann["category_id"] in cat_ids] + ann_ids = [ann["id"] for ann in anns] + if not len(ref_ids) == 0: + ids = set(ann_ids).intersection( + set([self.Refs[ref_id]["ann_id"] for ref_id in ref_ids]) + ) + return ann_ids + + def getImgIds(self, ref_ids=[]): + ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] + + if not len(ref_ids) == 0: + image_ids = list(set([self.Refs[ref_id]["image_id"] for ref_id in ref_ids])) + else: + image_ids = self.Imgs.keys() + return image_ids + + def getCatIds(self): + return self.Cats.keys() + + def loadRefs(self, ref_ids=[]): + if type(ref_ids) == list: + return [self.Refs[ref_id] for ref_id in ref_ids] + elif type(ref_ids) == int: + return [self.Refs[ref_ids]] + + def loadAnns(self, ann_ids=[]): + if type(ann_ids) == list: + return [self.Anns[ann_id] for ann_id in ann_ids] + elif type(ann_ids) == int or type(ann_ids) == unicode: + return [self.Anns[ann_ids]] + + def loadImgs(self, image_ids=[]): + if type(image_ids) == list: + return [self.Imgs[image_id] for image_id in image_ids] + elif type(image_ids) == int: + return [self.Imgs[image_ids]] + + def loadCats(self, cat_ids=[]): + if type(cat_ids) == list: + return [self.Cats[cat_id] for cat_id in cat_ids] + elif type(cat_ids) == int: + return [self.Cats[cat_ids]] + + def getRefBox(self, ref_id): + ref = self.Refs[ref_id] + ann = self.refToAnn[ref_id] + return ann["bbox"] # [x, y, w, h] + + def showRef(self, ref, seg_box="seg"): + ax = plt.gca() + # show image + image = self.Imgs[ref["image_id"]] + I = io.imread(osp.join(self.IMAGE_DIR, image["file_name"])) + ax.imshow(I) + # show refer expression + for sid, sent in enumerate(ref["sentences"]): + print("%s. %s" % (sid + 1, sent["sent"])) + # show segmentations + if seg_box == "seg": + ann_id = ref["ann_id"] + ann = self.Anns[ann_id] + polygons = [] + color = [] + c = "none" + if type(ann["segmentation"][0]) == list: + # polygon used for refcoco* + for seg in ann["segmentation"]: + poly = np.array(seg).reshape((len(seg) / 2, 2)) + polygons.append(Polygon(poly, True, alpha=0.4)) + color.append(c) + p = PatchCollection( + polygons, + facecolors=color, + edgecolors=(1, 1, 0, 0), + linewidths=3, + alpha=1, + ) + ax.add_collection(p) # thick yellow polygon + p = PatchCollection( + polygons, + facecolors=color, + edgecolors=(1, 0, 0, 0), + linewidths=1, + alpha=1, + ) + ax.add_collection(p) # thin red polygon + else: + # mask used for refclef + rle = ann["segmentation"] + m = mask.decode(rle) + img = np.ones((m.shape[0], m.shape[1], 3)) + color_mask = np.array([2.0, 166.0, 101.0]) / 255 + for i in range(3): + img[:, :, i] = color_mask[i] + ax.imshow(np.dstack((img, m * 0.5))) + # show bounding-box + elif seg_box == "box": + ann_id = ref["ann_id"] + ann = self.Anns[ann_id] + bbox = self.getRefBox(ref["ref_id"]) + box_plot = Rectangle( + (bbox[0], bbox[1]), + bbox[2], + bbox[3], + fill=False, + edgecolor="green", + linewidth=3, + ) + ax.add_patch(box_plot) + + def getMask(self, ref): + # return mask, area and mask-center + ann = self.refToAnn[ref["ref_id"]] + image = self.Imgs[ref["image_id"]] + if type(ann["segmentation"][0]) == list: # polygon + rle = mask.frPyObjects(ann["segmentation"], image["height"], image["width"]) + else: + rle = ann["segmentation"] + m = mask.decode(rle) + m = np.sum( + m, axis=2 + ) # sometimes there are multiple binary map (corresponding to multiple segs) + m = m.astype(np.uint8) # convert to np.uint8 + # compute area + area = sum(mask.area(rle)) # should be close to ann['area'] + return {"mask": m, "area": area} + # # position + # position_x = np.mean(np.where(m==1)[1]) # [1] means columns (matlab style) -> x (c style) + # position_y = np.mean(np.where(m==1)[0]) # [0] means rows (matlab style) -> y (c style) + # # mass position (if there were multiple regions, we use the largest one.) + # label_m = label(m, connectivity=m.ndim) + # regions = regionprops(label_m) + # if len(regions) > 0: + # largest_id = np.argmax(np.array([props.filled_area for props in regions])) + # largest_props = regions[largest_id] + # mass_y, mass_x = largest_props.centroid + # else: + # mass_x, mass_y = position_x, position_y + # # if centroid is not in mask, we find the closest point to it from mask + # if m[mass_y, mass_x] != 1: + # print('Finding closes mask point ...') + # kernel = np.ones((10, 10),np.uint8) + # me = cv2.erode(m, kernel, iterations = 1) + # points = zip(np.where(me == 1)[0].tolist(), np.where(me == 1)[1].tolist()) # row, col style + # points = np.array(points) + # dist = np.sum((points - (mass_y, mass_x))**2, axis=1) + # id = np.argsort(dist)[0] + # mass_y, mass_x = points[id] + # # return + # return {'mask': m, 'area': area, 'position_x': position_x, 'position_y': position_y, 'mass_x': mass_x, 'mass_y': mass_y} + # # show image and mask + # I = io.imread(osp.join(self.IMAGE_DIR, image['file_name'])) + # plt.figure() + # plt.imshow(I) + # ax = plt.gca() + # img = np.ones( (m.shape[0], m.shape[1], 3) ) + # color_mask = np.array([2.0,166.0,101.0])/255 + # for i in range(3): + # img[:,:,i] = color_mask[i] + # ax.imshow(np.dstack( (img, m*0.5) )) + # plt.show() + + def showMask(self, ref): + M = self.getMask(ref) + msk = M["mask"] + ax = plt.gca() + ax.imshow(msk) + + +if __name__ == "__main__": + refer = REFER(dataset="refcocog", splitBy="google") + ref_ids = refer.getRefIds() + print(len(ref_ids)) + + print(len(refer.Imgs)) + print(len(refer.imgToRefs)) + + ref_ids = refer.getRefIds(split="train") + print("There are %s training referred objects." % len(ref_ids)) + + for ref_id in ref_ids: + ref = refer.loadRefs(ref_id)[0] + if len(ref["sentences"]) < 2: + continue + + pprint(ref) + print("The label is %s." % refer.Cats[ref["category_id"]]) + plt.figure() + refer.showRef(ref, seg_box="box") + plt.show() + + # plt.figure() + # refer.showMask(ref) + # plt.show() diff --git a/omg_llava/dataset/utils/utils.py b/omg_llava/dataset/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7b4a6cdd89e09c4f3a78fc2270b6c12cc5b6d77d --- /dev/null +++ b/omg_llava/dataset/utils/utils.py @@ -0,0 +1,71 @@ +import numpy as np +from PIL import Image + +def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + +def expand2square_mask(mask): + # mask (n, h, w) + n_mask, width, height = mask.shape + if width == height: + return mask + elif width > height: + n_pad = width - height + n_pad_1 = n_pad // 2 + n_pad_2 = n_pad - n_pad_1 + pad_mask_1 = np.zeros((n_mask, width, n_pad_1), dtype=np.uint8) + pad_mask_2 = np.zeros((n_mask, width, n_pad_2), dtype=np.uint8) + result = np.concatenate([pad_mask_1, mask, pad_mask_2], axis=2) + return result + else: + n_pad = height - width + n_pad_1 = n_pad // 2 + n_pad_2 = n_pad - n_pad_1 + pad_mask_1 = np.zeros((n_mask, n_pad_1, height), dtype=np.uint8) + pad_mask_2 = np.zeros((n_mask, n_pad_2, height), dtype=np.uint8) + result = np.concatenate([pad_mask_1, mask, pad_mask_2], axis=1) + return result + +def expand2square_bbox(bboxes, width, height): + bboxes = np.array(bboxes) + if width == height: + return bboxes + elif width > height: + n_pad = width - height + n_pad_1 = n_pad // 2 + n_pad_2 = n_pad - n_pad_1 + bboxes[:, 1] += n_pad_1 + return bboxes + else: + n_pad = height - width + n_pad_1 = n_pad // 2 + n_pad_2 = n_pad - n_pad_1 + bboxes[:, 0] += n_pad_1 + return bboxes + +def expand2square_points(points, width, height): + if width == height: + return points + elif width > height: + n_pad = width - height + n_pad_1 = n_pad // 2 + n_pad_2 = n_pad - n_pad_1 + points[:, 1] += n_pad_1 + return points + else: + n_pad = height - width + n_pad_1 = n_pad // 2 + n_pad_2 = n_pad - n_pad_1 + points[:, 0] += n_pad_1 + return points + diff --git a/omg_llava/engine/__init__.py b/omg_llava/engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f7d1e8d0a1ceb1bd7a89f55661902b5f3c895191 --- /dev/null +++ b/omg_llava/engine/__init__.py @@ -0,0 +1,2 @@ +from .dataset_info_hook import DatasetInfoHook_withSpecoalTokens +from .evaluate_chat_hook import EvaluateChatHook_withSpecialTokens \ No newline at end of file diff --git a/omg_llava/engine/dataset_info_hook.py b/omg_llava/engine/dataset_info_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..3d7aa45dafbf97a231c139c8a19caaed32232db5 --- /dev/null +++ b/omg_llava/engine/dataset_info_hook.py @@ -0,0 +1,17 @@ +from xtuner.registry import BUILDER +from xtuner.engine.hooks import DatasetInfoHook + +class DatasetInfoHook_withSpecoalTokens(DatasetInfoHook): + def __init__(self, tokenizer, is_intern_repo_dataset=False): + self.tokenizer = BUILDER.build(tokenizer) + self.is_intern_repo_dataset = is_intern_repo_dataset + # add special tokens + # Adding special tokens for pixel grounding + segmentation_tokens = ['[SEG]'] + # Adding tokens for GCG + phrase_tokens = ['', '
'] + # add for visual prompt + region_tokens = ['', '
'] + # add for visual prompt + region_tokens = ['', '
'] + # add for visual prompt + region_tokens = ['", add_special_tokens=False).input_ids[0] + self.eop_token_idx = self.tokenizer("
", add_special_tokens=False).input_ids[0] + self.region_token_idx = self.tokenizer(": {},
: {},", "" + ) + markdown_out = markdown_out.replace("
", "") + + for color in colors: + markdown_out = markdown_out.replace("[COLOR]", str(desaturate(tuple(color))), 1) + + markdown_out = f""" + {markdown_out} + """ + markdown_out = markdown_default + "" + markdown_out
+ return markdown_out
+
+def show_mask_pred(image, masks, crop_range=(0, 1024, 0, 1024)):
+ print(crop_range)
+
+ selected_colors = []
+
+ colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255),
+ (255, 255, 0), (255, 0, 255), (0, 255, 255),
+ (128, 128, 255), [255, 192, 203], # Pink
+ [165, 42, 42], # Brown
+ [255, 165, 0], # Orange
+ [128, 0, 128], # Purple
+ [0, 0, 128], # Navy
+ [128, 0, 0], # Maroon
+ [128, 128, 0], # Olive
+ [70, 130, 180], # Steel Blue
+ [173, 216, 230], # Light Blue
+ [255, 192, 0], # Gold
+ [255, 165, 165], # Light Salmon
+ [255, 20, 147], # Deep Pink
+ ]
+
+ masks = F.interpolate(masks, size=image.size, mode='bilinear', align_corners=False)
+ masks = masks.sigmoid() > 0.5
+ masks = masks.to(torch.uint8).cpu().numpy()[:, 0]
+
+ _mask_image = np.zeros((masks.shape[1], masks.shape[2], 3), dtype=np.uint8)
+
+ for i, mask in enumerate(masks):
+ color = colors[i % len(colors)]
+ selected_colors.append(color)
+ _mask_image[:, :, 0] = _mask_image[:, :, 0] + mask.astype(np.uint8) * color[0]
+ _mask_image[:, :, 1] = _mask_image[:, :, 1] + mask.astype(np.uint8) * color[1]
+ _mask_image[:, :, 2] = _mask_image[:, :, 2] + mask.astype(np.uint8) * color[2]
+
+
+ image = np.array(image)
+ image = image * 0.5 + _mask_image * 0.5
+ image = image.astype(np.uint8)
+ image = image[crop_range[2]: crop_range[3], crop_range[0]: crop_range[1], :]
+ # image = Image.fromarray(image)
+ # image.save(save_dir)
+ return image, selected_colors
+
+def parse_visual_prompts(points):
+ ret = {'points': [], 'boxes': []}
+ for item in points:
+ if item[2] == 1.0:
+ ret['points'].append([item[0], item[1]])
+ elif item[2] == 2.0 or item[2] == 3.0:
+ ret['boxes'].append(item[[0, 1, 3, 4]])
+ else:
+ raise NotImplementedError
+ return ret
\ No newline at end of file
diff --git a/omg_llava/tools/chat_omg_llava.py b/omg_llava/tools/chat_omg_llava.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f0cde989e8ca42055bd9e90b38b8a46ab25fc65
--- /dev/null
+++ b/omg_llava/tools/chat_omg_llava.py
@@ -0,0 +1,518 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import copy
+import os
+import os.path as osp
+import re
+import sys
+
+import torch
+from huggingface_hub import snapshot_download
+from peft import PeftModel
+from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer,
+ BitsAndBytesConfig, CLIPImageProcessor,
+ CLIPVisionModel, GenerationConfig)
+from transformers.generation.streamers import TextStreamer
+
+from xtuner.dataset.utils import expand2square, load_image
+from xtuner.model.utils import prepare_inputs_labels_for_multimodal
+from xtuner.tools.utils import get_stop_criteria
+from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX,
+ PROMPT_TEMPLATE, SYSTEM_TEMPLATE)
+
+import argparse
+import os.path as osp
+
+from mmengine.config import Config, DictAction
+from mmengine.fileio import PetrelBackend, get_file_backend
+
+from xtuner.configs import cfgs_name_path
+from xtuner.model.utils import guess_load_checkpoint
+from xtuner.registry import BUILDER
+
+TORCH_DTYPE_MAP = dict(
+ fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto')
+
+from xtuner.engine.hooks.evaluate_chat_hook import EvaluateChatHook
+
+def remove_prefix(state_dict, prefix):
+ new_state_dict = {}
+ for key, value in state_dict.items():
+ if key.startswith(prefix):
+ new_key = key[len(prefix):]
+ new_state_dict[new_key] = value
+ else:
+ new_state_dict[key] = value
+ return new_state_dict
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Chat with a HF model')
+ parser.add_argument('config', help='config file name or path.')
+ parser.add_argument('pth_model', help='pth model file')
+
+ parser.add_argument('--image', default=None, help='image')
+ parser.add_argument(
+ '--torch-dtype',
+ default='fp16',
+ choices=TORCH_DTYPE_MAP.keys(),
+ help='Override the default `torch.dtype` and load the model under '
+ 'a specific `dtype`.')
+ parser.add_argument(
+ '--prompt-template',
+ choices=PROMPT_TEMPLATE.keys(),
+ default="internlm2_chat",
+ help='Specify a prompt template')
+ system_group = parser.add_mutually_exclusive_group()
+ system_group.add_argument(
+ '--system', default=None, help='Specify the system text')
+ system_group.add_argument(
+ '--system-template',
+ choices=SYSTEM_TEMPLATE.keys(),
+ default=None,
+ help='Specify a system template')
+ parser.add_argument(
+ '--bits',
+ type=int,
+ choices=[4, 8, None],
+ default=None,
+ help='LLM bits')
+ parser.add_argument(
+ '--bot-name', type=str, default='BOT', help='Name for Bot')
+ parser.add_argument(
+ '--with-plugins',
+ nargs='+',
+ choices=['calculate', 'solve', 'search'],
+ help='Specify plugins to use')
+ parser.add_argument(
+ '--no-streamer', action='store_true', help='Whether to with streamer')
+ parser.add_argument(
+ '--lagent', action='store_true', help='Whether to use lagent')
+ parser.add_argument(
+ '--stop-words', nargs='+', type=str, default=[], help='Stop words')
+ parser.add_argument(
+ '--offload-folder',
+ default=None,
+ help='The folder in which to offload the model weights (or where the '
+ 'model weights are already offloaded).')
+ parser.add_argument(
+ '--max-new-tokens',
+ type=int,
+ default=2048,
+ help='Maximum number of new tokens allowed in generated text')
+ parser.add_argument(
+ '--temperature',
+ type=float,
+ default=0.1,
+ help='The value used to modulate the next token probabilities.')
+ parser.add_argument(
+ '--top-k',
+ type=int,
+ default=40,
+ help='The number of highest probability vocabulary tokens to '
+ 'keep for top-k-filtering.')
+ parser.add_argument(
+ '--top-p',
+ type=float,
+ default=0.75,
+ help='If set to float < 1, only the smallest set of most probable '
+ 'tokens with probabilities that add up to top_p or higher are '
+ 'kept for generation.')
+ parser.add_argument(
+ '--repetition-penalty',
+ type=float,
+ default=1.0,
+ help='The parameter for repetition penalty. 1.0 means no penalty.')
+ parser.add_argument(
+ '--seed',
+ type=int,
+ default=0,
+ help='Random seed for reproducible text generation')
+ args = parser.parse_args()
+ return args
+
+
+def get_input():
+ """Helper function for getting input from users."""
+ sentinel = '' # ends when this string is seen
+ result = None
+ while result is None:
+ print(('\ndouble enter to end input (EXIT: exit chat, '
+ 'RESET: reset history) >>> '),
+ end='')
+ try:
+ result = '\n'.join(iter(input, sentinel))
+ except UnicodeDecodeError:
+ print('Invalid characters detected. Please enter again.')
+ return result
+
+
+def main():
+ args = parse_args()
+ torch.manual_seed(args.seed)
+
+ # parse config
+ if not osp.isfile(args.config):
+ try:
+ args.config = cfgs_name_path[args.config]
+ except KeyError:
+ raise FileNotFoundError(f'Cannot find {args.config}')
+
+ # load config
+ cfg = Config.fromfile(args.config)
+ # if args.cfg_options is not None:
+ # cfg.merge_from_dict(args.cfg_options)
+
+ model_name = cfg.model.type if isinstance(cfg.model.type,
+ str) else cfg.model.type.__name__
+ if 'LLaVAModel' or 'OMG' in model_name:
+ cfg.model.pretrained_pth = None
+
+ model = BUILDER.build(cfg.model)
+ print(model.state_dict().keys())
+
+ # pre_state_dict = torch.load("/root/omg-llava.pth")
+ # model.load_state_dict(pre_state_dict)
+
+ backend = get_file_backend(args.pth_model)
+ if isinstance(backend, PetrelBackend):
+ from xtuner.utils.fileio import patch_fileio
+ with patch_fileio():
+ state_dict = guess_load_checkpoint(args.pth_model)
+ else:
+ state_dict = guess_load_checkpoint(args.pth_model)
+
+ print(state_dict.keys())
+ # del state_dict['llm.base_model.model.model.tok_embeddings.weight']
+ model.load_state_dict(state_dict, strict=False)
+ print(f'Load PTH model from {args.pth_model}')
+
+ # image_processor_cfg = copy.deepcopy(cfg.image_processor)
+ image_processor = cfg.image_processor
+ image_processor_type = image_processor['type']
+ del image_processor['type']
+ image_processor = image_processor_type(**image_processor)
+
+ # chat_hook = EvaluateChatHook(
+ # tokenizer=cfg.tokenizer,
+ # image_processor=image_processor_cfg,
+ # every_n_iters=100,
+ # evaluation_inputs=cfg.evaluation_inputs,
+ # evaluation_images=cfg.evaluation_images,
+ # system='',
+ # prompt_template=PROMPT_TEMPLATE.internlm2_chat
+ # )
+ # model.cuda()
+ # model.eval()
+ # chat_hook._eval_images_(model, model.device, max_new_tokens=200)
+
+ # build llm
+ quantization_config = None
+ load_in_8bit = False
+ if args.bits == 4:
+ quantization_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ load_in_8bit=False,
+ llm_int8_threshold=6.0,
+ llm_int8_has_fp16_weight=False,
+ bnb_4bit_compute_dtype=torch.float16,
+ bnb_4bit_use_double_quant=True,
+ bnb_4bit_quant_type='nf4')
+ elif args.bits == 8:
+ load_in_8bit = True
+ model_kwargs = {
+ 'quantization_config': quantization_config,
+ 'load_in_8bit': load_in_8bit,
+ 'device_map': 'auto',
+ 'offload_folder': args.offload_folder,
+ 'trust_remote_code': True,
+ 'torch_dtype': TORCH_DTYPE_MAP[args.torch_dtype]
+ }
+ if False:
+ pass
+ else:
+ if args.with_plugins is None:
+ inner_thoughts_open = False
+ calculate_open = False
+ solve_open = False
+ search_open = False
+ else:
+ assert args.prompt_template == args.system_template == 'moss_sft'
+ from plugins import plugins_api
+ inner_thoughts_open = True
+ calculate_open = 'calculate' in args.with_plugins
+ solve_open = 'solve' in args.with_plugins
+ search_open = 'search' in args.with_plugins
+ # pre-import for api and model preparation
+ if calculate_open:
+ from plugins import calculate # noqa: F401
+ if solve_open:
+ from plugins import solve # noqa: F401
+ if search_open:
+ from plugins import search # noqa: F401
+ # build llm
+ llm = model.llm
+ tokenizer = model.tokenizer
+
+ model.cuda()
+ model.eval()
+ llm.eval()
+ visual_encoder = model.visual_encoder
+ projector = model.projector
+ projector_text2vision = model.projector_text2vision
+
+ if args.image is not None:
+ image = load_image(args.image)
+ image = expand2square(
+ image, tuple(int(x * 255) for x in image_processor.image_mean))
+ image_for_show = image
+ image = image_processor.preprocess(
+ image, return_tensors='pt')['pixel_values'][0]
+ image = image.cuda().unsqueeze(0).to(visual_encoder.dtype)
+ visual_outputs = visual_encoder(image, output_hidden_states=True)
+ print([item.shape for item in visual_outputs])
+ pixel_values = projector(visual_outputs)
+
+ stop_words = args.stop_words
+ sep = ''
+ if args.prompt_template:
+ template = PROMPT_TEMPLATE[args.prompt_template]
+ stop_words += template.get('STOP_WORDS', [])
+ sep = template.get('SEP', '')
+ stop_criteria = get_stop_criteria(
+ tokenizer=tokenizer, stop_words=stop_words)
+
+ if args.no_streamer:
+ streamer = None
+ else:
+ streamer = TextStreamer(tokenizer, skip_prompt=True)
+
+ gen_config = GenerationConfig(
+ max_new_tokens=args.max_new_tokens,
+ do_sample=args.temperature > 0,
+ temperature=args.temperature,
+ top_p=args.top_p,
+ top_k=args.top_k,
+ repetition_penalty=args.repetition_penalty,
+ eos_token_id=tokenizer.eos_token_id,
+ pad_token_id=tokenizer.pad_token_id
+ if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
+ )
+
+ n_turn = 0
+ inputs = ''
+ while True:
+ text = get_input()
+ while text.strip() == 'RESET':
+ print('Log: History responses have been removed!')
+ n_turn = 0
+ inputs = ''
+ text = get_input()
+ if text.strip() == 'EXIT':
+ print('Log: Exit!')
+ exit(0)
+
+ if args.image is not None and n_turn == 0:
+ text = DEFAULT_IMAGE_TOKEN + '\n' + text
+
+ if args.prompt_template:
+ prompt_text = ''
+ template = PROMPT_TEMPLATE[args.prompt_template]
+ if 'SYSTEM' in template and n_turn == 0:
+ system_text = None
+ if args.system_template is not None:
+ system_text = SYSTEM_TEMPLATE[
+ args.system_template].format(
+ round=n_turn + 1, bot_name=args.bot_name)
+ elif args.system is not None:
+ system_text = args.system
+ if system_text is not None:
+ prompt_text += template['SYSTEM'].format(
+ system=system_text,
+ round=n_turn + 1,
+ bot_name=args.bot_name)
+ prompt_text += template['INSTRUCTION'].format(
+ input=text, round=n_turn + 1, bot_name=args.bot_name)
+ if args.prompt_template == args.system_template == 'moss_sft':
+ if not inner_thoughts_open:
+ prompt_text.replace('- Inner thoughts: enabled.',
+ '- Inner thoughts: disabled.')
+ if not calculate_open:
+ prompt_text.replace(('- Calculator: enabled. API: '
+ 'Calculate(expression)'),
+ '- Calculator: disabled.')
+ if not solve_open:
+ prompt_text.replace(
+ '- Equation solver: enabled. API: Solve(equation)',
+ '- Equation solver: disabled.')
+ if not search_open:
+ prompt_text.replace(
+ '- Web search: enabled. API: Search(query)',
+ '- Web search: disabled.')
+ else:
+ prompt_text = text
+ print("prompt_text: ", prompt_text)
+ inputs += prompt_text
+ if args.image is None:
+ if n_turn == 0:
+ ids = tokenizer.encode(inputs, return_tensors='pt')
+ else:
+ ids = tokenizer.encode(
+ inputs, return_tensors='pt', add_special_tokens=False)
+
+ if args.with_plugins is not None:
+ generate_output = llm.generate(
+ inputs=ids.cuda(),
+ generation_config=gen_config,
+ streamer=streamer,
+ stopping_criteria=stop_criteria).cpu()
+ generate_output_text = tokenizer.decode(
+ generate_output[0][len(ids[0]):])
+ if streamer is None:
+ end = '' if generate_output_text[-1] == '\n' else '\n'
+ print(generate_output_text, end=end)
+ pattern = r'<\|Commands\|>:(.*?) (.*?)<\/p>')
+ phrases = pattern.findall(text_output)
+ phrases = [p.strip() for p in phrases]
+
+ # Remove the [SEG] token
+ cleaned_str = cleaned_str.replace('[SEG]', '')
+
+ # Strip unnecessary spaces
+ cleaned_str = ' '.join(cleaned_str.split()).strip("'")
+ cleaned_str = cleaned_str.strip()
+
+ # Convert the predicted masks into RLE format
+ pred_masks_tensor = pred_masks.cpu()
+ uncompressed_mask_rles = mask_to_rle_pytorch(pred_masks_tensor)
+ rle_masks = []
+ for m in uncompressed_mask_rles:
+ rle_masks.append(coco_encode_rle(m))
+
+ # Create results dictionary
+ result_dict = {
+ "image_id": image_name[:-4],
+ "caption": cleaned_str,
+ "phrases": phrases,
+ "pred_masks": rle_masks
+ }
+
+ # print(cleaned_str)
+ # print(phrases)
+
+ output_path = f"{output_dir}/{image_name[:-4]}.json"
+
+ with open(output_path, 'w') as f:
+ json.dump(result_dict, f)
+
+ return
+
+def mask_to_rle_pytorch(tensor: torch.Tensor):
+ """
+ Encodes masks to an uncompressed RLE, in the format expected by
+ pycoco tools.
+ """
+ # Put in fortran order and flatten h,w
+ b, h, w = tensor.shape
+ tensor = tensor.permute(0, 2, 1).flatten(1)
+
+ # Compute change indices
+ diff = tensor[:, 1:] ^ tensor[:, :-1]
+ change_indices = diff.nonzero()
+
+ # Encode run length
+ out = []
+ for i in range(b):
+ cur_idxs = change_indices[change_indices[:, 0] == i, 1]
+ cur_idxs = torch.cat(
+ [torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), cur_idxs + 1,
+ torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), ]
+ )
+ btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
+ counts = [] if tensor[i, 0] == 0 else [0]
+ counts.extend(btw_idxs.detach().cpu().tolist())
+ out.append({"size": [h, w], "counts": counts})
+
+ return out
+
+def coco_encode_rle(uncompressed_rle):
+ h, w = uncompressed_rle["size"]
+ rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
+ rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json
+
+ return rle
+
+if __name__ == '__main__':
+
+ main()
diff --git a/omg_llava/tools/mmbench_omg_seg_llava.py b/omg_llava/tools/mmbench_omg_seg_llava.py
new file mode 100644
index 0000000000000000000000000000000000000000..23de5842ec68e2bc38e835f4a782047cc523cb87
--- /dev/null
+++ b/omg_llava/tools/mmbench_omg_seg_llava.py
@@ -0,0 +1,498 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import json
+import math
+import os.path as osp
+import re
+import string
+import time
+
+import numpy as np
+import pandas as pd
+import torch
+import tqdm
+from huggingface_hub import snapshot_download
+from mmengine import mkdir_or_exist
+from mmengine.dist import (collect_results, get_dist_info, get_rank, init_dist,
+ master_only)
+from mmengine.utils.dl_utils import set_multi_processing
+from peft import PeftModel
+from rich.console import Console
+from rich.table import Table
+from torch.utils.data import Dataset
+from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer,
+ BitsAndBytesConfig, CLIPImageProcessor,
+ CLIPVisionModel, GenerationConfig)
+
+from xtuner.dataset.utils import decode_base64_to_image, expand2square
+from xtuner.model.utils import LoadWoInit, prepare_inputs_labels_for_multimodal
+from xtuner.tools.utils import get_stop_criteria, is_cn_string
+from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX,
+ PROMPT_TEMPLATE)
+from importlib import import_module
+from xtuner.registry import BUILDER
+from xtuner.configs import cfgs_name_path
+from xtuner.model.utils import guess_load_checkpoint
+from mmengine.config import Config
+from mmengine.fileio import PetrelBackend, get_file_backend
+from mmengine.config import ConfigDict
+
+def convert_dict2config_dict(input):
+ input = ConfigDict(**input)
+ for key in input.keys():
+ if isinstance(input[key], dict):
+ input[key] = convert_dict2config_dict(input[key])
+ return input
+
+TORCH_DTYPE_MAP = dict(
+ fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto')
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='MMBench')
+ parser.add_argument('config', help='config file name or path.')
+ parser.add_argument('pth_model', help='pth model file')
+ parser.add_argument('--data-path', default=None, help='data path')
+ parser.add_argument('--work-dir', help='the dir to save results')
+ parser.add_argument('--llava', default=None, help='llava name or path')
+ parser.add_argument(
+ '--visual-encoder', default=None, help='visual encoder name or path')
+ parser.add_argument(
+ '--visual-select-layer', default=-2, help='visual select layer')
+ parser.add_argument(
+ '--prompt-template',
+ choices=PROMPT_TEMPLATE.keys(),
+ default='internlm2_chat',
+ help='Specify a prompt template')
+ parser.add_argument(
+ '--stop-words', nargs='+', type=str, default=[], help='Stop words')
+ parser.add_argument(
+ '--torch-dtype',
+ default='fp16',
+ choices=TORCH_DTYPE_MAP.keys(),
+ help='Override the default `torch.dtype` and load the model under '
+ 'a specific `dtype`.')
+ parser.add_argument(
+ '--bits',
+ type=int,
+ choices=[4, 8, None],
+ default=None,
+ help='LLM bits')
+ parser.add_argument(
+ '--bot-name', type=str, default='BOT', help='Name for Bot')
+ parser.add_argument(
+ '--offload-folder',
+ default=None,
+ help='The folder in which to offload the model weights (or where the '
+ 'model weights are already offloaded).')
+ parser.add_argument(
+ '--max-new-tokens',
+ type=int,
+ default=100,
+ help='Maximum number of new tokens allowed in generated text')
+ parser.add_argument(
+ '--seed',
+ type=int,
+ default=0,
+ help='Random seed for reproducible text generation')
+ parser.add_argument(
+ '--launcher',
+ choices=['none', 'pytorch', 'slurm', 'mpi'],
+ default='none',
+ help='job launcher')
+ args = parser.parse_args()
+ return args
+
+
+@master_only
+def master_print(msg):
+ print(msg)
+
+
+class MMBenchDataset(Dataset):
+ ABBRS = {
+ 'coarse_perception': 'CP',
+ 'finegrained_perception (instance-level)': 'FP-S',
+ 'finegrained_perception (cross-instance)': 'FP-C',
+ 'logic_reasoning': 'LR',
+ 'relation_reasoning': 'RR',
+ 'attribute_reasoning': 'AR',
+ 'sketch_reasoning': 'Sketch Reasoning',
+ 'scenery_building': 'Scenery & Building',
+ 'food_clothes': 'Food & Clothes',
+ 'historical_figure': 'Historical Figure',
+ 'traditional_show': 'Traditional Show',
+ 'calligraphy_painting': 'Calligraphy Painting',
+ 'cultural_relic': 'Cultural Relic'
+ }
+
+ def __init__(self, data_file):
+ self.data_file = data_file
+ self.df = pd.read_csv(data_file, sep='\t')
+ self.split = 'dev' if 'answer' in self.df.iloc[0].keys() else 'test'
+ self.has_l2_category = 'l2-category' in self.df.columns.to_list()
+
+ def get_image(self, image):
+ while len(image) < 16:
+ image = self.df[self.df['index'] == int(image)]['image'].values
+ assert len(image) == 1
+ image = image[0]
+ image = decode_base64_to_image(image)
+ return image
+
+ def __len__(self):
+ return len(self.df)
+
+ def __getitem__(self, idx):
+ index = self.df.iloc[idx]['index']
+ image = self.df.iloc[idx]['image']
+ image = self.get_image(image)
+ question = self.df.iloc[idx]['question']
+ answer = self.df.iloc[idx]['answer'] if 'answer' in self.df.iloc[
+ 0].keys() else None
+ category = self.df.iloc[idx]['category']
+
+ options = {
+ cand: self.load_from_df(idx, cand)
+ for cand in string.ascii_uppercase
+ if self.load_from_df(idx, cand) is not None
+ }
+ options_prompt = ''
+ for key, item in options.items():
+ options_prompt += f'{key}. {item}\n'
+
+ hint = self.load_from_df(idx, 'hint')
+ data = {
+ 'img': image,
+ 'question': question,
+ 'answer': answer,
+ 'options': options_prompt,
+ 'category': category,
+ 'options_dict': options,
+ 'index': index,
+ 'context': hint,
+ }
+ if self.has_l2_category:
+ data.update({'l2-category': self.df.iloc[idx]['l2-category']})
+ return data
+
+ def load_from_df(self, idx, key):
+ if key in self.df.iloc[idx] and not pd.isna(self.df.iloc[idx][key]):
+ return self.df.iloc[idx][key]
+ else:
+ return None
+
+ @master_only
+ def eval_result(self, result_df, show=True):
+
+ def calc_acc(df, group='category'):
+ assert group in ['overall', 'category', 'l2-category']
+ if group == 'overall':
+ res = {'Average': np.mean(df['hit'])}
+ else:
+ res = {}
+ abilities = list(set(df[group]))
+ abilities.sort()
+ for ab in abilities:
+ sub_df = df[df[group] == ab]
+ ab = self.ABBRS[ab] if ab in self.ABBRS else ab
+ res[ab] = np.mean(sub_df['hit'])
+ return res
+
+ def eval_sub_data(sub_data, answer_map):
+ lt = len(sub_data)
+ for i in range(lt):
+ item = sub_data.iloc[i]
+ match = re.search(r'([A-D]+)', item['prediction'])
+ pred = match.group(1) if match else ''
+ gt = answer_map[item['index']]
+ if gt != pred:
+ return 0
+ return 1
+
+ def show_result(ret_json):
+ show_dict = ret_json.copy()
+ table = Table(title=f' MMBench ({self.data_file}) ')
+ console = Console()
+ table.add_column('Category', justify='left')
+ table.add_column('Accuracy (%)', justify='right')
+ average = show_dict.pop('Average') * 100
+ table.add_row('Average', f'{average:.1f}')
+ table.add_section()
+ for cat_name, cat_acc in show_dict.items():
+ table.add_row(cat_name, f'{cat_acc * 100:.1f}')
+ with console.capture() as capture:
+ console.print(table, end='')
+ print('\n' + capture.get())
+ print('Note: Please be cautious if you use the results in papers, '
+ "since we don't use ChatGPT as a helper for choice "
+ 'extraction')
+
+ data = result_df.sort_values(by='index')
+ data['prediction'] = [str(x) for x in data['prediction']]
+ for k in data.keys():
+ data[k.lower() if k not in 'ABCD' else k] = data.pop(k)
+
+ data_main = data[data['index'] < int(1e6)]
+ cate_map = {
+ i: c
+ for i, c in zip(self.df['index'], self.df['category'])
+ }
+ if self.has_l2_category:
+ l2_cate_map = {
+ i: c
+ for i, c in zip(self.df['index'], self.df['l2-category'])
+ }
+ answer_map = {
+ i: c
+ for i, c in zip(self.df['index'], self.df['answer'])
+ }
+
+ lt = len(data_main)
+ hit, tot = 0, 0
+ result = {}
+ for i in range(lt):
+ item_main = data_main.iloc[i]
+ idx = item_main['index']
+ assert idx not in result
+ sub_data = data[data['index'] % int(1e6) == idx]
+ ret = eval_sub_data(sub_data, answer_map)
+ result[idx] = ret
+ hit += ret
+ tot += 1
+
+ indices = data_main['index']
+ data_main = data_main.copy()
+ data_main['hit'] = [result[i] for i in indices]
+ main_idx = data_main['index']
+ data_main['category'] = [cate_map[i] for i in main_idx]
+
+ ret_json = calc_acc(data_main, 'overall')
+
+ if self.has_l2_category:
+ data_main['l2-category'] = [l2_cate_map[i] for i in main_idx]
+ l2 = calc_acc(data_main, 'l2-category')
+ ret_json.update(l2)
+ else:
+ leaf = calc_acc(data_main, 'category')
+ ret_json.update(leaf)
+ if show:
+ show_result(ret_json)
+ return ret_json
+
+
+def main():
+ args = parse_args()
+
+ torch.manual_seed(args.seed)
+
+ if args.launcher != 'none':
+ set_multi_processing(distributed=True)
+ init_dist(args.launcher)
+
+ rank, world_size = get_dist_info()
+ torch.cuda.set_device(rank)
+ else:
+ rank = 0
+ world_size = 1
+
+ # build model
+ if not osp.isfile(args.config):
+ try:
+ args.config = cfgs_name_path[args.config]
+ except KeyError:
+ raise FileNotFoundError(f'Cannot find {args.config}')
+
+ # load config
+ cfg = Config.fromfile(args.config)
+ # if args.cfg_options is not None:
+ # cfg.merge_from_dict(args.cfg_options)
+
+ model_name = cfg.model.type if isinstance(cfg.model.type,
+ str) else cfg.model.type.__name__
+ if 'LLaVAModel' or 'OMG' in model_name:
+ cfg.model.pretrained_pth = None
+
+ model = BUILDER.build(cfg.model)
+ backend = get_file_backend(args.pth_model)
+ if isinstance(backend, PetrelBackend):
+ from xtuner.utils.fileio import patch_fileio
+ with patch_fileio():
+ state_dict = guess_load_checkpoint(args.pth_model)
+ else:
+ state_dict = guess_load_checkpoint(args.pth_model)
+
+ print(state_dict.keys())
+ # del state_dict['llm.base_model.model.model.tok_embeddings.weight']
+ model.load_state_dict(state_dict, strict=False)
+ print(f'Load PTH model from {args.pth_model}')
+ # image_processor_cfg = copy.deepcopy(cfg.image_processor)
+ image_processor = cfg.image_processor
+ image_processor_type = image_processor['type']
+ del image_processor['type']
+ image_processor = image_processor_type(**image_processor)
+
+ llm = model.llm
+ tokenizer = model.tokenizer
+
+ model.cuda()
+ model.eval()
+ llm.eval()
+ visual_encoder = model.visual_encoder
+ projector = model.projector
+ projector_text2vision = model.projector_text2vision
+
+ projector.cuda()
+ projector.eval()
+
+ visual_encoder.cuda()
+ visual_encoder.eval()
+
+ stop_words = args.stop_words
+ if args.prompt_template:
+ template = PROMPT_TEMPLATE[args.prompt_template]
+ stop_words += template.get('STOP_WORDS', [])
+ stop_criteria = get_stop_criteria(
+ tokenizer=tokenizer, stop_words=stop_words)
+
+ gen_config = GenerationConfig(
+ max_new_tokens=args.max_new_tokens,
+ do_sample=False,
+ eos_token_id=tokenizer.eos_token_id,
+ pad_token_id=tokenizer.pad_token_id
+ if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
+ )
+
+ # work_dir
+ if args.work_dir is not None:
+ # update configs according to CLI args if args.work_dir is not None
+ save_dir = args.work_dir
+ else:
+ # use config filename as default work_dir
+ save_dir = osp.join('./work_dirs',
+ osp.splitext(osp.basename(args.data_path))[0])
+ timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))
+ save_dir = osp.join(save_dir, timestamp)
+
+ if rank == 0:
+ mkdir_or_exist(osp.abspath(save_dir))
+ print('=======================================================')
+ print(f'Dataset path: {osp.abspath(args.data_path)}\n'
+ f'Results will be saved to {osp.abspath(save_dir)}')
+ print('=======================================================')
+
+ args_path = osp.join(save_dir, 'args.json')
+ with open(args_path, 'w', encoding='utf-8') as f:
+ json.dump(args.__dict__, f, indent=2)
+
+ results_xlsx_path = osp.join(save_dir, 'mmbench_result.xlsx')
+ results_json_path = osp.join(save_dir, 'mmbench_result.json')
+
+ dataset = MMBenchDataset(args.data_path)
+
+ results = []
+ n_samples = len(dataset)
+ per_rank_samples = math.ceil(n_samples / world_size)
+
+ per_rank_ids = range(per_rank_samples * rank,
+ min(n_samples, per_rank_samples * (rank + 1)))
+ for i in tqdm.tqdm(per_rank_ids, desc=f'Rank {rank}'):
+ data_sample = dataset[i]
+ if data_sample['context'] is not None:
+ text = data_sample['context'] + '\n' + data_sample[
+ 'question'] + '\n' + data_sample['options']
+ else:
+ text = data_sample['question'] + '\n' + data_sample['options']
+
+ text = DEFAULT_IMAGE_TOKEN + '\n' + text
+
+ if is_cn_string(text):
+ text = text + '请直接回答选项字母。'
+ else:
+ text = text + ("Answer with the option's letter from the "
+ 'given choices directly.')
+
+ if args.prompt_template:
+ prompt_text = ''
+ template = PROMPT_TEMPLATE[args.prompt_template]
+ prompt_text += template['INSTRUCTION'].format(
+ input=text, round=1, bot_name=args.bot_name)
+ else:
+ prompt_text = text
+ inputs = prompt_text
+
+ image = data_sample['img'].convert('RGB')
+ image = expand2square(
+ image, tuple(int(x * 255) for x in image_processor.image_mean))
+ image = image_processor.preprocess(
+ image, return_tensors='pt')['pixel_values'][0]
+ image = image.cuda().unsqueeze(0).to(visual_encoder.dtype)
+ visual_outputs = visual_encoder(image, output_hidden_states=True)
+ if isinstance(visual_outputs, list) or isinstance(visual_outputs, tuple)\
+ or isinstance(visual_outputs, torch.Tensor):
+ pixel_values = projector(visual_outputs)
+ else:
+ pixel_values = projector(
+ visual_outputs.hidden_states[args.visual_select_layer][:, 1:])
+ # pixel_values = projector(
+ # visual_outputs.hidden_states[args.visual_select_layer][:, 1:])
+
+ chunk_encode = []
+ for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)):
+ if idx == 0:
+ cur_encode = tokenizer.encode(chunk)
+ else:
+ cur_encode = tokenizer.encode(chunk, add_special_tokens=False)
+ chunk_encode.append(cur_encode)
+ assert len(chunk_encode) == 2
+ ids = []
+ for idx, cur_chunk_encode in enumerate(chunk_encode):
+ ids.extend(cur_chunk_encode)
+ if idx != len(chunk_encode) - 1:
+ ids.append(IMAGE_TOKEN_INDEX)
+ ids = torch.tensor(ids).cuda().unsqueeze(0)
+ mm_inputs = prepare_inputs_labels_for_multimodal(
+ llm=llm, input_ids=ids, pixel_values=pixel_values)
+
+ generate_output = llm.generate(
+ **mm_inputs,
+ generation_config=gen_config,
+ streamer=None,
+ bos_token_id=tokenizer.bos_token_id,
+ stopping_criteria=stop_criteria)
+
+ predict = tokenizer.decode(
+ generate_output[0], skip_special_tokens=True).strip()
+ cur_result = {}
+ cur_result['question'] = data_sample.get('question')
+ cur_result.update(data_sample.get('options_dict'))
+ cur_result['prediction'] = predict
+ if data_sample.get('category') is not None:
+ cur_result['category'] = data_sample.get('category')
+ if data_sample.get('l2-category') is not None:
+ cur_result['l2-category'] = data_sample.get('l2-category')
+ cur_result['index'] = data_sample.get('index')
+ cur_result['split'] = data_sample.get('split')
+ cur_result['answer'] = data_sample.get('answer')
+ results.append(cur_result)
+
+ results = collect_results(results, n_samples)
+
+ if get_rank() == 0:
+
+ results_df = pd.DataFrame(results)
+ # with pd.ExcelWriter(results_xlsx_path, engine='openpyxl') as writer:
+ with pd.ExcelWriter(results_xlsx_path, engine='xlsxwriter') as writer:
+ results_df.to_excel(writer, index=False)
+
+ if dataset.split == 'dev':
+ results_dict = dataset.eval_result(results_df, show=True)
+ with open(results_json_path, 'w', encoding='utf-8') as f:
+ json.dump(results_dict, f, indent=2)
+ else:
+ print('All done!')
+
+
+if __name__ == '__main__':
+
+ main()
diff --git a/omg_llava/tools/refcoco_omg_seg_llava.py b/omg_llava/tools/refcoco_omg_seg_llava.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3f445ef3e6a6d7733a39c1f317113af67ff1822
--- /dev/null
+++ b/omg_llava/tools/refcoco_omg_seg_llava.py
@@ -0,0 +1,727 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import math
+import os
+import os.path as osp
+import numpy as np
+import torch
+import tqdm
+from mmengine.dist import (collect_results, get_dist_info, get_rank, init_dist,
+ master_only)
+from mmengine.utils.dl_utils import set_multi_processing
+from torch.utils.data import Dataset
+from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer,
+ BitsAndBytesConfig, CLIPImageProcessor,
+ CLIPVisionModel, GenerationConfig)
+
+from xtuner.model.utils import prepare_inputs_labels_for_multimodal
+from xtuner.tools.utils import get_stop_criteria
+from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX,
+ PROMPT_TEMPLATE)
+from xtuner.registry import BUILDER
+from xtuner.configs import cfgs_name_path
+from xtuner.model.utils import guess_load_checkpoint
+from mmengine.config import Config
+from mmengine.fileio import PetrelBackend, get_file_backend
+from mmengine.config import ConfigDict
+
+import logging
+from mmengine import print_log
+from PIL import Image
+from pycocotools import mask
+import torch.nn.functional as F
+from omg_llava.dataset.utils import expand2square
+from omg_llava.dataset.utils.refcoco_refer import REFER
+from omg_llava.tools.utils_refcoco import AverageMeter, Summary, intersectionAndUnionGPU
+
+
+def convert_dict2config_dict(input):
+ input = ConfigDict(**input)
+ for key in input.keys():
+ if isinstance(input[key], dict):
+ input[key] = convert_dict2config_dict(input[key])
+ return input
+
+TORCH_DTYPE_MAP = dict(
+ fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto')
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='RefCocoSeg')
+ parser.add_argument('config', help='config file name or path.')
+ parser.add_argument('pth_model', help='pth model file')
+ parser.add_argument(
+ '--dataset',
+ choices=DATASETS_ATTRIBUTES.keys(),
+ default='refcoco',
+ help='Specify a ref dataset')
+ parser.add_argument(
+ '--split',
+ default='val',
+ help='Specify a split')
+ parser.add_argument(
+ '--prompt-template',
+ choices=PROMPT_TEMPLATE.keys(),
+ default='internlm2_chat',
+ help='Specify a prompt template')
+ parser.add_argument(
+ '--stop-words', nargs='+', type=str, default=[], help='Stop words')
+ parser.add_argument(
+ '--torch-dtype',
+ default='fp16',
+ choices=TORCH_DTYPE_MAP.keys(),
+ help='Override the default `torch.dtype` and load the model under '
+ 'a specific `dtype`.')
+ parser.add_argument(
+ '--bits',
+ type=int,
+ choices=[4, 8, None],
+ default=None,
+ help='LLM bits')
+ parser.add_argument(
+ '--bot-name', type=str, default='BOT', help='Name for Bot')
+ parser.add_argument(
+ '--offload-folder',
+ default=None,
+ help='The folder in which to offload the model weights (or where the '
+ 'model weights are already offloaded).')
+ parser.add_argument(
+ '--max-new-tokens',
+ type=int,
+ default=100,
+ help='Maximum number of new tokens allowed in generated text')
+ parser.add_argument(
+ '--seed',
+ type=int,
+ default=0,
+ help='Random seed for reproducible text generation')
+ parser.add_argument(
+ '--launcher',
+ choices=['none', 'pytorch', 'slurm', 'mpi'],
+ default='none',
+ help='job launcher')
+ args = parser.parse_args()
+ return args
+
+DATASETS_ATTRIBUTES = {
+ 'refcoco': {'splitBy': "unc", 'dataset_name': 'refcoco'},
+ 'refcoco_plus': {'splitBy': "unc", 'dataset_name': 'refcoco+'},
+ 'refcocog': {'splitBy': "umd", 'dataset_name': 'refcocog'},
+}
+
+@master_only
+def master_print(msg):
+ print(msg)
+
+class RefcocoReferringSegDataset(Dataset):
+ def __init__(self,
+ image_folder,
+ image_processor,
+ dataset_name,
+ data_path=None,
+ tokenizer=None,
+ offline_processed_text_folder=None,
+ pad_image_to_square=False,
+ debug=False,
+ repeats=1,
+ split='val',
+ ):
+ self.split = split
+ self._set_attribute(dataset_name)
+ self.debug = debug
+ if offline_processed_text_folder and data_path:
+ print_log(
+ 'Both `offline_processed_text_folder` and '
+ '`data_path` are set, and we load dataset from'
+ '`offline_processed_text_folder` '
+ f'({offline_processed_text_folder})',
+ logger='current',
+ level=logging.WARNING)
+
+ if offline_processed_text_folder is not None:
+ raise NotImplementedError
+ else:
+ json_datas = self.json_file_preprocess(data_path)
+ self.json_datas = json_datas
+ # json_datas = self.only_get_hf_map_infos()
+ # json_data = DatasetDict({'train': HFDataset.from_list(json_datas)})
+ # self.text_data = process_hf_dataset(
+ # dataset=json_data,
+ # tokenizer=tokenizer,
+ # max_length=max_length,
+ # dataset_map_fn=dataset_map_fn,
+ # template_map_fn=template_map_fn,
+ # split='train',
+ # max_dataset_length=max_dataset_length,
+ # remove_unused_columns=False,
+ # pack_to_max_length=False,
+ # with_image_token=True,
+ # map_num_proc=num_proc, # because limited mem
+ # )
+
+ self.image_folder = image_folder
+ size = image_processor.crop_size
+ if isinstance(size, int):
+ self.image_h, self.image_w = size, size
+ else:
+ self.image_w, self.image_h = size
+
+ if isinstance(image_processor, dict) or isinstance(
+ image_processor, Config) or isinstance(image_processor,
+ ConfigDict):
+ self.image_processor = BUILDER.build(image_processor)
+ else:
+ self.image_processor = image_processor
+ self.pad_image_to_square = pad_image_to_square
+ self.down_ratio = 1
+ self.repeats = repeats
+
+ def _set_attribute(self, dataset_name):
+ attr_dict = DATASETS_ATTRIBUTES[dataset_name]
+
+ self.splitBy = attr_dict['splitBy']
+ self.dataset_name = attr_dict['dataset_name']
+
+ def __len__(self):
+ return len(self.json_datas) * self.repeats
+
+ def real_len(self):
+ return len(self.json_datas)
+
+ def json_file_preprocess(self, data_path):
+ splitBy = self.splitBy
+ dataset_name = self.dataset_name
+ refer_api = REFER(data_path, dataset_name, splitBy)
+ ref_ids_train = refer_api.getRefIds(split=self.split)
+ images_ids_train = refer_api.getImgIds(ref_ids=ref_ids_train)
+ refs_train = refer_api.loadRefs(ref_ids=ref_ids_train)
+ self.img2refs = self.create_img_to_refs_mapping(refs_train)
+
+ image_infos = []
+ loaded_images = refer_api.loadImgs(image_ids=images_ids_train)
+ for item in loaded_images:
+ item = item.copy()
+ image_infos.append(item)
+
+ self.annotations = refer_api.Anns
+ # self.img2refs = self.create_img_to_refs_mapping(refs_train)
+
+ refs = [self.img2refs[image_info['id']] for image_info in image_infos]
+
+ ret = []
+ for image_info, ref in zip(image_infos, refs):
+ if len(ref) == 0:
+ continue
+
+ sents = []
+ ann_ids = []
+ for _ref in ref:
+ for sent in _ref["sentences"]:
+ text = sent["sent"]
+ sents.append(text)
+ ann_ids.append(_ref["ann_id"])
+
+ # if len(sents) >= 3:
+ # sampled_inds = np.random.choice(
+ # list(range(len(sents))), 3, replace=False
+ # )
+ # else:
+ # sampled_inds = list(range(len(sents)))
+
+ sampled_inds = list(range(len(sents)))
+ sampled_sents = np.vectorize(sents.__getitem__)(sampled_inds).tolist()
+ # sampled_ann_ids = np.vectorize(ann_ids.__getitem__)(sampled_inds).tolist()
+ sampled_ann_ids = [ann_ids[ind] for ind in sampled_inds]
+ selected_labels = sampled_sents
+ ret.append(
+ {'image_info': image_info,
+ 'sampled_ann_id': sampled_ann_ids,
+ 'selected_labels': selected_labels,
+ 'image': image_info['file_name']
+ }
+ )
+ if self.debug:
+ return ret[:10]
+ return ret
+
+ def load_images(self, refer_api, images_ids_train, dataset_dir, dataset_name, inference=False):
+ images = []
+ loaded_images = refer_api.loadImgs(image_ids=images_ids_train)
+ for item in loaded_images:
+ item = item.copy()
+ images.append(item)
+ return images
+
+ def create_img_to_refs_mapping(self, refs_train):
+ img2refs = {}
+ for ref in refs_train:
+ img2refs[ref["image_id"]] = img2refs.get(ref["image_id"], []) + [ref, ]
+ return img2refs
+
+ def decode_mask(self, annotations_ids, image_info):
+ flag = False
+ masks = []
+
+ for ann_id in annotations_ids:
+ if isinstance(ann_id, list):
+ flag = True
+ if -1 in ann_id:
+ assert len(ann_id) == 1
+ m = np.zeros((image_info["height"], image_info["width"])).astype(
+ np.uint8
+ )
+ else:
+ m_final = np.zeros(
+ (image_info["height"], image_info["width"])
+ ).astype(np.uint8)
+ for ann_id_i in ann_id:
+ ann = self.annotations[ann_id_i]
+
+ if len(ann["segmentation"]) == 0:
+ m = np.zeros(
+ (image_info["height"], image_info["width"])
+ ).astype(np.uint8)
+ else:
+ if type(ann["segmentation"][0]) == list: # polygon
+ rle = mask.frPyObjects(
+ ann["segmentation"], image_info["height"], image_info["width"], )
+ else:
+ rle = ann["segmentation"]
+ for i in range(len(rle)):
+ if not isinstance(rle[i]["counts"], bytes):
+ rle[i]["counts"] = rle[i]["counts"].encode()
+ m = mask.decode(rle)
+ m = np.sum(
+ m, axis=2
+ ) # sometimes there are multiple binary map (corresponding to multiple segs)
+ m = m.astype(np.uint8) # convert to np.uint8
+ m_final = m_final | m
+ m = m_final
+ masks.append(m)
+ continue
+
+ ann = self.annotations[ann_id]
+
+ if len(ann["segmentation"]) == 0:
+ m = np.zeros((image_info["height"], image_info["width"])).astype(
+ np.uint8
+ )
+ masks.append(m)
+ continue
+
+ if type(ann["segmentation"][0]) == list: # polygon
+ rle = mask.frPyObjects(
+ ann["segmentation"], image_info["height"], image_info["width"]
+ )
+ else:
+ rle = ann["segmentation"]
+ for i in range(len(rle)):
+ if not isinstance(rle[i]["counts"], bytes):
+ rle[i]["counts"] = rle[i]["counts"].encode()
+ m = mask.decode(rle)
+ m = np.sum(m, axis=2) # sometimes there are multiple binary map (corresponding to multiple segs)
+ m = m.astype(np.uint8) # convert to np.uint8
+ masks.append(m)
+ masks = np.stack(masks, axis=0)
+
+ # if self.pad_image_to_square:
+ # masks = expand2square_mask(masks)
+ masks = torch.from_numpy(masks)
+
+ # masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio,
+ # self.image_w // self.down_ratio), mode='nearest').squeeze(0)
+
+ # print(image_info['file_name'])
+ # print(masks.shape)
+ # save_masks = torch.stack([masks[0], masks[0], masks[0]], dim=-1)
+ # save_masks = save_masks.numpy() * 255
+ # save_masks = Image.fromarray(save_masks.astype(np.uint8))
+ # save_masks.save("/root/mask.png")
+ # print(kkk)
+ return masks
+
+ def only_get_text_infos(self, json_data):
+ return {'sampled_sents': json_data['selected_labels']}
+
+ def get_questions(self, text_require_infos):
+ sampled_sents = text_require_infos['sampled_sents']
+ ret = []
+ for sent in sampled_sents:
+ ret.append("Please segment {} in this image.".format(sent))
+ return ret
+
+ def filter_data_dict(self, data_dict):
+ names = ['pixel_values', 'masks', 'ori_size', 'questions']
+ ret = {name: data_dict[name] for name in names}
+ return ret
+
+ def __getitem__(self, index):
+ index = index % self.real_len()
+ data_dict = self.json_datas[index]
+ text_require_infos = self.only_get_text_infos(data_dict)
+ questions = self.get_questions(text_require_infos)
+
+ assert data_dict.get('image', None) is not None
+ if data_dict.get('image', None) is not None:
+ image_file = data_dict['image']
+ image_file = os.path.join(self.image_folder, image_file)
+ # print(image_file)
+ image = Image.open(image_file).convert('RGB')
+ ori_width, ori_height = image.size
+ if self.pad_image_to_square:
+ image = expand2square(
+ image,
+ tuple(
+ int(x * 255) for x in self.image_processor.image_mean))
+ image = self.image_processor.preprocess(
+ image, return_tensors='pt')['pixel_values'][0]
+ data_dict['pixel_values'] = image
+
+ # process and get masks
+ masks = self.decode_mask(data_dict['sampled_ann_id'], data_dict['image_info'])
+ data_dict['masks'] = masks
+ data_dict['ori_size'] = (ori_width, ori_height)
+ data_dict['questions'] = questions
+ else:
+ if hasattr(self.image_processor, 'crop_size'):
+ crop_size = self.image_processor.crop_size
+ else:
+ crop_size = self.image_processor.size
+ data_dict['pixel_values'] = torch.zeros(3, crop_size['height'],
+ crop_size['width'])
+ data_dict['masks'] = None
+ # pixel_values, binary masks, conversation/input ids
+ return self.filter_data_dict(data_dict)
+
+def main():
+ args = parse_args()
+
+ torch.manual_seed(args.seed)
+
+ if args.launcher != 'none':
+ set_multi_processing(distributed=True)
+ init_dist(args.launcher)
+
+ rank, world_size = get_dist_info()
+ torch.cuda.set_device(rank)
+ else:
+ rank = 0
+ world_size = 1
+
+ # build model
+ if not osp.isfile(args.config):
+ try:
+ args.config = cfgs_name_path[args.config]
+ except KeyError:
+ raise FileNotFoundError(f'Cannot find {args.config}')
+
+ # load config
+ cfg = Config.fromfile(args.config)
+ # if args.cfg_options is not None:
+ # cfg.merge_from_dict(args.cfg_options)
+
+ model_name = cfg.model.type if isinstance(cfg.model.type,
+ str) else cfg.model.type.__name__
+ if 'LLaVAModel' or 'OMG' in model_name:
+ cfg.model.pretrained_pth = None
+
+ model = BUILDER.build(cfg.model)
+ backend = get_file_backend(args.pth_model)
+
+ if os.path.exists(cfg.pretrained_pth):
+ if isinstance(backend, PetrelBackend):
+ from xtuner.utils.fileio import patch_fileio
+ with patch_fileio():
+ state_dict = guess_load_checkpoint(cfg.pretrained_pth)
+ else:
+ state_dict = guess_load_checkpoint(cfg.pretrained_pth)
+
+ # del state_dict['llm.base_model.model.model.tok_embeddings.weight']
+ model.load_state_dict(state_dict, strict=False)
+ print(f'Load pre PTH model from {cfg.pretrained_pth}')
+
+ if isinstance(backend, PetrelBackend):
+ from xtuner.utils.fileio import patch_fileio
+ with patch_fileio():
+ state_dict = guess_load_checkpoint(args.pth_model)
+ else:
+ state_dict = guess_load_checkpoint(args.pth_model)
+
+ # del state_dict['llm.base_model.model.model.tok_embeddings.weight']
+ # print(state_dict.keys())
+ model.load_state_dict(state_dict, strict=False)
+ print(f'Load PTH model from {args.pth_model}')
+ # image_processor_cfg = copy.deepcopy(cfg.image_processor)
+ image_processor = cfg.image_processor
+ image_processor_type = image_processor['type']
+ del image_processor['type']
+ image_processor = image_processor_type(**image_processor)
+
+ llm = model.llm
+ tokenizer = model.tokenizer
+
+ model.cuda()
+ model.eval()
+ llm.eval()
+ visual_encoder = model.visual_encoder
+ projector = model.projector
+ projector_text2vision = model.projector_text2vision
+
+ projector.cuda()
+ projector.eval()
+
+ visual_encoder.cuda()
+ visual_encoder.eval()
+
+ stop_words = args.stop_words
+ if args.prompt_template:
+ template = PROMPT_TEMPLATE[args.prompt_template]
+ stop_words += template.get('STOP_WORDS', [])
+ stop_criteria = get_stop_criteria(
+ tokenizer=tokenizer, stop_words=stop_words)
+
+ gen_config = GenerationConfig(
+ max_new_tokens=args.max_new_tokens,
+ do_sample=False,
+ eos_token_id=tokenizer.eos_token_id,
+ pad_token_id=tokenizer.pad_token_id
+ if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
+ )
+
+ # # work_dir
+ # if args.work_dir is not None:
+ # # update configs according to CLI args if args.work_dir is not None
+ # save_dir = args.work_dir
+ # else:
+ # # use config filename as default work_dir
+ # save_dir = osp.join('./work_dirs',
+ # osp.splitext(osp.basename(args.data_path))[0])
+ # timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))
+ # save_dir = osp.join(save_dir, timestamp)
+
+ # if rank == 0:
+ # mkdir_or_exist(osp.abspath(save_dir))
+ # print('=======================================================')
+ # print(f'Dataset path: {osp.abspath(args.data_path)}\n'
+ # f'Results will be saved to {osp.abspath(save_dir)}')
+ # print('=======================================================')
+
+ # args_path = osp.join(save_dir, 'args.json')
+ # with open(args_path, 'w', encoding='utf-8') as f:
+ # json.dump(args.__dict__, f, indent=2)
+
+ # results_xlsx_path = osp.join(save_dir, 'mmbench_result.xlsx')
+ # results_json_path = osp.join(save_dir, 'mmbench_result.json')
+
+ dataset = RefcocoReferringSegDataset(
+ dataset_name=args.dataset,
+ image_folder='./data/glamm_data/' + 'images/coco2014/train2014/',
+ image_processor=image_processor,
+ data_path="./data/ref_seg/",
+ tokenizer=tokenizer,
+ pad_image_to_square=True,
+ debug=False,
+ split=args.split,
+ # debug=True,
+ )
+
+ results = []
+ n_samples = len(dataset)
+ per_rank_samples = math.ceil(n_samples / world_size)
+
+ per_rank_ids = range(per_rank_samples * rank,
+ min(n_samples, per_rank_samples * (rank + 1)))
+
+ trackers = {
+ "intersection": AverageMeter("Intersec", ":6.3f", Summary.SUM),
+ "union": AverageMeter("Union", ":6.3f", Summary.SUM),
+ "gIoU": AverageMeter("gIoU", ":6.3f", Summary.SUM)
+ }
+
+ for i in tqdm.tqdm(per_rank_ids, desc=f'Rank {rank}'):
+ data_sample = dataset[i]
+ questions = data_sample['questions']
+ texts = []
+ for question in questions:
+ texts.append(DEFAULT_IMAGE_TOKEN + '\n' + question)
+
+ # if data_sample['context'] is not None:
+ # text = data_sample['context'] + '\n' + data_sample[
+ # 'question'] + '\n' + data_sample['options']
+ # else:
+ # text = data_sample['question'] + '\n' + data_sample['options']
+ #
+ # text = DEFAULT_IMAGE_TOKEN + '\n' + text
+ #
+ # if is_cn_string(text):
+ # text = text + '请直接回答选项字母。'
+ # else:
+ # text = text + ("Answer with the option's letter from the "
+ # 'given choices directly.')
+ prompt_texts = []
+
+ if args.prompt_template:
+ for text in texts:
+ prompt_text = ''
+ template = PROMPT_TEMPLATE[args.prompt_template]
+ prompt_text += template['INSTRUCTION'].format(
+ input=text, round=1, bot_name=args.bot_name)
+ prompt_texts.append(prompt_text)
+ else:
+ prompt_texts = texts
+
+ batch_inputs = prompt_texts
+
+ image = data_sample['pixel_values'] # ()
+ image = image.cuda().unsqueeze(0).to(visual_encoder.dtype)
+ visual_outputs = visual_encoder(image, output_hidden_states=True)
+ if isinstance(visual_outputs, list) or isinstance(visual_outputs, tuple)\
+ or isinstance(visual_outputs, torch.Tensor):
+ pixel_values = projector(visual_outputs)
+ else:
+ pixel_values = projector(
+ visual_outputs.hidden_states[args.visual_select_layer][:, 1:])
+ # pixel_values = projector(
+ # visual_outputs.hidden_states[args.visual_select_layer][:, 1:])
+
+ ori_size = data_sample['ori_size']
+ target_masks = data_sample['masks'].cuda().to(torch.uint8)
+
+ intersection, union, accuracy_iou = 0.0, 0.0, 0.0
+
+ for idx_inp, inputs in enumerate(batch_inputs):
+ # print("Question: ", inputs)
+ chunk_encode = []
+ for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)):
+ if idx == 0:
+ cur_encode = tokenizer.encode(chunk)
+ else:
+ cur_encode = tokenizer.encode(chunk, add_special_tokens=False)
+ chunk_encode.append(cur_encode)
+ assert len(chunk_encode) == 2
+ ids = []
+ for idx, cur_chunk_encode in enumerate(chunk_encode):
+ ids.extend(cur_chunk_encode)
+ if idx != len(chunk_encode) - 1:
+ ids.append(IMAGE_TOKEN_INDEX)
+ ids = torch.tensor(ids).cuda().unsqueeze(0)
+ mm_inputs = prepare_inputs_labels_for_multimodal(
+ llm=llm, input_ids=ids, pixel_values=pixel_values)
+
+ # mm_inputs['inputs_embeds'] = mm_inputs['inputs_embeds'].to(torch.float16)
+
+ generate_output = llm.generate(
+ **mm_inputs,
+ generation_config=gen_config,
+ streamer=None,
+ bos_token_id=tokenizer.bos_token_id,
+ stopping_criteria=stop_criteria,
+ output_hidden_states=True,
+ return_dict_in_generate=True
+ )
+ predict = tokenizer.decode(
+ # generate_output.sequences[0], skip_special_tokens=True).strip()
+ generate_output.sequences[0]).strip()
+ # print("Answer:", predict)
+
+ hidden_states = generate_output.hidden_states
+ last_hidden_states = [item[-1][-1] for item in hidden_states]
+ last_hidden_states = torch.cat(last_hidden_states, dim=0)
+ seg_hidden_states = get_seg_hidden_states(
+ # last_hidden_states, generate_output.sequences[0],
+ last_hidden_states, generate_output.sequences[0][:-1],
+ seg_id=model.seg_token_idx
+ )
+ # seg_hidden_states = seg_hidden_states.to(torch.float32)
+ # print("Mask num: ", len(seg_hidden_states))
+ if len(seg_hidden_states) == 0:
+ print("Warning, no [SEG] tokens !!!")
+ continue
+ elif len(seg_hidden_states) > 1:
+ print("Warning, {} [SEG] tokens !!!".format(len(seg_hidden_states)))
+ seg_hidden_states = seg_hidden_states[:1]
+
+ seg_hidden_states = projector_text2vision(seg_hidden_states)
+ batch_idxs = torch.zeros((seg_hidden_states.shape[0],),
+ dtype=torch.int64).to(seg_hidden_states.device)
+ pred_masks_list = model.visual_encoder.forward_llm_seg(seg_hidden_states, batch_idxs)
+ pred_masks = pred_masks_list[-1]
+ w, h = ori_size
+ masks = F.interpolate(pred_masks, size=(max(w, h), max(w, h)),
+ mode='bilinear', align_corners=False)
+ masks = masks[:, 0]
+ # remove padding
+ if w == h:
+ pass
+ elif w > h:
+ n_pad = w - h
+ n_pad_1 = n_pad // 2
+ n_pad_2 = n_pad - n_pad_1
+ masks = masks[:, n_pad_1: w - n_pad_2]
+ else:
+ n_pad = h - w
+ n_pad_1 = n_pad // 2
+ n_pad_2 = n_pad - n_pad_1
+ masks = masks[:, :, n_pad_1: h - n_pad_2]
+ # binary
+ masks = masks.sigmoid() > 0.5
+ masks = masks.int()
+ _target = target_masks[idx_inp:idx_inp+1].int()
+
+ # intersection, union, accuracy_iou = 0.0, 0.0, 0.0
+ for target, prediction in zip(masks, _target):
+ intersect, union_, _ = intersectionAndUnionGPU(
+ prediction.contiguous().clone(), target.contiguous(), 2, ignore_index=255
+ )
+ intersection += intersect
+ union += union_
+ accuracy_iou += intersect / (union_ + 1e-5)
+ # print(intersect / (union_ + 1e-5))
+ # handles no-object targets
+ accuracy_iou[union_ == 0] += 1.0
+
+ intersection, union = intersection.cpu().numpy(), union.cpu().numpy()
+ accuracy_iou = accuracy_iou.cpu().numpy() / target_masks.shape[0]
+ trackers["intersection"].update(intersection)
+ trackers["union"].update(union)
+ trackers["gIoU"].update(accuracy_iou, n=target_masks.shape[0])
+
+ # for meter in trackers.values():
+ # meter.all_reduce()
+ # print(trackers["intersection"].sum, ' ', trackers["union"].sum, ' ',
+ # trackers["gIoU"].avg, ' ', trackers["gIoU"].count)
+ cur_results = {'pixel_intersection': trackers["intersection"].sum[1],
+ 'pixel_union': trackers["union"].sum[1],
+ 'gIoU': trackers["gIoU"].avg[1],
+ 'mask_counts': trackers["gIoU"].count,
+ }
+ results.append(cur_results)
+ # iou_per_class = trackers["intersection"].sum / (trackers["union"].sum + 1e-10)
+ # class_iou = iou_per_class[1]
+ # global_iou = trackers["gIoU"].avg[1]
+ #
+ # print("ciou: ", class_iou)
+ # print("giou: ", global_iou)
+
+ results = collect_results(results, n_samples)
+
+ if get_rank() == 0:
+ pixel_intersection = [cur_result['pixel_intersection'] for cur_result in results]
+ pixel_union = [cur_result['pixel_union'] for cur_result in results]
+ gIoUs = [cur_result['gIoU'] for cur_result in results]
+ mask_counts = [cur_result['mask_counts'] for cur_result in results]
+
+ class_iou = sum(pixel_intersection) / (sum(pixel_union) + 1e-10)
+ global_iou = sum([giou * n_masks for giou, n_masks in zip(gIoUs, mask_counts)]) / sum(mask_counts)
+ print("ciou: ", class_iou)
+ print("giou: ", global_iou)
+
+def get_seg_hidden_states(hidden_states, output_ids, seg_id):
+ seg_mask = output_ids == seg_id
+ n_out = len(seg_mask)
+ return hidden_states[-n_out:][seg_mask]
+
+
+if __name__ == '__main__':
+
+ main()
diff --git a/omg_llava/tools/refcoco_omg_seg_llava_msseg.py b/omg_llava/tools/refcoco_omg_seg_llava_msseg.py
new file mode 100644
index 0000000000000000000000000000000000000000..761ed557e7a9a3236dd9391e3f242dfcf0922695
--- /dev/null
+++ b/omg_llava/tools/refcoco_omg_seg_llava_msseg.py
@@ -0,0 +1,756 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import math
+import os
+import os.path as osp
+import numpy as np
+import torch
+import tqdm
+from mmengine.dist import (collect_results, get_dist_info, get_rank, init_dist,
+ master_only)
+from mmengine.utils.dl_utils import set_multi_processing
+from torch.utils.data import Dataset
+from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer,
+ BitsAndBytesConfig, CLIPImageProcessor,
+ CLIPVisionModel, GenerationConfig)
+
+from xtuner.model.utils import prepare_inputs_labels_for_multimodal
+from xtuner.tools.utils import get_stop_criteria
+from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX,
+ PROMPT_TEMPLATE)
+from xtuner.registry import BUILDER
+from xtuner.configs import cfgs_name_path
+from xtuner.model.utils import guess_load_checkpoint
+from mmengine.config import Config
+from mmengine.fileio import PetrelBackend, get_file_backend
+from mmengine.config import ConfigDict
+
+import logging
+from mmengine import print_log
+from PIL import Image
+from pycocotools import mask
+import torch.nn.functional as F
+from omg_llava.dataset.utils import expand2square
+from omg_llava.dataset.utils.refcoco_refer import REFER
+from omg_llava.tools.utils_refcoco import AverageMeter, Summary, intersectionAndUnionGPU
+
+
+def convert_dict2config_dict(input):
+ input = ConfigDict(**input)
+ for key in input.keys():
+ if isinstance(input[key], dict):
+ input[key] = convert_dict2config_dict(input[key])
+ return input
+
+TORCH_DTYPE_MAP = dict(
+ fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto')
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='RefCocoSeg')
+ parser.add_argument('config', help='config file name or path.')
+ parser.add_argument('pth_model', help='pth model file')
+ parser.add_argument(
+ '--dataset',
+ choices=DATASETS_ATTRIBUTES.keys(),
+ default='refcoco',
+ help='Specify a ref dataset')
+ parser.add_argument(
+ '--split',
+ default='val',
+ help='Specify a split')
+ parser.add_argument(
+ '--mode',
+ default='baseline', # baseline, mean, linear_cat
+ help='Specify a mode')
+ parser.add_argument(
+ '--prompt-template',
+ choices=PROMPT_TEMPLATE.keys(),
+ default='internlm2_chat',
+ help='Specify a prompt template')
+ parser.add_argument(
+ '--stop-words', nargs='+', type=str, default=[], help='Stop words')
+ parser.add_argument(
+ '--torch-dtype',
+ default='fp16',
+ choices=TORCH_DTYPE_MAP.keys(),
+ help='Override the default `torch.dtype` and load the model under '
+ 'a specific `dtype`.')
+ parser.add_argument(
+ '--bits',
+ type=int,
+ choices=[4, 8, None],
+ default=None,
+ help='LLM bits')
+ parser.add_argument(
+ '--bot-name', type=str, default='BOT', help='Name for Bot')
+ parser.add_argument(
+ '--offload-folder',
+ default=None,
+ help='The folder in which to offload the model weights (or where the '
+ 'model weights are already offloaded).')
+ parser.add_argument(
+ '--max-new-tokens',
+ type=int,
+ default=100,
+ help='Maximum number of new tokens allowed in generated text')
+ parser.add_argument(
+ '--seed',
+ type=int,
+ default=0,
+ help='Random seed for reproducible text generation')
+ parser.add_argument(
+ '--launcher',
+ choices=['none', 'pytorch', 'slurm', 'mpi'],
+ default='none',
+ help='job launcher')
+ args = parser.parse_args()
+ return args
+
+DATASETS_ATTRIBUTES = {
+ 'refcoco': {'splitBy': "unc", 'dataset_name': 'refcoco'},
+ 'refcoco_plus': {'splitBy': "unc", 'dataset_name': 'refcoco+'},
+ 'refcocog': {'splitBy': "umd", 'dataset_name': 'refcocog'},
+}
+
+@master_only
+def master_print(msg):
+ print(msg)
+
+class RefcocoReferringSegDataset(Dataset):
+ def __init__(self,
+ image_folder,
+ image_processor,
+ dataset_name,
+ data_path=None,
+ tokenizer=None,
+ offline_processed_text_folder=None,
+ pad_image_to_square=False,
+ debug=False,
+ repeats=1,
+ split='val',
+ ):
+ self.split = split
+ self._set_attribute(dataset_name)
+ self.debug = debug
+ if offline_processed_text_folder and data_path:
+ print_log(
+ 'Both `offline_processed_text_folder` and '
+ '`data_path` are set, and we load dataset from'
+ '`offline_processed_text_folder` '
+ f'({offline_processed_text_folder})',
+ logger='current',
+ level=logging.WARNING)
+
+ if offline_processed_text_folder is not None:
+ raise NotImplementedError
+ else:
+ json_datas = self.json_file_preprocess(data_path)
+ self.json_datas = json_datas
+ # json_datas = self.only_get_hf_map_infos()
+ # json_data = DatasetDict({'train': HFDataset.from_list(json_datas)})
+ # self.text_data = process_hf_dataset(
+ # dataset=json_data,
+ # tokenizer=tokenizer,
+ # max_length=max_length,
+ # dataset_map_fn=dataset_map_fn,
+ # template_map_fn=template_map_fn,
+ # split='train',
+ # max_dataset_length=max_dataset_length,
+ # remove_unused_columns=False,
+ # pack_to_max_length=False,
+ # with_image_token=True,
+ # map_num_proc=num_proc, # because limited mem
+ # )
+
+ self.image_folder = image_folder
+ size = image_processor.crop_size
+ if isinstance(size, int):
+ self.image_h, self.image_w = size, size
+ else:
+ self.image_w, self.image_h = size
+
+ if isinstance(image_processor, dict) or isinstance(
+ image_processor, Config) or isinstance(image_processor,
+ ConfigDict):
+ self.image_processor = BUILDER.build(image_processor)
+ else:
+ self.image_processor = image_processor
+ self.pad_image_to_square = pad_image_to_square
+ self.down_ratio = 1
+ self.repeats = repeats
+
+ def _set_attribute(self, dataset_name):
+ attr_dict = DATASETS_ATTRIBUTES[dataset_name]
+
+ self.splitBy = attr_dict['splitBy']
+ self.dataset_name = attr_dict['dataset_name']
+
+ def __len__(self):
+ return len(self.json_datas) * self.repeats
+
+ def real_len(self):
+ return len(self.json_datas)
+
+ def json_file_preprocess(self, data_path):
+ splitBy = self.splitBy
+ dataset_name = self.dataset_name
+ refer_api = REFER(data_path, dataset_name, splitBy)
+ ref_ids_train = refer_api.getRefIds(split=self.split)
+ images_ids_train = refer_api.getImgIds(ref_ids=ref_ids_train)
+ refs_train = refer_api.loadRefs(ref_ids=ref_ids_train)
+ self.img2refs = self.create_img_to_refs_mapping(refs_train)
+
+ image_infos = []
+ loaded_images = refer_api.loadImgs(image_ids=images_ids_train)
+ for item in loaded_images:
+ item = item.copy()
+ image_infos.append(item)
+
+ self.annotations = refer_api.Anns
+ # self.img2refs = self.create_img_to_refs_mapping(refs_train)
+
+ refs = [self.img2refs[image_info['id']] for image_info in image_infos]
+
+ ret = []
+ for image_info, ref in zip(image_infos, refs):
+ if len(ref) == 0:
+ continue
+
+ sents = []
+ ann_ids = []
+ for _ref in ref:
+ for sent in _ref["sentences"]:
+ text = sent["sent"]
+ sents.append(text)
+ ann_ids.append(_ref["ann_id"])
+
+ # if len(sents) >= 3:
+ # sampled_inds = np.random.choice(
+ # list(range(len(sents))), 3, replace=False
+ # )
+ # else:
+ # sampled_inds = list(range(len(sents)))
+
+ sampled_inds = list(range(len(sents)))
+ sampled_sents = np.vectorize(sents.__getitem__)(sampled_inds).tolist()
+ # sampled_ann_ids = np.vectorize(ann_ids.__getitem__)(sampled_inds).tolist()
+ sampled_ann_ids = [ann_ids[ind] for ind in sampled_inds]
+ selected_labels = sampled_sents
+ ret.append(
+ {'image_info': image_info,
+ 'sampled_ann_id': sampled_ann_ids,
+ 'selected_labels': selected_labels,
+ 'image': image_info['file_name']
+ }
+ )
+ if self.debug:
+ return ret[:10]
+ return ret
+
+ def load_images(self, refer_api, images_ids_train, dataset_dir, dataset_name, inference=False):
+ images = []
+ loaded_images = refer_api.loadImgs(image_ids=images_ids_train)
+ for item in loaded_images:
+ item = item.copy()
+ images.append(item)
+ return images
+
+ def create_img_to_refs_mapping(self, refs_train):
+ img2refs = {}
+ for ref in refs_train:
+ img2refs[ref["image_id"]] = img2refs.get(ref["image_id"], []) + [ref, ]
+ return img2refs
+
+ def decode_mask(self, annotations_ids, image_info):
+ flag = False
+ masks = []
+
+ for ann_id in annotations_ids:
+ if isinstance(ann_id, list):
+ flag = True
+ if -1 in ann_id:
+ assert len(ann_id) == 1
+ m = np.zeros((image_info["height"], image_info["width"])).astype(
+ np.uint8
+ )
+ else:
+ m_final = np.zeros(
+ (image_info["height"], image_info["width"])
+ ).astype(np.uint8)
+ for ann_id_i in ann_id:
+ ann = self.annotations[ann_id_i]
+
+ if len(ann["segmentation"]) == 0:
+ m = np.zeros(
+ (image_info["height"], image_info["width"])
+ ).astype(np.uint8)
+ else:
+ if type(ann["segmentation"][0]) == list: # polygon
+ rle = mask.frPyObjects(
+ ann["segmentation"], image_info["height"], image_info["width"], )
+ else:
+ rle = ann["segmentation"]
+ for i in range(len(rle)):
+ if not isinstance(rle[i]["counts"], bytes):
+ rle[i]["counts"] = rle[i]["counts"].encode()
+ m = mask.decode(rle)
+ m = np.sum(
+ m, axis=2
+ ) # sometimes there are multiple binary map (corresponding to multiple segs)
+ m = m.astype(np.uint8) # convert to np.uint8
+ m_final = m_final | m
+ m = m_final
+ masks.append(m)
+ continue
+
+ ann = self.annotations[ann_id]
+
+ if len(ann["segmentation"]) == 0:
+ m = np.zeros((image_info["height"], image_info["width"])).astype(
+ np.uint8
+ )
+ masks.append(m)
+ continue
+
+ if type(ann["segmentation"][0]) == list: # polygon
+ rle = mask.frPyObjects(
+ ann["segmentation"], image_info["height"], image_info["width"]
+ )
+ else:
+ rle = ann["segmentation"]
+ for i in range(len(rle)):
+ if not isinstance(rle[i]["counts"], bytes):
+ rle[i]["counts"] = rle[i]["counts"].encode()
+ m = mask.decode(rle)
+ m = np.sum(m, axis=2) # sometimes there are multiple binary map (corresponding to multiple segs)
+ m = m.astype(np.uint8) # convert to np.uint8
+ masks.append(m)
+ masks = np.stack(masks, axis=0)
+
+ # if self.pad_image_to_square:
+ # masks = expand2square_mask(masks)
+ masks = torch.from_numpy(masks)
+
+ # masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio,
+ # self.image_w // self.down_ratio), mode='nearest').squeeze(0)
+
+ # print(image_info['file_name'])
+ # print(masks.shape)
+ # save_masks = torch.stack([masks[0], masks[0], masks[0]], dim=-1)
+ # save_masks = save_masks.numpy() * 255
+ # save_masks = Image.fromarray(save_masks.astype(np.uint8))
+ # save_masks.save("/root/mask.png")
+ # print(kkk)
+ return masks
+
+ def only_get_text_infos(self, json_data):
+ return {'sampled_sents': json_data['selected_labels']}
+
+ def get_questions(self, text_require_infos):
+ sampled_sents = text_require_infos['sampled_sents']
+ ret = []
+ for sent in sampled_sents:
+ ret.append("Please segment {} in this image.".format(sent))
+ return ret
+
+ def filter_data_dict(self, data_dict):
+ names = ['pixel_values', 'masks', 'ori_size', 'questions']
+ ret = {name: data_dict[name] for name in names}
+ return ret
+
+ def __getitem__(self, index):
+ index = index % self.real_len()
+ data_dict = self.json_datas[index]
+ text_require_infos = self.only_get_text_infos(data_dict)
+ questions = self.get_questions(text_require_infos)
+
+ assert data_dict.get('image', None) is not None
+ if data_dict.get('image', None) is not None:
+ image_file = data_dict['image']
+ image_file = os.path.join(self.image_folder, image_file)
+ # print(image_file)
+ image = Image.open(image_file).convert('RGB')
+ ori_width, ori_height = image.size
+ if self.pad_image_to_square:
+ image = expand2square(
+ image,
+ tuple(
+ int(x * 255) for x in self.image_processor.image_mean))
+ image = self.image_processor.preprocess(
+ image, return_tensors='pt')['pixel_values'][0]
+ data_dict['pixel_values'] = image
+
+ # process and get masks
+ masks = self.decode_mask(data_dict['sampled_ann_id'], data_dict['image_info'])
+ data_dict['masks'] = masks
+ data_dict['ori_size'] = (ori_width, ori_height)
+ data_dict['questions'] = questions
+ else:
+ if hasattr(self.image_processor, 'crop_size'):
+ crop_size = self.image_processor.crop_size
+ else:
+ crop_size = self.image_processor.size
+ data_dict['pixel_values'] = torch.zeros(3, crop_size['height'],
+ crop_size['width'])
+ data_dict['masks'] = None
+ # pixel_values, binary masks, conversation/input ids
+ return self.filter_data_dict(data_dict)
+
+def main():
+ args = parse_args()
+
+ torch.manual_seed(args.seed)
+
+ if args.launcher != 'none':
+ set_multi_processing(distributed=True)
+ init_dist(args.launcher)
+
+ rank, world_size = get_dist_info()
+ torch.cuda.set_device(rank)
+ else:
+ rank = 0
+ world_size = 1
+
+ # build model
+ if not osp.isfile(args.config):
+ try:
+ args.config = cfgs_name_path[args.config]
+ except KeyError:
+ raise FileNotFoundError(f'Cannot find {args.config}')
+
+ # load config
+ cfg = Config.fromfile(args.config)
+ # if args.cfg_options is not None:
+ # cfg.merge_from_dict(args.cfg_options)
+
+ model_name = cfg.model.type if isinstance(cfg.model.type,
+ str) else cfg.model.type.__name__
+ if 'LLaVAModel' or 'OMG' in model_name:
+ cfg.model.pretrained_pth = None
+
+ model = BUILDER.build(cfg.model)
+ backend = get_file_backend(args.pth_model)
+
+ if os.path.exists(cfg.pretrained_pth):
+ if isinstance(backend, PetrelBackend):
+ from xtuner.utils.fileio import patch_fileio
+ with patch_fileio():
+ state_dict = guess_load_checkpoint(cfg.pretrained_pth)
+ else:
+ state_dict = guess_load_checkpoint(cfg.pretrained_pth)
+
+ # del state_dict['llm.base_model.model.model.tok_embeddings.weight']
+ model.load_state_dict(state_dict, strict=False)
+ print(f'Load pre PTH model from {cfg.pretrained_pth}')
+
+ if isinstance(backend, PetrelBackend):
+ from xtuner.utils.fileio import patch_fileio
+ with patch_fileio():
+ state_dict = guess_load_checkpoint(args.pth_model)
+ else:
+ state_dict = guess_load_checkpoint(args.pth_model)
+
+ # del state_dict['llm.base_model.model.model.tok_embeddings.weight']
+ # print(state_dict.keys())
+ model.load_state_dict(state_dict, strict=False)
+ print(f'Load PTH model from {args.pth_model}')
+ # image_processor_cfg = copy.deepcopy(cfg.image_processor)
+ image_processor = cfg.image_processor
+ image_processor_type = image_processor['type']
+ del image_processor['type']
+ image_processor = image_processor_type(**image_processor)
+
+ llm = model.llm
+ tokenizer = model.tokenizer
+
+ model.cuda()
+ model.eval()
+ llm.eval()
+ visual_encoder = model.visual_encoder
+ projector = model.projector
+ projector_text2vision = model.projector_text2vision
+
+ projector.cuda()
+ projector.eval()
+
+ visual_encoder.cuda()
+ visual_encoder.eval()
+
+ stop_words = args.stop_words
+ if args.prompt_template:
+ template = PROMPT_TEMPLATE[args.prompt_template]
+ stop_words += template.get('STOP_WORDS', [])
+ stop_criteria = get_stop_criteria(
+ tokenizer=tokenizer, stop_words=stop_words)
+
+ gen_config = GenerationConfig(
+ max_new_tokens=args.max_new_tokens,
+ do_sample=False,
+ eos_token_id=tokenizer.eos_token_id,
+ pad_token_id=tokenizer.pad_token_id
+ if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
+ )
+
+ # # work_dir
+ # if args.work_dir is not None:
+ # # update configs according to CLI args if args.work_dir is not None
+ # save_dir = args.work_dir
+ # else:
+ # # use config filename as default work_dir
+ # save_dir = osp.join('./work_dirs',
+ # osp.splitext(osp.basename(args.data_path))[0])
+ # timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))
+ # save_dir = osp.join(save_dir, timestamp)
+
+ # if rank == 0:
+ # mkdir_or_exist(osp.abspath(save_dir))
+ # print('=======================================================')
+ # print(f'Dataset path: {osp.abspath(args.data_path)}\n'
+ # f'Results will be saved to {osp.abspath(save_dir)}')
+ # print('=======================================================')
+
+ # args_path = osp.join(save_dir, 'args.json')
+ # with open(args_path, 'w', encoding='utf-8') as f:
+ # json.dump(args.__dict__, f, indent=2)
+
+ # results_xlsx_path = osp.join(save_dir, 'mmbench_result.xlsx')
+ # results_json_path = osp.join(save_dir, 'mmbench_result.json')
+
+ dataset = RefcocoReferringSegDataset(
+ dataset_name=args.dataset,
+ image_folder='./data/glamm_data/' + 'images/coco2014/train2014/',
+ image_processor=image_processor,
+ data_path="./data/ref_seg/",
+ tokenizer=tokenizer,
+ pad_image_to_square=True,
+ debug=False,
+ split=args.split,
+ # debug=True,
+ )
+
+ results = []
+ n_samples = len(dataset)
+ per_rank_samples = math.ceil(n_samples / world_size)
+
+ per_rank_ids = range(per_rank_samples * rank,
+ min(n_samples, per_rank_samples * (rank + 1)))
+
+ trackers = {
+ "intersection": AverageMeter("Intersec", ":6.3f", Summary.SUM),
+ "union": AverageMeter("Union", ":6.3f", Summary.SUM),
+ "gIoU": AverageMeter("gIoU", ":6.3f", Summary.SUM)
+ }
+
+ for i in tqdm.tqdm(per_rank_ids, desc=f'Rank {rank}'):
+ data_sample = dataset[i]
+ questions = data_sample['questions']
+ texts = []
+ for question in questions:
+ texts.append(DEFAULT_IMAGE_TOKEN + '\n' + question)
+
+ # if data_sample['context'] is not None:
+ # text = data_sample['context'] + '\n' + data_sample[
+ # 'question'] + '\n' + data_sample['options']
+ # else:
+ # text = data_sample['question'] + '\n' + data_sample['options']
+ #
+ # text = DEFAULT_IMAGE_TOKEN + '\n' + text
+ #
+ # if is_cn_string(text):
+ # text = text + '请直接回答选项字母。'
+ # else:
+ # text = text + ("Answer with the option's letter from the "
+ # 'given choices directly.')
+ prompt_texts = []
+
+ if args.prompt_template:
+ for text in texts:
+ prompt_text = ''
+ template = PROMPT_TEMPLATE[args.prompt_template]
+ prompt_text += template['INSTRUCTION'].format(
+ input=text, round=1, bot_name=args.bot_name)
+ prompt_texts.append(prompt_text)
+ else:
+ prompt_texts = texts
+
+ batch_inputs = prompt_texts
+
+ image = data_sample['pixel_values'] # ()
+ image = image.cuda().unsqueeze(0).to(visual_encoder.dtype)
+ visual_outputs = visual_encoder(image, output_hidden_states=True)
+ if isinstance(visual_outputs, list) or isinstance(visual_outputs, tuple)\
+ or isinstance(visual_outputs, torch.Tensor):
+ pixel_values = projector(visual_outputs)
+ else:
+ pixel_values = projector(
+ visual_outputs.hidden_states[args.visual_select_layer][:, 1:])
+ # pixel_values = projector(
+ # visual_outputs.hidden_states[args.visual_select_layer][:, 1:])
+
+ ori_size = data_sample['ori_size']
+ target_masks = data_sample['masks'].cuda().to(torch.uint8)
+
+ intersection, union, accuracy_iou = 0.0, 0.0, 0.0
+
+ for idx_inp, inputs in enumerate(batch_inputs):
+ # print("Question: ", inputs)
+ chunk_encode = []
+ for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)):
+ if idx == 0:
+ cur_encode = tokenizer.encode(chunk)
+ else:
+ cur_encode = tokenizer.encode(chunk, add_special_tokens=False)
+ chunk_encode.append(cur_encode)
+ assert len(chunk_encode) == 2
+ ids = []
+ for idx, cur_chunk_encode in enumerate(chunk_encode):
+ ids.extend(cur_chunk_encode)
+ if idx != len(chunk_encode) - 1:
+ ids.append(IMAGE_TOKEN_INDEX)
+ ids = torch.tensor(ids).cuda().unsqueeze(0)
+ mm_inputs = prepare_inputs_labels_for_multimodal(
+ llm=llm, input_ids=ids, pixel_values=pixel_values)
+
+ # mm_inputs['inputs_embeds'] = mm_inputs['inputs_embeds'].to(torch.float16)
+
+ generate_output = llm.generate(
+ **mm_inputs,
+ generation_config=gen_config,
+ streamer=None,
+ bos_token_id=tokenizer.bos_token_id,
+ stopping_criteria=stop_criteria,
+ output_hidden_states=True,
+ return_dict_in_generate=True
+ )
+ predict = tokenizer.decode(
+ # generate_output.sequences[0], skip_special_tokens=True).strip()
+ generate_output.sequences[0]).strip()
+ print("Answer:", predict)
+
+ hidden_states = generate_output.hidden_states
+ if args.mode == 'baseline':
+ last_hidden_states = [item[-1][-1] for item in hidden_states]
+ last_hidden_states = torch.cat(last_hidden_states, dim=0)
+ seg_hidden_states = get_seg_hidden_states(
+ # last_hidden_states, generate_output.sequences[0],
+ last_hidden_states, generate_output.sequences[0][:-1],
+ seg_id=model.seg_token_idx
+ )
+ else:
+ hidden_states = [torch.cat(item, dim=0)[:, -1] for item in hidden_states]
+ last_hidden_states = torch.stack(hidden_states, dim=0) # (N, n_layer, c)
+ seg_hidden_states = get_seg_hidden_states_multistates(
+ last_hidden_states, generate_output.sequences[0][:-1],
+ seg_id=model.seg_token_idx, mode=args.mode,
+ model=model,
+ )
+ # seg_hidden_states = seg_hidden_states.to(torch.float32)
+ # print("Mask num: ", len(seg_hidden_states))
+ if len(seg_hidden_states) == 0:
+ print("Warning, no [SEG] tokens !!!")
+ continue
+ elif len(seg_hidden_states) > 1:
+ print("Warning, {} [SEG] tokens !!!".format(len(seg_hidden_states)))
+ seg_hidden_states = seg_hidden_states[:1]
+
+ seg_hidden_states = projector_text2vision(seg_hidden_states)
+ batch_idxs = torch.zeros((seg_hidden_states.shape[0],),
+ dtype=torch.int64).to(seg_hidden_states.device)
+ pred_masks_list = model.visual_encoder.forward_llm_seg(seg_hidden_states, batch_idxs)
+ pred_masks = pred_masks_list[-1]
+ w, h = ori_size
+ masks = F.interpolate(pred_masks, size=(max(w, h), max(w, h)),
+ mode='bilinear', align_corners=False)
+ masks = masks[:, 0]
+ # remove padding
+ if w == h:
+ pass
+ elif w > h:
+ n_pad = w - h
+ n_pad_1 = n_pad // 2
+ n_pad_2 = n_pad - n_pad_1
+ masks = masks[:, n_pad_1: w - n_pad_2]
+ else:
+ n_pad = h - w
+ n_pad_1 = n_pad // 2
+ n_pad_2 = n_pad - n_pad_1
+ masks = masks[:, :, n_pad_1: h - n_pad_2]
+ # binary
+ masks = masks.sigmoid() > 0.5
+ masks = masks.int()
+ _target = target_masks[idx_inp:idx_inp+1].int()
+
+ # intersection, union, accuracy_iou = 0.0, 0.0, 0.0
+ for target, prediction in zip(masks, _target):
+ intersect, union_, _ = intersectionAndUnionGPU(
+ prediction.contiguous().clone(), target.contiguous(), 2, ignore_index=255
+ )
+ intersection += intersect
+ union += union_
+ accuracy_iou += intersect / (union_ + 1e-5)
+ # print(intersect / (union_ + 1e-5))
+ # handles no-object targets
+ accuracy_iou[union_ == 0] += 1.0
+
+ intersection, union = intersection.cpu().numpy(), union.cpu().numpy()
+ accuracy_iou = accuracy_iou.cpu().numpy() / target_masks.shape[0]
+ trackers["intersection"].update(intersection)
+ trackers["union"].update(union)
+ trackers["gIoU"].update(accuracy_iou, n=target_masks.shape[0])
+
+ # for meter in trackers.values():
+ # meter.all_reduce()
+ # print(trackers["intersection"].sum, ' ', trackers["union"].sum, ' ',
+ # trackers["gIoU"].avg, ' ', trackers["gIoU"].count)
+ cur_results = {'pixel_intersection': trackers["intersection"].sum[1],
+ 'pixel_union': trackers["union"].sum[1],
+ 'gIoU': trackers["gIoU"].avg[1],
+ 'mask_counts': trackers["gIoU"].count,
+ }
+ results.append(cur_results)
+ # iou_per_class = trackers["intersection"].sum / (trackers["union"].sum + 1e-10)
+ # class_iou = iou_per_class[1]
+ # global_iou = trackers["gIoU"].avg[1]
+ #
+ # print("ciou: ", class_iou)
+ # print("giou: ", global_iou)
+
+ results = collect_results(results, n_samples)
+
+ if get_rank() == 0:
+ pixel_intersection = [cur_result['pixel_intersection'] for cur_result in results]
+ pixel_union = [cur_result['pixel_union'] for cur_result in results]
+ gIoUs = [cur_result['gIoU'] for cur_result in results]
+ mask_counts = [cur_result['mask_counts'] for cur_result in results]
+
+ class_iou = sum(pixel_intersection) / (sum(pixel_union) + 1e-10)
+ global_iou = sum([giou * n_masks for giou, n_masks in zip(gIoUs, mask_counts)]) / sum(mask_counts)
+ print("ciou: ", class_iou)
+ print("giou: ", global_iou)
+
+def get_seg_hidden_states(hidden_states, output_ids, seg_id):
+ seg_mask = output_ids == seg_id
+ n_out = len(seg_mask)
+ return hidden_states[-n_out:][seg_mask]
+
+
+def get_seg_hidden_states_multistates(hidden_states, output_ids, seg_id, mode, model):
+ # (N, n_layer, c)
+ seg_mask = output_ids == seg_id
+ n_out = len(seg_mask)
+ hidden_states = hidden_states[-n_out:][seg_mask] # (N, n_layers, c)
+ if mode == 'mean':
+ hidden_states = torch.mean(hidden_states, dim=1)
+ elif mode == 'linear_cat':
+ hidden_states = hidden_states[:, -model.selected_layers:]
+ hidden_states = model.seg_token_proj_linear_cat[0](hidden_states)
+ hidden_states = hidden_states.flatten(1)
+ hidden_states = model.seg_token_proj_linear_cat[1](hidden_states)
+ else:
+ raise NotImplementedError
+ return hidden_states
+
+if __name__ == '__main__':
+
+ main()
diff --git a/omg_llava/tools/refvos_omg_seg_llava.py b/omg_llava/tools/refvos_omg_seg_llava.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3123a2c1aec3d30e6ccc65ad149280e4a0bbef3
--- /dev/null
+++ b/omg_llava/tools/refvos_omg_seg_llava.py
@@ -0,0 +1,527 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import json
+import math
+import os
+import os.path as osp
+import re
+import string
+import time
+
+import numpy as np
+import pandas as pd
+import torch
+import tqdm
+from huggingface_hub import snapshot_download
+from mmengine import mkdir_or_exist
+from mmengine.dist import (collect_results, get_dist_info, get_rank, init_dist,
+ master_only)
+from mmengine.utils.dl_utils import set_multi_processing
+from peft import PeftModel
+from rich.console import Console
+from rich.table import Table
+from torch.utils.data import Dataset
+from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer,
+ BitsAndBytesConfig, CLIPImageProcessor,
+ CLIPVisionModel, GenerationConfig)
+
+from xtuner.dataset.utils import decode_base64_to_image, expand2square
+from xtuner.model.utils import LoadWoInit, prepare_inputs_labels_for_multimodal
+from xtuner.tools.utils import get_stop_criteria, is_cn_string
+from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX,
+ PROMPT_TEMPLATE)
+from xtuner.model.git import GitPerceptionEncoder, GitPerceptionEncoder_Clip
+from importlib import import_module
+from xtuner.registry import BUILDER
+from xtuner.configs import cfgs_name_path
+from xtuner.model.utils import guess_load_checkpoint
+from mmengine.config import Config
+from mmengine.fileio import PetrelBackend, get_file_backend
+from mmengine.config import ConfigDict
+
+import logging
+from datasets import Dataset as HFDataset
+from datasets import DatasetDict, load_from_disk
+from mmengine import print_log
+from PIL import Image
+from pycocotools import mask
+import torch.nn.functional as F
+from xtuner.dataset.utils import expand2square, expand2square_mask
+from xtuner.dataset.huggingface import process_hf_dataset
+from xtuner.dataset.GLAMM_dataset.utils.refcoco_refer import REFER
+from xtuner.tools.utils_refcoco import AverageMeter, Summary, intersectionAndUnionGPU
+
+
+def convert_dict2config_dict(input):
+ input = ConfigDict(**input)
+ for key in input.keys():
+ if isinstance(input[key], dict):
+ input[key] = convert_dict2config_dict(input[key])
+ return input
+
+TORCH_DTYPE_MAP = dict(
+ fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto')
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='RefCocoSeg')
+ parser.add_argument('config', help='config file name or path.')
+ parser.add_argument('pth_model', help='pth model file')
+ parser.add_argument(
+ '--dataset',
+ choices=DATASETS_ATTRIBUTES.keys(),
+ default='refcoco',
+ help='Specify a ref dataset')
+ parser.add_argument(
+ '--split',
+ default='val',
+ help='Specify a split')
+ parser.add_argument(
+ '--prompt-template',
+ choices=PROMPT_TEMPLATE.keys(),
+ default='internlm2_chat',
+ help='Specify a prompt template')
+ parser.add_argument(
+ '--stop-words', nargs='+', type=str, default=[], help='Stop words')
+ parser.add_argument(
+ '--torch-dtype',
+ default='fp16',
+ choices=TORCH_DTYPE_MAP.keys(),
+ help='Override the default `torch.dtype` and load the model under '
+ 'a specific `dtype`.')
+ parser.add_argument(
+ '--bits',
+ type=int,
+ choices=[4, 8, None],
+ default=None,
+ help='LLM bits')
+ parser.add_argument(
+ '--bot-name', type=str, default='BOT', help='Name for Bot')
+ parser.add_argument(
+ '--offload-folder',
+ default=None,
+ help='The folder in which to offload the model weights (or where the '
+ 'model weights are already offloaded).')
+ parser.add_argument(
+ '--max-new-tokens',
+ type=int,
+ default=100,
+ help='Maximum number of new tokens allowed in generated text')
+ parser.add_argument(
+ '--seed',
+ type=int,
+ default=0,
+ help='Random seed for reproducible text generation')
+ parser.add_argument(
+ '--launcher',
+ choices=['none', 'pytorch', 'slurm', 'mpi'],
+ default='none',
+ help='job launcher')
+ args = parser.parse_args()
+ return args
+
+DATASETS_ATTRIBUTES = {
+ 'refcoco': {'splitBy': "unc", 'dataset_name': 'refcoco'},
+ 'refcoco_plus': {'splitBy': "unc", 'dataset_name': 'refcoco+'},
+ 'refcocog': {'splitBy': "umd", 'dataset_name': 'refcocog'},
+}
+
+@master_only
+def master_print(msg):
+ print(msg)
+
+class RefVOSDataset(Dataset):
+ def __init__(self,
+ image_folder,
+ image_processor,
+ expressions_file=None,
+ tokenizer=None,
+ offline_processed_text_folder=None,
+ pad_image_to_square=False,
+ debug=False,
+ repeats=1,
+ ):
+ self.debug = debug
+
+ if offline_processed_text_folder is not None:
+ raise NotImplementedError
+ else:
+ json_datas = self.json_file_preprocess(expressions_file, image_folder)
+ self.json_datas = json_datas
+
+ self.image_folder = image_folder
+ size = image_processor.crop_size
+ if isinstance(size, int):
+ self.image_h, self.image_w = size, size
+ else:
+ self.image_w, self.image_h = size
+
+ if isinstance(image_processor, dict) or isinstance(
+ image_processor, Config) or isinstance(image_processor,
+ ConfigDict):
+ self.image_processor = BUILDER.build(image_processor)
+ else:
+ self.image_processor = image_processor
+ self.pad_image_to_square = pad_image_to_square
+ self.down_ratio = 1
+ self.repeats = repeats
+
+ def __len__(self):
+ return len(self.json_datas) * self.repeats
+
+ def real_len(self):
+ return len(self.json_datas)
+
+ def json_file_preprocess(self, expression_json_file, image_folder):
+
+ video_files = os.listdir(image_folder)
+ with open(expression_json_file, 'r') as f:
+ expression_data = json.load(f)
+ expression_data = expression_data["videos"]
+ expression_data_ = {}
+ for video_file in video_files:
+ expression_data_.update({video_file: expression_data[video_file]})
+
+ ret_items = []
+ for video_file in expression_data_.keys():
+ video_expression_data = expression_data_[video_file]
+ require_infos = {}
+ require_infos['video_file'] = video_file
+ require_infos['frames'] = video_expression_data["frames"]
+ require_infos['expressions_id'] = list(video_expression_data["expressions"].keys())
+ require_infos['expressions'] = [video_expression_data["expressions"][id]["exp"] for id in require_infos['expressions_id']]
+ ret_items.append(require_infos)
+
+ return ret_items
+
+ def get_questions(self, expressions):
+ ret = []
+ for sent in expressions:
+ ret.append("Please segment {} in this image.".format(sent))
+ return ret
+
+ def __getitem__(self, index):
+ index = index % self.real_len()
+ data_dict = self.json_datas[index]
+ questions = self.get_questions(data_dict['expressions'])
+
+ data_dict['image'] = data_dict['frames'][0] + '.jpg'
+ assert data_dict.get('image', None) is not None
+ if data_dict.get('image', None) is not None:
+ image_file = data_dict['image']
+ image_file = os.path.join(self.image_folder, data_dict['video_file'], image_file)
+ # print(image_file)
+ image = Image.open(image_file).convert('RGB')
+ ori_width, ori_height = image.size
+ if self.pad_image_to_square:
+ image = expand2square(
+ image,
+ tuple(
+ int(x * 255) for x in self.image_processor.image_mean))
+ image = self.image_processor.preprocess(
+ image, return_tensors='pt')['pixel_values'][0]
+ data_dict['pixel_values'] = image
+
+ data_dict['ori_size'] = (ori_width, ori_height)
+ data_dict['questions'] = questions
+ else:
+ if hasattr(self.image_processor, 'crop_size'):
+ crop_size = self.image_processor.crop_size
+ else:
+ crop_size = self.image_processor.size
+ data_dict['pixel_values'] = torch.zeros(3, crop_size['height'],
+ crop_size['width'])
+ data_dict['masks'] = None
+ # pixel_values, binary masks, conversation/input ids
+ return data_dict
+
+def main():
+ args = parse_args()
+
+ torch.manual_seed(args.seed)
+
+ if args.launcher != 'none':
+ set_multi_processing(distributed=True)
+ init_dist(args.launcher)
+
+ rank, world_size = get_dist_info()
+ torch.cuda.set_device(rank)
+ else:
+ rank = 0
+ world_size = 1
+
+ # build model
+ if not osp.isfile(args.config):
+ try:
+ args.config = cfgs_name_path[args.config]
+ except KeyError:
+ raise FileNotFoundError(f'Cannot find {args.config}')
+
+ # load config
+ cfg = Config.fromfile(args.config)
+ # if args.cfg_options is not None:
+ # cfg.merge_from_dict(args.cfg_options)
+
+ model_name = cfg.model.type if isinstance(cfg.model.type,
+ str) else cfg.model.type.__name__
+ if 'LLaVAModel' in model_name:
+ cfg.model.pretrained_pth = None
+
+ model = BUILDER.build(cfg.model)
+ backend = get_file_backend(args.pth_model)
+
+ if os.path.exists(cfg.pretrained_pth):
+ if isinstance(backend, PetrelBackend):
+ from xtuner.utils.fileio import patch_fileio
+ with patch_fileio():
+ state_dict = guess_load_checkpoint(cfg.pretrained_pth)
+ else:
+ state_dict = guess_load_checkpoint(cfg.pretrained_pth)
+
+ # del state_dict['llm.base_model.model.model.tok_embeddings.weight']
+ model.load_state_dict(state_dict, strict=False)
+ print(f'Load pre PTH model from {cfg.pretrained_pth}')
+
+ if isinstance(backend, PetrelBackend):
+ from xtuner.utils.fileio import patch_fileio
+ with patch_fileio():
+ state_dict = guess_load_checkpoint(args.pth_model)
+ else:
+ state_dict = guess_load_checkpoint(args.pth_model)
+
+ # del state_dict['llm.base_model.model.model.tok_embeddings.weight']
+ # print(state_dict.keys())
+ model.load_state_dict(state_dict, strict=False)
+ print(f'Load PTH model from {args.pth_model}')
+ # image_processor_cfg = copy.deepcopy(cfg.image_processor)
+ image_processor = cfg.image_processor
+ image_processor_type = image_processor['type']
+ del image_processor['type']
+ image_processor = image_processor_type(**image_processor)
+
+ llm = model.llm
+ tokenizer = model.tokenizer
+
+ model.cuda()
+ model.eval()
+ llm.eval()
+ visual_encoder = model.visual_encoder
+ projector = model.projector
+ projector_text2vision = model.projector_text2vision
+
+ projector.cuda()
+ projector.eval()
+
+ visual_encoder.cuda()
+ visual_encoder.eval()
+
+ stop_words = args.stop_words
+ if args.prompt_template:
+ template = PROMPT_TEMPLATE[args.prompt_template]
+ stop_words += template.get('STOP_WORDS', [])
+ stop_criteria = get_stop_criteria(
+ tokenizer=tokenizer, stop_words=stop_words)
+
+ gen_config = GenerationConfig(
+ max_new_tokens=args.max_new_tokens,
+ do_sample=False,
+ eos_token_id=tokenizer.eos_token_id,
+ pad_token_id=tokenizer.pad_token_id
+ if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
+ )
+
+ # # work_dir
+ # if args.work_dir is not None:
+ # # update configs according to CLI args if args.work_dir is not None
+ # save_dir = args.work_dir
+ # else:
+ # # use config filename as default work_dir
+ # save_dir = osp.join('./work_dirs',
+ # osp.splitext(osp.basename(args.data_path))[0])
+ # timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))
+ # save_dir = osp.join(save_dir, timestamp)
+
+ # if rank == 0:
+ # mkdir_or_exist(osp.abspath(save_dir))
+ # print('=======================================================')
+ # print(f'Dataset path: {osp.abspath(args.data_path)}\n'
+ # f'Results will be saved to {osp.abspath(save_dir)}')
+ # print('=======================================================')
+
+ # args_path = osp.join(save_dir, 'args.json')
+ # with open(args_path, 'w', encoding='utf-8') as f:
+ # json.dump(args.__dict__, f, indent=2)
+
+ # results_xlsx_path = osp.join(save_dir, 'mmbench_result.xlsx')
+ # results_json_path = osp.join(save_dir, 'mmbench_result.json')
+
+ dataset = RefVOSDataset(
+ image_folder='./data/rvos/valid/JPRGImages/',
+ image_processor=image_processor,
+ expressions_file="./data/rvos/meta_expressions/valid/meta_expressions.json",
+ tokenizer=tokenizer,
+ pad_image_to_square=True,
+ debug=False,
+ # debug=True,
+ )
+
+ results = []
+ n_samples = len(dataset)
+ per_rank_samples = math.ceil(n_samples / world_size)
+
+ per_rank_ids = range(per_rank_samples * rank,
+ min(n_samples, per_rank_samples * (rank + 1)))
+
+ trackers = {
+ "intersection": AverageMeter("Intersec", ":6.3f", Summary.SUM),
+ "union": AverageMeter("Union", ":6.3f", Summary.SUM),
+ "gIoU": AverageMeter("gIoU", ":6.3f", Summary.SUM)
+ }
+
+ for i in tqdm.tqdm(per_rank_ids, desc=f'Rank {rank}'):
+ data_sample = dataset[i]
+ questions = data_sample['questions']
+ texts = []
+ for question in questions:
+ texts.append(DEFAULT_IMAGE_TOKEN + '\n' + question)
+
+ # if data_sample['context'] is not None:
+ # text = data_sample['context'] + '\n' + data_sample[
+ # 'question'] + '\n' + data_sample['options']
+ # else:
+ # text = data_sample['question'] + '\n' + data_sample['options']
+ #
+ # text = DEFAULT_IMAGE_TOKEN + '\n' + text
+ #
+ # if is_cn_string(text):
+ # text = text + '请直接回答选项字母。'
+ # else:
+ # text = text + ("Answer with the option's letter from the "
+ # 'given choices directly.')
+ prompt_texts = []
+
+ if args.prompt_template:
+ for text in texts:
+ prompt_text = ''
+ template = PROMPT_TEMPLATE[args.prompt_template]
+ prompt_text += template['INSTRUCTION'].format(
+ input=text, round=1, bot_name=args.bot_name)
+ prompt_texts.append(prompt_text)
+ else:
+ prompt_texts = texts
+
+ batch_inputs = prompt_texts
+
+ image = data_sample['pixel_values'] # ()
+ image = image.cuda().unsqueeze(0).to(visual_encoder.dtype)
+ visual_outputs = visual_encoder(image, output_hidden_states=True)
+ if isinstance(visual_outputs, list) or isinstance(visual_outputs, tuple)\
+ or isinstance(visual_outputs, torch.Tensor):
+ pixel_values = projector(visual_outputs)
+ else:
+ pixel_values = projector(
+ visual_outputs.hidden_states[args.visual_select_layer][:, 1:])
+ # pixel_values = projector(
+ # visual_outputs.hidden_states[args.visual_select_layer][:, 1:])
+
+ ori_size = data_sample['ori_size']
+ target_masks = data_sample['masks'].cuda().to(torch.uint8)
+
+ intersection, union, accuracy_iou = 0.0, 0.0, 0.0
+
+ for idx_inp, inputs in enumerate(batch_inputs):
+ # print("Question: ", inputs)
+ chunk_encode = []
+ for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)):
+ if idx == 0:
+ cur_encode = tokenizer.encode(chunk)
+ else:
+ cur_encode = tokenizer.encode(chunk, add_special_tokens=False)
+ chunk_encode.append(cur_encode)
+ assert len(chunk_encode) == 2
+ ids = []
+ for idx, cur_chunk_encode in enumerate(chunk_encode):
+ ids.extend(cur_chunk_encode)
+ if idx != len(chunk_encode) - 1:
+ ids.append(IMAGE_TOKEN_INDEX)
+ ids = torch.tensor(ids).cuda().unsqueeze(0)
+ mm_inputs = prepare_inputs_labels_for_multimodal(
+ llm=llm, input_ids=ids, pixel_values=pixel_values)
+ generate_output = llm.generate(
+ **mm_inputs,
+ generation_config=gen_config,
+ streamer=None,
+ bos_token_id=tokenizer.bos_token_id,
+ stopping_criteria=stop_criteria,
+ output_hidden_states=True,
+ return_dict_in_generate=True
+ )
+ predict = tokenizer.decode(
+ # generate_output.sequences[0], skip_special_tokens=True).strip()
+ generate_output.sequences[0]).strip()
+ # print("Answer:", predict)
+
+ hidden_states = generate_output.hidden_states
+ last_hidden_states = [item[-1][0] for item in hidden_states]
+ last_hidden_states = torch.cat(last_hidden_states, dim=0)
+ seg_hidden_states = get_seg_hidden_states(
+ last_hidden_states, generate_output.sequences[0],
+ seg_id=model.seg_token_idx
+ )
+ # print("Mask num: ", len(seg_hidden_states))
+ if len(seg_hidden_states) == 0:
+ print("Warning, no [SEG] tokens !!!")
+ continue
+ elif len(seg_hidden_states) > 1:
+ print("Warning, {} [SEG] tokens !!!".format(len(seg_hidden_states)))
+ seg_hidden_states = seg_hidden_states[:1]
+
+ seg_hidden_states = projector_text2vision(seg_hidden_states)
+ batch_idxs = torch.zeros((seg_hidden_states.shape[0],),
+ dtype=torch.int64).to(seg_hidden_states.device)
+ pred_masks_list = model.visual_encoder.forward_llm_seg(seg_hidden_states, batch_idxs)
+ pred_masks = pred_masks_list[-1]
+ w, h = ori_size
+ masks = F.interpolate(pred_masks, size=(max(w, h), max(w, h)),
+ mode='bilinear', align_corners=False)
+ masks = masks[:, 0]
+ # remove padding
+ if w == h:
+ pass
+ elif w > h:
+ n_pad = w - h
+ n_pad_1 = n_pad // 2
+ n_pad_2 = n_pad - n_pad_1
+ masks = masks[:, n_pad_1: w - n_pad_2]
+ else:
+ n_pad = h - w
+ n_pad_1 = n_pad // 2
+ n_pad_2 = n_pad - n_pad_1
+ masks = masks[:, :, n_pad_1: h - n_pad_2]
+ # binary
+ masks = masks.sigmoid() > 0.5
+ masks = masks.int() * 255
+
+ video_name = data_sample['video_file']
+ expression_id = data_sample['expressions_id'][idx_inp]
+ if not os.path.exists('./work_dirs/rvos/'):
+ os.mkdir('./work_dirs/rvos/')
+ if not os.path.exists(os.path.join('./work_dirs/rvos/', video_name)):
+ os.mkdir(os.path.join('./work_dirs/rvos/', video_name))
+ if not os.path.exists(os.path.join('./work_dirs/rvos/', video_name, expression_id)):
+ os.mkdir(os.path.join('./work_dirs/rvos/', video_name, expression_id))
+
+ masks = torch.cat([masks, masks, masks], dim=0).permute(1, 2, 0)
+ masks = masks.cpu().numpy()
+ image = Image.fromarray(masks)
+ image.save(os.path.join('./work_dirs/rvos/', video_name, expression_id, data_sample['frames'][0] + '.png'))
+
+def get_seg_hidden_states(hidden_states, output_ids, seg_id):
+ seg_mask = output_ids == seg_id
+ n_out = len(seg_mask)
+ return hidden_states[-n_out:][seg_mask]
+
+
+if __name__ == '__main__':
+
+ main()
diff --git a/omg_llava/tools/region_cap_mask_omg_seg_llava.py b/omg_llava/tools/region_cap_mask_omg_seg_llava.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5b3b43f73912001aad2a3d045b44f3aaaec223d
--- /dev/null
+++ b/omg_llava/tools/region_cap_mask_omg_seg_llava.py
@@ -0,0 +1,508 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import json
+import math
+import os
+import os.path as osp
+import re
+import torch
+import tqdm
+
+from mmengine.dist import (collect_results, get_dist_info, get_rank, init_dist,
+ master_only)
+from mmengine.utils.dl_utils import set_multi_processing
+from torch.utils.data import Dataset
+from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer,
+ BitsAndBytesConfig, CLIPImageProcessor,
+ CLIPVisionModel, GenerationConfig)
+
+from xtuner.model.utils import LoadWoInit
+from omg_llava.model.utils import prepare_inputs_labels_for_multimodal_with_visual_prompts
+from xtuner.tools.utils import get_stop_criteria, is_cn_string
+from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX,
+ PROMPT_TEMPLATE)
+
+from xtuner.registry import BUILDER
+from xtuner.configs import cfgs_name_path
+from xtuner.model.utils import guess_load_checkpoint
+from mmengine.config import Config
+from mmengine.fileio import PetrelBackend, get_file_backend
+from mmengine.config import ConfigDict
+
+from PIL import Image
+import torch.nn.functional as F
+from omg_llava.dataset.utils import expand2square, expand2square_mask
+from pycocotools import mask
+
+from pycocotools.coco import COCO
+import numpy as np
+
+def bbox_to_x1y1x2y2(bbox):
+ x1, y1, w, h = bbox
+ bbox = [x1, y1, x1 + w, y1 + h]
+
+ return bbox
+
+def convert_dict2config_dict(input):
+ input = ConfigDict(**input)
+ for key in input.keys():
+ if isinstance(input[key], dict):
+ input[key] = convert_dict2config_dict(input[key])
+ return input
+
+TORCH_DTYPE_MAP = dict(
+ fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto')
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='RefCocoSeg')
+ parser.add_argument('config', help='config file name or path.')
+ parser.add_argument('pth_model', help='pth model file')
+ parser.add_argument(
+ '--output-path', type=str, default='./work_dirs/region_cap_pred.json', help='Name for Bot')
+ parser.add_argument(
+ '--prompt-template',
+ choices=PROMPT_TEMPLATE.keys(),
+ default='internlm2_chat',
+ help='Specify a prompt template')
+ parser.add_argument(
+ '--stop-words', nargs='+', type=str, default=[], help='Stop words')
+ parser.add_argument(
+ '--torch-dtype',
+ default='fp16',
+ choices=TORCH_DTYPE_MAP.keys(),
+ help='Override the default `torch.dtype` and load the model under '
+ 'a specific `dtype`.')
+ parser.add_argument(
+ '--bits',
+ type=int,
+ choices=[4, 8, None],
+ default=None,
+ help='LLM bits')
+ parser.add_argument(
+ '--bot-name', type=str, default='BOT', help='Name for Bot')
+ parser.add_argument(
+ '--offload-folder',
+ default=None,
+ help='The folder in which to offload the model weights (or where the '
+ 'model weights are already offloaded).')
+ parser.add_argument(
+ '--max-new-tokens',
+ type=int,
+ default=300,
+ help='Maximum number of new tokens allowed in generated text')
+ parser.add_argument(
+ '--seed',
+ type=int,
+ default=0,
+ help='Random seed for reproducible text generation')
+ parser.add_argument(
+ '--launcher',
+ choices=['none', 'pytorch', 'slurm', 'mpi'],
+ default='none',
+ help='job launcher')
+ args = parser.parse_args()
+ return args
+
+
+@master_only
+def master_print(msg):
+ print(msg)
+
+class RegionCap_Inference_Dataset(Dataset):
+ def __init__(self,
+ image_folder,
+ image_processor,
+ pad_image_to_square=True,
+ annotation_file=None,
+ debug=False,
+ ):
+ self.debug = debug
+ self.image_folder = image_folder
+ size = image_processor.crop_size
+ # if isinstance(size, int):
+ # self.image_h, self.image_w = size, size
+ # else:
+ # self.image_w, self.image_h = size
+ self.image_h, self.image_w = 1024, 1024
+
+ if isinstance(image_processor, dict) or isinstance(
+ image_processor, Config) or isinstance(image_processor,
+ ConfigDict):
+ self.image_processor = BUILDER.build(image_processor)
+ else:
+ self.image_processor = image_processor
+ self.pad_image_to_square = pad_image_to_square
+ self.down_ratio = 1
+
+ self.coco = COCO(annotation_file)
+ self.image_dict = self.coco.imgs
+ self.ann_dict = self.coco.anns
+ self.image_dict_keys = list(self.image_dict.keys())
+
+ def __len__(self):
+ return len(self.image_dict_keys)
+
+ def decode_mask(self, annotation, image_info):
+ flag = False
+ masks = []
+
+ for ann_id in range(1):
+ # if isinstance(ann_id, list):
+ # flag = True
+ # if -1 in ann_id:
+ # assert len(ann_id) == 1
+ # m = np.zeros((image_info["height"], image_info["width"])).astype(
+ # np.uint8
+ # )
+ # else:
+ # m_final = np.zeros(
+ # (image_info["height"], image_info["width"])
+ # ).astype(np.uint8)
+ # for ann_id_i in ann_id:
+ # ann = self.annotations[ann_id_i]
+ #
+ # if len(ann["segmentation"]) == 0:
+ # m = np.zeros(
+ # (image_info["height"], image_info["width"])
+ # ).astype(np.uint8)
+ # else:
+ # if type(ann["segmentation"][0]) == list: # polygon
+ # rle = mask.frPyObjects(
+ # ann["segmentation"], image_info["height"], image_info["width"], )
+ # else:
+ # rle = ann["segmentation"]
+ # for i in range(len(rle)):
+ # if not isinstance(rle[i]["counts"], bytes):
+ # rle[i]["counts"] = rle[i]["counts"].encode()
+ # m = mask.decode(rle)
+ # m = np.sum(
+ # m, axis=2
+ # ) # sometimes there are multiple binary map (corresponding to multiple segs)
+ # m = m.astype(np.uint8) # convert to np.uint8
+ # m_final = m_final | m
+ # m = m_final
+ # masks.append(m)
+ # continue
+
+ ann = {"segmentation": annotation}
+
+ if len(ann["segmentation"]) == 0:
+ m = np.zeros((image_info["height"], image_info["width"])).astype(
+ np.uint8
+ )
+ masks.append(m)
+ continue
+
+ if type(ann["segmentation"][0]) == list: # polygon
+ rle = mask.frPyObjects(
+ ann["segmentation"], image_info["height"], image_info["width"]
+ )
+ else:
+ rle = ann["segmentation"]
+ for i in range(len(rle)):
+ if not isinstance(rle[i]["counts"], bytes):
+ rle[i]["counts"] = rle[i]["counts"].encode()
+ m = mask.decode(rle)
+ m = np.sum(m, axis=2) # sometimes there are multiple binary map (corresponding to multiple segs)
+ m = m.astype(np.uint8) # convert to np.uint8
+ masks.append(m)
+ masks = np.stack(masks, axis=0)
+
+ if self.pad_image_to_square:
+ masks = expand2square_mask(masks)
+ masks = torch.from_numpy(masks)
+ masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio,
+ self.image_w // self.down_ratio), mode='nearest').squeeze(0)
+
+ # print(image_info['file_name'])
+ # print(masks.shape)
+ # save_masks = torch.stack([masks[0], masks[0], masks[0]], dim=-1)
+ # save_masks = save_masks.numpy() * 255
+ # save_masks = Image.fromarray(save_masks.astype(np.uint8))
+ # save_masks.save("/root/mask.png")
+ # print(kkk)
+ return masks
+
+ def get_questions(self):
+ question = "Can you provide me with a detailed description of the region in the picture marked by region1 ?"
+ return question
+
+ def __getitem__(self, index):
+
+ data_dict = {}
+
+ image_id = self.image_dict_keys[index]
+ image_file = self.image_dict[image_id]['file_name']
+
+ questions = self.get_questions()
+
+ data_dict['image_file'] = image_file
+ image_file = os.path.join(self.image_folder, image_file)
+ image = Image.open(image_file).convert('RGB')
+ ori_width, ori_height = image.size
+ if self.pad_image_to_square:
+ image = expand2square(
+ image,
+ tuple(
+ int(x * 255) for x in self.image_processor.image_mean))
+ image = self.image_processor.preprocess(
+ image, return_tensors='pt')['pixel_values'][0]
+ data_dict['pixel_values'] = image
+ data_dict['ori_size'] = (ori_width, ori_height)
+ data_dict['questions'] = questions
+
+ masks = self.ann_dict[image_id]['segmentation']
+ image_info = self.image_dict[image_id]
+ masks = self.decode_mask(masks, image_info)
+
+ data_dict['regions'] = masks
+ data_dict['image_id'] = image_id
+
+ return data_dict
+
+def main():
+ args = parse_args()
+
+ torch.manual_seed(args.seed)
+
+ if args.launcher != 'none':
+ set_multi_processing(distributed=True)
+ init_dist(args.launcher)
+
+ rank, world_size = get_dist_info()
+ torch.cuda.set_device(rank)
+ else:
+ rank = 0
+ world_size = 1
+
+ # build model
+ if not osp.isfile(args.config):
+ try:
+ args.config = cfgs_name_path[args.config]
+ except KeyError:
+ raise FileNotFoundError(f'Cannot find {args.config}')
+
+ # load config
+ cfg = Config.fromfile(args.config)
+ # if args.cfg_options is not None:
+ # cfg.merge_from_dict(args.cfg_options)
+
+ model_name = cfg.model.type if isinstance(cfg.model.type,
+ str) else cfg.model.type.__name__
+ if 'LLaVAModel' or 'OMG' in model_name:
+ cfg.model.pretrained_pth = None
+
+ model = BUILDER.build(cfg.model)
+ backend = get_file_backend(args.pth_model)
+ if isinstance(backend, PetrelBackend):
+ from xtuner.utils.fileio import patch_fileio
+ with patch_fileio():
+ state_dict = guess_load_checkpoint(args.pth_model)
+ else:
+ state_dict = guess_load_checkpoint(args.pth_model)
+
+ # del state_dict['llm.base_model.model.model.tok_embeddings.weight']
+ model.load_state_dict(state_dict, strict=False)
+ print(f'Load PTH model from {args.pth_model}')
+ # image_processor_cfg = copy.deepcopy(cfg.image_processor)
+ image_processor = cfg.image_processor
+ image_processor_type = image_processor['type']
+ del image_processor['type']
+ image_processor = image_processor_type(**image_processor)
+
+ llm = model.llm
+ tokenizer = model.tokenizer
+
+ model.cuda()
+ model.eval()
+ llm.eval()
+ visual_encoder = model.visual_encoder
+ projector = model.projector
+ projector_text2vision = model.projector_text2vision
+
+ projector.cuda()
+ projector.eval()
+
+ visual_encoder.cuda()
+ visual_encoder.eval()
+
+ stop_words = args.stop_words
+ if args.prompt_template:
+ template = PROMPT_TEMPLATE[args.prompt_template]
+ stop_words += template.get('STOP_WORDS', [])
+ stop_criteria = get_stop_criteria(
+ tokenizer=tokenizer, stop_words=stop_words)
+
+ gen_config = GenerationConfig(
+ max_new_tokens=args.max_new_tokens,
+ do_sample=False,
+ eos_token_id=tokenizer.eos_token_id,
+ pad_token_id=tokenizer.pad_token_id
+ if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
+ )
+
+ dataset = RegionCap_Inference_Dataset(
+ annotation_file='./data/region_caption/refcocog/finetune_refcocog_val_with_mask.json',
+ image_folder='./data/glamm_data/images/coco2014/train2014/',
+ image_processor=image_processor,
+ pad_image_to_square=True,
+ debug=False,
+ # debug=True,
+ )
+ n_samples = len(dataset)
+ per_rank_samples = math.ceil(n_samples / world_size)
+
+ per_rank_ids = range(per_rank_samples * rank,
+ min(n_samples, per_rank_samples * (rank + 1)))
+ results = []
+ for i in tqdm.tqdm(per_rank_ids, desc=f'Rank {rank}'):
+ # pixel feature
+ data_sample = dataset[i]
+ image = data_sample['pixel_values'] # ()
+ image = image.cuda().unsqueeze(0).to(visual_encoder.dtype)
+ visual_outputs = visual_encoder(image, output_hidden_states=True)
+ if isinstance(visual_outputs, list) or isinstance(visual_outputs, tuple) \
+ or isinstance(visual_outputs, torch.Tensor):
+ pixel_values = projector(visual_outputs)
+ else:
+ pixel_values = projector(
+ visual_outputs.hidden_states[args.visual_select_layer][:, 1:])
+
+
+ questions = data_sample['questions']
+ regions = data_sample['regions']
+ texts = DEFAULT_IMAGE_TOKEN + '\n' + questions
+
+ if args.prompt_template:
+ prompt_text = ''
+ template = PROMPT_TEMPLATE[args.prompt_template]
+ prompt_text += template['INSTRUCTION'].format(
+ input=texts, round=1, bot_name=args.bot_name)
+ else:
+ prompt_text = texts
+
+ batch_inputs = prompt_text
+
+ predict = forward_model(
+ batch_inputs, pixel_values,
+ tokenizer, model, llm,
+ projector_text2vision,
+ gen_config, stop_criteria, points=regions,
+ mark_token_id=model.mark_token_idx,
+ width=image.shape[-1], height=image.shape[-2],
+ visual_encoder=visual_encoder, projector=projector
+ )
+
+ text_output = predict.replace("", "").replace("\n", "").replace(" ", " ")
+ text_output = text_output.split("ASSISTANT: ")[-1]
+
+ cleaned_str = re.sub(r'<.*?>', '', text_output)
+
+ pattern = re.compile(r'", "").replace("\n", "")\
+ .replace("region1", '').replace("Region1", '')\
+ .replace(':', '').replace(" ", " ").replace(" ", " ")
+ text_output = text_output.split("ASSISTANT: ")[-1]
+
+ cleaned_str = re.sub(r'<.*?>', '', text_output)
+
+ # Remove the [SEG] token
+ cleaned_str = cleaned_str.replace('[SEG]', '')
+
+ # only select 1 setence for eval
+ # cleaned_str = cleaned_str.split('.')[0]
+
+ # Strip unnecessary spaces
+ cleaned_str = ' '.join(cleaned_str.split()).strip("'")
+ cleaned_str = cleaned_str.strip()
+
+ result_dict = {}
+ result_dict["image_id"] = data_sample['image_id']
+ result_dict["caption"] = cleaned_str
+ result_dict["image_file"] = data_sample['image_file']
+ result_dict["prediction"] = cleaned_str
+ results.append(result_dict)
+ print(cleaned_str)
+
+ results = collect_results(results, n_samples)
+
+ if get_rank() == 0:
+ with open(args.output_path, 'w') as json_file:
+ json.dump(results, json_file, indent=2)
+
+def forward_model(question, pixel_values,
+ tokenizer, model, llm,
+ projector_text2vision,
+ gen_config, stop_criteria,
+ mark_token_id=None,
+ points=None, width=None, height=None,
+ visual_encoder=None, projector=None):
+ # pixel_values = projector(
+ # visual_outputs.hidden_states[args.visual_select_layer][:, 1:])
+
+ inputs = question
+ # print("Question: ", inputs)
+ chunk_encode = []
+ for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)):
+ if idx == 0:
+ cur_encode = tokenizer.encode(chunk)
+ else:
+ cur_encode = tokenizer.encode(chunk, add_special_tokens=False)
+ chunk_encode.append(cur_encode)
+ assert len(chunk_encode) == 2
+ ids = []
+ for idx, cur_chunk_encode in enumerate(chunk_encode):
+ ids.extend(cur_chunk_encode)
+ if idx != len(chunk_encode) - 1:
+ ids.append(IMAGE_TOKEN_INDEX)
+ ids = torch.tensor(ids).cuda().unsqueeze(0)
+ points = points.cuda()
+
+ points_mark_embedding = get_points_embeddings(
+ points, ids, width, height,
+ mark_token_id, visual_encoder,
+ projector)
+
+ mm_inputs = prepare_inputs_labels_for_multimodal_with_visual_prompts(
+ llm=llm, input_ids=ids, pixel_values=pixel_values,
+ mark_id=mark_token_id,
+ mark_feats=points_mark_embedding, region_id=-9999)
+
+ generate_output = llm.generate(
+ **mm_inputs,
+ generation_config=gen_config,
+ streamer=None,
+ bos_token_id=tokenizer.bos_token_id,
+ stopping_criteria=stop_criteria,
+ output_hidden_states=True,
+ return_dict_in_generate=True
+ )
+ predict = tokenizer.decode(
+ # generate_output.sequences[0], skip_special_tokens=True).strip()
+ generate_output.sequences[0], skip_special_tokens=True).strip()
+ return predict
+
+
+def get_points_embeddings(points, input_ids, width, height,
+ mark_token_idx, visual_encoder,
+ projector):
+ if points is None or len(points) == 0:
+ return []
+
+ mark_token_mask = input_ids == mark_token_idx
+ batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to(
+ input_ids.device)
+ batch_idxs = batch_idxs[mark_token_mask] # (N, ) batch_size number
+
+ # points = points.to(torch.float32)
+ # print(points.dtype, batch_idxs.dtype)
+ # marks_embeddings = visual_encoder.forward_point_sam(
+ # points, batch_idxs, width=width, height=height
+ # )[:, 0] # (N, C)
+
+ marks_embeddings = visual_encoder.forward_region_sam(
+ points, batch_idxs
+ )[:, 0] # (N, C)
+
+ marks_embeddings = marks_embeddings.to(projector.model.query_proj.weight.dtype)
+ marks_embeddings = projector.model.query_proj(marks_embeddings)
+ marks_embeddings = projector.model.model(marks_embeddings)
+ return marks_embeddings # (N, C)
+
+if __name__ == '__main__':
+
+ main()
diff --git a/omg_llava/tools/region_cap_omg_seg_llava.py b/omg_llava/tools/region_cap_omg_seg_llava.py
new file mode 100644
index 0000000000000000000000000000000000000000..51315e43b99b53d7191f1f5072992f001907da48
--- /dev/null
+++ b/omg_llava/tools/region_cap_omg_seg_llava.py
@@ -0,0 +1,426 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import json
+import math
+import os
+import os.path as osp
+import re
+import torch
+import tqdm
+
+from mmengine.dist import (collect_results, get_dist_info, get_rank, init_dist,
+ master_only)
+from mmengine.utils.dl_utils import set_multi_processing
+from torch.utils.data import Dataset
+from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer,
+ BitsAndBytesConfig, CLIPImageProcessor,
+ CLIPVisionModel, GenerationConfig)
+
+from xtuner.model.utils import LoadWoInit, prepare_inputs_labels_for_multimodal, prepare_inputs_labels_for_multimodal_with_region
+from xtuner.tools.utils import get_stop_criteria, is_cn_string
+from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX,
+ PROMPT_TEMPLATE)
+
+from xtuner.registry import BUILDER
+from xtuner.configs import cfgs_name_path
+from xtuner.model.utils import guess_load_checkpoint
+from mmengine.config import Config
+from mmengine.fileio import PetrelBackend, get_file_backend
+from mmengine.config import ConfigDict
+
+from PIL import Image
+import torch.nn.functional as F
+from xtuner.dataset.utils import expand2square, expand2square_points
+from pycocotools import mask as mask_utils
+
+from pycocotools.coco import COCO
+import numpy as np
+
+def bbox_to_x1y1x2y2(bbox):
+ x1, y1, w, h = bbox
+ bbox = [x1, y1, x1 + w, y1 + h]
+
+ return bbox
+
+def convert_dict2config_dict(input):
+ input = ConfigDict(**input)
+ for key in input.keys():
+ if isinstance(input[key], dict):
+ input[key] = convert_dict2config_dict(input[key])
+ return input
+
+TORCH_DTYPE_MAP = dict(
+ fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto')
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='RefCocoSeg')
+ parser.add_argument('config', help='config file name or path.')
+ parser.add_argument('pth_model', help='pth model file')
+ parser.add_argument(
+ '--prompt-template',
+ choices=PROMPT_TEMPLATE.keys(),
+ default='internlm2_chat',
+ help='Specify a prompt template')
+ parser.add_argument(
+ '--stop-words', nargs='+', type=str, default=[], help='Stop words')
+ parser.add_argument(
+ '--torch-dtype',
+ default='fp16',
+ choices=TORCH_DTYPE_MAP.keys(),
+ help='Override the default `torch.dtype` and load the model under '
+ 'a specific `dtype`.')
+ parser.add_argument(
+ '--bits',
+ type=int,
+ choices=[4, 8, None],
+ default=None,
+ help='LLM bits')
+ parser.add_argument(
+ '--bot-name', type=str, default='BOT', help='Name for Bot')
+ parser.add_argument(
+ '--offload-folder',
+ default=None,
+ help='The folder in which to offload the model weights (or where the '
+ 'model weights are already offloaded).')
+ parser.add_argument(
+ '--max-new-tokens',
+ type=int,
+ default=100,
+ help='Maximum number of new tokens allowed in generated text')
+ parser.add_argument(
+ '--seed',
+ type=int,
+ default=0,
+ help='Random seed for reproducible text generation')
+ parser.add_argument(
+ '--launcher',
+ choices=['none', 'pytorch', 'slurm', 'mpi'],
+ default='none',
+ help='job launcher')
+ args = parser.parse_args()
+ return args
+
+
+@master_only
+def master_print(msg):
+ print(msg)
+
+class RegionCap_Inference_Dataset(Dataset):
+ def __init__(self,
+ image_folder,
+ image_processor,
+ pad_image_to_square=True,
+ annotation_file=None,
+ debug=False,
+ ):
+ self.debug = debug
+ self.image_folder = image_folder
+ size = image_processor.crop_size
+ # if isinstance(size, int):
+ # self.image_h, self.image_w = size, size
+ # else:
+ # self.image_w, self.image_h = size
+ self.image_h, self.image_w = 1024, 1024
+
+ if isinstance(image_processor, dict) or isinstance(
+ image_processor, Config) or isinstance(image_processor,
+ ConfigDict):
+ self.image_processor = BUILDER.build(image_processor)
+ else:
+ self.image_processor = image_processor
+ self.pad_image_to_square = pad_image_to_square
+ self.down_ratio = 1
+
+ self.coco = COCO(annotation_file)
+ self.image_dict = self.coco.imgs
+ self.ann_dict = self.coco.anns
+ self.image_dict_keys = list(self.image_dict.keys())
+
+ def __len__(self):
+ return len(self.image_dict_keys)
+
+ def get_questions(self):
+ question = "Can you provide me with a detailed description of the region in the picture marked by region1 ?"
+ return question
+
+ def __getitem__(self, index):
+
+ data_dict = {}
+
+ image_id = self.image_dict_keys[index]
+ image_file = self.image_dict[image_id]['file_name']
+ gt = self.ann_dict[image_id]['caption']
+
+ questions = self.get_questions()
+
+ data_dict['image_file'] = image_file
+ image_file = os.path.join(self.image_folder, image_file)
+ print(image_file)
+ image = Image.open(image_file).convert('RGB')
+ ori_width, ori_height = image.size
+ if self.pad_image_to_square:
+ image = expand2square(
+ image,
+ tuple(
+ int(x * 255) for x in self.image_processor.image_mean))
+ image = self.image_processor.preprocess(
+ image, return_tensors='pt')['pixel_values'][0]
+ data_dict['pixel_values'] = image
+ data_dict['ori_size'] = (ori_width, ori_height)
+ data_dict['questions'] = questions
+
+ bbox = bbox_to_x1y1x2y2(self.ann_dict[image_id]['bbox'])
+ bbox = np.array([bbox])
+ points = (bbox[:, :2] + bbox[:, 2:]) / 2.
+ if self.pad_image_to_square:
+ points = expand2square_points(points, height=ori_height, width=ori_width)
+ points[:, 0] = points[:, 0] / max(ori_height, ori_width) * self.image_w
+ points[:, 1] = points[:, 1] / max(ori_height, ori_width) * self.image_h
+ else:
+ points[:, 0] = points[:, 0] / ori_width * self.image_w
+ points[:, 1] = points[:, 1] / ori_height * self.image_h
+ data_dict['points'] = torch.from_numpy(points)
+ print(data_dict['points'])
+ data_dict['image_id'] = image_id
+
+ return data_dict
+
+def main():
+ args = parse_args()
+
+ torch.manual_seed(args.seed)
+
+ if args.launcher != 'none':
+ set_multi_processing(distributed=True)
+ init_dist(args.launcher)
+
+ rank, world_size = get_dist_info()
+ torch.cuda.set_device(rank)
+ else:
+ rank = 0
+ world_size = 1
+
+ # build model
+ if not osp.isfile(args.config):
+ try:
+ args.config = cfgs_name_path[args.config]
+ except KeyError:
+ raise FileNotFoundError(f'Cannot find {args.config}')
+
+ # load config
+ cfg = Config.fromfile(args.config)
+ # if args.cfg_options is not None:
+ # cfg.merge_from_dict(args.cfg_options)
+
+ model_name = cfg.model.type if isinstance(cfg.model.type,
+ str) else cfg.model.type.__name__
+ if 'LLaVAModel' in model_name:
+ cfg.model.pretrained_pth = None
+
+ model = BUILDER.build(cfg.model)
+ backend = get_file_backend(args.pth_model)
+ if isinstance(backend, PetrelBackend):
+ from xtuner.utils.fileio import patch_fileio
+ with patch_fileio():
+ state_dict = guess_load_checkpoint(args.pth_model)
+ else:
+ state_dict = guess_load_checkpoint(args.pth_model)
+
+ # del state_dict['llm.base_model.model.model.tok_embeddings.weight']
+ model.load_state_dict(state_dict, strict=False)
+ print(f'Load PTH model from {args.pth_model}')
+ # image_processor_cfg = copy.deepcopy(cfg.image_processor)
+ image_processor = cfg.image_processor
+ image_processor_type = image_processor['type']
+ del image_processor['type']
+ image_processor = image_processor_type(**image_processor)
+
+ llm = model.llm
+ tokenizer = model.tokenizer
+
+ model.cuda()
+ model.eval()
+ llm.eval()
+ visual_encoder = model.visual_encoder
+ projector = model.projector
+ projector_text2vision = model.projector_text2vision
+
+ projector.cuda()
+ projector.eval()
+
+ visual_encoder.cuda()
+ visual_encoder.eval()
+
+ stop_words = args.stop_words
+ if args.prompt_template:
+ template = PROMPT_TEMPLATE[args.prompt_template]
+ stop_words += template.get('STOP_WORDS', [])
+ stop_criteria = get_stop_criteria(
+ tokenizer=tokenizer, stop_words=stop_words)
+
+ gen_config = GenerationConfig(
+ max_new_tokens=args.max_new_tokens,
+ do_sample=False,
+ eos_token_id=tokenizer.eos_token_id,
+ pad_token_id=tokenizer.pad_token_id
+ if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
+ )
+
+ dataset = RegionCap_Inference_Dataset(
+ annotation_file='./data/region_caption/mdetr_annotations/finetune_refcocog_val_captions.json',
+ image_folder='./data/glamm_data/images/coco2014/train2014/',
+ image_processor=image_processor,
+ pad_image_to_square=True,
+ debug=False,
+ # debug=True,
+ )
+ n_samples = len(dataset)
+ per_rank_samples = math.ceil(n_samples / world_size)
+
+ per_rank_ids = range(per_rank_samples * rank,
+ min(n_samples, per_rank_samples * (rank + 1)))
+ results = []
+ for i in tqdm.tqdm(per_rank_ids, desc=f'Rank {rank}'):
+ # pixel feature
+ data_sample = dataset[i]
+ image = data_sample['pixel_values'] # ()
+ image = image.cuda().unsqueeze(0).to(visual_encoder.dtype)
+ visual_outputs = visual_encoder(image, output_hidden_states=True)
+ if isinstance(visual_outputs, list) or isinstance(visual_outputs, tuple) \
+ or isinstance(visual_outputs, torch.Tensor):
+ pixel_values = projector(visual_outputs)
+ else:
+ pixel_values = projector(
+ visual_outputs.hidden_states[args.visual_select_layer][:, 1:])
+
+
+ questions = data_sample['questions']
+ points = data_sample['points']
+ texts = DEFAULT_IMAGE_TOKEN + '\n' + questions
+
+ if args.prompt_template:
+ prompt_text = ''
+ template = PROMPT_TEMPLATE[args.prompt_template]
+ prompt_text += template['INSTRUCTION'].format(
+ input=texts, round=1, bot_name=args.bot_name)
+ else:
+ prompt_text = texts
+
+ batch_inputs = prompt_text
+
+ predict = forward_model(
+ batch_inputs, pixel_values,
+ tokenizer, model, llm,
+ projector_text2vision,
+ gen_config, stop_criteria, points=points,
+ mark_token_id=model.mark_token_idx,
+ width=image.shape[-1], height=image.shape[-2],
+ visual_encoder=visual_encoder, projector=projector
+ )
+
+ text_output = predict.replace("", "").replace("\n", "")\
+ .replace("region1", '').replace(':', '').replace(" ", " ")
+ text_output = text_output.split("ASSISTANT: ")[-1]
+
+ cleaned_str = re.sub(r'<.*?>', '', text_output)
+
+ # Remove the [SEG] token
+ cleaned_str = cleaned_str.replace('[SEG]', '')
+
+ # only select 1 setence for eval
+ # cleaned_str = cleaned_str.split('.')[0]
+
+ # Strip unnecessary spaces
+ cleaned_str = ' '.join(cleaned_str.split()).strip("'")
+ cleaned_str = cleaned_str.strip()
+
+ result_dict = {}
+ result_dict["image_id"] = data_sample['image_id']
+ result_dict["caption"] = cleaned_str
+ results.append(result_dict)
+ print(cleaned_str)
+
+ results = collect_results(results, n_samples)
+
+ if get_rank() == 0:
+ with open('./work_dirs/region_cap_pred.json', 'w') as json_file:
+ json.dump(results, json_file, indent=2)
+
+def forward_model(question, pixel_values,
+ tokenizer, model, llm,
+ projector_text2vision,
+ gen_config, stop_criteria,
+ mark_token_id=None,
+ points=None, width=None, height=None,
+ visual_encoder=None, projector=None):
+ # pixel_values = projector(
+ # visual_outputs.hidden_states[args.visual_select_layer][:, 1:])
+
+ inputs = question
+ # print("Question: ", inputs)
+ chunk_encode = []
+ for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)):
+ if idx == 0:
+ cur_encode = tokenizer.encode(chunk)
+ else:
+ cur_encode = tokenizer.encode(chunk, add_special_tokens=False)
+ chunk_encode.append(cur_encode)
+ assert len(chunk_encode) == 2
+ ids = []
+ for idx, cur_chunk_encode in enumerate(chunk_encode):
+ ids.extend(cur_chunk_encode)
+ if idx != len(chunk_encode) - 1:
+ ids.append(IMAGE_TOKEN_INDEX)
+ ids = torch.tensor(ids).cuda().unsqueeze(0)
+ points = points.cuda()
+
+ points_mark_embedding = get_points_embeddings(
+ points, ids, width, height,
+ mark_token_id, visual_encoder,
+ projector)
+
+ mm_inputs = prepare_inputs_labels_for_multimodal_with_region(
+ llm=llm, input_ids=ids, pixel_values=pixel_values,
+ mark_id=mark_token_id,
+ mark_feats=points_mark_embedding, region_id=-9999)
+
+ generate_output = llm.generate(
+ **mm_inputs,
+ generation_config=gen_config,
+ streamer=None,
+ bos_token_id=tokenizer.bos_token_id,
+ stopping_criteria=stop_criteria,
+ output_hidden_states=True,
+ return_dict_in_generate=True
+ )
+ predict = tokenizer.decode(
+ # generate_output.sequences[0], skip_special_tokens=True).strip()
+ generate_output.sequences[0]).strip()
+ return predict
+
+
+def get_points_embeddings(points, input_ids, width, height,
+ mark_token_idx, visual_encoder,
+ projector):
+ if points is None or len(points) == 0:
+ return []
+
+ mark_token_mask = input_ids == mark_token_idx
+ batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to(
+ input_ids.device)
+ batch_idxs = batch_idxs[mark_token_mask] # (N, ) batch_size number
+
+ points = points.to(torch.float32)
+ # print(points.dtype, batch_idxs.dtype)
+ marks_embeddings = visual_encoder.forward_point_sam(
+ points, batch_idxs, width=width, height=height
+ )[:, 0] # (N, C)
+
+ marks_embeddings = marks_embeddings.to(projector.model.query_proj.weight.dtype)
+ marks_embeddings = projector.model.query_proj(marks_embeddings)
+ marks_embeddings = projector.model.model(marks_embeddings)
+ return marks_embeddings # (N, C)
+
+if __name__ == '__main__':
+
+ main()
diff --git a/omg_llava/tools/seg_cap_omg_llava.py b/omg_llava/tools/seg_cap_omg_llava.py
new file mode 100644
index 0000000000000000000000000000000000000000..951f7048fe51820cc7cdfe39b34c46b569fda1f3
--- /dev/null
+++ b/omg_llava/tools/seg_cap_omg_llava.py
@@ -0,0 +1,654 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import copy
+import os
+import os.path as osp
+import re
+import sys
+
+import torch
+from huggingface_hub import snapshot_download
+from peft import PeftModel
+from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer,
+ BitsAndBytesConfig, CLIPImageProcessor,
+ CLIPVisionModel, GenerationConfig)
+from transformers.generation.streamers import TextStreamer
+
+from xtuner.dataset.utils import expand2square, load_image, expand2square_points
+from xtuner.model.utils import prepare_inputs_labels_for_multimodal, prepare_inputs_labels_for_multimodal_with_region
+
+from xtuner.tools.utils import get_stop_criteria
+from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX,
+ PROMPT_TEMPLATE, SYSTEM_TEMPLATE)
+
+import argparse
+import os.path as osp
+
+from mmengine.config import Config, DictAction
+from mmengine.fileio import PetrelBackend, get_file_backend
+
+from xtuner.configs import cfgs_name_path
+from xtuner.model.utils import guess_load_checkpoint
+from xtuner.registry import BUILDER
+import numpy as np
+
+TORCH_DTYPE_MAP = dict(
+ fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto')
+
+from xtuner.engine.hooks.evaluate_chat_hook import EvaluateChatHook
+
+# def get_points_embeddings(points, input_ids, width, height,
+# mark_token_idx, visual_encoder,
+# projector):
+# if points is None or len(points) == 0:
+# return []
+#
+# mark_token_mask = input_ids == mark_token_idx
+# batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to(
+# input_ids.device)
+# batch_idxs = batch_idxs[mark_token_mask] # (N, ) batch_size number
+#
+# points = points.to(torch.float32)
+# # print(points.dtype, batch_idxs.dtype)
+# marks_embeddings = visual_encoder.forward_point_sam(
+# points, batch_idxs, width=width, height=height
+# )[:, 0] # (N, C)
+#
+# marks_embeddings = marks_embeddings.to(projector.model.query_proj.weight.dtype)
+# marks_embeddings = projector.model.query_proj(marks_embeddings)
+# marks_embeddings = projector.model.model(marks_embeddings)
+# return marks_embeddings # (N, C)
+
+def get_points_embeddings(points, input_ids, width, height,
+ mark_token_idx, visual_encoder,
+ projector):
+ if points is None or len(points) == 0:
+ return []
+
+ mark_token_mask = input_ids == mark_token_idx
+ batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to(
+ input_ids.device)
+ batch_idxs = batch_idxs[mark_token_mask] # (N, ) batch_size number
+
+ # points = points.to(torch.float32)
+ # print(points.dtype, batch_idxs.dtype)
+ # marks_embeddings = visual_encoder.forward_point_sam(
+ # points, batch_idxs, width=width, height=height
+ # )[:, 0] # (N, C)
+
+ marks_embeddings = visual_encoder.forward_region_sam(
+ points, batch_idxs
+ )[:, 0] # (N, C)
+
+ marks_embeddings = marks_embeddings.to(projector.model.query_proj.weight.dtype)
+ marks_embeddings = projector.model.query_proj(marks_embeddings)
+ marks_embeddings = projector.model.model(marks_embeddings)
+ return marks_embeddings # (N, C)
+
+def remove_prefix(state_dict, prefix):
+ new_state_dict = {}
+ for key, value in state_dict.items():
+ if key.startswith(prefix):
+ new_key = key[len(prefix):]
+ new_state_dict[new_key] = value
+ else:
+ new_state_dict[key] = value
+ return new_state_dict
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Chat with a HF model')
+ parser.add_argument('config', help='config file name or path.')
+ parser.add_argument('pth_model', help='pth model file')
+
+ parser.add_argument('--image', default=None, help='image')
+ parser.add_argument(
+ '--torch-dtype',
+ default='fp16',
+ choices=TORCH_DTYPE_MAP.keys(),
+ help='Override the default `torch.dtype` and load the model under '
+ 'a specific `dtype`.')
+ parser.add_argument(
+ '--prompt-template',
+ choices=PROMPT_TEMPLATE.keys(),
+ default="internlm2_chat",
+ help='Specify a prompt template')
+ system_group = parser.add_mutually_exclusive_group()
+ system_group.add_argument(
+ '--system', default=None, help='Specify the system text')
+ system_group.add_argument(
+ '--system-template',
+ choices=SYSTEM_TEMPLATE.keys(),
+ default=None,
+ help='Specify a system template')
+ parser.add_argument(
+ '--bits',
+ type=int,
+ choices=[4, 8, None],
+ default=None,
+ help='LLM bits')
+ parser.add_argument(
+ '--bot-name', type=str, default='BOT', help='Name for Bot')
+ parser.add_argument(
+ '--with-plugins',
+ nargs='+',
+ choices=['calculate', 'solve', 'search'],
+ help='Specify plugins to use')
+ parser.add_argument(
+ '--no-streamer', action='store_true', help='Whether to with streamer')
+ parser.add_argument(
+ '--lagent', action='store_true', help='Whether to use lagent')
+ parser.add_argument(
+ '--stop-words', nargs='+', type=str, default=[], help='Stop words')
+ parser.add_argument(
+ '--offload-folder',
+ default=None,
+ help='The folder in which to offload the model weights (or where the '
+ 'model weights are already offloaded).')
+ parser.add_argument(
+ '--max-new-tokens',
+ type=int,
+ default=2048,
+ help='Maximum number of new tokens allowed in generated text')
+ parser.add_argument(
+ '--temperature',
+ type=float,
+ default=0.1,
+ help='The value used to modulate the next token probabilities.')
+ parser.add_argument(
+ '--top-k',
+ type=int,
+ default=40,
+ help='The number of highest probability vocabulary tokens to '
+ 'keep for top-k-filtering.')
+ parser.add_argument(
+ '--top-p',
+ type=float,
+ default=0.75,
+ help='If set to float < 1, only the smallest set of most probable '
+ 'tokens with probabilities that add up to top_p or higher are '
+ 'kept for generation.')
+ parser.add_argument(
+ '--repetition-penalty',
+ type=float,
+ default=1.0,
+ help='The parameter for repetition penalty. 1.0 means no penalty.')
+ parser.add_argument(
+ '--seed',
+ type=int,
+ default=0,
+ help='Random seed for reproducible text generation')
+ args = parser.parse_args()
+ return args
+
+
+def get_input():
+ """Helper function for getting input from users."""
+ sentinel = '' # ends when this string is seen
+ result = None
+ while result is None:
+ print(('\ndouble enter to end input (EXIT: exit chat, '
+ 'RESET: reset history) >>> '),
+ end='')
+ try:
+ result = '\n'.join(iter(input, sentinel))
+ except UnicodeDecodeError:
+ print('Invalid characters detected. Please enter again.')
+ return result
+
+
+def main():
+ args = parse_args()
+ torch.manual_seed(args.seed)
+
+ # parse config
+ if not osp.isfile(args.config):
+ try:
+ args.config = cfgs_name_path[args.config]
+ except KeyError:
+ raise FileNotFoundError(f'Cannot find {args.config}')
+
+ # load config
+ cfg = Config.fromfile(args.config)
+ # if args.cfg_options is not None:
+ # cfg.merge_from_dict(args.cfg_options)
+
+ model_name = cfg.model.type if isinstance(cfg.model.type,
+ str) else cfg.model.type.__name__
+ if 'LLaVAModel' in model_name:
+ cfg.model.pretrained_pth = None
+
+ model = BUILDER.build(cfg.model)
+ print(model.state_dict().keys())
+
+ # pre_state_dict = torch.load("/root/omg-llava.pth")
+ # model.load_state_dict(pre_state_dict)
+
+ backend = get_file_backend(args.pth_model)
+ if isinstance(backend, PetrelBackend):
+ from xtuner.utils.fileio import patch_fileio
+ with patch_fileio():
+ state_dict = guess_load_checkpoint(args.pth_model)
+ else:
+ state_dict = guess_load_checkpoint(args.pth_model)
+
+ print(state_dict.keys())
+ # del state_dict['llm.base_model.model.model.tok_embeddings.weight']
+ model.load_state_dict(state_dict, strict=False)
+ print(f'Load PTH model from {args.pth_model}')
+
+ # image_processor_cfg = copy.deepcopy(cfg.image_processor)
+ image_processor = cfg.image_processor
+ image_processor_type = image_processor['type']
+ del image_processor['type']
+ image_processor = image_processor_type(**image_processor)
+
+ # chat_hook = EvaluateChatHook(
+ # tokenizer=cfg.tokenizer,
+ # image_processor=image_processor_cfg,
+ # every_n_iters=100,
+ # evaluation_inputs=cfg.evaluation_inputs,
+ # evaluation_images=cfg.evaluation_images,
+ # system='',
+ # prompt_template=PROMPT_TEMPLATE.internlm2_chat
+ # )
+ # model.cuda()
+ # model.eval()
+ # chat_hook._eval_images_(model, model.device, max_new_tokens=200)
+
+ # build llm
+ quantization_config = None
+ load_in_8bit = False
+ if args.bits == 4:
+ quantization_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ load_in_8bit=False,
+ llm_int8_threshold=6.0,
+ llm_int8_has_fp16_weight=False,
+ bnb_4bit_compute_dtype=torch.float16,
+ bnb_4bit_use_double_quant=True,
+ bnb_4bit_quant_type='nf4')
+ elif args.bits == 8:
+ load_in_8bit = True
+ model_kwargs = {
+ 'quantization_config': quantization_config,
+ 'load_in_8bit': load_in_8bit,
+ 'device_map': 'auto',
+ 'offload_folder': args.offload_folder,
+ 'trust_remote_code': True,
+ 'torch_dtype': TORCH_DTYPE_MAP[args.torch_dtype]
+ }
+ if False:
+ pass
+ else:
+ if args.with_plugins is None:
+ inner_thoughts_open = False
+ calculate_open = False
+ solve_open = False
+ search_open = False
+ else:
+ assert args.prompt_template == args.system_template == 'moss_sft'
+ from plugins import plugins_api
+ inner_thoughts_open = True
+ calculate_open = 'calculate' in args.with_plugins
+ solve_open = 'solve' in args.with_plugins
+ search_open = 'search' in args.with_plugins
+ # pre-import for api and model preparation
+ if calculate_open:
+ from plugins import calculate # noqa: F401
+ if solve_open:
+ from plugins import solve # noqa: F401
+ if search_open:
+ from plugins import search # noqa: F401
+ # build llm
+ llm = model.llm
+ tokenizer = model.tokenizer
+
+ model.cuda()
+ model.eval()
+ llm.eval()
+ visual_encoder = model.visual_encoder
+ projector = model.projector
+ projector_text2vision = model.projector_text2vision
+
+ if args.image is not None:
+ image = load_image(args.image)
+ ori_width, ori_height = image.size
+ image = expand2square(
+ image, tuple(int(x * 255) for x in image_processor.image_mean))
+ image_for_show = image
+ image = image_processor.preprocess(
+ image, return_tensors='pt')['pixel_values'][0]
+ image = image.cuda().unsqueeze(0).to(visual_encoder.dtype)
+ visual_outputs = visual_encoder(image, output_hidden_states=True)
+ print([item.shape for item in visual_outputs])
+ pixel_values = projector(visual_outputs)
+
+ n_obj = len(projector.model.valid_queries_embeddings[0])
+ stop_words = args.stop_words
+ sep = ''
+ if args.prompt_template:
+ template = PROMPT_TEMPLATE[args.prompt_template]
+ stop_words += template.get('STOP_WORDS', [])
+ sep = template.get('SEP', '')
+ stop_criteria = get_stop_criteria(
+ tokenizer=tokenizer, stop_words=stop_words)
+
+ if args.no_streamer:
+ streamer = None
+ else:
+ streamer = TextStreamer(tokenizer, skip_prompt=True)
+
+ gen_config = GenerationConfig(
+ max_new_tokens=args.max_new_tokens,
+ do_sample=args.temperature > 0,
+ temperature=args.temperature,
+ top_p=args.top_p,
+ top_k=args.top_k,
+ repetition_penalty=args.repetition_penalty,
+ eos_token_id=tokenizer.eos_token_id,
+ pad_token_id=tokenizer.pad_token_id
+ if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
+ )
+
+ n_turn = 0
+ inputs = ''
+ obj_index = 0
+ while True:
+ # text = get_input()
+ if n_turn == 0:
+ text = 'There are some Regions:'
+ text = text + ' region{} '.format(1)
+ text = text + '.\n'
+ text = text + 'Please detailed describe the regions.'
+ else:
+ text = 'RESET'
+ obj_index += 1
+ if obj_index == n_obj:
+ text = 'EXIT'
+ print('Log: Exit!')
+ exit(0)
+ n_turn = 0
+ inputs = ''
+ continue
+
+ while text.strip() == 'RESET':
+ print('Log: History responses have been removed!')
+ n_turn = 0
+ inputs = ''
+ text = get_input()
+ if text.strip() == 'EXIT':
+ print('Log: Exit!')
+ exit(0)
+
+ if args.image is not None and n_turn == 0:
+ text = DEFAULT_IMAGE_TOKEN + '\n' + text
+
+ if args.prompt_template:
+ prompt_text = ''
+ template = PROMPT_TEMPLATE[args.prompt_template]
+ if 'SYSTEM' in template and n_turn == 0:
+ system_text = None
+ if args.system_template is not None:
+ system_text = SYSTEM_TEMPLATE[
+ args.system_template].format(
+ round=n_turn + 1, bot_name=args.bot_name)
+ elif args.system is not None:
+ system_text = args.system
+ if system_text is not None:
+ prompt_text += template['SYSTEM'].format(
+ system=system_text,
+ round=n_turn + 1,
+ bot_name=args.bot_name)
+ prompt_text += template['INSTRUCTION'].format(
+ input=text, round=n_turn + 1, bot_name=args.bot_name)
+ if args.prompt_template == args.system_template == 'moss_sft':
+ if not inner_thoughts_open:
+ prompt_text.replace('- Inner thoughts: enabled.',
+ '- Inner thoughts: disabled.')
+ if not calculate_open:
+ prompt_text.replace(('- Calculator: enabled. API: '
+ 'Calculate(expression)'),
+ '- Calculator: disabled.')
+ if not solve_open:
+ prompt_text.replace(
+ '- Equation solver: enabled. API: Solve(equation)',
+ '- Equation solver: disabled.')
+ if not search_open:
+ prompt_text.replace(
+ '- Web search: enabled. API: Search(query)',
+ '- Web search: disabled.')
+ else:
+ prompt_text = text
+ print("prompt_text: ", prompt_text)
+ inputs += prompt_text
+ if args.image is None:
+ if n_turn == 0:
+ ids = tokenizer.encode(inputs, return_tensors='pt')
+ else:
+ ids = tokenizer.encode(
+ inputs, return_tensors='pt', add_special_tokens=False)
+
+ if args.with_plugins is not None:
+ generate_output = llm.generate(
+ inputs=ids.cuda(),
+ generation_config=gen_config,
+ streamer=streamer,
+ stopping_criteria=stop_criteria).cpu()
+ generate_output_text = tokenizer.decode(
+ generate_output[0][len(ids[0]):])
+ if streamer is None:
+ end = '' if generate_output_text[-1] == '\n' else '\n'
+ print(generate_output_text, end=end)
+ pattern = r'<\|Commands\|>:(.*?)