zhangtao-whu commited on
Commit
f6d075a
·
verified ·
1 Parent(s): a30d3e3

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. omg_llava/__init__.py +0 -0
  2. omg_llava/__pycache__/__init__.cpython-310.pyc +0 -0
  3. omg_llava/configs/__init__.py +0 -0
  4. omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_baseline.py +951 -0
  5. omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_cat.py +954 -0
  6. omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_cat_debug.py +927 -0
  7. omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_linear_cat.py +954 -0
  8. omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_linearcat_debug.py +927 -0
  9. omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_mean.py +954 -0
  10. omg_llava/configs/finetune/ablation_multi_seg_states/debug.py +924 -0
  11. omg_llava/configs/finetune/ablation_projector/finetune_ablation_projector_cross.py +953 -0
  12. omg_llava/configs/finetune/ablation_projector/finetune_ablation_projector_seperate.py +953 -0
  13. omg_llava/configs/finetune/ablation_projector/finetune_ablation_projector_seperate_cross.py +953 -0
  14. omg_llava/configs/finetune/ablation_projector/finetune_ablation_projector_seperate_cross_debug.py +926 -0
  15. omg_llava/configs/finetune/debug.py +967 -0
  16. omg_llava/configs/finetune/fix_unfrozen_bug_omg_llava_7b_finetune_stage1_1024image_uniSegFormat_8gpus.py +951 -0
  17. omg_llava/configs/finetune/hf_app.py +951 -0
  18. omg_llava/configs/finetune/omg_llava_20b_finetune_stage1_1024image_8gpus.py +993 -0
  19. omg_llava/configs/finetune/omg_llava_7b_convnextXXL_finetune_stage1_1024image_uniSegFormat_8gpus.py +952 -0
  20. omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_8gpus.py +993 -0
  21. omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_8gpus_01.py +1007 -0
  22. omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_decoupleGCG_8gpus.py +1028 -0
  23. omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_decoupleGCG_8gpus_debug.py +1000 -0
  24. omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_uniSegFormat_8gpus.py +951 -0
  25. omg_llava/configs/finetune/omg_llava_7b_finetune_stage2_1024image_8gpus.py +994 -0
  26. omg_llava/configs/finetune/specific_tasks_finetune/finetune_gcg.py +925 -0
  27. omg_llava/configs/finetune/specific_tasks_finetune/finetune_refseg.py +929 -0
  28. omg_llava/configs/pretrain/ablation_projector/ablation_projector_baseline.py +377 -0
  29. omg_llava/configs/pretrain/ablation_projector/ablation_projector_seperate_cross.py +377 -0
  30. omg_llava/configs/pretrain/ablation_projector/ablation_projector_seperate_cross_rmProjloss.py +377 -0
  31. omg_llava/configs/pretrain/omg_llava_20b_pretrain_1024image_8gpus.py +379 -0
  32. omg_llava/configs/pretrain/omg_llava_7b_pretrain_1024image_8gpus.py +375 -0
  33. omg_llava/configs/pretrain/omg_llava_7b_pretrain_1024image_convnextXXL_8gpus.py +376 -0
  34. omg_llava/dataset/CombineDataset.py +81 -0
  35. omg_llava/dataset/DecoupledGCGDataset.py +381 -0
  36. omg_llava/dataset/GCGDataset.py +364 -0
  37. omg_llava/dataset/LlavaDataset.py +134 -0
  38. omg_llava/dataset/MDPVPointsDataset.py +220 -0
  39. omg_llava/dataset/ReferringSegDataset.py +380 -0
  40. omg_llava/dataset/RegionCaptionDataset.py +356 -0
  41. omg_llava/dataset/SemanticSegDataset.py +725 -0
  42. omg_llava/dataset/__init__.py +29 -0
  43. omg_llava/dataset/__pycache__/CombineDataset.cpython-310.pyc +0 -0
  44. omg_llava/dataset/__pycache__/DecoupledGCGDataset.cpython-310.pyc +0 -0
  45. omg_llava/dataset/__pycache__/GCGDataset.cpython-310.pyc +0 -0
  46. omg_llava/dataset/__pycache__/LlavaDataset.cpython-310.pyc +0 -0
  47. omg_llava/dataset/__pycache__/MDPVPointsDataset.cpython-310.pyc +0 -0
  48. omg_llava/dataset/__pycache__/ReferringSegDataset.cpython-310.pyc +0 -0
  49. omg_llava/dataset/__pycache__/RegionCaptionDataset.cpython-310.pyc +0 -0
  50. omg_llava/dataset/__pycache__/SemanticSegDataset.cpython-310.pyc +0 -0
omg_llava/__init__.py ADDED
File without changes
omg_llava/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (151 Bytes). View file
 
omg_llava/configs/__init__.py ADDED
File without changes
omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_baseline.py ADDED
@@ -0,0 +1,951 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
4
+ LoggerHook, ParamSchedulerHook)
5
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
6
+ from peft import LoraConfig
7
+ from torch.optim import AdamW
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig, CLIPImageProcessor,
10
+ CLIPVisionModel)
11
+
12
+ from omg_llava.dataset import LLaVADataset
13
+ from omg_llava.dataset.collect_fns import omg_llava_collate_fn
14
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
15
+ from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\
16
+ CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\
17
+ ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\
18
+ PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\
19
+ RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\
20
+ Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\
21
+ OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\
22
+ OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\
23
+ MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\
24
+ semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\
25
+ referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn
26
+ from xtuner.dataset.samplers import LengthGroupedSampler
27
+ from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens
28
+ from xtuner.engine.runner import TrainLoop
29
+ from omg_llava.model import OMG_LLaVA
30
+ from xtuner.utils import PROMPT_TEMPLATE
31
+ from omg_llava.model import OpenCLIPBackbone_omgseg
32
+ from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead
33
+
34
+ from torch.nn import GroupNorm, ReLU
35
+
36
+ from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \
37
+ DiceLoss, MaskFormerFusionHead, FocalLoss
38
+ from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost
39
+ from mmdet.models.task_modules.samplers import MaskPseudoSampler
40
+
41
+ #######################################################################
42
+ # PART 1 Settings #
43
+ #######################################################################
44
+ # Model
45
+ llm_name_or_path = '/mnt/workspace/taozhang/chekpoints/internlm2-chat-7b' # Please change to your own path
46
+ pretrained_pth = './work_dirs/omg_llava_7b_pretrain_1024image_8gpus.pth'
47
+ omg_ov_class_embed_path='/mnt/workspace/taozhang/chekpoints/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path
48
+ omg_head_pretrain_pth_path = '/mnt/workspace/taozhang/chekpoints/omg_seg_convl.pth' # Please change to your own path
49
+
50
+ # Data
51
+ data_root = './data/llava_data/'
52
+ data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json'
53
+ image_folder = data_root + 'llava_images'
54
+
55
+ glamm_data_root = './data/glamm_data/'
56
+
57
+ refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/'
58
+ refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json'
59
+
60
+ grandf_image_path = glamm_data_root + 'images/grandf/train/'
61
+ grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json'
62
+
63
+ flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/'
64
+ flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json'
65
+
66
+ psg_image_path = glamm_data_root + 'images/coco2017/'
67
+ psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json'
68
+
69
+ ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
70
+ ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json'
71
+
72
+ cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/'
73
+ cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt'
74
+ cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/'
75
+
76
+ mapillary_image_path = './data/semantic_seg/mapillary/training/images/'
77
+ mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json'
78
+ mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/'
79
+
80
+ pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/'
81
+ pascal_file = './data/semantic_seg/pascal_part/train.json'
82
+
83
+ paco_image_path = './data/glamm_data/images/coco2017/'
84
+ paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json'
85
+
86
+ referring_refcoco_image_path = refcocog_image_path
87
+ referring_refcoco_data_path = "./data/ref_seg/"
88
+
89
+ referring_refcoco_plus_image_path = refcocog_image_path
90
+ referring_refcoco_plus_data_path = "./data/ref_seg/"
91
+
92
+ referring_refcocog_image_path = refcocog_image_path
93
+ referring_refcocog_data_path = "./data/ref_seg/"
94
+
95
+ referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/"
96
+ referring_refclef_data_path = "./data/ref_seg/"
97
+
98
+ region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
99
+ region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json"
100
+
101
+ region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
102
+ region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json"
103
+
104
+ mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
105
+ mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json'
106
+
107
+ mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
108
+ mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json'
109
+
110
+ mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
111
+ mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json'
112
+
113
+ mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
114
+ mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json'
115
+
116
+ mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
117
+ mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json'
118
+
119
+ mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
120
+ mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json'
121
+
122
+ mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
123
+ mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json'
124
+
125
+ mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
126
+ mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json'
127
+
128
+ mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
129
+ mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json'
130
+
131
+ mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
132
+ mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json'
133
+
134
+ mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
135
+ mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json'
136
+
137
+ mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017'
138
+ mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json'
139
+
140
+ mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
141
+ mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json'
142
+
143
+ mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
144
+ mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json'
145
+
146
+ mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/'
147
+ mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json'
148
+
149
+ mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017'
150
+ mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json'
151
+
152
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
153
+ max_length = int(2048 - (1024 / 64)**2 - 100)
154
+
155
+ # Scheduler & Optimizer
156
+ batch_size = 8 # per_device
157
+ accumulative_counts = 2
158
+ dataloader_num_workers = 0
159
+ max_epochs = 1
160
+ optim_type = AdamW
161
+ lr = 2e-4
162
+ betas = (0.9, 0.999)
163
+ weight_decay = 0
164
+ max_norm = 1 # grad clip
165
+ warmup_ratio = 0.03
166
+
167
+
168
+ # Save
169
+ save_steps = 2000
170
+ save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited)
171
+
172
+ # Evaluate the generation performance during the training
173
+ evaluation_freq = 2000
174
+ SYSTEM = ''
175
+ evaluation_images = './work_dirs/test.jpg'
176
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture',
177
+ 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.']
178
+
179
+ #######################################################################
180
+ # PART 2 Model & Tokenizer & Image Processor #
181
+ #######################################################################
182
+ tokenizer = dict(
183
+ type=AutoTokenizer.from_pretrained,
184
+ pretrained_model_name_or_path=llm_name_or_path,
185
+ trust_remote_code=True,
186
+ padding_side='right')
187
+
188
+ image_processor = dict(
189
+ type=CLIPImageProcessor,
190
+ do_resize=True,
191
+ size=1024,
192
+ resample=3,
193
+ do_center_crop=True,
194
+ crop_size=1024,
195
+ do_rescale=True,
196
+ do_normalize=True,
197
+ image_mean=[0.4814, 0.4578, 0.4082],
198
+ image_std=[0.2686, 0.2613, 0.2757],
199
+ do_convert_rgb=True
200
+ )
201
+
202
+ class_embed = 'convnext_large_d_320_CocoPanopticOVDataset'
203
+ num_things_classes = 80
204
+ num_stuff_classes = 53
205
+ num_classes = num_things_classes + num_stuff_classes
206
+
207
+ omgseg_model = dict(
208
+ type=OMGSegVisualEncoder,
209
+ data_preprocessor=None,
210
+ pixel_shuffle_down_ratio=2,
211
+ backbone=dict(
212
+ type=OpenCLIPBackbone_omgseg,
213
+ model_name='convnext_large_d_320',
214
+ fix=True,
215
+ init_cfg=dict(
216
+ type='clip_pretrain',
217
+ checkpoint='laion2b_s29b_b131k_ft_soup'
218
+ )
219
+ ),
220
+ panoptic_head=dict(
221
+ type=Mask2FormerVideoSemSamHead,
222
+ sphere_cls=True,
223
+ ov_path=omg_ov_class_embed_path,
224
+ enable_box_query=False,
225
+ ov_classifier_name=class_embed,
226
+ logit=None,
227
+ in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside
228
+ strides=[4, 8, 16, 32],
229
+ feat_channels=256,
230
+ out_channels=256,
231
+ num_things_classes=num_things_classes,
232
+ num_stuff_classes=num_stuff_classes,
233
+ num_queries=300,
234
+ num_transformer_feat_level=3,
235
+ pixel_decoder=dict(
236
+ type=MSDeformAttnPixelDecoder,
237
+ num_outs=3,
238
+ norm_cfg=dict(type=GroupNorm, num_groups=32),
239
+ act_cfg=dict(type=ReLU),
240
+ encoder=dict( # DeformableDetrTransformerEncoder
241
+ num_layers=6,
242
+ layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
243
+ self_attn_cfg=dict( # MultiScaleDeformableAttention
244
+ embed_dims=256,
245
+ num_heads=8,
246
+ num_levels=3,
247
+ num_points=4,
248
+ dropout=0.0,
249
+ batch_first=True),
250
+ ffn_cfg=dict(
251
+ embed_dims=256,
252
+ feedforward_channels=1024,
253
+ num_fcs=2,
254
+ ffn_drop=0.0,
255
+ act_cfg=dict(type=ReLU, inplace=True)))),
256
+ positional_encoding=dict(num_feats=128, normalize=True)),
257
+ enforce_decoder_input_project=False,
258
+ positional_encoding=dict(num_feats=128, normalize=True),
259
+ transformer_decoder=dict( # Mask2FormerTransformerDecoder
260
+ return_intermediate=True,
261
+ num_layers=9,
262
+ layer_cfg=dict( # Mask2FormerTransformerDecoderLayer
263
+ self_attn_cfg=dict( # MultiheadAttention
264
+ embed_dims=256,
265
+ num_heads=8,
266
+ dropout=0.0,
267
+ batch_first=True),
268
+ cross_attn_cfg=dict( # MultiheadAttention
269
+ embed_dims=256,
270
+ num_heads=8,
271
+ dropout=0.0,
272
+ batch_first=True),
273
+ ffn_cfg=dict(
274
+ embed_dims=256,
275
+ feedforward_channels=2048,
276
+ num_fcs=2,
277
+ ffn_drop=0.0,
278
+ act_cfg=dict(type='ReLU', inplace=True))),
279
+ init_cfg=None),
280
+ loss_cls=dict(
281
+ type=CrossEntropyLoss,
282
+ use_sigmoid=False,
283
+ loss_weight=2.0,
284
+ reduction='mean',
285
+ class_weight=[1.0] * 240 + [0.1]),
286
+ loss_mask=dict(
287
+ type=CrossEntropyLoss,
288
+ use_sigmoid=True,
289
+ reduction='mean',
290
+ loss_weight=5.0),
291
+ loss_dice=dict(
292
+ type=DiceLoss,
293
+ use_sigmoid=True,
294
+ activate=True,
295
+ reduction='mean',
296
+ naive_dice=True,
297
+ eps=1.0,
298
+ loss_weight=5.0),
299
+ loss_iou=dict(
300
+ type=FocalLoss,
301
+ use_sigmoid=True,
302
+ loss_weight=2.0,
303
+ reduction='mean')
304
+ ),
305
+ panoptic_fusion_head=dict(
306
+ type=MaskFormerFusionHead,
307
+ num_things_classes=num_things_classes,
308
+ num_stuff_classes=num_stuff_classes,
309
+ loss_panoptic=None,
310
+ init_cfg=None),
311
+ train_cfg=dict(
312
+ num_points=12544,
313
+ oversample_ratio=3.0,
314
+ importance_sample_ratio=0.75,
315
+ assigner=dict(
316
+ type=HungarianAssigner,
317
+ match_costs=[
318
+ # dict(type=FlexibleClassificationCost, weight=2.0),
319
+ dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True),
320
+ dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0)
321
+ ]),
322
+ sampler=dict(type=MaskPseudoSampler)),
323
+ test_cfg=dict(
324
+ panoptic_on=True,
325
+ # For now, the dataset does not support
326
+ # evaluating semantic segmentation metric.
327
+ semantic_on=False,
328
+ instance_on=True,
329
+ # max_per_image is for instance segmentation.
330
+ max_per_image=100,
331
+ iou_thr=0.8,
332
+ # In Mask2Former's panoptic postprocessing,
333
+ # it will filter mask area where score is less than 0.5 .
334
+ filter_low_score=True),
335
+ init_cfg=dict(
336
+ type='Pretrained',
337
+ checkpoint=omg_head_pretrain_pth_path,
338
+ )
339
+ )
340
+
341
+ model = dict(
342
+ type=OMG_LLaVA,
343
+ freeze_llm=True,
344
+ freeze_visual_encoder=True,
345
+ require_omg_decoder=False,
346
+ pretrained_pth=pretrained_pth,
347
+ text2vision_projector=True,
348
+ pixel_shuffle_ratio=2,
349
+ llm=dict(
350
+ type=AutoModelForCausalLM.from_pretrained,
351
+ pretrained_model_name_or_path=llm_name_or_path,
352
+ trust_remote_code=True,
353
+ torch_dtype=torch.float16,
354
+ quantization_config=dict(
355
+ type=BitsAndBytesConfig,
356
+ load_in_4bit=True,
357
+ load_in_8bit=False,
358
+ llm_int8_threshold=6.0,
359
+ llm_int8_has_fp16_weight=False,
360
+ bnb_4bit_compute_dtype=torch.float16,
361
+ bnb_4bit_use_double_quant=True,
362
+ bnb_4bit_quant_type='nf4')),
363
+ llm_lora=dict(
364
+ type=LoraConfig,
365
+ r=512,
366
+ lora_alpha=256,
367
+ lora_dropout=0.05,
368
+ bias='none',
369
+ task_type='CAUSAL_LM'),
370
+ visual_encoder=omgseg_model,
371
+ tokenizer=tokenizer,
372
+ )
373
+
374
+ #######################################################################
375
+ # PART 3 Dataset & Dataloader #
376
+ #######################################################################
377
+ debug=False
378
+ llava_dataset = dict(
379
+ type=LLaVADataset,
380
+ data_path=data_path,
381
+ image_folder=image_folder,
382
+ tokenizer=tokenizer,
383
+ image_processor=image_processor,
384
+ dataset_map_fn=llava_map_fn,
385
+ template_map_fn=dict(
386
+ type=template_map_fn_factory, template=prompt_template),
387
+ max_length=max_length,
388
+ pad_image_to_square=True)
389
+
390
+ glamm_refcocog_dataset = dict(
391
+ type=RefCOCOgGCGDataset,
392
+ data_path=refcocog_ann_file,
393
+ image_folder=refcocog_image_path,
394
+ tokenizer=tokenizer,
395
+ image_processor=image_processor,
396
+ dataset_map_fn=glamm_refcocog_map_fn,
397
+ template_map_fn=dict(
398
+ type=template_map_fn_factory, template=prompt_template),
399
+ max_length=max_length,
400
+ pad_image_to_square=True,
401
+ debug=False,
402
+ repeats=1,
403
+ )
404
+
405
+ glamm_grandf_dataset = dict(
406
+ type=GranDfGCGDataset,
407
+ data_path=grandf_ann_file,
408
+ image_folder=grandf_image_path,
409
+ tokenizer=tokenizer,
410
+ image_processor=image_processor,
411
+ dataset_map_fn=glamm_granf_map_fn,
412
+ template_map_fn=dict(
413
+ type=template_map_fn_factory, template=prompt_template),
414
+ max_length=max_length,
415
+ pad_image_to_square=True,
416
+ debug=debug,
417
+ repeats=10,
418
+ )
419
+
420
+ glamm_psg_dataset = dict(
421
+ type=OpenPsgGCGDataset,
422
+ data_path=psg_ann_file,
423
+ image_folder=psg_image_path,
424
+ tokenizer=tokenizer,
425
+ image_processor=image_processor,
426
+ dataset_map_fn=glamm_openpsg_map_fn,
427
+ template_map_fn=dict(
428
+ type=template_map_fn_factory, template=prompt_template),
429
+ max_length=max_length,
430
+ pad_image_to_square=True,
431
+ debug=debug,
432
+ repeats=1,
433
+ )
434
+
435
+ glamm_flickr_dataset = dict(
436
+ type=FlickrGCGDataset,
437
+ data_path=flickr_ann_file,
438
+ image_folder=flickr_image_path,
439
+ tokenizer=tokenizer,
440
+ image_processor=image_processor,
441
+ dataset_map_fn=glamm_flickr_map_fn,
442
+ template_map_fn=dict(
443
+ type=template_map_fn_factory, template=prompt_template),
444
+ max_length=max_length,
445
+ pad_image_to_square=True,
446
+ debug=debug,
447
+ repeats=1,
448
+ )
449
+
450
+ semantic_seg_ade20k_dataset = dict(
451
+ type=ADE20kSemanticSegDataset,
452
+ data_path=ade20k_class_file,
453
+ image_folder=ade20k_image_path,
454
+ tokenizer=tokenizer,
455
+ image_processor=image_processor,
456
+ dataset_map_fn=semantic_seg_gcg_format_map_fn,
457
+ template_map_fn=dict(
458
+ type=template_map_fn_factory, template=prompt_template),
459
+ max_length=max_length,
460
+ pad_image_to_square=True,
461
+ debug=False,
462
+ repeats=1,
463
+ gcg_format=True,
464
+ )
465
+
466
+ semantic_seg_cocostuff_dataset = dict(
467
+ type=COCOStuffSemanticSegDataset,
468
+ data_path=cocostuff_class_file,
469
+ image_folder=cocostuff_image_path,
470
+ label_path=cocostuff_label_path,
471
+ tokenizer=tokenizer,
472
+ image_processor=image_processor,
473
+ dataset_map_fn=semantic_seg_gcg_format_map_fn,
474
+ template_map_fn=dict(
475
+ type=template_map_fn_factory, template=prompt_template),
476
+ max_length=max_length,
477
+ pad_image_to_square=True,
478
+ debug=False,
479
+ repeats=1,
480
+ gcg_format=True,
481
+ )
482
+
483
+ referring_seg_refcoco_dataset = dict(
484
+ type=RefcocoReferringSegDataset,
485
+ data_path=referring_refcoco_data_path,
486
+ image_folder=referring_refcoco_image_path,
487
+ tokenizer=tokenizer,
488
+ image_processor=image_processor,
489
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
490
+ template_map_fn=dict(
491
+ type=template_map_fn_factory, template=prompt_template),
492
+ max_length=max_length,
493
+ pad_image_to_square=True,
494
+ debug=False,
495
+ repeats=1,
496
+ )
497
+
498
+ referring_seg_refcoco_plus_dataset = dict(
499
+ type=Refcoco_plus_ReferringSegDataset,
500
+ data_path=referring_refcoco_plus_data_path,
501
+ image_folder=referring_refcoco_plus_image_path,
502
+ tokenizer=tokenizer,
503
+ image_processor=image_processor,
504
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
505
+ template_map_fn=dict(
506
+ type=template_map_fn_factory, template=prompt_template),
507
+ max_length=max_length,
508
+ pad_image_to_square=True,
509
+ debug=False,
510
+ repeats=1,
511
+ )
512
+
513
+ referring_seg_refcocog_dataset = dict(
514
+ type=Refcocog_ReferringSegDataset,
515
+ data_path=referring_refcocog_data_path,
516
+ image_folder=referring_refcocog_image_path,
517
+ tokenizer=tokenizer,
518
+ image_processor=image_processor,
519
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
520
+ template_map_fn=dict(
521
+ type=template_map_fn_factory, template=prompt_template),
522
+ max_length=max_length,
523
+ pad_image_to_square=True,
524
+ debug=False,
525
+ repeats=1,
526
+ )
527
+
528
+ referring_seg_refclef_dataset = dict(
529
+ type=Refclef_ReferringSegDataset,
530
+ data_path=referring_refclef_data_path,
531
+ image_folder=referring_refclef_image_path,
532
+ tokenizer=tokenizer,
533
+ image_processor=image_processor,
534
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
535
+ template_map_fn=dict(
536
+ type=template_map_fn_factory, template=prompt_template),
537
+ max_length=max_length,
538
+ pad_image_to_square=True,
539
+ debug=False,
540
+ repeats=1,
541
+ )
542
+
543
+ region_cap_osprey_dataset = dict(
544
+ type=OspreyRegionCaptionDataset,
545
+ data_path=region_cap_osprey_data_path,
546
+ image_folder=region_cap_osprey_image_path,
547
+ tokenizer=tokenizer,
548
+ image_processor=image_processor,
549
+ dataset_map_fn=osprey_region_caption_gcg_format_map_fn,
550
+ template_map_fn=dict(
551
+ type=template_map_fn_factory, template=prompt_template),
552
+ max_length=max_length,
553
+ pad_image_to_square=True,
554
+ debug=False,
555
+ repeats=1,
556
+ )
557
+
558
+ region_conversation_osprey_dataset = dict(
559
+ type=OspreyRegionConversationDataset,
560
+ data_path=region_conversation_osprey_data_path,
561
+ image_folder=region_conversation_osprey_image_path,
562
+ tokenizer=tokenizer,
563
+ image_processor=image_processor,
564
+ dataset_map_fn=osprey_region_conversation_map_fn,
565
+ template_map_fn=dict(
566
+ type=template_map_fn_factory, template=prompt_template),
567
+ max_length=max_length,
568
+ pad_image_to_square=True,
569
+ debug=False,
570
+ repeats=1,
571
+ )
572
+
573
+ mdpv_detailed_description_ade20k_dataset = dict(
574
+ type=MDPVPointDetailedCaptionDataset,
575
+ data_path=mdpv_detailed_caption_ade20k_data_path,
576
+ image_folder=mdpv_detailed_caption_ade20k_image_path,
577
+ tokenizer=tokenizer,
578
+ image_processor=image_processor,
579
+ dataset_map_fn=mdpv_points_map_fn,
580
+ template_map_fn=dict(
581
+ type=template_map_fn_factory, template=prompt_template),
582
+ max_length=max_length,
583
+ pad_image_to_square=True,
584
+ debug=False,
585
+ repeats=1,
586
+ )
587
+
588
+ mdpv_detailed_description_cocostuff_10k_dataset = dict(
589
+ type=MDPVPointDetailedCaptionDataset,
590
+ data_path=mdpv_detailed_caption_cocostuff_10k_data_path,
591
+ image_folder=mdpv_detailed_caption_cocostuff_10k_image_path,
592
+ tokenizer=tokenizer,
593
+ image_processor=image_processor,
594
+ dataset_map_fn=mdpv_points_map_fn,
595
+ template_map_fn=dict(
596
+ type=template_map_fn_factory, template=prompt_template),
597
+ max_length=max_length,
598
+ pad_image_to_square=True,
599
+ debug=False,
600
+ repeats=1,
601
+ )
602
+
603
+ mdpv_detailed_description_cocostuff_164k_dataset = dict(
604
+ type=MDPVPointDetailedCaptionDataset,
605
+ data_path=mdpv_detailed_caption_cocostuff_164k_data_path,
606
+ image_folder=mdpv_detailed_caption_cocostuff_164k_image_path,
607
+ tokenizer=tokenizer,
608
+ image_processor=image_processor,
609
+ dataset_map_fn=mdpv_points_map_fn,
610
+ template_map_fn=dict(
611
+ type=template_map_fn_factory, template=prompt_template),
612
+ max_length=max_length,
613
+ pad_image_to_square=True,
614
+ debug=False,
615
+ repeats=1,
616
+ )
617
+
618
+ mdpv_detailed_description_vg_dataset = dict(
619
+ type=MDPVPointDetailedCaptionDataset,
620
+ data_path=mdpv_detailed_caption_vg_data_path,
621
+ image_folder=mdpv_detailed_caption_vg_image_path,
622
+ tokenizer=tokenizer,
623
+ image_processor=image_processor,
624
+ dataset_map_fn=mdpv_points_map_fn,
625
+ template_map_fn=dict(
626
+ type=template_map_fn_factory, template=prompt_template),
627
+ max_length=max_length,
628
+ pad_image_to_square=True,
629
+ debug=False,
630
+ repeats=1,
631
+ )
632
+
633
+ mdpv_brief_description_vg_dataset = dict(
634
+ type=MDPVPointBriefCaptionDataset,
635
+ data_path=mdpv_brief_caption_vg_data_path,
636
+ image_folder=mdpv_brief_caption_vg_image_path,
637
+ tokenizer=tokenizer,
638
+ image_processor=image_processor,
639
+ dataset_map_fn=mdpv_points_map_fn,
640
+ template_map_fn=dict(
641
+ type=template_map_fn_factory, template=prompt_template),
642
+ max_length=max_length,
643
+ pad_image_to_square=True,
644
+ debug=False,
645
+ repeats=1,
646
+ )
647
+
648
+ mdpv_brief_description_cocostuff10k_dataset = dict(
649
+ type=MDPVPointBriefCaptionDataset,
650
+ data_path=mdpv_brief_caption_cocostuff_10k_data_path,
651
+ image_folder=mdpv_brief_caption_cocostuff_10k_image_path,
652
+ tokenizer=tokenizer,
653
+ image_processor=image_processor,
654
+ dataset_map_fn=mdpv_points_map_fn,
655
+ template_map_fn=dict(
656
+ type=template_map_fn_factory, template=prompt_template),
657
+ max_length=max_length,
658
+ pad_image_to_square=True,
659
+ debug=False,
660
+ repeats=1,
661
+ )
662
+
663
+ mdpv_brief_description_cocostuff164k_dataset = dict(
664
+ type=MDPVPointBriefCaptionDataset,
665
+ data_path=mdpv_brief_caption_cocostuff_164k_data_path,
666
+ image_folder=mdpv_brief_caption_cocostuff_164k_image_path,
667
+ tokenizer=tokenizer,
668
+ image_processor=image_processor,
669
+ dataset_map_fn=mdpv_points_map_fn,
670
+ template_map_fn=dict(
671
+ type=template_map_fn_factory, template=prompt_template),
672
+ max_length=max_length,
673
+ pad_image_to_square=True,
674
+ debug=False,
675
+ repeats=1,
676
+ )
677
+
678
+ mdpv_brief_description_ade20k_dataset = dict(
679
+ type=MDPVPointBriefCaptionDataset,
680
+ data_path=mdpv_brief_caption_ade20k_data_path,
681
+ image_folder=mdpv_brief_caption_ade20k_image_path,
682
+ tokenizer=tokenizer,
683
+ image_processor=image_processor,
684
+ dataset_map_fn=mdpv_points_map_fn,
685
+ template_map_fn=dict(
686
+ type=template_map_fn_factory, template=prompt_template),
687
+ max_length=max_length,
688
+ pad_image_to_square=True,
689
+ debug=False,
690
+ repeats=1,
691
+ )
692
+
693
+ mdpv_brief_description_lvis_dataset = dict(
694
+ type=MDPVPointBriefCaptionDataset,
695
+ data_path=mdpv_brief_caption_lvis_data_path,
696
+ image_folder=mdpv_brief_caption_lvis_image_path,
697
+ tokenizer=tokenizer,
698
+ image_processor=image_processor,
699
+ dataset_map_fn=mdpv_points_map_fn,
700
+ template_map_fn=dict(
701
+ type=template_map_fn_factory, template=prompt_template),
702
+ max_length=max_length,
703
+ pad_image_to_square=True,
704
+ debug=False,
705
+ repeats=1,
706
+ )
707
+
708
+ mdpv_qa_vg_dataset = dict(
709
+ type=MDPVPointBriefCaptionDataset,
710
+ data_path=mdpv_qa_vg_data_path,
711
+ image_folder=mdpv_qa_vg_image_path,
712
+ tokenizer=tokenizer,
713
+ image_processor=image_processor,
714
+ dataset_map_fn=mdpv_points_map_fn,
715
+ template_map_fn=dict(
716
+ type=template_map_fn_factory, template=prompt_template),
717
+ max_length=max_length,
718
+ pad_image_to_square=True,
719
+ debug=False,
720
+ repeats=1,
721
+ )
722
+
723
+ mdpv_qa_ade20k_dataset = dict(
724
+ type=MDPVPointBriefCaptionDataset,
725
+ data_path=mdpv_qa_ade20k_data_path,
726
+ image_folder=mdpv_qa_ade20k_image_path,
727
+ tokenizer=tokenizer,
728
+ image_processor=image_processor,
729
+ dataset_map_fn=mdpv_points_map_fn,
730
+ template_map_fn=dict(
731
+ type=template_map_fn_factory, template=prompt_template),
732
+ max_length=max_length,
733
+ pad_image_to_square=True,
734
+ debug=False,
735
+ repeats=1,
736
+ )
737
+
738
+ mdpv_qa_lvis_dataset = dict(
739
+ type=MDPVPointBriefCaptionDataset,
740
+ data_path=mdpv_qa_lvis_data_path,
741
+ image_folder=mdpv_qa_lvis_image_path,
742
+ tokenizer=tokenizer,
743
+ image_processor=image_processor,
744
+ dataset_map_fn=mdpv_points_map_fn,
745
+ template_map_fn=dict(
746
+ type=template_map_fn_factory, template=prompt_template),
747
+ max_length=max_length,
748
+ pad_image_to_square=True,
749
+ debug=False,
750
+ repeats=1,
751
+ )
752
+
753
+ mdpv_qa_cocostuff10k_dataset = dict(
754
+ type=MDPVPointBriefCaptionDataset,
755
+ data_path=mdpv_qa_cocostuff10k_data_path,
756
+ image_folder=mdpv_qa_cocostuff10k_image_path,
757
+ tokenizer=tokenizer,
758
+ image_processor=image_processor,
759
+ dataset_map_fn=mdpv_points_map_fn,
760
+ template_map_fn=dict(
761
+ type=template_map_fn_factory, template=prompt_template),
762
+ max_length=max_length,
763
+ pad_image_to_square=True,
764
+ debug=False,
765
+ repeats=1,
766
+ )
767
+
768
+ mdpv_qa_cocostuff164k_dataset = dict(
769
+ type=MDPVPointBriefCaptionDataset,
770
+ data_path=mdpv_qa_cocostuff164k_data_path,
771
+ image_folder=mdpv_qa_cocostuff164k_image_path,
772
+ tokenizer=tokenizer,
773
+ image_processor=image_processor,
774
+ dataset_map_fn=mdpv_points_map_fn,
775
+ template_map_fn=dict(
776
+ type=template_map_fn_factory, template=prompt_template),
777
+ max_length=max_length,
778
+ pad_image_to_square=True,
779
+ debug=False,
780
+ repeats=1,
781
+ )
782
+
783
+ mdpv_multi_points_openpsg_dataset = dict(
784
+ type=MDPVPointBriefCaptionDataset,
785
+ data_path=mdpv_multi_points_openpsg_data_path,
786
+ image_folder=mdpv_multi_points_openpsg_image_path,
787
+ tokenizer=tokenizer,
788
+ image_processor=image_processor,
789
+ dataset_map_fn=mdpv_points_map_fn,
790
+ template_map_fn=dict(
791
+ type=template_map_fn_factory, template=prompt_template),
792
+ max_length=max_length,
793
+ pad_image_to_square=True,
794
+ debug=False,
795
+ repeats=1,
796
+ )
797
+
798
+ mdpv_multi_points_flicker30k_dataset = dict(
799
+ type=MDPVPointBriefCaptionDataset,
800
+ data_path=mdpv_multi_points_flicker30k_data_path,
801
+ image_folder=mdpv_multi_points_flicker30k_image_path,
802
+ tokenizer=tokenizer,
803
+ image_processor=image_processor,
804
+ dataset_map_fn=mdpv_points_map_fn,
805
+ template_map_fn=dict(
806
+ type=template_map_fn_factory, template=prompt_template),
807
+ max_length=max_length,
808
+ pad_image_to_square=True,
809
+ debug=False,
810
+ repeats=1,
811
+ )
812
+
813
+ train_dataset = dict(
814
+ type=CombineDataset,
815
+ datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset,
816
+ glamm_grandf_dataset, glamm_psg_dataset,
817
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x
818
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset,
819
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset,
820
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
821
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x
822
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
823
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset,
824
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
825
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset,
826
+ region_cap_osprey_dataset, region_conversation_osprey_dataset,
827
+ mdpv_detailed_description_ade20k_dataset,
828
+ mdpv_detailed_description_cocostuff_10k_dataset,
829
+ mdpv_detailed_description_cocostuff_164k_dataset,
830
+ mdpv_detailed_description_vg_dataset,
831
+ mdpv_brief_description_lvis_dataset,
832
+ mdpv_brief_description_vg_dataset,
833
+ mdpv_brief_description_ade20k_dataset,
834
+ mdpv_brief_description_cocostuff10k_dataset,
835
+ mdpv_brief_description_cocostuff164k_dataset,
836
+ mdpv_qa_vg_dataset,
837
+ mdpv_qa_lvis_dataset,
838
+ mdpv_qa_ade20k_dataset,
839
+ mdpv_qa_cocostuff10k_dataset,
840
+ mdpv_qa_cocostuff164k_dataset,
841
+ mdpv_multi_points_flicker30k_dataset,
842
+ mdpv_multi_points_openpsg_dataset,],
843
+ )
844
+
845
+ train_dataloader = dict(
846
+ batch_size=batch_size,
847
+ num_workers=dataloader_num_workers,
848
+ dataset=train_dataset,
849
+ sampler=dict(
850
+ type=LengthGroupedSampler,
851
+ length_property='modality_length',
852
+ per_device_batch_size=batch_size * accumulative_counts),
853
+ collate_fn=dict(type=omg_llava_collate_fn))
854
+
855
+ #######################################################################
856
+ # PART 4 Scheduler & Optimizer #
857
+ #######################################################################
858
+ # optimizer
859
+ optim_wrapper = dict(
860
+ type=AmpOptimWrapper,
861
+ optimizer=dict(
862
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
863
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
864
+ accumulative_counts=accumulative_counts,
865
+ loss_scale='dynamic',
866
+ dtype='float16')
867
+
868
+ # learning policy
869
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
870
+ param_scheduler = [
871
+ dict(
872
+ type=LinearLR,
873
+ start_factor=1e-5,
874
+ by_epoch=True,
875
+ begin=0,
876
+ end=warmup_ratio * max_epochs,
877
+ convert_to_iter_based=True),
878
+ dict(
879
+ type=CosineAnnealingLR,
880
+ eta_min=0.0,
881
+ by_epoch=True,
882
+ begin=warmup_ratio * max_epochs,
883
+ end=max_epochs,
884
+ convert_to_iter_based=True)
885
+ ]
886
+
887
+ # train, val, test setting
888
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
889
+
890
+ #######################################################################
891
+ # PART 5 Runtime #
892
+ #######################################################################
893
+ # Log the dialogue periodically during the training process, optional
894
+ custom_hooks = [
895
+ dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer),
896
+ dict(
897
+ type=EvaluateChatHook_withSpecialTokens,
898
+ tokenizer=tokenizer,
899
+ image_processor=image_processor,
900
+ every_n_iters=evaluation_freq,
901
+ evaluation_inputs=evaluation_inputs,
902
+ evaluation_images=evaluation_images,
903
+ system=SYSTEM,
904
+ prompt_template=prompt_template)
905
+ ]
906
+
907
+ # configure default hooks
908
+ default_hooks = dict(
909
+ # record the time of every iteration.
910
+ timer=dict(type=IterTimerHook),
911
+ # print log every 10 iterations.
912
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
913
+ # enable the parameter scheduler.
914
+ param_scheduler=dict(type=ParamSchedulerHook),
915
+ # save checkpoint per `save_steps`.
916
+ checkpoint=dict(
917
+ type=CheckpointHook,
918
+ by_epoch=False,
919
+ interval=save_steps,
920
+ max_keep_ckpts=save_total_limit),
921
+ # set sampler seed in distributed evrionment.
922
+ sampler_seed=dict(type=DistSamplerSeedHook),
923
+ )
924
+
925
+ # configure environment
926
+ env_cfg = dict(
927
+ # whether to enable cudnn benchmark
928
+ cudnn_benchmark=False,
929
+ # set multi process parameters
930
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
931
+ # set distributed parameters
932
+ dist_cfg=dict(backend='nccl'),
933
+ )
934
+
935
+ # set visualizer
936
+ visualizer = None
937
+
938
+ # set log level
939
+ log_level = 'INFO'
940
+
941
+ # load from which checkpoint
942
+ load_from = None
943
+
944
+ # whether to resume training from the loaded checkpoint
945
+ resume = False
946
+
947
+ # Defaults to use random seed and disable `deterministic`
948
+ randomness = dict(seed=None, deterministic=False)
949
+
950
+ # set log processor
951
+ log_processor = dict(by_epoch=False)
omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_cat.py ADDED
@@ -0,0 +1,954 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
4
+ LoggerHook, ParamSchedulerHook)
5
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
6
+ from peft import LoraConfig
7
+ from torch.optim import AdamW
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig, CLIPImageProcessor,
10
+ CLIPVisionModel)
11
+
12
+ from omg_llava.dataset import LLaVADataset
13
+ from omg_llava.dataset.collect_fns import omg_llava_collate_fn
14
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
15
+ from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\
16
+ CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\
17
+ ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\
18
+ PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\
19
+ RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\
20
+ Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\
21
+ OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\
22
+ OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\
23
+ MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\
24
+ semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\
25
+ referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn
26
+ from xtuner.dataset.samplers import LengthGroupedSampler
27
+ from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens
28
+ from xtuner.engine.runner import TrainLoop
29
+ from omg_llava.model import OMG_LLaVA
30
+ from xtuner.utils import PROMPT_TEMPLATE
31
+ from omg_llava.model import OpenCLIPBackbone_omgseg
32
+ from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead
33
+
34
+ from torch.nn import GroupNorm, ReLU
35
+
36
+ from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \
37
+ DiceLoss, MaskFormerFusionHead, FocalLoss
38
+ from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost
39
+ from mmdet.models.task_modules.samplers import MaskPseudoSampler
40
+
41
+ #######################################################################
42
+ # PART 1 Settings #
43
+ #######################################################################
44
+ # Model
45
+ llm_name_or_path = '/mnt/workspace/taozhang/chekpoints/internlm2-chat-7b' # Please change to your own path
46
+ pretrained_pth = './work_dirs/omg_llava_7b_pretrain_1024image_8gpus.pth'
47
+ omg_ov_class_embed_path='/mnt/workspace/taozhang/chekpoints/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path
48
+ omg_head_pretrain_pth_path = '/mnt/workspace/taozhang/chekpoints/omg_seg_convl.pth' # Please change to your own path
49
+
50
+ # Data
51
+ data_root = './data/llava_data/'
52
+ data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json'
53
+ image_folder = data_root + 'llava_images'
54
+
55
+ glamm_data_root = './data/glamm_data/'
56
+
57
+ refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/'
58
+ refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json'
59
+
60
+ grandf_image_path = glamm_data_root + 'images/grandf/train/'
61
+ grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json'
62
+
63
+ flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/'
64
+ flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json'
65
+
66
+ psg_image_path = glamm_data_root + 'images/coco2017/'
67
+ psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json'
68
+
69
+ ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
70
+ ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json'
71
+
72
+ cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/'
73
+ cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt'
74
+ cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/'
75
+
76
+ mapillary_image_path = './data/semantic_seg/mapillary/training/images/'
77
+ mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json'
78
+ mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/'
79
+
80
+ pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/'
81
+ pascal_file = './data/semantic_seg/pascal_part/train.json'
82
+
83
+ paco_image_path = './data/glamm_data/images/coco2017/'
84
+ paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json'
85
+
86
+ referring_refcoco_image_path = refcocog_image_path
87
+ referring_refcoco_data_path = "./data/ref_seg/"
88
+
89
+ referring_refcoco_plus_image_path = refcocog_image_path
90
+ referring_refcoco_plus_data_path = "./data/ref_seg/"
91
+
92
+ referring_refcocog_image_path = refcocog_image_path
93
+ referring_refcocog_data_path = "./data/ref_seg/"
94
+
95
+ referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/"
96
+ referring_refclef_data_path = "./data/ref_seg/"
97
+
98
+ region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
99
+ region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json"
100
+
101
+ region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
102
+ region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json"
103
+
104
+ mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
105
+ mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json'
106
+
107
+ mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
108
+ mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json'
109
+
110
+ mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
111
+ mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json'
112
+
113
+ mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
114
+ mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json'
115
+
116
+ mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
117
+ mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json'
118
+
119
+ mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
120
+ mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json'
121
+
122
+ mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
123
+ mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json'
124
+
125
+ mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
126
+ mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json'
127
+
128
+ mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
129
+ mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json'
130
+
131
+ mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
132
+ mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json'
133
+
134
+ mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
135
+ mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json'
136
+
137
+ mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017'
138
+ mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json'
139
+
140
+ mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
141
+ mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json'
142
+
143
+ mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
144
+ mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json'
145
+
146
+ mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/'
147
+ mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json'
148
+
149
+ mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017'
150
+ mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json'
151
+
152
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
153
+ max_length = int(2048 - (1024 / 64)**2 - 100)
154
+
155
+ # Scheduler & Optimizer
156
+ batch_size = 8 # per_device
157
+ accumulative_counts = 2
158
+ dataloader_num_workers = 0
159
+ max_epochs = 1
160
+ optim_type = AdamW
161
+ lr = 2e-4
162
+ betas = (0.9, 0.999)
163
+ weight_decay = 0
164
+ max_norm = 1 # grad clip
165
+ warmup_ratio = 0.03
166
+
167
+
168
+ # Save
169
+ save_steps = 2000
170
+ save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited)
171
+
172
+ # Evaluate the generation performance during the training
173
+ evaluation_freq = 2000
174
+ SYSTEM = ''
175
+ evaluation_images = './work_dirs/test.jpg'
176
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture',
177
+ 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.']
178
+
179
+ #######################################################################
180
+ # PART 2 Model & Tokenizer & Image Processor #
181
+ #######################################################################
182
+ tokenizer = dict(
183
+ type=AutoTokenizer.from_pretrained,
184
+ pretrained_model_name_or_path=llm_name_or_path,
185
+ trust_remote_code=True,
186
+ padding_side='right')
187
+
188
+ image_processor = dict(
189
+ type=CLIPImageProcessor,
190
+ do_resize=True,
191
+ size=1024,
192
+ resample=3,
193
+ do_center_crop=True,
194
+ crop_size=1024,
195
+ do_rescale=True,
196
+ do_normalize=True,
197
+ image_mean=[0.4814, 0.4578, 0.4082],
198
+ image_std=[0.2686, 0.2613, 0.2757],
199
+ do_convert_rgb=True
200
+ )
201
+
202
+ class_embed = 'convnext_large_d_320_CocoPanopticOVDataset'
203
+ num_things_classes = 80
204
+ num_stuff_classes = 53
205
+ num_classes = num_things_classes + num_stuff_classes
206
+
207
+ omgseg_model = dict(
208
+ type=OMGSegVisualEncoder,
209
+ data_preprocessor=None,
210
+ pixel_shuffle_down_ratio=2,
211
+ backbone=dict(
212
+ type=OpenCLIPBackbone_omgseg,
213
+ model_name='convnext_large_d_320',
214
+ fix=True,
215
+ init_cfg=dict(
216
+ type='clip_pretrain',
217
+ checkpoint='laion2b_s29b_b131k_ft_soup'
218
+ )
219
+ ),
220
+ panoptic_head=dict(
221
+ type=Mask2FormerVideoSemSamHead,
222
+ sphere_cls=True,
223
+ ov_path=omg_ov_class_embed_path,
224
+ enable_box_query=False,
225
+ ov_classifier_name=class_embed,
226
+ logit=None,
227
+ in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside
228
+ strides=[4, 8, 16, 32],
229
+ feat_channels=256,
230
+ out_channels=256,
231
+ num_things_classes=num_things_classes,
232
+ num_stuff_classes=num_stuff_classes,
233
+ num_queries=300,
234
+ num_transformer_feat_level=3,
235
+ pixel_decoder=dict(
236
+ type=MSDeformAttnPixelDecoder,
237
+ num_outs=3,
238
+ norm_cfg=dict(type=GroupNorm, num_groups=32),
239
+ act_cfg=dict(type=ReLU),
240
+ encoder=dict( # DeformableDetrTransformerEncoder
241
+ num_layers=6,
242
+ layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
243
+ self_attn_cfg=dict( # MultiScaleDeformableAttention
244
+ embed_dims=256,
245
+ num_heads=8,
246
+ num_levels=3,
247
+ num_points=4,
248
+ dropout=0.0,
249
+ batch_first=True),
250
+ ffn_cfg=dict(
251
+ embed_dims=256,
252
+ feedforward_channels=1024,
253
+ num_fcs=2,
254
+ ffn_drop=0.0,
255
+ act_cfg=dict(type=ReLU, inplace=True)))),
256
+ positional_encoding=dict(num_feats=128, normalize=True)),
257
+ enforce_decoder_input_project=False,
258
+ positional_encoding=dict(num_feats=128, normalize=True),
259
+ transformer_decoder=dict( # Mask2FormerTransformerDecoder
260
+ return_intermediate=True,
261
+ num_layers=9,
262
+ layer_cfg=dict( # Mask2FormerTransformerDecoderLayer
263
+ self_attn_cfg=dict( # MultiheadAttention
264
+ embed_dims=256,
265
+ num_heads=8,
266
+ dropout=0.0,
267
+ batch_first=True),
268
+ cross_attn_cfg=dict( # MultiheadAttention
269
+ embed_dims=256,
270
+ num_heads=8,
271
+ dropout=0.0,
272
+ batch_first=True),
273
+ ffn_cfg=dict(
274
+ embed_dims=256,
275
+ feedforward_channels=2048,
276
+ num_fcs=2,
277
+ ffn_drop=0.0,
278
+ act_cfg=dict(type='ReLU', inplace=True))),
279
+ init_cfg=None),
280
+ loss_cls=dict(
281
+ type=CrossEntropyLoss,
282
+ use_sigmoid=False,
283
+ loss_weight=2.0,
284
+ reduction='mean',
285
+ class_weight=[1.0] * 240 + [0.1]),
286
+ loss_mask=dict(
287
+ type=CrossEntropyLoss,
288
+ use_sigmoid=True,
289
+ reduction='mean',
290
+ loss_weight=5.0),
291
+ loss_dice=dict(
292
+ type=DiceLoss,
293
+ use_sigmoid=True,
294
+ activate=True,
295
+ reduction='mean',
296
+ naive_dice=True,
297
+ eps=1.0,
298
+ loss_weight=5.0),
299
+ loss_iou=dict(
300
+ type=FocalLoss,
301
+ use_sigmoid=True,
302
+ loss_weight=2.0,
303
+ reduction='mean')
304
+ ),
305
+ panoptic_fusion_head=dict(
306
+ type=MaskFormerFusionHead,
307
+ num_things_classes=num_things_classes,
308
+ num_stuff_classes=num_stuff_classes,
309
+ loss_panoptic=None,
310
+ init_cfg=None),
311
+ train_cfg=dict(
312
+ num_points=12544,
313
+ oversample_ratio=3.0,
314
+ importance_sample_ratio=0.75,
315
+ assigner=dict(
316
+ type=HungarianAssigner,
317
+ match_costs=[
318
+ # dict(type=FlexibleClassificationCost, weight=2.0),
319
+ dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True),
320
+ dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0)
321
+ ]),
322
+ sampler=dict(type=MaskPseudoSampler)),
323
+ test_cfg=dict(
324
+ panoptic_on=True,
325
+ # For now, the dataset does not support
326
+ # evaluating semantic segmentation metric.
327
+ semantic_on=False,
328
+ instance_on=True,
329
+ # max_per_image is for instance segmentation.
330
+ max_per_image=100,
331
+ iou_thr=0.8,
332
+ # In Mask2Former's panoptic postprocessing,
333
+ # it will filter mask area where score is less than 0.5 .
334
+ filter_low_score=True),
335
+ init_cfg=dict(
336
+ type='Pretrained',
337
+ checkpoint=omg_head_pretrain_pth_path,
338
+ )
339
+ )
340
+
341
+ model = dict(
342
+ type=OMG_LLaVA,
343
+ freeze_llm=True,
344
+ freeze_visual_encoder=True,
345
+ require_omg_decoder=False,
346
+ pretrained_pth=pretrained_pth,
347
+ text2vision_projector=True,
348
+ pixel_shuffle_ratio=2,
349
+ using_multilayer_states=True,
350
+ seg_token_merge_type='cat',
351
+ selected_layers=32,
352
+ llm=dict(
353
+ type=AutoModelForCausalLM.from_pretrained,
354
+ pretrained_model_name_or_path=llm_name_or_path,
355
+ trust_remote_code=True,
356
+ torch_dtype=torch.float16,
357
+ quantization_config=dict(
358
+ type=BitsAndBytesConfig,
359
+ load_in_4bit=True,
360
+ load_in_8bit=False,
361
+ llm_int8_threshold=6.0,
362
+ llm_int8_has_fp16_weight=False,
363
+ bnb_4bit_compute_dtype=torch.float16,
364
+ bnb_4bit_use_double_quant=True,
365
+ bnb_4bit_quant_type='nf4')),
366
+ llm_lora=dict(
367
+ type=LoraConfig,
368
+ r=512,
369
+ lora_alpha=256,
370
+ lora_dropout=0.05,
371
+ bias='none',
372
+ task_type='CAUSAL_LM'),
373
+ visual_encoder=omgseg_model,
374
+ tokenizer=tokenizer,
375
+ )
376
+
377
+ #######################################################################
378
+ # PART 3 Dataset & Dataloader #
379
+ #######################################################################
380
+ debug=False
381
+ llava_dataset = dict(
382
+ type=LLaVADataset,
383
+ data_path=data_path,
384
+ image_folder=image_folder,
385
+ tokenizer=tokenizer,
386
+ image_processor=image_processor,
387
+ dataset_map_fn=llava_map_fn,
388
+ template_map_fn=dict(
389
+ type=template_map_fn_factory, template=prompt_template),
390
+ max_length=max_length,
391
+ pad_image_to_square=True)
392
+
393
+ glamm_refcocog_dataset = dict(
394
+ type=RefCOCOgGCGDataset,
395
+ data_path=refcocog_ann_file,
396
+ image_folder=refcocog_image_path,
397
+ tokenizer=tokenizer,
398
+ image_processor=image_processor,
399
+ dataset_map_fn=glamm_refcocog_map_fn,
400
+ template_map_fn=dict(
401
+ type=template_map_fn_factory, template=prompt_template),
402
+ max_length=max_length,
403
+ pad_image_to_square=True,
404
+ debug=False,
405
+ repeats=1,
406
+ )
407
+
408
+ glamm_grandf_dataset = dict(
409
+ type=GranDfGCGDataset,
410
+ data_path=grandf_ann_file,
411
+ image_folder=grandf_image_path,
412
+ tokenizer=tokenizer,
413
+ image_processor=image_processor,
414
+ dataset_map_fn=glamm_granf_map_fn,
415
+ template_map_fn=dict(
416
+ type=template_map_fn_factory, template=prompt_template),
417
+ max_length=max_length,
418
+ pad_image_to_square=True,
419
+ debug=debug,
420
+ repeats=10,
421
+ )
422
+
423
+ glamm_psg_dataset = dict(
424
+ type=OpenPsgGCGDataset,
425
+ data_path=psg_ann_file,
426
+ image_folder=psg_image_path,
427
+ tokenizer=tokenizer,
428
+ image_processor=image_processor,
429
+ dataset_map_fn=glamm_openpsg_map_fn,
430
+ template_map_fn=dict(
431
+ type=template_map_fn_factory, template=prompt_template),
432
+ max_length=max_length,
433
+ pad_image_to_square=True,
434
+ debug=debug,
435
+ repeats=1,
436
+ )
437
+
438
+ glamm_flickr_dataset = dict(
439
+ type=FlickrGCGDataset,
440
+ data_path=flickr_ann_file,
441
+ image_folder=flickr_image_path,
442
+ tokenizer=tokenizer,
443
+ image_processor=image_processor,
444
+ dataset_map_fn=glamm_flickr_map_fn,
445
+ template_map_fn=dict(
446
+ type=template_map_fn_factory, template=prompt_template),
447
+ max_length=max_length,
448
+ pad_image_to_square=True,
449
+ debug=debug,
450
+ repeats=1,
451
+ )
452
+
453
+ semantic_seg_ade20k_dataset = dict(
454
+ type=ADE20kSemanticSegDataset,
455
+ data_path=ade20k_class_file,
456
+ image_folder=ade20k_image_path,
457
+ tokenizer=tokenizer,
458
+ image_processor=image_processor,
459
+ dataset_map_fn=semantic_seg_gcg_format_map_fn,
460
+ template_map_fn=dict(
461
+ type=template_map_fn_factory, template=prompt_template),
462
+ max_length=max_length,
463
+ pad_image_to_square=True,
464
+ debug=False,
465
+ repeats=1,
466
+ gcg_format=True,
467
+ )
468
+
469
+ semantic_seg_cocostuff_dataset = dict(
470
+ type=COCOStuffSemanticSegDataset,
471
+ data_path=cocostuff_class_file,
472
+ image_folder=cocostuff_image_path,
473
+ label_path=cocostuff_label_path,
474
+ tokenizer=tokenizer,
475
+ image_processor=image_processor,
476
+ dataset_map_fn=semantic_seg_gcg_format_map_fn,
477
+ template_map_fn=dict(
478
+ type=template_map_fn_factory, template=prompt_template),
479
+ max_length=max_length,
480
+ pad_image_to_square=True,
481
+ debug=False,
482
+ repeats=1,
483
+ gcg_format=True,
484
+ )
485
+
486
+ referring_seg_refcoco_dataset = dict(
487
+ type=RefcocoReferringSegDataset,
488
+ data_path=referring_refcoco_data_path,
489
+ image_folder=referring_refcoco_image_path,
490
+ tokenizer=tokenizer,
491
+ image_processor=image_processor,
492
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
493
+ template_map_fn=dict(
494
+ type=template_map_fn_factory, template=prompt_template),
495
+ max_length=max_length,
496
+ pad_image_to_square=True,
497
+ debug=False,
498
+ repeats=1,
499
+ )
500
+
501
+ referring_seg_refcoco_plus_dataset = dict(
502
+ type=Refcoco_plus_ReferringSegDataset,
503
+ data_path=referring_refcoco_plus_data_path,
504
+ image_folder=referring_refcoco_plus_image_path,
505
+ tokenizer=tokenizer,
506
+ image_processor=image_processor,
507
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
508
+ template_map_fn=dict(
509
+ type=template_map_fn_factory, template=prompt_template),
510
+ max_length=max_length,
511
+ pad_image_to_square=True,
512
+ debug=False,
513
+ repeats=1,
514
+ )
515
+
516
+ referring_seg_refcocog_dataset = dict(
517
+ type=Refcocog_ReferringSegDataset,
518
+ data_path=referring_refcocog_data_path,
519
+ image_folder=referring_refcocog_image_path,
520
+ tokenizer=tokenizer,
521
+ image_processor=image_processor,
522
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
523
+ template_map_fn=dict(
524
+ type=template_map_fn_factory, template=prompt_template),
525
+ max_length=max_length,
526
+ pad_image_to_square=True,
527
+ debug=False,
528
+ repeats=1,
529
+ )
530
+
531
+ referring_seg_refclef_dataset = dict(
532
+ type=Refclef_ReferringSegDataset,
533
+ data_path=referring_refclef_data_path,
534
+ image_folder=referring_refclef_image_path,
535
+ tokenizer=tokenizer,
536
+ image_processor=image_processor,
537
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
538
+ template_map_fn=dict(
539
+ type=template_map_fn_factory, template=prompt_template),
540
+ max_length=max_length,
541
+ pad_image_to_square=True,
542
+ debug=False,
543
+ repeats=1,
544
+ )
545
+
546
+ region_cap_osprey_dataset = dict(
547
+ type=OspreyRegionCaptionDataset,
548
+ data_path=region_cap_osprey_data_path,
549
+ image_folder=region_cap_osprey_image_path,
550
+ tokenizer=tokenizer,
551
+ image_processor=image_processor,
552
+ dataset_map_fn=osprey_region_caption_gcg_format_map_fn,
553
+ template_map_fn=dict(
554
+ type=template_map_fn_factory, template=prompt_template),
555
+ max_length=max_length,
556
+ pad_image_to_square=True,
557
+ debug=False,
558
+ repeats=1,
559
+ )
560
+
561
+ region_conversation_osprey_dataset = dict(
562
+ type=OspreyRegionConversationDataset,
563
+ data_path=region_conversation_osprey_data_path,
564
+ image_folder=region_conversation_osprey_image_path,
565
+ tokenizer=tokenizer,
566
+ image_processor=image_processor,
567
+ dataset_map_fn=osprey_region_conversation_map_fn,
568
+ template_map_fn=dict(
569
+ type=template_map_fn_factory, template=prompt_template),
570
+ max_length=max_length,
571
+ pad_image_to_square=True,
572
+ debug=False,
573
+ repeats=1,
574
+ )
575
+
576
+ mdpv_detailed_description_ade20k_dataset = dict(
577
+ type=MDPVPointDetailedCaptionDataset,
578
+ data_path=mdpv_detailed_caption_ade20k_data_path,
579
+ image_folder=mdpv_detailed_caption_ade20k_image_path,
580
+ tokenizer=tokenizer,
581
+ image_processor=image_processor,
582
+ dataset_map_fn=mdpv_points_map_fn,
583
+ template_map_fn=dict(
584
+ type=template_map_fn_factory, template=prompt_template),
585
+ max_length=max_length,
586
+ pad_image_to_square=True,
587
+ debug=False,
588
+ repeats=1,
589
+ )
590
+
591
+ mdpv_detailed_description_cocostuff_10k_dataset = dict(
592
+ type=MDPVPointDetailedCaptionDataset,
593
+ data_path=mdpv_detailed_caption_cocostuff_10k_data_path,
594
+ image_folder=mdpv_detailed_caption_cocostuff_10k_image_path,
595
+ tokenizer=tokenizer,
596
+ image_processor=image_processor,
597
+ dataset_map_fn=mdpv_points_map_fn,
598
+ template_map_fn=dict(
599
+ type=template_map_fn_factory, template=prompt_template),
600
+ max_length=max_length,
601
+ pad_image_to_square=True,
602
+ debug=False,
603
+ repeats=1,
604
+ )
605
+
606
+ mdpv_detailed_description_cocostuff_164k_dataset = dict(
607
+ type=MDPVPointDetailedCaptionDataset,
608
+ data_path=mdpv_detailed_caption_cocostuff_164k_data_path,
609
+ image_folder=mdpv_detailed_caption_cocostuff_164k_image_path,
610
+ tokenizer=tokenizer,
611
+ image_processor=image_processor,
612
+ dataset_map_fn=mdpv_points_map_fn,
613
+ template_map_fn=dict(
614
+ type=template_map_fn_factory, template=prompt_template),
615
+ max_length=max_length,
616
+ pad_image_to_square=True,
617
+ debug=False,
618
+ repeats=1,
619
+ )
620
+
621
+ mdpv_detailed_description_vg_dataset = dict(
622
+ type=MDPVPointDetailedCaptionDataset,
623
+ data_path=mdpv_detailed_caption_vg_data_path,
624
+ image_folder=mdpv_detailed_caption_vg_image_path,
625
+ tokenizer=tokenizer,
626
+ image_processor=image_processor,
627
+ dataset_map_fn=mdpv_points_map_fn,
628
+ template_map_fn=dict(
629
+ type=template_map_fn_factory, template=prompt_template),
630
+ max_length=max_length,
631
+ pad_image_to_square=True,
632
+ debug=False,
633
+ repeats=1,
634
+ )
635
+
636
+ mdpv_brief_description_vg_dataset = dict(
637
+ type=MDPVPointBriefCaptionDataset,
638
+ data_path=mdpv_brief_caption_vg_data_path,
639
+ image_folder=mdpv_brief_caption_vg_image_path,
640
+ tokenizer=tokenizer,
641
+ image_processor=image_processor,
642
+ dataset_map_fn=mdpv_points_map_fn,
643
+ template_map_fn=dict(
644
+ type=template_map_fn_factory, template=prompt_template),
645
+ max_length=max_length,
646
+ pad_image_to_square=True,
647
+ debug=False,
648
+ repeats=1,
649
+ )
650
+
651
+ mdpv_brief_description_cocostuff10k_dataset = dict(
652
+ type=MDPVPointBriefCaptionDataset,
653
+ data_path=mdpv_brief_caption_cocostuff_10k_data_path,
654
+ image_folder=mdpv_brief_caption_cocostuff_10k_image_path,
655
+ tokenizer=tokenizer,
656
+ image_processor=image_processor,
657
+ dataset_map_fn=mdpv_points_map_fn,
658
+ template_map_fn=dict(
659
+ type=template_map_fn_factory, template=prompt_template),
660
+ max_length=max_length,
661
+ pad_image_to_square=True,
662
+ debug=False,
663
+ repeats=1,
664
+ )
665
+
666
+ mdpv_brief_description_cocostuff164k_dataset = dict(
667
+ type=MDPVPointBriefCaptionDataset,
668
+ data_path=mdpv_brief_caption_cocostuff_164k_data_path,
669
+ image_folder=mdpv_brief_caption_cocostuff_164k_image_path,
670
+ tokenizer=tokenizer,
671
+ image_processor=image_processor,
672
+ dataset_map_fn=mdpv_points_map_fn,
673
+ template_map_fn=dict(
674
+ type=template_map_fn_factory, template=prompt_template),
675
+ max_length=max_length,
676
+ pad_image_to_square=True,
677
+ debug=False,
678
+ repeats=1,
679
+ )
680
+
681
+ mdpv_brief_description_ade20k_dataset = dict(
682
+ type=MDPVPointBriefCaptionDataset,
683
+ data_path=mdpv_brief_caption_ade20k_data_path,
684
+ image_folder=mdpv_brief_caption_ade20k_image_path,
685
+ tokenizer=tokenizer,
686
+ image_processor=image_processor,
687
+ dataset_map_fn=mdpv_points_map_fn,
688
+ template_map_fn=dict(
689
+ type=template_map_fn_factory, template=prompt_template),
690
+ max_length=max_length,
691
+ pad_image_to_square=True,
692
+ debug=False,
693
+ repeats=1,
694
+ )
695
+
696
+ mdpv_brief_description_lvis_dataset = dict(
697
+ type=MDPVPointBriefCaptionDataset,
698
+ data_path=mdpv_brief_caption_lvis_data_path,
699
+ image_folder=mdpv_brief_caption_lvis_image_path,
700
+ tokenizer=tokenizer,
701
+ image_processor=image_processor,
702
+ dataset_map_fn=mdpv_points_map_fn,
703
+ template_map_fn=dict(
704
+ type=template_map_fn_factory, template=prompt_template),
705
+ max_length=max_length,
706
+ pad_image_to_square=True,
707
+ debug=False,
708
+ repeats=1,
709
+ )
710
+
711
+ mdpv_qa_vg_dataset = dict(
712
+ type=MDPVPointBriefCaptionDataset,
713
+ data_path=mdpv_qa_vg_data_path,
714
+ image_folder=mdpv_qa_vg_image_path,
715
+ tokenizer=tokenizer,
716
+ image_processor=image_processor,
717
+ dataset_map_fn=mdpv_points_map_fn,
718
+ template_map_fn=dict(
719
+ type=template_map_fn_factory, template=prompt_template),
720
+ max_length=max_length,
721
+ pad_image_to_square=True,
722
+ debug=False,
723
+ repeats=1,
724
+ )
725
+
726
+ mdpv_qa_ade20k_dataset = dict(
727
+ type=MDPVPointBriefCaptionDataset,
728
+ data_path=mdpv_qa_ade20k_data_path,
729
+ image_folder=mdpv_qa_ade20k_image_path,
730
+ tokenizer=tokenizer,
731
+ image_processor=image_processor,
732
+ dataset_map_fn=mdpv_points_map_fn,
733
+ template_map_fn=dict(
734
+ type=template_map_fn_factory, template=prompt_template),
735
+ max_length=max_length,
736
+ pad_image_to_square=True,
737
+ debug=False,
738
+ repeats=1,
739
+ )
740
+
741
+ mdpv_qa_lvis_dataset = dict(
742
+ type=MDPVPointBriefCaptionDataset,
743
+ data_path=mdpv_qa_lvis_data_path,
744
+ image_folder=mdpv_qa_lvis_image_path,
745
+ tokenizer=tokenizer,
746
+ image_processor=image_processor,
747
+ dataset_map_fn=mdpv_points_map_fn,
748
+ template_map_fn=dict(
749
+ type=template_map_fn_factory, template=prompt_template),
750
+ max_length=max_length,
751
+ pad_image_to_square=True,
752
+ debug=False,
753
+ repeats=1,
754
+ )
755
+
756
+ mdpv_qa_cocostuff10k_dataset = dict(
757
+ type=MDPVPointBriefCaptionDataset,
758
+ data_path=mdpv_qa_cocostuff10k_data_path,
759
+ image_folder=mdpv_qa_cocostuff10k_image_path,
760
+ tokenizer=tokenizer,
761
+ image_processor=image_processor,
762
+ dataset_map_fn=mdpv_points_map_fn,
763
+ template_map_fn=dict(
764
+ type=template_map_fn_factory, template=prompt_template),
765
+ max_length=max_length,
766
+ pad_image_to_square=True,
767
+ debug=False,
768
+ repeats=1,
769
+ )
770
+
771
+ mdpv_qa_cocostuff164k_dataset = dict(
772
+ type=MDPVPointBriefCaptionDataset,
773
+ data_path=mdpv_qa_cocostuff164k_data_path,
774
+ image_folder=mdpv_qa_cocostuff164k_image_path,
775
+ tokenizer=tokenizer,
776
+ image_processor=image_processor,
777
+ dataset_map_fn=mdpv_points_map_fn,
778
+ template_map_fn=dict(
779
+ type=template_map_fn_factory, template=prompt_template),
780
+ max_length=max_length,
781
+ pad_image_to_square=True,
782
+ debug=False,
783
+ repeats=1,
784
+ )
785
+
786
+ mdpv_multi_points_openpsg_dataset = dict(
787
+ type=MDPVPointBriefCaptionDataset,
788
+ data_path=mdpv_multi_points_openpsg_data_path,
789
+ image_folder=mdpv_multi_points_openpsg_image_path,
790
+ tokenizer=tokenizer,
791
+ image_processor=image_processor,
792
+ dataset_map_fn=mdpv_points_map_fn,
793
+ template_map_fn=dict(
794
+ type=template_map_fn_factory, template=prompt_template),
795
+ max_length=max_length,
796
+ pad_image_to_square=True,
797
+ debug=False,
798
+ repeats=1,
799
+ )
800
+
801
+ mdpv_multi_points_flicker30k_dataset = dict(
802
+ type=MDPVPointBriefCaptionDataset,
803
+ data_path=mdpv_multi_points_flicker30k_data_path,
804
+ image_folder=mdpv_multi_points_flicker30k_image_path,
805
+ tokenizer=tokenizer,
806
+ image_processor=image_processor,
807
+ dataset_map_fn=mdpv_points_map_fn,
808
+ template_map_fn=dict(
809
+ type=template_map_fn_factory, template=prompt_template),
810
+ max_length=max_length,
811
+ pad_image_to_square=True,
812
+ debug=False,
813
+ repeats=1,
814
+ )
815
+
816
+ train_dataset = dict(
817
+ type=CombineDataset,
818
+ datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset,
819
+ glamm_grandf_dataset, glamm_psg_dataset,
820
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x
821
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset,
822
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset,
823
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
824
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x
825
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
826
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset,
827
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
828
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset,
829
+ region_cap_osprey_dataset, region_conversation_osprey_dataset,
830
+ mdpv_detailed_description_ade20k_dataset,
831
+ mdpv_detailed_description_cocostuff_10k_dataset,
832
+ mdpv_detailed_description_cocostuff_164k_dataset,
833
+ mdpv_detailed_description_vg_dataset,
834
+ mdpv_brief_description_lvis_dataset,
835
+ mdpv_brief_description_vg_dataset,
836
+ mdpv_brief_description_ade20k_dataset,
837
+ mdpv_brief_description_cocostuff10k_dataset,
838
+ mdpv_brief_description_cocostuff164k_dataset,
839
+ mdpv_qa_vg_dataset,
840
+ mdpv_qa_lvis_dataset,
841
+ mdpv_qa_ade20k_dataset,
842
+ mdpv_qa_cocostuff10k_dataset,
843
+ mdpv_qa_cocostuff164k_dataset,
844
+ mdpv_multi_points_flicker30k_dataset,
845
+ mdpv_multi_points_openpsg_dataset,],
846
+ )
847
+
848
+ train_dataloader = dict(
849
+ batch_size=batch_size,
850
+ num_workers=dataloader_num_workers,
851
+ dataset=train_dataset,
852
+ sampler=dict(
853
+ type=LengthGroupedSampler,
854
+ length_property='modality_length',
855
+ per_device_batch_size=batch_size * accumulative_counts),
856
+ collate_fn=dict(type=omg_llava_collate_fn))
857
+
858
+ #######################################################################
859
+ # PART 4 Scheduler & Optimizer #
860
+ #######################################################################
861
+ # optimizer
862
+ optim_wrapper = dict(
863
+ type=AmpOptimWrapper,
864
+ optimizer=dict(
865
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
866
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
867
+ accumulative_counts=accumulative_counts,
868
+ loss_scale='dynamic',
869
+ dtype='float16')
870
+
871
+ # learning policy
872
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
873
+ param_scheduler = [
874
+ dict(
875
+ type=LinearLR,
876
+ start_factor=1e-5,
877
+ by_epoch=True,
878
+ begin=0,
879
+ end=warmup_ratio * max_epochs,
880
+ convert_to_iter_based=True),
881
+ dict(
882
+ type=CosineAnnealingLR,
883
+ eta_min=0.0,
884
+ by_epoch=True,
885
+ begin=warmup_ratio * max_epochs,
886
+ end=max_epochs,
887
+ convert_to_iter_based=True)
888
+ ]
889
+
890
+ # train, val, test setting
891
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
892
+
893
+ #######################################################################
894
+ # PART 5 Runtime #
895
+ #######################################################################
896
+ # Log the dialogue periodically during the training process, optional
897
+ custom_hooks = [
898
+ dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer),
899
+ dict(
900
+ type=EvaluateChatHook_withSpecialTokens,
901
+ tokenizer=tokenizer,
902
+ image_processor=image_processor,
903
+ every_n_iters=evaluation_freq,
904
+ evaluation_inputs=evaluation_inputs,
905
+ evaluation_images=evaluation_images,
906
+ system=SYSTEM,
907
+ prompt_template=prompt_template)
908
+ ]
909
+
910
+ # configure default hooks
911
+ default_hooks = dict(
912
+ # record the time of every iteration.
913
+ timer=dict(type=IterTimerHook),
914
+ # print log every 10 iterations.
915
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
916
+ # enable the parameter scheduler.
917
+ param_scheduler=dict(type=ParamSchedulerHook),
918
+ # save checkpoint per `save_steps`.
919
+ checkpoint=dict(
920
+ type=CheckpointHook,
921
+ by_epoch=False,
922
+ interval=save_steps,
923
+ max_keep_ckpts=save_total_limit),
924
+ # set sampler seed in distributed evrionment.
925
+ sampler_seed=dict(type=DistSamplerSeedHook),
926
+ )
927
+
928
+ # configure environment
929
+ env_cfg = dict(
930
+ # whether to enable cudnn benchmark
931
+ cudnn_benchmark=False,
932
+ # set multi process parameters
933
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
934
+ # set distributed parameters
935
+ dist_cfg=dict(backend='nccl'),
936
+ )
937
+
938
+ # set visualizer
939
+ visualizer = None
940
+
941
+ # set log level
942
+ log_level = 'INFO'
943
+
944
+ # load from which checkpoint
945
+ load_from = None
946
+
947
+ # whether to resume training from the loaded checkpoint
948
+ resume = False
949
+
950
+ # Defaults to use random seed and disable `deterministic`
951
+ randomness = dict(seed=None, deterministic=False)
952
+
953
+ # set log processor
954
+ log_processor = dict(by_epoch=False)
omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_cat_debug.py ADDED
@@ -0,0 +1,927 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
4
+ LoggerHook, ParamSchedulerHook)
5
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
6
+ from peft import LoraConfig
7
+ from torch.optim import AdamW
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig, CLIPImageProcessor,
10
+ CLIPVisionModel)
11
+
12
+ from omg_llava.dataset import LLaVADataset
13
+ from omg_llava.dataset.collect_fns import omg_llava_collate_fn
14
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
15
+ from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\
16
+ CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\
17
+ ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\
18
+ PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\
19
+ RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\
20
+ Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\
21
+ OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\
22
+ OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\
23
+ MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\
24
+ semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\
25
+ referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn
26
+ from xtuner.dataset.samplers import LengthGroupedSampler
27
+ from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens
28
+ from xtuner.engine.runner import TrainLoop
29
+ from omg_llava.model import OMG_LLaVA
30
+ from xtuner.utils import PROMPT_TEMPLATE
31
+ from omg_llava.model import OpenCLIPBackbone_omgseg
32
+ from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead
33
+
34
+ from torch.nn import GroupNorm, ReLU
35
+
36
+ from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \
37
+ DiceLoss, MaskFormerFusionHead, FocalLoss
38
+ from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost
39
+ from mmdet.models.task_modules.samplers import MaskPseudoSampler
40
+
41
+ #######################################################################
42
+ # PART 1 Settings #
43
+ #######################################################################
44
+ # Model
45
+ llm_name_or_path = '/mnt/workspace/taozhang/chekpoints/internlm2-chat-7b' # Please change to your own path
46
+ pretrained_pth = './work_dirs/omg_llava_7b_pretrain_1024image_8gpus.pth'
47
+ omg_ov_class_embed_path='/mnt/workspace/taozhang/chekpoints/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path
48
+ omg_head_pretrain_pth_path = '/mnt/workspace/taozhang/chekpoints/omg_seg_convl.pth' # Please change to your own path
49
+
50
+ # Data
51
+ data_root = './data/llava_data/'
52
+ data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json'
53
+ image_folder = data_root + 'llava_images'
54
+
55
+ glamm_data_root = './data/glamm_data/'
56
+
57
+ refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/'
58
+ refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json'
59
+
60
+ grandf_image_path = glamm_data_root + 'images/grandf/train/'
61
+ grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json'
62
+
63
+ flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/'
64
+ flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json'
65
+
66
+ psg_image_path = glamm_data_root + 'images/coco2017/'
67
+ psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json'
68
+
69
+ ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
70
+ ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json'
71
+
72
+ cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/'
73
+ cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt'
74
+ cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/'
75
+
76
+ mapillary_image_path = './data/semantic_seg/mapillary/training/images/'
77
+ mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json'
78
+ mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/'
79
+
80
+ pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/'
81
+ pascal_file = './data/semantic_seg/pascal_part/train.json'
82
+
83
+ paco_image_path = './data/glamm_data/images/coco2017/'
84
+ paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json'
85
+
86
+ referring_refcoco_image_path = refcocog_image_path
87
+ referring_refcoco_data_path = "./data/ref_seg/"
88
+
89
+ referring_refcoco_plus_image_path = refcocog_image_path
90
+ referring_refcoco_plus_data_path = "./data/ref_seg/"
91
+
92
+ referring_refcocog_image_path = refcocog_image_path
93
+ referring_refcocog_data_path = "./data/ref_seg/"
94
+
95
+ referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/"
96
+ referring_refclef_data_path = "./data/ref_seg/"
97
+
98
+ region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
99
+ region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json"
100
+
101
+ region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
102
+ region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json"
103
+
104
+ mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
105
+ mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json'
106
+
107
+ mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
108
+ mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json'
109
+
110
+ mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
111
+ mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json'
112
+
113
+ mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
114
+ mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json'
115
+
116
+ mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
117
+ mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json'
118
+
119
+ mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
120
+ mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json'
121
+
122
+ mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
123
+ mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json'
124
+
125
+ mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
126
+ mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json'
127
+
128
+ mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
129
+ mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json'
130
+
131
+ mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
132
+ mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json'
133
+
134
+ mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
135
+ mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json'
136
+
137
+ mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017'
138
+ mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json'
139
+
140
+ mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
141
+ mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json'
142
+
143
+ mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
144
+ mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json'
145
+
146
+ mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/'
147
+ mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json'
148
+
149
+ mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017'
150
+ mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json'
151
+
152
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
153
+ max_length = int(2048 - (1024 / 64)**2 - 100)
154
+
155
+ # Scheduler & Optimizer
156
+ batch_size = 8 # per_device
157
+ accumulative_counts = 2
158
+ dataloader_num_workers = 0
159
+ max_epochs = 1
160
+ optim_type = AdamW
161
+ lr = 2e-4
162
+ betas = (0.9, 0.999)
163
+ weight_decay = 0
164
+ max_norm = 1 # grad clip
165
+ warmup_ratio = 0.03
166
+
167
+
168
+ # Save
169
+ save_steps = 2000
170
+ save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited)
171
+
172
+ # Evaluate the generation performance during the training
173
+ evaluation_freq = 2000
174
+ SYSTEM = ''
175
+ evaluation_images = './work_dirs/test.jpg'
176
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture',
177
+ 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.']
178
+
179
+ #######################################################################
180
+ # PART 2 Model & Tokenizer & Image Processor #
181
+ #######################################################################
182
+ tokenizer = dict(
183
+ type=AutoTokenizer.from_pretrained,
184
+ pretrained_model_name_or_path=llm_name_or_path,
185
+ trust_remote_code=True,
186
+ padding_side='right')
187
+
188
+ image_processor = dict(
189
+ type=CLIPImageProcessor,
190
+ do_resize=True,
191
+ size=1024,
192
+ resample=3,
193
+ do_center_crop=True,
194
+ crop_size=1024,
195
+ do_rescale=True,
196
+ do_normalize=True,
197
+ image_mean=[0.4814, 0.4578, 0.4082],
198
+ image_std=[0.2686, 0.2613, 0.2757],
199
+ do_convert_rgb=True
200
+ )
201
+
202
+ class_embed = 'convnext_large_d_320_CocoPanopticOVDataset'
203
+ num_things_classes = 80
204
+ num_stuff_classes = 53
205
+ num_classes = num_things_classes + num_stuff_classes
206
+
207
+ omgseg_model = dict(
208
+ type=OMGSegVisualEncoder,
209
+ data_preprocessor=None,
210
+ pixel_shuffle_down_ratio=2,
211
+ backbone=dict(
212
+ type=OpenCLIPBackbone_omgseg,
213
+ model_name='convnext_large_d_320',
214
+ fix=True,
215
+ init_cfg=dict(
216
+ type='clip_pretrain',
217
+ checkpoint='laion2b_s29b_b131k_ft_soup'
218
+ )
219
+ ),
220
+ panoptic_head=dict(
221
+ type=Mask2FormerVideoSemSamHead,
222
+ sphere_cls=True,
223
+ ov_path=omg_ov_class_embed_path,
224
+ enable_box_query=False,
225
+ ov_classifier_name=class_embed,
226
+ logit=None,
227
+ in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside
228
+ strides=[4, 8, 16, 32],
229
+ feat_channels=256,
230
+ out_channels=256,
231
+ num_things_classes=num_things_classes,
232
+ num_stuff_classes=num_stuff_classes,
233
+ num_queries=300,
234
+ num_transformer_feat_level=3,
235
+ pixel_decoder=dict(
236
+ type=MSDeformAttnPixelDecoder,
237
+ num_outs=3,
238
+ norm_cfg=dict(type=GroupNorm, num_groups=32),
239
+ act_cfg=dict(type=ReLU),
240
+ encoder=dict( # DeformableDetrTransformerEncoder
241
+ num_layers=6,
242
+ layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
243
+ self_attn_cfg=dict( # MultiScaleDeformableAttention
244
+ embed_dims=256,
245
+ num_heads=8,
246
+ num_levels=3,
247
+ num_points=4,
248
+ dropout=0.0,
249
+ batch_first=True),
250
+ ffn_cfg=dict(
251
+ embed_dims=256,
252
+ feedforward_channels=1024,
253
+ num_fcs=2,
254
+ ffn_drop=0.0,
255
+ act_cfg=dict(type=ReLU, inplace=True)))),
256
+ positional_encoding=dict(num_feats=128, normalize=True)),
257
+ enforce_decoder_input_project=False,
258
+ positional_encoding=dict(num_feats=128, normalize=True),
259
+ transformer_decoder=dict( # Mask2FormerTransformerDecoder
260
+ return_intermediate=True,
261
+ num_layers=9,
262
+ layer_cfg=dict( # Mask2FormerTransformerDecoderLayer
263
+ self_attn_cfg=dict( # MultiheadAttention
264
+ embed_dims=256,
265
+ num_heads=8,
266
+ dropout=0.0,
267
+ batch_first=True),
268
+ cross_attn_cfg=dict( # MultiheadAttention
269
+ embed_dims=256,
270
+ num_heads=8,
271
+ dropout=0.0,
272
+ batch_first=True),
273
+ ffn_cfg=dict(
274
+ embed_dims=256,
275
+ feedforward_channels=2048,
276
+ num_fcs=2,
277
+ ffn_drop=0.0,
278
+ act_cfg=dict(type='ReLU', inplace=True))),
279
+ init_cfg=None),
280
+ loss_cls=dict(
281
+ type=CrossEntropyLoss,
282
+ use_sigmoid=False,
283
+ loss_weight=2.0,
284
+ reduction='mean',
285
+ class_weight=[1.0] * 240 + [0.1]),
286
+ loss_mask=dict(
287
+ type=CrossEntropyLoss,
288
+ use_sigmoid=True,
289
+ reduction='mean',
290
+ loss_weight=5.0),
291
+ loss_dice=dict(
292
+ type=DiceLoss,
293
+ use_sigmoid=True,
294
+ activate=True,
295
+ reduction='mean',
296
+ naive_dice=True,
297
+ eps=1.0,
298
+ loss_weight=5.0),
299
+ loss_iou=dict(
300
+ type=FocalLoss,
301
+ use_sigmoid=True,
302
+ loss_weight=2.0,
303
+ reduction='mean')
304
+ ),
305
+ panoptic_fusion_head=dict(
306
+ type=MaskFormerFusionHead,
307
+ num_things_classes=num_things_classes,
308
+ num_stuff_classes=num_stuff_classes,
309
+ loss_panoptic=None,
310
+ init_cfg=None),
311
+ train_cfg=dict(
312
+ num_points=12544,
313
+ oversample_ratio=3.0,
314
+ importance_sample_ratio=0.75,
315
+ assigner=dict(
316
+ type=HungarianAssigner,
317
+ match_costs=[
318
+ # dict(type=FlexibleClassificationCost, weight=2.0),
319
+ dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True),
320
+ dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0)
321
+ ]),
322
+ sampler=dict(type=MaskPseudoSampler)),
323
+ test_cfg=dict(
324
+ panoptic_on=True,
325
+ # For now, the dataset does not support
326
+ # evaluating semantic segmentation metric.
327
+ semantic_on=False,
328
+ instance_on=True,
329
+ # max_per_image is for instance segmentation.
330
+ max_per_image=100,
331
+ iou_thr=0.8,
332
+ # In Mask2Former's panoptic postprocessing,
333
+ # it will filter mask area where score is less than 0.5 .
334
+ filter_low_score=True),
335
+ init_cfg=dict(
336
+ type='Pretrained',
337
+ checkpoint=omg_head_pretrain_pth_path,
338
+ )
339
+ )
340
+
341
+ model = dict(
342
+ type=OMG_LLaVA,
343
+ freeze_llm=True,
344
+ freeze_visual_encoder=True,
345
+ require_omg_decoder=False,
346
+ pretrained_pth=pretrained_pth,
347
+ text2vision_projector=True,
348
+ pixel_shuffle_ratio=2,
349
+ using_multilayer_states=True,
350
+ seg_token_merge_type='cat',
351
+ selected_layers=32,
352
+ llm=dict(
353
+ type=AutoModelForCausalLM.from_pretrained,
354
+ pretrained_model_name_or_path=llm_name_or_path,
355
+ trust_remote_code=True,
356
+ torch_dtype=torch.float16,
357
+ quantization_config=dict(
358
+ type=BitsAndBytesConfig,
359
+ load_in_4bit=True,
360
+ load_in_8bit=False,
361
+ llm_int8_threshold=6.0,
362
+ llm_int8_has_fp16_weight=False,
363
+ bnb_4bit_compute_dtype=torch.float16,
364
+ bnb_4bit_use_double_quant=True,
365
+ bnb_4bit_quant_type='nf4')),
366
+ llm_lora=dict(
367
+ type=LoraConfig,
368
+ r=512,
369
+ lora_alpha=256,
370
+ lora_dropout=0.05,
371
+ bias='none',
372
+ task_type='CAUSAL_LM'),
373
+ visual_encoder=omgseg_model,
374
+ tokenizer=tokenizer,
375
+ )
376
+
377
+ #######################################################################
378
+ # PART 3 Dataset & Dataloader #
379
+ #######################################################################
380
+ debug=False
381
+ llava_dataset = dict(
382
+ type=LLaVADataset,
383
+ data_path=data_path,
384
+ image_folder=image_folder,
385
+ tokenizer=tokenizer,
386
+ image_processor=image_processor,
387
+ dataset_map_fn=llava_map_fn,
388
+ template_map_fn=dict(
389
+ type=template_map_fn_factory, template=prompt_template),
390
+ max_length=max_length,
391
+ pad_image_to_square=True)
392
+
393
+ glamm_refcocog_dataset = dict(
394
+ type=RefCOCOgGCGDataset,
395
+ data_path=refcocog_ann_file,
396
+ image_folder=refcocog_image_path,
397
+ tokenizer=tokenizer,
398
+ image_processor=image_processor,
399
+ dataset_map_fn=glamm_refcocog_map_fn,
400
+ template_map_fn=dict(
401
+ type=template_map_fn_factory, template=prompt_template),
402
+ max_length=max_length,
403
+ pad_image_to_square=True,
404
+ debug=False,
405
+ repeats=1,
406
+ )
407
+
408
+ glamm_grandf_dataset = dict(
409
+ type=GranDfGCGDataset,
410
+ data_path=grandf_ann_file,
411
+ image_folder=grandf_image_path,
412
+ tokenizer=tokenizer,
413
+ image_processor=image_processor,
414
+ dataset_map_fn=glamm_granf_map_fn,
415
+ template_map_fn=dict(
416
+ type=template_map_fn_factory, template=prompt_template),
417
+ max_length=max_length,
418
+ pad_image_to_square=True,
419
+ debug=debug,
420
+ repeats=10,
421
+ )
422
+
423
+ glamm_psg_dataset = dict(
424
+ type=OpenPsgGCGDataset,
425
+ data_path=psg_ann_file,
426
+ image_folder=psg_image_path,
427
+ tokenizer=tokenizer,
428
+ image_processor=image_processor,
429
+ dataset_map_fn=glamm_openpsg_map_fn,
430
+ template_map_fn=dict(
431
+ type=template_map_fn_factory, template=prompt_template),
432
+ max_length=max_length,
433
+ pad_image_to_square=True,
434
+ debug=debug,
435
+ repeats=1,
436
+ )
437
+
438
+ glamm_flickr_dataset = dict(
439
+ type=FlickrGCGDataset,
440
+ data_path=flickr_ann_file,
441
+ image_folder=flickr_image_path,
442
+ tokenizer=tokenizer,
443
+ image_processor=image_processor,
444
+ dataset_map_fn=glamm_flickr_map_fn,
445
+ template_map_fn=dict(
446
+ type=template_map_fn_factory, template=prompt_template),
447
+ max_length=max_length,
448
+ pad_image_to_square=True,
449
+ debug=debug,
450
+ repeats=1,
451
+ )
452
+
453
+ semantic_seg_ade20k_dataset = dict(
454
+ type=ADE20kSemanticSegDataset,
455
+ data_path=ade20k_class_file,
456
+ image_folder=ade20k_image_path,
457
+ tokenizer=tokenizer,
458
+ image_processor=image_processor,
459
+ dataset_map_fn=semantic_seg_gcg_format_map_fn,
460
+ template_map_fn=dict(
461
+ type=template_map_fn_factory, template=prompt_template),
462
+ max_length=max_length,
463
+ pad_image_to_square=True,
464
+ debug=False,
465
+ repeats=1,
466
+ gcg_format=True,
467
+ )
468
+
469
+ semantic_seg_cocostuff_dataset = dict(
470
+ type=COCOStuffSemanticSegDataset,
471
+ data_path=cocostuff_class_file,
472
+ image_folder=cocostuff_image_path,
473
+ label_path=cocostuff_label_path,
474
+ tokenizer=tokenizer,
475
+ image_processor=image_processor,
476
+ dataset_map_fn=semantic_seg_gcg_format_map_fn,
477
+ template_map_fn=dict(
478
+ type=template_map_fn_factory, template=prompt_template),
479
+ max_length=max_length,
480
+ pad_image_to_square=True,
481
+ debug=False,
482
+ repeats=1,
483
+ gcg_format=True,
484
+ )
485
+
486
+ referring_seg_refcoco_dataset = dict(
487
+ type=RefcocoReferringSegDataset,
488
+ data_path=referring_refcoco_data_path,
489
+ image_folder=referring_refcoco_image_path,
490
+ tokenizer=tokenizer,
491
+ image_processor=image_processor,
492
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
493
+ template_map_fn=dict(
494
+ type=template_map_fn_factory, template=prompt_template),
495
+ max_length=max_length,
496
+ pad_image_to_square=True,
497
+ debug=False,
498
+ repeats=1,
499
+ )
500
+
501
+ referring_seg_refcoco_plus_dataset = dict(
502
+ type=Refcoco_plus_ReferringSegDataset,
503
+ data_path=referring_refcoco_plus_data_path,
504
+ image_folder=referring_refcoco_plus_image_path,
505
+ tokenizer=tokenizer,
506
+ image_processor=image_processor,
507
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
508
+ template_map_fn=dict(
509
+ type=template_map_fn_factory, template=prompt_template),
510
+ max_length=max_length,
511
+ pad_image_to_square=True,
512
+ debug=False,
513
+ repeats=1,
514
+ )
515
+
516
+ referring_seg_refcocog_dataset = dict(
517
+ type=Refcocog_ReferringSegDataset,
518
+ data_path=referring_refcocog_data_path,
519
+ image_folder=referring_refcocog_image_path,
520
+ tokenizer=tokenizer,
521
+ image_processor=image_processor,
522
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
523
+ template_map_fn=dict(
524
+ type=template_map_fn_factory, template=prompt_template),
525
+ max_length=max_length,
526
+ pad_image_to_square=True,
527
+ debug=False,
528
+ repeats=1,
529
+ )
530
+
531
+ referring_seg_refclef_dataset = dict(
532
+ type=Refclef_ReferringSegDataset,
533
+ data_path=referring_refclef_data_path,
534
+ image_folder=referring_refclef_image_path,
535
+ tokenizer=tokenizer,
536
+ image_processor=image_processor,
537
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
538
+ template_map_fn=dict(
539
+ type=template_map_fn_factory, template=prompt_template),
540
+ max_length=max_length,
541
+ pad_image_to_square=True,
542
+ debug=False,
543
+ repeats=1,
544
+ )
545
+
546
+ region_cap_osprey_dataset = dict(
547
+ type=OspreyRegionCaptionDataset,
548
+ data_path=region_cap_osprey_data_path,
549
+ image_folder=region_cap_osprey_image_path,
550
+ tokenizer=tokenizer,
551
+ image_processor=image_processor,
552
+ dataset_map_fn=osprey_region_caption_gcg_format_map_fn,
553
+ template_map_fn=dict(
554
+ type=template_map_fn_factory, template=prompt_template),
555
+ max_length=max_length,
556
+ pad_image_to_square=True,
557
+ debug=False,
558
+ repeats=1,
559
+ )
560
+
561
+ region_conversation_osprey_dataset = dict(
562
+ type=OspreyRegionConversationDataset,
563
+ data_path=region_conversation_osprey_data_path,
564
+ image_folder=region_conversation_osprey_image_path,
565
+ tokenizer=tokenizer,
566
+ image_processor=image_processor,
567
+ dataset_map_fn=osprey_region_conversation_map_fn,
568
+ template_map_fn=dict(
569
+ type=template_map_fn_factory, template=prompt_template),
570
+ max_length=max_length,
571
+ pad_image_to_square=True,
572
+ debug=False,
573
+ repeats=1,
574
+ )
575
+
576
+ mdpv_detailed_description_ade20k_dataset = dict(
577
+ type=MDPVPointDetailedCaptionDataset,
578
+ data_path=mdpv_detailed_caption_ade20k_data_path,
579
+ image_folder=mdpv_detailed_caption_ade20k_image_path,
580
+ tokenizer=tokenizer,
581
+ image_processor=image_processor,
582
+ dataset_map_fn=mdpv_points_map_fn,
583
+ template_map_fn=dict(
584
+ type=template_map_fn_factory, template=prompt_template),
585
+ max_length=max_length,
586
+ pad_image_to_square=True,
587
+ debug=False,
588
+ repeats=1,
589
+ )
590
+
591
+ mdpv_detailed_description_cocostuff_10k_dataset = dict(
592
+ type=MDPVPointDetailedCaptionDataset,
593
+ data_path=mdpv_detailed_caption_cocostuff_10k_data_path,
594
+ image_folder=mdpv_detailed_caption_cocostuff_10k_image_path,
595
+ tokenizer=tokenizer,
596
+ image_processor=image_processor,
597
+ dataset_map_fn=mdpv_points_map_fn,
598
+ template_map_fn=dict(
599
+ type=template_map_fn_factory, template=prompt_template),
600
+ max_length=max_length,
601
+ pad_image_to_square=True,
602
+ debug=False,
603
+ repeats=1,
604
+ )
605
+
606
+ mdpv_detailed_description_cocostuff_164k_dataset = dict(
607
+ type=MDPVPointDetailedCaptionDataset,
608
+ data_path=mdpv_detailed_caption_cocostuff_164k_data_path,
609
+ image_folder=mdpv_detailed_caption_cocostuff_164k_image_path,
610
+ tokenizer=tokenizer,
611
+ image_processor=image_processor,
612
+ dataset_map_fn=mdpv_points_map_fn,
613
+ template_map_fn=dict(
614
+ type=template_map_fn_factory, template=prompt_template),
615
+ max_length=max_length,
616
+ pad_image_to_square=True,
617
+ debug=False,
618
+ repeats=1,
619
+ )
620
+
621
+ mdpv_detailed_description_vg_dataset = dict(
622
+ type=MDPVPointDetailedCaptionDataset,
623
+ data_path=mdpv_detailed_caption_vg_data_path,
624
+ image_folder=mdpv_detailed_caption_vg_image_path,
625
+ tokenizer=tokenizer,
626
+ image_processor=image_processor,
627
+ dataset_map_fn=mdpv_points_map_fn,
628
+ template_map_fn=dict(
629
+ type=template_map_fn_factory, template=prompt_template),
630
+ max_length=max_length,
631
+ pad_image_to_square=True,
632
+ debug=False,
633
+ repeats=1,
634
+ )
635
+
636
+ mdpv_brief_description_vg_dataset = dict(
637
+ type=MDPVPointBriefCaptionDataset,
638
+ data_path=mdpv_brief_caption_vg_data_path,
639
+ image_folder=mdpv_brief_caption_vg_image_path,
640
+ tokenizer=tokenizer,
641
+ image_processor=image_processor,
642
+ dataset_map_fn=mdpv_points_map_fn,
643
+ template_map_fn=dict(
644
+ type=template_map_fn_factory, template=prompt_template),
645
+ max_length=max_length,
646
+ pad_image_to_square=True,
647
+ debug=False,
648
+ repeats=1,
649
+ )
650
+
651
+ mdpv_brief_description_cocostuff10k_dataset = dict(
652
+ type=MDPVPointBriefCaptionDataset,
653
+ data_path=mdpv_brief_caption_cocostuff_10k_data_path,
654
+ image_folder=mdpv_brief_caption_cocostuff_10k_image_path,
655
+ tokenizer=tokenizer,
656
+ image_processor=image_processor,
657
+ dataset_map_fn=mdpv_points_map_fn,
658
+ template_map_fn=dict(
659
+ type=template_map_fn_factory, template=prompt_template),
660
+ max_length=max_length,
661
+ pad_image_to_square=True,
662
+ debug=False,
663
+ repeats=1,
664
+ )
665
+
666
+ mdpv_brief_description_cocostuff164k_dataset = dict(
667
+ type=MDPVPointBriefCaptionDataset,
668
+ data_path=mdpv_brief_caption_cocostuff_164k_data_path,
669
+ image_folder=mdpv_brief_caption_cocostuff_164k_image_path,
670
+ tokenizer=tokenizer,
671
+ image_processor=image_processor,
672
+ dataset_map_fn=mdpv_points_map_fn,
673
+ template_map_fn=dict(
674
+ type=template_map_fn_factory, template=prompt_template),
675
+ max_length=max_length,
676
+ pad_image_to_square=True,
677
+ debug=False,
678
+ repeats=1,
679
+ )
680
+
681
+ mdpv_brief_description_ade20k_dataset = dict(
682
+ type=MDPVPointBriefCaptionDataset,
683
+ data_path=mdpv_brief_caption_ade20k_data_path,
684
+ image_folder=mdpv_brief_caption_ade20k_image_path,
685
+ tokenizer=tokenizer,
686
+ image_processor=image_processor,
687
+ dataset_map_fn=mdpv_points_map_fn,
688
+ template_map_fn=dict(
689
+ type=template_map_fn_factory, template=prompt_template),
690
+ max_length=max_length,
691
+ pad_image_to_square=True,
692
+ debug=False,
693
+ repeats=1,
694
+ )
695
+
696
+ mdpv_brief_description_lvis_dataset = dict(
697
+ type=MDPVPointBriefCaptionDataset,
698
+ data_path=mdpv_brief_caption_lvis_data_path,
699
+ image_folder=mdpv_brief_caption_lvis_image_path,
700
+ tokenizer=tokenizer,
701
+ image_processor=image_processor,
702
+ dataset_map_fn=mdpv_points_map_fn,
703
+ template_map_fn=dict(
704
+ type=template_map_fn_factory, template=prompt_template),
705
+ max_length=max_length,
706
+ pad_image_to_square=True,
707
+ debug=False,
708
+ repeats=1,
709
+ )
710
+
711
+ mdpv_qa_vg_dataset = dict(
712
+ type=MDPVPointBriefCaptionDataset,
713
+ data_path=mdpv_qa_vg_data_path,
714
+ image_folder=mdpv_qa_vg_image_path,
715
+ tokenizer=tokenizer,
716
+ image_processor=image_processor,
717
+ dataset_map_fn=mdpv_points_map_fn,
718
+ template_map_fn=dict(
719
+ type=template_map_fn_factory, template=prompt_template),
720
+ max_length=max_length,
721
+ pad_image_to_square=True,
722
+ debug=False,
723
+ repeats=1,
724
+ )
725
+
726
+ mdpv_qa_ade20k_dataset = dict(
727
+ type=MDPVPointBriefCaptionDataset,
728
+ data_path=mdpv_qa_ade20k_data_path,
729
+ image_folder=mdpv_qa_ade20k_image_path,
730
+ tokenizer=tokenizer,
731
+ image_processor=image_processor,
732
+ dataset_map_fn=mdpv_points_map_fn,
733
+ template_map_fn=dict(
734
+ type=template_map_fn_factory, template=prompt_template),
735
+ max_length=max_length,
736
+ pad_image_to_square=True,
737
+ debug=False,
738
+ repeats=1,
739
+ )
740
+
741
+ mdpv_qa_lvis_dataset = dict(
742
+ type=MDPVPointBriefCaptionDataset,
743
+ data_path=mdpv_qa_lvis_data_path,
744
+ image_folder=mdpv_qa_lvis_image_path,
745
+ tokenizer=tokenizer,
746
+ image_processor=image_processor,
747
+ dataset_map_fn=mdpv_points_map_fn,
748
+ template_map_fn=dict(
749
+ type=template_map_fn_factory, template=prompt_template),
750
+ max_length=max_length,
751
+ pad_image_to_square=True,
752
+ debug=False,
753
+ repeats=1,
754
+ )
755
+
756
+ mdpv_qa_cocostuff10k_dataset = dict(
757
+ type=MDPVPointBriefCaptionDataset,
758
+ data_path=mdpv_qa_cocostuff10k_data_path,
759
+ image_folder=mdpv_qa_cocostuff10k_image_path,
760
+ tokenizer=tokenizer,
761
+ image_processor=image_processor,
762
+ dataset_map_fn=mdpv_points_map_fn,
763
+ template_map_fn=dict(
764
+ type=template_map_fn_factory, template=prompt_template),
765
+ max_length=max_length,
766
+ pad_image_to_square=True,
767
+ debug=False,
768
+ repeats=1,
769
+ )
770
+
771
+ mdpv_qa_cocostuff164k_dataset = dict(
772
+ type=MDPVPointBriefCaptionDataset,
773
+ data_path=mdpv_qa_cocostuff164k_data_path,
774
+ image_folder=mdpv_qa_cocostuff164k_image_path,
775
+ tokenizer=tokenizer,
776
+ image_processor=image_processor,
777
+ dataset_map_fn=mdpv_points_map_fn,
778
+ template_map_fn=dict(
779
+ type=template_map_fn_factory, template=prompt_template),
780
+ max_length=max_length,
781
+ pad_image_to_square=True,
782
+ debug=False,
783
+ repeats=1,
784
+ )
785
+
786
+ mdpv_multi_points_openpsg_dataset = dict(
787
+ type=MDPVPointBriefCaptionDataset,
788
+ data_path=mdpv_multi_points_openpsg_data_path,
789
+ image_folder=mdpv_multi_points_openpsg_image_path,
790
+ tokenizer=tokenizer,
791
+ image_processor=image_processor,
792
+ dataset_map_fn=mdpv_points_map_fn,
793
+ template_map_fn=dict(
794
+ type=template_map_fn_factory, template=prompt_template),
795
+ max_length=max_length,
796
+ pad_image_to_square=True,
797
+ debug=False,
798
+ repeats=1,
799
+ )
800
+
801
+ mdpv_multi_points_flicker30k_dataset = dict(
802
+ type=MDPVPointBriefCaptionDataset,
803
+ data_path=mdpv_multi_points_flicker30k_data_path,
804
+ image_folder=mdpv_multi_points_flicker30k_image_path,
805
+ tokenizer=tokenizer,
806
+ image_processor=image_processor,
807
+ dataset_map_fn=mdpv_points_map_fn,
808
+ template_map_fn=dict(
809
+ type=template_map_fn_factory, template=prompt_template),
810
+ max_length=max_length,
811
+ pad_image_to_square=True,
812
+ debug=False,
813
+ repeats=1,
814
+ )
815
+
816
+ train_dataset = dict(
817
+ type=CombineDataset,
818
+ datasets_cfgs=[glamm_refcocog_dataset, ],
819
+ )
820
+
821
+ train_dataloader = dict(
822
+ batch_size=batch_size,
823
+ num_workers=dataloader_num_workers,
824
+ dataset=train_dataset,
825
+ sampler=dict(
826
+ type=LengthGroupedSampler,
827
+ length_property='modality_length',
828
+ per_device_batch_size=batch_size * accumulative_counts),
829
+ collate_fn=dict(type=omg_llava_collate_fn))
830
+
831
+ #######################################################################
832
+ # PART 4 Scheduler & Optimizer #
833
+ #######################################################################
834
+ # optimizer
835
+ optim_wrapper = dict(
836
+ type=AmpOptimWrapper,
837
+ optimizer=dict(
838
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
839
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
840
+ accumulative_counts=accumulative_counts,
841
+ loss_scale='dynamic',
842
+ dtype='float16')
843
+
844
+ # learning policy
845
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
846
+ param_scheduler = [
847
+ dict(
848
+ type=LinearLR,
849
+ start_factor=1e-5,
850
+ by_epoch=True,
851
+ begin=0,
852
+ end=warmup_ratio * max_epochs,
853
+ convert_to_iter_based=True),
854
+ dict(
855
+ type=CosineAnnealingLR,
856
+ eta_min=0.0,
857
+ by_epoch=True,
858
+ begin=warmup_ratio * max_epochs,
859
+ end=max_epochs,
860
+ convert_to_iter_based=True)
861
+ ]
862
+
863
+ # train, val, test setting
864
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
865
+
866
+ #######################################################################
867
+ # PART 5 Runtime #
868
+ #######################################################################
869
+ # Log the dialogue periodically during the training process, optional
870
+ custom_hooks = [
871
+ dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer),
872
+ dict(
873
+ type=EvaluateChatHook_withSpecialTokens,
874
+ tokenizer=tokenizer,
875
+ image_processor=image_processor,
876
+ every_n_iters=evaluation_freq,
877
+ evaluation_inputs=evaluation_inputs,
878
+ evaluation_images=evaluation_images,
879
+ system=SYSTEM,
880
+ prompt_template=prompt_template)
881
+ ]
882
+
883
+ # configure default hooks
884
+ default_hooks = dict(
885
+ # record the time of every iteration.
886
+ timer=dict(type=IterTimerHook),
887
+ # print log every 10 iterations.
888
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
889
+ # enable the parameter scheduler.
890
+ param_scheduler=dict(type=ParamSchedulerHook),
891
+ # save checkpoint per `save_steps`.
892
+ checkpoint=dict(
893
+ type=CheckpointHook,
894
+ by_epoch=False,
895
+ interval=save_steps,
896
+ max_keep_ckpts=save_total_limit),
897
+ # set sampler seed in distributed evrionment.
898
+ sampler_seed=dict(type=DistSamplerSeedHook),
899
+ )
900
+
901
+ # configure environment
902
+ env_cfg = dict(
903
+ # whether to enable cudnn benchmark
904
+ cudnn_benchmark=False,
905
+ # set multi process parameters
906
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
907
+ # set distributed parameters
908
+ dist_cfg=dict(backend='nccl'),
909
+ )
910
+
911
+ # set visualizer
912
+ visualizer = None
913
+
914
+ # set log level
915
+ log_level = 'INFO'
916
+
917
+ # load from which checkpoint
918
+ load_from = None
919
+
920
+ # whether to resume training from the loaded checkpoint
921
+ resume = False
922
+
923
+ # Defaults to use random seed and disable `deterministic`
924
+ randomness = dict(seed=None, deterministic=False)
925
+
926
+ # set log processor
927
+ log_processor = dict(by_epoch=False)
omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_linear_cat.py ADDED
@@ -0,0 +1,954 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
4
+ LoggerHook, ParamSchedulerHook)
5
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
6
+ from peft import LoraConfig
7
+ from torch.optim import AdamW
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig, CLIPImageProcessor,
10
+ CLIPVisionModel)
11
+
12
+ from omg_llava.dataset import LLaVADataset
13
+ from omg_llava.dataset.collect_fns import omg_llava_collate_fn
14
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
15
+ from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\
16
+ CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\
17
+ ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\
18
+ PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\
19
+ RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\
20
+ Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\
21
+ OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\
22
+ OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\
23
+ MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\
24
+ semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\
25
+ referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn
26
+ from xtuner.dataset.samplers import LengthGroupedSampler
27
+ from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens
28
+ from xtuner.engine.runner import TrainLoop
29
+ from omg_llava.model import OMG_LLaVA
30
+ from xtuner.utils import PROMPT_TEMPLATE
31
+ from omg_llava.model import OpenCLIPBackbone_omgseg
32
+ from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead
33
+
34
+ from torch.nn import GroupNorm, ReLU
35
+
36
+ from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \
37
+ DiceLoss, MaskFormerFusionHead, FocalLoss
38
+ from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost
39
+ from mmdet.models.task_modules.samplers import MaskPseudoSampler
40
+
41
+ #######################################################################
42
+ # PART 1 Settings #
43
+ #######################################################################
44
+ # Model
45
+ llm_name_or_path = '/mnt/workspace/taozhang/chekpoints/internlm2-chat-7b' # Please change to your own path
46
+ pretrained_pth = './work_dirs/omg_llava_7b_pretrain_1024image_8gpus.pth'
47
+ omg_ov_class_embed_path='/mnt/workspace/taozhang/chekpoints/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path
48
+ omg_head_pretrain_pth_path = '/mnt/workspace/taozhang/chekpoints/omg_seg_convl.pth' # Please change to your own path
49
+
50
+ # Data
51
+ data_root = './data/llava_data/'
52
+ data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json'
53
+ image_folder = data_root + 'llava_images'
54
+
55
+ glamm_data_root = './data/glamm_data/'
56
+
57
+ refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/'
58
+ refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json'
59
+
60
+ grandf_image_path = glamm_data_root + 'images/grandf/train/'
61
+ grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json'
62
+
63
+ flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/'
64
+ flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json'
65
+
66
+ psg_image_path = glamm_data_root + 'images/coco2017/'
67
+ psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json'
68
+
69
+ ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
70
+ ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json'
71
+
72
+ cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/'
73
+ cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt'
74
+ cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/'
75
+
76
+ mapillary_image_path = './data/semantic_seg/mapillary/training/images/'
77
+ mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json'
78
+ mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/'
79
+
80
+ pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/'
81
+ pascal_file = './data/semantic_seg/pascal_part/train.json'
82
+
83
+ paco_image_path = './data/glamm_data/images/coco2017/'
84
+ paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json'
85
+
86
+ referring_refcoco_image_path = refcocog_image_path
87
+ referring_refcoco_data_path = "./data/ref_seg/"
88
+
89
+ referring_refcoco_plus_image_path = refcocog_image_path
90
+ referring_refcoco_plus_data_path = "./data/ref_seg/"
91
+
92
+ referring_refcocog_image_path = refcocog_image_path
93
+ referring_refcocog_data_path = "./data/ref_seg/"
94
+
95
+ referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/"
96
+ referring_refclef_data_path = "./data/ref_seg/"
97
+
98
+ region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
99
+ region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json"
100
+
101
+ region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
102
+ region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json"
103
+
104
+ mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
105
+ mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json'
106
+
107
+ mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
108
+ mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json'
109
+
110
+ mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
111
+ mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json'
112
+
113
+ mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
114
+ mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json'
115
+
116
+ mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
117
+ mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json'
118
+
119
+ mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
120
+ mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json'
121
+
122
+ mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
123
+ mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json'
124
+
125
+ mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
126
+ mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json'
127
+
128
+ mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
129
+ mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json'
130
+
131
+ mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
132
+ mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json'
133
+
134
+ mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
135
+ mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json'
136
+
137
+ mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017'
138
+ mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json'
139
+
140
+ mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
141
+ mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json'
142
+
143
+ mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
144
+ mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json'
145
+
146
+ mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/'
147
+ mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json'
148
+
149
+ mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017'
150
+ mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json'
151
+
152
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
153
+ max_length = int(2048 - (1024 / 64)**2 - 100)
154
+
155
+ # Scheduler & Optimizer
156
+ batch_size = 8 # per_device
157
+ accumulative_counts = 2
158
+ dataloader_num_workers = 0
159
+ max_epochs = 1
160
+ optim_type = AdamW
161
+ lr = 2e-4
162
+ betas = (0.9, 0.999)
163
+ weight_decay = 0
164
+ max_norm = 1 # grad clip
165
+ warmup_ratio = 0.03
166
+
167
+
168
+ # Save
169
+ save_steps = 2000
170
+ save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited)
171
+
172
+ # Evaluate the generation performance during the training
173
+ evaluation_freq = 2000
174
+ SYSTEM = ''
175
+ evaluation_images = './work_dirs/test.jpg'
176
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture',
177
+ 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.']
178
+
179
+ #######################################################################
180
+ # PART 2 Model & Tokenizer & Image Processor #
181
+ #######################################################################
182
+ tokenizer = dict(
183
+ type=AutoTokenizer.from_pretrained,
184
+ pretrained_model_name_or_path=llm_name_or_path,
185
+ trust_remote_code=True,
186
+ padding_side='right')
187
+
188
+ image_processor = dict(
189
+ type=CLIPImageProcessor,
190
+ do_resize=True,
191
+ size=1024,
192
+ resample=3,
193
+ do_center_crop=True,
194
+ crop_size=1024,
195
+ do_rescale=True,
196
+ do_normalize=True,
197
+ image_mean=[0.4814, 0.4578, 0.4082],
198
+ image_std=[0.2686, 0.2613, 0.2757],
199
+ do_convert_rgb=True
200
+ )
201
+
202
+ class_embed = 'convnext_large_d_320_CocoPanopticOVDataset'
203
+ num_things_classes = 80
204
+ num_stuff_classes = 53
205
+ num_classes = num_things_classes + num_stuff_classes
206
+
207
+ omgseg_model = dict(
208
+ type=OMGSegVisualEncoder,
209
+ data_preprocessor=None,
210
+ pixel_shuffle_down_ratio=2,
211
+ backbone=dict(
212
+ type=OpenCLIPBackbone_omgseg,
213
+ model_name='convnext_large_d_320',
214
+ fix=True,
215
+ init_cfg=dict(
216
+ type='clip_pretrain',
217
+ checkpoint='laion2b_s29b_b131k_ft_soup'
218
+ )
219
+ ),
220
+ panoptic_head=dict(
221
+ type=Mask2FormerVideoSemSamHead,
222
+ sphere_cls=True,
223
+ ov_path=omg_ov_class_embed_path,
224
+ enable_box_query=False,
225
+ ov_classifier_name=class_embed,
226
+ logit=None,
227
+ in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside
228
+ strides=[4, 8, 16, 32],
229
+ feat_channels=256,
230
+ out_channels=256,
231
+ num_things_classes=num_things_classes,
232
+ num_stuff_classes=num_stuff_classes,
233
+ num_queries=300,
234
+ num_transformer_feat_level=3,
235
+ pixel_decoder=dict(
236
+ type=MSDeformAttnPixelDecoder,
237
+ num_outs=3,
238
+ norm_cfg=dict(type=GroupNorm, num_groups=32),
239
+ act_cfg=dict(type=ReLU),
240
+ encoder=dict( # DeformableDetrTransformerEncoder
241
+ num_layers=6,
242
+ layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
243
+ self_attn_cfg=dict( # MultiScaleDeformableAttention
244
+ embed_dims=256,
245
+ num_heads=8,
246
+ num_levels=3,
247
+ num_points=4,
248
+ dropout=0.0,
249
+ batch_first=True),
250
+ ffn_cfg=dict(
251
+ embed_dims=256,
252
+ feedforward_channels=1024,
253
+ num_fcs=2,
254
+ ffn_drop=0.0,
255
+ act_cfg=dict(type=ReLU, inplace=True)))),
256
+ positional_encoding=dict(num_feats=128, normalize=True)),
257
+ enforce_decoder_input_project=False,
258
+ positional_encoding=dict(num_feats=128, normalize=True),
259
+ transformer_decoder=dict( # Mask2FormerTransformerDecoder
260
+ return_intermediate=True,
261
+ num_layers=9,
262
+ layer_cfg=dict( # Mask2FormerTransformerDecoderLayer
263
+ self_attn_cfg=dict( # MultiheadAttention
264
+ embed_dims=256,
265
+ num_heads=8,
266
+ dropout=0.0,
267
+ batch_first=True),
268
+ cross_attn_cfg=dict( # MultiheadAttention
269
+ embed_dims=256,
270
+ num_heads=8,
271
+ dropout=0.0,
272
+ batch_first=True),
273
+ ffn_cfg=dict(
274
+ embed_dims=256,
275
+ feedforward_channels=2048,
276
+ num_fcs=2,
277
+ ffn_drop=0.0,
278
+ act_cfg=dict(type='ReLU', inplace=True))),
279
+ init_cfg=None),
280
+ loss_cls=dict(
281
+ type=CrossEntropyLoss,
282
+ use_sigmoid=False,
283
+ loss_weight=2.0,
284
+ reduction='mean',
285
+ class_weight=[1.0] * 240 + [0.1]),
286
+ loss_mask=dict(
287
+ type=CrossEntropyLoss,
288
+ use_sigmoid=True,
289
+ reduction='mean',
290
+ loss_weight=5.0),
291
+ loss_dice=dict(
292
+ type=DiceLoss,
293
+ use_sigmoid=True,
294
+ activate=True,
295
+ reduction='mean',
296
+ naive_dice=True,
297
+ eps=1.0,
298
+ loss_weight=5.0),
299
+ loss_iou=dict(
300
+ type=FocalLoss,
301
+ use_sigmoid=True,
302
+ loss_weight=2.0,
303
+ reduction='mean')
304
+ ),
305
+ panoptic_fusion_head=dict(
306
+ type=MaskFormerFusionHead,
307
+ num_things_classes=num_things_classes,
308
+ num_stuff_classes=num_stuff_classes,
309
+ loss_panoptic=None,
310
+ init_cfg=None),
311
+ train_cfg=dict(
312
+ num_points=12544,
313
+ oversample_ratio=3.0,
314
+ importance_sample_ratio=0.75,
315
+ assigner=dict(
316
+ type=HungarianAssigner,
317
+ match_costs=[
318
+ # dict(type=FlexibleClassificationCost, weight=2.0),
319
+ dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True),
320
+ dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0)
321
+ ]),
322
+ sampler=dict(type=MaskPseudoSampler)),
323
+ test_cfg=dict(
324
+ panoptic_on=True,
325
+ # For now, the dataset does not support
326
+ # evaluating semantic segmentation metric.
327
+ semantic_on=False,
328
+ instance_on=True,
329
+ # max_per_image is for instance segmentation.
330
+ max_per_image=100,
331
+ iou_thr=0.8,
332
+ # In Mask2Former's panoptic postprocessing,
333
+ # it will filter mask area where score is less than 0.5 .
334
+ filter_low_score=True),
335
+ init_cfg=dict(
336
+ type='Pretrained',
337
+ checkpoint=omg_head_pretrain_pth_path,
338
+ )
339
+ )
340
+
341
+ model = dict(
342
+ type=OMG_LLaVA,
343
+ freeze_llm=True,
344
+ freeze_visual_encoder=True,
345
+ require_omg_decoder=False,
346
+ pretrained_pth=pretrained_pth,
347
+ text2vision_projector=True,
348
+ pixel_shuffle_ratio=2,
349
+ using_multilayer_states=True,
350
+ seg_token_merge_type='linear_cat',
351
+ selected_layers=32,
352
+ llm=dict(
353
+ type=AutoModelForCausalLM.from_pretrained,
354
+ pretrained_model_name_or_path=llm_name_or_path,
355
+ trust_remote_code=True,
356
+ torch_dtype=torch.float16,
357
+ quantization_config=dict(
358
+ type=BitsAndBytesConfig,
359
+ load_in_4bit=True,
360
+ load_in_8bit=False,
361
+ llm_int8_threshold=6.0,
362
+ llm_int8_has_fp16_weight=False,
363
+ bnb_4bit_compute_dtype=torch.float16,
364
+ bnb_4bit_use_double_quant=True,
365
+ bnb_4bit_quant_type='nf4')),
366
+ llm_lora=dict(
367
+ type=LoraConfig,
368
+ r=512,
369
+ lora_alpha=256,
370
+ lora_dropout=0.05,
371
+ bias='none',
372
+ task_type='CAUSAL_LM'),
373
+ visual_encoder=omgseg_model,
374
+ tokenizer=tokenizer,
375
+ )
376
+
377
+ #######################################################################
378
+ # PART 3 Dataset & Dataloader #
379
+ #######################################################################
380
+ debug=False
381
+ llava_dataset = dict(
382
+ type=LLaVADataset,
383
+ data_path=data_path,
384
+ image_folder=image_folder,
385
+ tokenizer=tokenizer,
386
+ image_processor=image_processor,
387
+ dataset_map_fn=llava_map_fn,
388
+ template_map_fn=dict(
389
+ type=template_map_fn_factory, template=prompt_template),
390
+ max_length=max_length,
391
+ pad_image_to_square=True)
392
+
393
+ glamm_refcocog_dataset = dict(
394
+ type=RefCOCOgGCGDataset,
395
+ data_path=refcocog_ann_file,
396
+ image_folder=refcocog_image_path,
397
+ tokenizer=tokenizer,
398
+ image_processor=image_processor,
399
+ dataset_map_fn=glamm_refcocog_map_fn,
400
+ template_map_fn=dict(
401
+ type=template_map_fn_factory, template=prompt_template),
402
+ max_length=max_length,
403
+ pad_image_to_square=True,
404
+ debug=False,
405
+ repeats=1,
406
+ )
407
+
408
+ glamm_grandf_dataset = dict(
409
+ type=GranDfGCGDataset,
410
+ data_path=grandf_ann_file,
411
+ image_folder=grandf_image_path,
412
+ tokenizer=tokenizer,
413
+ image_processor=image_processor,
414
+ dataset_map_fn=glamm_granf_map_fn,
415
+ template_map_fn=dict(
416
+ type=template_map_fn_factory, template=prompt_template),
417
+ max_length=max_length,
418
+ pad_image_to_square=True,
419
+ debug=debug,
420
+ repeats=10,
421
+ )
422
+
423
+ glamm_psg_dataset = dict(
424
+ type=OpenPsgGCGDataset,
425
+ data_path=psg_ann_file,
426
+ image_folder=psg_image_path,
427
+ tokenizer=tokenizer,
428
+ image_processor=image_processor,
429
+ dataset_map_fn=glamm_openpsg_map_fn,
430
+ template_map_fn=dict(
431
+ type=template_map_fn_factory, template=prompt_template),
432
+ max_length=max_length,
433
+ pad_image_to_square=True,
434
+ debug=debug,
435
+ repeats=1,
436
+ )
437
+
438
+ glamm_flickr_dataset = dict(
439
+ type=FlickrGCGDataset,
440
+ data_path=flickr_ann_file,
441
+ image_folder=flickr_image_path,
442
+ tokenizer=tokenizer,
443
+ image_processor=image_processor,
444
+ dataset_map_fn=glamm_flickr_map_fn,
445
+ template_map_fn=dict(
446
+ type=template_map_fn_factory, template=prompt_template),
447
+ max_length=max_length,
448
+ pad_image_to_square=True,
449
+ debug=debug,
450
+ repeats=1,
451
+ )
452
+
453
+ semantic_seg_ade20k_dataset = dict(
454
+ type=ADE20kSemanticSegDataset,
455
+ data_path=ade20k_class_file,
456
+ image_folder=ade20k_image_path,
457
+ tokenizer=tokenizer,
458
+ image_processor=image_processor,
459
+ dataset_map_fn=semantic_seg_gcg_format_map_fn,
460
+ template_map_fn=dict(
461
+ type=template_map_fn_factory, template=prompt_template),
462
+ max_length=max_length,
463
+ pad_image_to_square=True,
464
+ debug=False,
465
+ repeats=1,
466
+ gcg_format=True,
467
+ )
468
+
469
+ semantic_seg_cocostuff_dataset = dict(
470
+ type=COCOStuffSemanticSegDataset,
471
+ data_path=cocostuff_class_file,
472
+ image_folder=cocostuff_image_path,
473
+ label_path=cocostuff_label_path,
474
+ tokenizer=tokenizer,
475
+ image_processor=image_processor,
476
+ dataset_map_fn=semantic_seg_gcg_format_map_fn,
477
+ template_map_fn=dict(
478
+ type=template_map_fn_factory, template=prompt_template),
479
+ max_length=max_length,
480
+ pad_image_to_square=True,
481
+ debug=False,
482
+ repeats=1,
483
+ gcg_format=True,
484
+ )
485
+
486
+ referring_seg_refcoco_dataset = dict(
487
+ type=RefcocoReferringSegDataset,
488
+ data_path=referring_refcoco_data_path,
489
+ image_folder=referring_refcoco_image_path,
490
+ tokenizer=tokenizer,
491
+ image_processor=image_processor,
492
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
493
+ template_map_fn=dict(
494
+ type=template_map_fn_factory, template=prompt_template),
495
+ max_length=max_length,
496
+ pad_image_to_square=True,
497
+ debug=False,
498
+ repeats=1,
499
+ )
500
+
501
+ referring_seg_refcoco_plus_dataset = dict(
502
+ type=Refcoco_plus_ReferringSegDataset,
503
+ data_path=referring_refcoco_plus_data_path,
504
+ image_folder=referring_refcoco_plus_image_path,
505
+ tokenizer=tokenizer,
506
+ image_processor=image_processor,
507
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
508
+ template_map_fn=dict(
509
+ type=template_map_fn_factory, template=prompt_template),
510
+ max_length=max_length,
511
+ pad_image_to_square=True,
512
+ debug=False,
513
+ repeats=1,
514
+ )
515
+
516
+ referring_seg_refcocog_dataset = dict(
517
+ type=Refcocog_ReferringSegDataset,
518
+ data_path=referring_refcocog_data_path,
519
+ image_folder=referring_refcocog_image_path,
520
+ tokenizer=tokenizer,
521
+ image_processor=image_processor,
522
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
523
+ template_map_fn=dict(
524
+ type=template_map_fn_factory, template=prompt_template),
525
+ max_length=max_length,
526
+ pad_image_to_square=True,
527
+ debug=False,
528
+ repeats=1,
529
+ )
530
+
531
+ referring_seg_refclef_dataset = dict(
532
+ type=Refclef_ReferringSegDataset,
533
+ data_path=referring_refclef_data_path,
534
+ image_folder=referring_refclef_image_path,
535
+ tokenizer=tokenizer,
536
+ image_processor=image_processor,
537
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
538
+ template_map_fn=dict(
539
+ type=template_map_fn_factory, template=prompt_template),
540
+ max_length=max_length,
541
+ pad_image_to_square=True,
542
+ debug=False,
543
+ repeats=1,
544
+ )
545
+
546
+ region_cap_osprey_dataset = dict(
547
+ type=OspreyRegionCaptionDataset,
548
+ data_path=region_cap_osprey_data_path,
549
+ image_folder=region_cap_osprey_image_path,
550
+ tokenizer=tokenizer,
551
+ image_processor=image_processor,
552
+ dataset_map_fn=osprey_region_caption_gcg_format_map_fn,
553
+ template_map_fn=dict(
554
+ type=template_map_fn_factory, template=prompt_template),
555
+ max_length=max_length,
556
+ pad_image_to_square=True,
557
+ debug=False,
558
+ repeats=1,
559
+ )
560
+
561
+ region_conversation_osprey_dataset = dict(
562
+ type=OspreyRegionConversationDataset,
563
+ data_path=region_conversation_osprey_data_path,
564
+ image_folder=region_conversation_osprey_image_path,
565
+ tokenizer=tokenizer,
566
+ image_processor=image_processor,
567
+ dataset_map_fn=osprey_region_conversation_map_fn,
568
+ template_map_fn=dict(
569
+ type=template_map_fn_factory, template=prompt_template),
570
+ max_length=max_length,
571
+ pad_image_to_square=True,
572
+ debug=False,
573
+ repeats=1,
574
+ )
575
+
576
+ mdpv_detailed_description_ade20k_dataset = dict(
577
+ type=MDPVPointDetailedCaptionDataset,
578
+ data_path=mdpv_detailed_caption_ade20k_data_path,
579
+ image_folder=mdpv_detailed_caption_ade20k_image_path,
580
+ tokenizer=tokenizer,
581
+ image_processor=image_processor,
582
+ dataset_map_fn=mdpv_points_map_fn,
583
+ template_map_fn=dict(
584
+ type=template_map_fn_factory, template=prompt_template),
585
+ max_length=max_length,
586
+ pad_image_to_square=True,
587
+ debug=False,
588
+ repeats=1,
589
+ )
590
+
591
+ mdpv_detailed_description_cocostuff_10k_dataset = dict(
592
+ type=MDPVPointDetailedCaptionDataset,
593
+ data_path=mdpv_detailed_caption_cocostuff_10k_data_path,
594
+ image_folder=mdpv_detailed_caption_cocostuff_10k_image_path,
595
+ tokenizer=tokenizer,
596
+ image_processor=image_processor,
597
+ dataset_map_fn=mdpv_points_map_fn,
598
+ template_map_fn=dict(
599
+ type=template_map_fn_factory, template=prompt_template),
600
+ max_length=max_length,
601
+ pad_image_to_square=True,
602
+ debug=False,
603
+ repeats=1,
604
+ )
605
+
606
+ mdpv_detailed_description_cocostuff_164k_dataset = dict(
607
+ type=MDPVPointDetailedCaptionDataset,
608
+ data_path=mdpv_detailed_caption_cocostuff_164k_data_path,
609
+ image_folder=mdpv_detailed_caption_cocostuff_164k_image_path,
610
+ tokenizer=tokenizer,
611
+ image_processor=image_processor,
612
+ dataset_map_fn=mdpv_points_map_fn,
613
+ template_map_fn=dict(
614
+ type=template_map_fn_factory, template=prompt_template),
615
+ max_length=max_length,
616
+ pad_image_to_square=True,
617
+ debug=False,
618
+ repeats=1,
619
+ )
620
+
621
+ mdpv_detailed_description_vg_dataset = dict(
622
+ type=MDPVPointDetailedCaptionDataset,
623
+ data_path=mdpv_detailed_caption_vg_data_path,
624
+ image_folder=mdpv_detailed_caption_vg_image_path,
625
+ tokenizer=tokenizer,
626
+ image_processor=image_processor,
627
+ dataset_map_fn=mdpv_points_map_fn,
628
+ template_map_fn=dict(
629
+ type=template_map_fn_factory, template=prompt_template),
630
+ max_length=max_length,
631
+ pad_image_to_square=True,
632
+ debug=False,
633
+ repeats=1,
634
+ )
635
+
636
+ mdpv_brief_description_vg_dataset = dict(
637
+ type=MDPVPointBriefCaptionDataset,
638
+ data_path=mdpv_brief_caption_vg_data_path,
639
+ image_folder=mdpv_brief_caption_vg_image_path,
640
+ tokenizer=tokenizer,
641
+ image_processor=image_processor,
642
+ dataset_map_fn=mdpv_points_map_fn,
643
+ template_map_fn=dict(
644
+ type=template_map_fn_factory, template=prompt_template),
645
+ max_length=max_length,
646
+ pad_image_to_square=True,
647
+ debug=False,
648
+ repeats=1,
649
+ )
650
+
651
+ mdpv_brief_description_cocostuff10k_dataset = dict(
652
+ type=MDPVPointBriefCaptionDataset,
653
+ data_path=mdpv_brief_caption_cocostuff_10k_data_path,
654
+ image_folder=mdpv_brief_caption_cocostuff_10k_image_path,
655
+ tokenizer=tokenizer,
656
+ image_processor=image_processor,
657
+ dataset_map_fn=mdpv_points_map_fn,
658
+ template_map_fn=dict(
659
+ type=template_map_fn_factory, template=prompt_template),
660
+ max_length=max_length,
661
+ pad_image_to_square=True,
662
+ debug=False,
663
+ repeats=1,
664
+ )
665
+
666
+ mdpv_brief_description_cocostuff164k_dataset = dict(
667
+ type=MDPVPointBriefCaptionDataset,
668
+ data_path=mdpv_brief_caption_cocostuff_164k_data_path,
669
+ image_folder=mdpv_brief_caption_cocostuff_164k_image_path,
670
+ tokenizer=tokenizer,
671
+ image_processor=image_processor,
672
+ dataset_map_fn=mdpv_points_map_fn,
673
+ template_map_fn=dict(
674
+ type=template_map_fn_factory, template=prompt_template),
675
+ max_length=max_length,
676
+ pad_image_to_square=True,
677
+ debug=False,
678
+ repeats=1,
679
+ )
680
+
681
+ mdpv_brief_description_ade20k_dataset = dict(
682
+ type=MDPVPointBriefCaptionDataset,
683
+ data_path=mdpv_brief_caption_ade20k_data_path,
684
+ image_folder=mdpv_brief_caption_ade20k_image_path,
685
+ tokenizer=tokenizer,
686
+ image_processor=image_processor,
687
+ dataset_map_fn=mdpv_points_map_fn,
688
+ template_map_fn=dict(
689
+ type=template_map_fn_factory, template=prompt_template),
690
+ max_length=max_length,
691
+ pad_image_to_square=True,
692
+ debug=False,
693
+ repeats=1,
694
+ )
695
+
696
+ mdpv_brief_description_lvis_dataset = dict(
697
+ type=MDPVPointBriefCaptionDataset,
698
+ data_path=mdpv_brief_caption_lvis_data_path,
699
+ image_folder=mdpv_brief_caption_lvis_image_path,
700
+ tokenizer=tokenizer,
701
+ image_processor=image_processor,
702
+ dataset_map_fn=mdpv_points_map_fn,
703
+ template_map_fn=dict(
704
+ type=template_map_fn_factory, template=prompt_template),
705
+ max_length=max_length,
706
+ pad_image_to_square=True,
707
+ debug=False,
708
+ repeats=1,
709
+ )
710
+
711
+ mdpv_qa_vg_dataset = dict(
712
+ type=MDPVPointBriefCaptionDataset,
713
+ data_path=mdpv_qa_vg_data_path,
714
+ image_folder=mdpv_qa_vg_image_path,
715
+ tokenizer=tokenizer,
716
+ image_processor=image_processor,
717
+ dataset_map_fn=mdpv_points_map_fn,
718
+ template_map_fn=dict(
719
+ type=template_map_fn_factory, template=prompt_template),
720
+ max_length=max_length,
721
+ pad_image_to_square=True,
722
+ debug=False,
723
+ repeats=1,
724
+ )
725
+
726
+ mdpv_qa_ade20k_dataset = dict(
727
+ type=MDPVPointBriefCaptionDataset,
728
+ data_path=mdpv_qa_ade20k_data_path,
729
+ image_folder=mdpv_qa_ade20k_image_path,
730
+ tokenizer=tokenizer,
731
+ image_processor=image_processor,
732
+ dataset_map_fn=mdpv_points_map_fn,
733
+ template_map_fn=dict(
734
+ type=template_map_fn_factory, template=prompt_template),
735
+ max_length=max_length,
736
+ pad_image_to_square=True,
737
+ debug=False,
738
+ repeats=1,
739
+ )
740
+
741
+ mdpv_qa_lvis_dataset = dict(
742
+ type=MDPVPointBriefCaptionDataset,
743
+ data_path=mdpv_qa_lvis_data_path,
744
+ image_folder=mdpv_qa_lvis_image_path,
745
+ tokenizer=tokenizer,
746
+ image_processor=image_processor,
747
+ dataset_map_fn=mdpv_points_map_fn,
748
+ template_map_fn=dict(
749
+ type=template_map_fn_factory, template=prompt_template),
750
+ max_length=max_length,
751
+ pad_image_to_square=True,
752
+ debug=False,
753
+ repeats=1,
754
+ )
755
+
756
+ mdpv_qa_cocostuff10k_dataset = dict(
757
+ type=MDPVPointBriefCaptionDataset,
758
+ data_path=mdpv_qa_cocostuff10k_data_path,
759
+ image_folder=mdpv_qa_cocostuff10k_image_path,
760
+ tokenizer=tokenizer,
761
+ image_processor=image_processor,
762
+ dataset_map_fn=mdpv_points_map_fn,
763
+ template_map_fn=dict(
764
+ type=template_map_fn_factory, template=prompt_template),
765
+ max_length=max_length,
766
+ pad_image_to_square=True,
767
+ debug=False,
768
+ repeats=1,
769
+ )
770
+
771
+ mdpv_qa_cocostuff164k_dataset = dict(
772
+ type=MDPVPointBriefCaptionDataset,
773
+ data_path=mdpv_qa_cocostuff164k_data_path,
774
+ image_folder=mdpv_qa_cocostuff164k_image_path,
775
+ tokenizer=tokenizer,
776
+ image_processor=image_processor,
777
+ dataset_map_fn=mdpv_points_map_fn,
778
+ template_map_fn=dict(
779
+ type=template_map_fn_factory, template=prompt_template),
780
+ max_length=max_length,
781
+ pad_image_to_square=True,
782
+ debug=False,
783
+ repeats=1,
784
+ )
785
+
786
+ mdpv_multi_points_openpsg_dataset = dict(
787
+ type=MDPVPointBriefCaptionDataset,
788
+ data_path=mdpv_multi_points_openpsg_data_path,
789
+ image_folder=mdpv_multi_points_openpsg_image_path,
790
+ tokenizer=tokenizer,
791
+ image_processor=image_processor,
792
+ dataset_map_fn=mdpv_points_map_fn,
793
+ template_map_fn=dict(
794
+ type=template_map_fn_factory, template=prompt_template),
795
+ max_length=max_length,
796
+ pad_image_to_square=True,
797
+ debug=False,
798
+ repeats=1,
799
+ )
800
+
801
+ mdpv_multi_points_flicker30k_dataset = dict(
802
+ type=MDPVPointBriefCaptionDataset,
803
+ data_path=mdpv_multi_points_flicker30k_data_path,
804
+ image_folder=mdpv_multi_points_flicker30k_image_path,
805
+ tokenizer=tokenizer,
806
+ image_processor=image_processor,
807
+ dataset_map_fn=mdpv_points_map_fn,
808
+ template_map_fn=dict(
809
+ type=template_map_fn_factory, template=prompt_template),
810
+ max_length=max_length,
811
+ pad_image_to_square=True,
812
+ debug=False,
813
+ repeats=1,
814
+ )
815
+
816
+ train_dataset = dict(
817
+ type=CombineDataset,
818
+ datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset,
819
+ glamm_grandf_dataset, glamm_psg_dataset,
820
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x
821
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset,
822
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset,
823
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
824
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x
825
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
826
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset,
827
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
828
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset,
829
+ region_cap_osprey_dataset, region_conversation_osprey_dataset,
830
+ mdpv_detailed_description_ade20k_dataset,
831
+ mdpv_detailed_description_cocostuff_10k_dataset,
832
+ mdpv_detailed_description_cocostuff_164k_dataset,
833
+ mdpv_detailed_description_vg_dataset,
834
+ mdpv_brief_description_lvis_dataset,
835
+ mdpv_brief_description_vg_dataset,
836
+ mdpv_brief_description_ade20k_dataset,
837
+ mdpv_brief_description_cocostuff10k_dataset,
838
+ mdpv_brief_description_cocostuff164k_dataset,
839
+ mdpv_qa_vg_dataset,
840
+ mdpv_qa_lvis_dataset,
841
+ mdpv_qa_ade20k_dataset,
842
+ mdpv_qa_cocostuff10k_dataset,
843
+ mdpv_qa_cocostuff164k_dataset,
844
+ mdpv_multi_points_flicker30k_dataset,
845
+ mdpv_multi_points_openpsg_dataset,],
846
+ )
847
+
848
+ train_dataloader = dict(
849
+ batch_size=batch_size,
850
+ num_workers=dataloader_num_workers,
851
+ dataset=train_dataset,
852
+ sampler=dict(
853
+ type=LengthGroupedSampler,
854
+ length_property='modality_length',
855
+ per_device_batch_size=batch_size * accumulative_counts),
856
+ collate_fn=dict(type=omg_llava_collate_fn))
857
+
858
+ #######################################################################
859
+ # PART 4 Scheduler & Optimizer #
860
+ #######################################################################
861
+ # optimizer
862
+ optim_wrapper = dict(
863
+ type=AmpOptimWrapper,
864
+ optimizer=dict(
865
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
866
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
867
+ accumulative_counts=accumulative_counts,
868
+ loss_scale='dynamic',
869
+ dtype='float16')
870
+
871
+ # learning policy
872
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
873
+ param_scheduler = [
874
+ dict(
875
+ type=LinearLR,
876
+ start_factor=1e-5,
877
+ by_epoch=True,
878
+ begin=0,
879
+ end=warmup_ratio * max_epochs,
880
+ convert_to_iter_based=True),
881
+ dict(
882
+ type=CosineAnnealingLR,
883
+ eta_min=0.0,
884
+ by_epoch=True,
885
+ begin=warmup_ratio * max_epochs,
886
+ end=max_epochs,
887
+ convert_to_iter_based=True)
888
+ ]
889
+
890
+ # train, val, test setting
891
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
892
+
893
+ #######################################################################
894
+ # PART 5 Runtime #
895
+ #######################################################################
896
+ # Log the dialogue periodically during the training process, optional
897
+ custom_hooks = [
898
+ dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer),
899
+ dict(
900
+ type=EvaluateChatHook_withSpecialTokens,
901
+ tokenizer=tokenizer,
902
+ image_processor=image_processor,
903
+ every_n_iters=evaluation_freq,
904
+ evaluation_inputs=evaluation_inputs,
905
+ evaluation_images=evaluation_images,
906
+ system=SYSTEM,
907
+ prompt_template=prompt_template)
908
+ ]
909
+
910
+ # configure default hooks
911
+ default_hooks = dict(
912
+ # record the time of every iteration.
913
+ timer=dict(type=IterTimerHook),
914
+ # print log every 10 iterations.
915
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
916
+ # enable the parameter scheduler.
917
+ param_scheduler=dict(type=ParamSchedulerHook),
918
+ # save checkpoint per `save_steps`.
919
+ checkpoint=dict(
920
+ type=CheckpointHook,
921
+ by_epoch=False,
922
+ interval=save_steps,
923
+ max_keep_ckpts=save_total_limit),
924
+ # set sampler seed in distributed evrionment.
925
+ sampler_seed=dict(type=DistSamplerSeedHook),
926
+ )
927
+
928
+ # configure environment
929
+ env_cfg = dict(
930
+ # whether to enable cudnn benchmark
931
+ cudnn_benchmark=False,
932
+ # set multi process parameters
933
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
934
+ # set distributed parameters
935
+ dist_cfg=dict(backend='nccl'),
936
+ )
937
+
938
+ # set visualizer
939
+ visualizer = None
940
+
941
+ # set log level
942
+ log_level = 'INFO'
943
+
944
+ # load from which checkpoint
945
+ load_from = None
946
+
947
+ # whether to resume training from the loaded checkpoint
948
+ resume = False
949
+
950
+ # Defaults to use random seed and disable `deterministic`
951
+ randomness = dict(seed=None, deterministic=False)
952
+
953
+ # set log processor
954
+ log_processor = dict(by_epoch=False)
omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_linearcat_debug.py ADDED
@@ -0,0 +1,927 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
4
+ LoggerHook, ParamSchedulerHook)
5
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
6
+ from peft import LoraConfig
7
+ from torch.optim import AdamW
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig, CLIPImageProcessor,
10
+ CLIPVisionModel)
11
+
12
+ from omg_llava.dataset import LLaVADataset
13
+ from omg_llava.dataset.collect_fns import omg_llava_collate_fn
14
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
15
+ from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\
16
+ CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\
17
+ ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\
18
+ PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\
19
+ RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\
20
+ Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\
21
+ OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\
22
+ OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\
23
+ MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\
24
+ semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\
25
+ referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn
26
+ from xtuner.dataset.samplers import LengthGroupedSampler
27
+ from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens
28
+ from xtuner.engine.runner import TrainLoop
29
+ from omg_llava.model import OMG_LLaVA
30
+ from xtuner.utils import PROMPT_TEMPLATE
31
+ from omg_llava.model import OpenCLIPBackbone_omgseg
32
+ from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead
33
+
34
+ from torch.nn import GroupNorm, ReLU
35
+
36
+ from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \
37
+ DiceLoss, MaskFormerFusionHead, FocalLoss
38
+ from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost
39
+ from mmdet.models.task_modules.samplers import MaskPseudoSampler
40
+
41
+ #######################################################################
42
+ # PART 1 Settings #
43
+ #######################################################################
44
+ # Model
45
+ llm_name_or_path = '/mnt/workspace/taozhang/chekpoints/internlm2-chat-7b' # Please change to your own path
46
+ pretrained_pth = './work_dirs/omg_llava_7b_pretrain_1024image_8gpus.pth'
47
+ omg_ov_class_embed_path='/mnt/workspace/taozhang/chekpoints/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path
48
+ omg_head_pretrain_pth_path = '/mnt/workspace/taozhang/chekpoints/omg_seg_convl.pth' # Please change to your own path
49
+
50
+ # Data
51
+ data_root = './data/llava_data/'
52
+ data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json'
53
+ image_folder = data_root + 'llava_images'
54
+
55
+ glamm_data_root = './data/glamm_data/'
56
+
57
+ refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/'
58
+ refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json'
59
+
60
+ grandf_image_path = glamm_data_root + 'images/grandf/train/'
61
+ grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json'
62
+
63
+ flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/'
64
+ flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json'
65
+
66
+ psg_image_path = glamm_data_root + 'images/coco2017/'
67
+ psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json'
68
+
69
+ ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
70
+ ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json'
71
+
72
+ cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/'
73
+ cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt'
74
+ cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/'
75
+
76
+ mapillary_image_path = './data/semantic_seg/mapillary/training/images/'
77
+ mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json'
78
+ mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/'
79
+
80
+ pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/'
81
+ pascal_file = './data/semantic_seg/pascal_part/train.json'
82
+
83
+ paco_image_path = './data/glamm_data/images/coco2017/'
84
+ paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json'
85
+
86
+ referring_refcoco_image_path = refcocog_image_path
87
+ referring_refcoco_data_path = "./data/ref_seg/"
88
+
89
+ referring_refcoco_plus_image_path = refcocog_image_path
90
+ referring_refcoco_plus_data_path = "./data/ref_seg/"
91
+
92
+ referring_refcocog_image_path = refcocog_image_path
93
+ referring_refcocog_data_path = "./data/ref_seg/"
94
+
95
+ referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/"
96
+ referring_refclef_data_path = "./data/ref_seg/"
97
+
98
+ region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
99
+ region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json"
100
+
101
+ region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
102
+ region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json"
103
+
104
+ mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
105
+ mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json'
106
+
107
+ mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
108
+ mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json'
109
+
110
+ mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
111
+ mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json'
112
+
113
+ mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
114
+ mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json'
115
+
116
+ mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
117
+ mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json'
118
+
119
+ mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
120
+ mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json'
121
+
122
+ mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
123
+ mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json'
124
+
125
+ mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
126
+ mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json'
127
+
128
+ mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
129
+ mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json'
130
+
131
+ mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
132
+ mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json'
133
+
134
+ mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
135
+ mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json'
136
+
137
+ mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017'
138
+ mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json'
139
+
140
+ mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
141
+ mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json'
142
+
143
+ mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
144
+ mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json'
145
+
146
+ mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/'
147
+ mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json'
148
+
149
+ mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017'
150
+ mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json'
151
+
152
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
153
+ max_length = int(2048 - (1024 / 64)**2 - 100)
154
+
155
+ # Scheduler & Optimizer
156
+ batch_size = 8 # per_device
157
+ accumulative_counts = 2
158
+ dataloader_num_workers = 0
159
+ max_epochs = 1
160
+ optim_type = AdamW
161
+ lr = 2e-4
162
+ betas = (0.9, 0.999)
163
+ weight_decay = 0
164
+ max_norm = 1 # grad clip
165
+ warmup_ratio = 0.03
166
+
167
+
168
+ # Save
169
+ save_steps = 2000
170
+ save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited)
171
+
172
+ # Evaluate the generation performance during the training
173
+ evaluation_freq = 2000
174
+ SYSTEM = ''
175
+ evaluation_images = './work_dirs/test.jpg'
176
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture',
177
+ 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.']
178
+
179
+ #######################################################################
180
+ # PART 2 Model & Tokenizer & Image Processor #
181
+ #######################################################################
182
+ tokenizer = dict(
183
+ type=AutoTokenizer.from_pretrained,
184
+ pretrained_model_name_or_path=llm_name_or_path,
185
+ trust_remote_code=True,
186
+ padding_side='right')
187
+
188
+ image_processor = dict(
189
+ type=CLIPImageProcessor,
190
+ do_resize=True,
191
+ size=1024,
192
+ resample=3,
193
+ do_center_crop=True,
194
+ crop_size=1024,
195
+ do_rescale=True,
196
+ do_normalize=True,
197
+ image_mean=[0.4814, 0.4578, 0.4082],
198
+ image_std=[0.2686, 0.2613, 0.2757],
199
+ do_convert_rgb=True
200
+ )
201
+
202
+ class_embed = 'convnext_large_d_320_CocoPanopticOVDataset'
203
+ num_things_classes = 80
204
+ num_stuff_classes = 53
205
+ num_classes = num_things_classes + num_stuff_classes
206
+
207
+ omgseg_model = dict(
208
+ type=OMGSegVisualEncoder,
209
+ data_preprocessor=None,
210
+ pixel_shuffle_down_ratio=2,
211
+ backbone=dict(
212
+ type=OpenCLIPBackbone_omgseg,
213
+ model_name='convnext_large_d_320',
214
+ fix=True,
215
+ init_cfg=dict(
216
+ type='clip_pretrain',
217
+ checkpoint='laion2b_s29b_b131k_ft_soup'
218
+ )
219
+ ),
220
+ panoptic_head=dict(
221
+ type=Mask2FormerVideoSemSamHead,
222
+ sphere_cls=True,
223
+ ov_path=omg_ov_class_embed_path,
224
+ enable_box_query=False,
225
+ ov_classifier_name=class_embed,
226
+ logit=None,
227
+ in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside
228
+ strides=[4, 8, 16, 32],
229
+ feat_channels=256,
230
+ out_channels=256,
231
+ num_things_classes=num_things_classes,
232
+ num_stuff_classes=num_stuff_classes,
233
+ num_queries=300,
234
+ num_transformer_feat_level=3,
235
+ pixel_decoder=dict(
236
+ type=MSDeformAttnPixelDecoder,
237
+ num_outs=3,
238
+ norm_cfg=dict(type=GroupNorm, num_groups=32),
239
+ act_cfg=dict(type=ReLU),
240
+ encoder=dict( # DeformableDetrTransformerEncoder
241
+ num_layers=6,
242
+ layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
243
+ self_attn_cfg=dict( # MultiScaleDeformableAttention
244
+ embed_dims=256,
245
+ num_heads=8,
246
+ num_levels=3,
247
+ num_points=4,
248
+ dropout=0.0,
249
+ batch_first=True),
250
+ ffn_cfg=dict(
251
+ embed_dims=256,
252
+ feedforward_channels=1024,
253
+ num_fcs=2,
254
+ ffn_drop=0.0,
255
+ act_cfg=dict(type=ReLU, inplace=True)))),
256
+ positional_encoding=dict(num_feats=128, normalize=True)),
257
+ enforce_decoder_input_project=False,
258
+ positional_encoding=dict(num_feats=128, normalize=True),
259
+ transformer_decoder=dict( # Mask2FormerTransformerDecoder
260
+ return_intermediate=True,
261
+ num_layers=9,
262
+ layer_cfg=dict( # Mask2FormerTransformerDecoderLayer
263
+ self_attn_cfg=dict( # MultiheadAttention
264
+ embed_dims=256,
265
+ num_heads=8,
266
+ dropout=0.0,
267
+ batch_first=True),
268
+ cross_attn_cfg=dict( # MultiheadAttention
269
+ embed_dims=256,
270
+ num_heads=8,
271
+ dropout=0.0,
272
+ batch_first=True),
273
+ ffn_cfg=dict(
274
+ embed_dims=256,
275
+ feedforward_channels=2048,
276
+ num_fcs=2,
277
+ ffn_drop=0.0,
278
+ act_cfg=dict(type='ReLU', inplace=True))),
279
+ init_cfg=None),
280
+ loss_cls=dict(
281
+ type=CrossEntropyLoss,
282
+ use_sigmoid=False,
283
+ loss_weight=2.0,
284
+ reduction='mean',
285
+ class_weight=[1.0] * 240 + [0.1]),
286
+ loss_mask=dict(
287
+ type=CrossEntropyLoss,
288
+ use_sigmoid=True,
289
+ reduction='mean',
290
+ loss_weight=5.0),
291
+ loss_dice=dict(
292
+ type=DiceLoss,
293
+ use_sigmoid=True,
294
+ activate=True,
295
+ reduction='mean',
296
+ naive_dice=True,
297
+ eps=1.0,
298
+ loss_weight=5.0),
299
+ loss_iou=dict(
300
+ type=FocalLoss,
301
+ use_sigmoid=True,
302
+ loss_weight=2.0,
303
+ reduction='mean')
304
+ ),
305
+ panoptic_fusion_head=dict(
306
+ type=MaskFormerFusionHead,
307
+ num_things_classes=num_things_classes,
308
+ num_stuff_classes=num_stuff_classes,
309
+ loss_panoptic=None,
310
+ init_cfg=None),
311
+ train_cfg=dict(
312
+ num_points=12544,
313
+ oversample_ratio=3.0,
314
+ importance_sample_ratio=0.75,
315
+ assigner=dict(
316
+ type=HungarianAssigner,
317
+ match_costs=[
318
+ # dict(type=FlexibleClassificationCost, weight=2.0),
319
+ dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True),
320
+ dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0)
321
+ ]),
322
+ sampler=dict(type=MaskPseudoSampler)),
323
+ test_cfg=dict(
324
+ panoptic_on=True,
325
+ # For now, the dataset does not support
326
+ # evaluating semantic segmentation metric.
327
+ semantic_on=False,
328
+ instance_on=True,
329
+ # max_per_image is for instance segmentation.
330
+ max_per_image=100,
331
+ iou_thr=0.8,
332
+ # In Mask2Former's panoptic postprocessing,
333
+ # it will filter mask area where score is less than 0.5 .
334
+ filter_low_score=True),
335
+ init_cfg=dict(
336
+ type='Pretrained',
337
+ checkpoint=omg_head_pretrain_pth_path,
338
+ )
339
+ )
340
+
341
+ model = dict(
342
+ type=OMG_LLaVA,
343
+ freeze_llm=True,
344
+ freeze_visual_encoder=True,
345
+ require_omg_decoder=False,
346
+ pretrained_pth=pretrained_pth,
347
+ text2vision_projector=True,
348
+ pixel_shuffle_ratio=2,
349
+ using_multilayer_states=True,
350
+ seg_token_merge_type='linear_cat',
351
+ selected_layers=32,
352
+ llm=dict(
353
+ type=AutoModelForCausalLM.from_pretrained,
354
+ pretrained_model_name_or_path=llm_name_or_path,
355
+ trust_remote_code=True,
356
+ torch_dtype=torch.float16,
357
+ quantization_config=dict(
358
+ type=BitsAndBytesConfig,
359
+ load_in_4bit=True,
360
+ load_in_8bit=False,
361
+ llm_int8_threshold=6.0,
362
+ llm_int8_has_fp16_weight=False,
363
+ bnb_4bit_compute_dtype=torch.float16,
364
+ bnb_4bit_use_double_quant=True,
365
+ bnb_4bit_quant_type='nf4')),
366
+ llm_lora=dict(
367
+ type=LoraConfig,
368
+ r=512,
369
+ lora_alpha=256,
370
+ lora_dropout=0.05,
371
+ bias='none',
372
+ task_type='CAUSAL_LM'),
373
+ visual_encoder=omgseg_model,
374
+ tokenizer=tokenizer,
375
+ )
376
+
377
+ #######################################################################
378
+ # PART 3 Dataset & Dataloader #
379
+ #######################################################################
380
+ debug=False
381
+ llava_dataset = dict(
382
+ type=LLaVADataset,
383
+ data_path=data_path,
384
+ image_folder=image_folder,
385
+ tokenizer=tokenizer,
386
+ image_processor=image_processor,
387
+ dataset_map_fn=llava_map_fn,
388
+ template_map_fn=dict(
389
+ type=template_map_fn_factory, template=prompt_template),
390
+ max_length=max_length,
391
+ pad_image_to_square=True)
392
+
393
+ glamm_refcocog_dataset = dict(
394
+ type=RefCOCOgGCGDataset,
395
+ data_path=refcocog_ann_file,
396
+ image_folder=refcocog_image_path,
397
+ tokenizer=tokenizer,
398
+ image_processor=image_processor,
399
+ dataset_map_fn=glamm_refcocog_map_fn,
400
+ template_map_fn=dict(
401
+ type=template_map_fn_factory, template=prompt_template),
402
+ max_length=max_length,
403
+ pad_image_to_square=True,
404
+ debug=False,
405
+ repeats=1,
406
+ )
407
+
408
+ glamm_grandf_dataset = dict(
409
+ type=GranDfGCGDataset,
410
+ data_path=grandf_ann_file,
411
+ image_folder=grandf_image_path,
412
+ tokenizer=tokenizer,
413
+ image_processor=image_processor,
414
+ dataset_map_fn=glamm_granf_map_fn,
415
+ template_map_fn=dict(
416
+ type=template_map_fn_factory, template=prompt_template),
417
+ max_length=max_length,
418
+ pad_image_to_square=True,
419
+ debug=debug,
420
+ repeats=10,
421
+ )
422
+
423
+ glamm_psg_dataset = dict(
424
+ type=OpenPsgGCGDataset,
425
+ data_path=psg_ann_file,
426
+ image_folder=psg_image_path,
427
+ tokenizer=tokenizer,
428
+ image_processor=image_processor,
429
+ dataset_map_fn=glamm_openpsg_map_fn,
430
+ template_map_fn=dict(
431
+ type=template_map_fn_factory, template=prompt_template),
432
+ max_length=max_length,
433
+ pad_image_to_square=True,
434
+ debug=debug,
435
+ repeats=1,
436
+ )
437
+
438
+ glamm_flickr_dataset = dict(
439
+ type=FlickrGCGDataset,
440
+ data_path=flickr_ann_file,
441
+ image_folder=flickr_image_path,
442
+ tokenizer=tokenizer,
443
+ image_processor=image_processor,
444
+ dataset_map_fn=glamm_flickr_map_fn,
445
+ template_map_fn=dict(
446
+ type=template_map_fn_factory, template=prompt_template),
447
+ max_length=max_length,
448
+ pad_image_to_square=True,
449
+ debug=debug,
450
+ repeats=1,
451
+ )
452
+
453
+ semantic_seg_ade20k_dataset = dict(
454
+ type=ADE20kSemanticSegDataset,
455
+ data_path=ade20k_class_file,
456
+ image_folder=ade20k_image_path,
457
+ tokenizer=tokenizer,
458
+ image_processor=image_processor,
459
+ dataset_map_fn=semantic_seg_gcg_format_map_fn,
460
+ template_map_fn=dict(
461
+ type=template_map_fn_factory, template=prompt_template),
462
+ max_length=max_length,
463
+ pad_image_to_square=True,
464
+ debug=False,
465
+ repeats=1,
466
+ gcg_format=True,
467
+ )
468
+
469
+ semantic_seg_cocostuff_dataset = dict(
470
+ type=COCOStuffSemanticSegDataset,
471
+ data_path=cocostuff_class_file,
472
+ image_folder=cocostuff_image_path,
473
+ label_path=cocostuff_label_path,
474
+ tokenizer=tokenizer,
475
+ image_processor=image_processor,
476
+ dataset_map_fn=semantic_seg_gcg_format_map_fn,
477
+ template_map_fn=dict(
478
+ type=template_map_fn_factory, template=prompt_template),
479
+ max_length=max_length,
480
+ pad_image_to_square=True,
481
+ debug=False,
482
+ repeats=1,
483
+ gcg_format=True,
484
+ )
485
+
486
+ referring_seg_refcoco_dataset = dict(
487
+ type=RefcocoReferringSegDataset,
488
+ data_path=referring_refcoco_data_path,
489
+ image_folder=referring_refcoco_image_path,
490
+ tokenizer=tokenizer,
491
+ image_processor=image_processor,
492
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
493
+ template_map_fn=dict(
494
+ type=template_map_fn_factory, template=prompt_template),
495
+ max_length=max_length,
496
+ pad_image_to_square=True,
497
+ debug=False,
498
+ repeats=1,
499
+ )
500
+
501
+ referring_seg_refcoco_plus_dataset = dict(
502
+ type=Refcoco_plus_ReferringSegDataset,
503
+ data_path=referring_refcoco_plus_data_path,
504
+ image_folder=referring_refcoco_plus_image_path,
505
+ tokenizer=tokenizer,
506
+ image_processor=image_processor,
507
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
508
+ template_map_fn=dict(
509
+ type=template_map_fn_factory, template=prompt_template),
510
+ max_length=max_length,
511
+ pad_image_to_square=True,
512
+ debug=False,
513
+ repeats=1,
514
+ )
515
+
516
+ referring_seg_refcocog_dataset = dict(
517
+ type=Refcocog_ReferringSegDataset,
518
+ data_path=referring_refcocog_data_path,
519
+ image_folder=referring_refcocog_image_path,
520
+ tokenizer=tokenizer,
521
+ image_processor=image_processor,
522
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
523
+ template_map_fn=dict(
524
+ type=template_map_fn_factory, template=prompt_template),
525
+ max_length=max_length,
526
+ pad_image_to_square=True,
527
+ debug=False,
528
+ repeats=1,
529
+ )
530
+
531
+ referring_seg_refclef_dataset = dict(
532
+ type=Refclef_ReferringSegDataset,
533
+ data_path=referring_refclef_data_path,
534
+ image_folder=referring_refclef_image_path,
535
+ tokenizer=tokenizer,
536
+ image_processor=image_processor,
537
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
538
+ template_map_fn=dict(
539
+ type=template_map_fn_factory, template=prompt_template),
540
+ max_length=max_length,
541
+ pad_image_to_square=True,
542
+ debug=False,
543
+ repeats=1,
544
+ )
545
+
546
+ region_cap_osprey_dataset = dict(
547
+ type=OspreyRegionCaptionDataset,
548
+ data_path=region_cap_osprey_data_path,
549
+ image_folder=region_cap_osprey_image_path,
550
+ tokenizer=tokenizer,
551
+ image_processor=image_processor,
552
+ dataset_map_fn=osprey_region_caption_gcg_format_map_fn,
553
+ template_map_fn=dict(
554
+ type=template_map_fn_factory, template=prompt_template),
555
+ max_length=max_length,
556
+ pad_image_to_square=True,
557
+ debug=False,
558
+ repeats=1,
559
+ )
560
+
561
+ region_conversation_osprey_dataset = dict(
562
+ type=OspreyRegionConversationDataset,
563
+ data_path=region_conversation_osprey_data_path,
564
+ image_folder=region_conversation_osprey_image_path,
565
+ tokenizer=tokenizer,
566
+ image_processor=image_processor,
567
+ dataset_map_fn=osprey_region_conversation_map_fn,
568
+ template_map_fn=dict(
569
+ type=template_map_fn_factory, template=prompt_template),
570
+ max_length=max_length,
571
+ pad_image_to_square=True,
572
+ debug=False,
573
+ repeats=1,
574
+ )
575
+
576
+ mdpv_detailed_description_ade20k_dataset = dict(
577
+ type=MDPVPointDetailedCaptionDataset,
578
+ data_path=mdpv_detailed_caption_ade20k_data_path,
579
+ image_folder=mdpv_detailed_caption_ade20k_image_path,
580
+ tokenizer=tokenizer,
581
+ image_processor=image_processor,
582
+ dataset_map_fn=mdpv_points_map_fn,
583
+ template_map_fn=dict(
584
+ type=template_map_fn_factory, template=prompt_template),
585
+ max_length=max_length,
586
+ pad_image_to_square=True,
587
+ debug=False,
588
+ repeats=1,
589
+ )
590
+
591
+ mdpv_detailed_description_cocostuff_10k_dataset = dict(
592
+ type=MDPVPointDetailedCaptionDataset,
593
+ data_path=mdpv_detailed_caption_cocostuff_10k_data_path,
594
+ image_folder=mdpv_detailed_caption_cocostuff_10k_image_path,
595
+ tokenizer=tokenizer,
596
+ image_processor=image_processor,
597
+ dataset_map_fn=mdpv_points_map_fn,
598
+ template_map_fn=dict(
599
+ type=template_map_fn_factory, template=prompt_template),
600
+ max_length=max_length,
601
+ pad_image_to_square=True,
602
+ debug=False,
603
+ repeats=1,
604
+ )
605
+
606
+ mdpv_detailed_description_cocostuff_164k_dataset = dict(
607
+ type=MDPVPointDetailedCaptionDataset,
608
+ data_path=mdpv_detailed_caption_cocostuff_164k_data_path,
609
+ image_folder=mdpv_detailed_caption_cocostuff_164k_image_path,
610
+ tokenizer=tokenizer,
611
+ image_processor=image_processor,
612
+ dataset_map_fn=mdpv_points_map_fn,
613
+ template_map_fn=dict(
614
+ type=template_map_fn_factory, template=prompt_template),
615
+ max_length=max_length,
616
+ pad_image_to_square=True,
617
+ debug=False,
618
+ repeats=1,
619
+ )
620
+
621
+ mdpv_detailed_description_vg_dataset = dict(
622
+ type=MDPVPointDetailedCaptionDataset,
623
+ data_path=mdpv_detailed_caption_vg_data_path,
624
+ image_folder=mdpv_detailed_caption_vg_image_path,
625
+ tokenizer=tokenizer,
626
+ image_processor=image_processor,
627
+ dataset_map_fn=mdpv_points_map_fn,
628
+ template_map_fn=dict(
629
+ type=template_map_fn_factory, template=prompt_template),
630
+ max_length=max_length,
631
+ pad_image_to_square=True,
632
+ debug=False,
633
+ repeats=1,
634
+ )
635
+
636
+ mdpv_brief_description_vg_dataset = dict(
637
+ type=MDPVPointBriefCaptionDataset,
638
+ data_path=mdpv_brief_caption_vg_data_path,
639
+ image_folder=mdpv_brief_caption_vg_image_path,
640
+ tokenizer=tokenizer,
641
+ image_processor=image_processor,
642
+ dataset_map_fn=mdpv_points_map_fn,
643
+ template_map_fn=dict(
644
+ type=template_map_fn_factory, template=prompt_template),
645
+ max_length=max_length,
646
+ pad_image_to_square=True,
647
+ debug=False,
648
+ repeats=1,
649
+ )
650
+
651
+ mdpv_brief_description_cocostuff10k_dataset = dict(
652
+ type=MDPVPointBriefCaptionDataset,
653
+ data_path=mdpv_brief_caption_cocostuff_10k_data_path,
654
+ image_folder=mdpv_brief_caption_cocostuff_10k_image_path,
655
+ tokenizer=tokenizer,
656
+ image_processor=image_processor,
657
+ dataset_map_fn=mdpv_points_map_fn,
658
+ template_map_fn=dict(
659
+ type=template_map_fn_factory, template=prompt_template),
660
+ max_length=max_length,
661
+ pad_image_to_square=True,
662
+ debug=False,
663
+ repeats=1,
664
+ )
665
+
666
+ mdpv_brief_description_cocostuff164k_dataset = dict(
667
+ type=MDPVPointBriefCaptionDataset,
668
+ data_path=mdpv_brief_caption_cocostuff_164k_data_path,
669
+ image_folder=mdpv_brief_caption_cocostuff_164k_image_path,
670
+ tokenizer=tokenizer,
671
+ image_processor=image_processor,
672
+ dataset_map_fn=mdpv_points_map_fn,
673
+ template_map_fn=dict(
674
+ type=template_map_fn_factory, template=prompt_template),
675
+ max_length=max_length,
676
+ pad_image_to_square=True,
677
+ debug=False,
678
+ repeats=1,
679
+ )
680
+
681
+ mdpv_brief_description_ade20k_dataset = dict(
682
+ type=MDPVPointBriefCaptionDataset,
683
+ data_path=mdpv_brief_caption_ade20k_data_path,
684
+ image_folder=mdpv_brief_caption_ade20k_image_path,
685
+ tokenizer=tokenizer,
686
+ image_processor=image_processor,
687
+ dataset_map_fn=mdpv_points_map_fn,
688
+ template_map_fn=dict(
689
+ type=template_map_fn_factory, template=prompt_template),
690
+ max_length=max_length,
691
+ pad_image_to_square=True,
692
+ debug=False,
693
+ repeats=1,
694
+ )
695
+
696
+ mdpv_brief_description_lvis_dataset = dict(
697
+ type=MDPVPointBriefCaptionDataset,
698
+ data_path=mdpv_brief_caption_lvis_data_path,
699
+ image_folder=mdpv_brief_caption_lvis_image_path,
700
+ tokenizer=tokenizer,
701
+ image_processor=image_processor,
702
+ dataset_map_fn=mdpv_points_map_fn,
703
+ template_map_fn=dict(
704
+ type=template_map_fn_factory, template=prompt_template),
705
+ max_length=max_length,
706
+ pad_image_to_square=True,
707
+ debug=False,
708
+ repeats=1,
709
+ )
710
+
711
+ mdpv_qa_vg_dataset = dict(
712
+ type=MDPVPointBriefCaptionDataset,
713
+ data_path=mdpv_qa_vg_data_path,
714
+ image_folder=mdpv_qa_vg_image_path,
715
+ tokenizer=tokenizer,
716
+ image_processor=image_processor,
717
+ dataset_map_fn=mdpv_points_map_fn,
718
+ template_map_fn=dict(
719
+ type=template_map_fn_factory, template=prompt_template),
720
+ max_length=max_length,
721
+ pad_image_to_square=True,
722
+ debug=False,
723
+ repeats=1,
724
+ )
725
+
726
+ mdpv_qa_ade20k_dataset = dict(
727
+ type=MDPVPointBriefCaptionDataset,
728
+ data_path=mdpv_qa_ade20k_data_path,
729
+ image_folder=mdpv_qa_ade20k_image_path,
730
+ tokenizer=tokenizer,
731
+ image_processor=image_processor,
732
+ dataset_map_fn=mdpv_points_map_fn,
733
+ template_map_fn=dict(
734
+ type=template_map_fn_factory, template=prompt_template),
735
+ max_length=max_length,
736
+ pad_image_to_square=True,
737
+ debug=False,
738
+ repeats=1,
739
+ )
740
+
741
+ mdpv_qa_lvis_dataset = dict(
742
+ type=MDPVPointBriefCaptionDataset,
743
+ data_path=mdpv_qa_lvis_data_path,
744
+ image_folder=mdpv_qa_lvis_image_path,
745
+ tokenizer=tokenizer,
746
+ image_processor=image_processor,
747
+ dataset_map_fn=mdpv_points_map_fn,
748
+ template_map_fn=dict(
749
+ type=template_map_fn_factory, template=prompt_template),
750
+ max_length=max_length,
751
+ pad_image_to_square=True,
752
+ debug=False,
753
+ repeats=1,
754
+ )
755
+
756
+ mdpv_qa_cocostuff10k_dataset = dict(
757
+ type=MDPVPointBriefCaptionDataset,
758
+ data_path=mdpv_qa_cocostuff10k_data_path,
759
+ image_folder=mdpv_qa_cocostuff10k_image_path,
760
+ tokenizer=tokenizer,
761
+ image_processor=image_processor,
762
+ dataset_map_fn=mdpv_points_map_fn,
763
+ template_map_fn=dict(
764
+ type=template_map_fn_factory, template=prompt_template),
765
+ max_length=max_length,
766
+ pad_image_to_square=True,
767
+ debug=False,
768
+ repeats=1,
769
+ )
770
+
771
+ mdpv_qa_cocostuff164k_dataset = dict(
772
+ type=MDPVPointBriefCaptionDataset,
773
+ data_path=mdpv_qa_cocostuff164k_data_path,
774
+ image_folder=mdpv_qa_cocostuff164k_image_path,
775
+ tokenizer=tokenizer,
776
+ image_processor=image_processor,
777
+ dataset_map_fn=mdpv_points_map_fn,
778
+ template_map_fn=dict(
779
+ type=template_map_fn_factory, template=prompt_template),
780
+ max_length=max_length,
781
+ pad_image_to_square=True,
782
+ debug=False,
783
+ repeats=1,
784
+ )
785
+
786
+ mdpv_multi_points_openpsg_dataset = dict(
787
+ type=MDPVPointBriefCaptionDataset,
788
+ data_path=mdpv_multi_points_openpsg_data_path,
789
+ image_folder=mdpv_multi_points_openpsg_image_path,
790
+ tokenizer=tokenizer,
791
+ image_processor=image_processor,
792
+ dataset_map_fn=mdpv_points_map_fn,
793
+ template_map_fn=dict(
794
+ type=template_map_fn_factory, template=prompt_template),
795
+ max_length=max_length,
796
+ pad_image_to_square=True,
797
+ debug=False,
798
+ repeats=1,
799
+ )
800
+
801
+ mdpv_multi_points_flicker30k_dataset = dict(
802
+ type=MDPVPointBriefCaptionDataset,
803
+ data_path=mdpv_multi_points_flicker30k_data_path,
804
+ image_folder=mdpv_multi_points_flicker30k_image_path,
805
+ tokenizer=tokenizer,
806
+ image_processor=image_processor,
807
+ dataset_map_fn=mdpv_points_map_fn,
808
+ template_map_fn=dict(
809
+ type=template_map_fn_factory, template=prompt_template),
810
+ max_length=max_length,
811
+ pad_image_to_square=True,
812
+ debug=False,
813
+ repeats=1,
814
+ )
815
+
816
+ train_dataset = dict(
817
+ type=CombineDataset,
818
+ datasets_cfgs=[glamm_refcocog_dataset, ],
819
+ )
820
+
821
+ train_dataloader = dict(
822
+ batch_size=batch_size,
823
+ num_workers=dataloader_num_workers,
824
+ dataset=train_dataset,
825
+ sampler=dict(
826
+ type=LengthGroupedSampler,
827
+ length_property='modality_length',
828
+ per_device_batch_size=batch_size * accumulative_counts),
829
+ collate_fn=dict(type=omg_llava_collate_fn))
830
+
831
+ #######################################################################
832
+ # PART 4 Scheduler & Optimizer #
833
+ #######################################################################
834
+ # optimizer
835
+ optim_wrapper = dict(
836
+ type=AmpOptimWrapper,
837
+ optimizer=dict(
838
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
839
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
840
+ accumulative_counts=accumulative_counts,
841
+ loss_scale='dynamic',
842
+ dtype='float16')
843
+
844
+ # learning policy
845
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
846
+ param_scheduler = [
847
+ dict(
848
+ type=LinearLR,
849
+ start_factor=1e-5,
850
+ by_epoch=True,
851
+ begin=0,
852
+ end=warmup_ratio * max_epochs,
853
+ convert_to_iter_based=True),
854
+ dict(
855
+ type=CosineAnnealingLR,
856
+ eta_min=0.0,
857
+ by_epoch=True,
858
+ begin=warmup_ratio * max_epochs,
859
+ end=max_epochs,
860
+ convert_to_iter_based=True)
861
+ ]
862
+
863
+ # train, val, test setting
864
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
865
+
866
+ #######################################################################
867
+ # PART 5 Runtime #
868
+ #######################################################################
869
+ # Log the dialogue periodically during the training process, optional
870
+ custom_hooks = [
871
+ dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer),
872
+ dict(
873
+ type=EvaluateChatHook_withSpecialTokens,
874
+ tokenizer=tokenizer,
875
+ image_processor=image_processor,
876
+ every_n_iters=evaluation_freq,
877
+ evaluation_inputs=evaluation_inputs,
878
+ evaluation_images=evaluation_images,
879
+ system=SYSTEM,
880
+ prompt_template=prompt_template)
881
+ ]
882
+
883
+ # configure default hooks
884
+ default_hooks = dict(
885
+ # record the time of every iteration.
886
+ timer=dict(type=IterTimerHook),
887
+ # print log every 10 iterations.
888
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
889
+ # enable the parameter scheduler.
890
+ param_scheduler=dict(type=ParamSchedulerHook),
891
+ # save checkpoint per `save_steps`.
892
+ checkpoint=dict(
893
+ type=CheckpointHook,
894
+ by_epoch=False,
895
+ interval=save_steps,
896
+ max_keep_ckpts=save_total_limit),
897
+ # set sampler seed in distributed evrionment.
898
+ sampler_seed=dict(type=DistSamplerSeedHook),
899
+ )
900
+
901
+ # configure environment
902
+ env_cfg = dict(
903
+ # whether to enable cudnn benchmark
904
+ cudnn_benchmark=False,
905
+ # set multi process parameters
906
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
907
+ # set distributed parameters
908
+ dist_cfg=dict(backend='nccl'),
909
+ )
910
+
911
+ # set visualizer
912
+ visualizer = None
913
+
914
+ # set log level
915
+ log_level = 'INFO'
916
+
917
+ # load from which checkpoint
918
+ load_from = None
919
+
920
+ # whether to resume training from the loaded checkpoint
921
+ resume = False
922
+
923
+ # Defaults to use random seed and disable `deterministic`
924
+ randomness = dict(seed=None, deterministic=False)
925
+
926
+ # set log processor
927
+ log_processor = dict(by_epoch=False)
omg_llava/configs/finetune/ablation_multi_seg_states/ablation_multi_seg_states_mean.py ADDED
@@ -0,0 +1,954 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
4
+ LoggerHook, ParamSchedulerHook)
5
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
6
+ from peft import LoraConfig
7
+ from torch.optim import AdamW
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig, CLIPImageProcessor,
10
+ CLIPVisionModel)
11
+
12
+ from omg_llava.dataset import LLaVADataset
13
+ from omg_llava.dataset.collect_fns import omg_llava_collate_fn
14
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
15
+ from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\
16
+ CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\
17
+ ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\
18
+ PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\
19
+ RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\
20
+ Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\
21
+ OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\
22
+ OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\
23
+ MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\
24
+ semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\
25
+ referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn
26
+ from xtuner.dataset.samplers import LengthGroupedSampler
27
+ from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens
28
+ from xtuner.engine.runner import TrainLoop
29
+ from omg_llava.model import OMG_LLaVA
30
+ from xtuner.utils import PROMPT_TEMPLATE
31
+ from omg_llava.model import OpenCLIPBackbone_omgseg
32
+ from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead
33
+
34
+ from torch.nn import GroupNorm, ReLU
35
+
36
+ from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \
37
+ DiceLoss, MaskFormerFusionHead, FocalLoss
38
+ from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost
39
+ from mmdet.models.task_modules.samplers import MaskPseudoSampler
40
+
41
+ #######################################################################
42
+ # PART 1 Settings #
43
+ #######################################################################
44
+ # Model
45
+ llm_name_or_path = '/mnt/workspace/taozhang/chekpoints/internlm2-chat-7b' # Please change to your own path
46
+ pretrained_pth = './work_dirs/omg_llava_7b_pretrain_1024image_8gpus.pth'
47
+ omg_ov_class_embed_path='/mnt/workspace/taozhang/chekpoints/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path
48
+ omg_head_pretrain_pth_path = '/mnt/workspace/taozhang/chekpoints/omg_seg_convl.pth' # Please change to your own path
49
+
50
+ # Data
51
+ data_root = './data/llava_data/'
52
+ data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json'
53
+ image_folder = data_root + 'llava_images'
54
+
55
+ glamm_data_root = './data/glamm_data/'
56
+
57
+ refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/'
58
+ refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json'
59
+
60
+ grandf_image_path = glamm_data_root + 'images/grandf/train/'
61
+ grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json'
62
+
63
+ flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/'
64
+ flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json'
65
+
66
+ psg_image_path = glamm_data_root + 'images/coco2017/'
67
+ psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json'
68
+
69
+ ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
70
+ ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json'
71
+
72
+ cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/'
73
+ cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt'
74
+ cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/'
75
+
76
+ mapillary_image_path = './data/semantic_seg/mapillary/training/images/'
77
+ mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json'
78
+ mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/'
79
+
80
+ pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/'
81
+ pascal_file = './data/semantic_seg/pascal_part/train.json'
82
+
83
+ paco_image_path = './data/glamm_data/images/coco2017/'
84
+ paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json'
85
+
86
+ referring_refcoco_image_path = refcocog_image_path
87
+ referring_refcoco_data_path = "./data/ref_seg/"
88
+
89
+ referring_refcoco_plus_image_path = refcocog_image_path
90
+ referring_refcoco_plus_data_path = "./data/ref_seg/"
91
+
92
+ referring_refcocog_image_path = refcocog_image_path
93
+ referring_refcocog_data_path = "./data/ref_seg/"
94
+
95
+ referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/"
96
+ referring_refclef_data_path = "./data/ref_seg/"
97
+
98
+ region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
99
+ region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json"
100
+
101
+ region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
102
+ region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json"
103
+
104
+ mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
105
+ mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json'
106
+
107
+ mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
108
+ mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json'
109
+
110
+ mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
111
+ mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json'
112
+
113
+ mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
114
+ mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json'
115
+
116
+ mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
117
+ mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json'
118
+
119
+ mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
120
+ mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json'
121
+
122
+ mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
123
+ mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json'
124
+
125
+ mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
126
+ mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json'
127
+
128
+ mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
129
+ mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json'
130
+
131
+ mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
132
+ mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json'
133
+
134
+ mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
135
+ mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json'
136
+
137
+ mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017'
138
+ mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json'
139
+
140
+ mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
141
+ mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json'
142
+
143
+ mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
144
+ mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json'
145
+
146
+ mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/'
147
+ mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json'
148
+
149
+ mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017'
150
+ mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json'
151
+
152
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
153
+ max_length = int(2048 - (1024 / 64)**2 - 100)
154
+
155
+ # Scheduler & Optimizer
156
+ batch_size = 8 # per_device
157
+ accumulative_counts = 2
158
+ dataloader_num_workers = 0
159
+ max_epochs = 1
160
+ optim_type = AdamW
161
+ lr = 2e-4
162
+ betas = (0.9, 0.999)
163
+ weight_decay = 0
164
+ max_norm = 1 # grad clip
165
+ warmup_ratio = 0.03
166
+
167
+
168
+ # Save
169
+ save_steps = 2000
170
+ save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited)
171
+
172
+ # Evaluate the generation performance during the training
173
+ evaluation_freq = 2000
174
+ SYSTEM = ''
175
+ evaluation_images = './work_dirs/test.jpg'
176
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture',
177
+ 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.']
178
+
179
+ #######################################################################
180
+ # PART 2 Model & Tokenizer & Image Processor #
181
+ #######################################################################
182
+ tokenizer = dict(
183
+ type=AutoTokenizer.from_pretrained,
184
+ pretrained_model_name_or_path=llm_name_or_path,
185
+ trust_remote_code=True,
186
+ padding_side='right')
187
+
188
+ image_processor = dict(
189
+ type=CLIPImageProcessor,
190
+ do_resize=True,
191
+ size=1024,
192
+ resample=3,
193
+ do_center_crop=True,
194
+ crop_size=1024,
195
+ do_rescale=True,
196
+ do_normalize=True,
197
+ image_mean=[0.4814, 0.4578, 0.4082],
198
+ image_std=[0.2686, 0.2613, 0.2757],
199
+ do_convert_rgb=True
200
+ )
201
+
202
+ class_embed = 'convnext_large_d_320_CocoPanopticOVDataset'
203
+ num_things_classes = 80
204
+ num_stuff_classes = 53
205
+ num_classes = num_things_classes + num_stuff_classes
206
+
207
+ omgseg_model = dict(
208
+ type=OMGSegVisualEncoder,
209
+ data_preprocessor=None,
210
+ pixel_shuffle_down_ratio=2,
211
+ backbone=dict(
212
+ type=OpenCLIPBackbone_omgseg,
213
+ model_name='convnext_large_d_320',
214
+ fix=True,
215
+ init_cfg=dict(
216
+ type='clip_pretrain',
217
+ checkpoint='laion2b_s29b_b131k_ft_soup'
218
+ )
219
+ ),
220
+ panoptic_head=dict(
221
+ type=Mask2FormerVideoSemSamHead,
222
+ sphere_cls=True,
223
+ ov_path=omg_ov_class_embed_path,
224
+ enable_box_query=False,
225
+ ov_classifier_name=class_embed,
226
+ logit=None,
227
+ in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside
228
+ strides=[4, 8, 16, 32],
229
+ feat_channels=256,
230
+ out_channels=256,
231
+ num_things_classes=num_things_classes,
232
+ num_stuff_classes=num_stuff_classes,
233
+ num_queries=300,
234
+ num_transformer_feat_level=3,
235
+ pixel_decoder=dict(
236
+ type=MSDeformAttnPixelDecoder,
237
+ num_outs=3,
238
+ norm_cfg=dict(type=GroupNorm, num_groups=32),
239
+ act_cfg=dict(type=ReLU),
240
+ encoder=dict( # DeformableDetrTransformerEncoder
241
+ num_layers=6,
242
+ layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
243
+ self_attn_cfg=dict( # MultiScaleDeformableAttention
244
+ embed_dims=256,
245
+ num_heads=8,
246
+ num_levels=3,
247
+ num_points=4,
248
+ dropout=0.0,
249
+ batch_first=True),
250
+ ffn_cfg=dict(
251
+ embed_dims=256,
252
+ feedforward_channels=1024,
253
+ num_fcs=2,
254
+ ffn_drop=0.0,
255
+ act_cfg=dict(type=ReLU, inplace=True)))),
256
+ positional_encoding=dict(num_feats=128, normalize=True)),
257
+ enforce_decoder_input_project=False,
258
+ positional_encoding=dict(num_feats=128, normalize=True),
259
+ transformer_decoder=dict( # Mask2FormerTransformerDecoder
260
+ return_intermediate=True,
261
+ num_layers=9,
262
+ layer_cfg=dict( # Mask2FormerTransformerDecoderLayer
263
+ self_attn_cfg=dict( # MultiheadAttention
264
+ embed_dims=256,
265
+ num_heads=8,
266
+ dropout=0.0,
267
+ batch_first=True),
268
+ cross_attn_cfg=dict( # MultiheadAttention
269
+ embed_dims=256,
270
+ num_heads=8,
271
+ dropout=0.0,
272
+ batch_first=True),
273
+ ffn_cfg=dict(
274
+ embed_dims=256,
275
+ feedforward_channels=2048,
276
+ num_fcs=2,
277
+ ffn_drop=0.0,
278
+ act_cfg=dict(type='ReLU', inplace=True))),
279
+ init_cfg=None),
280
+ loss_cls=dict(
281
+ type=CrossEntropyLoss,
282
+ use_sigmoid=False,
283
+ loss_weight=2.0,
284
+ reduction='mean',
285
+ class_weight=[1.0] * 240 + [0.1]),
286
+ loss_mask=dict(
287
+ type=CrossEntropyLoss,
288
+ use_sigmoid=True,
289
+ reduction='mean',
290
+ loss_weight=5.0),
291
+ loss_dice=dict(
292
+ type=DiceLoss,
293
+ use_sigmoid=True,
294
+ activate=True,
295
+ reduction='mean',
296
+ naive_dice=True,
297
+ eps=1.0,
298
+ loss_weight=5.0),
299
+ loss_iou=dict(
300
+ type=FocalLoss,
301
+ use_sigmoid=True,
302
+ loss_weight=2.0,
303
+ reduction='mean')
304
+ ),
305
+ panoptic_fusion_head=dict(
306
+ type=MaskFormerFusionHead,
307
+ num_things_classes=num_things_classes,
308
+ num_stuff_classes=num_stuff_classes,
309
+ loss_panoptic=None,
310
+ init_cfg=None),
311
+ train_cfg=dict(
312
+ num_points=12544,
313
+ oversample_ratio=3.0,
314
+ importance_sample_ratio=0.75,
315
+ assigner=dict(
316
+ type=HungarianAssigner,
317
+ match_costs=[
318
+ # dict(type=FlexibleClassificationCost, weight=2.0),
319
+ dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True),
320
+ dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0)
321
+ ]),
322
+ sampler=dict(type=MaskPseudoSampler)),
323
+ test_cfg=dict(
324
+ panoptic_on=True,
325
+ # For now, the dataset does not support
326
+ # evaluating semantic segmentation metric.
327
+ semantic_on=False,
328
+ instance_on=True,
329
+ # max_per_image is for instance segmentation.
330
+ max_per_image=100,
331
+ iou_thr=0.8,
332
+ # In Mask2Former's panoptic postprocessing,
333
+ # it will filter mask area where score is less than 0.5 .
334
+ filter_low_score=True),
335
+ init_cfg=dict(
336
+ type='Pretrained',
337
+ checkpoint=omg_head_pretrain_pth_path,
338
+ )
339
+ )
340
+
341
+ model = dict(
342
+ type=OMG_LLaVA,
343
+ freeze_llm=True,
344
+ freeze_visual_encoder=True,
345
+ require_omg_decoder=False,
346
+ pretrained_pth=pretrained_pth,
347
+ text2vision_projector=True,
348
+ pixel_shuffle_ratio=2,
349
+ using_multilayer_states=True,
350
+ seg_token_merge_type='mean',
351
+ selected_layers=32,
352
+ llm=dict(
353
+ type=AutoModelForCausalLM.from_pretrained,
354
+ pretrained_model_name_or_path=llm_name_or_path,
355
+ trust_remote_code=True,
356
+ torch_dtype=torch.float16,
357
+ quantization_config=dict(
358
+ type=BitsAndBytesConfig,
359
+ load_in_4bit=True,
360
+ load_in_8bit=False,
361
+ llm_int8_threshold=6.0,
362
+ llm_int8_has_fp16_weight=False,
363
+ bnb_4bit_compute_dtype=torch.float16,
364
+ bnb_4bit_use_double_quant=True,
365
+ bnb_4bit_quant_type='nf4')),
366
+ llm_lora=dict(
367
+ type=LoraConfig,
368
+ r=512,
369
+ lora_alpha=256,
370
+ lora_dropout=0.05,
371
+ bias='none',
372
+ task_type='CAUSAL_LM'),
373
+ visual_encoder=omgseg_model,
374
+ tokenizer=tokenizer,
375
+ )
376
+
377
+ #######################################################################
378
+ # PART 3 Dataset & Dataloader #
379
+ #######################################################################
380
+ debug=False
381
+ llava_dataset = dict(
382
+ type=LLaVADataset,
383
+ data_path=data_path,
384
+ image_folder=image_folder,
385
+ tokenizer=tokenizer,
386
+ image_processor=image_processor,
387
+ dataset_map_fn=llava_map_fn,
388
+ template_map_fn=dict(
389
+ type=template_map_fn_factory, template=prompt_template),
390
+ max_length=max_length,
391
+ pad_image_to_square=True)
392
+
393
+ glamm_refcocog_dataset = dict(
394
+ type=RefCOCOgGCGDataset,
395
+ data_path=refcocog_ann_file,
396
+ image_folder=refcocog_image_path,
397
+ tokenizer=tokenizer,
398
+ image_processor=image_processor,
399
+ dataset_map_fn=glamm_refcocog_map_fn,
400
+ template_map_fn=dict(
401
+ type=template_map_fn_factory, template=prompt_template),
402
+ max_length=max_length,
403
+ pad_image_to_square=True,
404
+ debug=False,
405
+ repeats=1,
406
+ )
407
+
408
+ glamm_grandf_dataset = dict(
409
+ type=GranDfGCGDataset,
410
+ data_path=grandf_ann_file,
411
+ image_folder=grandf_image_path,
412
+ tokenizer=tokenizer,
413
+ image_processor=image_processor,
414
+ dataset_map_fn=glamm_granf_map_fn,
415
+ template_map_fn=dict(
416
+ type=template_map_fn_factory, template=prompt_template),
417
+ max_length=max_length,
418
+ pad_image_to_square=True,
419
+ debug=debug,
420
+ repeats=10,
421
+ )
422
+
423
+ glamm_psg_dataset = dict(
424
+ type=OpenPsgGCGDataset,
425
+ data_path=psg_ann_file,
426
+ image_folder=psg_image_path,
427
+ tokenizer=tokenizer,
428
+ image_processor=image_processor,
429
+ dataset_map_fn=glamm_openpsg_map_fn,
430
+ template_map_fn=dict(
431
+ type=template_map_fn_factory, template=prompt_template),
432
+ max_length=max_length,
433
+ pad_image_to_square=True,
434
+ debug=debug,
435
+ repeats=1,
436
+ )
437
+
438
+ glamm_flickr_dataset = dict(
439
+ type=FlickrGCGDataset,
440
+ data_path=flickr_ann_file,
441
+ image_folder=flickr_image_path,
442
+ tokenizer=tokenizer,
443
+ image_processor=image_processor,
444
+ dataset_map_fn=glamm_flickr_map_fn,
445
+ template_map_fn=dict(
446
+ type=template_map_fn_factory, template=prompt_template),
447
+ max_length=max_length,
448
+ pad_image_to_square=True,
449
+ debug=debug,
450
+ repeats=1,
451
+ )
452
+
453
+ semantic_seg_ade20k_dataset = dict(
454
+ type=ADE20kSemanticSegDataset,
455
+ data_path=ade20k_class_file,
456
+ image_folder=ade20k_image_path,
457
+ tokenizer=tokenizer,
458
+ image_processor=image_processor,
459
+ dataset_map_fn=semantic_seg_gcg_format_map_fn,
460
+ template_map_fn=dict(
461
+ type=template_map_fn_factory, template=prompt_template),
462
+ max_length=max_length,
463
+ pad_image_to_square=True,
464
+ debug=False,
465
+ repeats=1,
466
+ gcg_format=True,
467
+ )
468
+
469
+ semantic_seg_cocostuff_dataset = dict(
470
+ type=COCOStuffSemanticSegDataset,
471
+ data_path=cocostuff_class_file,
472
+ image_folder=cocostuff_image_path,
473
+ label_path=cocostuff_label_path,
474
+ tokenizer=tokenizer,
475
+ image_processor=image_processor,
476
+ dataset_map_fn=semantic_seg_gcg_format_map_fn,
477
+ template_map_fn=dict(
478
+ type=template_map_fn_factory, template=prompt_template),
479
+ max_length=max_length,
480
+ pad_image_to_square=True,
481
+ debug=False,
482
+ repeats=1,
483
+ gcg_format=True,
484
+ )
485
+
486
+ referring_seg_refcoco_dataset = dict(
487
+ type=RefcocoReferringSegDataset,
488
+ data_path=referring_refcoco_data_path,
489
+ image_folder=referring_refcoco_image_path,
490
+ tokenizer=tokenizer,
491
+ image_processor=image_processor,
492
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
493
+ template_map_fn=dict(
494
+ type=template_map_fn_factory, template=prompt_template),
495
+ max_length=max_length,
496
+ pad_image_to_square=True,
497
+ debug=False,
498
+ repeats=1,
499
+ )
500
+
501
+ referring_seg_refcoco_plus_dataset = dict(
502
+ type=Refcoco_plus_ReferringSegDataset,
503
+ data_path=referring_refcoco_plus_data_path,
504
+ image_folder=referring_refcoco_plus_image_path,
505
+ tokenizer=tokenizer,
506
+ image_processor=image_processor,
507
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
508
+ template_map_fn=dict(
509
+ type=template_map_fn_factory, template=prompt_template),
510
+ max_length=max_length,
511
+ pad_image_to_square=True,
512
+ debug=False,
513
+ repeats=1,
514
+ )
515
+
516
+ referring_seg_refcocog_dataset = dict(
517
+ type=Refcocog_ReferringSegDataset,
518
+ data_path=referring_refcocog_data_path,
519
+ image_folder=referring_refcocog_image_path,
520
+ tokenizer=tokenizer,
521
+ image_processor=image_processor,
522
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
523
+ template_map_fn=dict(
524
+ type=template_map_fn_factory, template=prompt_template),
525
+ max_length=max_length,
526
+ pad_image_to_square=True,
527
+ debug=False,
528
+ repeats=1,
529
+ )
530
+
531
+ referring_seg_refclef_dataset = dict(
532
+ type=Refclef_ReferringSegDataset,
533
+ data_path=referring_refclef_data_path,
534
+ image_folder=referring_refclef_image_path,
535
+ tokenizer=tokenizer,
536
+ image_processor=image_processor,
537
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
538
+ template_map_fn=dict(
539
+ type=template_map_fn_factory, template=prompt_template),
540
+ max_length=max_length,
541
+ pad_image_to_square=True,
542
+ debug=False,
543
+ repeats=1,
544
+ )
545
+
546
+ region_cap_osprey_dataset = dict(
547
+ type=OspreyRegionCaptionDataset,
548
+ data_path=region_cap_osprey_data_path,
549
+ image_folder=region_cap_osprey_image_path,
550
+ tokenizer=tokenizer,
551
+ image_processor=image_processor,
552
+ dataset_map_fn=osprey_region_caption_gcg_format_map_fn,
553
+ template_map_fn=dict(
554
+ type=template_map_fn_factory, template=prompt_template),
555
+ max_length=max_length,
556
+ pad_image_to_square=True,
557
+ debug=False,
558
+ repeats=1,
559
+ )
560
+
561
+ region_conversation_osprey_dataset = dict(
562
+ type=OspreyRegionConversationDataset,
563
+ data_path=region_conversation_osprey_data_path,
564
+ image_folder=region_conversation_osprey_image_path,
565
+ tokenizer=tokenizer,
566
+ image_processor=image_processor,
567
+ dataset_map_fn=osprey_region_conversation_map_fn,
568
+ template_map_fn=dict(
569
+ type=template_map_fn_factory, template=prompt_template),
570
+ max_length=max_length,
571
+ pad_image_to_square=True,
572
+ debug=False,
573
+ repeats=1,
574
+ )
575
+
576
+ mdpv_detailed_description_ade20k_dataset = dict(
577
+ type=MDPVPointDetailedCaptionDataset,
578
+ data_path=mdpv_detailed_caption_ade20k_data_path,
579
+ image_folder=mdpv_detailed_caption_ade20k_image_path,
580
+ tokenizer=tokenizer,
581
+ image_processor=image_processor,
582
+ dataset_map_fn=mdpv_points_map_fn,
583
+ template_map_fn=dict(
584
+ type=template_map_fn_factory, template=prompt_template),
585
+ max_length=max_length,
586
+ pad_image_to_square=True,
587
+ debug=False,
588
+ repeats=1,
589
+ )
590
+
591
+ mdpv_detailed_description_cocostuff_10k_dataset = dict(
592
+ type=MDPVPointDetailedCaptionDataset,
593
+ data_path=mdpv_detailed_caption_cocostuff_10k_data_path,
594
+ image_folder=mdpv_detailed_caption_cocostuff_10k_image_path,
595
+ tokenizer=tokenizer,
596
+ image_processor=image_processor,
597
+ dataset_map_fn=mdpv_points_map_fn,
598
+ template_map_fn=dict(
599
+ type=template_map_fn_factory, template=prompt_template),
600
+ max_length=max_length,
601
+ pad_image_to_square=True,
602
+ debug=False,
603
+ repeats=1,
604
+ )
605
+
606
+ mdpv_detailed_description_cocostuff_164k_dataset = dict(
607
+ type=MDPVPointDetailedCaptionDataset,
608
+ data_path=mdpv_detailed_caption_cocostuff_164k_data_path,
609
+ image_folder=mdpv_detailed_caption_cocostuff_164k_image_path,
610
+ tokenizer=tokenizer,
611
+ image_processor=image_processor,
612
+ dataset_map_fn=mdpv_points_map_fn,
613
+ template_map_fn=dict(
614
+ type=template_map_fn_factory, template=prompt_template),
615
+ max_length=max_length,
616
+ pad_image_to_square=True,
617
+ debug=False,
618
+ repeats=1,
619
+ )
620
+
621
+ mdpv_detailed_description_vg_dataset = dict(
622
+ type=MDPVPointDetailedCaptionDataset,
623
+ data_path=mdpv_detailed_caption_vg_data_path,
624
+ image_folder=mdpv_detailed_caption_vg_image_path,
625
+ tokenizer=tokenizer,
626
+ image_processor=image_processor,
627
+ dataset_map_fn=mdpv_points_map_fn,
628
+ template_map_fn=dict(
629
+ type=template_map_fn_factory, template=prompt_template),
630
+ max_length=max_length,
631
+ pad_image_to_square=True,
632
+ debug=False,
633
+ repeats=1,
634
+ )
635
+
636
+ mdpv_brief_description_vg_dataset = dict(
637
+ type=MDPVPointBriefCaptionDataset,
638
+ data_path=mdpv_brief_caption_vg_data_path,
639
+ image_folder=mdpv_brief_caption_vg_image_path,
640
+ tokenizer=tokenizer,
641
+ image_processor=image_processor,
642
+ dataset_map_fn=mdpv_points_map_fn,
643
+ template_map_fn=dict(
644
+ type=template_map_fn_factory, template=prompt_template),
645
+ max_length=max_length,
646
+ pad_image_to_square=True,
647
+ debug=False,
648
+ repeats=1,
649
+ )
650
+
651
+ mdpv_brief_description_cocostuff10k_dataset = dict(
652
+ type=MDPVPointBriefCaptionDataset,
653
+ data_path=mdpv_brief_caption_cocostuff_10k_data_path,
654
+ image_folder=mdpv_brief_caption_cocostuff_10k_image_path,
655
+ tokenizer=tokenizer,
656
+ image_processor=image_processor,
657
+ dataset_map_fn=mdpv_points_map_fn,
658
+ template_map_fn=dict(
659
+ type=template_map_fn_factory, template=prompt_template),
660
+ max_length=max_length,
661
+ pad_image_to_square=True,
662
+ debug=False,
663
+ repeats=1,
664
+ )
665
+
666
+ mdpv_brief_description_cocostuff164k_dataset = dict(
667
+ type=MDPVPointBriefCaptionDataset,
668
+ data_path=mdpv_brief_caption_cocostuff_164k_data_path,
669
+ image_folder=mdpv_brief_caption_cocostuff_164k_image_path,
670
+ tokenizer=tokenizer,
671
+ image_processor=image_processor,
672
+ dataset_map_fn=mdpv_points_map_fn,
673
+ template_map_fn=dict(
674
+ type=template_map_fn_factory, template=prompt_template),
675
+ max_length=max_length,
676
+ pad_image_to_square=True,
677
+ debug=False,
678
+ repeats=1,
679
+ )
680
+
681
+ mdpv_brief_description_ade20k_dataset = dict(
682
+ type=MDPVPointBriefCaptionDataset,
683
+ data_path=mdpv_brief_caption_ade20k_data_path,
684
+ image_folder=mdpv_brief_caption_ade20k_image_path,
685
+ tokenizer=tokenizer,
686
+ image_processor=image_processor,
687
+ dataset_map_fn=mdpv_points_map_fn,
688
+ template_map_fn=dict(
689
+ type=template_map_fn_factory, template=prompt_template),
690
+ max_length=max_length,
691
+ pad_image_to_square=True,
692
+ debug=False,
693
+ repeats=1,
694
+ )
695
+
696
+ mdpv_brief_description_lvis_dataset = dict(
697
+ type=MDPVPointBriefCaptionDataset,
698
+ data_path=mdpv_brief_caption_lvis_data_path,
699
+ image_folder=mdpv_brief_caption_lvis_image_path,
700
+ tokenizer=tokenizer,
701
+ image_processor=image_processor,
702
+ dataset_map_fn=mdpv_points_map_fn,
703
+ template_map_fn=dict(
704
+ type=template_map_fn_factory, template=prompt_template),
705
+ max_length=max_length,
706
+ pad_image_to_square=True,
707
+ debug=False,
708
+ repeats=1,
709
+ )
710
+
711
+ mdpv_qa_vg_dataset = dict(
712
+ type=MDPVPointBriefCaptionDataset,
713
+ data_path=mdpv_qa_vg_data_path,
714
+ image_folder=mdpv_qa_vg_image_path,
715
+ tokenizer=tokenizer,
716
+ image_processor=image_processor,
717
+ dataset_map_fn=mdpv_points_map_fn,
718
+ template_map_fn=dict(
719
+ type=template_map_fn_factory, template=prompt_template),
720
+ max_length=max_length,
721
+ pad_image_to_square=True,
722
+ debug=False,
723
+ repeats=1,
724
+ )
725
+
726
+ mdpv_qa_ade20k_dataset = dict(
727
+ type=MDPVPointBriefCaptionDataset,
728
+ data_path=mdpv_qa_ade20k_data_path,
729
+ image_folder=mdpv_qa_ade20k_image_path,
730
+ tokenizer=tokenizer,
731
+ image_processor=image_processor,
732
+ dataset_map_fn=mdpv_points_map_fn,
733
+ template_map_fn=dict(
734
+ type=template_map_fn_factory, template=prompt_template),
735
+ max_length=max_length,
736
+ pad_image_to_square=True,
737
+ debug=False,
738
+ repeats=1,
739
+ )
740
+
741
+ mdpv_qa_lvis_dataset = dict(
742
+ type=MDPVPointBriefCaptionDataset,
743
+ data_path=mdpv_qa_lvis_data_path,
744
+ image_folder=mdpv_qa_lvis_image_path,
745
+ tokenizer=tokenizer,
746
+ image_processor=image_processor,
747
+ dataset_map_fn=mdpv_points_map_fn,
748
+ template_map_fn=dict(
749
+ type=template_map_fn_factory, template=prompt_template),
750
+ max_length=max_length,
751
+ pad_image_to_square=True,
752
+ debug=False,
753
+ repeats=1,
754
+ )
755
+
756
+ mdpv_qa_cocostuff10k_dataset = dict(
757
+ type=MDPVPointBriefCaptionDataset,
758
+ data_path=mdpv_qa_cocostuff10k_data_path,
759
+ image_folder=mdpv_qa_cocostuff10k_image_path,
760
+ tokenizer=tokenizer,
761
+ image_processor=image_processor,
762
+ dataset_map_fn=mdpv_points_map_fn,
763
+ template_map_fn=dict(
764
+ type=template_map_fn_factory, template=prompt_template),
765
+ max_length=max_length,
766
+ pad_image_to_square=True,
767
+ debug=False,
768
+ repeats=1,
769
+ )
770
+
771
+ mdpv_qa_cocostuff164k_dataset = dict(
772
+ type=MDPVPointBriefCaptionDataset,
773
+ data_path=mdpv_qa_cocostuff164k_data_path,
774
+ image_folder=mdpv_qa_cocostuff164k_image_path,
775
+ tokenizer=tokenizer,
776
+ image_processor=image_processor,
777
+ dataset_map_fn=mdpv_points_map_fn,
778
+ template_map_fn=dict(
779
+ type=template_map_fn_factory, template=prompt_template),
780
+ max_length=max_length,
781
+ pad_image_to_square=True,
782
+ debug=False,
783
+ repeats=1,
784
+ )
785
+
786
+ mdpv_multi_points_openpsg_dataset = dict(
787
+ type=MDPVPointBriefCaptionDataset,
788
+ data_path=mdpv_multi_points_openpsg_data_path,
789
+ image_folder=mdpv_multi_points_openpsg_image_path,
790
+ tokenizer=tokenizer,
791
+ image_processor=image_processor,
792
+ dataset_map_fn=mdpv_points_map_fn,
793
+ template_map_fn=dict(
794
+ type=template_map_fn_factory, template=prompt_template),
795
+ max_length=max_length,
796
+ pad_image_to_square=True,
797
+ debug=False,
798
+ repeats=1,
799
+ )
800
+
801
+ mdpv_multi_points_flicker30k_dataset = dict(
802
+ type=MDPVPointBriefCaptionDataset,
803
+ data_path=mdpv_multi_points_flicker30k_data_path,
804
+ image_folder=mdpv_multi_points_flicker30k_image_path,
805
+ tokenizer=tokenizer,
806
+ image_processor=image_processor,
807
+ dataset_map_fn=mdpv_points_map_fn,
808
+ template_map_fn=dict(
809
+ type=template_map_fn_factory, template=prompt_template),
810
+ max_length=max_length,
811
+ pad_image_to_square=True,
812
+ debug=False,
813
+ repeats=1,
814
+ )
815
+
816
+ train_dataset = dict(
817
+ type=CombineDataset,
818
+ datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset,
819
+ glamm_grandf_dataset, glamm_psg_dataset,
820
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x
821
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset,
822
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset,
823
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
824
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x
825
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
826
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset,
827
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
828
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset,
829
+ region_cap_osprey_dataset, region_conversation_osprey_dataset,
830
+ mdpv_detailed_description_ade20k_dataset,
831
+ mdpv_detailed_description_cocostuff_10k_dataset,
832
+ mdpv_detailed_description_cocostuff_164k_dataset,
833
+ mdpv_detailed_description_vg_dataset,
834
+ mdpv_brief_description_lvis_dataset,
835
+ mdpv_brief_description_vg_dataset,
836
+ mdpv_brief_description_ade20k_dataset,
837
+ mdpv_brief_description_cocostuff10k_dataset,
838
+ mdpv_brief_description_cocostuff164k_dataset,
839
+ mdpv_qa_vg_dataset,
840
+ mdpv_qa_lvis_dataset,
841
+ mdpv_qa_ade20k_dataset,
842
+ mdpv_qa_cocostuff10k_dataset,
843
+ mdpv_qa_cocostuff164k_dataset,
844
+ mdpv_multi_points_flicker30k_dataset,
845
+ mdpv_multi_points_openpsg_dataset,],
846
+ )
847
+
848
+ train_dataloader = dict(
849
+ batch_size=batch_size,
850
+ num_workers=dataloader_num_workers,
851
+ dataset=train_dataset,
852
+ sampler=dict(
853
+ type=LengthGroupedSampler,
854
+ length_property='modality_length',
855
+ per_device_batch_size=batch_size * accumulative_counts),
856
+ collate_fn=dict(type=omg_llava_collate_fn))
857
+
858
+ #######################################################################
859
+ # PART 4 Scheduler & Optimizer #
860
+ #######################################################################
861
+ # optimizer
862
+ optim_wrapper = dict(
863
+ type=AmpOptimWrapper,
864
+ optimizer=dict(
865
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
866
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
867
+ accumulative_counts=accumulative_counts,
868
+ loss_scale='dynamic',
869
+ dtype='float16')
870
+
871
+ # learning policy
872
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
873
+ param_scheduler = [
874
+ dict(
875
+ type=LinearLR,
876
+ start_factor=1e-5,
877
+ by_epoch=True,
878
+ begin=0,
879
+ end=warmup_ratio * max_epochs,
880
+ convert_to_iter_based=True),
881
+ dict(
882
+ type=CosineAnnealingLR,
883
+ eta_min=0.0,
884
+ by_epoch=True,
885
+ begin=warmup_ratio * max_epochs,
886
+ end=max_epochs,
887
+ convert_to_iter_based=True)
888
+ ]
889
+
890
+ # train, val, test setting
891
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
892
+
893
+ #######################################################################
894
+ # PART 5 Runtime #
895
+ #######################################################################
896
+ # Log the dialogue periodically during the training process, optional
897
+ custom_hooks = [
898
+ dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer),
899
+ dict(
900
+ type=EvaluateChatHook_withSpecialTokens,
901
+ tokenizer=tokenizer,
902
+ image_processor=image_processor,
903
+ every_n_iters=evaluation_freq,
904
+ evaluation_inputs=evaluation_inputs,
905
+ evaluation_images=evaluation_images,
906
+ system=SYSTEM,
907
+ prompt_template=prompt_template)
908
+ ]
909
+
910
+ # configure default hooks
911
+ default_hooks = dict(
912
+ # record the time of every iteration.
913
+ timer=dict(type=IterTimerHook),
914
+ # print log every 10 iterations.
915
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
916
+ # enable the parameter scheduler.
917
+ param_scheduler=dict(type=ParamSchedulerHook),
918
+ # save checkpoint per `save_steps`.
919
+ checkpoint=dict(
920
+ type=CheckpointHook,
921
+ by_epoch=False,
922
+ interval=save_steps,
923
+ max_keep_ckpts=save_total_limit),
924
+ # set sampler seed in distributed evrionment.
925
+ sampler_seed=dict(type=DistSamplerSeedHook),
926
+ )
927
+
928
+ # configure environment
929
+ env_cfg = dict(
930
+ # whether to enable cudnn benchmark
931
+ cudnn_benchmark=False,
932
+ # set multi process parameters
933
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
934
+ # set distributed parameters
935
+ dist_cfg=dict(backend='nccl'),
936
+ )
937
+
938
+ # set visualizer
939
+ visualizer = None
940
+
941
+ # set log level
942
+ log_level = 'INFO'
943
+
944
+ # load from which checkpoint
945
+ load_from = None
946
+
947
+ # whether to resume training from the loaded checkpoint
948
+ resume = False
949
+
950
+ # Defaults to use random seed and disable `deterministic`
951
+ randomness = dict(seed=None, deterministic=False)
952
+
953
+ # set log processor
954
+ log_processor = dict(by_epoch=False)
omg_llava/configs/finetune/ablation_multi_seg_states/debug.py ADDED
@@ -0,0 +1,924 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
4
+ LoggerHook, ParamSchedulerHook)
5
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
6
+ from peft import LoraConfig
7
+ from torch.optim import AdamW
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig, CLIPImageProcessor,
10
+ CLIPVisionModel)
11
+
12
+ from omg_llava.dataset import LLaVADataset
13
+ from omg_llava.dataset.collect_fns import omg_llava_collate_fn
14
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
15
+ from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\
16
+ CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\
17
+ ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\
18
+ PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\
19
+ RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\
20
+ Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\
21
+ OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\
22
+ OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\
23
+ MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\
24
+ semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\
25
+ referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn
26
+ from xtuner.dataset.samplers import LengthGroupedSampler
27
+ from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens
28
+ from xtuner.engine.runner import TrainLoop
29
+ from omg_llava.model import OMG_LLaVA
30
+ from xtuner.utils import PROMPT_TEMPLATE
31
+ from omg_llava.model import OpenCLIPBackbone_omgseg
32
+ from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead
33
+
34
+ from torch.nn import GroupNorm, ReLU
35
+
36
+ from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \
37
+ DiceLoss, MaskFormerFusionHead, FocalLoss
38
+ from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost
39
+ from mmdet.models.task_modules.samplers import MaskPseudoSampler
40
+
41
+ #######################################################################
42
+ # PART 1 Settings #
43
+ #######################################################################
44
+ # Model
45
+ llm_name_or_path = '/mnt/workspace/taozhang/chekpoints/internlm2-chat-7b' # Please change to your own path
46
+ pretrained_pth = './work_dirs/omg_llava_7b_pretrain_1024image_8gpus/iter_4361.pth'
47
+ omg_ov_class_embed_path='/mnt/workspace/taozhang/chekpoints/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path
48
+ omg_head_pretrain_pth_path = '/mnt/workspace/taozhang/chekpoints/omg_seg_convl.pth' # Please change to your own path
49
+
50
+ # Data
51
+ data_root = './data/llava_data/'
52
+ data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json'
53
+ image_folder = data_root + 'llava_images'
54
+
55
+ glamm_data_root = './data/glamm_data/'
56
+
57
+ refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/'
58
+ refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json'
59
+
60
+ grandf_image_path = glamm_data_root + 'images/grandf/train/'
61
+ grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json'
62
+
63
+ flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/'
64
+ flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json'
65
+
66
+ psg_image_path = glamm_data_root + 'images/coco2017/'
67
+ psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json'
68
+
69
+ ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
70
+ ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json'
71
+
72
+ cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/'
73
+ cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt'
74
+ cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/'
75
+
76
+ mapillary_image_path = './data/semantic_seg/mapillary/training/images/'
77
+ mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json'
78
+ mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/'
79
+
80
+ pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/'
81
+ pascal_file = './data/semantic_seg/pascal_part/train.json'
82
+
83
+ paco_image_path = './data/glamm_data/images/coco2017/'
84
+ paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json'
85
+
86
+ referring_refcoco_image_path = refcocog_image_path
87
+ referring_refcoco_data_path = "./data/ref_seg/"
88
+
89
+ referring_refcoco_plus_image_path = refcocog_image_path
90
+ referring_refcoco_plus_data_path = "./data/ref_seg/"
91
+
92
+ referring_refcocog_image_path = refcocog_image_path
93
+ referring_refcocog_data_path = "./data/ref_seg/"
94
+
95
+ referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/"
96
+ referring_refclef_data_path = "./data/ref_seg/"
97
+
98
+ region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
99
+ region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json"
100
+
101
+ region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
102
+ region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json"
103
+
104
+ mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
105
+ mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json'
106
+
107
+ mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
108
+ mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json'
109
+
110
+ mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
111
+ mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json'
112
+
113
+ mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
114
+ mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json'
115
+
116
+ mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
117
+ mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json'
118
+
119
+ mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
120
+ mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json'
121
+
122
+ mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
123
+ mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json'
124
+
125
+ mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
126
+ mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json'
127
+
128
+ mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
129
+ mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json'
130
+
131
+ mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
132
+ mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json'
133
+
134
+ mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
135
+ mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json'
136
+
137
+ mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017'
138
+ mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json'
139
+
140
+ mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
141
+ mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json'
142
+
143
+ mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
144
+ mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json'
145
+
146
+ mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/'
147
+ mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json'
148
+
149
+ mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017'
150
+ mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json'
151
+
152
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
153
+ max_length = int(2048 - (1024 / 64)**2 - 100)
154
+
155
+ # Scheduler & Optimizer
156
+ batch_size = 8 # per_device
157
+ accumulative_counts = 2
158
+ dataloader_num_workers = 4
159
+ max_epochs = 1
160
+ optim_type = AdamW
161
+ lr = 2e-4
162
+ betas = (0.9, 0.999)
163
+ weight_decay = 0
164
+ max_norm = 1 # grad clip
165
+ warmup_ratio = 0.03
166
+
167
+
168
+ # Save
169
+ save_steps = 2000
170
+ save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited)
171
+
172
+ # Evaluate the generation performance during the training
173
+ evaluation_freq = 2000
174
+ SYSTEM = ''
175
+ evaluation_images = './work_dirs/test.jpg'
176
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture',
177
+ 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.']
178
+
179
+ #######################################################################
180
+ # PART 2 Model & Tokenizer & Image Processor #
181
+ #######################################################################
182
+ tokenizer = dict(
183
+ type=AutoTokenizer.from_pretrained,
184
+ pretrained_model_name_or_path=llm_name_or_path,
185
+ trust_remote_code=True,
186
+ padding_side='right')
187
+
188
+ image_processor = dict(
189
+ type=CLIPImageProcessor,
190
+ do_resize=True,
191
+ size=1024,
192
+ resample=3,
193
+ do_center_crop=True,
194
+ crop_size=1024,
195
+ do_rescale=True,
196
+ do_normalize=True,
197
+ image_mean=[0.4814, 0.4578, 0.4082],
198
+ image_std=[0.2686, 0.2613, 0.2757],
199
+ do_convert_rgb=True
200
+ )
201
+
202
+ class_embed = 'convnext_large_d_320_CocoPanopticOVDataset'
203
+ num_things_classes = 80
204
+ num_stuff_classes = 53
205
+ num_classes = num_things_classes + num_stuff_classes
206
+
207
+ omgseg_model = dict(
208
+ type=OMGSegVisualEncoder,
209
+ data_preprocessor=None,
210
+ pixel_shuffle_down_ratio=2,
211
+ backbone=dict(
212
+ type=OpenCLIPBackbone_omgseg,
213
+ model_name='convnext_large_d_320',
214
+ fix=True,
215
+ init_cfg=dict(
216
+ type='clip_pretrain',
217
+ checkpoint='laion2b_s29b_b131k_ft_soup'
218
+ )
219
+ ),
220
+ panoptic_head=dict(
221
+ type=Mask2FormerVideoSemSamHead,
222
+ sphere_cls=True,
223
+ ov_path=omg_ov_class_embed_path,
224
+ enable_box_query=False,
225
+ ov_classifier_name=class_embed,
226
+ logit=None,
227
+ in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside
228
+ strides=[4, 8, 16, 32],
229
+ feat_channels=256,
230
+ out_channels=256,
231
+ num_things_classes=num_things_classes,
232
+ num_stuff_classes=num_stuff_classes,
233
+ num_queries=300,
234
+ num_transformer_feat_level=3,
235
+ pixel_decoder=dict(
236
+ type=MSDeformAttnPixelDecoder,
237
+ num_outs=3,
238
+ norm_cfg=dict(type=GroupNorm, num_groups=32),
239
+ act_cfg=dict(type=ReLU),
240
+ encoder=dict( # DeformableDetrTransformerEncoder
241
+ num_layers=6,
242
+ layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
243
+ self_attn_cfg=dict( # MultiScaleDeformableAttention
244
+ embed_dims=256,
245
+ num_heads=8,
246
+ num_levels=3,
247
+ num_points=4,
248
+ dropout=0.0,
249
+ batch_first=True),
250
+ ffn_cfg=dict(
251
+ embed_dims=256,
252
+ feedforward_channels=1024,
253
+ num_fcs=2,
254
+ ffn_drop=0.0,
255
+ act_cfg=dict(type=ReLU, inplace=True)))),
256
+ positional_encoding=dict(num_feats=128, normalize=True)),
257
+ enforce_decoder_input_project=False,
258
+ positional_encoding=dict(num_feats=128, normalize=True),
259
+ transformer_decoder=dict( # Mask2FormerTransformerDecoder
260
+ return_intermediate=True,
261
+ num_layers=9,
262
+ layer_cfg=dict( # Mask2FormerTransformerDecoderLayer
263
+ self_attn_cfg=dict( # MultiheadAttention
264
+ embed_dims=256,
265
+ num_heads=8,
266
+ dropout=0.0,
267
+ batch_first=True),
268
+ cross_attn_cfg=dict( # MultiheadAttention
269
+ embed_dims=256,
270
+ num_heads=8,
271
+ dropout=0.0,
272
+ batch_first=True),
273
+ ffn_cfg=dict(
274
+ embed_dims=256,
275
+ feedforward_channels=2048,
276
+ num_fcs=2,
277
+ ffn_drop=0.0,
278
+ act_cfg=dict(type='ReLU', inplace=True))),
279
+ init_cfg=None),
280
+ loss_cls=dict(
281
+ type=CrossEntropyLoss,
282
+ use_sigmoid=False,
283
+ loss_weight=2.0,
284
+ reduction='mean',
285
+ class_weight=[1.0] * 240 + [0.1]),
286
+ loss_mask=dict(
287
+ type=CrossEntropyLoss,
288
+ use_sigmoid=True,
289
+ reduction='mean',
290
+ loss_weight=5.0),
291
+ loss_dice=dict(
292
+ type=DiceLoss,
293
+ use_sigmoid=True,
294
+ activate=True,
295
+ reduction='mean',
296
+ naive_dice=True,
297
+ eps=1.0,
298
+ loss_weight=5.0),
299
+ loss_iou=dict(
300
+ type=FocalLoss,
301
+ use_sigmoid=True,
302
+ loss_weight=2.0,
303
+ reduction='mean')
304
+ ),
305
+ panoptic_fusion_head=dict(
306
+ type=MaskFormerFusionHead,
307
+ num_things_classes=num_things_classes,
308
+ num_stuff_classes=num_stuff_classes,
309
+ loss_panoptic=None,
310
+ init_cfg=None),
311
+ train_cfg=dict(
312
+ num_points=12544,
313
+ oversample_ratio=3.0,
314
+ importance_sample_ratio=0.75,
315
+ assigner=dict(
316
+ type=HungarianAssigner,
317
+ match_costs=[
318
+ # dict(type=FlexibleClassificationCost, weight=2.0),
319
+ dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True),
320
+ dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0)
321
+ ]),
322
+ sampler=dict(type=MaskPseudoSampler)),
323
+ test_cfg=dict(
324
+ panoptic_on=True,
325
+ # For now, the dataset does not support
326
+ # evaluating semantic segmentation metric.
327
+ semantic_on=False,
328
+ instance_on=True,
329
+ # max_per_image is for instance segmentation.
330
+ max_per_image=100,
331
+ iou_thr=0.8,
332
+ # In Mask2Former's panoptic postprocessing,
333
+ # it will filter mask area where score is less than 0.5 .
334
+ filter_low_score=True),
335
+ init_cfg=dict(
336
+ type='Pretrained',
337
+ checkpoint=omg_head_pretrain_pth_path,
338
+ )
339
+ )
340
+
341
+ model = dict(
342
+ type=OMG_LLaVA,
343
+ freeze_llm=True,
344
+ freeze_visual_encoder=True,
345
+ require_omg_decoder=False,
346
+ pretrained_pth=pretrained_pth,
347
+ text2vision_projector=True,
348
+ pixel_shuffle_ratio=2,
349
+ llm=dict(
350
+ type=AutoModelForCausalLM.from_pretrained,
351
+ pretrained_model_name_or_path=llm_name_or_path,
352
+ trust_remote_code=True,
353
+ torch_dtype=torch.float16,
354
+ quantization_config=dict(
355
+ type=BitsAndBytesConfig,
356
+ load_in_4bit=True,
357
+ load_in_8bit=False,
358
+ llm_int8_threshold=6.0,
359
+ llm_int8_has_fp16_weight=False,
360
+ bnb_4bit_compute_dtype=torch.float16,
361
+ bnb_4bit_use_double_quant=True,
362
+ bnb_4bit_quant_type='nf4')),
363
+ llm_lora=dict(
364
+ type=LoraConfig,
365
+ r=512,
366
+ lora_alpha=256,
367
+ lora_dropout=0.05,
368
+ bias='none',
369
+ task_type='CAUSAL_LM'),
370
+ visual_encoder=omgseg_model,
371
+ tokenizer=tokenizer,
372
+ )
373
+
374
+ #######################################################################
375
+ # PART 3 Dataset & Dataloader #
376
+ #######################################################################
377
+ debug=False
378
+ llava_dataset = dict(
379
+ type=LLaVADataset,
380
+ data_path=data_path,
381
+ image_folder=image_folder,
382
+ tokenizer=tokenizer,
383
+ image_processor=image_processor,
384
+ dataset_map_fn=llava_map_fn,
385
+ template_map_fn=dict(
386
+ type=template_map_fn_factory, template=prompt_template),
387
+ max_length=max_length,
388
+ pad_image_to_square=True)
389
+
390
+ glamm_refcocog_dataset = dict(
391
+ type=RefCOCOgGCGDataset,
392
+ data_path=refcocog_ann_file,
393
+ image_folder=refcocog_image_path,
394
+ tokenizer=tokenizer,
395
+ image_processor=image_processor,
396
+ dataset_map_fn=glamm_refcocog_map_fn,
397
+ template_map_fn=dict(
398
+ type=template_map_fn_factory, template=prompt_template),
399
+ max_length=max_length,
400
+ pad_image_to_square=True,
401
+ debug=False,
402
+ repeats=1,
403
+ )
404
+
405
+ glamm_grandf_dataset = dict(
406
+ type=GranDfGCGDataset,
407
+ data_path=grandf_ann_file,
408
+ image_folder=grandf_image_path,
409
+ tokenizer=tokenizer,
410
+ image_processor=image_processor,
411
+ dataset_map_fn=glamm_granf_map_fn,
412
+ template_map_fn=dict(
413
+ type=template_map_fn_factory, template=prompt_template),
414
+ max_length=max_length,
415
+ pad_image_to_square=True,
416
+ debug=debug,
417
+ repeats=10,
418
+ )
419
+
420
+ glamm_psg_dataset = dict(
421
+ type=OpenPsgGCGDataset,
422
+ data_path=psg_ann_file,
423
+ image_folder=psg_image_path,
424
+ tokenizer=tokenizer,
425
+ image_processor=image_processor,
426
+ dataset_map_fn=glamm_openpsg_map_fn,
427
+ template_map_fn=dict(
428
+ type=template_map_fn_factory, template=prompt_template),
429
+ max_length=max_length,
430
+ pad_image_to_square=True,
431
+ debug=debug,
432
+ repeats=1,
433
+ )
434
+
435
+ glamm_flickr_dataset = dict(
436
+ type=FlickrGCGDataset,
437
+ data_path=flickr_ann_file,
438
+ image_folder=flickr_image_path,
439
+ tokenizer=tokenizer,
440
+ image_processor=image_processor,
441
+ dataset_map_fn=glamm_flickr_map_fn,
442
+ template_map_fn=dict(
443
+ type=template_map_fn_factory, template=prompt_template),
444
+ max_length=max_length,
445
+ pad_image_to_square=True,
446
+ debug=debug,
447
+ repeats=1,
448
+ )
449
+
450
+ semantic_seg_ade20k_dataset = dict(
451
+ type=ADE20kSemanticSegDataset,
452
+ data_path=ade20k_class_file,
453
+ image_folder=ade20k_image_path,
454
+ tokenizer=tokenizer,
455
+ image_processor=image_processor,
456
+ dataset_map_fn=semantic_seg_gcg_format_map_fn,
457
+ template_map_fn=dict(
458
+ type=template_map_fn_factory, template=prompt_template),
459
+ max_length=max_length,
460
+ pad_image_to_square=True,
461
+ debug=False,
462
+ repeats=1,
463
+ gcg_format=True,
464
+ )
465
+
466
+ semantic_seg_cocostuff_dataset = dict(
467
+ type=COCOStuffSemanticSegDataset,
468
+ data_path=cocostuff_class_file,
469
+ image_folder=cocostuff_image_path,
470
+ label_path=cocostuff_label_path,
471
+ tokenizer=tokenizer,
472
+ image_processor=image_processor,
473
+ dataset_map_fn=semantic_seg_gcg_format_map_fn,
474
+ template_map_fn=dict(
475
+ type=template_map_fn_factory, template=prompt_template),
476
+ max_length=max_length,
477
+ pad_image_to_square=True,
478
+ debug=False,
479
+ repeats=1,
480
+ gcg_format=True,
481
+ )
482
+
483
+ referring_seg_refcoco_dataset = dict(
484
+ type=RefcocoReferringSegDataset,
485
+ data_path=referring_refcoco_data_path,
486
+ image_folder=referring_refcoco_image_path,
487
+ tokenizer=tokenizer,
488
+ image_processor=image_processor,
489
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
490
+ template_map_fn=dict(
491
+ type=template_map_fn_factory, template=prompt_template),
492
+ max_length=max_length,
493
+ pad_image_to_square=True,
494
+ debug=False,
495
+ repeats=1,
496
+ )
497
+
498
+ referring_seg_refcoco_plus_dataset = dict(
499
+ type=Refcoco_plus_ReferringSegDataset,
500
+ data_path=referring_refcoco_plus_data_path,
501
+ image_folder=referring_refcoco_plus_image_path,
502
+ tokenizer=tokenizer,
503
+ image_processor=image_processor,
504
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
505
+ template_map_fn=dict(
506
+ type=template_map_fn_factory, template=prompt_template),
507
+ max_length=max_length,
508
+ pad_image_to_square=True,
509
+ debug=False,
510
+ repeats=1,
511
+ )
512
+
513
+ referring_seg_refcocog_dataset = dict(
514
+ type=Refcocog_ReferringSegDataset,
515
+ data_path=referring_refcocog_data_path,
516
+ image_folder=referring_refcocog_image_path,
517
+ tokenizer=tokenizer,
518
+ image_processor=image_processor,
519
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
520
+ template_map_fn=dict(
521
+ type=template_map_fn_factory, template=prompt_template),
522
+ max_length=max_length,
523
+ pad_image_to_square=True,
524
+ debug=False,
525
+ repeats=1,
526
+ )
527
+
528
+ referring_seg_refclef_dataset = dict(
529
+ type=Refclef_ReferringSegDataset,
530
+ data_path=referring_refclef_data_path,
531
+ image_folder=referring_refclef_image_path,
532
+ tokenizer=tokenizer,
533
+ image_processor=image_processor,
534
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
535
+ template_map_fn=dict(
536
+ type=template_map_fn_factory, template=prompt_template),
537
+ max_length=max_length,
538
+ pad_image_to_square=True,
539
+ debug=False,
540
+ repeats=1,
541
+ )
542
+
543
+ region_cap_osprey_dataset = dict(
544
+ type=OspreyRegionCaptionDataset,
545
+ data_path=region_cap_osprey_data_path,
546
+ image_folder=region_cap_osprey_image_path,
547
+ tokenizer=tokenizer,
548
+ image_processor=image_processor,
549
+ dataset_map_fn=osprey_region_caption_gcg_format_map_fn,
550
+ template_map_fn=dict(
551
+ type=template_map_fn_factory, template=prompt_template),
552
+ max_length=max_length,
553
+ pad_image_to_square=True,
554
+ debug=False,
555
+ repeats=1,
556
+ )
557
+
558
+ region_conversation_osprey_dataset = dict(
559
+ type=OspreyRegionConversationDataset,
560
+ data_path=region_conversation_osprey_data_path,
561
+ image_folder=region_conversation_osprey_image_path,
562
+ tokenizer=tokenizer,
563
+ image_processor=image_processor,
564
+ dataset_map_fn=osprey_region_conversation_map_fn,
565
+ template_map_fn=dict(
566
+ type=template_map_fn_factory, template=prompt_template),
567
+ max_length=max_length,
568
+ pad_image_to_square=True,
569
+ debug=False,
570
+ repeats=1,
571
+ )
572
+
573
+ mdpv_detailed_description_ade20k_dataset = dict(
574
+ type=MDPVPointDetailedCaptionDataset,
575
+ data_path=mdpv_detailed_caption_ade20k_data_path,
576
+ image_folder=mdpv_detailed_caption_ade20k_image_path,
577
+ tokenizer=tokenizer,
578
+ image_processor=image_processor,
579
+ dataset_map_fn=mdpv_points_map_fn,
580
+ template_map_fn=dict(
581
+ type=template_map_fn_factory, template=prompt_template),
582
+ max_length=max_length,
583
+ pad_image_to_square=True,
584
+ debug=False,
585
+ repeats=1,
586
+ )
587
+
588
+ mdpv_detailed_description_cocostuff_10k_dataset = dict(
589
+ type=MDPVPointDetailedCaptionDataset,
590
+ data_path=mdpv_detailed_caption_cocostuff_10k_data_path,
591
+ image_folder=mdpv_detailed_caption_cocostuff_10k_image_path,
592
+ tokenizer=tokenizer,
593
+ image_processor=image_processor,
594
+ dataset_map_fn=mdpv_points_map_fn,
595
+ template_map_fn=dict(
596
+ type=template_map_fn_factory, template=prompt_template),
597
+ max_length=max_length,
598
+ pad_image_to_square=True,
599
+ debug=False,
600
+ repeats=1,
601
+ )
602
+
603
+ mdpv_detailed_description_cocostuff_164k_dataset = dict(
604
+ type=MDPVPointDetailedCaptionDataset,
605
+ data_path=mdpv_detailed_caption_cocostuff_164k_data_path,
606
+ image_folder=mdpv_detailed_caption_cocostuff_164k_image_path,
607
+ tokenizer=tokenizer,
608
+ image_processor=image_processor,
609
+ dataset_map_fn=mdpv_points_map_fn,
610
+ template_map_fn=dict(
611
+ type=template_map_fn_factory, template=prompt_template),
612
+ max_length=max_length,
613
+ pad_image_to_square=True,
614
+ debug=False,
615
+ repeats=1,
616
+ )
617
+
618
+ mdpv_detailed_description_vg_dataset = dict(
619
+ type=MDPVPointDetailedCaptionDataset,
620
+ data_path=mdpv_detailed_caption_vg_data_path,
621
+ image_folder=mdpv_detailed_caption_vg_image_path,
622
+ tokenizer=tokenizer,
623
+ image_processor=image_processor,
624
+ dataset_map_fn=mdpv_points_map_fn,
625
+ template_map_fn=dict(
626
+ type=template_map_fn_factory, template=prompt_template),
627
+ max_length=max_length,
628
+ pad_image_to_square=True,
629
+ debug=False,
630
+ repeats=1,
631
+ )
632
+
633
+ mdpv_brief_description_vg_dataset = dict(
634
+ type=MDPVPointBriefCaptionDataset,
635
+ data_path=mdpv_brief_caption_vg_data_path,
636
+ image_folder=mdpv_brief_caption_vg_image_path,
637
+ tokenizer=tokenizer,
638
+ image_processor=image_processor,
639
+ dataset_map_fn=mdpv_points_map_fn,
640
+ template_map_fn=dict(
641
+ type=template_map_fn_factory, template=prompt_template),
642
+ max_length=max_length,
643
+ pad_image_to_square=True,
644
+ debug=False,
645
+ repeats=1,
646
+ )
647
+
648
+ mdpv_brief_description_cocostuff10k_dataset = dict(
649
+ type=MDPVPointBriefCaptionDataset,
650
+ data_path=mdpv_brief_caption_cocostuff_10k_data_path,
651
+ image_folder=mdpv_brief_caption_cocostuff_10k_image_path,
652
+ tokenizer=tokenizer,
653
+ image_processor=image_processor,
654
+ dataset_map_fn=mdpv_points_map_fn,
655
+ template_map_fn=dict(
656
+ type=template_map_fn_factory, template=prompt_template),
657
+ max_length=max_length,
658
+ pad_image_to_square=True,
659
+ debug=False,
660
+ repeats=1,
661
+ )
662
+
663
+ mdpv_brief_description_cocostuff164k_dataset = dict(
664
+ type=MDPVPointBriefCaptionDataset,
665
+ data_path=mdpv_brief_caption_cocostuff_164k_data_path,
666
+ image_folder=mdpv_brief_caption_cocostuff_164k_image_path,
667
+ tokenizer=tokenizer,
668
+ image_processor=image_processor,
669
+ dataset_map_fn=mdpv_points_map_fn,
670
+ template_map_fn=dict(
671
+ type=template_map_fn_factory, template=prompt_template),
672
+ max_length=max_length,
673
+ pad_image_to_square=True,
674
+ debug=False,
675
+ repeats=1,
676
+ )
677
+
678
+ mdpv_brief_description_ade20k_dataset = dict(
679
+ type=MDPVPointBriefCaptionDataset,
680
+ data_path=mdpv_brief_caption_ade20k_data_path,
681
+ image_folder=mdpv_brief_caption_ade20k_image_path,
682
+ tokenizer=tokenizer,
683
+ image_processor=image_processor,
684
+ dataset_map_fn=mdpv_points_map_fn,
685
+ template_map_fn=dict(
686
+ type=template_map_fn_factory, template=prompt_template),
687
+ max_length=max_length,
688
+ pad_image_to_square=True,
689
+ debug=False,
690
+ repeats=1,
691
+ )
692
+
693
+ mdpv_brief_description_lvis_dataset = dict(
694
+ type=MDPVPointBriefCaptionDataset,
695
+ data_path=mdpv_brief_caption_lvis_data_path,
696
+ image_folder=mdpv_brief_caption_lvis_image_path,
697
+ tokenizer=tokenizer,
698
+ image_processor=image_processor,
699
+ dataset_map_fn=mdpv_points_map_fn,
700
+ template_map_fn=dict(
701
+ type=template_map_fn_factory, template=prompt_template),
702
+ max_length=max_length,
703
+ pad_image_to_square=True,
704
+ debug=False,
705
+ repeats=1,
706
+ )
707
+
708
+ mdpv_qa_vg_dataset = dict(
709
+ type=MDPVPointBriefCaptionDataset,
710
+ data_path=mdpv_qa_vg_data_path,
711
+ image_folder=mdpv_qa_vg_image_path,
712
+ tokenizer=tokenizer,
713
+ image_processor=image_processor,
714
+ dataset_map_fn=mdpv_points_map_fn,
715
+ template_map_fn=dict(
716
+ type=template_map_fn_factory, template=prompt_template),
717
+ max_length=max_length,
718
+ pad_image_to_square=True,
719
+ debug=False,
720
+ repeats=1,
721
+ )
722
+
723
+ mdpv_qa_ade20k_dataset = dict(
724
+ type=MDPVPointBriefCaptionDataset,
725
+ data_path=mdpv_qa_ade20k_data_path,
726
+ image_folder=mdpv_qa_ade20k_image_path,
727
+ tokenizer=tokenizer,
728
+ image_processor=image_processor,
729
+ dataset_map_fn=mdpv_points_map_fn,
730
+ template_map_fn=dict(
731
+ type=template_map_fn_factory, template=prompt_template),
732
+ max_length=max_length,
733
+ pad_image_to_square=True,
734
+ debug=False,
735
+ repeats=1,
736
+ )
737
+
738
+ mdpv_qa_lvis_dataset = dict(
739
+ type=MDPVPointBriefCaptionDataset,
740
+ data_path=mdpv_qa_lvis_data_path,
741
+ image_folder=mdpv_qa_lvis_image_path,
742
+ tokenizer=tokenizer,
743
+ image_processor=image_processor,
744
+ dataset_map_fn=mdpv_points_map_fn,
745
+ template_map_fn=dict(
746
+ type=template_map_fn_factory, template=prompt_template),
747
+ max_length=max_length,
748
+ pad_image_to_square=True,
749
+ debug=False,
750
+ repeats=1,
751
+ )
752
+
753
+ mdpv_qa_cocostuff10k_dataset = dict(
754
+ type=MDPVPointBriefCaptionDataset,
755
+ data_path=mdpv_qa_cocostuff10k_data_path,
756
+ image_folder=mdpv_qa_cocostuff10k_image_path,
757
+ tokenizer=tokenizer,
758
+ image_processor=image_processor,
759
+ dataset_map_fn=mdpv_points_map_fn,
760
+ template_map_fn=dict(
761
+ type=template_map_fn_factory, template=prompt_template),
762
+ max_length=max_length,
763
+ pad_image_to_square=True,
764
+ debug=False,
765
+ repeats=1,
766
+ )
767
+
768
+ mdpv_qa_cocostuff164k_dataset = dict(
769
+ type=MDPVPointBriefCaptionDataset,
770
+ data_path=mdpv_qa_cocostuff164k_data_path,
771
+ image_folder=mdpv_qa_cocostuff164k_image_path,
772
+ tokenizer=tokenizer,
773
+ image_processor=image_processor,
774
+ dataset_map_fn=mdpv_points_map_fn,
775
+ template_map_fn=dict(
776
+ type=template_map_fn_factory, template=prompt_template),
777
+ max_length=max_length,
778
+ pad_image_to_square=True,
779
+ debug=False,
780
+ repeats=1,
781
+ )
782
+
783
+ mdpv_multi_points_openpsg_dataset = dict(
784
+ type=MDPVPointBriefCaptionDataset,
785
+ data_path=mdpv_multi_points_openpsg_data_path,
786
+ image_folder=mdpv_multi_points_openpsg_image_path,
787
+ tokenizer=tokenizer,
788
+ image_processor=image_processor,
789
+ dataset_map_fn=mdpv_points_map_fn,
790
+ template_map_fn=dict(
791
+ type=template_map_fn_factory, template=prompt_template),
792
+ max_length=max_length,
793
+ pad_image_to_square=True,
794
+ debug=False,
795
+ repeats=1,
796
+ )
797
+
798
+ mdpv_multi_points_flicker30k_dataset = dict(
799
+ type=MDPVPointBriefCaptionDataset,
800
+ data_path=mdpv_multi_points_flicker30k_data_path,
801
+ image_folder=mdpv_multi_points_flicker30k_image_path,
802
+ tokenizer=tokenizer,
803
+ image_processor=image_processor,
804
+ dataset_map_fn=mdpv_points_map_fn,
805
+ template_map_fn=dict(
806
+ type=template_map_fn_factory, template=prompt_template),
807
+ max_length=max_length,
808
+ pad_image_to_square=True,
809
+ debug=False,
810
+ repeats=1,
811
+ )
812
+
813
+ train_dataset = dict(
814
+ type=CombineDataset,
815
+ datasets_cfgs=[glamm_refcocog_dataset,],
816
+ )
817
+
818
+ train_dataloader = dict(
819
+ batch_size=batch_size,
820
+ num_workers=dataloader_num_workers,
821
+ dataset=train_dataset,
822
+ sampler=dict(
823
+ type=LengthGroupedSampler,
824
+ length_property='modality_length',
825
+ per_device_batch_size=batch_size * accumulative_counts),
826
+ collate_fn=dict(type=omg_llava_collate_fn))
827
+
828
+ #######################################################################
829
+ # PART 4 Scheduler & Optimizer #
830
+ #######################################################################
831
+ # optimizer
832
+ optim_wrapper = dict(
833
+ type=AmpOptimWrapper,
834
+ optimizer=dict(
835
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
836
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
837
+ accumulative_counts=accumulative_counts,
838
+ loss_scale='dynamic',
839
+ dtype='float16')
840
+
841
+ # learning policy
842
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
843
+ param_scheduler = [
844
+ dict(
845
+ type=LinearLR,
846
+ start_factor=1e-5,
847
+ by_epoch=True,
848
+ begin=0,
849
+ end=warmup_ratio * max_epochs,
850
+ convert_to_iter_based=True),
851
+ dict(
852
+ type=CosineAnnealingLR,
853
+ eta_min=0.0,
854
+ by_epoch=True,
855
+ begin=warmup_ratio * max_epochs,
856
+ end=max_epochs,
857
+ convert_to_iter_based=True)
858
+ ]
859
+
860
+ # train, val, test setting
861
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
862
+
863
+ #######################################################################
864
+ # PART 5 Runtime #
865
+ #######################################################################
866
+ # Log the dialogue periodically during the training process, optional
867
+ custom_hooks = [
868
+ dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer),
869
+ dict(
870
+ type=EvaluateChatHook_withSpecialTokens,
871
+ tokenizer=tokenizer,
872
+ image_processor=image_processor,
873
+ every_n_iters=evaluation_freq,
874
+ evaluation_inputs=evaluation_inputs,
875
+ evaluation_images=evaluation_images,
876
+ system=SYSTEM,
877
+ prompt_template=prompt_template)
878
+ ]
879
+
880
+ # configure default hooks
881
+ default_hooks = dict(
882
+ # record the time of every iteration.
883
+ timer=dict(type=IterTimerHook),
884
+ # print log every 10 iterations.
885
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
886
+ # enable the parameter scheduler.
887
+ param_scheduler=dict(type=ParamSchedulerHook),
888
+ # save checkpoint per `save_steps`.
889
+ checkpoint=dict(
890
+ type=CheckpointHook,
891
+ by_epoch=False,
892
+ interval=save_steps,
893
+ max_keep_ckpts=save_total_limit),
894
+ # set sampler seed in distributed evrionment.
895
+ sampler_seed=dict(type=DistSamplerSeedHook),
896
+ )
897
+
898
+ # configure environment
899
+ env_cfg = dict(
900
+ # whether to enable cudnn benchmark
901
+ cudnn_benchmark=False,
902
+ # set multi process parameters
903
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
904
+ # set distributed parameters
905
+ dist_cfg=dict(backend='nccl'),
906
+ )
907
+
908
+ # set visualizer
909
+ visualizer = None
910
+
911
+ # set log level
912
+ log_level = 'INFO'
913
+
914
+ # load from which checkpoint
915
+ load_from = None
916
+
917
+ # whether to resume training from the loaded checkpoint
918
+ resume = False
919
+
920
+ # Defaults to use random seed and disable `deterministic`
921
+ randomness = dict(seed=None, deterministic=False)
922
+
923
+ # set log processor
924
+ log_processor = dict(by_epoch=False)
omg_llava/configs/finetune/ablation_projector/finetune_ablation_projector_cross.py ADDED
@@ -0,0 +1,953 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
4
+ LoggerHook, ParamSchedulerHook)
5
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
6
+ from peft import LoraConfig
7
+ from torch.optim import AdamW
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig, CLIPImageProcessor,
10
+ CLIPVisionModel)
11
+
12
+ from omg_llava.dataset import LLaVADataset
13
+ from omg_llava.dataset.collect_fns import omg_llava_collate_fn
14
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
15
+ from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\
16
+ CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\
17
+ ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\
18
+ PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\
19
+ RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\
20
+ Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\
21
+ OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\
22
+ OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\
23
+ MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\
24
+ semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\
25
+ referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn
26
+ from xtuner.dataset.samplers import LengthGroupedSampler
27
+ from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens
28
+ from xtuner.engine.runner import TrainLoop
29
+ from omg_llava.model import OMG_LLaVA
30
+ from xtuner.utils import PROMPT_TEMPLATE
31
+ from omg_llava.model import OpenCLIPBackbone_omgseg
32
+ from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead
33
+
34
+ from torch.nn import GroupNorm, ReLU
35
+
36
+ from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \
37
+ DiceLoss, MaskFormerFusionHead, FocalLoss
38
+ from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost
39
+ from mmdet.models.task_modules.samplers import MaskPseudoSampler
40
+
41
+ #######################################################################
42
+ # PART 1 Settings #
43
+ #######################################################################
44
+ # Model
45
+ llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path
46
+ pretrained_pth = './pretrained/omg_llava/omg_llava_7b_pretrain_1024image_seperate_cross_8gpus.pth'
47
+ omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path
48
+ omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path
49
+
50
+ # Data
51
+ data_root = './data/llava_data/'
52
+ data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json'
53
+ image_folder = data_root + 'llava_images'
54
+
55
+ glamm_data_root = './data/glamm_data/'
56
+
57
+ refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/'
58
+ refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json'
59
+
60
+ grandf_image_path = glamm_data_root + 'images/grandf/train/'
61
+ grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json'
62
+
63
+ flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/'
64
+ flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json'
65
+
66
+ psg_image_path = glamm_data_root + 'images/coco2017/'
67
+ psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json'
68
+
69
+ ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
70
+ ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json'
71
+
72
+ cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/'
73
+ cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt'
74
+ cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/'
75
+
76
+ mapillary_image_path = './data/semantic_seg/mapillary/training/images/'
77
+ mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json'
78
+ mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/'
79
+
80
+ pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/'
81
+ pascal_file = './data/semantic_seg/pascal_part/train.json'
82
+
83
+ paco_image_path = './data/glamm_data/images/coco2017/'
84
+ paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json'
85
+
86
+ referring_refcoco_image_path = refcocog_image_path
87
+ referring_refcoco_data_path = "./data/ref_seg/"
88
+
89
+ referring_refcoco_plus_image_path = refcocog_image_path
90
+ referring_refcoco_plus_data_path = "./data/ref_seg/"
91
+
92
+ referring_refcocog_image_path = refcocog_image_path
93
+ referring_refcocog_data_path = "./data/ref_seg/"
94
+
95
+ referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/"
96
+ referring_refclef_data_path = "./data/ref_seg/"
97
+
98
+ region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
99
+ region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json"
100
+
101
+ region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
102
+ region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json"
103
+
104
+ mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
105
+ mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json'
106
+
107
+ mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
108
+ mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json'
109
+
110
+ mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
111
+ mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json'
112
+
113
+ mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
114
+ mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json'
115
+
116
+ mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
117
+ mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json'
118
+
119
+ mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
120
+ mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json'
121
+
122
+ mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
123
+ mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json'
124
+
125
+ mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
126
+ mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json'
127
+
128
+ mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
129
+ mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json'
130
+
131
+ mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
132
+ mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json'
133
+
134
+ mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
135
+ mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json'
136
+
137
+ mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017'
138
+ mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json'
139
+
140
+ mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
141
+ mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json'
142
+
143
+ mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
144
+ mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json'
145
+
146
+ mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/'
147
+ mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json'
148
+
149
+ mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017'
150
+ mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json'
151
+
152
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
153
+ max_length = int(2048 - (1024 / 64)**2 - 100)
154
+
155
+ # Scheduler & Optimizer
156
+ batch_size = 8 # per_device
157
+ accumulative_counts = 2
158
+ dataloader_num_workers = 4
159
+ max_epochs = 1
160
+ optim_type = AdamW
161
+ lr = 2e-4
162
+ betas = (0.9, 0.999)
163
+ weight_decay = 0
164
+ max_norm = 1 # grad clip
165
+ warmup_ratio = 0.03
166
+
167
+
168
+ # Save
169
+ save_steps = 2000
170
+ save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited)
171
+
172
+ # Evaluate the generation performance during the training
173
+ evaluation_freq = 2000
174
+ SYSTEM = ''
175
+ evaluation_images = './work_dirs/test.jpg'
176
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture',
177
+ 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.']
178
+
179
+ #######################################################################
180
+ # PART 2 Model & Tokenizer & Image Processor #
181
+ #######################################################################
182
+ tokenizer = dict(
183
+ type=AutoTokenizer.from_pretrained,
184
+ pretrained_model_name_or_path=llm_name_or_path,
185
+ trust_remote_code=True,
186
+ padding_side='right')
187
+
188
+ image_processor = dict(
189
+ type=CLIPImageProcessor,
190
+ do_resize=True,
191
+ size=1024,
192
+ resample=3,
193
+ do_center_crop=True,
194
+ crop_size=1024,
195
+ do_rescale=True,
196
+ do_normalize=True,
197
+ image_mean=[0.4814, 0.4578, 0.4082],
198
+ image_std=[0.2686, 0.2613, 0.2757],
199
+ do_convert_rgb=True
200
+ )
201
+
202
+ class_embed = 'convnext_large_d_320_CocoPanopticOVDataset'
203
+ num_things_classes = 80
204
+ num_stuff_classes = 53
205
+ num_classes = num_things_classes + num_stuff_classes
206
+
207
+ omgseg_model = dict(
208
+ type=OMGSegVisualEncoder,
209
+ data_preprocessor=None,
210
+ pixel_shuffle_down_ratio=2,
211
+ backbone=dict(
212
+ type=OpenCLIPBackbone_omgseg,
213
+ model_name='convnext_large_d_320',
214
+ fix=True,
215
+ init_cfg=dict(
216
+ type='clip_pretrain',
217
+ checkpoint='laion2b_s29b_b131k_ft_soup'
218
+ )
219
+ ),
220
+ panoptic_head=dict(
221
+ type=Mask2FormerVideoSemSamHead,
222
+ sphere_cls=True,
223
+ ov_path=omg_ov_class_embed_path,
224
+ enable_box_query=False,
225
+ ov_classifier_name=class_embed,
226
+ logit=None,
227
+ in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside
228
+ strides=[4, 8, 16, 32],
229
+ feat_channels=256,
230
+ out_channels=256,
231
+ num_things_classes=num_things_classes,
232
+ num_stuff_classes=num_stuff_classes,
233
+ num_queries=300,
234
+ num_transformer_feat_level=3,
235
+ pixel_decoder=dict(
236
+ type=MSDeformAttnPixelDecoder,
237
+ num_outs=3,
238
+ norm_cfg=dict(type=GroupNorm, num_groups=32),
239
+ act_cfg=dict(type=ReLU),
240
+ encoder=dict( # DeformableDetrTransformerEncoder
241
+ num_layers=6,
242
+ layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
243
+ self_attn_cfg=dict( # MultiScaleDeformableAttention
244
+ embed_dims=256,
245
+ num_heads=8,
246
+ num_levels=3,
247
+ num_points=4,
248
+ dropout=0.0,
249
+ batch_first=True),
250
+ ffn_cfg=dict(
251
+ embed_dims=256,
252
+ feedforward_channels=1024,
253
+ num_fcs=2,
254
+ ffn_drop=0.0,
255
+ act_cfg=dict(type=ReLU, inplace=True)))),
256
+ positional_encoding=dict(num_feats=128, normalize=True)),
257
+ enforce_decoder_input_project=False,
258
+ positional_encoding=dict(num_feats=128, normalize=True),
259
+ transformer_decoder=dict( # Mask2FormerTransformerDecoder
260
+ return_intermediate=True,
261
+ num_layers=9,
262
+ layer_cfg=dict( # Mask2FormerTransformerDecoderLayer
263
+ self_attn_cfg=dict( # MultiheadAttention
264
+ embed_dims=256,
265
+ num_heads=8,
266
+ dropout=0.0,
267
+ batch_first=True),
268
+ cross_attn_cfg=dict( # MultiheadAttention
269
+ embed_dims=256,
270
+ num_heads=8,
271
+ dropout=0.0,
272
+ batch_first=True),
273
+ ffn_cfg=dict(
274
+ embed_dims=256,
275
+ feedforward_channels=2048,
276
+ num_fcs=2,
277
+ ffn_drop=0.0,
278
+ act_cfg=dict(type='ReLU', inplace=True))),
279
+ init_cfg=None),
280
+ loss_cls=dict(
281
+ type=CrossEntropyLoss,
282
+ use_sigmoid=False,
283
+ loss_weight=2.0,
284
+ reduction='mean',
285
+ class_weight=[1.0] * 240 + [0.1]),
286
+ loss_mask=dict(
287
+ type=CrossEntropyLoss,
288
+ use_sigmoid=True,
289
+ reduction='mean',
290
+ loss_weight=5.0),
291
+ loss_dice=dict(
292
+ type=DiceLoss,
293
+ use_sigmoid=True,
294
+ activate=True,
295
+ reduction='mean',
296
+ naive_dice=True,
297
+ eps=1.0,
298
+ loss_weight=5.0),
299
+ loss_iou=dict(
300
+ type=FocalLoss,
301
+ use_sigmoid=True,
302
+ loss_weight=2.0,
303
+ reduction='mean')
304
+ ),
305
+ panoptic_fusion_head=dict(
306
+ type=MaskFormerFusionHead,
307
+ num_things_classes=num_things_classes,
308
+ num_stuff_classes=num_stuff_classes,
309
+ loss_panoptic=None,
310
+ init_cfg=None),
311
+ train_cfg=dict(
312
+ num_points=12544,
313
+ oversample_ratio=3.0,
314
+ importance_sample_ratio=0.75,
315
+ assigner=dict(
316
+ type=HungarianAssigner,
317
+ match_costs=[
318
+ # dict(type=FlexibleClassificationCost, weight=2.0),
319
+ dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True),
320
+ dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0)
321
+ ]),
322
+ sampler=dict(type=MaskPseudoSampler)),
323
+ test_cfg=dict(
324
+ panoptic_on=True,
325
+ # For now, the dataset does not support
326
+ # evaluating semantic segmentation metric.
327
+ semantic_on=False,
328
+ instance_on=True,
329
+ # max_per_image is for instance segmentation.
330
+ max_per_image=100,
331
+ iou_thr=0.8,
332
+ # In Mask2Former's panoptic postprocessing,
333
+ # it will filter mask area where score is less than 0.5 .
334
+ filter_low_score=True),
335
+ init_cfg=dict(
336
+ type='Pretrained',
337
+ checkpoint=omg_head_pretrain_pth_path,
338
+ )
339
+ )
340
+
341
+ model = dict(
342
+ type=OMG_LLaVA,
343
+ freeze_llm=True,
344
+ freeze_visual_encoder=True,
345
+ require_omg_decoder=False,
346
+ pretrained_pth=pretrained_pth,
347
+ text2vision_projector=True,
348
+ pixel_shuffle_ratio=2,
349
+ visual_prompt_proj=False,
350
+ add_cross_attn_layer=True,
351
+ llm=dict(
352
+ type=AutoModelForCausalLM.from_pretrained,
353
+ pretrained_model_name_or_path=llm_name_or_path,
354
+ trust_remote_code=True,
355
+ torch_dtype=torch.float16,
356
+ quantization_config=dict(
357
+ type=BitsAndBytesConfig,
358
+ load_in_4bit=True,
359
+ load_in_8bit=False,
360
+ llm_int8_threshold=6.0,
361
+ llm_int8_has_fp16_weight=False,
362
+ bnb_4bit_compute_dtype=torch.float16,
363
+ bnb_4bit_use_double_quant=True,
364
+ bnb_4bit_quant_type='nf4')),
365
+ llm_lora=dict(
366
+ type=LoraConfig,
367
+ r=512,
368
+ lora_alpha=256,
369
+ lora_dropout=0.05,
370
+ bias='none',
371
+ task_type='CAUSAL_LM'),
372
+ visual_encoder=omgseg_model,
373
+ tokenizer=tokenizer,
374
+ )
375
+
376
+ #######################################################################
377
+ # PART 3 Dataset & Dataloader #
378
+ #######################################################################
379
+ debug=False
380
+ llava_dataset = dict(
381
+ type=LLaVADataset,
382
+ data_path=data_path,
383
+ image_folder=image_folder,
384
+ tokenizer=tokenizer,
385
+ image_processor=image_processor,
386
+ dataset_map_fn=llava_map_fn,
387
+ template_map_fn=dict(
388
+ type=template_map_fn_factory, template=prompt_template),
389
+ max_length=max_length,
390
+ pad_image_to_square=True)
391
+
392
+ glamm_refcocog_dataset = dict(
393
+ type=RefCOCOgGCGDataset,
394
+ data_path=refcocog_ann_file,
395
+ image_folder=refcocog_image_path,
396
+ tokenizer=tokenizer,
397
+ image_processor=image_processor,
398
+ dataset_map_fn=glamm_refcocog_map_fn,
399
+ template_map_fn=dict(
400
+ type=template_map_fn_factory, template=prompt_template),
401
+ max_length=max_length,
402
+ pad_image_to_square=True,
403
+ debug=False,
404
+ repeats=1,
405
+ )
406
+
407
+ glamm_grandf_dataset = dict(
408
+ type=GranDfGCGDataset,
409
+ data_path=grandf_ann_file,
410
+ image_folder=grandf_image_path,
411
+ tokenizer=tokenizer,
412
+ image_processor=image_processor,
413
+ dataset_map_fn=glamm_granf_map_fn,
414
+ template_map_fn=dict(
415
+ type=template_map_fn_factory, template=prompt_template),
416
+ max_length=max_length,
417
+ pad_image_to_square=True,
418
+ debug=debug,
419
+ repeats=10,
420
+ )
421
+
422
+ glamm_psg_dataset = dict(
423
+ type=OpenPsgGCGDataset,
424
+ data_path=psg_ann_file,
425
+ image_folder=psg_image_path,
426
+ tokenizer=tokenizer,
427
+ image_processor=image_processor,
428
+ dataset_map_fn=glamm_openpsg_map_fn,
429
+ template_map_fn=dict(
430
+ type=template_map_fn_factory, template=prompt_template),
431
+ max_length=max_length,
432
+ pad_image_to_square=True,
433
+ debug=debug,
434
+ repeats=1,
435
+ )
436
+
437
+ glamm_flickr_dataset = dict(
438
+ type=FlickrGCGDataset,
439
+ data_path=flickr_ann_file,
440
+ image_folder=flickr_image_path,
441
+ tokenizer=tokenizer,
442
+ image_processor=image_processor,
443
+ dataset_map_fn=glamm_flickr_map_fn,
444
+ template_map_fn=dict(
445
+ type=template_map_fn_factory, template=prompt_template),
446
+ max_length=max_length,
447
+ pad_image_to_square=True,
448
+ debug=debug,
449
+ repeats=1,
450
+ )
451
+
452
+ semantic_seg_ade20k_dataset = dict(
453
+ type=ADE20kSemanticSegDataset,
454
+ data_path=ade20k_class_file,
455
+ image_folder=ade20k_image_path,
456
+ tokenizer=tokenizer,
457
+ image_processor=image_processor,
458
+ dataset_map_fn=semantic_seg_gcg_format_map_fn,
459
+ template_map_fn=dict(
460
+ type=template_map_fn_factory, template=prompt_template),
461
+ max_length=max_length,
462
+ pad_image_to_square=True,
463
+ debug=False,
464
+ repeats=1,
465
+ gcg_format=True,
466
+ )
467
+
468
+ semantic_seg_cocostuff_dataset = dict(
469
+ type=COCOStuffSemanticSegDataset,
470
+ data_path=cocostuff_class_file,
471
+ image_folder=cocostuff_image_path,
472
+ label_path=cocostuff_label_path,
473
+ tokenizer=tokenizer,
474
+ image_processor=image_processor,
475
+ dataset_map_fn=semantic_seg_gcg_format_map_fn,
476
+ template_map_fn=dict(
477
+ type=template_map_fn_factory, template=prompt_template),
478
+ max_length=max_length,
479
+ pad_image_to_square=True,
480
+ debug=False,
481
+ repeats=1,
482
+ gcg_format=True,
483
+ )
484
+
485
+ referring_seg_refcoco_dataset = dict(
486
+ type=RefcocoReferringSegDataset,
487
+ data_path=referring_refcoco_data_path,
488
+ image_folder=referring_refcoco_image_path,
489
+ tokenizer=tokenizer,
490
+ image_processor=image_processor,
491
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
492
+ template_map_fn=dict(
493
+ type=template_map_fn_factory, template=prompt_template),
494
+ max_length=max_length,
495
+ pad_image_to_square=True,
496
+ debug=False,
497
+ repeats=1,
498
+ )
499
+
500
+ referring_seg_refcoco_plus_dataset = dict(
501
+ type=Refcoco_plus_ReferringSegDataset,
502
+ data_path=referring_refcoco_plus_data_path,
503
+ image_folder=referring_refcoco_plus_image_path,
504
+ tokenizer=tokenizer,
505
+ image_processor=image_processor,
506
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
507
+ template_map_fn=dict(
508
+ type=template_map_fn_factory, template=prompt_template),
509
+ max_length=max_length,
510
+ pad_image_to_square=True,
511
+ debug=False,
512
+ repeats=1,
513
+ )
514
+
515
+ referring_seg_refcocog_dataset = dict(
516
+ type=Refcocog_ReferringSegDataset,
517
+ data_path=referring_refcocog_data_path,
518
+ image_folder=referring_refcocog_image_path,
519
+ tokenizer=tokenizer,
520
+ image_processor=image_processor,
521
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
522
+ template_map_fn=dict(
523
+ type=template_map_fn_factory, template=prompt_template),
524
+ max_length=max_length,
525
+ pad_image_to_square=True,
526
+ debug=False,
527
+ repeats=1,
528
+ )
529
+
530
+ referring_seg_refclef_dataset = dict(
531
+ type=Refclef_ReferringSegDataset,
532
+ data_path=referring_refclef_data_path,
533
+ image_folder=referring_refclef_image_path,
534
+ tokenizer=tokenizer,
535
+ image_processor=image_processor,
536
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
537
+ template_map_fn=dict(
538
+ type=template_map_fn_factory, template=prompt_template),
539
+ max_length=max_length,
540
+ pad_image_to_square=True,
541
+ debug=False,
542
+ repeats=1,
543
+ )
544
+
545
+ region_cap_osprey_dataset = dict(
546
+ type=OspreyRegionCaptionDataset,
547
+ data_path=region_cap_osprey_data_path,
548
+ image_folder=region_cap_osprey_image_path,
549
+ tokenizer=tokenizer,
550
+ image_processor=image_processor,
551
+ dataset_map_fn=osprey_region_caption_gcg_format_map_fn,
552
+ template_map_fn=dict(
553
+ type=template_map_fn_factory, template=prompt_template),
554
+ max_length=max_length,
555
+ pad_image_to_square=True,
556
+ debug=False,
557
+ repeats=1,
558
+ )
559
+
560
+ region_conversation_osprey_dataset = dict(
561
+ type=OspreyRegionConversationDataset,
562
+ data_path=region_conversation_osprey_data_path,
563
+ image_folder=region_conversation_osprey_image_path,
564
+ tokenizer=tokenizer,
565
+ image_processor=image_processor,
566
+ dataset_map_fn=osprey_region_conversation_map_fn,
567
+ template_map_fn=dict(
568
+ type=template_map_fn_factory, template=prompt_template),
569
+ max_length=max_length,
570
+ pad_image_to_square=True,
571
+ debug=False,
572
+ repeats=1,
573
+ )
574
+
575
+ mdpv_detailed_description_ade20k_dataset = dict(
576
+ type=MDPVPointDetailedCaptionDataset,
577
+ data_path=mdpv_detailed_caption_ade20k_data_path,
578
+ image_folder=mdpv_detailed_caption_ade20k_image_path,
579
+ tokenizer=tokenizer,
580
+ image_processor=image_processor,
581
+ dataset_map_fn=mdpv_points_map_fn,
582
+ template_map_fn=dict(
583
+ type=template_map_fn_factory, template=prompt_template),
584
+ max_length=max_length,
585
+ pad_image_to_square=True,
586
+ debug=False,
587
+ repeats=1,
588
+ )
589
+
590
+ mdpv_detailed_description_cocostuff_10k_dataset = dict(
591
+ type=MDPVPointDetailedCaptionDataset,
592
+ data_path=mdpv_detailed_caption_cocostuff_10k_data_path,
593
+ image_folder=mdpv_detailed_caption_cocostuff_10k_image_path,
594
+ tokenizer=tokenizer,
595
+ image_processor=image_processor,
596
+ dataset_map_fn=mdpv_points_map_fn,
597
+ template_map_fn=dict(
598
+ type=template_map_fn_factory, template=prompt_template),
599
+ max_length=max_length,
600
+ pad_image_to_square=True,
601
+ debug=False,
602
+ repeats=1,
603
+ )
604
+
605
+ mdpv_detailed_description_cocostuff_164k_dataset = dict(
606
+ type=MDPVPointDetailedCaptionDataset,
607
+ data_path=mdpv_detailed_caption_cocostuff_164k_data_path,
608
+ image_folder=mdpv_detailed_caption_cocostuff_164k_image_path,
609
+ tokenizer=tokenizer,
610
+ image_processor=image_processor,
611
+ dataset_map_fn=mdpv_points_map_fn,
612
+ template_map_fn=dict(
613
+ type=template_map_fn_factory, template=prompt_template),
614
+ max_length=max_length,
615
+ pad_image_to_square=True,
616
+ debug=False,
617
+ repeats=1,
618
+ )
619
+
620
+ mdpv_detailed_description_vg_dataset = dict(
621
+ type=MDPVPointDetailedCaptionDataset,
622
+ data_path=mdpv_detailed_caption_vg_data_path,
623
+ image_folder=mdpv_detailed_caption_vg_image_path,
624
+ tokenizer=tokenizer,
625
+ image_processor=image_processor,
626
+ dataset_map_fn=mdpv_points_map_fn,
627
+ template_map_fn=dict(
628
+ type=template_map_fn_factory, template=prompt_template),
629
+ max_length=max_length,
630
+ pad_image_to_square=True,
631
+ debug=False,
632
+ repeats=1,
633
+ )
634
+
635
+ mdpv_brief_description_vg_dataset = dict(
636
+ type=MDPVPointBriefCaptionDataset,
637
+ data_path=mdpv_brief_caption_vg_data_path,
638
+ image_folder=mdpv_brief_caption_vg_image_path,
639
+ tokenizer=tokenizer,
640
+ image_processor=image_processor,
641
+ dataset_map_fn=mdpv_points_map_fn,
642
+ template_map_fn=dict(
643
+ type=template_map_fn_factory, template=prompt_template),
644
+ max_length=max_length,
645
+ pad_image_to_square=True,
646
+ debug=False,
647
+ repeats=1,
648
+ )
649
+
650
+ mdpv_brief_description_cocostuff10k_dataset = dict(
651
+ type=MDPVPointBriefCaptionDataset,
652
+ data_path=mdpv_brief_caption_cocostuff_10k_data_path,
653
+ image_folder=mdpv_brief_caption_cocostuff_10k_image_path,
654
+ tokenizer=tokenizer,
655
+ image_processor=image_processor,
656
+ dataset_map_fn=mdpv_points_map_fn,
657
+ template_map_fn=dict(
658
+ type=template_map_fn_factory, template=prompt_template),
659
+ max_length=max_length,
660
+ pad_image_to_square=True,
661
+ debug=False,
662
+ repeats=1,
663
+ )
664
+
665
+ mdpv_brief_description_cocostuff164k_dataset = dict(
666
+ type=MDPVPointBriefCaptionDataset,
667
+ data_path=mdpv_brief_caption_cocostuff_164k_data_path,
668
+ image_folder=mdpv_brief_caption_cocostuff_164k_image_path,
669
+ tokenizer=tokenizer,
670
+ image_processor=image_processor,
671
+ dataset_map_fn=mdpv_points_map_fn,
672
+ template_map_fn=dict(
673
+ type=template_map_fn_factory, template=prompt_template),
674
+ max_length=max_length,
675
+ pad_image_to_square=True,
676
+ debug=False,
677
+ repeats=1,
678
+ )
679
+
680
+ mdpv_brief_description_ade20k_dataset = dict(
681
+ type=MDPVPointBriefCaptionDataset,
682
+ data_path=mdpv_brief_caption_ade20k_data_path,
683
+ image_folder=mdpv_brief_caption_ade20k_image_path,
684
+ tokenizer=tokenizer,
685
+ image_processor=image_processor,
686
+ dataset_map_fn=mdpv_points_map_fn,
687
+ template_map_fn=dict(
688
+ type=template_map_fn_factory, template=prompt_template),
689
+ max_length=max_length,
690
+ pad_image_to_square=True,
691
+ debug=False,
692
+ repeats=1,
693
+ )
694
+
695
+ mdpv_brief_description_lvis_dataset = dict(
696
+ type=MDPVPointBriefCaptionDataset,
697
+ data_path=mdpv_brief_caption_lvis_data_path,
698
+ image_folder=mdpv_brief_caption_lvis_image_path,
699
+ tokenizer=tokenizer,
700
+ image_processor=image_processor,
701
+ dataset_map_fn=mdpv_points_map_fn,
702
+ template_map_fn=dict(
703
+ type=template_map_fn_factory, template=prompt_template),
704
+ max_length=max_length,
705
+ pad_image_to_square=True,
706
+ debug=False,
707
+ repeats=1,
708
+ )
709
+
710
+ mdpv_qa_vg_dataset = dict(
711
+ type=MDPVPointBriefCaptionDataset,
712
+ data_path=mdpv_qa_vg_data_path,
713
+ image_folder=mdpv_qa_vg_image_path,
714
+ tokenizer=tokenizer,
715
+ image_processor=image_processor,
716
+ dataset_map_fn=mdpv_points_map_fn,
717
+ template_map_fn=dict(
718
+ type=template_map_fn_factory, template=prompt_template),
719
+ max_length=max_length,
720
+ pad_image_to_square=True,
721
+ debug=False,
722
+ repeats=1,
723
+ )
724
+
725
+ mdpv_qa_ade20k_dataset = dict(
726
+ type=MDPVPointBriefCaptionDataset,
727
+ data_path=mdpv_qa_ade20k_data_path,
728
+ image_folder=mdpv_qa_ade20k_image_path,
729
+ tokenizer=tokenizer,
730
+ image_processor=image_processor,
731
+ dataset_map_fn=mdpv_points_map_fn,
732
+ template_map_fn=dict(
733
+ type=template_map_fn_factory, template=prompt_template),
734
+ max_length=max_length,
735
+ pad_image_to_square=True,
736
+ debug=False,
737
+ repeats=1,
738
+ )
739
+
740
+ mdpv_qa_lvis_dataset = dict(
741
+ type=MDPVPointBriefCaptionDataset,
742
+ data_path=mdpv_qa_lvis_data_path,
743
+ image_folder=mdpv_qa_lvis_image_path,
744
+ tokenizer=tokenizer,
745
+ image_processor=image_processor,
746
+ dataset_map_fn=mdpv_points_map_fn,
747
+ template_map_fn=dict(
748
+ type=template_map_fn_factory, template=prompt_template),
749
+ max_length=max_length,
750
+ pad_image_to_square=True,
751
+ debug=False,
752
+ repeats=1,
753
+ )
754
+
755
+ mdpv_qa_cocostuff10k_dataset = dict(
756
+ type=MDPVPointBriefCaptionDataset,
757
+ data_path=mdpv_qa_cocostuff10k_data_path,
758
+ image_folder=mdpv_qa_cocostuff10k_image_path,
759
+ tokenizer=tokenizer,
760
+ image_processor=image_processor,
761
+ dataset_map_fn=mdpv_points_map_fn,
762
+ template_map_fn=dict(
763
+ type=template_map_fn_factory, template=prompt_template),
764
+ max_length=max_length,
765
+ pad_image_to_square=True,
766
+ debug=False,
767
+ repeats=1,
768
+ )
769
+
770
+ mdpv_qa_cocostuff164k_dataset = dict(
771
+ type=MDPVPointBriefCaptionDataset,
772
+ data_path=mdpv_qa_cocostuff164k_data_path,
773
+ image_folder=mdpv_qa_cocostuff164k_image_path,
774
+ tokenizer=tokenizer,
775
+ image_processor=image_processor,
776
+ dataset_map_fn=mdpv_points_map_fn,
777
+ template_map_fn=dict(
778
+ type=template_map_fn_factory, template=prompt_template),
779
+ max_length=max_length,
780
+ pad_image_to_square=True,
781
+ debug=False,
782
+ repeats=1,
783
+ )
784
+
785
+ mdpv_multi_points_openpsg_dataset = dict(
786
+ type=MDPVPointBriefCaptionDataset,
787
+ data_path=mdpv_multi_points_openpsg_data_path,
788
+ image_folder=mdpv_multi_points_openpsg_image_path,
789
+ tokenizer=tokenizer,
790
+ image_processor=image_processor,
791
+ dataset_map_fn=mdpv_points_map_fn,
792
+ template_map_fn=dict(
793
+ type=template_map_fn_factory, template=prompt_template),
794
+ max_length=max_length,
795
+ pad_image_to_square=True,
796
+ debug=False,
797
+ repeats=1,
798
+ )
799
+
800
+ mdpv_multi_points_flicker30k_dataset = dict(
801
+ type=MDPVPointBriefCaptionDataset,
802
+ data_path=mdpv_multi_points_flicker30k_data_path,
803
+ image_folder=mdpv_multi_points_flicker30k_image_path,
804
+ tokenizer=tokenizer,
805
+ image_processor=image_processor,
806
+ dataset_map_fn=mdpv_points_map_fn,
807
+ template_map_fn=dict(
808
+ type=template_map_fn_factory, template=prompt_template),
809
+ max_length=max_length,
810
+ pad_image_to_square=True,
811
+ debug=False,
812
+ repeats=1,
813
+ )
814
+
815
+ train_dataset = dict(
816
+ type=CombineDataset,
817
+ datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset,
818
+ glamm_grandf_dataset, glamm_psg_dataset,
819
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x
820
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset,
821
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset,
822
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
823
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x
824
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
825
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset,
826
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
827
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset,
828
+ region_cap_osprey_dataset, region_conversation_osprey_dataset,
829
+ mdpv_detailed_description_ade20k_dataset,
830
+ mdpv_detailed_description_cocostuff_10k_dataset,
831
+ mdpv_detailed_description_cocostuff_164k_dataset,
832
+ mdpv_detailed_description_vg_dataset,
833
+ mdpv_brief_description_lvis_dataset,
834
+ mdpv_brief_description_vg_dataset,
835
+ mdpv_brief_description_ade20k_dataset,
836
+ mdpv_brief_description_cocostuff10k_dataset,
837
+ mdpv_brief_description_cocostuff164k_dataset,
838
+ mdpv_qa_vg_dataset,
839
+ mdpv_qa_lvis_dataset,
840
+ mdpv_qa_ade20k_dataset,
841
+ mdpv_qa_cocostuff10k_dataset,
842
+ mdpv_qa_cocostuff164k_dataset,
843
+ mdpv_multi_points_flicker30k_dataset,
844
+ mdpv_multi_points_openpsg_dataset,],
845
+ )
846
+
847
+ train_dataloader = dict(
848
+ batch_size=batch_size,
849
+ num_workers=dataloader_num_workers,
850
+ dataset=train_dataset,
851
+ sampler=dict(
852
+ type=LengthGroupedSampler,
853
+ length_property='modality_length',
854
+ per_device_batch_size=batch_size * accumulative_counts),
855
+ collate_fn=dict(type=omg_llava_collate_fn))
856
+
857
+ #######################################################################
858
+ # PART 4 Scheduler & Optimizer #
859
+ #######################################################################
860
+ # optimizer
861
+ optim_wrapper = dict(
862
+ type=AmpOptimWrapper,
863
+ optimizer=dict(
864
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
865
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
866
+ accumulative_counts=accumulative_counts,
867
+ loss_scale='dynamic',
868
+ dtype='float16')
869
+
870
+ # learning policy
871
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
872
+ param_scheduler = [
873
+ dict(
874
+ type=LinearLR,
875
+ start_factor=1e-5,
876
+ by_epoch=True,
877
+ begin=0,
878
+ end=warmup_ratio * max_epochs,
879
+ convert_to_iter_based=True),
880
+ dict(
881
+ type=CosineAnnealingLR,
882
+ eta_min=0.0,
883
+ by_epoch=True,
884
+ begin=warmup_ratio * max_epochs,
885
+ end=max_epochs,
886
+ convert_to_iter_based=True)
887
+ ]
888
+
889
+ # train, val, test setting
890
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
891
+
892
+ #######################################################################
893
+ # PART 5 Runtime #
894
+ #######################################################################
895
+ # Log the dialogue periodically during the training process, optional
896
+ custom_hooks = [
897
+ dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer),
898
+ dict(
899
+ type=EvaluateChatHook_withSpecialTokens,
900
+ tokenizer=tokenizer,
901
+ image_processor=image_processor,
902
+ every_n_iters=evaluation_freq,
903
+ evaluation_inputs=evaluation_inputs,
904
+ evaluation_images=evaluation_images,
905
+ system=SYSTEM,
906
+ prompt_template=prompt_template)
907
+ ]
908
+
909
+ # configure default hooks
910
+ default_hooks = dict(
911
+ # record the time of every iteration.
912
+ timer=dict(type=IterTimerHook),
913
+ # print log every 10 iterations.
914
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
915
+ # enable the parameter scheduler.
916
+ param_scheduler=dict(type=ParamSchedulerHook),
917
+ # save checkpoint per `save_steps`.
918
+ checkpoint=dict(
919
+ type=CheckpointHook,
920
+ by_epoch=False,
921
+ interval=save_steps,
922
+ max_keep_ckpts=save_total_limit),
923
+ # set sampler seed in distributed evrionment.
924
+ sampler_seed=dict(type=DistSamplerSeedHook),
925
+ )
926
+
927
+ # configure environment
928
+ env_cfg = dict(
929
+ # whether to enable cudnn benchmark
930
+ cudnn_benchmark=False,
931
+ # set multi process parameters
932
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
933
+ # set distributed parameters
934
+ dist_cfg=dict(backend='nccl'),
935
+ )
936
+
937
+ # set visualizer
938
+ visualizer = None
939
+
940
+ # set log level
941
+ log_level = 'INFO'
942
+
943
+ # load from which checkpoint
944
+ load_from = None
945
+
946
+ # whether to resume training from the loaded checkpoint
947
+ resume = False
948
+
949
+ # Defaults to use random seed and disable `deterministic`
950
+ randomness = dict(seed=None, deterministic=False)
951
+
952
+ # set log processor
953
+ log_processor = dict(by_epoch=False)
omg_llava/configs/finetune/ablation_projector/finetune_ablation_projector_seperate.py ADDED
@@ -0,0 +1,953 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
4
+ LoggerHook, ParamSchedulerHook)
5
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
6
+ from peft import LoraConfig
7
+ from torch.optim import AdamW
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig, CLIPImageProcessor,
10
+ CLIPVisionModel)
11
+
12
+ from omg_llava.dataset import LLaVADataset
13
+ from omg_llava.dataset.collect_fns import omg_llava_collate_fn
14
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
15
+ from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\
16
+ CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\
17
+ ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\
18
+ PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\
19
+ RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\
20
+ Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\
21
+ OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\
22
+ OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\
23
+ MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\
24
+ semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\
25
+ referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn
26
+ from xtuner.dataset.samplers import LengthGroupedSampler
27
+ from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens
28
+ from xtuner.engine.runner import TrainLoop
29
+ from omg_llava.model import OMG_LLaVA
30
+ from xtuner.utils import PROMPT_TEMPLATE
31
+ from omg_llava.model import OpenCLIPBackbone_omgseg
32
+ from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead
33
+
34
+ from torch.nn import GroupNorm, ReLU
35
+
36
+ from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \
37
+ DiceLoss, MaskFormerFusionHead, FocalLoss
38
+ from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost
39
+ from mmdet.models.task_modules.samplers import MaskPseudoSampler
40
+
41
+ #######################################################################
42
+ # PART 1 Settings #
43
+ #######################################################################
44
+ # Model
45
+ llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path
46
+ pretrained_pth = './pretrained/omg_llava/omg_llava_7b_pretrain_1024image_8gpus.pth'
47
+ omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path
48
+ omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path
49
+
50
+ # Data
51
+ data_root = './data/llava_data/'
52
+ data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json'
53
+ image_folder = data_root + 'llava_images'
54
+
55
+ glamm_data_root = './data/glamm_data/'
56
+
57
+ refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/'
58
+ refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json'
59
+
60
+ grandf_image_path = glamm_data_root + 'images/grandf/train/'
61
+ grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json'
62
+
63
+ flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/'
64
+ flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json'
65
+
66
+ psg_image_path = glamm_data_root + 'images/coco2017/'
67
+ psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json'
68
+
69
+ ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
70
+ ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json'
71
+
72
+ cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/'
73
+ cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt'
74
+ cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/'
75
+
76
+ mapillary_image_path = './data/semantic_seg/mapillary/training/images/'
77
+ mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json'
78
+ mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/'
79
+
80
+ pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/'
81
+ pascal_file = './data/semantic_seg/pascal_part/train.json'
82
+
83
+ paco_image_path = './data/glamm_data/images/coco2017/'
84
+ paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json'
85
+
86
+ referring_refcoco_image_path = refcocog_image_path
87
+ referring_refcoco_data_path = "./data/ref_seg/"
88
+
89
+ referring_refcoco_plus_image_path = refcocog_image_path
90
+ referring_refcoco_plus_data_path = "./data/ref_seg/"
91
+
92
+ referring_refcocog_image_path = refcocog_image_path
93
+ referring_refcocog_data_path = "./data/ref_seg/"
94
+
95
+ referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/"
96
+ referring_refclef_data_path = "./data/ref_seg/"
97
+
98
+ region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
99
+ region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json"
100
+
101
+ region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
102
+ region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json"
103
+
104
+ mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
105
+ mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json'
106
+
107
+ mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
108
+ mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json'
109
+
110
+ mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
111
+ mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json'
112
+
113
+ mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
114
+ mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json'
115
+
116
+ mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
117
+ mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json'
118
+
119
+ mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
120
+ mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json'
121
+
122
+ mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
123
+ mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json'
124
+
125
+ mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
126
+ mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json'
127
+
128
+ mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
129
+ mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json'
130
+
131
+ mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
132
+ mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json'
133
+
134
+ mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
135
+ mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json'
136
+
137
+ mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017'
138
+ mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json'
139
+
140
+ mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
141
+ mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json'
142
+
143
+ mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
144
+ mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json'
145
+
146
+ mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/'
147
+ mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json'
148
+
149
+ mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017'
150
+ mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json'
151
+
152
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
153
+ max_length = int(2048 - (1024 / 64)**2 - 100)
154
+
155
+ # Scheduler & Optimizer
156
+ batch_size = 8 # per_device
157
+ accumulative_counts = 2
158
+ dataloader_num_workers = 4
159
+ max_epochs = 1
160
+ optim_type = AdamW
161
+ lr = 2e-4
162
+ betas = (0.9, 0.999)
163
+ weight_decay = 0
164
+ max_norm = 1 # grad clip
165
+ warmup_ratio = 0.03
166
+
167
+
168
+ # Save
169
+ save_steps = 2000
170
+ save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited)
171
+
172
+ # Evaluate the generation performance during the training
173
+ evaluation_freq = 2000
174
+ SYSTEM = ''
175
+ evaluation_images = './work_dirs/test.jpg'
176
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture',
177
+ 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.']
178
+
179
+ #######################################################################
180
+ # PART 2 Model & Tokenizer & Image Processor #
181
+ #######################################################################
182
+ tokenizer = dict(
183
+ type=AutoTokenizer.from_pretrained,
184
+ pretrained_model_name_or_path=llm_name_or_path,
185
+ trust_remote_code=True,
186
+ padding_side='right')
187
+
188
+ image_processor = dict(
189
+ type=CLIPImageProcessor,
190
+ do_resize=True,
191
+ size=1024,
192
+ resample=3,
193
+ do_center_crop=True,
194
+ crop_size=1024,
195
+ do_rescale=True,
196
+ do_normalize=True,
197
+ image_mean=[0.4814, 0.4578, 0.4082],
198
+ image_std=[0.2686, 0.2613, 0.2757],
199
+ do_convert_rgb=True
200
+ )
201
+
202
+ class_embed = 'convnext_large_d_320_CocoPanopticOVDataset'
203
+ num_things_classes = 80
204
+ num_stuff_classes = 53
205
+ num_classes = num_things_classes + num_stuff_classes
206
+
207
+ omgseg_model = dict(
208
+ type=OMGSegVisualEncoder,
209
+ data_preprocessor=None,
210
+ pixel_shuffle_down_ratio=2,
211
+ backbone=dict(
212
+ type=OpenCLIPBackbone_omgseg,
213
+ model_name='convnext_large_d_320',
214
+ fix=True,
215
+ init_cfg=dict(
216
+ type='clip_pretrain',
217
+ checkpoint='laion2b_s29b_b131k_ft_soup'
218
+ )
219
+ ),
220
+ panoptic_head=dict(
221
+ type=Mask2FormerVideoSemSamHead,
222
+ sphere_cls=True,
223
+ ov_path=omg_ov_class_embed_path,
224
+ enable_box_query=False,
225
+ ov_classifier_name=class_embed,
226
+ logit=None,
227
+ in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside
228
+ strides=[4, 8, 16, 32],
229
+ feat_channels=256,
230
+ out_channels=256,
231
+ num_things_classes=num_things_classes,
232
+ num_stuff_classes=num_stuff_classes,
233
+ num_queries=300,
234
+ num_transformer_feat_level=3,
235
+ pixel_decoder=dict(
236
+ type=MSDeformAttnPixelDecoder,
237
+ num_outs=3,
238
+ norm_cfg=dict(type=GroupNorm, num_groups=32),
239
+ act_cfg=dict(type=ReLU),
240
+ encoder=dict( # DeformableDetrTransformerEncoder
241
+ num_layers=6,
242
+ layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
243
+ self_attn_cfg=dict( # MultiScaleDeformableAttention
244
+ embed_dims=256,
245
+ num_heads=8,
246
+ num_levels=3,
247
+ num_points=4,
248
+ dropout=0.0,
249
+ batch_first=True),
250
+ ffn_cfg=dict(
251
+ embed_dims=256,
252
+ feedforward_channels=1024,
253
+ num_fcs=2,
254
+ ffn_drop=0.0,
255
+ act_cfg=dict(type=ReLU, inplace=True)))),
256
+ positional_encoding=dict(num_feats=128, normalize=True)),
257
+ enforce_decoder_input_project=False,
258
+ positional_encoding=dict(num_feats=128, normalize=True),
259
+ transformer_decoder=dict( # Mask2FormerTransformerDecoder
260
+ return_intermediate=True,
261
+ num_layers=9,
262
+ layer_cfg=dict( # Mask2FormerTransformerDecoderLayer
263
+ self_attn_cfg=dict( # MultiheadAttention
264
+ embed_dims=256,
265
+ num_heads=8,
266
+ dropout=0.0,
267
+ batch_first=True),
268
+ cross_attn_cfg=dict( # MultiheadAttention
269
+ embed_dims=256,
270
+ num_heads=8,
271
+ dropout=0.0,
272
+ batch_first=True),
273
+ ffn_cfg=dict(
274
+ embed_dims=256,
275
+ feedforward_channels=2048,
276
+ num_fcs=2,
277
+ ffn_drop=0.0,
278
+ act_cfg=dict(type='ReLU', inplace=True))),
279
+ init_cfg=None),
280
+ loss_cls=dict(
281
+ type=CrossEntropyLoss,
282
+ use_sigmoid=False,
283
+ loss_weight=2.0,
284
+ reduction='mean',
285
+ class_weight=[1.0] * 240 + [0.1]),
286
+ loss_mask=dict(
287
+ type=CrossEntropyLoss,
288
+ use_sigmoid=True,
289
+ reduction='mean',
290
+ loss_weight=5.0),
291
+ loss_dice=dict(
292
+ type=DiceLoss,
293
+ use_sigmoid=True,
294
+ activate=True,
295
+ reduction='mean',
296
+ naive_dice=True,
297
+ eps=1.0,
298
+ loss_weight=5.0),
299
+ loss_iou=dict(
300
+ type=FocalLoss,
301
+ use_sigmoid=True,
302
+ loss_weight=2.0,
303
+ reduction='mean')
304
+ ),
305
+ panoptic_fusion_head=dict(
306
+ type=MaskFormerFusionHead,
307
+ num_things_classes=num_things_classes,
308
+ num_stuff_classes=num_stuff_classes,
309
+ loss_panoptic=None,
310
+ init_cfg=None),
311
+ train_cfg=dict(
312
+ num_points=12544,
313
+ oversample_ratio=3.0,
314
+ importance_sample_ratio=0.75,
315
+ assigner=dict(
316
+ type=HungarianAssigner,
317
+ match_costs=[
318
+ # dict(type=FlexibleClassificationCost, weight=2.0),
319
+ dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True),
320
+ dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0)
321
+ ]),
322
+ sampler=dict(type=MaskPseudoSampler)),
323
+ test_cfg=dict(
324
+ panoptic_on=True,
325
+ # For now, the dataset does not support
326
+ # evaluating semantic segmentation metric.
327
+ semantic_on=False,
328
+ instance_on=True,
329
+ # max_per_image is for instance segmentation.
330
+ max_per_image=100,
331
+ iou_thr=0.8,
332
+ # In Mask2Former's panoptic postprocessing,
333
+ # it will filter mask area where score is less than 0.5 .
334
+ filter_low_score=True),
335
+ init_cfg=dict(
336
+ type='Pretrained',
337
+ checkpoint=omg_head_pretrain_pth_path,
338
+ )
339
+ )
340
+
341
+ model = dict(
342
+ type=OMG_LLaVA,
343
+ freeze_llm=True,
344
+ freeze_visual_encoder=True,
345
+ require_omg_decoder=False,
346
+ pretrained_pth=pretrained_pth,
347
+ text2vision_projector=True,
348
+ pixel_shuffle_ratio=2,
349
+ visual_prompt_proj=True,
350
+ add_cross_attn_layer=False,
351
+ llm=dict(
352
+ type=AutoModelForCausalLM.from_pretrained,
353
+ pretrained_model_name_or_path=llm_name_or_path,
354
+ trust_remote_code=True,
355
+ torch_dtype=torch.float16,
356
+ quantization_config=dict(
357
+ type=BitsAndBytesConfig,
358
+ load_in_4bit=True,
359
+ load_in_8bit=False,
360
+ llm_int8_threshold=6.0,
361
+ llm_int8_has_fp16_weight=False,
362
+ bnb_4bit_compute_dtype=torch.float16,
363
+ bnb_4bit_use_double_quant=True,
364
+ bnb_4bit_quant_type='nf4')),
365
+ llm_lora=dict(
366
+ type=LoraConfig,
367
+ r=512,
368
+ lora_alpha=256,
369
+ lora_dropout=0.05,
370
+ bias='none',
371
+ task_type='CAUSAL_LM'),
372
+ visual_encoder=omgseg_model,
373
+ tokenizer=tokenizer,
374
+ )
375
+
376
+ #######################################################################
377
+ # PART 3 Dataset & Dataloader #
378
+ #######################################################################
379
+ debug=False
380
+ llava_dataset = dict(
381
+ type=LLaVADataset,
382
+ data_path=data_path,
383
+ image_folder=image_folder,
384
+ tokenizer=tokenizer,
385
+ image_processor=image_processor,
386
+ dataset_map_fn=llava_map_fn,
387
+ template_map_fn=dict(
388
+ type=template_map_fn_factory, template=prompt_template),
389
+ max_length=max_length,
390
+ pad_image_to_square=True)
391
+
392
+ glamm_refcocog_dataset = dict(
393
+ type=RefCOCOgGCGDataset,
394
+ data_path=refcocog_ann_file,
395
+ image_folder=refcocog_image_path,
396
+ tokenizer=tokenizer,
397
+ image_processor=image_processor,
398
+ dataset_map_fn=glamm_refcocog_map_fn,
399
+ template_map_fn=dict(
400
+ type=template_map_fn_factory, template=prompt_template),
401
+ max_length=max_length,
402
+ pad_image_to_square=True,
403
+ debug=False,
404
+ repeats=1,
405
+ )
406
+
407
+ glamm_grandf_dataset = dict(
408
+ type=GranDfGCGDataset,
409
+ data_path=grandf_ann_file,
410
+ image_folder=grandf_image_path,
411
+ tokenizer=tokenizer,
412
+ image_processor=image_processor,
413
+ dataset_map_fn=glamm_granf_map_fn,
414
+ template_map_fn=dict(
415
+ type=template_map_fn_factory, template=prompt_template),
416
+ max_length=max_length,
417
+ pad_image_to_square=True,
418
+ debug=debug,
419
+ repeats=10,
420
+ )
421
+
422
+ glamm_psg_dataset = dict(
423
+ type=OpenPsgGCGDataset,
424
+ data_path=psg_ann_file,
425
+ image_folder=psg_image_path,
426
+ tokenizer=tokenizer,
427
+ image_processor=image_processor,
428
+ dataset_map_fn=glamm_openpsg_map_fn,
429
+ template_map_fn=dict(
430
+ type=template_map_fn_factory, template=prompt_template),
431
+ max_length=max_length,
432
+ pad_image_to_square=True,
433
+ debug=debug,
434
+ repeats=1,
435
+ )
436
+
437
+ glamm_flickr_dataset = dict(
438
+ type=FlickrGCGDataset,
439
+ data_path=flickr_ann_file,
440
+ image_folder=flickr_image_path,
441
+ tokenizer=tokenizer,
442
+ image_processor=image_processor,
443
+ dataset_map_fn=glamm_flickr_map_fn,
444
+ template_map_fn=dict(
445
+ type=template_map_fn_factory, template=prompt_template),
446
+ max_length=max_length,
447
+ pad_image_to_square=True,
448
+ debug=debug,
449
+ repeats=1,
450
+ )
451
+
452
+ semantic_seg_ade20k_dataset = dict(
453
+ type=ADE20kSemanticSegDataset,
454
+ data_path=ade20k_class_file,
455
+ image_folder=ade20k_image_path,
456
+ tokenizer=tokenizer,
457
+ image_processor=image_processor,
458
+ dataset_map_fn=semantic_seg_gcg_format_map_fn,
459
+ template_map_fn=dict(
460
+ type=template_map_fn_factory, template=prompt_template),
461
+ max_length=max_length,
462
+ pad_image_to_square=True,
463
+ debug=False,
464
+ repeats=1,
465
+ gcg_format=True,
466
+ )
467
+
468
+ semantic_seg_cocostuff_dataset = dict(
469
+ type=COCOStuffSemanticSegDataset,
470
+ data_path=cocostuff_class_file,
471
+ image_folder=cocostuff_image_path,
472
+ label_path=cocostuff_label_path,
473
+ tokenizer=tokenizer,
474
+ image_processor=image_processor,
475
+ dataset_map_fn=semantic_seg_gcg_format_map_fn,
476
+ template_map_fn=dict(
477
+ type=template_map_fn_factory, template=prompt_template),
478
+ max_length=max_length,
479
+ pad_image_to_square=True,
480
+ debug=False,
481
+ repeats=1,
482
+ gcg_format=True,
483
+ )
484
+
485
+ referring_seg_refcoco_dataset = dict(
486
+ type=RefcocoReferringSegDataset,
487
+ data_path=referring_refcoco_data_path,
488
+ image_folder=referring_refcoco_image_path,
489
+ tokenizer=tokenizer,
490
+ image_processor=image_processor,
491
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
492
+ template_map_fn=dict(
493
+ type=template_map_fn_factory, template=prompt_template),
494
+ max_length=max_length,
495
+ pad_image_to_square=True,
496
+ debug=False,
497
+ repeats=1,
498
+ )
499
+
500
+ referring_seg_refcoco_plus_dataset = dict(
501
+ type=Refcoco_plus_ReferringSegDataset,
502
+ data_path=referring_refcoco_plus_data_path,
503
+ image_folder=referring_refcoco_plus_image_path,
504
+ tokenizer=tokenizer,
505
+ image_processor=image_processor,
506
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
507
+ template_map_fn=dict(
508
+ type=template_map_fn_factory, template=prompt_template),
509
+ max_length=max_length,
510
+ pad_image_to_square=True,
511
+ debug=False,
512
+ repeats=1,
513
+ )
514
+
515
+ referring_seg_refcocog_dataset = dict(
516
+ type=Refcocog_ReferringSegDataset,
517
+ data_path=referring_refcocog_data_path,
518
+ image_folder=referring_refcocog_image_path,
519
+ tokenizer=tokenizer,
520
+ image_processor=image_processor,
521
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
522
+ template_map_fn=dict(
523
+ type=template_map_fn_factory, template=prompt_template),
524
+ max_length=max_length,
525
+ pad_image_to_square=True,
526
+ debug=False,
527
+ repeats=1,
528
+ )
529
+
530
+ referring_seg_refclef_dataset = dict(
531
+ type=Refclef_ReferringSegDataset,
532
+ data_path=referring_refclef_data_path,
533
+ image_folder=referring_refclef_image_path,
534
+ tokenizer=tokenizer,
535
+ image_processor=image_processor,
536
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
537
+ template_map_fn=dict(
538
+ type=template_map_fn_factory, template=prompt_template),
539
+ max_length=max_length,
540
+ pad_image_to_square=True,
541
+ debug=False,
542
+ repeats=1,
543
+ )
544
+
545
+ region_cap_osprey_dataset = dict(
546
+ type=OspreyRegionCaptionDataset,
547
+ data_path=region_cap_osprey_data_path,
548
+ image_folder=region_cap_osprey_image_path,
549
+ tokenizer=tokenizer,
550
+ image_processor=image_processor,
551
+ dataset_map_fn=osprey_region_caption_gcg_format_map_fn,
552
+ template_map_fn=dict(
553
+ type=template_map_fn_factory, template=prompt_template),
554
+ max_length=max_length,
555
+ pad_image_to_square=True,
556
+ debug=False,
557
+ repeats=1,
558
+ )
559
+
560
+ region_conversation_osprey_dataset = dict(
561
+ type=OspreyRegionConversationDataset,
562
+ data_path=region_conversation_osprey_data_path,
563
+ image_folder=region_conversation_osprey_image_path,
564
+ tokenizer=tokenizer,
565
+ image_processor=image_processor,
566
+ dataset_map_fn=osprey_region_conversation_map_fn,
567
+ template_map_fn=dict(
568
+ type=template_map_fn_factory, template=prompt_template),
569
+ max_length=max_length,
570
+ pad_image_to_square=True,
571
+ debug=False,
572
+ repeats=1,
573
+ )
574
+
575
+ mdpv_detailed_description_ade20k_dataset = dict(
576
+ type=MDPVPointDetailedCaptionDataset,
577
+ data_path=mdpv_detailed_caption_ade20k_data_path,
578
+ image_folder=mdpv_detailed_caption_ade20k_image_path,
579
+ tokenizer=tokenizer,
580
+ image_processor=image_processor,
581
+ dataset_map_fn=mdpv_points_map_fn,
582
+ template_map_fn=dict(
583
+ type=template_map_fn_factory, template=prompt_template),
584
+ max_length=max_length,
585
+ pad_image_to_square=True,
586
+ debug=False,
587
+ repeats=1,
588
+ )
589
+
590
+ mdpv_detailed_description_cocostuff_10k_dataset = dict(
591
+ type=MDPVPointDetailedCaptionDataset,
592
+ data_path=mdpv_detailed_caption_cocostuff_10k_data_path,
593
+ image_folder=mdpv_detailed_caption_cocostuff_10k_image_path,
594
+ tokenizer=tokenizer,
595
+ image_processor=image_processor,
596
+ dataset_map_fn=mdpv_points_map_fn,
597
+ template_map_fn=dict(
598
+ type=template_map_fn_factory, template=prompt_template),
599
+ max_length=max_length,
600
+ pad_image_to_square=True,
601
+ debug=False,
602
+ repeats=1,
603
+ )
604
+
605
+ mdpv_detailed_description_cocostuff_164k_dataset = dict(
606
+ type=MDPVPointDetailedCaptionDataset,
607
+ data_path=mdpv_detailed_caption_cocostuff_164k_data_path,
608
+ image_folder=mdpv_detailed_caption_cocostuff_164k_image_path,
609
+ tokenizer=tokenizer,
610
+ image_processor=image_processor,
611
+ dataset_map_fn=mdpv_points_map_fn,
612
+ template_map_fn=dict(
613
+ type=template_map_fn_factory, template=prompt_template),
614
+ max_length=max_length,
615
+ pad_image_to_square=True,
616
+ debug=False,
617
+ repeats=1,
618
+ )
619
+
620
+ mdpv_detailed_description_vg_dataset = dict(
621
+ type=MDPVPointDetailedCaptionDataset,
622
+ data_path=mdpv_detailed_caption_vg_data_path,
623
+ image_folder=mdpv_detailed_caption_vg_image_path,
624
+ tokenizer=tokenizer,
625
+ image_processor=image_processor,
626
+ dataset_map_fn=mdpv_points_map_fn,
627
+ template_map_fn=dict(
628
+ type=template_map_fn_factory, template=prompt_template),
629
+ max_length=max_length,
630
+ pad_image_to_square=True,
631
+ debug=False,
632
+ repeats=1,
633
+ )
634
+
635
+ mdpv_brief_description_vg_dataset = dict(
636
+ type=MDPVPointBriefCaptionDataset,
637
+ data_path=mdpv_brief_caption_vg_data_path,
638
+ image_folder=mdpv_brief_caption_vg_image_path,
639
+ tokenizer=tokenizer,
640
+ image_processor=image_processor,
641
+ dataset_map_fn=mdpv_points_map_fn,
642
+ template_map_fn=dict(
643
+ type=template_map_fn_factory, template=prompt_template),
644
+ max_length=max_length,
645
+ pad_image_to_square=True,
646
+ debug=False,
647
+ repeats=1,
648
+ )
649
+
650
+ mdpv_brief_description_cocostuff10k_dataset = dict(
651
+ type=MDPVPointBriefCaptionDataset,
652
+ data_path=mdpv_brief_caption_cocostuff_10k_data_path,
653
+ image_folder=mdpv_brief_caption_cocostuff_10k_image_path,
654
+ tokenizer=tokenizer,
655
+ image_processor=image_processor,
656
+ dataset_map_fn=mdpv_points_map_fn,
657
+ template_map_fn=dict(
658
+ type=template_map_fn_factory, template=prompt_template),
659
+ max_length=max_length,
660
+ pad_image_to_square=True,
661
+ debug=False,
662
+ repeats=1,
663
+ )
664
+
665
+ mdpv_brief_description_cocostuff164k_dataset = dict(
666
+ type=MDPVPointBriefCaptionDataset,
667
+ data_path=mdpv_brief_caption_cocostuff_164k_data_path,
668
+ image_folder=mdpv_brief_caption_cocostuff_164k_image_path,
669
+ tokenizer=tokenizer,
670
+ image_processor=image_processor,
671
+ dataset_map_fn=mdpv_points_map_fn,
672
+ template_map_fn=dict(
673
+ type=template_map_fn_factory, template=prompt_template),
674
+ max_length=max_length,
675
+ pad_image_to_square=True,
676
+ debug=False,
677
+ repeats=1,
678
+ )
679
+
680
+ mdpv_brief_description_ade20k_dataset = dict(
681
+ type=MDPVPointBriefCaptionDataset,
682
+ data_path=mdpv_brief_caption_ade20k_data_path,
683
+ image_folder=mdpv_brief_caption_ade20k_image_path,
684
+ tokenizer=tokenizer,
685
+ image_processor=image_processor,
686
+ dataset_map_fn=mdpv_points_map_fn,
687
+ template_map_fn=dict(
688
+ type=template_map_fn_factory, template=prompt_template),
689
+ max_length=max_length,
690
+ pad_image_to_square=True,
691
+ debug=False,
692
+ repeats=1,
693
+ )
694
+
695
+ mdpv_brief_description_lvis_dataset = dict(
696
+ type=MDPVPointBriefCaptionDataset,
697
+ data_path=mdpv_brief_caption_lvis_data_path,
698
+ image_folder=mdpv_brief_caption_lvis_image_path,
699
+ tokenizer=tokenizer,
700
+ image_processor=image_processor,
701
+ dataset_map_fn=mdpv_points_map_fn,
702
+ template_map_fn=dict(
703
+ type=template_map_fn_factory, template=prompt_template),
704
+ max_length=max_length,
705
+ pad_image_to_square=True,
706
+ debug=False,
707
+ repeats=1,
708
+ )
709
+
710
+ mdpv_qa_vg_dataset = dict(
711
+ type=MDPVPointBriefCaptionDataset,
712
+ data_path=mdpv_qa_vg_data_path,
713
+ image_folder=mdpv_qa_vg_image_path,
714
+ tokenizer=tokenizer,
715
+ image_processor=image_processor,
716
+ dataset_map_fn=mdpv_points_map_fn,
717
+ template_map_fn=dict(
718
+ type=template_map_fn_factory, template=prompt_template),
719
+ max_length=max_length,
720
+ pad_image_to_square=True,
721
+ debug=False,
722
+ repeats=1,
723
+ )
724
+
725
+ mdpv_qa_ade20k_dataset = dict(
726
+ type=MDPVPointBriefCaptionDataset,
727
+ data_path=mdpv_qa_ade20k_data_path,
728
+ image_folder=mdpv_qa_ade20k_image_path,
729
+ tokenizer=tokenizer,
730
+ image_processor=image_processor,
731
+ dataset_map_fn=mdpv_points_map_fn,
732
+ template_map_fn=dict(
733
+ type=template_map_fn_factory, template=prompt_template),
734
+ max_length=max_length,
735
+ pad_image_to_square=True,
736
+ debug=False,
737
+ repeats=1,
738
+ )
739
+
740
+ mdpv_qa_lvis_dataset = dict(
741
+ type=MDPVPointBriefCaptionDataset,
742
+ data_path=mdpv_qa_lvis_data_path,
743
+ image_folder=mdpv_qa_lvis_image_path,
744
+ tokenizer=tokenizer,
745
+ image_processor=image_processor,
746
+ dataset_map_fn=mdpv_points_map_fn,
747
+ template_map_fn=dict(
748
+ type=template_map_fn_factory, template=prompt_template),
749
+ max_length=max_length,
750
+ pad_image_to_square=True,
751
+ debug=False,
752
+ repeats=1,
753
+ )
754
+
755
+ mdpv_qa_cocostuff10k_dataset = dict(
756
+ type=MDPVPointBriefCaptionDataset,
757
+ data_path=mdpv_qa_cocostuff10k_data_path,
758
+ image_folder=mdpv_qa_cocostuff10k_image_path,
759
+ tokenizer=tokenizer,
760
+ image_processor=image_processor,
761
+ dataset_map_fn=mdpv_points_map_fn,
762
+ template_map_fn=dict(
763
+ type=template_map_fn_factory, template=prompt_template),
764
+ max_length=max_length,
765
+ pad_image_to_square=True,
766
+ debug=False,
767
+ repeats=1,
768
+ )
769
+
770
+ mdpv_qa_cocostuff164k_dataset = dict(
771
+ type=MDPVPointBriefCaptionDataset,
772
+ data_path=mdpv_qa_cocostuff164k_data_path,
773
+ image_folder=mdpv_qa_cocostuff164k_image_path,
774
+ tokenizer=tokenizer,
775
+ image_processor=image_processor,
776
+ dataset_map_fn=mdpv_points_map_fn,
777
+ template_map_fn=dict(
778
+ type=template_map_fn_factory, template=prompt_template),
779
+ max_length=max_length,
780
+ pad_image_to_square=True,
781
+ debug=False,
782
+ repeats=1,
783
+ )
784
+
785
+ mdpv_multi_points_openpsg_dataset = dict(
786
+ type=MDPVPointBriefCaptionDataset,
787
+ data_path=mdpv_multi_points_openpsg_data_path,
788
+ image_folder=mdpv_multi_points_openpsg_image_path,
789
+ tokenizer=tokenizer,
790
+ image_processor=image_processor,
791
+ dataset_map_fn=mdpv_points_map_fn,
792
+ template_map_fn=dict(
793
+ type=template_map_fn_factory, template=prompt_template),
794
+ max_length=max_length,
795
+ pad_image_to_square=True,
796
+ debug=False,
797
+ repeats=1,
798
+ )
799
+
800
+ mdpv_multi_points_flicker30k_dataset = dict(
801
+ type=MDPVPointBriefCaptionDataset,
802
+ data_path=mdpv_multi_points_flicker30k_data_path,
803
+ image_folder=mdpv_multi_points_flicker30k_image_path,
804
+ tokenizer=tokenizer,
805
+ image_processor=image_processor,
806
+ dataset_map_fn=mdpv_points_map_fn,
807
+ template_map_fn=dict(
808
+ type=template_map_fn_factory, template=prompt_template),
809
+ max_length=max_length,
810
+ pad_image_to_square=True,
811
+ debug=False,
812
+ repeats=1,
813
+ )
814
+
815
+ train_dataset = dict(
816
+ type=CombineDataset,
817
+ datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset,
818
+ glamm_grandf_dataset, glamm_psg_dataset,
819
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x
820
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset,
821
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset,
822
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
823
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x
824
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
825
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset,
826
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
827
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset,
828
+ region_cap_osprey_dataset, region_conversation_osprey_dataset,
829
+ mdpv_detailed_description_ade20k_dataset,
830
+ mdpv_detailed_description_cocostuff_10k_dataset,
831
+ mdpv_detailed_description_cocostuff_164k_dataset,
832
+ mdpv_detailed_description_vg_dataset,
833
+ mdpv_brief_description_lvis_dataset,
834
+ mdpv_brief_description_vg_dataset,
835
+ mdpv_brief_description_ade20k_dataset,
836
+ mdpv_brief_description_cocostuff10k_dataset,
837
+ mdpv_brief_description_cocostuff164k_dataset,
838
+ mdpv_qa_vg_dataset,
839
+ mdpv_qa_lvis_dataset,
840
+ mdpv_qa_ade20k_dataset,
841
+ mdpv_qa_cocostuff10k_dataset,
842
+ mdpv_qa_cocostuff164k_dataset,
843
+ mdpv_multi_points_flicker30k_dataset,
844
+ mdpv_multi_points_openpsg_dataset,],
845
+ )
846
+
847
+ train_dataloader = dict(
848
+ batch_size=batch_size,
849
+ num_workers=dataloader_num_workers,
850
+ dataset=train_dataset,
851
+ sampler=dict(
852
+ type=LengthGroupedSampler,
853
+ length_property='modality_length',
854
+ per_device_batch_size=batch_size * accumulative_counts),
855
+ collate_fn=dict(type=omg_llava_collate_fn))
856
+
857
+ #######################################################################
858
+ # PART 4 Scheduler & Optimizer #
859
+ #######################################################################
860
+ # optimizer
861
+ optim_wrapper = dict(
862
+ type=AmpOptimWrapper,
863
+ optimizer=dict(
864
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
865
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
866
+ accumulative_counts=accumulative_counts,
867
+ loss_scale='dynamic',
868
+ dtype='float16')
869
+
870
+ # learning policy
871
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
872
+ param_scheduler = [
873
+ dict(
874
+ type=LinearLR,
875
+ start_factor=1e-5,
876
+ by_epoch=True,
877
+ begin=0,
878
+ end=warmup_ratio * max_epochs,
879
+ convert_to_iter_based=True),
880
+ dict(
881
+ type=CosineAnnealingLR,
882
+ eta_min=0.0,
883
+ by_epoch=True,
884
+ begin=warmup_ratio * max_epochs,
885
+ end=max_epochs,
886
+ convert_to_iter_based=True)
887
+ ]
888
+
889
+ # train, val, test setting
890
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
891
+
892
+ #######################################################################
893
+ # PART 5 Runtime #
894
+ #######################################################################
895
+ # Log the dialogue periodically during the training process, optional
896
+ custom_hooks = [
897
+ dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer),
898
+ dict(
899
+ type=EvaluateChatHook_withSpecialTokens,
900
+ tokenizer=tokenizer,
901
+ image_processor=image_processor,
902
+ every_n_iters=evaluation_freq,
903
+ evaluation_inputs=evaluation_inputs,
904
+ evaluation_images=evaluation_images,
905
+ system=SYSTEM,
906
+ prompt_template=prompt_template)
907
+ ]
908
+
909
+ # configure default hooks
910
+ default_hooks = dict(
911
+ # record the time of every iteration.
912
+ timer=dict(type=IterTimerHook),
913
+ # print log every 10 iterations.
914
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
915
+ # enable the parameter scheduler.
916
+ param_scheduler=dict(type=ParamSchedulerHook),
917
+ # save checkpoint per `save_steps`.
918
+ checkpoint=dict(
919
+ type=CheckpointHook,
920
+ by_epoch=False,
921
+ interval=save_steps,
922
+ max_keep_ckpts=save_total_limit),
923
+ # set sampler seed in distributed evrionment.
924
+ sampler_seed=dict(type=DistSamplerSeedHook),
925
+ )
926
+
927
+ # configure environment
928
+ env_cfg = dict(
929
+ # whether to enable cudnn benchmark
930
+ cudnn_benchmark=False,
931
+ # set multi process parameters
932
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
933
+ # set distributed parameters
934
+ dist_cfg=dict(backend='nccl'),
935
+ )
936
+
937
+ # set visualizer
938
+ visualizer = None
939
+
940
+ # set log level
941
+ log_level = 'INFO'
942
+
943
+ # load from which checkpoint
944
+ load_from = None
945
+
946
+ # whether to resume training from the loaded checkpoint
947
+ resume = False
948
+
949
+ # Defaults to use random seed and disable `deterministic`
950
+ randomness = dict(seed=None, deterministic=False)
951
+
952
+ # set log processor
953
+ log_processor = dict(by_epoch=False)
omg_llava/configs/finetune/ablation_projector/finetune_ablation_projector_seperate_cross.py ADDED
@@ -0,0 +1,953 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
4
+ LoggerHook, ParamSchedulerHook)
5
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
6
+ from peft import LoraConfig
7
+ from torch.optim import AdamW
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig, CLIPImageProcessor,
10
+ CLIPVisionModel)
11
+
12
+ from omg_llava.dataset import LLaVADataset
13
+ from omg_llava.dataset.collect_fns import omg_llava_collate_fn
14
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
15
+ from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\
16
+ CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\
17
+ ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\
18
+ PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\
19
+ RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\
20
+ Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\
21
+ OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\
22
+ OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\
23
+ MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\
24
+ semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\
25
+ referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn
26
+ from xtuner.dataset.samplers import LengthGroupedSampler
27
+ from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens
28
+ from xtuner.engine.runner import TrainLoop
29
+ from omg_llava.model import OMG_LLaVA
30
+ from xtuner.utils import PROMPT_TEMPLATE
31
+ from omg_llava.model import OpenCLIPBackbone_omgseg
32
+ from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead
33
+
34
+ from torch.nn import GroupNorm, ReLU
35
+
36
+ from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \
37
+ DiceLoss, MaskFormerFusionHead, FocalLoss
38
+ from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost
39
+ from mmdet.models.task_modules.samplers import MaskPseudoSampler
40
+
41
+ #######################################################################
42
+ # PART 1 Settings #
43
+ #######################################################################
44
+ # Model
45
+ llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path
46
+ pretrained_pth = './pretrained/omg_llava/omg_llava_7b_pretrain_1024image_seperate_cross_8gpus.pth'
47
+ omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path
48
+ omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path
49
+
50
+ # Data
51
+ data_root = './data/llava_data/'
52
+ data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json'
53
+ image_folder = data_root + 'llava_images'
54
+
55
+ glamm_data_root = './data/glamm_data/'
56
+
57
+ refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/'
58
+ refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json'
59
+
60
+ grandf_image_path = glamm_data_root + 'images/grandf/train/'
61
+ grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json'
62
+
63
+ flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/'
64
+ flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json'
65
+
66
+ psg_image_path = glamm_data_root + 'images/coco2017/'
67
+ psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json'
68
+
69
+ ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
70
+ ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json'
71
+
72
+ cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/'
73
+ cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt'
74
+ cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/'
75
+
76
+ mapillary_image_path = './data/semantic_seg/mapillary/training/images/'
77
+ mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json'
78
+ mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/'
79
+
80
+ pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/'
81
+ pascal_file = './data/semantic_seg/pascal_part/train.json'
82
+
83
+ paco_image_path = './data/glamm_data/images/coco2017/'
84
+ paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json'
85
+
86
+ referring_refcoco_image_path = refcocog_image_path
87
+ referring_refcoco_data_path = "./data/ref_seg/"
88
+
89
+ referring_refcoco_plus_image_path = refcocog_image_path
90
+ referring_refcoco_plus_data_path = "./data/ref_seg/"
91
+
92
+ referring_refcocog_image_path = refcocog_image_path
93
+ referring_refcocog_data_path = "./data/ref_seg/"
94
+
95
+ referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/"
96
+ referring_refclef_data_path = "./data/ref_seg/"
97
+
98
+ region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
99
+ region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json"
100
+
101
+ region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
102
+ region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json"
103
+
104
+ mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
105
+ mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json'
106
+
107
+ mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
108
+ mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json'
109
+
110
+ mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
111
+ mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json'
112
+
113
+ mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
114
+ mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json'
115
+
116
+ mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
117
+ mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json'
118
+
119
+ mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
120
+ mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json'
121
+
122
+ mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
123
+ mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json'
124
+
125
+ mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
126
+ mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json'
127
+
128
+ mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
129
+ mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json'
130
+
131
+ mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
132
+ mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json'
133
+
134
+ mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
135
+ mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json'
136
+
137
+ mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017'
138
+ mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json'
139
+
140
+ mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
141
+ mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json'
142
+
143
+ mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
144
+ mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json'
145
+
146
+ mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/'
147
+ mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json'
148
+
149
+ mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017'
150
+ mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json'
151
+
152
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
153
+ max_length = int(2048 - (1024 / 64)**2 - 100)
154
+
155
+ # Scheduler & Optimizer
156
+ batch_size = 8 # per_device
157
+ accumulative_counts = 2
158
+ dataloader_num_workers = 4
159
+ max_epochs = 1
160
+ optim_type = AdamW
161
+ lr = 2e-4
162
+ betas = (0.9, 0.999)
163
+ weight_decay = 0
164
+ max_norm = 1 # grad clip
165
+ warmup_ratio = 0.03
166
+
167
+
168
+ # Save
169
+ save_steps = 2000
170
+ save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited)
171
+
172
+ # Evaluate the generation performance during the training
173
+ evaluation_freq = 2000
174
+ SYSTEM = ''
175
+ evaluation_images = './work_dirs/test.jpg'
176
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture',
177
+ 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.']
178
+
179
+ #######################################################################
180
+ # PART 2 Model & Tokenizer & Image Processor #
181
+ #######################################################################
182
+ tokenizer = dict(
183
+ type=AutoTokenizer.from_pretrained,
184
+ pretrained_model_name_or_path=llm_name_or_path,
185
+ trust_remote_code=True,
186
+ padding_side='right')
187
+
188
+ image_processor = dict(
189
+ type=CLIPImageProcessor,
190
+ do_resize=True,
191
+ size=1024,
192
+ resample=3,
193
+ do_center_crop=True,
194
+ crop_size=1024,
195
+ do_rescale=True,
196
+ do_normalize=True,
197
+ image_mean=[0.4814, 0.4578, 0.4082],
198
+ image_std=[0.2686, 0.2613, 0.2757],
199
+ do_convert_rgb=True
200
+ )
201
+
202
+ class_embed = 'convnext_large_d_320_CocoPanopticOVDataset'
203
+ num_things_classes = 80
204
+ num_stuff_classes = 53
205
+ num_classes = num_things_classes + num_stuff_classes
206
+
207
+ omgseg_model = dict(
208
+ type=OMGSegVisualEncoder,
209
+ data_preprocessor=None,
210
+ pixel_shuffle_down_ratio=2,
211
+ backbone=dict(
212
+ type=OpenCLIPBackbone_omgseg,
213
+ model_name='convnext_large_d_320',
214
+ fix=True,
215
+ init_cfg=dict(
216
+ type='clip_pretrain',
217
+ checkpoint='laion2b_s29b_b131k_ft_soup'
218
+ )
219
+ ),
220
+ panoptic_head=dict(
221
+ type=Mask2FormerVideoSemSamHead,
222
+ sphere_cls=True,
223
+ ov_path=omg_ov_class_embed_path,
224
+ enable_box_query=False,
225
+ ov_classifier_name=class_embed,
226
+ logit=None,
227
+ in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside
228
+ strides=[4, 8, 16, 32],
229
+ feat_channels=256,
230
+ out_channels=256,
231
+ num_things_classes=num_things_classes,
232
+ num_stuff_classes=num_stuff_classes,
233
+ num_queries=300,
234
+ num_transformer_feat_level=3,
235
+ pixel_decoder=dict(
236
+ type=MSDeformAttnPixelDecoder,
237
+ num_outs=3,
238
+ norm_cfg=dict(type=GroupNorm, num_groups=32),
239
+ act_cfg=dict(type=ReLU),
240
+ encoder=dict( # DeformableDetrTransformerEncoder
241
+ num_layers=6,
242
+ layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
243
+ self_attn_cfg=dict( # MultiScaleDeformableAttention
244
+ embed_dims=256,
245
+ num_heads=8,
246
+ num_levels=3,
247
+ num_points=4,
248
+ dropout=0.0,
249
+ batch_first=True),
250
+ ffn_cfg=dict(
251
+ embed_dims=256,
252
+ feedforward_channels=1024,
253
+ num_fcs=2,
254
+ ffn_drop=0.0,
255
+ act_cfg=dict(type=ReLU, inplace=True)))),
256
+ positional_encoding=dict(num_feats=128, normalize=True)),
257
+ enforce_decoder_input_project=False,
258
+ positional_encoding=dict(num_feats=128, normalize=True),
259
+ transformer_decoder=dict( # Mask2FormerTransformerDecoder
260
+ return_intermediate=True,
261
+ num_layers=9,
262
+ layer_cfg=dict( # Mask2FormerTransformerDecoderLayer
263
+ self_attn_cfg=dict( # MultiheadAttention
264
+ embed_dims=256,
265
+ num_heads=8,
266
+ dropout=0.0,
267
+ batch_first=True),
268
+ cross_attn_cfg=dict( # MultiheadAttention
269
+ embed_dims=256,
270
+ num_heads=8,
271
+ dropout=0.0,
272
+ batch_first=True),
273
+ ffn_cfg=dict(
274
+ embed_dims=256,
275
+ feedforward_channels=2048,
276
+ num_fcs=2,
277
+ ffn_drop=0.0,
278
+ act_cfg=dict(type='ReLU', inplace=True))),
279
+ init_cfg=None),
280
+ loss_cls=dict(
281
+ type=CrossEntropyLoss,
282
+ use_sigmoid=False,
283
+ loss_weight=2.0,
284
+ reduction='mean',
285
+ class_weight=[1.0] * 240 + [0.1]),
286
+ loss_mask=dict(
287
+ type=CrossEntropyLoss,
288
+ use_sigmoid=True,
289
+ reduction='mean',
290
+ loss_weight=5.0),
291
+ loss_dice=dict(
292
+ type=DiceLoss,
293
+ use_sigmoid=True,
294
+ activate=True,
295
+ reduction='mean',
296
+ naive_dice=True,
297
+ eps=1.0,
298
+ loss_weight=5.0),
299
+ loss_iou=dict(
300
+ type=FocalLoss,
301
+ use_sigmoid=True,
302
+ loss_weight=2.0,
303
+ reduction='mean')
304
+ ),
305
+ panoptic_fusion_head=dict(
306
+ type=MaskFormerFusionHead,
307
+ num_things_classes=num_things_classes,
308
+ num_stuff_classes=num_stuff_classes,
309
+ loss_panoptic=None,
310
+ init_cfg=None),
311
+ train_cfg=dict(
312
+ num_points=12544,
313
+ oversample_ratio=3.0,
314
+ importance_sample_ratio=0.75,
315
+ assigner=dict(
316
+ type=HungarianAssigner,
317
+ match_costs=[
318
+ # dict(type=FlexibleClassificationCost, weight=2.0),
319
+ dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True),
320
+ dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0)
321
+ ]),
322
+ sampler=dict(type=MaskPseudoSampler)),
323
+ test_cfg=dict(
324
+ panoptic_on=True,
325
+ # For now, the dataset does not support
326
+ # evaluating semantic segmentation metric.
327
+ semantic_on=False,
328
+ instance_on=True,
329
+ # max_per_image is for instance segmentation.
330
+ max_per_image=100,
331
+ iou_thr=0.8,
332
+ # In Mask2Former's panoptic postprocessing,
333
+ # it will filter mask area where score is less than 0.5 .
334
+ filter_low_score=True),
335
+ init_cfg=dict(
336
+ type='Pretrained',
337
+ checkpoint=omg_head_pretrain_pth_path,
338
+ )
339
+ )
340
+
341
+ model = dict(
342
+ type=OMG_LLaVA,
343
+ freeze_llm=True,
344
+ freeze_visual_encoder=True,
345
+ require_omg_decoder=False,
346
+ pretrained_pth=pretrained_pth,
347
+ text2vision_projector=True,
348
+ pixel_shuffle_ratio=2,
349
+ visual_prompt_proj=True,
350
+ add_cross_attn_layer=True,
351
+ llm=dict(
352
+ type=AutoModelForCausalLM.from_pretrained,
353
+ pretrained_model_name_or_path=llm_name_or_path,
354
+ trust_remote_code=True,
355
+ torch_dtype=torch.float16,
356
+ quantization_config=dict(
357
+ type=BitsAndBytesConfig,
358
+ load_in_4bit=True,
359
+ load_in_8bit=False,
360
+ llm_int8_threshold=6.0,
361
+ llm_int8_has_fp16_weight=False,
362
+ bnb_4bit_compute_dtype=torch.float16,
363
+ bnb_4bit_use_double_quant=True,
364
+ bnb_4bit_quant_type='nf4')),
365
+ llm_lora=dict(
366
+ type=LoraConfig,
367
+ r=512,
368
+ lora_alpha=256,
369
+ lora_dropout=0.05,
370
+ bias='none',
371
+ task_type='CAUSAL_LM'),
372
+ visual_encoder=omgseg_model,
373
+ tokenizer=tokenizer,
374
+ )
375
+
376
+ #######################################################################
377
+ # PART 3 Dataset & Dataloader #
378
+ #######################################################################
379
+ debug=False
380
+ llava_dataset = dict(
381
+ type=LLaVADataset,
382
+ data_path=data_path,
383
+ image_folder=image_folder,
384
+ tokenizer=tokenizer,
385
+ image_processor=image_processor,
386
+ dataset_map_fn=llava_map_fn,
387
+ template_map_fn=dict(
388
+ type=template_map_fn_factory, template=prompt_template),
389
+ max_length=max_length,
390
+ pad_image_to_square=True)
391
+
392
+ glamm_refcocog_dataset = dict(
393
+ type=RefCOCOgGCGDataset,
394
+ data_path=refcocog_ann_file,
395
+ image_folder=refcocog_image_path,
396
+ tokenizer=tokenizer,
397
+ image_processor=image_processor,
398
+ dataset_map_fn=glamm_refcocog_map_fn,
399
+ template_map_fn=dict(
400
+ type=template_map_fn_factory, template=prompt_template),
401
+ max_length=max_length,
402
+ pad_image_to_square=True,
403
+ debug=False,
404
+ repeats=1,
405
+ )
406
+
407
+ glamm_grandf_dataset = dict(
408
+ type=GranDfGCGDataset,
409
+ data_path=grandf_ann_file,
410
+ image_folder=grandf_image_path,
411
+ tokenizer=tokenizer,
412
+ image_processor=image_processor,
413
+ dataset_map_fn=glamm_granf_map_fn,
414
+ template_map_fn=dict(
415
+ type=template_map_fn_factory, template=prompt_template),
416
+ max_length=max_length,
417
+ pad_image_to_square=True,
418
+ debug=debug,
419
+ repeats=10,
420
+ )
421
+
422
+ glamm_psg_dataset = dict(
423
+ type=OpenPsgGCGDataset,
424
+ data_path=psg_ann_file,
425
+ image_folder=psg_image_path,
426
+ tokenizer=tokenizer,
427
+ image_processor=image_processor,
428
+ dataset_map_fn=glamm_openpsg_map_fn,
429
+ template_map_fn=dict(
430
+ type=template_map_fn_factory, template=prompt_template),
431
+ max_length=max_length,
432
+ pad_image_to_square=True,
433
+ debug=debug,
434
+ repeats=1,
435
+ )
436
+
437
+ glamm_flickr_dataset = dict(
438
+ type=FlickrGCGDataset,
439
+ data_path=flickr_ann_file,
440
+ image_folder=flickr_image_path,
441
+ tokenizer=tokenizer,
442
+ image_processor=image_processor,
443
+ dataset_map_fn=glamm_flickr_map_fn,
444
+ template_map_fn=dict(
445
+ type=template_map_fn_factory, template=prompt_template),
446
+ max_length=max_length,
447
+ pad_image_to_square=True,
448
+ debug=debug,
449
+ repeats=1,
450
+ )
451
+
452
+ semantic_seg_ade20k_dataset = dict(
453
+ type=ADE20kSemanticSegDataset,
454
+ data_path=ade20k_class_file,
455
+ image_folder=ade20k_image_path,
456
+ tokenizer=tokenizer,
457
+ image_processor=image_processor,
458
+ dataset_map_fn=semantic_seg_gcg_format_map_fn,
459
+ template_map_fn=dict(
460
+ type=template_map_fn_factory, template=prompt_template),
461
+ max_length=max_length,
462
+ pad_image_to_square=True,
463
+ debug=False,
464
+ repeats=1,
465
+ gcg_format=True,
466
+ )
467
+
468
+ semantic_seg_cocostuff_dataset = dict(
469
+ type=COCOStuffSemanticSegDataset,
470
+ data_path=cocostuff_class_file,
471
+ image_folder=cocostuff_image_path,
472
+ label_path=cocostuff_label_path,
473
+ tokenizer=tokenizer,
474
+ image_processor=image_processor,
475
+ dataset_map_fn=semantic_seg_gcg_format_map_fn,
476
+ template_map_fn=dict(
477
+ type=template_map_fn_factory, template=prompt_template),
478
+ max_length=max_length,
479
+ pad_image_to_square=True,
480
+ debug=False,
481
+ repeats=1,
482
+ gcg_format=True,
483
+ )
484
+
485
+ referring_seg_refcoco_dataset = dict(
486
+ type=RefcocoReferringSegDataset,
487
+ data_path=referring_refcoco_data_path,
488
+ image_folder=referring_refcoco_image_path,
489
+ tokenizer=tokenizer,
490
+ image_processor=image_processor,
491
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
492
+ template_map_fn=dict(
493
+ type=template_map_fn_factory, template=prompt_template),
494
+ max_length=max_length,
495
+ pad_image_to_square=True,
496
+ debug=False,
497
+ repeats=1,
498
+ )
499
+
500
+ referring_seg_refcoco_plus_dataset = dict(
501
+ type=Refcoco_plus_ReferringSegDataset,
502
+ data_path=referring_refcoco_plus_data_path,
503
+ image_folder=referring_refcoco_plus_image_path,
504
+ tokenizer=tokenizer,
505
+ image_processor=image_processor,
506
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
507
+ template_map_fn=dict(
508
+ type=template_map_fn_factory, template=prompt_template),
509
+ max_length=max_length,
510
+ pad_image_to_square=True,
511
+ debug=False,
512
+ repeats=1,
513
+ )
514
+
515
+ referring_seg_refcocog_dataset = dict(
516
+ type=Refcocog_ReferringSegDataset,
517
+ data_path=referring_refcocog_data_path,
518
+ image_folder=referring_refcocog_image_path,
519
+ tokenizer=tokenizer,
520
+ image_processor=image_processor,
521
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
522
+ template_map_fn=dict(
523
+ type=template_map_fn_factory, template=prompt_template),
524
+ max_length=max_length,
525
+ pad_image_to_square=True,
526
+ debug=False,
527
+ repeats=1,
528
+ )
529
+
530
+ referring_seg_refclef_dataset = dict(
531
+ type=Refclef_ReferringSegDataset,
532
+ data_path=referring_refclef_data_path,
533
+ image_folder=referring_refclef_image_path,
534
+ tokenizer=tokenizer,
535
+ image_processor=image_processor,
536
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
537
+ template_map_fn=dict(
538
+ type=template_map_fn_factory, template=prompt_template),
539
+ max_length=max_length,
540
+ pad_image_to_square=True,
541
+ debug=False,
542
+ repeats=1,
543
+ )
544
+
545
+ region_cap_osprey_dataset = dict(
546
+ type=OspreyRegionCaptionDataset,
547
+ data_path=region_cap_osprey_data_path,
548
+ image_folder=region_cap_osprey_image_path,
549
+ tokenizer=tokenizer,
550
+ image_processor=image_processor,
551
+ dataset_map_fn=osprey_region_caption_gcg_format_map_fn,
552
+ template_map_fn=dict(
553
+ type=template_map_fn_factory, template=prompt_template),
554
+ max_length=max_length,
555
+ pad_image_to_square=True,
556
+ debug=False,
557
+ repeats=1,
558
+ )
559
+
560
+ region_conversation_osprey_dataset = dict(
561
+ type=OspreyRegionConversationDataset,
562
+ data_path=region_conversation_osprey_data_path,
563
+ image_folder=region_conversation_osprey_image_path,
564
+ tokenizer=tokenizer,
565
+ image_processor=image_processor,
566
+ dataset_map_fn=osprey_region_conversation_map_fn,
567
+ template_map_fn=dict(
568
+ type=template_map_fn_factory, template=prompt_template),
569
+ max_length=max_length,
570
+ pad_image_to_square=True,
571
+ debug=False,
572
+ repeats=1,
573
+ )
574
+
575
+ mdpv_detailed_description_ade20k_dataset = dict(
576
+ type=MDPVPointDetailedCaptionDataset,
577
+ data_path=mdpv_detailed_caption_ade20k_data_path,
578
+ image_folder=mdpv_detailed_caption_ade20k_image_path,
579
+ tokenizer=tokenizer,
580
+ image_processor=image_processor,
581
+ dataset_map_fn=mdpv_points_map_fn,
582
+ template_map_fn=dict(
583
+ type=template_map_fn_factory, template=prompt_template),
584
+ max_length=max_length,
585
+ pad_image_to_square=True,
586
+ debug=False,
587
+ repeats=1,
588
+ )
589
+
590
+ mdpv_detailed_description_cocostuff_10k_dataset = dict(
591
+ type=MDPVPointDetailedCaptionDataset,
592
+ data_path=mdpv_detailed_caption_cocostuff_10k_data_path,
593
+ image_folder=mdpv_detailed_caption_cocostuff_10k_image_path,
594
+ tokenizer=tokenizer,
595
+ image_processor=image_processor,
596
+ dataset_map_fn=mdpv_points_map_fn,
597
+ template_map_fn=dict(
598
+ type=template_map_fn_factory, template=prompt_template),
599
+ max_length=max_length,
600
+ pad_image_to_square=True,
601
+ debug=False,
602
+ repeats=1,
603
+ )
604
+
605
+ mdpv_detailed_description_cocostuff_164k_dataset = dict(
606
+ type=MDPVPointDetailedCaptionDataset,
607
+ data_path=mdpv_detailed_caption_cocostuff_164k_data_path,
608
+ image_folder=mdpv_detailed_caption_cocostuff_164k_image_path,
609
+ tokenizer=tokenizer,
610
+ image_processor=image_processor,
611
+ dataset_map_fn=mdpv_points_map_fn,
612
+ template_map_fn=dict(
613
+ type=template_map_fn_factory, template=prompt_template),
614
+ max_length=max_length,
615
+ pad_image_to_square=True,
616
+ debug=False,
617
+ repeats=1,
618
+ )
619
+
620
+ mdpv_detailed_description_vg_dataset = dict(
621
+ type=MDPVPointDetailedCaptionDataset,
622
+ data_path=mdpv_detailed_caption_vg_data_path,
623
+ image_folder=mdpv_detailed_caption_vg_image_path,
624
+ tokenizer=tokenizer,
625
+ image_processor=image_processor,
626
+ dataset_map_fn=mdpv_points_map_fn,
627
+ template_map_fn=dict(
628
+ type=template_map_fn_factory, template=prompt_template),
629
+ max_length=max_length,
630
+ pad_image_to_square=True,
631
+ debug=False,
632
+ repeats=1,
633
+ )
634
+
635
+ mdpv_brief_description_vg_dataset = dict(
636
+ type=MDPVPointBriefCaptionDataset,
637
+ data_path=mdpv_brief_caption_vg_data_path,
638
+ image_folder=mdpv_brief_caption_vg_image_path,
639
+ tokenizer=tokenizer,
640
+ image_processor=image_processor,
641
+ dataset_map_fn=mdpv_points_map_fn,
642
+ template_map_fn=dict(
643
+ type=template_map_fn_factory, template=prompt_template),
644
+ max_length=max_length,
645
+ pad_image_to_square=True,
646
+ debug=False,
647
+ repeats=1,
648
+ )
649
+
650
+ mdpv_brief_description_cocostuff10k_dataset = dict(
651
+ type=MDPVPointBriefCaptionDataset,
652
+ data_path=mdpv_brief_caption_cocostuff_10k_data_path,
653
+ image_folder=mdpv_brief_caption_cocostuff_10k_image_path,
654
+ tokenizer=tokenizer,
655
+ image_processor=image_processor,
656
+ dataset_map_fn=mdpv_points_map_fn,
657
+ template_map_fn=dict(
658
+ type=template_map_fn_factory, template=prompt_template),
659
+ max_length=max_length,
660
+ pad_image_to_square=True,
661
+ debug=False,
662
+ repeats=1,
663
+ )
664
+
665
+ mdpv_brief_description_cocostuff164k_dataset = dict(
666
+ type=MDPVPointBriefCaptionDataset,
667
+ data_path=mdpv_brief_caption_cocostuff_164k_data_path,
668
+ image_folder=mdpv_brief_caption_cocostuff_164k_image_path,
669
+ tokenizer=tokenizer,
670
+ image_processor=image_processor,
671
+ dataset_map_fn=mdpv_points_map_fn,
672
+ template_map_fn=dict(
673
+ type=template_map_fn_factory, template=prompt_template),
674
+ max_length=max_length,
675
+ pad_image_to_square=True,
676
+ debug=False,
677
+ repeats=1,
678
+ )
679
+
680
+ mdpv_brief_description_ade20k_dataset = dict(
681
+ type=MDPVPointBriefCaptionDataset,
682
+ data_path=mdpv_brief_caption_ade20k_data_path,
683
+ image_folder=mdpv_brief_caption_ade20k_image_path,
684
+ tokenizer=tokenizer,
685
+ image_processor=image_processor,
686
+ dataset_map_fn=mdpv_points_map_fn,
687
+ template_map_fn=dict(
688
+ type=template_map_fn_factory, template=prompt_template),
689
+ max_length=max_length,
690
+ pad_image_to_square=True,
691
+ debug=False,
692
+ repeats=1,
693
+ )
694
+
695
+ mdpv_brief_description_lvis_dataset = dict(
696
+ type=MDPVPointBriefCaptionDataset,
697
+ data_path=mdpv_brief_caption_lvis_data_path,
698
+ image_folder=mdpv_brief_caption_lvis_image_path,
699
+ tokenizer=tokenizer,
700
+ image_processor=image_processor,
701
+ dataset_map_fn=mdpv_points_map_fn,
702
+ template_map_fn=dict(
703
+ type=template_map_fn_factory, template=prompt_template),
704
+ max_length=max_length,
705
+ pad_image_to_square=True,
706
+ debug=False,
707
+ repeats=1,
708
+ )
709
+
710
+ mdpv_qa_vg_dataset = dict(
711
+ type=MDPVPointBriefCaptionDataset,
712
+ data_path=mdpv_qa_vg_data_path,
713
+ image_folder=mdpv_qa_vg_image_path,
714
+ tokenizer=tokenizer,
715
+ image_processor=image_processor,
716
+ dataset_map_fn=mdpv_points_map_fn,
717
+ template_map_fn=dict(
718
+ type=template_map_fn_factory, template=prompt_template),
719
+ max_length=max_length,
720
+ pad_image_to_square=True,
721
+ debug=False,
722
+ repeats=1,
723
+ )
724
+
725
+ mdpv_qa_ade20k_dataset = dict(
726
+ type=MDPVPointBriefCaptionDataset,
727
+ data_path=mdpv_qa_ade20k_data_path,
728
+ image_folder=mdpv_qa_ade20k_image_path,
729
+ tokenizer=tokenizer,
730
+ image_processor=image_processor,
731
+ dataset_map_fn=mdpv_points_map_fn,
732
+ template_map_fn=dict(
733
+ type=template_map_fn_factory, template=prompt_template),
734
+ max_length=max_length,
735
+ pad_image_to_square=True,
736
+ debug=False,
737
+ repeats=1,
738
+ )
739
+
740
+ mdpv_qa_lvis_dataset = dict(
741
+ type=MDPVPointBriefCaptionDataset,
742
+ data_path=mdpv_qa_lvis_data_path,
743
+ image_folder=mdpv_qa_lvis_image_path,
744
+ tokenizer=tokenizer,
745
+ image_processor=image_processor,
746
+ dataset_map_fn=mdpv_points_map_fn,
747
+ template_map_fn=dict(
748
+ type=template_map_fn_factory, template=prompt_template),
749
+ max_length=max_length,
750
+ pad_image_to_square=True,
751
+ debug=False,
752
+ repeats=1,
753
+ )
754
+
755
+ mdpv_qa_cocostuff10k_dataset = dict(
756
+ type=MDPVPointBriefCaptionDataset,
757
+ data_path=mdpv_qa_cocostuff10k_data_path,
758
+ image_folder=mdpv_qa_cocostuff10k_image_path,
759
+ tokenizer=tokenizer,
760
+ image_processor=image_processor,
761
+ dataset_map_fn=mdpv_points_map_fn,
762
+ template_map_fn=dict(
763
+ type=template_map_fn_factory, template=prompt_template),
764
+ max_length=max_length,
765
+ pad_image_to_square=True,
766
+ debug=False,
767
+ repeats=1,
768
+ )
769
+
770
+ mdpv_qa_cocostuff164k_dataset = dict(
771
+ type=MDPVPointBriefCaptionDataset,
772
+ data_path=mdpv_qa_cocostuff164k_data_path,
773
+ image_folder=mdpv_qa_cocostuff164k_image_path,
774
+ tokenizer=tokenizer,
775
+ image_processor=image_processor,
776
+ dataset_map_fn=mdpv_points_map_fn,
777
+ template_map_fn=dict(
778
+ type=template_map_fn_factory, template=prompt_template),
779
+ max_length=max_length,
780
+ pad_image_to_square=True,
781
+ debug=False,
782
+ repeats=1,
783
+ )
784
+
785
+ mdpv_multi_points_openpsg_dataset = dict(
786
+ type=MDPVPointBriefCaptionDataset,
787
+ data_path=mdpv_multi_points_openpsg_data_path,
788
+ image_folder=mdpv_multi_points_openpsg_image_path,
789
+ tokenizer=tokenizer,
790
+ image_processor=image_processor,
791
+ dataset_map_fn=mdpv_points_map_fn,
792
+ template_map_fn=dict(
793
+ type=template_map_fn_factory, template=prompt_template),
794
+ max_length=max_length,
795
+ pad_image_to_square=True,
796
+ debug=False,
797
+ repeats=1,
798
+ )
799
+
800
+ mdpv_multi_points_flicker30k_dataset = dict(
801
+ type=MDPVPointBriefCaptionDataset,
802
+ data_path=mdpv_multi_points_flicker30k_data_path,
803
+ image_folder=mdpv_multi_points_flicker30k_image_path,
804
+ tokenizer=tokenizer,
805
+ image_processor=image_processor,
806
+ dataset_map_fn=mdpv_points_map_fn,
807
+ template_map_fn=dict(
808
+ type=template_map_fn_factory, template=prompt_template),
809
+ max_length=max_length,
810
+ pad_image_to_square=True,
811
+ debug=False,
812
+ repeats=1,
813
+ )
814
+
815
+ train_dataset = dict(
816
+ type=CombineDataset,
817
+ datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset,
818
+ glamm_grandf_dataset, glamm_psg_dataset,
819
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x
820
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset,
821
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset,
822
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
823
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x
824
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
825
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset,
826
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
827
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset,
828
+ region_cap_osprey_dataset, region_conversation_osprey_dataset,
829
+ mdpv_detailed_description_ade20k_dataset,
830
+ mdpv_detailed_description_cocostuff_10k_dataset,
831
+ mdpv_detailed_description_cocostuff_164k_dataset,
832
+ mdpv_detailed_description_vg_dataset,
833
+ mdpv_brief_description_lvis_dataset,
834
+ mdpv_brief_description_vg_dataset,
835
+ mdpv_brief_description_ade20k_dataset,
836
+ mdpv_brief_description_cocostuff10k_dataset,
837
+ mdpv_brief_description_cocostuff164k_dataset,
838
+ mdpv_qa_vg_dataset,
839
+ mdpv_qa_lvis_dataset,
840
+ mdpv_qa_ade20k_dataset,
841
+ mdpv_qa_cocostuff10k_dataset,
842
+ mdpv_qa_cocostuff164k_dataset,
843
+ mdpv_multi_points_flicker30k_dataset,
844
+ mdpv_multi_points_openpsg_dataset,],
845
+ )
846
+
847
+ train_dataloader = dict(
848
+ batch_size=batch_size,
849
+ num_workers=dataloader_num_workers,
850
+ dataset=train_dataset,
851
+ sampler=dict(
852
+ type=LengthGroupedSampler,
853
+ length_property='modality_length',
854
+ per_device_batch_size=batch_size * accumulative_counts),
855
+ collate_fn=dict(type=omg_llava_collate_fn))
856
+
857
+ #######################################################################
858
+ # PART 4 Scheduler & Optimizer #
859
+ #######################################################################
860
+ # optimizer
861
+ optim_wrapper = dict(
862
+ type=AmpOptimWrapper,
863
+ optimizer=dict(
864
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
865
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
866
+ accumulative_counts=accumulative_counts,
867
+ loss_scale='dynamic',
868
+ dtype='float16')
869
+
870
+ # learning policy
871
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
872
+ param_scheduler = [
873
+ dict(
874
+ type=LinearLR,
875
+ start_factor=1e-5,
876
+ by_epoch=True,
877
+ begin=0,
878
+ end=warmup_ratio * max_epochs,
879
+ convert_to_iter_based=True),
880
+ dict(
881
+ type=CosineAnnealingLR,
882
+ eta_min=0.0,
883
+ by_epoch=True,
884
+ begin=warmup_ratio * max_epochs,
885
+ end=max_epochs,
886
+ convert_to_iter_based=True)
887
+ ]
888
+
889
+ # train, val, test setting
890
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
891
+
892
+ #######################################################################
893
+ # PART 5 Runtime #
894
+ #######################################################################
895
+ # Log the dialogue periodically during the training process, optional
896
+ custom_hooks = [
897
+ dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer),
898
+ dict(
899
+ type=EvaluateChatHook_withSpecialTokens,
900
+ tokenizer=tokenizer,
901
+ image_processor=image_processor,
902
+ every_n_iters=evaluation_freq,
903
+ evaluation_inputs=evaluation_inputs,
904
+ evaluation_images=evaluation_images,
905
+ system=SYSTEM,
906
+ prompt_template=prompt_template)
907
+ ]
908
+
909
+ # configure default hooks
910
+ default_hooks = dict(
911
+ # record the time of every iteration.
912
+ timer=dict(type=IterTimerHook),
913
+ # print log every 10 iterations.
914
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
915
+ # enable the parameter scheduler.
916
+ param_scheduler=dict(type=ParamSchedulerHook),
917
+ # save checkpoint per `save_steps`.
918
+ checkpoint=dict(
919
+ type=CheckpointHook,
920
+ by_epoch=False,
921
+ interval=save_steps,
922
+ max_keep_ckpts=save_total_limit),
923
+ # set sampler seed in distributed evrionment.
924
+ sampler_seed=dict(type=DistSamplerSeedHook),
925
+ )
926
+
927
+ # configure environment
928
+ env_cfg = dict(
929
+ # whether to enable cudnn benchmark
930
+ cudnn_benchmark=False,
931
+ # set multi process parameters
932
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
933
+ # set distributed parameters
934
+ dist_cfg=dict(backend='nccl'),
935
+ )
936
+
937
+ # set visualizer
938
+ visualizer = None
939
+
940
+ # set log level
941
+ log_level = 'INFO'
942
+
943
+ # load from which checkpoint
944
+ load_from = None
945
+
946
+ # whether to resume training from the loaded checkpoint
947
+ resume = False
948
+
949
+ # Defaults to use random seed and disable `deterministic`
950
+ randomness = dict(seed=None, deterministic=False)
951
+
952
+ # set log processor
953
+ log_processor = dict(by_epoch=False)
omg_llava/configs/finetune/ablation_projector/finetune_ablation_projector_seperate_cross_debug.py ADDED
@@ -0,0 +1,926 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
4
+ LoggerHook, ParamSchedulerHook)
5
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
6
+ from peft import LoraConfig
7
+ from torch.optim import AdamW
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig, CLIPImageProcessor,
10
+ CLIPVisionModel)
11
+
12
+ from omg_llava.dataset import LLaVADataset
13
+ from omg_llava.dataset.collect_fns import omg_llava_collate_fn
14
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
15
+ from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\
16
+ CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\
17
+ ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\
18
+ PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\
19
+ RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\
20
+ Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\
21
+ OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\
22
+ OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\
23
+ MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\
24
+ semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\
25
+ referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn
26
+ from xtuner.dataset.samplers import LengthGroupedSampler
27
+ from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens
28
+ from xtuner.engine.runner import TrainLoop
29
+ from omg_llava.model import OMG_LLaVA
30
+ from xtuner.utils import PROMPT_TEMPLATE
31
+ from omg_llava.model import OpenCLIPBackbone_omgseg
32
+ from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead
33
+
34
+ from torch.nn import GroupNorm, ReLU
35
+
36
+ from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \
37
+ DiceLoss, MaskFormerFusionHead, FocalLoss
38
+ from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost
39
+ from mmdet.models.task_modules.samplers import MaskPseudoSampler
40
+
41
+ #######################################################################
42
+ # PART 1 Settings #
43
+ #######################################################################
44
+ # Model
45
+ llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path
46
+ pretrained_pth = './pretrained/omg_llava/omg_llava_7b_pretrain_1024image_seperate_cross_8gpus.pth'
47
+ omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path
48
+ omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path
49
+
50
+ # Data
51
+ data_root = './data/llava_data/'
52
+ data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json'
53
+ image_folder = data_root + 'llava_images'
54
+
55
+ glamm_data_root = './data/glamm_data/'
56
+
57
+ refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/'
58
+ refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json'
59
+
60
+ grandf_image_path = glamm_data_root + 'images/grandf/train/'
61
+ grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json'
62
+
63
+ flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/'
64
+ flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json'
65
+
66
+ psg_image_path = glamm_data_root + 'images/coco2017/'
67
+ psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json'
68
+
69
+ ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
70
+ ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json'
71
+
72
+ cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/'
73
+ cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt'
74
+ cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/'
75
+
76
+ mapillary_image_path = './data/semantic_seg/mapillary/training/images/'
77
+ mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json'
78
+ mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/'
79
+
80
+ pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/'
81
+ pascal_file = './data/semantic_seg/pascal_part/train.json'
82
+
83
+ paco_image_path = './data/glamm_data/images/coco2017/'
84
+ paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json'
85
+
86
+ referring_refcoco_image_path = refcocog_image_path
87
+ referring_refcoco_data_path = "./data/ref_seg/"
88
+
89
+ referring_refcoco_plus_image_path = refcocog_image_path
90
+ referring_refcoco_plus_data_path = "./data/ref_seg/"
91
+
92
+ referring_refcocog_image_path = refcocog_image_path
93
+ referring_refcocog_data_path = "./data/ref_seg/"
94
+
95
+ referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/"
96
+ referring_refclef_data_path = "./data/ref_seg/"
97
+
98
+ region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
99
+ region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json"
100
+
101
+ region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
102
+ region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json"
103
+
104
+ mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
105
+ mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json'
106
+
107
+ mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
108
+ mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json'
109
+
110
+ mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
111
+ mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json'
112
+
113
+ mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
114
+ mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json'
115
+
116
+ mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
117
+ mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json'
118
+
119
+ mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
120
+ mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json'
121
+
122
+ mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
123
+ mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json'
124
+
125
+ mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
126
+ mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json'
127
+
128
+ mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
129
+ mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json'
130
+
131
+ mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
132
+ mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json'
133
+
134
+ mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
135
+ mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json'
136
+
137
+ mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017'
138
+ mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json'
139
+
140
+ mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
141
+ mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json'
142
+
143
+ mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
144
+ mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json'
145
+
146
+ mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/'
147
+ mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json'
148
+
149
+ mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017'
150
+ mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json'
151
+
152
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
153
+ max_length = int(2048 - (1024 / 64)**2 - 100)
154
+
155
+ # Scheduler & Optimizer
156
+ batch_size = 8 # per_device
157
+ accumulative_counts = 2
158
+ dataloader_num_workers = 0
159
+ max_epochs = 1
160
+ optim_type = AdamW
161
+ lr = 2e-4
162
+ betas = (0.9, 0.999)
163
+ weight_decay = 0
164
+ max_norm = 1 # grad clip
165
+ warmup_ratio = 0.03
166
+
167
+
168
+ # Save
169
+ save_steps = 2000
170
+ save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited)
171
+
172
+ # Evaluate the generation performance during the training
173
+ evaluation_freq = 2000
174
+ SYSTEM = ''
175
+ evaluation_images = './work_dirs/test.jpg'
176
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture',
177
+ 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.']
178
+
179
+ #######################################################################
180
+ # PART 2 Model & Tokenizer & Image Processor #
181
+ #######################################################################
182
+ tokenizer = dict(
183
+ type=AutoTokenizer.from_pretrained,
184
+ pretrained_model_name_or_path=llm_name_or_path,
185
+ trust_remote_code=True,
186
+ padding_side='right')
187
+
188
+ image_processor = dict(
189
+ type=CLIPImageProcessor,
190
+ do_resize=True,
191
+ size=1024,
192
+ resample=3,
193
+ do_center_crop=True,
194
+ crop_size=1024,
195
+ do_rescale=True,
196
+ do_normalize=True,
197
+ image_mean=[0.4814, 0.4578, 0.4082],
198
+ image_std=[0.2686, 0.2613, 0.2757],
199
+ do_convert_rgb=True
200
+ )
201
+
202
+ class_embed = 'convnext_large_d_320_CocoPanopticOVDataset'
203
+ num_things_classes = 80
204
+ num_stuff_classes = 53
205
+ num_classes = num_things_classes + num_stuff_classes
206
+
207
+ omgseg_model = dict(
208
+ type=OMGSegVisualEncoder,
209
+ data_preprocessor=None,
210
+ pixel_shuffle_down_ratio=2,
211
+ backbone=dict(
212
+ type=OpenCLIPBackbone_omgseg,
213
+ model_name='convnext_large_d_320',
214
+ fix=True,
215
+ init_cfg=dict(
216
+ type='clip_pretrain',
217
+ checkpoint='laion2b_s29b_b131k_ft_soup'
218
+ )
219
+ ),
220
+ panoptic_head=dict(
221
+ type=Mask2FormerVideoSemSamHead,
222
+ sphere_cls=True,
223
+ ov_path=omg_ov_class_embed_path,
224
+ enable_box_query=False,
225
+ ov_classifier_name=class_embed,
226
+ logit=None,
227
+ in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside
228
+ strides=[4, 8, 16, 32],
229
+ feat_channels=256,
230
+ out_channels=256,
231
+ num_things_classes=num_things_classes,
232
+ num_stuff_classes=num_stuff_classes,
233
+ num_queries=300,
234
+ num_transformer_feat_level=3,
235
+ pixel_decoder=dict(
236
+ type=MSDeformAttnPixelDecoder,
237
+ num_outs=3,
238
+ norm_cfg=dict(type=GroupNorm, num_groups=32),
239
+ act_cfg=dict(type=ReLU),
240
+ encoder=dict( # DeformableDetrTransformerEncoder
241
+ num_layers=6,
242
+ layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
243
+ self_attn_cfg=dict( # MultiScaleDeformableAttention
244
+ embed_dims=256,
245
+ num_heads=8,
246
+ num_levels=3,
247
+ num_points=4,
248
+ dropout=0.0,
249
+ batch_first=True),
250
+ ffn_cfg=dict(
251
+ embed_dims=256,
252
+ feedforward_channels=1024,
253
+ num_fcs=2,
254
+ ffn_drop=0.0,
255
+ act_cfg=dict(type=ReLU, inplace=True)))),
256
+ positional_encoding=dict(num_feats=128, normalize=True)),
257
+ enforce_decoder_input_project=False,
258
+ positional_encoding=dict(num_feats=128, normalize=True),
259
+ transformer_decoder=dict( # Mask2FormerTransformerDecoder
260
+ return_intermediate=True,
261
+ num_layers=9,
262
+ layer_cfg=dict( # Mask2FormerTransformerDecoderLayer
263
+ self_attn_cfg=dict( # MultiheadAttention
264
+ embed_dims=256,
265
+ num_heads=8,
266
+ dropout=0.0,
267
+ batch_first=True),
268
+ cross_attn_cfg=dict( # MultiheadAttention
269
+ embed_dims=256,
270
+ num_heads=8,
271
+ dropout=0.0,
272
+ batch_first=True),
273
+ ffn_cfg=dict(
274
+ embed_dims=256,
275
+ feedforward_channels=2048,
276
+ num_fcs=2,
277
+ ffn_drop=0.0,
278
+ act_cfg=dict(type='ReLU', inplace=True))),
279
+ init_cfg=None),
280
+ loss_cls=dict(
281
+ type=CrossEntropyLoss,
282
+ use_sigmoid=False,
283
+ loss_weight=2.0,
284
+ reduction='mean',
285
+ class_weight=[1.0] * 240 + [0.1]),
286
+ loss_mask=dict(
287
+ type=CrossEntropyLoss,
288
+ use_sigmoid=True,
289
+ reduction='mean',
290
+ loss_weight=5.0),
291
+ loss_dice=dict(
292
+ type=DiceLoss,
293
+ use_sigmoid=True,
294
+ activate=True,
295
+ reduction='mean',
296
+ naive_dice=True,
297
+ eps=1.0,
298
+ loss_weight=5.0),
299
+ loss_iou=dict(
300
+ type=FocalLoss,
301
+ use_sigmoid=True,
302
+ loss_weight=2.0,
303
+ reduction='mean')
304
+ ),
305
+ panoptic_fusion_head=dict(
306
+ type=MaskFormerFusionHead,
307
+ num_things_classes=num_things_classes,
308
+ num_stuff_classes=num_stuff_classes,
309
+ loss_panoptic=None,
310
+ init_cfg=None),
311
+ train_cfg=dict(
312
+ num_points=12544,
313
+ oversample_ratio=3.0,
314
+ importance_sample_ratio=0.75,
315
+ assigner=dict(
316
+ type=HungarianAssigner,
317
+ match_costs=[
318
+ # dict(type=FlexibleClassificationCost, weight=2.0),
319
+ dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True),
320
+ dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0)
321
+ ]),
322
+ sampler=dict(type=MaskPseudoSampler)),
323
+ test_cfg=dict(
324
+ panoptic_on=True,
325
+ # For now, the dataset does not support
326
+ # evaluating semantic segmentation metric.
327
+ semantic_on=False,
328
+ instance_on=True,
329
+ # max_per_image is for instance segmentation.
330
+ max_per_image=100,
331
+ iou_thr=0.8,
332
+ # In Mask2Former's panoptic postprocessing,
333
+ # it will filter mask area where score is less than 0.5 .
334
+ filter_low_score=True),
335
+ init_cfg=dict(
336
+ type='Pretrained',
337
+ checkpoint=omg_head_pretrain_pth_path,
338
+ )
339
+ )
340
+
341
+ model = dict(
342
+ type=OMG_LLaVA,
343
+ freeze_llm=True,
344
+ freeze_visual_encoder=True,
345
+ require_omg_decoder=False,
346
+ pretrained_pth=pretrained_pth,
347
+ text2vision_projector=True,
348
+ pixel_shuffle_ratio=2,
349
+ visual_prompt_proj=True,
350
+ add_cross_attn_layer=True,
351
+ llm=dict(
352
+ type=AutoModelForCausalLM.from_pretrained,
353
+ pretrained_model_name_or_path=llm_name_or_path,
354
+ trust_remote_code=True,
355
+ torch_dtype=torch.float16,
356
+ quantization_config=dict(
357
+ type=BitsAndBytesConfig,
358
+ load_in_4bit=True,
359
+ load_in_8bit=False,
360
+ llm_int8_threshold=6.0,
361
+ llm_int8_has_fp16_weight=False,
362
+ bnb_4bit_compute_dtype=torch.float16,
363
+ bnb_4bit_use_double_quant=True,
364
+ bnb_4bit_quant_type='nf4')),
365
+ llm_lora=dict(
366
+ type=LoraConfig,
367
+ r=512,
368
+ lora_alpha=256,
369
+ lora_dropout=0.05,
370
+ bias='none',
371
+ task_type='CAUSAL_LM'),
372
+ visual_encoder=omgseg_model,
373
+ tokenizer=tokenizer,
374
+ )
375
+
376
+ #######################################################################
377
+ # PART 3 Dataset & Dataloader #
378
+ #######################################################################
379
+ debug=False
380
+ llava_dataset = dict(
381
+ type=LLaVADataset,
382
+ data_path=data_path,
383
+ image_folder=image_folder,
384
+ tokenizer=tokenizer,
385
+ image_processor=image_processor,
386
+ dataset_map_fn=llava_map_fn,
387
+ template_map_fn=dict(
388
+ type=template_map_fn_factory, template=prompt_template),
389
+ max_length=max_length,
390
+ pad_image_to_square=True)
391
+
392
+ glamm_refcocog_dataset = dict(
393
+ type=RefCOCOgGCGDataset,
394
+ data_path=refcocog_ann_file,
395
+ image_folder=refcocog_image_path,
396
+ tokenizer=tokenizer,
397
+ image_processor=image_processor,
398
+ dataset_map_fn=glamm_refcocog_map_fn,
399
+ template_map_fn=dict(
400
+ type=template_map_fn_factory, template=prompt_template),
401
+ max_length=max_length,
402
+ pad_image_to_square=True,
403
+ debug=False,
404
+ repeats=1,
405
+ )
406
+
407
+ glamm_grandf_dataset = dict(
408
+ type=GranDfGCGDataset,
409
+ data_path=grandf_ann_file,
410
+ image_folder=grandf_image_path,
411
+ tokenizer=tokenizer,
412
+ image_processor=image_processor,
413
+ dataset_map_fn=glamm_granf_map_fn,
414
+ template_map_fn=dict(
415
+ type=template_map_fn_factory, template=prompt_template),
416
+ max_length=max_length,
417
+ pad_image_to_square=True,
418
+ debug=debug,
419
+ repeats=10,
420
+ )
421
+
422
+ glamm_psg_dataset = dict(
423
+ type=OpenPsgGCGDataset,
424
+ data_path=psg_ann_file,
425
+ image_folder=psg_image_path,
426
+ tokenizer=tokenizer,
427
+ image_processor=image_processor,
428
+ dataset_map_fn=glamm_openpsg_map_fn,
429
+ template_map_fn=dict(
430
+ type=template_map_fn_factory, template=prompt_template),
431
+ max_length=max_length,
432
+ pad_image_to_square=True,
433
+ debug=debug,
434
+ repeats=1,
435
+ )
436
+
437
+ glamm_flickr_dataset = dict(
438
+ type=FlickrGCGDataset,
439
+ data_path=flickr_ann_file,
440
+ image_folder=flickr_image_path,
441
+ tokenizer=tokenizer,
442
+ image_processor=image_processor,
443
+ dataset_map_fn=glamm_flickr_map_fn,
444
+ template_map_fn=dict(
445
+ type=template_map_fn_factory, template=prompt_template),
446
+ max_length=max_length,
447
+ pad_image_to_square=True,
448
+ debug=debug,
449
+ repeats=1,
450
+ )
451
+
452
+ semantic_seg_ade20k_dataset = dict(
453
+ type=ADE20kSemanticSegDataset,
454
+ data_path=ade20k_class_file,
455
+ image_folder=ade20k_image_path,
456
+ tokenizer=tokenizer,
457
+ image_processor=image_processor,
458
+ dataset_map_fn=semantic_seg_gcg_format_map_fn,
459
+ template_map_fn=dict(
460
+ type=template_map_fn_factory, template=prompt_template),
461
+ max_length=max_length,
462
+ pad_image_to_square=True,
463
+ debug=False,
464
+ repeats=1,
465
+ gcg_format=True,
466
+ )
467
+
468
+ semantic_seg_cocostuff_dataset = dict(
469
+ type=COCOStuffSemanticSegDataset,
470
+ data_path=cocostuff_class_file,
471
+ image_folder=cocostuff_image_path,
472
+ label_path=cocostuff_label_path,
473
+ tokenizer=tokenizer,
474
+ image_processor=image_processor,
475
+ dataset_map_fn=semantic_seg_gcg_format_map_fn,
476
+ template_map_fn=dict(
477
+ type=template_map_fn_factory, template=prompt_template),
478
+ max_length=max_length,
479
+ pad_image_to_square=True,
480
+ debug=False,
481
+ repeats=1,
482
+ gcg_format=True,
483
+ )
484
+
485
+ referring_seg_refcoco_dataset = dict(
486
+ type=RefcocoReferringSegDataset,
487
+ data_path=referring_refcoco_data_path,
488
+ image_folder=referring_refcoco_image_path,
489
+ tokenizer=tokenizer,
490
+ image_processor=image_processor,
491
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
492
+ template_map_fn=dict(
493
+ type=template_map_fn_factory, template=prompt_template),
494
+ max_length=max_length,
495
+ pad_image_to_square=True,
496
+ debug=False,
497
+ repeats=1,
498
+ )
499
+
500
+ referring_seg_refcoco_plus_dataset = dict(
501
+ type=Refcoco_plus_ReferringSegDataset,
502
+ data_path=referring_refcoco_plus_data_path,
503
+ image_folder=referring_refcoco_plus_image_path,
504
+ tokenizer=tokenizer,
505
+ image_processor=image_processor,
506
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
507
+ template_map_fn=dict(
508
+ type=template_map_fn_factory, template=prompt_template),
509
+ max_length=max_length,
510
+ pad_image_to_square=True,
511
+ debug=False,
512
+ repeats=1,
513
+ )
514
+
515
+ referring_seg_refcocog_dataset = dict(
516
+ type=Refcocog_ReferringSegDataset,
517
+ data_path=referring_refcocog_data_path,
518
+ image_folder=referring_refcocog_image_path,
519
+ tokenizer=tokenizer,
520
+ image_processor=image_processor,
521
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
522
+ template_map_fn=dict(
523
+ type=template_map_fn_factory, template=prompt_template),
524
+ max_length=max_length,
525
+ pad_image_to_square=True,
526
+ debug=False,
527
+ repeats=1,
528
+ )
529
+
530
+ referring_seg_refclef_dataset = dict(
531
+ type=Refclef_ReferringSegDataset,
532
+ data_path=referring_refclef_data_path,
533
+ image_folder=referring_refclef_image_path,
534
+ tokenizer=tokenizer,
535
+ image_processor=image_processor,
536
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
537
+ template_map_fn=dict(
538
+ type=template_map_fn_factory, template=prompt_template),
539
+ max_length=max_length,
540
+ pad_image_to_square=True,
541
+ debug=False,
542
+ repeats=1,
543
+ )
544
+
545
+ region_cap_osprey_dataset = dict(
546
+ type=OspreyRegionCaptionDataset,
547
+ data_path=region_cap_osprey_data_path,
548
+ image_folder=region_cap_osprey_image_path,
549
+ tokenizer=tokenizer,
550
+ image_processor=image_processor,
551
+ dataset_map_fn=osprey_region_caption_gcg_format_map_fn,
552
+ template_map_fn=dict(
553
+ type=template_map_fn_factory, template=prompt_template),
554
+ max_length=max_length,
555
+ pad_image_to_square=True,
556
+ debug=False,
557
+ repeats=1,
558
+ )
559
+
560
+ region_conversation_osprey_dataset = dict(
561
+ type=OspreyRegionConversationDataset,
562
+ data_path=region_conversation_osprey_data_path,
563
+ image_folder=region_conversation_osprey_image_path,
564
+ tokenizer=tokenizer,
565
+ image_processor=image_processor,
566
+ dataset_map_fn=osprey_region_conversation_map_fn,
567
+ template_map_fn=dict(
568
+ type=template_map_fn_factory, template=prompt_template),
569
+ max_length=max_length,
570
+ pad_image_to_square=True,
571
+ debug=False,
572
+ repeats=1,
573
+ )
574
+
575
+ mdpv_detailed_description_ade20k_dataset = dict(
576
+ type=MDPVPointDetailedCaptionDataset,
577
+ data_path=mdpv_detailed_caption_ade20k_data_path,
578
+ image_folder=mdpv_detailed_caption_ade20k_image_path,
579
+ tokenizer=tokenizer,
580
+ image_processor=image_processor,
581
+ dataset_map_fn=mdpv_points_map_fn,
582
+ template_map_fn=dict(
583
+ type=template_map_fn_factory, template=prompt_template),
584
+ max_length=max_length,
585
+ pad_image_to_square=True,
586
+ debug=False,
587
+ repeats=1,
588
+ )
589
+
590
+ mdpv_detailed_description_cocostuff_10k_dataset = dict(
591
+ type=MDPVPointDetailedCaptionDataset,
592
+ data_path=mdpv_detailed_caption_cocostuff_10k_data_path,
593
+ image_folder=mdpv_detailed_caption_cocostuff_10k_image_path,
594
+ tokenizer=tokenizer,
595
+ image_processor=image_processor,
596
+ dataset_map_fn=mdpv_points_map_fn,
597
+ template_map_fn=dict(
598
+ type=template_map_fn_factory, template=prompt_template),
599
+ max_length=max_length,
600
+ pad_image_to_square=True,
601
+ debug=False,
602
+ repeats=1,
603
+ )
604
+
605
+ mdpv_detailed_description_cocostuff_164k_dataset = dict(
606
+ type=MDPVPointDetailedCaptionDataset,
607
+ data_path=mdpv_detailed_caption_cocostuff_164k_data_path,
608
+ image_folder=mdpv_detailed_caption_cocostuff_164k_image_path,
609
+ tokenizer=tokenizer,
610
+ image_processor=image_processor,
611
+ dataset_map_fn=mdpv_points_map_fn,
612
+ template_map_fn=dict(
613
+ type=template_map_fn_factory, template=prompt_template),
614
+ max_length=max_length,
615
+ pad_image_to_square=True,
616
+ debug=False,
617
+ repeats=1,
618
+ )
619
+
620
+ mdpv_detailed_description_vg_dataset = dict(
621
+ type=MDPVPointDetailedCaptionDataset,
622
+ data_path=mdpv_detailed_caption_vg_data_path,
623
+ image_folder=mdpv_detailed_caption_vg_image_path,
624
+ tokenizer=tokenizer,
625
+ image_processor=image_processor,
626
+ dataset_map_fn=mdpv_points_map_fn,
627
+ template_map_fn=dict(
628
+ type=template_map_fn_factory, template=prompt_template),
629
+ max_length=max_length,
630
+ pad_image_to_square=True,
631
+ debug=False,
632
+ repeats=1,
633
+ )
634
+
635
+ mdpv_brief_description_vg_dataset = dict(
636
+ type=MDPVPointBriefCaptionDataset,
637
+ data_path=mdpv_brief_caption_vg_data_path,
638
+ image_folder=mdpv_brief_caption_vg_image_path,
639
+ tokenizer=tokenizer,
640
+ image_processor=image_processor,
641
+ dataset_map_fn=mdpv_points_map_fn,
642
+ template_map_fn=dict(
643
+ type=template_map_fn_factory, template=prompt_template),
644
+ max_length=max_length,
645
+ pad_image_to_square=True,
646
+ debug=False,
647
+ repeats=1,
648
+ )
649
+
650
+ mdpv_brief_description_cocostuff10k_dataset = dict(
651
+ type=MDPVPointBriefCaptionDataset,
652
+ data_path=mdpv_brief_caption_cocostuff_10k_data_path,
653
+ image_folder=mdpv_brief_caption_cocostuff_10k_image_path,
654
+ tokenizer=tokenizer,
655
+ image_processor=image_processor,
656
+ dataset_map_fn=mdpv_points_map_fn,
657
+ template_map_fn=dict(
658
+ type=template_map_fn_factory, template=prompt_template),
659
+ max_length=max_length,
660
+ pad_image_to_square=True,
661
+ debug=False,
662
+ repeats=1,
663
+ )
664
+
665
+ mdpv_brief_description_cocostuff164k_dataset = dict(
666
+ type=MDPVPointBriefCaptionDataset,
667
+ data_path=mdpv_brief_caption_cocostuff_164k_data_path,
668
+ image_folder=mdpv_brief_caption_cocostuff_164k_image_path,
669
+ tokenizer=tokenizer,
670
+ image_processor=image_processor,
671
+ dataset_map_fn=mdpv_points_map_fn,
672
+ template_map_fn=dict(
673
+ type=template_map_fn_factory, template=prompt_template),
674
+ max_length=max_length,
675
+ pad_image_to_square=True,
676
+ debug=False,
677
+ repeats=1,
678
+ )
679
+
680
+ mdpv_brief_description_ade20k_dataset = dict(
681
+ type=MDPVPointBriefCaptionDataset,
682
+ data_path=mdpv_brief_caption_ade20k_data_path,
683
+ image_folder=mdpv_brief_caption_ade20k_image_path,
684
+ tokenizer=tokenizer,
685
+ image_processor=image_processor,
686
+ dataset_map_fn=mdpv_points_map_fn,
687
+ template_map_fn=dict(
688
+ type=template_map_fn_factory, template=prompt_template),
689
+ max_length=max_length,
690
+ pad_image_to_square=True,
691
+ debug=False,
692
+ repeats=1,
693
+ )
694
+
695
+ mdpv_brief_description_lvis_dataset = dict(
696
+ type=MDPVPointBriefCaptionDataset,
697
+ data_path=mdpv_brief_caption_lvis_data_path,
698
+ image_folder=mdpv_brief_caption_lvis_image_path,
699
+ tokenizer=tokenizer,
700
+ image_processor=image_processor,
701
+ dataset_map_fn=mdpv_points_map_fn,
702
+ template_map_fn=dict(
703
+ type=template_map_fn_factory, template=prompt_template),
704
+ max_length=max_length,
705
+ pad_image_to_square=True,
706
+ debug=False,
707
+ repeats=1,
708
+ )
709
+
710
+ mdpv_qa_vg_dataset = dict(
711
+ type=MDPVPointBriefCaptionDataset,
712
+ data_path=mdpv_qa_vg_data_path,
713
+ image_folder=mdpv_qa_vg_image_path,
714
+ tokenizer=tokenizer,
715
+ image_processor=image_processor,
716
+ dataset_map_fn=mdpv_points_map_fn,
717
+ template_map_fn=dict(
718
+ type=template_map_fn_factory, template=prompt_template),
719
+ max_length=max_length,
720
+ pad_image_to_square=True,
721
+ debug=False,
722
+ repeats=1,
723
+ )
724
+
725
+ mdpv_qa_ade20k_dataset = dict(
726
+ type=MDPVPointBriefCaptionDataset,
727
+ data_path=mdpv_qa_ade20k_data_path,
728
+ image_folder=mdpv_qa_ade20k_image_path,
729
+ tokenizer=tokenizer,
730
+ image_processor=image_processor,
731
+ dataset_map_fn=mdpv_points_map_fn,
732
+ template_map_fn=dict(
733
+ type=template_map_fn_factory, template=prompt_template),
734
+ max_length=max_length,
735
+ pad_image_to_square=True,
736
+ debug=False,
737
+ repeats=1,
738
+ )
739
+
740
+ mdpv_qa_lvis_dataset = dict(
741
+ type=MDPVPointBriefCaptionDataset,
742
+ data_path=mdpv_qa_lvis_data_path,
743
+ image_folder=mdpv_qa_lvis_image_path,
744
+ tokenizer=tokenizer,
745
+ image_processor=image_processor,
746
+ dataset_map_fn=mdpv_points_map_fn,
747
+ template_map_fn=dict(
748
+ type=template_map_fn_factory, template=prompt_template),
749
+ max_length=max_length,
750
+ pad_image_to_square=True,
751
+ debug=False,
752
+ repeats=1,
753
+ )
754
+
755
+ mdpv_qa_cocostuff10k_dataset = dict(
756
+ type=MDPVPointBriefCaptionDataset,
757
+ data_path=mdpv_qa_cocostuff10k_data_path,
758
+ image_folder=mdpv_qa_cocostuff10k_image_path,
759
+ tokenizer=tokenizer,
760
+ image_processor=image_processor,
761
+ dataset_map_fn=mdpv_points_map_fn,
762
+ template_map_fn=dict(
763
+ type=template_map_fn_factory, template=prompt_template),
764
+ max_length=max_length,
765
+ pad_image_to_square=True,
766
+ debug=False,
767
+ repeats=1,
768
+ )
769
+
770
+ mdpv_qa_cocostuff164k_dataset = dict(
771
+ type=MDPVPointBriefCaptionDataset,
772
+ data_path=mdpv_qa_cocostuff164k_data_path,
773
+ image_folder=mdpv_qa_cocostuff164k_image_path,
774
+ tokenizer=tokenizer,
775
+ image_processor=image_processor,
776
+ dataset_map_fn=mdpv_points_map_fn,
777
+ template_map_fn=dict(
778
+ type=template_map_fn_factory, template=prompt_template),
779
+ max_length=max_length,
780
+ pad_image_to_square=True,
781
+ debug=False,
782
+ repeats=1,
783
+ )
784
+
785
+ mdpv_multi_points_openpsg_dataset = dict(
786
+ type=MDPVPointBriefCaptionDataset,
787
+ data_path=mdpv_multi_points_openpsg_data_path,
788
+ image_folder=mdpv_multi_points_openpsg_image_path,
789
+ tokenizer=tokenizer,
790
+ image_processor=image_processor,
791
+ dataset_map_fn=mdpv_points_map_fn,
792
+ template_map_fn=dict(
793
+ type=template_map_fn_factory, template=prompt_template),
794
+ max_length=max_length,
795
+ pad_image_to_square=True,
796
+ debug=False,
797
+ repeats=1,
798
+ )
799
+
800
+ mdpv_multi_points_flicker30k_dataset = dict(
801
+ type=MDPVPointBriefCaptionDataset,
802
+ data_path=mdpv_multi_points_flicker30k_data_path,
803
+ image_folder=mdpv_multi_points_flicker30k_image_path,
804
+ tokenizer=tokenizer,
805
+ image_processor=image_processor,
806
+ dataset_map_fn=mdpv_points_map_fn,
807
+ template_map_fn=dict(
808
+ type=template_map_fn_factory, template=prompt_template),
809
+ max_length=max_length,
810
+ pad_image_to_square=True,
811
+ debug=False,
812
+ repeats=1,
813
+ )
814
+
815
+ train_dataset = dict(
816
+ type=CombineDataset,
817
+ datasets_cfgs=[mdpv_brief_description_lvis_dataset,],
818
+ )
819
+
820
+ train_dataloader = dict(
821
+ batch_size=batch_size,
822
+ num_workers=dataloader_num_workers,
823
+ dataset=train_dataset,
824
+ sampler=dict(
825
+ type=LengthGroupedSampler,
826
+ length_property='modality_length',
827
+ per_device_batch_size=batch_size * accumulative_counts),
828
+ collate_fn=dict(type=omg_llava_collate_fn))
829
+
830
+ #######################################################################
831
+ # PART 4 Scheduler & Optimizer #
832
+ #######################################################################
833
+ # optimizer
834
+ optim_wrapper = dict(
835
+ type=AmpOptimWrapper,
836
+ optimizer=dict(
837
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
838
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
839
+ accumulative_counts=accumulative_counts,
840
+ loss_scale='dynamic',
841
+ dtype='float16')
842
+
843
+ # learning policy
844
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
845
+ param_scheduler = [
846
+ dict(
847
+ type=LinearLR,
848
+ start_factor=1e-5,
849
+ by_epoch=True,
850
+ begin=0,
851
+ end=warmup_ratio * max_epochs,
852
+ convert_to_iter_based=True),
853
+ dict(
854
+ type=CosineAnnealingLR,
855
+ eta_min=0.0,
856
+ by_epoch=True,
857
+ begin=warmup_ratio * max_epochs,
858
+ end=max_epochs,
859
+ convert_to_iter_based=True)
860
+ ]
861
+
862
+ # train, val, test setting
863
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
864
+
865
+ #######################################################################
866
+ # PART 5 Runtime #
867
+ #######################################################################
868
+ # Log the dialogue periodically during the training process, optional
869
+ custom_hooks = [
870
+ dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer),
871
+ dict(
872
+ type=EvaluateChatHook_withSpecialTokens,
873
+ tokenizer=tokenizer,
874
+ image_processor=image_processor,
875
+ every_n_iters=evaluation_freq,
876
+ evaluation_inputs=evaluation_inputs,
877
+ evaluation_images=evaluation_images,
878
+ system=SYSTEM,
879
+ prompt_template=prompt_template)
880
+ ]
881
+
882
+ # configure default hooks
883
+ default_hooks = dict(
884
+ # record the time of every iteration.
885
+ timer=dict(type=IterTimerHook),
886
+ # print log every 10 iterations.
887
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
888
+ # enable the parameter scheduler.
889
+ param_scheduler=dict(type=ParamSchedulerHook),
890
+ # save checkpoint per `save_steps`.
891
+ checkpoint=dict(
892
+ type=CheckpointHook,
893
+ by_epoch=False,
894
+ interval=save_steps,
895
+ max_keep_ckpts=save_total_limit),
896
+ # set sampler seed in distributed evrionment.
897
+ sampler_seed=dict(type=DistSamplerSeedHook),
898
+ )
899
+
900
+ # configure environment
901
+ env_cfg = dict(
902
+ # whether to enable cudnn benchmark
903
+ cudnn_benchmark=False,
904
+ # set multi process parameters
905
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
906
+ # set distributed parameters
907
+ dist_cfg=dict(backend='nccl'),
908
+ )
909
+
910
+ # set visualizer
911
+ visualizer = None
912
+
913
+ # set log level
914
+ log_level = 'INFO'
915
+
916
+ # load from which checkpoint
917
+ load_from = None
918
+
919
+ # whether to resume training from the loaded checkpoint
920
+ resume = False
921
+
922
+ # Defaults to use random seed and disable `deterministic`
923
+ randomness = dict(seed=None, deterministic=False)
924
+
925
+ # set log processor
926
+ log_processor = dict(by_epoch=False)
omg_llava/configs/finetune/debug.py ADDED
@@ -0,0 +1,967 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
4
+ LoggerHook, ParamSchedulerHook)
5
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
6
+ from peft import LoraConfig
7
+ from torch.optim import AdamW
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig, CLIPImageProcessor,
10
+ CLIPVisionModel)
11
+
12
+ from omg_llava.dataset import LLaVADataset
13
+ from omg_llava.dataset.collect_fns import omg_llava_collate_fn
14
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
15
+ from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\
16
+ CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\
17
+ ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\
18
+ PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\
19
+ RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\
20
+ Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\
21
+ OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\
22
+ OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\
23
+ MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset
24
+ from xtuner.dataset.samplers import LengthGroupedSampler
25
+ from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens
26
+ from xtuner.engine.runner import TrainLoop
27
+ from omg_llava.model import OMG_LLaVA
28
+ from xtuner.utils import PROMPT_TEMPLATE
29
+ from omg_llava.model import OpenCLIPBackbone_omgseg
30
+ from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead
31
+
32
+ from torch.nn import GroupNorm, ReLU
33
+
34
+ from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \
35
+ DiceLoss, MaskFormerFusionHead, FocalLoss
36
+ from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost
37
+ from mmdet.models.task_modules.samplers import MaskPseudoSampler
38
+
39
+ #######################################################################
40
+ # PART 1 Settings #
41
+ #######################################################################
42
+ # Model
43
+ llm_name_or_path = './pretrained/omg_llava/internlm2-7b' # Please change to your own path
44
+ pretrained_pth = './pretrained/omg_llava/omg_llava_7b_pretrain_1024image_8gpus.pth'
45
+ omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path
46
+ omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path
47
+
48
+ # Data
49
+ data_root = './data/llava_data/'
50
+ data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json'
51
+ image_folder = data_root + 'llava_images'
52
+
53
+ glamm_data_root = './data/glamm_data/'
54
+
55
+ refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/'
56
+ refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json'
57
+
58
+ grandf_image_path = glamm_data_root + 'images/grandf/train/'
59
+ grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json'
60
+
61
+ flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/'
62
+ flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json'
63
+
64
+ psg_image_path = glamm_data_root + 'images/coco2017/'
65
+ psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json'
66
+
67
+ ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
68
+ ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json'
69
+
70
+ cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/'
71
+ cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt'
72
+ cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/'
73
+
74
+ mapillary_image_path = './data/semantic_seg/mapillary/training/images/'
75
+ mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json'
76
+ mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/'
77
+
78
+ pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/'
79
+ pascal_file = './data/semantic_seg/pascal_part/train.json'
80
+
81
+ paco_image_path = './data/glamm_data/images/coco2017/'
82
+ paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json'
83
+
84
+ referring_refcoco_image_path = refcocog_image_path
85
+ referring_refcoco_data_path = "./data/ref_seg/"
86
+
87
+ referring_refcoco_plus_image_path = refcocog_image_path
88
+ referring_refcoco_plus_data_path = "./data/ref_seg/"
89
+
90
+ referring_refcocog_image_path = refcocog_image_path
91
+ referring_refcocog_data_path = "./data/ref_seg/"
92
+
93
+ referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/"
94
+ referring_refclef_data_path = "./data/ref_seg/"
95
+
96
+ region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
97
+ region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json"
98
+
99
+ region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
100
+ region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json"
101
+
102
+ mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
103
+ mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json'
104
+
105
+ mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
106
+ mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json'
107
+
108
+ mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
109
+ mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json'
110
+
111
+ mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
112
+ mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json'
113
+
114
+ mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
115
+ mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json'
116
+
117
+ mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
118
+ mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json'
119
+
120
+ mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
121
+ mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json'
122
+
123
+ mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
124
+ mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json'
125
+
126
+ mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
127
+ mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json'
128
+
129
+ mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
130
+ mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json'
131
+
132
+ mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
133
+ mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json'
134
+
135
+ mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017'
136
+ mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json'
137
+
138
+ mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
139
+ mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json'
140
+
141
+ mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
142
+ mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json'
143
+
144
+ mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/'
145
+ mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json'
146
+
147
+ mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017'
148
+ mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json'
149
+
150
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
151
+ max_length = int(2048 - (1024 / 64)**2 - 100)
152
+
153
+ # Scheduler & Optimizer
154
+ batch_size = 8 # per_device
155
+ accumulative_counts = 2
156
+ dataloader_num_workers = 4
157
+ max_epochs = 1
158
+ optim_type = AdamW
159
+ lr = 2e-4
160
+ betas = (0.9, 0.999)
161
+ weight_decay = 0
162
+ max_norm = 1 # grad clip
163
+ warmup_ratio = 0.03
164
+
165
+
166
+ # Save
167
+ save_steps = 2000
168
+ save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited)
169
+
170
+ # Evaluate the generation performance during the training
171
+ evaluation_freq = 2000
172
+ SYSTEM = ''
173
+ evaluation_images = './work_dirs/test.jpg'
174
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture',
175
+ 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.']
176
+
177
+ #######################################################################
178
+ # PART 2 Model & Tokenizer & Image Processor #
179
+ #######################################################################
180
+ tokenizer = dict(
181
+ type=AutoTokenizer.from_pretrained,
182
+ pretrained_model_name_or_path=llm_name_or_path,
183
+ trust_remote_code=True,
184
+ padding_side='right')
185
+
186
+ image_processor = dict(
187
+ type=CLIPImageProcessor,
188
+ do_resize=True,
189
+ size=1024,
190
+ resample=3,
191
+ do_center_crop=True,
192
+ crop_size=1024,
193
+ do_rescale=True,
194
+ do_normalize=True,
195
+ image_mean=[0.4814, 0.4578, 0.4082],
196
+ image_std=[0.2686, 0.2613, 0.2757],
197
+ do_convert_rgb=True
198
+ )
199
+
200
+ class_embed = 'convnext_large_d_320_CocoPanopticOVDataset'
201
+ num_things_classes = 80
202
+ num_stuff_classes = 53
203
+ num_classes = num_things_classes + num_stuff_classes
204
+
205
+ omgseg_model = dict(
206
+ type=OMGSegVisualEncoder,
207
+ data_preprocessor=None,
208
+ pixel_shuffle_down_ratio=2,
209
+ backbone=dict(
210
+ type=OpenCLIPBackbone_omgseg,
211
+ model_name='convnext_large_d_320',
212
+ fix=True,
213
+ init_cfg=dict(
214
+ type='clip_pretrain',
215
+ checkpoint='laion2b_s29b_b131k_ft_soup'
216
+ )
217
+ ),
218
+ panoptic_head=dict(
219
+ type=Mask2FormerVideoSemSamHead,
220
+ sphere_cls=True,
221
+ ov_path=omg_ov_class_embed_path,
222
+ enable_box_query=False,
223
+ ov_classifier_name=class_embed,
224
+ logit=None,
225
+ in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside
226
+ strides=[4, 8, 16, 32],
227
+ feat_channels=256,
228
+ out_channels=256,
229
+ num_things_classes=num_things_classes,
230
+ num_stuff_classes=num_stuff_classes,
231
+ num_queries=300,
232
+ num_transformer_feat_level=3,
233
+ pixel_decoder=dict(
234
+ type=MSDeformAttnPixelDecoder,
235
+ num_outs=3,
236
+ norm_cfg=dict(type=GroupNorm, num_groups=32),
237
+ act_cfg=dict(type=ReLU),
238
+ encoder=dict( # DeformableDetrTransformerEncoder
239
+ num_layers=6,
240
+ layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
241
+ self_attn_cfg=dict( # MultiScaleDeformableAttention
242
+ embed_dims=256,
243
+ num_heads=8,
244
+ num_levels=3,
245
+ num_points=4,
246
+ dropout=0.0,
247
+ batch_first=True),
248
+ ffn_cfg=dict(
249
+ embed_dims=256,
250
+ feedforward_channels=1024,
251
+ num_fcs=2,
252
+ ffn_drop=0.0,
253
+ act_cfg=dict(type=ReLU, inplace=True)))),
254
+ positional_encoding=dict(num_feats=128, normalize=True)),
255
+ enforce_decoder_input_project=False,
256
+ positional_encoding=dict(num_feats=128, normalize=True),
257
+ transformer_decoder=dict( # Mask2FormerTransformerDecoder
258
+ return_intermediate=True,
259
+ num_layers=9,
260
+ layer_cfg=dict( # Mask2FormerTransformerDecoderLayer
261
+ self_attn_cfg=dict( # MultiheadAttention
262
+ embed_dims=256,
263
+ num_heads=8,
264
+ dropout=0.0,
265
+ batch_first=True),
266
+ cross_attn_cfg=dict( # MultiheadAttention
267
+ embed_dims=256,
268
+ num_heads=8,
269
+ dropout=0.0,
270
+ batch_first=True),
271
+ ffn_cfg=dict(
272
+ embed_dims=256,
273
+ feedforward_channels=2048,
274
+ num_fcs=2,
275
+ ffn_drop=0.0,
276
+ act_cfg=dict(type='ReLU', inplace=True))),
277
+ init_cfg=None),
278
+ loss_cls=dict(
279
+ type=CrossEntropyLoss,
280
+ use_sigmoid=False,
281
+ loss_weight=2.0,
282
+ reduction='mean',
283
+ class_weight=[1.0] * 240 + [0.1]),
284
+ loss_mask=dict(
285
+ type=CrossEntropyLoss,
286
+ use_sigmoid=True,
287
+ reduction='mean',
288
+ loss_weight=5.0),
289
+ loss_dice=dict(
290
+ type=DiceLoss,
291
+ use_sigmoid=True,
292
+ activate=True,
293
+ reduction='mean',
294
+ naive_dice=True,
295
+ eps=1.0,
296
+ loss_weight=5.0),
297
+ loss_iou=dict(
298
+ type=FocalLoss,
299
+ use_sigmoid=True,
300
+ loss_weight=2.0,
301
+ reduction='mean')
302
+ ),
303
+ panoptic_fusion_head=dict(
304
+ type=MaskFormerFusionHead,
305
+ num_things_classes=num_things_classes,
306
+ num_stuff_classes=num_stuff_classes,
307
+ loss_panoptic=None,
308
+ init_cfg=None),
309
+ train_cfg=dict(
310
+ num_points=12544,
311
+ oversample_ratio=3.0,
312
+ importance_sample_ratio=0.75,
313
+ assigner=dict(
314
+ type=HungarianAssigner,
315
+ match_costs=[
316
+ # dict(type=FlexibleClassificationCost, weight=2.0),
317
+ dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True),
318
+ dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0)
319
+ ]),
320
+ sampler=dict(type=MaskPseudoSampler)),
321
+ test_cfg=dict(
322
+ panoptic_on=True,
323
+ # For now, the dataset does not support
324
+ # evaluating semantic segmentation metric.
325
+ semantic_on=False,
326
+ instance_on=True,
327
+ # max_per_image is for instance segmentation.
328
+ max_per_image=100,
329
+ iou_thr=0.8,
330
+ # In Mask2Former's panoptic postprocessing,
331
+ # it will filter mask area where score is less than 0.5 .
332
+ filter_low_score=True),
333
+ init_cfg=dict(
334
+ type='Pretrained',
335
+ checkpoint=omg_head_pretrain_pth_path,
336
+ )
337
+ )
338
+
339
+ model = dict(
340
+ type=OMG_LLaVA,
341
+ freeze_llm=True,
342
+ freeze_visual_encoder=True,
343
+ require_omg_decoder=False,
344
+ pretrained_pth=pretrained_pth,
345
+ text2vision_projector=True,
346
+ pixel_shuffle_ratio=2,
347
+ llm=dict(
348
+ type=AutoModelForCausalLM.from_pretrained,
349
+ pretrained_model_name_or_path=llm_name_or_path,
350
+ trust_remote_code=True,
351
+ torch_dtype=torch.float16,
352
+ quantization_config=dict(
353
+ type=BitsAndBytesConfig,
354
+ load_in_4bit=True,
355
+ load_in_8bit=False,
356
+ llm_int8_threshold=6.0,
357
+ llm_int8_has_fp16_weight=False,
358
+ bnb_4bit_compute_dtype=torch.float16,
359
+ bnb_4bit_use_double_quant=True,
360
+ bnb_4bit_quant_type='nf4')),
361
+ llm_lora=dict(
362
+ type=LoraConfig,
363
+ r=512,
364
+ lora_alpha=256,
365
+ lora_dropout=0.05,
366
+ bias='none',
367
+ task_type='CAUSAL_LM'),
368
+ visual_encoder=omgseg_model,
369
+ tokenizer=tokenizer,
370
+ )
371
+
372
+ #######################################################################
373
+ # PART 3 Dataset & Dataloader #
374
+ #######################################################################
375
+ debug=False
376
+ llava_dataset = dict(
377
+ type=LLaVADataset,
378
+ data_path=data_path,
379
+ image_folder=image_folder,
380
+ tokenizer=tokenizer,
381
+ image_processor=image_processor,
382
+ dataset_map_fn=llava_map_fn,
383
+ template_map_fn=dict(
384
+ type=template_map_fn_factory, template=prompt_template),
385
+ max_length=max_length,
386
+ pad_image_to_square=True)
387
+
388
+ glamm_refcocog_dataset = dict(
389
+ type=RefCOCOgGCGDataset,
390
+ data_path=refcocog_ann_file,
391
+ image_folder=refcocog_image_path,
392
+ tokenizer=tokenizer,
393
+ image_processor=image_processor,
394
+ dataset_map_fn=glamm_refcocog_map_fn,
395
+ template_map_fn=dict(
396
+ type=template_map_fn_factory, template=prompt_template),
397
+ max_length=max_length,
398
+ pad_image_to_square=True,
399
+ debug=False,
400
+ repeats=1,
401
+ )
402
+
403
+ glamm_grandf_dataset = dict(
404
+ type=GranDfGCGDataset,
405
+ data_path=grandf_ann_file,
406
+ image_folder=grandf_image_path,
407
+ tokenizer=tokenizer,
408
+ image_processor=image_processor,
409
+ dataset_map_fn=glamm_granf_map_fn,
410
+ template_map_fn=dict(
411
+ type=template_map_fn_factory, template=prompt_template),
412
+ max_length=max_length,
413
+ pad_image_to_square=True,
414
+ debug=debug,
415
+ repeats=10,
416
+ )
417
+
418
+ glamm_psg_dataset = dict(
419
+ type=OpenPsgGCGDataset,
420
+ data_path=psg_ann_file,
421
+ image_folder=psg_image_path,
422
+ tokenizer=tokenizer,
423
+ image_processor=image_processor,
424
+ dataset_map_fn=glamm_openpsg_map_fn,
425
+ template_map_fn=dict(
426
+ type=template_map_fn_factory, template=prompt_template),
427
+ max_length=max_length,
428
+ pad_image_to_square=True,
429
+ debug=debug,
430
+ repeats=1,
431
+ )
432
+
433
+ glamm_flickr_dataset = dict(
434
+ type=FlickrGCGDataset,
435
+ data_path=flickr_ann_file,
436
+ image_folder=flickr_image_path,
437
+ tokenizer=tokenizer,
438
+ image_processor=image_processor,
439
+ dataset_map_fn=glamm_flickr_map_fn,
440
+ template_map_fn=dict(
441
+ type=template_map_fn_factory, template=prompt_template),
442
+ max_length=max_length,
443
+ pad_image_to_square=True,
444
+ debug=debug,
445
+ repeats=1,
446
+ )
447
+
448
+ semantic_seg_ade20k_dataset = dict(
449
+ type=ADE20kSemanticSegDataset,
450
+ data_path=ade20k_class_file,
451
+ image_folder=ade20k_image_path,
452
+ tokenizer=tokenizer,
453
+ image_processor=image_processor,
454
+ dataset_map_fn=semantic_seg_map_fn,
455
+ template_map_fn=dict(
456
+ type=template_map_fn_factory, template=prompt_template),
457
+ max_length=max_length,
458
+ pad_image_to_square=True,
459
+ debug=False,
460
+ repeats=1,
461
+ )
462
+
463
+ semantic_seg_cocostuff_dataset = dict(
464
+ type=COCOStuffSemanticSegDataset,
465
+ data_path=cocostuff_class_file,
466
+ image_folder=cocostuff_image_path,
467
+ label_path=cocostuff_label_path,
468
+ tokenizer=tokenizer,
469
+ image_processor=image_processor,
470
+ dataset_map_fn=semantic_seg_map_fn,
471
+ template_map_fn=dict(
472
+ type=template_map_fn_factory, template=prompt_template),
473
+ max_length=max_length,
474
+ pad_image_to_square=True,
475
+ debug=False,
476
+ repeats=1,
477
+ )
478
+
479
+ semantic_seg_mapillary_dataset = dict(
480
+ type=MapillarySemanticSegDataset,
481
+ data_path=mapillary_class_file,
482
+ image_folder=mapillary_image_path,
483
+ label_path=mapillary_label_path,
484
+ tokenizer=tokenizer,
485
+ image_processor=image_processor,
486
+ dataset_map_fn=semantic_seg_map_fn,
487
+ template_map_fn=dict(
488
+ type=template_map_fn_factory, template=prompt_template),
489
+ max_length=max_length,
490
+ pad_image_to_square=True,
491
+ debug=False,
492
+ repeats=1,
493
+ )
494
+
495
+ semantic_seg_pascal_part_dataset = dict(
496
+ type=PascalPartSemanticSegDataset,
497
+ data_path=pascal_file,
498
+ image_folder=pascal_part_image_path,
499
+ tokenizer=tokenizer,
500
+ image_processor=image_processor,
501
+ dataset_map_fn=pascal_part_map_fn,
502
+ template_map_fn=dict(
503
+ type=template_map_fn_factory, template=prompt_template),
504
+ max_length=max_length,
505
+ pad_image_to_square=True,
506
+ debug=False,
507
+ repeats=1,
508
+ )
509
+
510
+ semantic_seg_paco_dataset = dict(
511
+ type=PacoSemanticSegDataset,
512
+ data_path=paco_file,
513
+ image_folder=paco_image_path,
514
+ tokenizer=tokenizer,
515
+ image_processor=image_processor,
516
+ dataset_map_fn=pascal_part_map_fn,
517
+ template_map_fn=dict(
518
+ type=template_map_fn_factory, template=prompt_template),
519
+ max_length=max_length,
520
+ pad_image_to_square=True,
521
+ debug=False,
522
+ repeats=1,
523
+ )
524
+
525
+ referring_seg_refcoco_dataset = dict(
526
+ type=RefcocoReferringSegDataset,
527
+ data_path=referring_refcoco_data_path,
528
+ image_folder=referring_refcoco_image_path,
529
+ tokenizer=tokenizer,
530
+ image_processor=image_processor,
531
+ dataset_map_fn=referring_seg_map_fn,
532
+ template_map_fn=dict(
533
+ type=template_map_fn_factory, template=prompt_template),
534
+ max_length=max_length,
535
+ pad_image_to_square=True,
536
+ debug=False,
537
+ repeats=1,
538
+ )
539
+
540
+ referring_seg_refcoco_plus_dataset = dict(
541
+ type=Refcoco_plus_ReferringSegDataset,
542
+ data_path=referring_refcoco_plus_data_path,
543
+ image_folder=referring_refcoco_plus_image_path,
544
+ tokenizer=tokenizer,
545
+ image_processor=image_processor,
546
+ dataset_map_fn=referring_seg_map_fn,
547
+ template_map_fn=dict(
548
+ type=template_map_fn_factory, template=prompt_template),
549
+ max_length=max_length,
550
+ pad_image_to_square=True,
551
+ debug=False,
552
+ repeats=1,
553
+ )
554
+
555
+ referring_seg_refcocog_dataset = dict(
556
+ type=Refcocog_ReferringSegDataset,
557
+ data_path=referring_refcocog_data_path,
558
+ image_folder=referring_refcocog_image_path,
559
+ tokenizer=tokenizer,
560
+ image_processor=image_processor,
561
+ dataset_map_fn=referring_seg_map_fn,
562
+ template_map_fn=dict(
563
+ type=template_map_fn_factory, template=prompt_template),
564
+ max_length=max_length,
565
+ pad_image_to_square=True,
566
+ debug=False,
567
+ repeats=1,
568
+ )
569
+
570
+ referring_seg_refclef_dataset = dict(
571
+ type=Refclef_ReferringSegDataset,
572
+ data_path=referring_refclef_data_path,
573
+ image_folder=referring_refclef_image_path,
574
+ tokenizer=tokenizer,
575
+ image_processor=image_processor,
576
+ dataset_map_fn=referring_seg_map_fn,
577
+ template_map_fn=dict(
578
+ type=template_map_fn_factory, template=prompt_template),
579
+ max_length=max_length,
580
+ pad_image_to_square=True,
581
+ debug=False,
582
+ repeats=1,
583
+ )
584
+
585
+ region_cap_osprey_dataset = dict(
586
+ type=OspreyRegionCaptionDataset,
587
+ data_path=region_cap_osprey_data_path,
588
+ image_folder=region_cap_osprey_image_path,
589
+ tokenizer=tokenizer,
590
+ image_processor=image_processor,
591
+ dataset_map_fn=osprey_region_caption_map_fn,
592
+ template_map_fn=dict(
593
+ type=template_map_fn_factory, template=prompt_template),
594
+ max_length=max_length,
595
+ pad_image_to_square=True,
596
+ debug=False,
597
+ repeats=1,
598
+ )
599
+
600
+ region_conversation_osprey_dataset = dict(
601
+ type=OspreyRegionConversationDataset,
602
+ data_path=region_conversation_osprey_data_path,
603
+ image_folder=region_conversation_osprey_image_path,
604
+ tokenizer=tokenizer,
605
+ image_processor=image_processor,
606
+ dataset_map_fn=osprey_region_conversation_map_fn,
607
+ template_map_fn=dict(
608
+ type=template_map_fn_factory, template=prompt_template),
609
+ max_length=max_length,
610
+ pad_image_to_square=True,
611
+ debug=False,
612
+ repeats=1,
613
+ )
614
+
615
+ mdpv_detailed_description_ade20k_dataset = dict(
616
+ type=MDPVPointDetailedCaptionDataset,
617
+ data_path=mdpv_detailed_caption_ade20k_data_path,
618
+ image_folder=mdpv_detailed_caption_ade20k_image_path,
619
+ tokenizer=tokenizer,
620
+ image_processor=image_processor,
621
+ dataset_map_fn=mdpv_points_map_fn,
622
+ template_map_fn=dict(
623
+ type=template_map_fn_factory, template=prompt_template),
624
+ max_length=max_length,
625
+ pad_image_to_square=True,
626
+ debug=False,
627
+ repeats=1,
628
+ )
629
+
630
+ mdpv_detailed_description_cocostuff_10k_dataset = dict(
631
+ type=MDPVPointDetailedCaptionDataset,
632
+ data_path=mdpv_detailed_caption_cocostuff_10k_data_path,
633
+ image_folder=mdpv_detailed_caption_cocostuff_10k_image_path,
634
+ tokenizer=tokenizer,
635
+ image_processor=image_processor,
636
+ dataset_map_fn=mdpv_points_map_fn,
637
+ template_map_fn=dict(
638
+ type=template_map_fn_factory, template=prompt_template),
639
+ max_length=max_length,
640
+ pad_image_to_square=True,
641
+ debug=False,
642
+ repeats=1,
643
+ )
644
+
645
+ mdpv_detailed_description_cocostuff_164k_dataset = dict(
646
+ type=MDPVPointDetailedCaptionDataset,
647
+ data_path=mdpv_detailed_caption_cocostuff_164k_data_path,
648
+ image_folder=mdpv_detailed_caption_cocostuff_164k_image_path,
649
+ tokenizer=tokenizer,
650
+ image_processor=image_processor,
651
+ dataset_map_fn=mdpv_points_map_fn,
652
+ template_map_fn=dict(
653
+ type=template_map_fn_factory, template=prompt_template),
654
+ max_length=max_length,
655
+ pad_image_to_square=True,
656
+ debug=False,
657
+ repeats=1,
658
+ )
659
+
660
+ mdpv_detailed_description_vg_dataset = dict(
661
+ type=MDPVPointDetailedCaptionDataset,
662
+ data_path=mdpv_detailed_caption_vg_data_path,
663
+ image_folder=mdpv_detailed_caption_vg_image_path,
664
+ tokenizer=tokenizer,
665
+ image_processor=image_processor,
666
+ dataset_map_fn=mdpv_points_map_fn,
667
+ template_map_fn=dict(
668
+ type=template_map_fn_factory, template=prompt_template),
669
+ max_length=max_length,
670
+ pad_image_to_square=True,
671
+ debug=False,
672
+ repeats=1,
673
+ )
674
+
675
+ mdpv_brief_description_vg_dataset = dict(
676
+ type=MDPVPointBriefCaptionDataset,
677
+ data_path=mdpv_brief_caption_vg_data_path,
678
+ image_folder=mdpv_brief_caption_vg_image_path,
679
+ tokenizer=tokenizer,
680
+ image_processor=image_processor,
681
+ dataset_map_fn=mdpv_points_map_fn,
682
+ template_map_fn=dict(
683
+ type=template_map_fn_factory, template=prompt_template),
684
+ max_length=max_length,
685
+ pad_image_to_square=True,
686
+ debug=False,
687
+ repeats=1,
688
+ )
689
+
690
+ mdpv_brief_description_cocostuff10k_dataset = dict(
691
+ type=MDPVPointBriefCaptionDataset,
692
+ data_path=mdpv_brief_caption_cocostuff_10k_data_path,
693
+ image_folder=mdpv_brief_caption_cocostuff_10k_image_path,
694
+ tokenizer=tokenizer,
695
+ image_processor=image_processor,
696
+ dataset_map_fn=mdpv_points_map_fn,
697
+ template_map_fn=dict(
698
+ type=template_map_fn_factory, template=prompt_template),
699
+ max_length=max_length,
700
+ pad_image_to_square=True,
701
+ debug=False,
702
+ repeats=1,
703
+ )
704
+
705
+ mdpv_brief_description_cocostuff164k_dataset = dict(
706
+ type=MDPVPointBriefCaptionDataset,
707
+ data_path=mdpv_brief_caption_cocostuff_164k_data_path,
708
+ image_folder=mdpv_brief_caption_cocostuff_164k_image_path,
709
+ tokenizer=tokenizer,
710
+ image_processor=image_processor,
711
+ dataset_map_fn=mdpv_points_map_fn,
712
+ template_map_fn=dict(
713
+ type=template_map_fn_factory, template=prompt_template),
714
+ max_length=max_length,
715
+ pad_image_to_square=True,
716
+ debug=False,
717
+ repeats=1,
718
+ )
719
+
720
+ mdpv_brief_description_ade20k_dataset = dict(
721
+ type=MDPVPointBriefCaptionDataset,
722
+ data_path=mdpv_brief_caption_ade20k_data_path,
723
+ image_folder=mdpv_brief_caption_ade20k_image_path,
724
+ tokenizer=tokenizer,
725
+ image_processor=image_processor,
726
+ dataset_map_fn=mdpv_points_map_fn,
727
+ template_map_fn=dict(
728
+ type=template_map_fn_factory, template=prompt_template),
729
+ max_length=max_length,
730
+ pad_image_to_square=True,
731
+ debug=False,
732
+ repeats=1,
733
+ )
734
+
735
+ mdpv_brief_description_lvis_dataset = dict(
736
+ type=MDPVPointBriefCaptionDataset,
737
+ data_path=mdpv_brief_caption_lvis_data_path,
738
+ image_folder=mdpv_brief_caption_lvis_image_path,
739
+ tokenizer=tokenizer,
740
+ image_processor=image_processor,
741
+ dataset_map_fn=mdpv_points_map_fn,
742
+ template_map_fn=dict(
743
+ type=template_map_fn_factory, template=prompt_template),
744
+ max_length=max_length,
745
+ pad_image_to_square=True,
746
+ debug=False,
747
+ repeats=1,
748
+ )
749
+
750
+ mdpv_qa_vg_dataset = dict(
751
+ type=MDPVPointBriefCaptionDataset,
752
+ data_path=mdpv_qa_vg_data_path,
753
+ image_folder=mdpv_qa_vg_image_path,
754
+ tokenizer=tokenizer,
755
+ image_processor=image_processor,
756
+ dataset_map_fn=mdpv_points_map_fn,
757
+ template_map_fn=dict(
758
+ type=template_map_fn_factory, template=prompt_template),
759
+ max_length=max_length,
760
+ pad_image_to_square=True,
761
+ debug=False,
762
+ repeats=1,
763
+ )
764
+
765
+ mdpv_qa_ade20k_dataset = dict(
766
+ type=MDPVPointBriefCaptionDataset,
767
+ data_path=mdpv_qa_ade20k_data_path,
768
+ image_folder=mdpv_qa_ade20k_image_path,
769
+ tokenizer=tokenizer,
770
+ image_processor=image_processor,
771
+ dataset_map_fn=mdpv_points_map_fn,
772
+ template_map_fn=dict(
773
+ type=template_map_fn_factory, template=prompt_template),
774
+ max_length=max_length,
775
+ pad_image_to_square=True,
776
+ debug=False,
777
+ repeats=1,
778
+ )
779
+
780
+ mdpv_qa_lvis_dataset = dict(
781
+ type=MDPVPointBriefCaptionDataset,
782
+ data_path=mdpv_qa_lvis_data_path,
783
+ image_folder=mdpv_qa_lvis_image_path,
784
+ tokenizer=tokenizer,
785
+ image_processor=image_processor,
786
+ dataset_map_fn=mdpv_points_map_fn,
787
+ template_map_fn=dict(
788
+ type=template_map_fn_factory, template=prompt_template),
789
+ max_length=max_length,
790
+ pad_image_to_square=True,
791
+ debug=False,
792
+ repeats=1,
793
+ )
794
+
795
+ mdpv_qa_cocostuff10k_dataset = dict(
796
+ type=MDPVPointBriefCaptionDataset,
797
+ data_path=mdpv_qa_cocostuff10k_data_path,
798
+ image_folder=mdpv_qa_cocostuff10k_image_path,
799
+ tokenizer=tokenizer,
800
+ image_processor=image_processor,
801
+ dataset_map_fn=mdpv_points_map_fn,
802
+ template_map_fn=dict(
803
+ type=template_map_fn_factory, template=prompt_template),
804
+ max_length=max_length,
805
+ pad_image_to_square=True,
806
+ debug=False,
807
+ repeats=1,
808
+ )
809
+
810
+ mdpv_qa_cocostuff164k_dataset = dict(
811
+ type=MDPVPointBriefCaptionDataset,
812
+ data_path=mdpv_qa_cocostuff164k_data_path,
813
+ image_folder=mdpv_qa_cocostuff164k_image_path,
814
+ tokenizer=tokenizer,
815
+ image_processor=image_processor,
816
+ dataset_map_fn=mdpv_points_map_fn,
817
+ template_map_fn=dict(
818
+ type=template_map_fn_factory, template=prompt_template),
819
+ max_length=max_length,
820
+ pad_image_to_square=True,
821
+ debug=False,
822
+ repeats=1,
823
+ )
824
+
825
+ mdpv_multi_points_openpsg_dataset = dict(
826
+ type=MDPVPointBriefCaptionDataset,
827
+ data_path=mdpv_multi_points_openpsg_data_path,
828
+ image_folder=mdpv_multi_points_openpsg_image_path,
829
+ tokenizer=tokenizer,
830
+ image_processor=image_processor,
831
+ dataset_map_fn=mdpv_points_map_fn,
832
+ template_map_fn=dict(
833
+ type=template_map_fn_factory, template=prompt_template),
834
+ max_length=max_length,
835
+ pad_image_to_square=True,
836
+ debug=False,
837
+ repeats=1,
838
+ )
839
+
840
+ mdpv_multi_points_flicker30k_dataset = dict(
841
+ type=MDPVPointBriefCaptionDataset,
842
+ data_path=mdpv_multi_points_flicker30k_data_path,
843
+ image_folder=mdpv_multi_points_flicker30k_image_path,
844
+ tokenizer=tokenizer,
845
+ image_processor=image_processor,
846
+ dataset_map_fn=mdpv_points_map_fn,
847
+ template_map_fn=dict(
848
+ type=template_map_fn_factory, template=prompt_template),
849
+ max_length=max_length,
850
+ pad_image_to_square=True,
851
+ debug=False,
852
+ repeats=1,
853
+ )
854
+
855
+ train_dataset = dict(
856
+ type=CombineDataset,
857
+ datasets_cfgs=[glamm_refcocog_dataset,
858
+ ],
859
+ )
860
+
861
+ train_dataloader = dict(
862
+ batch_size=batch_size,
863
+ num_workers=dataloader_num_workers,
864
+ dataset=train_dataset,
865
+ sampler=dict(
866
+ type=LengthGroupedSampler,
867
+ length_property='modality_length',
868
+ per_device_batch_size=batch_size * accumulative_counts),
869
+ collate_fn=dict(type=omg_llava_collate_fn))
870
+
871
+ #######################################################################
872
+ # PART 4 Scheduler & Optimizer #
873
+ #######################################################################
874
+ # optimizer
875
+ optim_wrapper = dict(
876
+ type=AmpOptimWrapper,
877
+ optimizer=dict(
878
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
879
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
880
+ accumulative_counts=accumulative_counts,
881
+ loss_scale='dynamic',
882
+ dtype='float16')
883
+
884
+ # learning policy
885
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
886
+ param_scheduler = [
887
+ dict(
888
+ type=LinearLR,
889
+ start_factor=1e-5,
890
+ by_epoch=True,
891
+ begin=0,
892
+ end=warmup_ratio * max_epochs,
893
+ convert_to_iter_based=True),
894
+ dict(
895
+ type=CosineAnnealingLR,
896
+ eta_min=0.0,
897
+ by_epoch=True,
898
+ begin=warmup_ratio * max_epochs,
899
+ end=max_epochs,
900
+ convert_to_iter_based=True)
901
+ ]
902
+
903
+ # train, val, test setting
904
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
905
+
906
+ #######################################################################
907
+ # PART 5 Runtime #
908
+ #######################################################################
909
+ # Log the dialogue periodically during the training process, optional
910
+ custom_hooks = [
911
+ dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer),
912
+ dict(
913
+ type=EvaluateChatHook_withSpecialTokens,
914
+ tokenizer=tokenizer,
915
+ image_processor=image_processor,
916
+ every_n_iters=evaluation_freq,
917
+ evaluation_inputs=evaluation_inputs,
918
+ evaluation_images=evaluation_images,
919
+ system=SYSTEM,
920
+ prompt_template=prompt_template)
921
+ ]
922
+
923
+ # configure default hooks
924
+ default_hooks = dict(
925
+ # record the time of every iteration.
926
+ timer=dict(type=IterTimerHook),
927
+ # print log every 10 iterations.
928
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
929
+ # enable the parameter scheduler.
930
+ param_scheduler=dict(type=ParamSchedulerHook),
931
+ # save checkpoint per `save_steps`.
932
+ checkpoint=dict(
933
+ type=CheckpointHook,
934
+ by_epoch=False,
935
+ interval=save_steps,
936
+ max_keep_ckpts=save_total_limit),
937
+ # set sampler seed in distributed evrionment.
938
+ sampler_seed=dict(type=DistSamplerSeedHook),
939
+ )
940
+
941
+ # configure environment
942
+ env_cfg = dict(
943
+ # whether to enable cudnn benchmark
944
+ cudnn_benchmark=False,
945
+ # set multi process parameters
946
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
947
+ # set distributed parameters
948
+ dist_cfg=dict(backend='nccl'),
949
+ )
950
+
951
+ # set visualizer
952
+ visualizer = None
953
+
954
+ # set log level
955
+ log_level = 'INFO'
956
+
957
+ # load from which checkpoint
958
+ load_from = None
959
+
960
+ # whether to resume training from the loaded checkpoint
961
+ resume = False
962
+
963
+ # Defaults to use random seed and disable `deterministic`
964
+ randomness = dict(seed=None, deterministic=False)
965
+
966
+ # set log processor
967
+ log_processor = dict(by_epoch=False)
omg_llava/configs/finetune/fix_unfrozen_bug_omg_llava_7b_finetune_stage1_1024image_uniSegFormat_8gpus.py ADDED
@@ -0,0 +1,951 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
4
+ LoggerHook, ParamSchedulerHook)
5
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
6
+ from peft import LoraConfig
7
+ from torch.optim import AdamW
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig, CLIPImageProcessor,
10
+ CLIPVisionModel)
11
+
12
+ from omg_llava.dataset import LLaVADataset
13
+ from omg_llava.dataset.collect_fns import omg_llava_collate_fn
14
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
15
+ from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\
16
+ CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\
17
+ ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\
18
+ PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\
19
+ RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\
20
+ Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\
21
+ OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\
22
+ OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\
23
+ MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\
24
+ semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\
25
+ referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn
26
+ from xtuner.dataset.samplers import LengthGroupedSampler
27
+ from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens
28
+ from xtuner.engine.runner import TrainLoop
29
+ from omg_llava.model import OMG_LLaVA
30
+ from xtuner.utils import PROMPT_TEMPLATE
31
+ from omg_llava.model import OpenCLIPBackbone_omgseg
32
+ from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead
33
+
34
+ from torch.nn import GroupNorm, ReLU
35
+
36
+ from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \
37
+ DiceLoss, MaskFormerFusionHead, FocalLoss
38
+ from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost
39
+ from mmdet.models.task_modules.samplers import MaskPseudoSampler
40
+
41
+ #######################################################################
42
+ # PART 1 Settings #
43
+ #######################################################################
44
+ # Model
45
+ llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path
46
+ pretrained_pth = './pretrained/omg_llava/omg_llava_7b_pretrain_1024image_8gpus.pth'
47
+ omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path
48
+ omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path
49
+
50
+ # Data
51
+ data_root = './data/llava_data/'
52
+ data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json'
53
+ image_folder = data_root + 'llava_images'
54
+
55
+ glamm_data_root = './data/glamm_data/'
56
+
57
+ refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/'
58
+ refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json'
59
+
60
+ grandf_image_path = glamm_data_root + 'images/grandf/train/'
61
+ grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json'
62
+
63
+ flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/'
64
+ flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json'
65
+
66
+ psg_image_path = glamm_data_root + 'images/coco2017/'
67
+ psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json'
68
+
69
+ ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
70
+ ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json'
71
+
72
+ cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/'
73
+ cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt'
74
+ cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/'
75
+
76
+ mapillary_image_path = './data/semantic_seg/mapillary/training/images/'
77
+ mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json'
78
+ mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/'
79
+
80
+ pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/'
81
+ pascal_file = './data/semantic_seg/pascal_part/train.json'
82
+
83
+ paco_image_path = './data/glamm_data/images/coco2017/'
84
+ paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json'
85
+
86
+ referring_refcoco_image_path = refcocog_image_path
87
+ referring_refcoco_data_path = "./data/ref_seg/"
88
+
89
+ referring_refcoco_plus_image_path = refcocog_image_path
90
+ referring_refcoco_plus_data_path = "./data/ref_seg/"
91
+
92
+ referring_refcocog_image_path = refcocog_image_path
93
+ referring_refcocog_data_path = "./data/ref_seg/"
94
+
95
+ referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/"
96
+ referring_refclef_data_path = "./data/ref_seg/"
97
+
98
+ region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
99
+ region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json"
100
+
101
+ region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
102
+ region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json"
103
+
104
+ mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
105
+ mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json'
106
+
107
+ mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
108
+ mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json'
109
+
110
+ mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
111
+ mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json'
112
+
113
+ mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
114
+ mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json'
115
+
116
+ mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
117
+ mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json'
118
+
119
+ mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
120
+ mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json'
121
+
122
+ mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
123
+ mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json'
124
+
125
+ mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
126
+ mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json'
127
+
128
+ mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
129
+ mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json'
130
+
131
+ mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
132
+ mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json'
133
+
134
+ mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
135
+ mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json'
136
+
137
+ mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017'
138
+ mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json'
139
+
140
+ mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
141
+ mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json'
142
+
143
+ mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
144
+ mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json'
145
+
146
+ mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/'
147
+ mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json'
148
+
149
+ mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017'
150
+ mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json'
151
+
152
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
153
+ max_length = int(2048 - (1024 / 64)**2 - 100)
154
+
155
+ # Scheduler & Optimizer
156
+ batch_size = 8 # per_device
157
+ accumulative_counts = 2
158
+ dataloader_num_workers = 0
159
+ max_epochs = 1
160
+ optim_type = AdamW
161
+ lr = 2e-4
162
+ betas = (0.9, 0.999)
163
+ weight_decay = 0
164
+ max_norm = 1 # grad clip
165
+ warmup_ratio = 0.03
166
+
167
+
168
+ # Save
169
+ save_steps = 2000
170
+ save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited)
171
+
172
+ # Evaluate the generation performance during the training
173
+ evaluation_freq = 2000
174
+ SYSTEM = ''
175
+ evaluation_images = './work_dirs/test.jpg'
176
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture',
177
+ 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.']
178
+
179
+ #######################################################################
180
+ # PART 2 Model & Tokenizer & Image Processor #
181
+ #######################################################################
182
+ tokenizer = dict(
183
+ type=AutoTokenizer.from_pretrained,
184
+ pretrained_model_name_or_path=llm_name_or_path,
185
+ trust_remote_code=True,
186
+ padding_side='right')
187
+
188
+ image_processor = dict(
189
+ type=CLIPImageProcessor,
190
+ do_resize=True,
191
+ size=1024,
192
+ resample=3,
193
+ do_center_crop=True,
194
+ crop_size=1024,
195
+ do_rescale=True,
196
+ do_normalize=True,
197
+ image_mean=[0.4814, 0.4578, 0.4082],
198
+ image_std=[0.2686, 0.2613, 0.2757],
199
+ do_convert_rgb=True
200
+ )
201
+
202
+ class_embed = 'convnext_large_d_320_CocoPanopticOVDataset'
203
+ num_things_classes = 80
204
+ num_stuff_classes = 53
205
+ num_classes = num_things_classes + num_stuff_classes
206
+
207
+ omgseg_model = dict(
208
+ type=OMGSegVisualEncoder,
209
+ data_preprocessor=None,
210
+ pixel_shuffle_down_ratio=2,
211
+ backbone=dict(
212
+ type=OpenCLIPBackbone_omgseg,
213
+ model_name='convnext_large_d_320',
214
+ fix=True,
215
+ init_cfg=dict(
216
+ type='clip_pretrain',
217
+ checkpoint='laion2b_s29b_b131k_ft_soup'
218
+ )
219
+ ),
220
+ panoptic_head=dict(
221
+ type=Mask2FormerVideoSemSamHead,
222
+ sphere_cls=True,
223
+ ov_path=omg_ov_class_embed_path,
224
+ enable_box_query=False,
225
+ ov_classifier_name=class_embed,
226
+ logit=None,
227
+ in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside
228
+ strides=[4, 8, 16, 32],
229
+ feat_channels=256,
230
+ out_channels=256,
231
+ num_things_classes=num_things_classes,
232
+ num_stuff_classes=num_stuff_classes,
233
+ num_queries=300,
234
+ num_transformer_feat_level=3,
235
+ pixel_decoder=dict(
236
+ type=MSDeformAttnPixelDecoder,
237
+ num_outs=3,
238
+ norm_cfg=dict(type=GroupNorm, num_groups=32),
239
+ act_cfg=dict(type=ReLU),
240
+ encoder=dict( # DeformableDetrTransformerEncoder
241
+ num_layers=6,
242
+ layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
243
+ self_attn_cfg=dict( # MultiScaleDeformableAttention
244
+ embed_dims=256,
245
+ num_heads=8,
246
+ num_levels=3,
247
+ num_points=4,
248
+ dropout=0.0,
249
+ batch_first=True),
250
+ ffn_cfg=dict(
251
+ embed_dims=256,
252
+ feedforward_channels=1024,
253
+ num_fcs=2,
254
+ ffn_drop=0.0,
255
+ act_cfg=dict(type=ReLU, inplace=True)))),
256
+ positional_encoding=dict(num_feats=128, normalize=True)),
257
+ enforce_decoder_input_project=False,
258
+ positional_encoding=dict(num_feats=128, normalize=True),
259
+ transformer_decoder=dict( # Mask2FormerTransformerDecoder
260
+ return_intermediate=True,
261
+ num_layers=9,
262
+ layer_cfg=dict( # Mask2FormerTransformerDecoderLayer
263
+ self_attn_cfg=dict( # MultiheadAttention
264
+ embed_dims=256,
265
+ num_heads=8,
266
+ dropout=0.0,
267
+ batch_first=True),
268
+ cross_attn_cfg=dict( # MultiheadAttention
269
+ embed_dims=256,
270
+ num_heads=8,
271
+ dropout=0.0,
272
+ batch_first=True),
273
+ ffn_cfg=dict(
274
+ embed_dims=256,
275
+ feedforward_channels=2048,
276
+ num_fcs=2,
277
+ ffn_drop=0.0,
278
+ act_cfg=dict(type='ReLU', inplace=True))),
279
+ init_cfg=None),
280
+ loss_cls=dict(
281
+ type=CrossEntropyLoss,
282
+ use_sigmoid=False,
283
+ loss_weight=2.0,
284
+ reduction='mean',
285
+ class_weight=[1.0] * 240 + [0.1]),
286
+ loss_mask=dict(
287
+ type=CrossEntropyLoss,
288
+ use_sigmoid=True,
289
+ reduction='mean',
290
+ loss_weight=5.0),
291
+ loss_dice=dict(
292
+ type=DiceLoss,
293
+ use_sigmoid=True,
294
+ activate=True,
295
+ reduction='mean',
296
+ naive_dice=True,
297
+ eps=1.0,
298
+ loss_weight=5.0),
299
+ loss_iou=dict(
300
+ type=FocalLoss,
301
+ use_sigmoid=True,
302
+ loss_weight=2.0,
303
+ reduction='mean')
304
+ ),
305
+ panoptic_fusion_head=dict(
306
+ type=MaskFormerFusionHead,
307
+ num_things_classes=num_things_classes,
308
+ num_stuff_classes=num_stuff_classes,
309
+ loss_panoptic=None,
310
+ init_cfg=None),
311
+ train_cfg=dict(
312
+ num_points=12544,
313
+ oversample_ratio=3.0,
314
+ importance_sample_ratio=0.75,
315
+ assigner=dict(
316
+ type=HungarianAssigner,
317
+ match_costs=[
318
+ # dict(type=FlexibleClassificationCost, weight=2.0),
319
+ dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True),
320
+ dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0)
321
+ ]),
322
+ sampler=dict(type=MaskPseudoSampler)),
323
+ test_cfg=dict(
324
+ panoptic_on=True,
325
+ # For now, the dataset does not support
326
+ # evaluating semantic segmentation metric.
327
+ semantic_on=False,
328
+ instance_on=True,
329
+ # max_per_image is for instance segmentation.
330
+ max_per_image=100,
331
+ iou_thr=0.8,
332
+ # In Mask2Former's panoptic postprocessing,
333
+ # it will filter mask area where score is less than 0.5 .
334
+ filter_low_score=True),
335
+ init_cfg=dict(
336
+ type='Pretrained',
337
+ checkpoint=omg_head_pretrain_pth_path,
338
+ )
339
+ )
340
+
341
+ model = dict(
342
+ type=OMG_LLaVA,
343
+ freeze_llm=True,
344
+ freeze_visual_encoder=True,
345
+ require_omg_decoder=False,
346
+ pretrained_pth=pretrained_pth,
347
+ text2vision_projector=True,
348
+ pixel_shuffle_ratio=2,
349
+ llm=dict(
350
+ type=AutoModelForCausalLM.from_pretrained,
351
+ pretrained_model_name_or_path=llm_name_or_path,
352
+ trust_remote_code=True,
353
+ torch_dtype=torch.float16,
354
+ quantization_config=dict(
355
+ type=BitsAndBytesConfig,
356
+ load_in_4bit=True,
357
+ load_in_8bit=False,
358
+ llm_int8_threshold=6.0,
359
+ llm_int8_has_fp16_weight=False,
360
+ bnb_4bit_compute_dtype=torch.float16,
361
+ bnb_4bit_use_double_quant=True,
362
+ bnb_4bit_quant_type='nf4')),
363
+ llm_lora=dict(
364
+ type=LoraConfig,
365
+ r=512,
366
+ lora_alpha=256,
367
+ lora_dropout=0.05,
368
+ bias='none',
369
+ task_type='CAUSAL_LM'),
370
+ visual_encoder=omgseg_model,
371
+ tokenizer=tokenizer,
372
+ )
373
+
374
+ #######################################################################
375
+ # PART 3 Dataset & Dataloader #
376
+ #######################################################################
377
+ debug=False
378
+ llava_dataset = dict(
379
+ type=LLaVADataset,
380
+ data_path=data_path,
381
+ image_folder=image_folder,
382
+ tokenizer=tokenizer,
383
+ image_processor=image_processor,
384
+ dataset_map_fn=llava_map_fn,
385
+ template_map_fn=dict(
386
+ type=template_map_fn_factory, template=prompt_template),
387
+ max_length=max_length,
388
+ pad_image_to_square=True)
389
+
390
+ glamm_refcocog_dataset = dict(
391
+ type=RefCOCOgGCGDataset,
392
+ data_path=refcocog_ann_file,
393
+ image_folder=refcocog_image_path,
394
+ tokenizer=tokenizer,
395
+ image_processor=image_processor,
396
+ dataset_map_fn=glamm_refcocog_map_fn,
397
+ template_map_fn=dict(
398
+ type=template_map_fn_factory, template=prompt_template),
399
+ max_length=max_length,
400
+ pad_image_to_square=True,
401
+ debug=False,
402
+ repeats=1,
403
+ )
404
+
405
+ glamm_grandf_dataset = dict(
406
+ type=GranDfGCGDataset,
407
+ data_path=grandf_ann_file,
408
+ image_folder=grandf_image_path,
409
+ tokenizer=tokenizer,
410
+ image_processor=image_processor,
411
+ dataset_map_fn=glamm_granf_map_fn,
412
+ template_map_fn=dict(
413
+ type=template_map_fn_factory, template=prompt_template),
414
+ max_length=max_length,
415
+ pad_image_to_square=True,
416
+ debug=debug,
417
+ repeats=10,
418
+ )
419
+
420
+ glamm_psg_dataset = dict(
421
+ type=OpenPsgGCGDataset,
422
+ data_path=psg_ann_file,
423
+ image_folder=psg_image_path,
424
+ tokenizer=tokenizer,
425
+ image_processor=image_processor,
426
+ dataset_map_fn=glamm_openpsg_map_fn,
427
+ template_map_fn=dict(
428
+ type=template_map_fn_factory, template=prompt_template),
429
+ max_length=max_length,
430
+ pad_image_to_square=True,
431
+ debug=debug,
432
+ repeats=1,
433
+ )
434
+
435
+ glamm_flickr_dataset = dict(
436
+ type=FlickrGCGDataset,
437
+ data_path=flickr_ann_file,
438
+ image_folder=flickr_image_path,
439
+ tokenizer=tokenizer,
440
+ image_processor=image_processor,
441
+ dataset_map_fn=glamm_flickr_map_fn,
442
+ template_map_fn=dict(
443
+ type=template_map_fn_factory, template=prompt_template),
444
+ max_length=max_length,
445
+ pad_image_to_square=True,
446
+ debug=debug,
447
+ repeats=1,
448
+ )
449
+
450
+ semantic_seg_ade20k_dataset = dict(
451
+ type=ADE20kSemanticSegDataset,
452
+ data_path=ade20k_class_file,
453
+ image_folder=ade20k_image_path,
454
+ tokenizer=tokenizer,
455
+ image_processor=image_processor,
456
+ dataset_map_fn=semantic_seg_gcg_format_map_fn,
457
+ template_map_fn=dict(
458
+ type=template_map_fn_factory, template=prompt_template),
459
+ max_length=max_length,
460
+ pad_image_to_square=True,
461
+ debug=False,
462
+ repeats=1,
463
+ gcg_format=True,
464
+ )
465
+
466
+ semantic_seg_cocostuff_dataset = dict(
467
+ type=COCOStuffSemanticSegDataset,
468
+ data_path=cocostuff_class_file,
469
+ image_folder=cocostuff_image_path,
470
+ label_path=cocostuff_label_path,
471
+ tokenizer=tokenizer,
472
+ image_processor=image_processor,
473
+ dataset_map_fn=semantic_seg_gcg_format_map_fn,
474
+ template_map_fn=dict(
475
+ type=template_map_fn_factory, template=prompt_template),
476
+ max_length=max_length,
477
+ pad_image_to_square=True,
478
+ debug=False,
479
+ repeats=1,
480
+ gcg_format=True,
481
+ )
482
+
483
+ referring_seg_refcoco_dataset = dict(
484
+ type=RefcocoReferringSegDataset,
485
+ data_path=referring_refcoco_data_path,
486
+ image_folder=referring_refcoco_image_path,
487
+ tokenizer=tokenizer,
488
+ image_processor=image_processor,
489
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
490
+ template_map_fn=dict(
491
+ type=template_map_fn_factory, template=prompt_template),
492
+ max_length=max_length,
493
+ pad_image_to_square=True,
494
+ debug=False,
495
+ repeats=1,
496
+ )
497
+
498
+ referring_seg_refcoco_plus_dataset = dict(
499
+ type=Refcoco_plus_ReferringSegDataset,
500
+ data_path=referring_refcoco_plus_data_path,
501
+ image_folder=referring_refcoco_plus_image_path,
502
+ tokenizer=tokenizer,
503
+ image_processor=image_processor,
504
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
505
+ template_map_fn=dict(
506
+ type=template_map_fn_factory, template=prompt_template),
507
+ max_length=max_length,
508
+ pad_image_to_square=True,
509
+ debug=False,
510
+ repeats=1,
511
+ )
512
+
513
+ referring_seg_refcocog_dataset = dict(
514
+ type=Refcocog_ReferringSegDataset,
515
+ data_path=referring_refcocog_data_path,
516
+ image_folder=referring_refcocog_image_path,
517
+ tokenizer=tokenizer,
518
+ image_processor=image_processor,
519
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
520
+ template_map_fn=dict(
521
+ type=template_map_fn_factory, template=prompt_template),
522
+ max_length=max_length,
523
+ pad_image_to_square=True,
524
+ debug=False,
525
+ repeats=1,
526
+ )
527
+
528
+ referring_seg_refclef_dataset = dict(
529
+ type=Refclef_ReferringSegDataset,
530
+ data_path=referring_refclef_data_path,
531
+ image_folder=referring_refclef_image_path,
532
+ tokenizer=tokenizer,
533
+ image_processor=image_processor,
534
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
535
+ template_map_fn=dict(
536
+ type=template_map_fn_factory, template=prompt_template),
537
+ max_length=max_length,
538
+ pad_image_to_square=True,
539
+ debug=False,
540
+ repeats=1,
541
+ )
542
+
543
+ region_cap_osprey_dataset = dict(
544
+ type=OspreyRegionCaptionDataset,
545
+ data_path=region_cap_osprey_data_path,
546
+ image_folder=region_cap_osprey_image_path,
547
+ tokenizer=tokenizer,
548
+ image_processor=image_processor,
549
+ dataset_map_fn=osprey_region_caption_gcg_format_map_fn,
550
+ template_map_fn=dict(
551
+ type=template_map_fn_factory, template=prompt_template),
552
+ max_length=max_length,
553
+ pad_image_to_square=True,
554
+ debug=False,
555
+ repeats=1,
556
+ )
557
+
558
+ region_conversation_osprey_dataset = dict(
559
+ type=OspreyRegionConversationDataset,
560
+ data_path=region_conversation_osprey_data_path,
561
+ image_folder=region_conversation_osprey_image_path,
562
+ tokenizer=tokenizer,
563
+ image_processor=image_processor,
564
+ dataset_map_fn=osprey_region_conversation_map_fn,
565
+ template_map_fn=dict(
566
+ type=template_map_fn_factory, template=prompt_template),
567
+ max_length=max_length,
568
+ pad_image_to_square=True,
569
+ debug=False,
570
+ repeats=1,
571
+ )
572
+
573
+ mdpv_detailed_description_ade20k_dataset = dict(
574
+ type=MDPVPointDetailedCaptionDataset,
575
+ data_path=mdpv_detailed_caption_ade20k_data_path,
576
+ image_folder=mdpv_detailed_caption_ade20k_image_path,
577
+ tokenizer=tokenizer,
578
+ image_processor=image_processor,
579
+ dataset_map_fn=mdpv_points_map_fn,
580
+ template_map_fn=dict(
581
+ type=template_map_fn_factory, template=prompt_template),
582
+ max_length=max_length,
583
+ pad_image_to_square=True,
584
+ debug=False,
585
+ repeats=1,
586
+ )
587
+
588
+ mdpv_detailed_description_cocostuff_10k_dataset = dict(
589
+ type=MDPVPointDetailedCaptionDataset,
590
+ data_path=mdpv_detailed_caption_cocostuff_10k_data_path,
591
+ image_folder=mdpv_detailed_caption_cocostuff_10k_image_path,
592
+ tokenizer=tokenizer,
593
+ image_processor=image_processor,
594
+ dataset_map_fn=mdpv_points_map_fn,
595
+ template_map_fn=dict(
596
+ type=template_map_fn_factory, template=prompt_template),
597
+ max_length=max_length,
598
+ pad_image_to_square=True,
599
+ debug=False,
600
+ repeats=1,
601
+ )
602
+
603
+ mdpv_detailed_description_cocostuff_164k_dataset = dict(
604
+ type=MDPVPointDetailedCaptionDataset,
605
+ data_path=mdpv_detailed_caption_cocostuff_164k_data_path,
606
+ image_folder=mdpv_detailed_caption_cocostuff_164k_image_path,
607
+ tokenizer=tokenizer,
608
+ image_processor=image_processor,
609
+ dataset_map_fn=mdpv_points_map_fn,
610
+ template_map_fn=dict(
611
+ type=template_map_fn_factory, template=prompt_template),
612
+ max_length=max_length,
613
+ pad_image_to_square=True,
614
+ debug=False,
615
+ repeats=1,
616
+ )
617
+
618
+ mdpv_detailed_description_vg_dataset = dict(
619
+ type=MDPVPointDetailedCaptionDataset,
620
+ data_path=mdpv_detailed_caption_vg_data_path,
621
+ image_folder=mdpv_detailed_caption_vg_image_path,
622
+ tokenizer=tokenizer,
623
+ image_processor=image_processor,
624
+ dataset_map_fn=mdpv_points_map_fn,
625
+ template_map_fn=dict(
626
+ type=template_map_fn_factory, template=prompt_template),
627
+ max_length=max_length,
628
+ pad_image_to_square=True,
629
+ debug=False,
630
+ repeats=1,
631
+ )
632
+
633
+ mdpv_brief_description_vg_dataset = dict(
634
+ type=MDPVPointBriefCaptionDataset,
635
+ data_path=mdpv_brief_caption_vg_data_path,
636
+ image_folder=mdpv_brief_caption_vg_image_path,
637
+ tokenizer=tokenizer,
638
+ image_processor=image_processor,
639
+ dataset_map_fn=mdpv_points_map_fn,
640
+ template_map_fn=dict(
641
+ type=template_map_fn_factory, template=prompt_template),
642
+ max_length=max_length,
643
+ pad_image_to_square=True,
644
+ debug=False,
645
+ repeats=1,
646
+ )
647
+
648
+ mdpv_brief_description_cocostuff10k_dataset = dict(
649
+ type=MDPVPointBriefCaptionDataset,
650
+ data_path=mdpv_brief_caption_cocostuff_10k_data_path,
651
+ image_folder=mdpv_brief_caption_cocostuff_10k_image_path,
652
+ tokenizer=tokenizer,
653
+ image_processor=image_processor,
654
+ dataset_map_fn=mdpv_points_map_fn,
655
+ template_map_fn=dict(
656
+ type=template_map_fn_factory, template=prompt_template),
657
+ max_length=max_length,
658
+ pad_image_to_square=True,
659
+ debug=False,
660
+ repeats=1,
661
+ )
662
+
663
+ mdpv_brief_description_cocostuff164k_dataset = dict(
664
+ type=MDPVPointBriefCaptionDataset,
665
+ data_path=mdpv_brief_caption_cocostuff_164k_data_path,
666
+ image_folder=mdpv_brief_caption_cocostuff_164k_image_path,
667
+ tokenizer=tokenizer,
668
+ image_processor=image_processor,
669
+ dataset_map_fn=mdpv_points_map_fn,
670
+ template_map_fn=dict(
671
+ type=template_map_fn_factory, template=prompt_template),
672
+ max_length=max_length,
673
+ pad_image_to_square=True,
674
+ debug=False,
675
+ repeats=1,
676
+ )
677
+
678
+ mdpv_brief_description_ade20k_dataset = dict(
679
+ type=MDPVPointBriefCaptionDataset,
680
+ data_path=mdpv_brief_caption_ade20k_data_path,
681
+ image_folder=mdpv_brief_caption_ade20k_image_path,
682
+ tokenizer=tokenizer,
683
+ image_processor=image_processor,
684
+ dataset_map_fn=mdpv_points_map_fn,
685
+ template_map_fn=dict(
686
+ type=template_map_fn_factory, template=prompt_template),
687
+ max_length=max_length,
688
+ pad_image_to_square=True,
689
+ debug=False,
690
+ repeats=1,
691
+ )
692
+
693
+ mdpv_brief_description_lvis_dataset = dict(
694
+ type=MDPVPointBriefCaptionDataset,
695
+ data_path=mdpv_brief_caption_lvis_data_path,
696
+ image_folder=mdpv_brief_caption_lvis_image_path,
697
+ tokenizer=tokenizer,
698
+ image_processor=image_processor,
699
+ dataset_map_fn=mdpv_points_map_fn,
700
+ template_map_fn=dict(
701
+ type=template_map_fn_factory, template=prompt_template),
702
+ max_length=max_length,
703
+ pad_image_to_square=True,
704
+ debug=False,
705
+ repeats=1,
706
+ )
707
+
708
+ mdpv_qa_vg_dataset = dict(
709
+ type=MDPVPointBriefCaptionDataset,
710
+ data_path=mdpv_qa_vg_data_path,
711
+ image_folder=mdpv_qa_vg_image_path,
712
+ tokenizer=tokenizer,
713
+ image_processor=image_processor,
714
+ dataset_map_fn=mdpv_points_map_fn,
715
+ template_map_fn=dict(
716
+ type=template_map_fn_factory, template=prompt_template),
717
+ max_length=max_length,
718
+ pad_image_to_square=True,
719
+ debug=False,
720
+ repeats=1,
721
+ )
722
+
723
+ mdpv_qa_ade20k_dataset = dict(
724
+ type=MDPVPointBriefCaptionDataset,
725
+ data_path=mdpv_qa_ade20k_data_path,
726
+ image_folder=mdpv_qa_ade20k_image_path,
727
+ tokenizer=tokenizer,
728
+ image_processor=image_processor,
729
+ dataset_map_fn=mdpv_points_map_fn,
730
+ template_map_fn=dict(
731
+ type=template_map_fn_factory, template=prompt_template),
732
+ max_length=max_length,
733
+ pad_image_to_square=True,
734
+ debug=False,
735
+ repeats=1,
736
+ )
737
+
738
+ mdpv_qa_lvis_dataset = dict(
739
+ type=MDPVPointBriefCaptionDataset,
740
+ data_path=mdpv_qa_lvis_data_path,
741
+ image_folder=mdpv_qa_lvis_image_path,
742
+ tokenizer=tokenizer,
743
+ image_processor=image_processor,
744
+ dataset_map_fn=mdpv_points_map_fn,
745
+ template_map_fn=dict(
746
+ type=template_map_fn_factory, template=prompt_template),
747
+ max_length=max_length,
748
+ pad_image_to_square=True,
749
+ debug=False,
750
+ repeats=1,
751
+ )
752
+
753
+ mdpv_qa_cocostuff10k_dataset = dict(
754
+ type=MDPVPointBriefCaptionDataset,
755
+ data_path=mdpv_qa_cocostuff10k_data_path,
756
+ image_folder=mdpv_qa_cocostuff10k_image_path,
757
+ tokenizer=tokenizer,
758
+ image_processor=image_processor,
759
+ dataset_map_fn=mdpv_points_map_fn,
760
+ template_map_fn=dict(
761
+ type=template_map_fn_factory, template=prompt_template),
762
+ max_length=max_length,
763
+ pad_image_to_square=True,
764
+ debug=False,
765
+ repeats=1,
766
+ )
767
+
768
+ mdpv_qa_cocostuff164k_dataset = dict(
769
+ type=MDPVPointBriefCaptionDataset,
770
+ data_path=mdpv_qa_cocostuff164k_data_path,
771
+ image_folder=mdpv_qa_cocostuff164k_image_path,
772
+ tokenizer=tokenizer,
773
+ image_processor=image_processor,
774
+ dataset_map_fn=mdpv_points_map_fn,
775
+ template_map_fn=dict(
776
+ type=template_map_fn_factory, template=prompt_template),
777
+ max_length=max_length,
778
+ pad_image_to_square=True,
779
+ debug=False,
780
+ repeats=1,
781
+ )
782
+
783
+ mdpv_multi_points_openpsg_dataset = dict(
784
+ type=MDPVPointBriefCaptionDataset,
785
+ data_path=mdpv_multi_points_openpsg_data_path,
786
+ image_folder=mdpv_multi_points_openpsg_image_path,
787
+ tokenizer=tokenizer,
788
+ image_processor=image_processor,
789
+ dataset_map_fn=mdpv_points_map_fn,
790
+ template_map_fn=dict(
791
+ type=template_map_fn_factory, template=prompt_template),
792
+ max_length=max_length,
793
+ pad_image_to_square=True,
794
+ debug=False,
795
+ repeats=1,
796
+ )
797
+
798
+ mdpv_multi_points_flicker30k_dataset = dict(
799
+ type=MDPVPointBriefCaptionDataset,
800
+ data_path=mdpv_multi_points_flicker30k_data_path,
801
+ image_folder=mdpv_multi_points_flicker30k_image_path,
802
+ tokenizer=tokenizer,
803
+ image_processor=image_processor,
804
+ dataset_map_fn=mdpv_points_map_fn,
805
+ template_map_fn=dict(
806
+ type=template_map_fn_factory, template=prompt_template),
807
+ max_length=max_length,
808
+ pad_image_to_square=True,
809
+ debug=False,
810
+ repeats=1,
811
+ )
812
+
813
+ train_dataset = dict(
814
+ type=CombineDataset,
815
+ datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset,
816
+ glamm_grandf_dataset, glamm_psg_dataset,
817
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x
818
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset,
819
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset,
820
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
821
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x
822
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
823
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset,
824
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
825
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset,
826
+ region_cap_osprey_dataset, region_conversation_osprey_dataset,
827
+ mdpv_detailed_description_ade20k_dataset,
828
+ mdpv_detailed_description_cocostuff_10k_dataset,
829
+ mdpv_detailed_description_cocostuff_164k_dataset,
830
+ mdpv_detailed_description_vg_dataset,
831
+ mdpv_brief_description_lvis_dataset,
832
+ mdpv_brief_description_vg_dataset,
833
+ mdpv_brief_description_ade20k_dataset,
834
+ mdpv_brief_description_cocostuff10k_dataset,
835
+ mdpv_brief_description_cocostuff164k_dataset,
836
+ mdpv_qa_vg_dataset,
837
+ mdpv_qa_lvis_dataset,
838
+ mdpv_qa_ade20k_dataset,
839
+ mdpv_qa_cocostuff10k_dataset,
840
+ mdpv_qa_cocostuff164k_dataset,
841
+ mdpv_multi_points_flicker30k_dataset,
842
+ mdpv_multi_points_openpsg_dataset,],
843
+ )
844
+
845
+ train_dataloader = dict(
846
+ batch_size=batch_size,
847
+ num_workers=dataloader_num_workers,
848
+ dataset=train_dataset,
849
+ sampler=dict(
850
+ type=LengthGroupedSampler,
851
+ length_property='modality_length',
852
+ per_device_batch_size=batch_size * accumulative_counts),
853
+ collate_fn=dict(type=omg_llava_collate_fn))
854
+
855
+ #######################################################################
856
+ # PART 4 Scheduler & Optimizer #
857
+ #######################################################################
858
+ # optimizer
859
+ optim_wrapper = dict(
860
+ type=AmpOptimWrapper,
861
+ optimizer=dict(
862
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
863
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
864
+ accumulative_counts=accumulative_counts,
865
+ loss_scale='dynamic',
866
+ dtype='float16')
867
+
868
+ # learning policy
869
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
870
+ param_scheduler = [
871
+ dict(
872
+ type=LinearLR,
873
+ start_factor=1e-5,
874
+ by_epoch=True,
875
+ begin=0,
876
+ end=warmup_ratio * max_epochs,
877
+ convert_to_iter_based=True),
878
+ dict(
879
+ type=CosineAnnealingLR,
880
+ eta_min=0.0,
881
+ by_epoch=True,
882
+ begin=warmup_ratio * max_epochs,
883
+ end=max_epochs,
884
+ convert_to_iter_based=True)
885
+ ]
886
+
887
+ # train, val, test setting
888
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
889
+
890
+ #######################################################################
891
+ # PART 5 Runtime #
892
+ #######################################################################
893
+ # Log the dialogue periodically during the training process, optional
894
+ custom_hooks = [
895
+ dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer),
896
+ dict(
897
+ type=EvaluateChatHook_withSpecialTokens,
898
+ tokenizer=tokenizer,
899
+ image_processor=image_processor,
900
+ every_n_iters=evaluation_freq,
901
+ evaluation_inputs=evaluation_inputs,
902
+ evaluation_images=evaluation_images,
903
+ system=SYSTEM,
904
+ prompt_template=prompt_template)
905
+ ]
906
+
907
+ # configure default hooks
908
+ default_hooks = dict(
909
+ # record the time of every iteration.
910
+ timer=dict(type=IterTimerHook),
911
+ # print log every 10 iterations.
912
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
913
+ # enable the parameter scheduler.
914
+ param_scheduler=dict(type=ParamSchedulerHook),
915
+ # save checkpoint per `save_steps`.
916
+ checkpoint=dict(
917
+ type=CheckpointHook,
918
+ by_epoch=False,
919
+ interval=save_steps,
920
+ max_keep_ckpts=save_total_limit),
921
+ # set sampler seed in distributed evrionment.
922
+ sampler_seed=dict(type=DistSamplerSeedHook),
923
+ )
924
+
925
+ # configure environment
926
+ env_cfg = dict(
927
+ # whether to enable cudnn benchmark
928
+ cudnn_benchmark=False,
929
+ # set multi process parameters
930
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
931
+ # set distributed parameters
932
+ dist_cfg=dict(backend='nccl'),
933
+ )
934
+
935
+ # set visualizer
936
+ visualizer = None
937
+
938
+ # set log level
939
+ log_level = 'INFO'
940
+
941
+ # load from which checkpoint
942
+ load_from = None
943
+
944
+ # whether to resume training from the loaded checkpoint
945
+ resume = False
946
+
947
+ # Defaults to use random seed and disable `deterministic`
948
+ randomness = dict(seed=None, deterministic=False)
949
+
950
+ # set log processor
951
+ log_processor = dict(by_epoch=False)
omg_llava/configs/finetune/hf_app.py ADDED
@@ -0,0 +1,951 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
4
+ LoggerHook, ParamSchedulerHook)
5
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
6
+ from peft import LoraConfig
7
+ from torch.optim import AdamW
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig, CLIPImageProcessor,
10
+ CLIPVisionModel)
11
+
12
+ from omg_llava.dataset import LLaVADataset
13
+ from omg_llava.dataset.collect_fns import omg_llava_collate_fn
14
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
15
+ from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\
16
+ CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\
17
+ ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\
18
+ PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\
19
+ RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\
20
+ Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\
21
+ OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\
22
+ OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\
23
+ MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\
24
+ semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\
25
+ referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn
26
+ from xtuner.dataset.samplers import LengthGroupedSampler
27
+ from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens
28
+ from xtuner.engine.runner import TrainLoop
29
+ from omg_llava.model import OMG_LLaVA
30
+ from xtuner.utils import PROMPT_TEMPLATE
31
+ from omg_llava.model import OpenCLIPBackbone_omgseg
32
+ from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead
33
+
34
+ from torch.nn import GroupNorm, ReLU
35
+
36
+ from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \
37
+ DiceLoss, MaskFormerFusionHead, FocalLoss
38
+ from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost
39
+ from mmdet.models.task_modules.samplers import MaskPseudoSampler
40
+
41
+ #######################################################################
42
+ # PART 1 Settings #
43
+ #######################################################################
44
+ # Model
45
+ llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path
46
+ pretrained_pth = './pretrained/omg_llava/omg_llava_7b_pretrain_1024image_8gpus.pth'
47
+ omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path
48
+ omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path
49
+
50
+ # Data
51
+ data_root = './data/llava_data/'
52
+ data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json'
53
+ image_folder = data_root + 'llava_images'
54
+
55
+ glamm_data_root = './data/glamm_data/'
56
+
57
+ refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/'
58
+ refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json'
59
+
60
+ grandf_image_path = glamm_data_root + 'images/grandf/train/'
61
+ grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json'
62
+
63
+ flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/'
64
+ flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json'
65
+
66
+ psg_image_path = glamm_data_root + 'images/coco2017/'
67
+ psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json'
68
+
69
+ ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
70
+ ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json'
71
+
72
+ cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/'
73
+ cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt'
74
+ cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/'
75
+
76
+ mapillary_image_path = './data/semantic_seg/mapillary/training/images/'
77
+ mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json'
78
+ mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/'
79
+
80
+ pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/'
81
+ pascal_file = './data/semantic_seg/pascal_part/train.json'
82
+
83
+ paco_image_path = './data/glamm_data/images/coco2017/'
84
+ paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json'
85
+
86
+ referring_refcoco_image_path = refcocog_image_path
87
+ referring_refcoco_data_path = "./data/ref_seg/"
88
+
89
+ referring_refcoco_plus_image_path = refcocog_image_path
90
+ referring_refcoco_plus_data_path = "./data/ref_seg/"
91
+
92
+ referring_refcocog_image_path = refcocog_image_path
93
+ referring_refcocog_data_path = "./data/ref_seg/"
94
+
95
+ referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/"
96
+ referring_refclef_data_path = "./data/ref_seg/"
97
+
98
+ region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
99
+ region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json"
100
+
101
+ region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
102
+ region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json"
103
+
104
+ mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
105
+ mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json'
106
+
107
+ mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
108
+ mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json'
109
+
110
+ mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
111
+ mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json'
112
+
113
+ mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
114
+ mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json'
115
+
116
+ mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
117
+ mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json'
118
+
119
+ mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
120
+ mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json'
121
+
122
+ mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
123
+ mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json'
124
+
125
+ mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
126
+ mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json'
127
+
128
+ mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
129
+ mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json'
130
+
131
+ mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
132
+ mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json'
133
+
134
+ mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
135
+ mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json'
136
+
137
+ mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017'
138
+ mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json'
139
+
140
+ mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
141
+ mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json'
142
+
143
+ mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
144
+ mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json'
145
+
146
+ mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/'
147
+ mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json'
148
+
149
+ mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017'
150
+ mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json'
151
+
152
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
153
+ max_length = int(2048 - (1024 / 64)**2 - 100)
154
+
155
+ # Scheduler & Optimizer
156
+ batch_size = 8 # per_device
157
+ accumulative_counts = 2
158
+ dataloader_num_workers = 4
159
+ max_epochs = 1
160
+ optim_type = AdamW
161
+ lr = 2e-4
162
+ betas = (0.9, 0.999)
163
+ weight_decay = 0
164
+ max_norm = 1 # grad clip
165
+ warmup_ratio = 0.03
166
+
167
+
168
+ # Save
169
+ save_steps = 2000
170
+ save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited)
171
+
172
+ # Evaluate the generation performance during the training
173
+ evaluation_freq = 2000
174
+ SYSTEM = ''
175
+ evaluation_images = './work_dirs/test.jpg'
176
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture',
177
+ 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.']
178
+
179
+ #######################################################################
180
+ # PART 2 Model & Tokenizer & Image Processor #
181
+ #######################################################################
182
+ tokenizer = dict(
183
+ type=AutoTokenizer.from_pretrained,
184
+ pretrained_model_name_or_path=llm_name_or_path,
185
+ trust_remote_code=True,
186
+ padding_side='right')
187
+
188
+ image_processor = dict(
189
+ type=CLIPImageProcessor,
190
+ do_resize=True,
191
+ size=1024,
192
+ resample=3,
193
+ do_center_crop=True,
194
+ crop_size=1024,
195
+ do_rescale=True,
196
+ do_normalize=True,
197
+ image_mean=[0.4814, 0.4578, 0.4082],
198
+ image_std=[0.2686, 0.2613, 0.2757],
199
+ do_convert_rgb=True
200
+ )
201
+
202
+ class_embed = 'convnext_large_d_320_CocoPanopticOVDataset'
203
+ num_things_classes = 80
204
+ num_stuff_classes = 53
205
+ num_classes = num_things_classes + num_stuff_classes
206
+
207
+ omgseg_model = dict(
208
+ type=OMGSegVisualEncoder,
209
+ data_preprocessor=None,
210
+ pixel_shuffle_down_ratio=2,
211
+ backbone=dict(
212
+ type=OpenCLIPBackbone_omgseg,
213
+ model_name='convnext_large_d_320',
214
+ fix=True,
215
+ init_cfg=dict(
216
+ type='clip_pretrain',
217
+ checkpoint='laion2b_s29b_b131k_ft_soup'
218
+ )
219
+ ),
220
+ panoptic_head=dict(
221
+ type=Mask2FormerVideoSemSamHead,
222
+ sphere_cls=True,
223
+ ov_path=omg_ov_class_embed_path,
224
+ enable_box_query=False,
225
+ ov_classifier_name=class_embed,
226
+ logit=None,
227
+ in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside
228
+ strides=[4, 8, 16, 32],
229
+ feat_channels=256,
230
+ out_channels=256,
231
+ num_things_classes=num_things_classes,
232
+ num_stuff_classes=num_stuff_classes,
233
+ num_queries=300,
234
+ num_transformer_feat_level=3,
235
+ pixel_decoder=dict(
236
+ type=MSDeformAttnPixelDecoder,
237
+ num_outs=3,
238
+ norm_cfg=dict(type=GroupNorm, num_groups=32),
239
+ act_cfg=dict(type=ReLU),
240
+ encoder=dict( # DeformableDetrTransformerEncoder
241
+ num_layers=6,
242
+ layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
243
+ self_attn_cfg=dict( # MultiScaleDeformableAttention
244
+ embed_dims=256,
245
+ num_heads=8,
246
+ num_levels=3,
247
+ num_points=4,
248
+ dropout=0.0,
249
+ batch_first=True),
250
+ ffn_cfg=dict(
251
+ embed_dims=256,
252
+ feedforward_channels=1024,
253
+ num_fcs=2,
254
+ ffn_drop=0.0,
255
+ act_cfg=dict(type=ReLU, inplace=True)))),
256
+ positional_encoding=dict(num_feats=128, normalize=True)),
257
+ enforce_decoder_input_project=False,
258
+ positional_encoding=dict(num_feats=128, normalize=True),
259
+ transformer_decoder=dict( # Mask2FormerTransformerDecoder
260
+ return_intermediate=True,
261
+ num_layers=9,
262
+ layer_cfg=dict( # Mask2FormerTransformerDecoderLayer
263
+ self_attn_cfg=dict( # MultiheadAttention
264
+ embed_dims=256,
265
+ num_heads=8,
266
+ dropout=0.0,
267
+ batch_first=True),
268
+ cross_attn_cfg=dict( # MultiheadAttention
269
+ embed_dims=256,
270
+ num_heads=8,
271
+ dropout=0.0,
272
+ batch_first=True),
273
+ ffn_cfg=dict(
274
+ embed_dims=256,
275
+ feedforward_channels=2048,
276
+ num_fcs=2,
277
+ ffn_drop=0.0,
278
+ act_cfg=dict(type='ReLU', inplace=True))),
279
+ init_cfg=None),
280
+ loss_cls=dict(
281
+ type=CrossEntropyLoss,
282
+ use_sigmoid=False,
283
+ loss_weight=2.0,
284
+ reduction='mean',
285
+ class_weight=[1.0] * 240 + [0.1]),
286
+ loss_mask=dict(
287
+ type=CrossEntropyLoss,
288
+ use_sigmoid=True,
289
+ reduction='mean',
290
+ loss_weight=5.0),
291
+ loss_dice=dict(
292
+ type=DiceLoss,
293
+ use_sigmoid=True,
294
+ activate=True,
295
+ reduction='mean',
296
+ naive_dice=True,
297
+ eps=1.0,
298
+ loss_weight=5.0),
299
+ loss_iou=dict(
300
+ type=FocalLoss,
301
+ use_sigmoid=True,
302
+ loss_weight=2.0,
303
+ reduction='mean')
304
+ ),
305
+ panoptic_fusion_head=dict(
306
+ type=MaskFormerFusionHead,
307
+ num_things_classes=num_things_classes,
308
+ num_stuff_classes=num_stuff_classes,
309
+ loss_panoptic=None,
310
+ init_cfg=None),
311
+ train_cfg=dict(
312
+ num_points=12544,
313
+ oversample_ratio=3.0,
314
+ importance_sample_ratio=0.75,
315
+ assigner=dict(
316
+ type=HungarianAssigner,
317
+ match_costs=[
318
+ # dict(type=FlexibleClassificationCost, weight=2.0),
319
+ dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True),
320
+ dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0)
321
+ ]),
322
+ sampler=dict(type=MaskPseudoSampler)),
323
+ test_cfg=dict(
324
+ panoptic_on=True,
325
+ # For now, the dataset does not support
326
+ # evaluating semantic segmentation metric.
327
+ semantic_on=False,
328
+ instance_on=True,
329
+ # max_per_image is for instance segmentation.
330
+ max_per_image=100,
331
+ iou_thr=0.8,
332
+ # In Mask2Former's panoptic postprocessing,
333
+ # it will filter mask area where score is less than 0.5 .
334
+ filter_low_score=True),
335
+ init_cfg=dict(
336
+ type='Pretrained',
337
+ checkpoint=omg_head_pretrain_pth_path,
338
+ )
339
+ )
340
+
341
+ model = dict(
342
+ type=OMG_LLaVA,
343
+ freeze_llm=True,
344
+ freeze_visual_encoder=True,
345
+ require_omg_decoder=False,
346
+ pretrained_pth=pretrained_pth,
347
+ text2vision_projector=True,
348
+ pixel_shuffle_ratio=2,
349
+ llm=dict(
350
+ type=AutoModelForCausalLM.from_pretrained,
351
+ pretrained_model_name_or_path=llm_name_or_path,
352
+ trust_remote_code=True,
353
+ torch_dtype=torch.float16,
354
+ quantization_config=dict(
355
+ type=BitsAndBytesConfig,
356
+ load_in_4bit=True,
357
+ load_in_8bit=False,
358
+ llm_int8_threshold=6.0,
359
+ llm_int8_has_fp16_weight=False,
360
+ bnb_4bit_compute_dtype=torch.float16,
361
+ bnb_4bit_use_double_quant=True,
362
+ bnb_4bit_quant_type='nf4')),
363
+ llm_lora=dict(
364
+ type=LoraConfig,
365
+ r=512,
366
+ lora_alpha=256,
367
+ lora_dropout=0.05,
368
+ bias='none',
369
+ task_type='CAUSAL_LM'),
370
+ visual_encoder=omgseg_model,
371
+ tokenizer=tokenizer,
372
+ )
373
+
374
+ #######################################################################
375
+ # PART 3 Dataset & Dataloader #
376
+ #######################################################################
377
+ debug=False
378
+ llava_dataset = dict(
379
+ type=LLaVADataset,
380
+ data_path=data_path,
381
+ image_folder=image_folder,
382
+ tokenizer=tokenizer,
383
+ image_processor=image_processor,
384
+ dataset_map_fn=llava_map_fn,
385
+ template_map_fn=dict(
386
+ type=template_map_fn_factory, template=prompt_template),
387
+ max_length=max_length,
388
+ pad_image_to_square=True)
389
+
390
+ glamm_refcocog_dataset = dict(
391
+ type=RefCOCOgGCGDataset,
392
+ data_path=refcocog_ann_file,
393
+ image_folder=refcocog_image_path,
394
+ tokenizer=tokenizer,
395
+ image_processor=image_processor,
396
+ dataset_map_fn=glamm_refcocog_map_fn,
397
+ template_map_fn=dict(
398
+ type=template_map_fn_factory, template=prompt_template),
399
+ max_length=max_length,
400
+ pad_image_to_square=True,
401
+ debug=False,
402
+ repeats=1,
403
+ )
404
+
405
+ glamm_grandf_dataset = dict(
406
+ type=GranDfGCGDataset,
407
+ data_path=grandf_ann_file,
408
+ image_folder=grandf_image_path,
409
+ tokenizer=tokenizer,
410
+ image_processor=image_processor,
411
+ dataset_map_fn=glamm_granf_map_fn,
412
+ template_map_fn=dict(
413
+ type=template_map_fn_factory, template=prompt_template),
414
+ max_length=max_length,
415
+ pad_image_to_square=True,
416
+ debug=debug,
417
+ repeats=10,
418
+ )
419
+
420
+ glamm_psg_dataset = dict(
421
+ type=OpenPsgGCGDataset,
422
+ data_path=psg_ann_file,
423
+ image_folder=psg_image_path,
424
+ tokenizer=tokenizer,
425
+ image_processor=image_processor,
426
+ dataset_map_fn=glamm_openpsg_map_fn,
427
+ template_map_fn=dict(
428
+ type=template_map_fn_factory, template=prompt_template),
429
+ max_length=max_length,
430
+ pad_image_to_square=True,
431
+ debug=debug,
432
+ repeats=1,
433
+ )
434
+
435
+ glamm_flickr_dataset = dict(
436
+ type=FlickrGCGDataset,
437
+ data_path=flickr_ann_file,
438
+ image_folder=flickr_image_path,
439
+ tokenizer=tokenizer,
440
+ image_processor=image_processor,
441
+ dataset_map_fn=glamm_flickr_map_fn,
442
+ template_map_fn=dict(
443
+ type=template_map_fn_factory, template=prompt_template),
444
+ max_length=max_length,
445
+ pad_image_to_square=True,
446
+ debug=debug,
447
+ repeats=1,
448
+ )
449
+
450
+ semantic_seg_ade20k_dataset = dict(
451
+ type=ADE20kSemanticSegDataset,
452
+ data_path=ade20k_class_file,
453
+ image_folder=ade20k_image_path,
454
+ tokenizer=tokenizer,
455
+ image_processor=image_processor,
456
+ dataset_map_fn=semantic_seg_gcg_format_map_fn,
457
+ template_map_fn=dict(
458
+ type=template_map_fn_factory, template=prompt_template),
459
+ max_length=max_length,
460
+ pad_image_to_square=True,
461
+ debug=False,
462
+ repeats=1,
463
+ gcg_format=True,
464
+ )
465
+
466
+ semantic_seg_cocostuff_dataset = dict(
467
+ type=COCOStuffSemanticSegDataset,
468
+ data_path=cocostuff_class_file,
469
+ image_folder=cocostuff_image_path,
470
+ label_path=cocostuff_label_path,
471
+ tokenizer=tokenizer,
472
+ image_processor=image_processor,
473
+ dataset_map_fn=semantic_seg_gcg_format_map_fn,
474
+ template_map_fn=dict(
475
+ type=template_map_fn_factory, template=prompt_template),
476
+ max_length=max_length,
477
+ pad_image_to_square=True,
478
+ debug=False,
479
+ repeats=1,
480
+ gcg_format=True,
481
+ )
482
+
483
+ referring_seg_refcoco_dataset = dict(
484
+ type=RefcocoReferringSegDataset,
485
+ data_path=referring_refcoco_data_path,
486
+ image_folder=referring_refcoco_image_path,
487
+ tokenizer=tokenizer,
488
+ image_processor=image_processor,
489
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
490
+ template_map_fn=dict(
491
+ type=template_map_fn_factory, template=prompt_template),
492
+ max_length=max_length,
493
+ pad_image_to_square=True,
494
+ debug=False,
495
+ repeats=1,
496
+ )
497
+
498
+ referring_seg_refcoco_plus_dataset = dict(
499
+ type=Refcoco_plus_ReferringSegDataset,
500
+ data_path=referring_refcoco_plus_data_path,
501
+ image_folder=referring_refcoco_plus_image_path,
502
+ tokenizer=tokenizer,
503
+ image_processor=image_processor,
504
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
505
+ template_map_fn=dict(
506
+ type=template_map_fn_factory, template=prompt_template),
507
+ max_length=max_length,
508
+ pad_image_to_square=True,
509
+ debug=False,
510
+ repeats=1,
511
+ )
512
+
513
+ referring_seg_refcocog_dataset = dict(
514
+ type=Refcocog_ReferringSegDataset,
515
+ data_path=referring_refcocog_data_path,
516
+ image_folder=referring_refcocog_image_path,
517
+ tokenizer=tokenizer,
518
+ image_processor=image_processor,
519
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
520
+ template_map_fn=dict(
521
+ type=template_map_fn_factory, template=prompt_template),
522
+ max_length=max_length,
523
+ pad_image_to_square=True,
524
+ debug=False,
525
+ repeats=1,
526
+ )
527
+
528
+ referring_seg_refclef_dataset = dict(
529
+ type=Refclef_ReferringSegDataset,
530
+ data_path=referring_refclef_data_path,
531
+ image_folder=referring_refclef_image_path,
532
+ tokenizer=tokenizer,
533
+ image_processor=image_processor,
534
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
535
+ template_map_fn=dict(
536
+ type=template_map_fn_factory, template=prompt_template),
537
+ max_length=max_length,
538
+ pad_image_to_square=True,
539
+ debug=False,
540
+ repeats=1,
541
+ )
542
+
543
+ region_cap_osprey_dataset = dict(
544
+ type=OspreyRegionCaptionDataset,
545
+ data_path=region_cap_osprey_data_path,
546
+ image_folder=region_cap_osprey_image_path,
547
+ tokenizer=tokenizer,
548
+ image_processor=image_processor,
549
+ dataset_map_fn=osprey_region_caption_gcg_format_map_fn,
550
+ template_map_fn=dict(
551
+ type=template_map_fn_factory, template=prompt_template),
552
+ max_length=max_length,
553
+ pad_image_to_square=True,
554
+ debug=False,
555
+ repeats=1,
556
+ )
557
+
558
+ region_conversation_osprey_dataset = dict(
559
+ type=OspreyRegionConversationDataset,
560
+ data_path=region_conversation_osprey_data_path,
561
+ image_folder=region_conversation_osprey_image_path,
562
+ tokenizer=tokenizer,
563
+ image_processor=image_processor,
564
+ dataset_map_fn=osprey_region_conversation_map_fn,
565
+ template_map_fn=dict(
566
+ type=template_map_fn_factory, template=prompt_template),
567
+ max_length=max_length,
568
+ pad_image_to_square=True,
569
+ debug=False,
570
+ repeats=1,
571
+ )
572
+
573
+ mdpv_detailed_description_ade20k_dataset = dict(
574
+ type=MDPVPointDetailedCaptionDataset,
575
+ data_path=mdpv_detailed_caption_ade20k_data_path,
576
+ image_folder=mdpv_detailed_caption_ade20k_image_path,
577
+ tokenizer=tokenizer,
578
+ image_processor=image_processor,
579
+ dataset_map_fn=mdpv_points_map_fn,
580
+ template_map_fn=dict(
581
+ type=template_map_fn_factory, template=prompt_template),
582
+ max_length=max_length,
583
+ pad_image_to_square=True,
584
+ debug=False,
585
+ repeats=1,
586
+ )
587
+
588
+ mdpv_detailed_description_cocostuff_10k_dataset = dict(
589
+ type=MDPVPointDetailedCaptionDataset,
590
+ data_path=mdpv_detailed_caption_cocostuff_10k_data_path,
591
+ image_folder=mdpv_detailed_caption_cocostuff_10k_image_path,
592
+ tokenizer=tokenizer,
593
+ image_processor=image_processor,
594
+ dataset_map_fn=mdpv_points_map_fn,
595
+ template_map_fn=dict(
596
+ type=template_map_fn_factory, template=prompt_template),
597
+ max_length=max_length,
598
+ pad_image_to_square=True,
599
+ debug=False,
600
+ repeats=1,
601
+ )
602
+
603
+ mdpv_detailed_description_cocostuff_164k_dataset = dict(
604
+ type=MDPVPointDetailedCaptionDataset,
605
+ data_path=mdpv_detailed_caption_cocostuff_164k_data_path,
606
+ image_folder=mdpv_detailed_caption_cocostuff_164k_image_path,
607
+ tokenizer=tokenizer,
608
+ image_processor=image_processor,
609
+ dataset_map_fn=mdpv_points_map_fn,
610
+ template_map_fn=dict(
611
+ type=template_map_fn_factory, template=prompt_template),
612
+ max_length=max_length,
613
+ pad_image_to_square=True,
614
+ debug=False,
615
+ repeats=1,
616
+ )
617
+
618
+ mdpv_detailed_description_vg_dataset = dict(
619
+ type=MDPVPointDetailedCaptionDataset,
620
+ data_path=mdpv_detailed_caption_vg_data_path,
621
+ image_folder=mdpv_detailed_caption_vg_image_path,
622
+ tokenizer=tokenizer,
623
+ image_processor=image_processor,
624
+ dataset_map_fn=mdpv_points_map_fn,
625
+ template_map_fn=dict(
626
+ type=template_map_fn_factory, template=prompt_template),
627
+ max_length=max_length,
628
+ pad_image_to_square=True,
629
+ debug=False,
630
+ repeats=1,
631
+ )
632
+
633
+ mdpv_brief_description_vg_dataset = dict(
634
+ type=MDPVPointBriefCaptionDataset,
635
+ data_path=mdpv_brief_caption_vg_data_path,
636
+ image_folder=mdpv_brief_caption_vg_image_path,
637
+ tokenizer=tokenizer,
638
+ image_processor=image_processor,
639
+ dataset_map_fn=mdpv_points_map_fn,
640
+ template_map_fn=dict(
641
+ type=template_map_fn_factory, template=prompt_template),
642
+ max_length=max_length,
643
+ pad_image_to_square=True,
644
+ debug=False,
645
+ repeats=1,
646
+ )
647
+
648
+ mdpv_brief_description_cocostuff10k_dataset = dict(
649
+ type=MDPVPointBriefCaptionDataset,
650
+ data_path=mdpv_brief_caption_cocostuff_10k_data_path,
651
+ image_folder=mdpv_brief_caption_cocostuff_10k_image_path,
652
+ tokenizer=tokenizer,
653
+ image_processor=image_processor,
654
+ dataset_map_fn=mdpv_points_map_fn,
655
+ template_map_fn=dict(
656
+ type=template_map_fn_factory, template=prompt_template),
657
+ max_length=max_length,
658
+ pad_image_to_square=True,
659
+ debug=False,
660
+ repeats=1,
661
+ )
662
+
663
+ mdpv_brief_description_cocostuff164k_dataset = dict(
664
+ type=MDPVPointBriefCaptionDataset,
665
+ data_path=mdpv_brief_caption_cocostuff_164k_data_path,
666
+ image_folder=mdpv_brief_caption_cocostuff_164k_image_path,
667
+ tokenizer=tokenizer,
668
+ image_processor=image_processor,
669
+ dataset_map_fn=mdpv_points_map_fn,
670
+ template_map_fn=dict(
671
+ type=template_map_fn_factory, template=prompt_template),
672
+ max_length=max_length,
673
+ pad_image_to_square=True,
674
+ debug=False,
675
+ repeats=1,
676
+ )
677
+
678
+ mdpv_brief_description_ade20k_dataset = dict(
679
+ type=MDPVPointBriefCaptionDataset,
680
+ data_path=mdpv_brief_caption_ade20k_data_path,
681
+ image_folder=mdpv_brief_caption_ade20k_image_path,
682
+ tokenizer=tokenizer,
683
+ image_processor=image_processor,
684
+ dataset_map_fn=mdpv_points_map_fn,
685
+ template_map_fn=dict(
686
+ type=template_map_fn_factory, template=prompt_template),
687
+ max_length=max_length,
688
+ pad_image_to_square=True,
689
+ debug=False,
690
+ repeats=1,
691
+ )
692
+
693
+ mdpv_brief_description_lvis_dataset = dict(
694
+ type=MDPVPointBriefCaptionDataset,
695
+ data_path=mdpv_brief_caption_lvis_data_path,
696
+ image_folder=mdpv_brief_caption_lvis_image_path,
697
+ tokenizer=tokenizer,
698
+ image_processor=image_processor,
699
+ dataset_map_fn=mdpv_points_map_fn,
700
+ template_map_fn=dict(
701
+ type=template_map_fn_factory, template=prompt_template),
702
+ max_length=max_length,
703
+ pad_image_to_square=True,
704
+ debug=False,
705
+ repeats=1,
706
+ )
707
+
708
+ mdpv_qa_vg_dataset = dict(
709
+ type=MDPVPointBriefCaptionDataset,
710
+ data_path=mdpv_qa_vg_data_path,
711
+ image_folder=mdpv_qa_vg_image_path,
712
+ tokenizer=tokenizer,
713
+ image_processor=image_processor,
714
+ dataset_map_fn=mdpv_points_map_fn,
715
+ template_map_fn=dict(
716
+ type=template_map_fn_factory, template=prompt_template),
717
+ max_length=max_length,
718
+ pad_image_to_square=True,
719
+ debug=False,
720
+ repeats=1,
721
+ )
722
+
723
+ mdpv_qa_ade20k_dataset = dict(
724
+ type=MDPVPointBriefCaptionDataset,
725
+ data_path=mdpv_qa_ade20k_data_path,
726
+ image_folder=mdpv_qa_ade20k_image_path,
727
+ tokenizer=tokenizer,
728
+ image_processor=image_processor,
729
+ dataset_map_fn=mdpv_points_map_fn,
730
+ template_map_fn=dict(
731
+ type=template_map_fn_factory, template=prompt_template),
732
+ max_length=max_length,
733
+ pad_image_to_square=True,
734
+ debug=False,
735
+ repeats=1,
736
+ )
737
+
738
+ mdpv_qa_lvis_dataset = dict(
739
+ type=MDPVPointBriefCaptionDataset,
740
+ data_path=mdpv_qa_lvis_data_path,
741
+ image_folder=mdpv_qa_lvis_image_path,
742
+ tokenizer=tokenizer,
743
+ image_processor=image_processor,
744
+ dataset_map_fn=mdpv_points_map_fn,
745
+ template_map_fn=dict(
746
+ type=template_map_fn_factory, template=prompt_template),
747
+ max_length=max_length,
748
+ pad_image_to_square=True,
749
+ debug=False,
750
+ repeats=1,
751
+ )
752
+
753
+ mdpv_qa_cocostuff10k_dataset = dict(
754
+ type=MDPVPointBriefCaptionDataset,
755
+ data_path=mdpv_qa_cocostuff10k_data_path,
756
+ image_folder=mdpv_qa_cocostuff10k_image_path,
757
+ tokenizer=tokenizer,
758
+ image_processor=image_processor,
759
+ dataset_map_fn=mdpv_points_map_fn,
760
+ template_map_fn=dict(
761
+ type=template_map_fn_factory, template=prompt_template),
762
+ max_length=max_length,
763
+ pad_image_to_square=True,
764
+ debug=False,
765
+ repeats=1,
766
+ )
767
+
768
+ mdpv_qa_cocostuff164k_dataset = dict(
769
+ type=MDPVPointBriefCaptionDataset,
770
+ data_path=mdpv_qa_cocostuff164k_data_path,
771
+ image_folder=mdpv_qa_cocostuff164k_image_path,
772
+ tokenizer=tokenizer,
773
+ image_processor=image_processor,
774
+ dataset_map_fn=mdpv_points_map_fn,
775
+ template_map_fn=dict(
776
+ type=template_map_fn_factory, template=prompt_template),
777
+ max_length=max_length,
778
+ pad_image_to_square=True,
779
+ debug=False,
780
+ repeats=1,
781
+ )
782
+
783
+ mdpv_multi_points_openpsg_dataset = dict(
784
+ type=MDPVPointBriefCaptionDataset,
785
+ data_path=mdpv_multi_points_openpsg_data_path,
786
+ image_folder=mdpv_multi_points_openpsg_image_path,
787
+ tokenizer=tokenizer,
788
+ image_processor=image_processor,
789
+ dataset_map_fn=mdpv_points_map_fn,
790
+ template_map_fn=dict(
791
+ type=template_map_fn_factory, template=prompt_template),
792
+ max_length=max_length,
793
+ pad_image_to_square=True,
794
+ debug=False,
795
+ repeats=1,
796
+ )
797
+
798
+ mdpv_multi_points_flicker30k_dataset = dict(
799
+ type=MDPVPointBriefCaptionDataset,
800
+ data_path=mdpv_multi_points_flicker30k_data_path,
801
+ image_folder=mdpv_multi_points_flicker30k_image_path,
802
+ tokenizer=tokenizer,
803
+ image_processor=image_processor,
804
+ dataset_map_fn=mdpv_points_map_fn,
805
+ template_map_fn=dict(
806
+ type=template_map_fn_factory, template=prompt_template),
807
+ max_length=max_length,
808
+ pad_image_to_square=True,
809
+ debug=False,
810
+ repeats=1,
811
+ )
812
+
813
+ train_dataset = dict(
814
+ type=CombineDataset,
815
+ datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset,
816
+ glamm_grandf_dataset, glamm_psg_dataset,
817
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x
818
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset,
819
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset,
820
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
821
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x
822
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
823
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset,
824
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
825
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset,
826
+ region_cap_osprey_dataset, region_conversation_osprey_dataset,
827
+ mdpv_detailed_description_ade20k_dataset,
828
+ mdpv_detailed_description_cocostuff_10k_dataset,
829
+ mdpv_detailed_description_cocostuff_164k_dataset,
830
+ mdpv_detailed_description_vg_dataset,
831
+ mdpv_brief_description_lvis_dataset,
832
+ mdpv_brief_description_vg_dataset,
833
+ mdpv_brief_description_ade20k_dataset,
834
+ mdpv_brief_description_cocostuff10k_dataset,
835
+ mdpv_brief_description_cocostuff164k_dataset,
836
+ mdpv_qa_vg_dataset,
837
+ mdpv_qa_lvis_dataset,
838
+ mdpv_qa_ade20k_dataset,
839
+ mdpv_qa_cocostuff10k_dataset,
840
+ mdpv_qa_cocostuff164k_dataset,
841
+ mdpv_multi_points_flicker30k_dataset,
842
+ mdpv_multi_points_openpsg_dataset,],
843
+ )
844
+
845
+ train_dataloader = dict(
846
+ batch_size=batch_size,
847
+ num_workers=dataloader_num_workers,
848
+ dataset=train_dataset,
849
+ sampler=dict(
850
+ type=LengthGroupedSampler,
851
+ length_property='modality_length',
852
+ per_device_batch_size=batch_size * accumulative_counts),
853
+ collate_fn=dict(type=omg_llava_collate_fn))
854
+
855
+ #######################################################################
856
+ # PART 4 Scheduler & Optimizer #
857
+ #######################################################################
858
+ # optimizer
859
+ optim_wrapper = dict(
860
+ type=AmpOptimWrapper,
861
+ optimizer=dict(
862
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
863
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
864
+ accumulative_counts=accumulative_counts,
865
+ loss_scale='dynamic',
866
+ dtype='float16')
867
+
868
+ # learning policy
869
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
870
+ param_scheduler = [
871
+ dict(
872
+ type=LinearLR,
873
+ start_factor=1e-5,
874
+ by_epoch=True,
875
+ begin=0,
876
+ end=warmup_ratio * max_epochs,
877
+ convert_to_iter_based=True),
878
+ dict(
879
+ type=CosineAnnealingLR,
880
+ eta_min=0.0,
881
+ by_epoch=True,
882
+ begin=warmup_ratio * max_epochs,
883
+ end=max_epochs,
884
+ convert_to_iter_based=True)
885
+ ]
886
+
887
+ # train, val, test setting
888
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
889
+
890
+ #######################################################################
891
+ # PART 5 Runtime #
892
+ #######################################################################
893
+ # Log the dialogue periodically during the training process, optional
894
+ custom_hooks = [
895
+ dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer),
896
+ dict(
897
+ type=EvaluateChatHook_withSpecialTokens,
898
+ tokenizer=tokenizer,
899
+ image_processor=image_processor,
900
+ every_n_iters=evaluation_freq,
901
+ evaluation_inputs=evaluation_inputs,
902
+ evaluation_images=evaluation_images,
903
+ system=SYSTEM,
904
+ prompt_template=prompt_template)
905
+ ]
906
+
907
+ # configure default hooks
908
+ default_hooks = dict(
909
+ # record the time of every iteration.
910
+ timer=dict(type=IterTimerHook),
911
+ # print log every 10 iterations.
912
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
913
+ # enable the parameter scheduler.
914
+ param_scheduler=dict(type=ParamSchedulerHook),
915
+ # save checkpoint per `save_steps`.
916
+ checkpoint=dict(
917
+ type=CheckpointHook,
918
+ by_epoch=False,
919
+ interval=save_steps,
920
+ max_keep_ckpts=save_total_limit),
921
+ # set sampler seed in distributed evrionment.
922
+ sampler_seed=dict(type=DistSamplerSeedHook),
923
+ )
924
+
925
+ # configure environment
926
+ env_cfg = dict(
927
+ # whether to enable cudnn benchmark
928
+ cudnn_benchmark=False,
929
+ # set multi process parameters
930
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
931
+ # set distributed parameters
932
+ dist_cfg=dict(backend='nccl'),
933
+ )
934
+
935
+ # set visualizer
936
+ visualizer = None
937
+
938
+ # set log level
939
+ log_level = 'INFO'
940
+
941
+ # load from which checkpoint
942
+ load_from = None
943
+
944
+ # whether to resume training from the loaded checkpoint
945
+ resume = False
946
+
947
+ # Defaults to use random seed and disable `deterministic`
948
+ randomness = dict(seed=None, deterministic=False)
949
+
950
+ # set log processor
951
+ log_processor = dict(by_epoch=False)
omg_llava/configs/finetune/omg_llava_20b_finetune_stage1_1024image_8gpus.py ADDED
@@ -0,0 +1,993 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
4
+ LoggerHook, ParamSchedulerHook)
5
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
6
+ from peft import LoraConfig
7
+ from torch.optim import AdamW
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig, CLIPImageProcessor,
10
+ CLIPVisionModel)
11
+
12
+ from omg_llava.dataset import LLaVADataset
13
+ from omg_llava.dataset.collect_fns import omg_llava_collate_fn
14
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
15
+ from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\
16
+ CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\
17
+ ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\
18
+ PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\
19
+ RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\
20
+ Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\
21
+ OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\
22
+ OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\
23
+ MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset
24
+ from xtuner.dataset.samplers import LengthGroupedSampler
25
+ from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens
26
+ from xtuner.engine.runner import TrainLoop
27
+ from omg_llava.model import OMG_LLaVA
28
+ from xtuner.utils import PROMPT_TEMPLATE
29
+ from omg_llava.model import OpenCLIPBackbone_omgseg
30
+ from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead
31
+
32
+ from torch.nn import GroupNorm, ReLU
33
+
34
+ from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \
35
+ DiceLoss, MaskFormerFusionHead, FocalLoss
36
+ from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost
37
+ from mmdet.models.task_modules.samplers import MaskPseudoSampler
38
+
39
+ #######################################################################
40
+ # PART 1 Settings #
41
+ #######################################################################
42
+ # Model
43
+ llm_name_or_path = './pretrained/omg_llava/internlm2-chat-20b' # Please change to your own path
44
+ pretrained_pth = './pretrained/omg_llava/omg_llava_20b_pretrain_1024image_8gpus.pth'
45
+ omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path
46
+ omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path
47
+
48
+ # Data
49
+ data_root = './data/llava_data/'
50
+ data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json'
51
+ image_folder = data_root + 'llava_images'
52
+
53
+ glamm_data_root = './data/glamm_data/'
54
+
55
+ refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/'
56
+ refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json'
57
+
58
+ grandf_image_path = glamm_data_root + 'images/grandf/train/'
59
+ grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json'
60
+
61
+ flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/'
62
+ flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json'
63
+
64
+ psg_image_path = glamm_data_root + 'images/coco2017/'
65
+ psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json'
66
+
67
+ ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
68
+ ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json'
69
+
70
+ cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/'
71
+ cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt'
72
+ cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/'
73
+
74
+ mapillary_image_path = './data/semantic_seg/mapillary/training/images/'
75
+ mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json'
76
+ mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/'
77
+
78
+ pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/'
79
+ pascal_file = './data/semantic_seg/pascal_part/train.json'
80
+
81
+ paco_image_path = './data/glamm_data/images/coco2017/'
82
+ paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json'
83
+
84
+ referring_refcoco_image_path = refcocog_image_path
85
+ referring_refcoco_data_path = "./data/ref_seg/"
86
+
87
+ referring_refcoco_plus_image_path = refcocog_image_path
88
+ referring_refcoco_plus_data_path = "./data/ref_seg/"
89
+
90
+ referring_refcocog_image_path = refcocog_image_path
91
+ referring_refcocog_data_path = "./data/ref_seg/"
92
+
93
+ referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/"
94
+ referring_refclef_data_path = "./data/ref_seg/"
95
+
96
+ region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
97
+ region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json"
98
+
99
+ region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
100
+ region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json"
101
+
102
+ mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
103
+ mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json'
104
+
105
+ mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
106
+ mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json'
107
+
108
+ mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
109
+ mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json'
110
+
111
+ mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
112
+ mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json'
113
+
114
+ mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
115
+ mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json'
116
+
117
+ mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
118
+ mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json'
119
+
120
+ mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
121
+ mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json'
122
+
123
+ mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
124
+ mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json'
125
+
126
+ mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
127
+ mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json'
128
+
129
+ mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
130
+ mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json'
131
+
132
+ mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
133
+ mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json'
134
+
135
+ mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017'
136
+ mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json'
137
+
138
+ mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
139
+ mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json'
140
+
141
+ mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
142
+ mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json'
143
+
144
+ mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/'
145
+ mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json'
146
+
147
+ mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017'
148
+ mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json'
149
+
150
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
151
+ max_length = int(2048 - (1024 / 64)**2 - 100)
152
+
153
+ # Scheduler & Optimizer
154
+ batch_size = 8 # per_device
155
+ accumulative_counts = 4
156
+ dataloader_num_workers = 4
157
+ max_epochs = 1
158
+ optim_type = AdamW
159
+ lr = 2e-4
160
+ betas = (0.9, 0.999)
161
+ weight_decay = 0
162
+ max_norm = 1 # grad clip
163
+ warmup_ratio = 0.03
164
+
165
+
166
+ # Save
167
+ save_steps = 2000
168
+ save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited)
169
+
170
+ # Evaluate the generation performance during the training
171
+ evaluation_freq = 2000
172
+ SYSTEM = ''
173
+ evaluation_images = './work_dirs/test.jpg'
174
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture',
175
+ 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.']
176
+
177
+ #######################################################################
178
+ # PART 2 Model & Tokenizer & Image Processor #
179
+ #######################################################################
180
+ tokenizer = dict(
181
+ type=AutoTokenizer.from_pretrained,
182
+ pretrained_model_name_or_path=llm_name_or_path,
183
+ trust_remote_code=True,
184
+ padding_side='right')
185
+
186
+ image_processor = dict(
187
+ type=CLIPImageProcessor,
188
+ do_resize=True,
189
+ size=1024,
190
+ resample=3,
191
+ do_center_crop=True,
192
+ crop_size=1024,
193
+ do_rescale=True,
194
+ do_normalize=True,
195
+ image_mean=[0.4814, 0.4578, 0.4082],
196
+ image_std=[0.2686, 0.2613, 0.2757],
197
+ do_convert_rgb=True
198
+ )
199
+
200
+ class_embed = 'convnext_large_d_320_CocoPanopticOVDataset'
201
+ num_things_classes = 80
202
+ num_stuff_classes = 53
203
+ num_classes = num_things_classes + num_stuff_classes
204
+
205
+ omgseg_model = dict(
206
+ type=OMGSegVisualEncoder,
207
+ data_preprocessor=None,
208
+ pixel_shuffle_down_ratio=2,
209
+ backbone=dict(
210
+ type=OpenCLIPBackbone_omgseg,
211
+ model_name='convnext_large_d_320',
212
+ fix=True,
213
+ init_cfg=dict(
214
+ type='clip_pretrain',
215
+ checkpoint='laion2b_s29b_b131k_ft_soup'
216
+ )
217
+ ),
218
+ panoptic_head=dict(
219
+ type=Mask2FormerVideoSemSamHead,
220
+ sphere_cls=True,
221
+ ov_path=omg_ov_class_embed_path,
222
+ enable_box_query=False,
223
+ ov_classifier_name=class_embed,
224
+ logit=None,
225
+ in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside
226
+ strides=[4, 8, 16, 32],
227
+ feat_channels=256,
228
+ out_channels=256,
229
+ num_things_classes=num_things_classes,
230
+ num_stuff_classes=num_stuff_classes,
231
+ num_queries=300,
232
+ num_transformer_feat_level=3,
233
+ pixel_decoder=dict(
234
+ type=MSDeformAttnPixelDecoder,
235
+ num_outs=3,
236
+ norm_cfg=dict(type=GroupNorm, num_groups=32),
237
+ act_cfg=dict(type=ReLU),
238
+ encoder=dict( # DeformableDetrTransformerEncoder
239
+ num_layers=6,
240
+ layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
241
+ self_attn_cfg=dict( # MultiScaleDeformableAttention
242
+ embed_dims=256,
243
+ num_heads=8,
244
+ num_levels=3,
245
+ num_points=4,
246
+ dropout=0.0,
247
+ batch_first=True),
248
+ ffn_cfg=dict(
249
+ embed_dims=256,
250
+ feedforward_channels=1024,
251
+ num_fcs=2,
252
+ ffn_drop=0.0,
253
+ act_cfg=dict(type=ReLU, inplace=True)))),
254
+ positional_encoding=dict(num_feats=128, normalize=True)),
255
+ enforce_decoder_input_project=False,
256
+ positional_encoding=dict(num_feats=128, normalize=True),
257
+ transformer_decoder=dict( # Mask2FormerTransformerDecoder
258
+ return_intermediate=True,
259
+ num_layers=9,
260
+ layer_cfg=dict( # Mask2FormerTransformerDecoderLayer
261
+ self_attn_cfg=dict( # MultiheadAttention
262
+ embed_dims=256,
263
+ num_heads=8,
264
+ dropout=0.0,
265
+ batch_first=True),
266
+ cross_attn_cfg=dict( # MultiheadAttention
267
+ embed_dims=256,
268
+ num_heads=8,
269
+ dropout=0.0,
270
+ batch_first=True),
271
+ ffn_cfg=dict(
272
+ embed_dims=256,
273
+ feedforward_channels=2048,
274
+ num_fcs=2,
275
+ ffn_drop=0.0,
276
+ act_cfg=dict(type='ReLU', inplace=True))),
277
+ init_cfg=None),
278
+ loss_cls=dict(
279
+ type=CrossEntropyLoss,
280
+ use_sigmoid=False,
281
+ loss_weight=2.0,
282
+ reduction='mean',
283
+ class_weight=[1.0] * 240 + [0.1]),
284
+ loss_mask=dict(
285
+ type=CrossEntropyLoss,
286
+ use_sigmoid=True,
287
+ reduction='mean',
288
+ loss_weight=5.0),
289
+ loss_dice=dict(
290
+ type=DiceLoss,
291
+ use_sigmoid=True,
292
+ activate=True,
293
+ reduction='mean',
294
+ naive_dice=True,
295
+ eps=1.0,
296
+ loss_weight=5.0),
297
+ loss_iou=dict(
298
+ type=FocalLoss,
299
+ use_sigmoid=True,
300
+ loss_weight=2.0,
301
+ reduction='mean')
302
+ ),
303
+ panoptic_fusion_head=dict(
304
+ type=MaskFormerFusionHead,
305
+ num_things_classes=num_things_classes,
306
+ num_stuff_classes=num_stuff_classes,
307
+ loss_panoptic=None,
308
+ init_cfg=None),
309
+ train_cfg=dict(
310
+ num_points=12544,
311
+ oversample_ratio=3.0,
312
+ importance_sample_ratio=0.75,
313
+ assigner=dict(
314
+ type=HungarianAssigner,
315
+ match_costs=[
316
+ # dict(type=FlexibleClassificationCost, weight=2.0),
317
+ dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True),
318
+ dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0)
319
+ ]),
320
+ sampler=dict(type=MaskPseudoSampler)),
321
+ test_cfg=dict(
322
+ panoptic_on=True,
323
+ # For now, the dataset does not support
324
+ # evaluating semantic segmentation metric.
325
+ semantic_on=False,
326
+ instance_on=True,
327
+ # max_per_image is for instance segmentation.
328
+ max_per_image=100,
329
+ iou_thr=0.8,
330
+ # In Mask2Former's panoptic postprocessing,
331
+ # it will filter mask area where score is less than 0.5 .
332
+ filter_low_score=True),
333
+ init_cfg=dict(
334
+ type='Pretrained',
335
+ checkpoint=omg_head_pretrain_pth_path,
336
+ )
337
+ )
338
+
339
+ model = dict(
340
+ type=OMG_LLaVA,
341
+ freeze_llm=True,
342
+ freeze_visual_encoder=True,
343
+ require_omg_decoder=False,
344
+ pretrained_pth=pretrained_pth,
345
+ text2vision_projector=True,
346
+ pixel_shuffle_ratio=2,
347
+ llm=dict(
348
+ type=AutoModelForCausalLM.from_pretrained,
349
+ pretrained_model_name_or_path=llm_name_or_path,
350
+ trust_remote_code=True,
351
+ torch_dtype=torch.float16,
352
+ quantization_config=dict(
353
+ type=BitsAndBytesConfig,
354
+ load_in_4bit=True,
355
+ load_in_8bit=False,
356
+ llm_int8_threshold=6.0,
357
+ llm_int8_has_fp16_weight=False,
358
+ bnb_4bit_compute_dtype=torch.float16,
359
+ bnb_4bit_use_double_quant=True,
360
+ bnb_4bit_quant_type='nf4')),
361
+ llm_lora=dict(
362
+ type=LoraConfig,
363
+ r=512,
364
+ lora_alpha=256,
365
+ lora_dropout=0.05,
366
+ bias='none',
367
+ task_type='CAUSAL_LM'),
368
+ visual_encoder=omgseg_model,
369
+ tokenizer=tokenizer,
370
+ )
371
+
372
+ #######################################################################
373
+ # PART 3 Dataset & Dataloader #
374
+ #######################################################################
375
+ debug=False
376
+ llava_dataset = dict(
377
+ type=LLaVADataset,
378
+ data_path=data_path,
379
+ image_folder=image_folder,
380
+ tokenizer=tokenizer,
381
+ image_processor=image_processor,
382
+ dataset_map_fn=llava_map_fn,
383
+ template_map_fn=dict(
384
+ type=template_map_fn_factory, template=prompt_template),
385
+ max_length=max_length,
386
+ pad_image_to_square=True)
387
+
388
+ glamm_refcocog_dataset = dict(
389
+ type=RefCOCOgGCGDataset,
390
+ data_path=refcocog_ann_file,
391
+ image_folder=refcocog_image_path,
392
+ tokenizer=tokenizer,
393
+ image_processor=image_processor,
394
+ dataset_map_fn=glamm_refcocog_map_fn,
395
+ template_map_fn=dict(
396
+ type=template_map_fn_factory, template=prompt_template),
397
+ max_length=max_length,
398
+ pad_image_to_square=True,
399
+ debug=False,
400
+ repeats=1,
401
+ )
402
+
403
+ glamm_grandf_dataset = dict(
404
+ type=GranDfGCGDataset,
405
+ data_path=grandf_ann_file,
406
+ image_folder=grandf_image_path,
407
+ tokenizer=tokenizer,
408
+ image_processor=image_processor,
409
+ dataset_map_fn=glamm_granf_map_fn,
410
+ template_map_fn=dict(
411
+ type=template_map_fn_factory, template=prompt_template),
412
+ max_length=max_length,
413
+ pad_image_to_square=True,
414
+ debug=debug,
415
+ repeats=10,
416
+ )
417
+
418
+ glamm_psg_dataset = dict(
419
+ type=OpenPsgGCGDataset,
420
+ data_path=psg_ann_file,
421
+ image_folder=psg_image_path,
422
+ tokenizer=tokenizer,
423
+ image_processor=image_processor,
424
+ dataset_map_fn=glamm_openpsg_map_fn,
425
+ template_map_fn=dict(
426
+ type=template_map_fn_factory, template=prompt_template),
427
+ max_length=max_length,
428
+ pad_image_to_square=True,
429
+ debug=debug,
430
+ repeats=1,
431
+ )
432
+
433
+ glamm_flickr_dataset = dict(
434
+ type=FlickrGCGDataset,
435
+ data_path=flickr_ann_file,
436
+ image_folder=flickr_image_path,
437
+ tokenizer=tokenizer,
438
+ image_processor=image_processor,
439
+ dataset_map_fn=glamm_flickr_map_fn,
440
+ template_map_fn=dict(
441
+ type=template_map_fn_factory, template=prompt_template),
442
+ max_length=max_length,
443
+ pad_image_to_square=True,
444
+ debug=debug,
445
+ repeats=1,
446
+ )
447
+
448
+ semantic_seg_ade20k_dataset = dict(
449
+ type=ADE20kSemanticSegDataset,
450
+ data_path=ade20k_class_file,
451
+ image_folder=ade20k_image_path,
452
+ tokenizer=tokenizer,
453
+ image_processor=image_processor,
454
+ dataset_map_fn=semantic_seg_map_fn,
455
+ template_map_fn=dict(
456
+ type=template_map_fn_factory, template=prompt_template),
457
+ max_length=max_length,
458
+ pad_image_to_square=True,
459
+ debug=False,
460
+ repeats=1,
461
+ )
462
+
463
+ semantic_seg_cocostuff_dataset = dict(
464
+ type=COCOStuffSemanticSegDataset,
465
+ data_path=cocostuff_class_file,
466
+ image_folder=cocostuff_image_path,
467
+ label_path=cocostuff_label_path,
468
+ tokenizer=tokenizer,
469
+ image_processor=image_processor,
470
+ dataset_map_fn=semantic_seg_map_fn,
471
+ template_map_fn=dict(
472
+ type=template_map_fn_factory, template=prompt_template),
473
+ max_length=max_length,
474
+ pad_image_to_square=True,
475
+ debug=False,
476
+ repeats=1,
477
+ )
478
+
479
+ semantic_seg_mapillary_dataset = dict(
480
+ type=MapillarySemanticSegDataset,
481
+ data_path=mapillary_class_file,
482
+ image_folder=mapillary_image_path,
483
+ label_path=mapillary_label_path,
484
+ tokenizer=tokenizer,
485
+ image_processor=image_processor,
486
+ dataset_map_fn=semantic_seg_map_fn,
487
+ template_map_fn=dict(
488
+ type=template_map_fn_factory, template=prompt_template),
489
+ max_length=max_length,
490
+ pad_image_to_square=True,
491
+ debug=False,
492
+ repeats=1,
493
+ )
494
+
495
+ semantic_seg_pascal_part_dataset = dict(
496
+ type=PascalPartSemanticSegDataset,
497
+ data_path=pascal_file,
498
+ image_folder=pascal_part_image_path,
499
+ tokenizer=tokenizer,
500
+ image_processor=image_processor,
501
+ dataset_map_fn=pascal_part_map_fn,
502
+ template_map_fn=dict(
503
+ type=template_map_fn_factory, template=prompt_template),
504
+ max_length=max_length,
505
+ pad_image_to_square=True,
506
+ debug=False,
507
+ repeats=1,
508
+ )
509
+
510
+ semantic_seg_paco_dataset = dict(
511
+ type=PacoSemanticSegDataset,
512
+ data_path=paco_file,
513
+ image_folder=paco_image_path,
514
+ tokenizer=tokenizer,
515
+ image_processor=image_processor,
516
+ dataset_map_fn=pascal_part_map_fn,
517
+ template_map_fn=dict(
518
+ type=template_map_fn_factory, template=prompt_template),
519
+ max_length=max_length,
520
+ pad_image_to_square=True,
521
+ debug=False,
522
+ repeats=1,
523
+ )
524
+
525
+ referring_seg_refcoco_dataset = dict(
526
+ type=RefcocoReferringSegDataset,
527
+ data_path=referring_refcoco_data_path,
528
+ image_folder=referring_refcoco_image_path,
529
+ tokenizer=tokenizer,
530
+ image_processor=image_processor,
531
+ dataset_map_fn=referring_seg_map_fn,
532
+ template_map_fn=dict(
533
+ type=template_map_fn_factory, template=prompt_template),
534
+ max_length=max_length,
535
+ pad_image_to_square=True,
536
+ debug=False,
537
+ repeats=1,
538
+ )
539
+
540
+ referring_seg_refcoco_plus_dataset = dict(
541
+ type=Refcoco_plus_ReferringSegDataset,
542
+ data_path=referring_refcoco_plus_data_path,
543
+ image_folder=referring_refcoco_plus_image_path,
544
+ tokenizer=tokenizer,
545
+ image_processor=image_processor,
546
+ dataset_map_fn=referring_seg_map_fn,
547
+ template_map_fn=dict(
548
+ type=template_map_fn_factory, template=prompt_template),
549
+ max_length=max_length,
550
+ pad_image_to_square=True,
551
+ debug=False,
552
+ repeats=1,
553
+ )
554
+
555
+ referring_seg_refcocog_dataset = dict(
556
+ type=Refcocog_ReferringSegDataset,
557
+ data_path=referring_refcocog_data_path,
558
+ image_folder=referring_refcocog_image_path,
559
+ tokenizer=tokenizer,
560
+ image_processor=image_processor,
561
+ dataset_map_fn=referring_seg_map_fn,
562
+ template_map_fn=dict(
563
+ type=template_map_fn_factory, template=prompt_template),
564
+ max_length=max_length,
565
+ pad_image_to_square=True,
566
+ debug=False,
567
+ repeats=1,
568
+ )
569
+
570
+ referring_seg_refclef_dataset = dict(
571
+ type=Refclef_ReferringSegDataset,
572
+ data_path=referring_refclef_data_path,
573
+ image_folder=referring_refclef_image_path,
574
+ tokenizer=tokenizer,
575
+ image_processor=image_processor,
576
+ dataset_map_fn=referring_seg_map_fn,
577
+ template_map_fn=dict(
578
+ type=template_map_fn_factory, template=prompt_template),
579
+ max_length=max_length,
580
+ pad_image_to_square=True,
581
+ debug=False,
582
+ repeats=1,
583
+ )
584
+
585
+ region_cap_osprey_dataset = dict(
586
+ type=OspreyRegionCaptionDataset,
587
+ data_path=region_cap_osprey_data_path,
588
+ image_folder=region_cap_osprey_image_path,
589
+ tokenizer=tokenizer,
590
+ image_processor=image_processor,
591
+ dataset_map_fn=osprey_region_caption_map_fn,
592
+ template_map_fn=dict(
593
+ type=template_map_fn_factory, template=prompt_template),
594
+ max_length=max_length,
595
+ pad_image_to_square=True,
596
+ debug=False,
597
+ repeats=1,
598
+ )
599
+
600
+ region_conversation_osprey_dataset = dict(
601
+ type=OspreyRegionConversationDataset,
602
+ data_path=region_conversation_osprey_data_path,
603
+ image_folder=region_conversation_osprey_image_path,
604
+ tokenizer=tokenizer,
605
+ image_processor=image_processor,
606
+ dataset_map_fn=osprey_region_conversation_map_fn,
607
+ template_map_fn=dict(
608
+ type=template_map_fn_factory, template=prompt_template),
609
+ max_length=max_length,
610
+ pad_image_to_square=True,
611
+ debug=False,
612
+ repeats=1,
613
+ )
614
+
615
+ mdpv_detailed_description_ade20k_dataset = dict(
616
+ type=MDPVPointDetailedCaptionDataset,
617
+ data_path=mdpv_detailed_caption_ade20k_data_path,
618
+ image_folder=mdpv_detailed_caption_ade20k_image_path,
619
+ tokenizer=tokenizer,
620
+ image_processor=image_processor,
621
+ dataset_map_fn=mdpv_points_map_fn,
622
+ template_map_fn=dict(
623
+ type=template_map_fn_factory, template=prompt_template),
624
+ max_length=max_length,
625
+ pad_image_to_square=True,
626
+ debug=False,
627
+ repeats=1,
628
+ )
629
+
630
+ mdpv_detailed_description_cocostuff_10k_dataset = dict(
631
+ type=MDPVPointDetailedCaptionDataset,
632
+ data_path=mdpv_detailed_caption_cocostuff_10k_data_path,
633
+ image_folder=mdpv_detailed_caption_cocostuff_10k_image_path,
634
+ tokenizer=tokenizer,
635
+ image_processor=image_processor,
636
+ dataset_map_fn=mdpv_points_map_fn,
637
+ template_map_fn=dict(
638
+ type=template_map_fn_factory, template=prompt_template),
639
+ max_length=max_length,
640
+ pad_image_to_square=True,
641
+ debug=False,
642
+ repeats=1,
643
+ )
644
+
645
+ mdpv_detailed_description_cocostuff_164k_dataset = dict(
646
+ type=MDPVPointDetailedCaptionDataset,
647
+ data_path=mdpv_detailed_caption_cocostuff_164k_data_path,
648
+ image_folder=mdpv_detailed_caption_cocostuff_164k_image_path,
649
+ tokenizer=tokenizer,
650
+ image_processor=image_processor,
651
+ dataset_map_fn=mdpv_points_map_fn,
652
+ template_map_fn=dict(
653
+ type=template_map_fn_factory, template=prompt_template),
654
+ max_length=max_length,
655
+ pad_image_to_square=True,
656
+ debug=False,
657
+ repeats=1,
658
+ )
659
+
660
+ mdpv_detailed_description_vg_dataset = dict(
661
+ type=MDPVPointDetailedCaptionDataset,
662
+ data_path=mdpv_detailed_caption_vg_data_path,
663
+ image_folder=mdpv_detailed_caption_vg_image_path,
664
+ tokenizer=tokenizer,
665
+ image_processor=image_processor,
666
+ dataset_map_fn=mdpv_points_map_fn,
667
+ template_map_fn=dict(
668
+ type=template_map_fn_factory, template=prompt_template),
669
+ max_length=max_length,
670
+ pad_image_to_square=True,
671
+ debug=False,
672
+ repeats=1,
673
+ )
674
+
675
+ mdpv_brief_description_vg_dataset = dict(
676
+ type=MDPVPointBriefCaptionDataset,
677
+ data_path=mdpv_brief_caption_vg_data_path,
678
+ image_folder=mdpv_brief_caption_vg_image_path,
679
+ tokenizer=tokenizer,
680
+ image_processor=image_processor,
681
+ dataset_map_fn=mdpv_points_map_fn,
682
+ template_map_fn=dict(
683
+ type=template_map_fn_factory, template=prompt_template),
684
+ max_length=max_length,
685
+ pad_image_to_square=True,
686
+ debug=False,
687
+ repeats=1,
688
+ )
689
+
690
+ mdpv_brief_description_cocostuff10k_dataset = dict(
691
+ type=MDPVPointBriefCaptionDataset,
692
+ data_path=mdpv_brief_caption_cocostuff_10k_data_path,
693
+ image_folder=mdpv_brief_caption_cocostuff_10k_image_path,
694
+ tokenizer=tokenizer,
695
+ image_processor=image_processor,
696
+ dataset_map_fn=mdpv_points_map_fn,
697
+ template_map_fn=dict(
698
+ type=template_map_fn_factory, template=prompt_template),
699
+ max_length=max_length,
700
+ pad_image_to_square=True,
701
+ debug=False,
702
+ repeats=1,
703
+ )
704
+
705
+ mdpv_brief_description_cocostuff164k_dataset = dict(
706
+ type=MDPVPointBriefCaptionDataset,
707
+ data_path=mdpv_brief_caption_cocostuff_164k_data_path,
708
+ image_folder=mdpv_brief_caption_cocostuff_164k_image_path,
709
+ tokenizer=tokenizer,
710
+ image_processor=image_processor,
711
+ dataset_map_fn=mdpv_points_map_fn,
712
+ template_map_fn=dict(
713
+ type=template_map_fn_factory, template=prompt_template),
714
+ max_length=max_length,
715
+ pad_image_to_square=True,
716
+ debug=False,
717
+ repeats=1,
718
+ )
719
+
720
+ mdpv_brief_description_ade20k_dataset = dict(
721
+ type=MDPVPointBriefCaptionDataset,
722
+ data_path=mdpv_brief_caption_ade20k_data_path,
723
+ image_folder=mdpv_brief_caption_ade20k_image_path,
724
+ tokenizer=tokenizer,
725
+ image_processor=image_processor,
726
+ dataset_map_fn=mdpv_points_map_fn,
727
+ template_map_fn=dict(
728
+ type=template_map_fn_factory, template=prompt_template),
729
+ max_length=max_length,
730
+ pad_image_to_square=True,
731
+ debug=False,
732
+ repeats=1,
733
+ )
734
+
735
+ mdpv_brief_description_lvis_dataset = dict(
736
+ type=MDPVPointBriefCaptionDataset,
737
+ data_path=mdpv_brief_caption_lvis_data_path,
738
+ image_folder=mdpv_brief_caption_lvis_image_path,
739
+ tokenizer=tokenizer,
740
+ image_processor=image_processor,
741
+ dataset_map_fn=mdpv_points_map_fn,
742
+ template_map_fn=dict(
743
+ type=template_map_fn_factory, template=prompt_template),
744
+ max_length=max_length,
745
+ pad_image_to_square=True,
746
+ debug=False,
747
+ repeats=1,
748
+ )
749
+
750
+ mdpv_qa_vg_dataset = dict(
751
+ type=MDPVPointBriefCaptionDataset,
752
+ data_path=mdpv_qa_vg_data_path,
753
+ image_folder=mdpv_qa_vg_image_path,
754
+ tokenizer=tokenizer,
755
+ image_processor=image_processor,
756
+ dataset_map_fn=mdpv_points_map_fn,
757
+ template_map_fn=dict(
758
+ type=template_map_fn_factory, template=prompt_template),
759
+ max_length=max_length,
760
+ pad_image_to_square=True,
761
+ debug=False,
762
+ repeats=1,
763
+ )
764
+
765
+ mdpv_qa_ade20k_dataset = dict(
766
+ type=MDPVPointBriefCaptionDataset,
767
+ data_path=mdpv_qa_ade20k_data_path,
768
+ image_folder=mdpv_qa_ade20k_image_path,
769
+ tokenizer=tokenizer,
770
+ image_processor=image_processor,
771
+ dataset_map_fn=mdpv_points_map_fn,
772
+ template_map_fn=dict(
773
+ type=template_map_fn_factory, template=prompt_template),
774
+ max_length=max_length,
775
+ pad_image_to_square=True,
776
+ debug=False,
777
+ repeats=1,
778
+ )
779
+
780
+ mdpv_qa_lvis_dataset = dict(
781
+ type=MDPVPointBriefCaptionDataset,
782
+ data_path=mdpv_qa_lvis_data_path,
783
+ image_folder=mdpv_qa_lvis_image_path,
784
+ tokenizer=tokenizer,
785
+ image_processor=image_processor,
786
+ dataset_map_fn=mdpv_points_map_fn,
787
+ template_map_fn=dict(
788
+ type=template_map_fn_factory, template=prompt_template),
789
+ max_length=max_length,
790
+ pad_image_to_square=True,
791
+ debug=False,
792
+ repeats=1,
793
+ )
794
+
795
+ mdpv_qa_cocostuff10k_dataset = dict(
796
+ type=MDPVPointBriefCaptionDataset,
797
+ data_path=mdpv_qa_cocostuff10k_data_path,
798
+ image_folder=mdpv_qa_cocostuff10k_image_path,
799
+ tokenizer=tokenizer,
800
+ image_processor=image_processor,
801
+ dataset_map_fn=mdpv_points_map_fn,
802
+ template_map_fn=dict(
803
+ type=template_map_fn_factory, template=prompt_template),
804
+ max_length=max_length,
805
+ pad_image_to_square=True,
806
+ debug=False,
807
+ repeats=1,
808
+ )
809
+
810
+ mdpv_qa_cocostuff164k_dataset = dict(
811
+ type=MDPVPointBriefCaptionDataset,
812
+ data_path=mdpv_qa_cocostuff164k_data_path,
813
+ image_folder=mdpv_qa_cocostuff164k_image_path,
814
+ tokenizer=tokenizer,
815
+ image_processor=image_processor,
816
+ dataset_map_fn=mdpv_points_map_fn,
817
+ template_map_fn=dict(
818
+ type=template_map_fn_factory, template=prompt_template),
819
+ max_length=max_length,
820
+ pad_image_to_square=True,
821
+ debug=False,
822
+ repeats=1,
823
+ )
824
+
825
+ mdpv_multi_points_openpsg_dataset = dict(
826
+ type=MDPVPointBriefCaptionDataset,
827
+ data_path=mdpv_multi_points_openpsg_data_path,
828
+ image_folder=mdpv_multi_points_openpsg_image_path,
829
+ tokenizer=tokenizer,
830
+ image_processor=image_processor,
831
+ dataset_map_fn=mdpv_points_map_fn,
832
+ template_map_fn=dict(
833
+ type=template_map_fn_factory, template=prompt_template),
834
+ max_length=max_length,
835
+ pad_image_to_square=True,
836
+ debug=False,
837
+ repeats=1,
838
+ )
839
+
840
+ mdpv_multi_points_flicker30k_dataset = dict(
841
+ type=MDPVPointBriefCaptionDataset,
842
+ data_path=mdpv_multi_points_flicker30k_data_path,
843
+ image_folder=mdpv_multi_points_flicker30k_image_path,
844
+ tokenizer=tokenizer,
845
+ image_processor=image_processor,
846
+ dataset_map_fn=mdpv_points_map_fn,
847
+ template_map_fn=dict(
848
+ type=template_map_fn_factory, template=prompt_template),
849
+ max_length=max_length,
850
+ pad_image_to_square=True,
851
+ debug=False,
852
+ repeats=1,
853
+ )
854
+
855
+ train_dataset = dict(
856
+ type=CombineDataset,
857
+ datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset,
858
+ glamm_grandf_dataset, glamm_psg_dataset,
859
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x
860
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset,
861
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset,
862
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
863
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x
864
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
865
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset,
866
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
867
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset,
868
+ region_cap_osprey_dataset, region_conversation_osprey_dataset,
869
+ mdpv_detailed_description_ade20k_dataset,
870
+ mdpv_detailed_description_cocostuff_10k_dataset,
871
+ mdpv_detailed_description_cocostuff_164k_dataset,
872
+ mdpv_detailed_description_vg_dataset,
873
+ mdpv_brief_description_lvis_dataset,
874
+ mdpv_brief_description_vg_dataset,
875
+ mdpv_brief_description_ade20k_dataset,
876
+ mdpv_brief_description_cocostuff10k_dataset,
877
+ mdpv_brief_description_cocostuff164k_dataset,
878
+ mdpv_qa_vg_dataset,
879
+ mdpv_qa_lvis_dataset,
880
+ mdpv_qa_ade20k_dataset,
881
+ mdpv_qa_cocostuff10k_dataset,
882
+ mdpv_qa_cocostuff164k_dataset,
883
+ mdpv_multi_points_flicker30k_dataset,
884
+ mdpv_multi_points_openpsg_dataset,],
885
+ )
886
+
887
+ train_dataloader = dict(
888
+ batch_size=batch_size,
889
+ num_workers=dataloader_num_workers,
890
+ dataset=train_dataset,
891
+ sampler=dict(
892
+ type=LengthGroupedSampler,
893
+ length_property='modality_length',
894
+ per_device_batch_size=batch_size * accumulative_counts),
895
+ collate_fn=dict(type=omg_llava_collate_fn))
896
+
897
+ #######################################################################
898
+ # PART 4 Scheduler & Optimizer #
899
+ #######################################################################
900
+ # optimizer
901
+ optim_wrapper = dict(
902
+ type=AmpOptimWrapper,
903
+ optimizer=dict(
904
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
905
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
906
+ accumulative_counts=accumulative_counts,
907
+ loss_scale='dynamic',
908
+ dtype='float16')
909
+
910
+ # learning policy
911
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
912
+ param_scheduler = [
913
+ dict(
914
+ type=LinearLR,
915
+ start_factor=1e-5,
916
+ by_epoch=True,
917
+ begin=0,
918
+ end=warmup_ratio * max_epochs,
919
+ convert_to_iter_based=True),
920
+ dict(
921
+ type=CosineAnnealingLR,
922
+ eta_min=0.0,
923
+ by_epoch=True,
924
+ begin=warmup_ratio * max_epochs,
925
+ end=max_epochs,
926
+ convert_to_iter_based=True)
927
+ ]
928
+
929
+ # train, val, test setting
930
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
931
+
932
+ #######################################################################
933
+ # PART 5 Runtime #
934
+ #######################################################################
935
+ # Log the dialogue periodically during the training process, optional
936
+ custom_hooks = [
937
+ dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer),
938
+ dict(
939
+ type=EvaluateChatHook_withSpecialTokens,
940
+ tokenizer=tokenizer,
941
+ image_processor=image_processor,
942
+ every_n_iters=evaluation_freq,
943
+ evaluation_inputs=evaluation_inputs,
944
+ evaluation_images=evaluation_images,
945
+ system=SYSTEM,
946
+ prompt_template=prompt_template)
947
+ ]
948
+
949
+ # configure default hooks
950
+ default_hooks = dict(
951
+ # record the time of every iteration.
952
+ timer=dict(type=IterTimerHook),
953
+ # print log every 10 iterations.
954
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
955
+ # enable the parameter scheduler.
956
+ param_scheduler=dict(type=ParamSchedulerHook),
957
+ # save checkpoint per `save_steps`.
958
+ checkpoint=dict(
959
+ type=CheckpointHook,
960
+ by_epoch=False,
961
+ interval=save_steps,
962
+ max_keep_ckpts=save_total_limit),
963
+ # set sampler seed in distributed evrionment.
964
+ sampler_seed=dict(type=DistSamplerSeedHook),
965
+ )
966
+
967
+ # configure environment
968
+ env_cfg = dict(
969
+ # whether to enable cudnn benchmark
970
+ cudnn_benchmark=False,
971
+ # set multi process parameters
972
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
973
+ # set distributed parameters
974
+ dist_cfg=dict(backend='nccl'),
975
+ )
976
+
977
+ # set visualizer
978
+ visualizer = None
979
+
980
+ # set log level
981
+ log_level = 'INFO'
982
+
983
+ # load from which checkpoint
984
+ load_from = None
985
+
986
+ # whether to resume training from the loaded checkpoint
987
+ resume = False
988
+
989
+ # Defaults to use random seed and disable `deterministic`
990
+ randomness = dict(seed=None, deterministic=False)
991
+
992
+ # set log processor
993
+ log_processor = dict(by_epoch=False)
omg_llava/configs/finetune/omg_llava_7b_convnextXXL_finetune_stage1_1024image_uniSegFormat_8gpus.py ADDED
@@ -0,0 +1,952 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
4
+ LoggerHook, ParamSchedulerHook)
5
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
6
+ from peft import LoraConfig
7
+ from torch.optim import AdamW
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig, CLIPImageProcessor,
10
+ CLIPVisionModel)
11
+
12
+ from omg_llava.dataset import LLaVADataset
13
+ from omg_llava.dataset.collect_fns import omg_llava_collate_fn
14
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
15
+ from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\
16
+ CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\
17
+ ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\
18
+ PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\
19
+ RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\
20
+ Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\
21
+ OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\
22
+ OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\
23
+ MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\
24
+ semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\
25
+ referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn
26
+ from xtuner.dataset.samplers import LengthGroupedSampler
27
+ from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens
28
+ from xtuner.engine.runner import TrainLoop
29
+ from omg_llava.model import OMG_LLaVA
30
+ from xtuner.utils import PROMPT_TEMPLATE
31
+ from omg_llava.model import OpenCLIPBackbone_omgseg
32
+ from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead
33
+
34
+ from torch.nn import GroupNorm, ReLU
35
+
36
+ from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \
37
+ DiceLoss, MaskFormerFusionHead, FocalLoss
38
+ from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost
39
+ from mmdet.models.task_modules.samplers import MaskPseudoSampler
40
+
41
+ #######################################################################
42
+ # PART 1 Settings #
43
+ #######################################################################
44
+ # Model
45
+ llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path
46
+ pretrained_pth = './pretrained/omg_llava/omg_llava_7b_pretrain_1024image_convnextXXL.pth'
47
+ omg_ov_class_embed_path='./pretrained/omg_llava/convnext_xxlarge_CocoPanopticOVDataset.pth' # Please change to your own path
48
+ omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convxxl.pth' # Please change to your own path
49
+
50
+ # Data
51
+ data_root = './data/llava_data/'
52
+ data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json'
53
+ image_folder = data_root + 'llava_images'
54
+
55
+ glamm_data_root = './data/glamm_data/'
56
+
57
+ refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/'
58
+ refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json'
59
+
60
+ grandf_image_path = glamm_data_root + 'images/grandf/train/'
61
+ grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json'
62
+
63
+ flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/'
64
+ flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json'
65
+
66
+ psg_image_path = glamm_data_root + 'images/coco2017/'
67
+ psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json'
68
+
69
+ ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
70
+ ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json'
71
+
72
+ cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/'
73
+ cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt'
74
+ cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/'
75
+
76
+ mapillary_image_path = './data/semantic_seg/mapillary/training/images/'
77
+ mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json'
78
+ mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/'
79
+
80
+ pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/'
81
+ pascal_file = './data/semantic_seg/pascal_part/train.json'
82
+
83
+ paco_image_path = './data/glamm_data/images/coco2017/'
84
+ paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json'
85
+
86
+ referring_refcoco_image_path = refcocog_image_path
87
+ referring_refcoco_data_path = "./data/ref_seg/"
88
+
89
+ referring_refcoco_plus_image_path = refcocog_image_path
90
+ referring_refcoco_plus_data_path = "./data/ref_seg/"
91
+
92
+ referring_refcocog_image_path = refcocog_image_path
93
+ referring_refcocog_data_path = "./data/ref_seg/"
94
+
95
+ referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/"
96
+ referring_refclef_data_path = "./data/ref_seg/"
97
+
98
+ region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
99
+ region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json"
100
+
101
+ region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
102
+ region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json"
103
+
104
+ mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
105
+ mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json'
106
+
107
+ mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
108
+ mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json'
109
+
110
+ mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
111
+ mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json'
112
+
113
+ mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
114
+ mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json'
115
+
116
+ mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
117
+ mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json'
118
+
119
+ mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
120
+ mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json'
121
+
122
+ mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
123
+ mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json'
124
+
125
+ mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
126
+ mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json'
127
+
128
+ mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
129
+ mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json'
130
+
131
+ mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
132
+ mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json'
133
+
134
+ mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
135
+ mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json'
136
+
137
+ mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017'
138
+ mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json'
139
+
140
+ mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
141
+ mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json'
142
+
143
+ mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
144
+ mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json'
145
+
146
+ mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/'
147
+ mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json'
148
+
149
+ mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017'
150
+ mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json'
151
+
152
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
153
+ max_length = int(2048 - (1024 / 64)**2 - 100)
154
+
155
+ # Scheduler & Optimizer
156
+ batch_size = 8 # per_device
157
+ accumulative_counts = 2
158
+ dataloader_num_workers = 4
159
+ max_epochs = 1
160
+ optim_type = AdamW
161
+ lr = 2e-4
162
+ betas = (0.9, 0.999)
163
+ weight_decay = 0
164
+ max_norm = 1 # grad clip
165
+ warmup_ratio = 0.03
166
+
167
+
168
+ # Save
169
+ save_steps = 2000
170
+ save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited)
171
+
172
+ # Evaluate the generation performance during the training
173
+ evaluation_freq = 2000
174
+ SYSTEM = ''
175
+ evaluation_images = './work_dirs/test.jpg'
176
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture',
177
+ 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.']
178
+
179
+ #######################################################################
180
+ # PART 2 Model & Tokenizer & Image Processor #
181
+ #######################################################################
182
+ tokenizer = dict(
183
+ type=AutoTokenizer.from_pretrained,
184
+ pretrained_model_name_or_path=llm_name_or_path,
185
+ trust_remote_code=True,
186
+ padding_side='right')
187
+
188
+ image_processor = dict(
189
+ type=CLIPImageProcessor,
190
+ do_resize=True,
191
+ size=1024,
192
+ resample=3,
193
+ do_center_crop=True,
194
+ crop_size=1024,
195
+ do_rescale=True,
196
+ do_normalize=True,
197
+ image_mean=[0.4814, 0.4578, 0.4082],
198
+ image_std=[0.2686, 0.2613, 0.2757],
199
+ do_convert_rgb=True
200
+ )
201
+
202
+ class_embed = 'convnext_large_d_320_CocoPanopticOVDataset'
203
+ num_things_classes = 80
204
+ num_stuff_classes = 53
205
+ num_classes = num_things_classes + num_stuff_classes
206
+
207
+ omgseg_model = dict(
208
+ type=OMGSegVisualEncoder,
209
+ data_preprocessor=None,
210
+ pixel_shuffle_down_ratio=2,
211
+ backbone=dict(
212
+ type=OpenCLIPBackbone_omgseg,
213
+ model_name='convnext_xxlarge',
214
+ fix=True,
215
+ init_cfg=dict(
216
+ type='clip_pretrain',
217
+ checkpoint='laion2b_s34b_b82k_augreg_soup'
218
+ )
219
+ ),
220
+ panoptic_head=dict(
221
+ type=Mask2FormerVideoSemSamHead,
222
+ sphere_cls=True,
223
+ ov_path=omg_ov_class_embed_path,
224
+ enable_box_query=False,
225
+ ov_classifier_name=class_embed,
226
+ logit=None,
227
+ in_channels=[384, 768, 1536, 3072], # pass to pixel_decoder inside
228
+ strides=[4, 8, 16, 32],
229
+ feat_channels=256,
230
+ out_channels=256,
231
+ num_things_classes=num_things_classes,
232
+ num_stuff_classes=num_stuff_classes,
233
+ num_queries=300,
234
+ num_transformer_feat_level=3,
235
+ pixel_decoder=dict(
236
+ type=MSDeformAttnPixelDecoder,
237
+ num_outs=3,
238
+ norm_cfg=dict(type=GroupNorm, num_groups=32),
239
+ act_cfg=dict(type=ReLU),
240
+ encoder=dict( # DeformableDetrTransformerEncoder
241
+ num_layers=6,
242
+ layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
243
+ self_attn_cfg=dict( # MultiScaleDeformableAttention
244
+ embed_dims=256,
245
+ num_heads=8,
246
+ num_levels=3,
247
+ num_points=4,
248
+ dropout=0.0,
249
+ batch_first=True),
250
+ ffn_cfg=dict(
251
+ embed_dims=256,
252
+ feedforward_channels=1024,
253
+ num_fcs=2,
254
+ ffn_drop=0.0,
255
+ act_cfg=dict(type=ReLU, inplace=True)))),
256
+ positional_encoding=dict(num_feats=128, normalize=True)),
257
+ enforce_decoder_input_project=False,
258
+ positional_encoding=dict(num_feats=128, normalize=True),
259
+ transformer_decoder=dict( # Mask2FormerTransformerDecoder
260
+ return_intermediate=True,
261
+ num_layers=9,
262
+ layer_cfg=dict( # Mask2FormerTransformerDecoderLayer
263
+ self_attn_cfg=dict( # MultiheadAttention
264
+ embed_dims=256,
265
+ num_heads=8,
266
+ dropout=0.0,
267
+ batch_first=True),
268
+ cross_attn_cfg=dict( # MultiheadAttention
269
+ embed_dims=256,
270
+ num_heads=8,
271
+ dropout=0.0,
272
+ batch_first=True),
273
+ ffn_cfg=dict(
274
+ embed_dims=256,
275
+ feedforward_channels=2048,
276
+ num_fcs=2,
277
+ ffn_drop=0.0,
278
+ act_cfg=dict(type='ReLU', inplace=True))),
279
+ init_cfg=None),
280
+ loss_cls=dict(
281
+ type=CrossEntropyLoss,
282
+ use_sigmoid=False,
283
+ loss_weight=2.0,
284
+ reduction='mean',
285
+ class_weight=[1.0] * 240 + [0.1]),
286
+ loss_mask=dict(
287
+ type=CrossEntropyLoss,
288
+ use_sigmoid=True,
289
+ reduction='mean',
290
+ loss_weight=5.0),
291
+ loss_dice=dict(
292
+ type=DiceLoss,
293
+ use_sigmoid=True,
294
+ activate=True,
295
+ reduction='mean',
296
+ naive_dice=True,
297
+ eps=1.0,
298
+ loss_weight=5.0),
299
+ loss_iou=dict(
300
+ type=FocalLoss,
301
+ use_sigmoid=True,
302
+ loss_weight=2.0,
303
+ reduction='mean')
304
+ ),
305
+ panoptic_fusion_head=dict(
306
+ type=MaskFormerFusionHead,
307
+ num_things_classes=num_things_classes,
308
+ num_stuff_classes=num_stuff_classes,
309
+ loss_panoptic=None,
310
+ init_cfg=None),
311
+ train_cfg=dict(
312
+ num_points=12544,
313
+ oversample_ratio=3.0,
314
+ importance_sample_ratio=0.75,
315
+ assigner=dict(
316
+ type=HungarianAssigner,
317
+ match_costs=[
318
+ # dict(type=FlexibleClassificationCost, weight=2.0),
319
+ dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True),
320
+ dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0)
321
+ ]),
322
+ sampler=dict(type=MaskPseudoSampler)),
323
+ test_cfg=dict(
324
+ panoptic_on=True,
325
+ # For now, the dataset does not support
326
+ # evaluating semantic segmentation metric.
327
+ semantic_on=False,
328
+ instance_on=True,
329
+ # max_per_image is for instance segmentation.
330
+ max_per_image=100,
331
+ iou_thr=0.8,
332
+ # In Mask2Former's panoptic postprocessing,
333
+ # it will filter mask area where score is less than 0.5 .
334
+ filter_low_score=True),
335
+ init_cfg=dict(
336
+ type='Pretrained',
337
+ checkpoint=omg_head_pretrain_pth_path,
338
+ )
339
+ )
340
+
341
+ model = dict(
342
+ type=OMG_LLaVA,
343
+ freeze_llm=True,
344
+ freeze_visual_encoder=True,
345
+ require_omg_decoder=False,
346
+ pretrained_pth=pretrained_pth,
347
+ text2vision_projector=True,
348
+ pixel_shuffle_ratio=2,
349
+ clip_feat_channel=3072,
350
+ llm=dict(
351
+ type=AutoModelForCausalLM.from_pretrained,
352
+ pretrained_model_name_or_path=llm_name_or_path,
353
+ trust_remote_code=True,
354
+ torch_dtype=torch.float16,
355
+ quantization_config=dict(
356
+ type=BitsAndBytesConfig,
357
+ load_in_4bit=True,
358
+ load_in_8bit=False,
359
+ llm_int8_threshold=6.0,
360
+ llm_int8_has_fp16_weight=False,
361
+ bnb_4bit_compute_dtype=torch.float16,
362
+ bnb_4bit_use_double_quant=True,
363
+ bnb_4bit_quant_type='nf4')),
364
+ llm_lora=dict(
365
+ type=LoraConfig,
366
+ r=512,
367
+ lora_alpha=256,
368
+ lora_dropout=0.05,
369
+ bias='none',
370
+ task_type='CAUSAL_LM'),
371
+ visual_encoder=omgseg_model,
372
+ tokenizer=tokenizer,
373
+ )
374
+
375
+ #######################################################################
376
+ # PART 3 Dataset & Dataloader #
377
+ #######################################################################
378
+ debug=False
379
+ llava_dataset = dict(
380
+ type=LLaVADataset,
381
+ data_path=data_path,
382
+ image_folder=image_folder,
383
+ tokenizer=tokenizer,
384
+ image_processor=image_processor,
385
+ dataset_map_fn=llava_map_fn,
386
+ template_map_fn=dict(
387
+ type=template_map_fn_factory, template=prompt_template),
388
+ max_length=max_length,
389
+ pad_image_to_square=True)
390
+
391
+ glamm_refcocog_dataset = dict(
392
+ type=RefCOCOgGCGDataset,
393
+ data_path=refcocog_ann_file,
394
+ image_folder=refcocog_image_path,
395
+ tokenizer=tokenizer,
396
+ image_processor=image_processor,
397
+ dataset_map_fn=glamm_refcocog_map_fn,
398
+ template_map_fn=dict(
399
+ type=template_map_fn_factory, template=prompt_template),
400
+ max_length=max_length,
401
+ pad_image_to_square=True,
402
+ debug=False,
403
+ repeats=1,
404
+ )
405
+
406
+ glamm_grandf_dataset = dict(
407
+ type=GranDfGCGDataset,
408
+ data_path=grandf_ann_file,
409
+ image_folder=grandf_image_path,
410
+ tokenizer=tokenizer,
411
+ image_processor=image_processor,
412
+ dataset_map_fn=glamm_granf_map_fn,
413
+ template_map_fn=dict(
414
+ type=template_map_fn_factory, template=prompt_template),
415
+ max_length=max_length,
416
+ pad_image_to_square=True,
417
+ debug=debug,
418
+ repeats=10,
419
+ )
420
+
421
+ glamm_psg_dataset = dict(
422
+ type=OpenPsgGCGDataset,
423
+ data_path=psg_ann_file,
424
+ image_folder=psg_image_path,
425
+ tokenizer=tokenizer,
426
+ image_processor=image_processor,
427
+ dataset_map_fn=glamm_openpsg_map_fn,
428
+ template_map_fn=dict(
429
+ type=template_map_fn_factory, template=prompt_template),
430
+ max_length=max_length,
431
+ pad_image_to_square=True,
432
+ debug=debug,
433
+ repeats=1,
434
+ )
435
+
436
+ glamm_flickr_dataset = dict(
437
+ type=FlickrGCGDataset,
438
+ data_path=flickr_ann_file,
439
+ image_folder=flickr_image_path,
440
+ tokenizer=tokenizer,
441
+ image_processor=image_processor,
442
+ dataset_map_fn=glamm_flickr_map_fn,
443
+ template_map_fn=dict(
444
+ type=template_map_fn_factory, template=prompt_template),
445
+ max_length=max_length,
446
+ pad_image_to_square=True,
447
+ debug=debug,
448
+ repeats=1,
449
+ )
450
+
451
+ semantic_seg_ade20k_dataset = dict(
452
+ type=ADE20kSemanticSegDataset,
453
+ data_path=ade20k_class_file,
454
+ image_folder=ade20k_image_path,
455
+ tokenizer=tokenizer,
456
+ image_processor=image_processor,
457
+ dataset_map_fn=semantic_seg_gcg_format_map_fn,
458
+ template_map_fn=dict(
459
+ type=template_map_fn_factory, template=prompt_template),
460
+ max_length=max_length,
461
+ pad_image_to_square=True,
462
+ debug=False,
463
+ repeats=1,
464
+ gcg_format=True,
465
+ )
466
+
467
+ semantic_seg_cocostuff_dataset = dict(
468
+ type=COCOStuffSemanticSegDataset,
469
+ data_path=cocostuff_class_file,
470
+ image_folder=cocostuff_image_path,
471
+ label_path=cocostuff_label_path,
472
+ tokenizer=tokenizer,
473
+ image_processor=image_processor,
474
+ dataset_map_fn=semantic_seg_gcg_format_map_fn,
475
+ template_map_fn=dict(
476
+ type=template_map_fn_factory, template=prompt_template),
477
+ max_length=max_length,
478
+ pad_image_to_square=True,
479
+ debug=False,
480
+ repeats=1,
481
+ gcg_format=True,
482
+ )
483
+
484
+ referring_seg_refcoco_dataset = dict(
485
+ type=RefcocoReferringSegDataset,
486
+ data_path=referring_refcoco_data_path,
487
+ image_folder=referring_refcoco_image_path,
488
+ tokenizer=tokenizer,
489
+ image_processor=image_processor,
490
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
491
+ template_map_fn=dict(
492
+ type=template_map_fn_factory, template=prompt_template),
493
+ max_length=max_length,
494
+ pad_image_to_square=True,
495
+ debug=False,
496
+ repeats=1,
497
+ )
498
+
499
+ referring_seg_refcoco_plus_dataset = dict(
500
+ type=Refcoco_plus_ReferringSegDataset,
501
+ data_path=referring_refcoco_plus_data_path,
502
+ image_folder=referring_refcoco_plus_image_path,
503
+ tokenizer=tokenizer,
504
+ image_processor=image_processor,
505
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
506
+ template_map_fn=dict(
507
+ type=template_map_fn_factory, template=prompt_template),
508
+ max_length=max_length,
509
+ pad_image_to_square=True,
510
+ debug=False,
511
+ repeats=1,
512
+ )
513
+
514
+ referring_seg_refcocog_dataset = dict(
515
+ type=Refcocog_ReferringSegDataset,
516
+ data_path=referring_refcocog_data_path,
517
+ image_folder=referring_refcocog_image_path,
518
+ tokenizer=tokenizer,
519
+ image_processor=image_processor,
520
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
521
+ template_map_fn=dict(
522
+ type=template_map_fn_factory, template=prompt_template),
523
+ max_length=max_length,
524
+ pad_image_to_square=True,
525
+ debug=False,
526
+ repeats=1,
527
+ )
528
+
529
+ referring_seg_refclef_dataset = dict(
530
+ type=Refclef_ReferringSegDataset,
531
+ data_path=referring_refclef_data_path,
532
+ image_folder=referring_refclef_image_path,
533
+ tokenizer=tokenizer,
534
+ image_processor=image_processor,
535
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
536
+ template_map_fn=dict(
537
+ type=template_map_fn_factory, template=prompt_template),
538
+ max_length=max_length,
539
+ pad_image_to_square=True,
540
+ debug=False,
541
+ repeats=1,
542
+ )
543
+
544
+ region_cap_osprey_dataset = dict(
545
+ type=OspreyRegionCaptionDataset,
546
+ data_path=region_cap_osprey_data_path,
547
+ image_folder=region_cap_osprey_image_path,
548
+ tokenizer=tokenizer,
549
+ image_processor=image_processor,
550
+ dataset_map_fn=osprey_region_caption_gcg_format_map_fn,
551
+ template_map_fn=dict(
552
+ type=template_map_fn_factory, template=prompt_template),
553
+ max_length=max_length,
554
+ pad_image_to_square=True,
555
+ debug=False,
556
+ repeats=1,
557
+ )
558
+
559
+ region_conversation_osprey_dataset = dict(
560
+ type=OspreyRegionConversationDataset,
561
+ data_path=region_conversation_osprey_data_path,
562
+ image_folder=region_conversation_osprey_image_path,
563
+ tokenizer=tokenizer,
564
+ image_processor=image_processor,
565
+ dataset_map_fn=osprey_region_conversation_map_fn,
566
+ template_map_fn=dict(
567
+ type=template_map_fn_factory, template=prompt_template),
568
+ max_length=max_length,
569
+ pad_image_to_square=True,
570
+ debug=False,
571
+ repeats=1,
572
+ )
573
+
574
+ mdpv_detailed_description_ade20k_dataset = dict(
575
+ type=MDPVPointDetailedCaptionDataset,
576
+ data_path=mdpv_detailed_caption_ade20k_data_path,
577
+ image_folder=mdpv_detailed_caption_ade20k_image_path,
578
+ tokenizer=tokenizer,
579
+ image_processor=image_processor,
580
+ dataset_map_fn=mdpv_points_map_fn,
581
+ template_map_fn=dict(
582
+ type=template_map_fn_factory, template=prompt_template),
583
+ max_length=max_length,
584
+ pad_image_to_square=True,
585
+ debug=False,
586
+ repeats=1,
587
+ )
588
+
589
+ mdpv_detailed_description_cocostuff_10k_dataset = dict(
590
+ type=MDPVPointDetailedCaptionDataset,
591
+ data_path=mdpv_detailed_caption_cocostuff_10k_data_path,
592
+ image_folder=mdpv_detailed_caption_cocostuff_10k_image_path,
593
+ tokenizer=tokenizer,
594
+ image_processor=image_processor,
595
+ dataset_map_fn=mdpv_points_map_fn,
596
+ template_map_fn=dict(
597
+ type=template_map_fn_factory, template=prompt_template),
598
+ max_length=max_length,
599
+ pad_image_to_square=True,
600
+ debug=False,
601
+ repeats=1,
602
+ )
603
+
604
+ mdpv_detailed_description_cocostuff_164k_dataset = dict(
605
+ type=MDPVPointDetailedCaptionDataset,
606
+ data_path=mdpv_detailed_caption_cocostuff_164k_data_path,
607
+ image_folder=mdpv_detailed_caption_cocostuff_164k_image_path,
608
+ tokenizer=tokenizer,
609
+ image_processor=image_processor,
610
+ dataset_map_fn=mdpv_points_map_fn,
611
+ template_map_fn=dict(
612
+ type=template_map_fn_factory, template=prompt_template),
613
+ max_length=max_length,
614
+ pad_image_to_square=True,
615
+ debug=False,
616
+ repeats=1,
617
+ )
618
+
619
+ mdpv_detailed_description_vg_dataset = dict(
620
+ type=MDPVPointDetailedCaptionDataset,
621
+ data_path=mdpv_detailed_caption_vg_data_path,
622
+ image_folder=mdpv_detailed_caption_vg_image_path,
623
+ tokenizer=tokenizer,
624
+ image_processor=image_processor,
625
+ dataset_map_fn=mdpv_points_map_fn,
626
+ template_map_fn=dict(
627
+ type=template_map_fn_factory, template=prompt_template),
628
+ max_length=max_length,
629
+ pad_image_to_square=True,
630
+ debug=False,
631
+ repeats=1,
632
+ )
633
+
634
+ mdpv_brief_description_vg_dataset = dict(
635
+ type=MDPVPointBriefCaptionDataset,
636
+ data_path=mdpv_brief_caption_vg_data_path,
637
+ image_folder=mdpv_brief_caption_vg_image_path,
638
+ tokenizer=tokenizer,
639
+ image_processor=image_processor,
640
+ dataset_map_fn=mdpv_points_map_fn,
641
+ template_map_fn=dict(
642
+ type=template_map_fn_factory, template=prompt_template),
643
+ max_length=max_length,
644
+ pad_image_to_square=True,
645
+ debug=False,
646
+ repeats=1,
647
+ )
648
+
649
+ mdpv_brief_description_cocostuff10k_dataset = dict(
650
+ type=MDPVPointBriefCaptionDataset,
651
+ data_path=mdpv_brief_caption_cocostuff_10k_data_path,
652
+ image_folder=mdpv_brief_caption_cocostuff_10k_image_path,
653
+ tokenizer=tokenizer,
654
+ image_processor=image_processor,
655
+ dataset_map_fn=mdpv_points_map_fn,
656
+ template_map_fn=dict(
657
+ type=template_map_fn_factory, template=prompt_template),
658
+ max_length=max_length,
659
+ pad_image_to_square=True,
660
+ debug=False,
661
+ repeats=1,
662
+ )
663
+
664
+ mdpv_brief_description_cocostuff164k_dataset = dict(
665
+ type=MDPVPointBriefCaptionDataset,
666
+ data_path=mdpv_brief_caption_cocostuff_164k_data_path,
667
+ image_folder=mdpv_brief_caption_cocostuff_164k_image_path,
668
+ tokenizer=tokenizer,
669
+ image_processor=image_processor,
670
+ dataset_map_fn=mdpv_points_map_fn,
671
+ template_map_fn=dict(
672
+ type=template_map_fn_factory, template=prompt_template),
673
+ max_length=max_length,
674
+ pad_image_to_square=True,
675
+ debug=False,
676
+ repeats=1,
677
+ )
678
+
679
+ mdpv_brief_description_ade20k_dataset = dict(
680
+ type=MDPVPointBriefCaptionDataset,
681
+ data_path=mdpv_brief_caption_ade20k_data_path,
682
+ image_folder=mdpv_brief_caption_ade20k_image_path,
683
+ tokenizer=tokenizer,
684
+ image_processor=image_processor,
685
+ dataset_map_fn=mdpv_points_map_fn,
686
+ template_map_fn=dict(
687
+ type=template_map_fn_factory, template=prompt_template),
688
+ max_length=max_length,
689
+ pad_image_to_square=True,
690
+ debug=False,
691
+ repeats=1,
692
+ )
693
+
694
+ mdpv_brief_description_lvis_dataset = dict(
695
+ type=MDPVPointBriefCaptionDataset,
696
+ data_path=mdpv_brief_caption_lvis_data_path,
697
+ image_folder=mdpv_brief_caption_lvis_image_path,
698
+ tokenizer=tokenizer,
699
+ image_processor=image_processor,
700
+ dataset_map_fn=mdpv_points_map_fn,
701
+ template_map_fn=dict(
702
+ type=template_map_fn_factory, template=prompt_template),
703
+ max_length=max_length,
704
+ pad_image_to_square=True,
705
+ debug=False,
706
+ repeats=1,
707
+ )
708
+
709
+ mdpv_qa_vg_dataset = dict(
710
+ type=MDPVPointBriefCaptionDataset,
711
+ data_path=mdpv_qa_vg_data_path,
712
+ image_folder=mdpv_qa_vg_image_path,
713
+ tokenizer=tokenizer,
714
+ image_processor=image_processor,
715
+ dataset_map_fn=mdpv_points_map_fn,
716
+ template_map_fn=dict(
717
+ type=template_map_fn_factory, template=prompt_template),
718
+ max_length=max_length,
719
+ pad_image_to_square=True,
720
+ debug=False,
721
+ repeats=1,
722
+ )
723
+
724
+ mdpv_qa_ade20k_dataset = dict(
725
+ type=MDPVPointBriefCaptionDataset,
726
+ data_path=mdpv_qa_ade20k_data_path,
727
+ image_folder=mdpv_qa_ade20k_image_path,
728
+ tokenizer=tokenizer,
729
+ image_processor=image_processor,
730
+ dataset_map_fn=mdpv_points_map_fn,
731
+ template_map_fn=dict(
732
+ type=template_map_fn_factory, template=prompt_template),
733
+ max_length=max_length,
734
+ pad_image_to_square=True,
735
+ debug=False,
736
+ repeats=1,
737
+ )
738
+
739
+ mdpv_qa_lvis_dataset = dict(
740
+ type=MDPVPointBriefCaptionDataset,
741
+ data_path=mdpv_qa_lvis_data_path,
742
+ image_folder=mdpv_qa_lvis_image_path,
743
+ tokenizer=tokenizer,
744
+ image_processor=image_processor,
745
+ dataset_map_fn=mdpv_points_map_fn,
746
+ template_map_fn=dict(
747
+ type=template_map_fn_factory, template=prompt_template),
748
+ max_length=max_length,
749
+ pad_image_to_square=True,
750
+ debug=False,
751
+ repeats=1,
752
+ )
753
+
754
+ mdpv_qa_cocostuff10k_dataset = dict(
755
+ type=MDPVPointBriefCaptionDataset,
756
+ data_path=mdpv_qa_cocostuff10k_data_path,
757
+ image_folder=mdpv_qa_cocostuff10k_image_path,
758
+ tokenizer=tokenizer,
759
+ image_processor=image_processor,
760
+ dataset_map_fn=mdpv_points_map_fn,
761
+ template_map_fn=dict(
762
+ type=template_map_fn_factory, template=prompt_template),
763
+ max_length=max_length,
764
+ pad_image_to_square=True,
765
+ debug=False,
766
+ repeats=1,
767
+ )
768
+
769
+ mdpv_qa_cocostuff164k_dataset = dict(
770
+ type=MDPVPointBriefCaptionDataset,
771
+ data_path=mdpv_qa_cocostuff164k_data_path,
772
+ image_folder=mdpv_qa_cocostuff164k_image_path,
773
+ tokenizer=tokenizer,
774
+ image_processor=image_processor,
775
+ dataset_map_fn=mdpv_points_map_fn,
776
+ template_map_fn=dict(
777
+ type=template_map_fn_factory, template=prompt_template),
778
+ max_length=max_length,
779
+ pad_image_to_square=True,
780
+ debug=False,
781
+ repeats=1,
782
+ )
783
+
784
+ mdpv_multi_points_openpsg_dataset = dict(
785
+ type=MDPVPointBriefCaptionDataset,
786
+ data_path=mdpv_multi_points_openpsg_data_path,
787
+ image_folder=mdpv_multi_points_openpsg_image_path,
788
+ tokenizer=tokenizer,
789
+ image_processor=image_processor,
790
+ dataset_map_fn=mdpv_points_map_fn,
791
+ template_map_fn=dict(
792
+ type=template_map_fn_factory, template=prompt_template),
793
+ max_length=max_length,
794
+ pad_image_to_square=True,
795
+ debug=False,
796
+ repeats=1,
797
+ )
798
+
799
+ mdpv_multi_points_flicker30k_dataset = dict(
800
+ type=MDPVPointBriefCaptionDataset,
801
+ data_path=mdpv_multi_points_flicker30k_data_path,
802
+ image_folder=mdpv_multi_points_flicker30k_image_path,
803
+ tokenizer=tokenizer,
804
+ image_processor=image_processor,
805
+ dataset_map_fn=mdpv_points_map_fn,
806
+ template_map_fn=dict(
807
+ type=template_map_fn_factory, template=prompt_template),
808
+ max_length=max_length,
809
+ pad_image_to_square=True,
810
+ debug=False,
811
+ repeats=1,
812
+ )
813
+
814
+ train_dataset = dict(
815
+ type=CombineDataset,
816
+ datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset,
817
+ glamm_grandf_dataset, glamm_psg_dataset,
818
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x
819
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset,
820
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset,
821
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
822
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x
823
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
824
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset,
825
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
826
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset,
827
+ region_cap_osprey_dataset, region_conversation_osprey_dataset,
828
+ mdpv_detailed_description_ade20k_dataset,
829
+ mdpv_detailed_description_cocostuff_10k_dataset,
830
+ mdpv_detailed_description_cocostuff_164k_dataset,
831
+ mdpv_detailed_description_vg_dataset,
832
+ mdpv_brief_description_lvis_dataset,
833
+ mdpv_brief_description_vg_dataset,
834
+ mdpv_brief_description_ade20k_dataset,
835
+ mdpv_brief_description_cocostuff10k_dataset,
836
+ mdpv_brief_description_cocostuff164k_dataset,
837
+ mdpv_qa_vg_dataset,
838
+ mdpv_qa_lvis_dataset,
839
+ mdpv_qa_ade20k_dataset,
840
+ mdpv_qa_cocostuff10k_dataset,
841
+ mdpv_qa_cocostuff164k_dataset,
842
+ mdpv_multi_points_flicker30k_dataset,
843
+ mdpv_multi_points_openpsg_dataset,],
844
+ )
845
+
846
+ train_dataloader = dict(
847
+ batch_size=batch_size,
848
+ num_workers=dataloader_num_workers,
849
+ dataset=train_dataset,
850
+ sampler=dict(
851
+ type=LengthGroupedSampler,
852
+ length_property='modality_length',
853
+ per_device_batch_size=batch_size * accumulative_counts),
854
+ collate_fn=dict(type=omg_llava_collate_fn))
855
+
856
+ #######################################################################
857
+ # PART 4 Scheduler & Optimizer #
858
+ #######################################################################
859
+ # optimizer
860
+ optim_wrapper = dict(
861
+ type=AmpOptimWrapper,
862
+ optimizer=dict(
863
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
864
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
865
+ accumulative_counts=accumulative_counts,
866
+ loss_scale='dynamic',
867
+ dtype='float16')
868
+
869
+ # learning policy
870
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
871
+ param_scheduler = [
872
+ dict(
873
+ type=LinearLR,
874
+ start_factor=1e-5,
875
+ by_epoch=True,
876
+ begin=0,
877
+ end=warmup_ratio * max_epochs,
878
+ convert_to_iter_based=True),
879
+ dict(
880
+ type=CosineAnnealingLR,
881
+ eta_min=0.0,
882
+ by_epoch=True,
883
+ begin=warmup_ratio * max_epochs,
884
+ end=max_epochs,
885
+ convert_to_iter_based=True)
886
+ ]
887
+
888
+ # train, val, test setting
889
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
890
+
891
+ #######################################################################
892
+ # PART 5 Runtime #
893
+ #######################################################################
894
+ # Log the dialogue periodically during the training process, optional
895
+ custom_hooks = [
896
+ dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer),
897
+ dict(
898
+ type=EvaluateChatHook_withSpecialTokens,
899
+ tokenizer=tokenizer,
900
+ image_processor=image_processor,
901
+ every_n_iters=evaluation_freq,
902
+ evaluation_inputs=evaluation_inputs,
903
+ evaluation_images=evaluation_images,
904
+ system=SYSTEM,
905
+ prompt_template=prompt_template)
906
+ ]
907
+
908
+ # configure default hooks
909
+ default_hooks = dict(
910
+ # record the time of every iteration.
911
+ timer=dict(type=IterTimerHook),
912
+ # print log every 10 iterations.
913
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
914
+ # enable the parameter scheduler.
915
+ param_scheduler=dict(type=ParamSchedulerHook),
916
+ # save checkpoint per `save_steps`.
917
+ checkpoint=dict(
918
+ type=CheckpointHook,
919
+ by_epoch=False,
920
+ interval=save_steps,
921
+ max_keep_ckpts=save_total_limit),
922
+ # set sampler seed in distributed evrionment.
923
+ sampler_seed=dict(type=DistSamplerSeedHook),
924
+ )
925
+
926
+ # configure environment
927
+ env_cfg = dict(
928
+ # whether to enable cudnn benchmark
929
+ cudnn_benchmark=False,
930
+ # set multi process parameters
931
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
932
+ # set distributed parameters
933
+ dist_cfg=dict(backend='nccl'),
934
+ )
935
+
936
+ # set visualizer
937
+ visualizer = None
938
+
939
+ # set log level
940
+ log_level = 'INFO'
941
+
942
+ # load from which checkpoint
943
+ load_from = None
944
+
945
+ # whether to resume training from the loaded checkpoint
946
+ resume = False
947
+
948
+ # Defaults to use random seed and disable `deterministic`
949
+ randomness = dict(seed=None, deterministic=False)
950
+
951
+ # set log processor
952
+ log_processor = dict(by_epoch=False)
omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_8gpus.py ADDED
@@ -0,0 +1,993 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
4
+ LoggerHook, ParamSchedulerHook)
5
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
6
+ from peft import LoraConfig
7
+ from torch.optim import AdamW
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig, CLIPImageProcessor,
10
+ CLIPVisionModel)
11
+
12
+ from omg_llava.dataset import LLaVADataset
13
+ from omg_llava.dataset.collect_fns import omg_llava_collate_fn
14
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
15
+ from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\
16
+ CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\
17
+ ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\
18
+ PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\
19
+ RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\
20
+ Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\
21
+ OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\
22
+ OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\
23
+ MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset
24
+ from xtuner.dataset.samplers import LengthGroupedSampler
25
+ from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens
26
+ from xtuner.engine.runner import TrainLoop
27
+ from omg_llava.model import OMG_LLaVA
28
+ from xtuner.utils import PROMPT_TEMPLATE
29
+ from omg_llava.model import OpenCLIPBackbone_omgseg
30
+ from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead
31
+
32
+ from torch.nn import GroupNorm, ReLU
33
+
34
+ from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \
35
+ DiceLoss, MaskFormerFusionHead, FocalLoss
36
+ from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost
37
+ from mmdet.models.task_modules.samplers import MaskPseudoSampler
38
+
39
+ #######################################################################
40
+ # PART 1 Settings #
41
+ #######################################################################
42
+ # Model
43
+ llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path
44
+ pretrained_pth = './pretrained/omg_llava/omg_llava_7b_pretrain_1024image_8gpus.pth'
45
+ omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path
46
+ omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path
47
+
48
+ # Data
49
+ data_root = './data/llava_data/'
50
+ data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json'
51
+ image_folder = data_root + 'llava_images'
52
+
53
+ glamm_data_root = './data/glamm_data/'
54
+
55
+ refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/'
56
+ refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json'
57
+
58
+ grandf_image_path = glamm_data_root + 'images/grandf/train/'
59
+ grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json'
60
+
61
+ flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/'
62
+ flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json'
63
+
64
+ psg_image_path = glamm_data_root + 'images/coco2017/'
65
+ psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json'
66
+
67
+ ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
68
+ ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json'
69
+
70
+ cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/'
71
+ cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt'
72
+ cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/'
73
+
74
+ mapillary_image_path = './data/semantic_seg/mapillary/training/images/'
75
+ mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json'
76
+ mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/'
77
+
78
+ pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/'
79
+ pascal_file = './data/semantic_seg/pascal_part/train.json'
80
+
81
+ paco_image_path = './data/glamm_data/images/coco2017/'
82
+ paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json'
83
+
84
+ referring_refcoco_image_path = refcocog_image_path
85
+ referring_refcoco_data_path = "./data/ref_seg/"
86
+
87
+ referring_refcoco_plus_image_path = refcocog_image_path
88
+ referring_refcoco_plus_data_path = "./data/ref_seg/"
89
+
90
+ referring_refcocog_image_path = refcocog_image_path
91
+ referring_refcocog_data_path = "./data/ref_seg/"
92
+
93
+ referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/"
94
+ referring_refclef_data_path = "./data/ref_seg/"
95
+
96
+ region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
97
+ region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json"
98
+
99
+ region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
100
+ region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json"
101
+
102
+ mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
103
+ mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json'
104
+
105
+ mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
106
+ mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json'
107
+
108
+ mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
109
+ mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json'
110
+
111
+ mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
112
+ mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json'
113
+
114
+ mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
115
+ mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json'
116
+
117
+ mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
118
+ mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json'
119
+
120
+ mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
121
+ mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json'
122
+
123
+ mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
124
+ mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json'
125
+
126
+ mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
127
+ mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json'
128
+
129
+ mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
130
+ mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json'
131
+
132
+ mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
133
+ mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json'
134
+
135
+ mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017'
136
+ mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json'
137
+
138
+ mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
139
+ mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json'
140
+
141
+ mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
142
+ mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json'
143
+
144
+ mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/'
145
+ mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json'
146
+
147
+ mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017'
148
+ mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json'
149
+
150
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
151
+ max_length = int(2048 - (1024 / 64)**2 - 100)
152
+
153
+ # Scheduler & Optimizer
154
+ batch_size = 8 # per_device
155
+ accumulative_counts = 2
156
+ dataloader_num_workers = 4
157
+ max_epochs = 1
158
+ optim_type = AdamW
159
+ lr = 2e-4
160
+ betas = (0.9, 0.999)
161
+ weight_decay = 0
162
+ max_norm = 1 # grad clip
163
+ warmup_ratio = 0.03
164
+
165
+
166
+ # Save
167
+ save_steps = 2000
168
+ save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited)
169
+
170
+ # Evaluate the generation performance during the training
171
+ evaluation_freq = 2000
172
+ SYSTEM = ''
173
+ evaluation_images = './work_dirs/test.jpg'
174
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture',
175
+ 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.']
176
+
177
+ #######################################################################
178
+ # PART 2 Model & Tokenizer & Image Processor #
179
+ #######################################################################
180
+ tokenizer = dict(
181
+ type=AutoTokenizer.from_pretrained,
182
+ pretrained_model_name_or_path=llm_name_or_path,
183
+ trust_remote_code=True,
184
+ padding_side='right')
185
+
186
+ image_processor = dict(
187
+ type=CLIPImageProcessor,
188
+ do_resize=True,
189
+ size=1024,
190
+ resample=3,
191
+ do_center_crop=True,
192
+ crop_size=1024,
193
+ do_rescale=True,
194
+ do_normalize=True,
195
+ image_mean=[0.4814, 0.4578, 0.4082],
196
+ image_std=[0.2686, 0.2613, 0.2757],
197
+ do_convert_rgb=True
198
+ )
199
+
200
+ class_embed = 'convnext_large_d_320_CocoPanopticOVDataset'
201
+ num_things_classes = 80
202
+ num_stuff_classes = 53
203
+ num_classes = num_things_classes + num_stuff_classes
204
+
205
+ omgseg_model = dict(
206
+ type=OMGSegVisualEncoder,
207
+ data_preprocessor=None,
208
+ pixel_shuffle_down_ratio=2,
209
+ backbone=dict(
210
+ type=OpenCLIPBackbone_omgseg,
211
+ model_name='convnext_large_d_320',
212
+ fix=True,
213
+ init_cfg=dict(
214
+ type='clip_pretrain',
215
+ checkpoint='laion2b_s29b_b131k_ft_soup'
216
+ )
217
+ ),
218
+ panoptic_head=dict(
219
+ type=Mask2FormerVideoSemSamHead,
220
+ sphere_cls=True,
221
+ ov_path=omg_ov_class_embed_path,
222
+ enable_box_query=False,
223
+ ov_classifier_name=class_embed,
224
+ logit=None,
225
+ in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside
226
+ strides=[4, 8, 16, 32],
227
+ feat_channels=256,
228
+ out_channels=256,
229
+ num_things_classes=num_things_classes,
230
+ num_stuff_classes=num_stuff_classes,
231
+ num_queries=300,
232
+ num_transformer_feat_level=3,
233
+ pixel_decoder=dict(
234
+ type=MSDeformAttnPixelDecoder,
235
+ num_outs=3,
236
+ norm_cfg=dict(type=GroupNorm, num_groups=32),
237
+ act_cfg=dict(type=ReLU),
238
+ encoder=dict( # DeformableDetrTransformerEncoder
239
+ num_layers=6,
240
+ layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
241
+ self_attn_cfg=dict( # MultiScaleDeformableAttention
242
+ embed_dims=256,
243
+ num_heads=8,
244
+ num_levels=3,
245
+ num_points=4,
246
+ dropout=0.0,
247
+ batch_first=True),
248
+ ffn_cfg=dict(
249
+ embed_dims=256,
250
+ feedforward_channels=1024,
251
+ num_fcs=2,
252
+ ffn_drop=0.0,
253
+ act_cfg=dict(type=ReLU, inplace=True)))),
254
+ positional_encoding=dict(num_feats=128, normalize=True)),
255
+ enforce_decoder_input_project=False,
256
+ positional_encoding=dict(num_feats=128, normalize=True),
257
+ transformer_decoder=dict( # Mask2FormerTransformerDecoder
258
+ return_intermediate=True,
259
+ num_layers=9,
260
+ layer_cfg=dict( # Mask2FormerTransformerDecoderLayer
261
+ self_attn_cfg=dict( # MultiheadAttention
262
+ embed_dims=256,
263
+ num_heads=8,
264
+ dropout=0.0,
265
+ batch_first=True),
266
+ cross_attn_cfg=dict( # MultiheadAttention
267
+ embed_dims=256,
268
+ num_heads=8,
269
+ dropout=0.0,
270
+ batch_first=True),
271
+ ffn_cfg=dict(
272
+ embed_dims=256,
273
+ feedforward_channels=2048,
274
+ num_fcs=2,
275
+ ffn_drop=0.0,
276
+ act_cfg=dict(type='ReLU', inplace=True))),
277
+ init_cfg=None),
278
+ loss_cls=dict(
279
+ type=CrossEntropyLoss,
280
+ use_sigmoid=False,
281
+ loss_weight=2.0,
282
+ reduction='mean',
283
+ class_weight=[1.0] * 240 + [0.1]),
284
+ loss_mask=dict(
285
+ type=CrossEntropyLoss,
286
+ use_sigmoid=True,
287
+ reduction='mean',
288
+ loss_weight=5.0),
289
+ loss_dice=dict(
290
+ type=DiceLoss,
291
+ use_sigmoid=True,
292
+ activate=True,
293
+ reduction='mean',
294
+ naive_dice=True,
295
+ eps=1.0,
296
+ loss_weight=5.0),
297
+ loss_iou=dict(
298
+ type=FocalLoss,
299
+ use_sigmoid=True,
300
+ loss_weight=2.0,
301
+ reduction='mean')
302
+ ),
303
+ panoptic_fusion_head=dict(
304
+ type=MaskFormerFusionHead,
305
+ num_things_classes=num_things_classes,
306
+ num_stuff_classes=num_stuff_classes,
307
+ loss_panoptic=None,
308
+ init_cfg=None),
309
+ train_cfg=dict(
310
+ num_points=12544,
311
+ oversample_ratio=3.0,
312
+ importance_sample_ratio=0.75,
313
+ assigner=dict(
314
+ type=HungarianAssigner,
315
+ match_costs=[
316
+ # dict(type=FlexibleClassificationCost, weight=2.0),
317
+ dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True),
318
+ dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0)
319
+ ]),
320
+ sampler=dict(type=MaskPseudoSampler)),
321
+ test_cfg=dict(
322
+ panoptic_on=True,
323
+ # For now, the dataset does not support
324
+ # evaluating semantic segmentation metric.
325
+ semantic_on=False,
326
+ instance_on=True,
327
+ # max_per_image is for instance segmentation.
328
+ max_per_image=100,
329
+ iou_thr=0.8,
330
+ # In Mask2Former's panoptic postprocessing,
331
+ # it will filter mask area where score is less than 0.5 .
332
+ filter_low_score=True),
333
+ init_cfg=dict(
334
+ type='Pretrained',
335
+ checkpoint=omg_head_pretrain_pth_path,
336
+ )
337
+ )
338
+
339
+ model = dict(
340
+ type=OMG_LLaVA,
341
+ freeze_llm=True,
342
+ freeze_visual_encoder=True,
343
+ require_omg_decoder=False,
344
+ pretrained_pth=pretrained_pth,
345
+ text2vision_projector=True,
346
+ pixel_shuffle_ratio=2,
347
+ llm=dict(
348
+ type=AutoModelForCausalLM.from_pretrained,
349
+ pretrained_model_name_or_path=llm_name_or_path,
350
+ trust_remote_code=True,
351
+ torch_dtype=torch.float16,
352
+ quantization_config=dict(
353
+ type=BitsAndBytesConfig,
354
+ load_in_4bit=True,
355
+ load_in_8bit=False,
356
+ llm_int8_threshold=6.0,
357
+ llm_int8_has_fp16_weight=False,
358
+ bnb_4bit_compute_dtype=torch.float16,
359
+ bnb_4bit_use_double_quant=True,
360
+ bnb_4bit_quant_type='nf4')),
361
+ llm_lora=dict(
362
+ type=LoraConfig,
363
+ r=512,
364
+ lora_alpha=256,
365
+ lora_dropout=0.05,
366
+ bias='none',
367
+ task_type='CAUSAL_LM'),
368
+ visual_encoder=omgseg_model,
369
+ tokenizer=tokenizer,
370
+ )
371
+
372
+ #######################################################################
373
+ # PART 3 Dataset & Dataloader #
374
+ #######################################################################
375
+ debug=False
376
+ llava_dataset = dict(
377
+ type=LLaVADataset,
378
+ data_path=data_path,
379
+ image_folder=image_folder,
380
+ tokenizer=tokenizer,
381
+ image_processor=image_processor,
382
+ dataset_map_fn=llava_map_fn,
383
+ template_map_fn=dict(
384
+ type=template_map_fn_factory, template=prompt_template),
385
+ max_length=max_length,
386
+ pad_image_to_square=True)
387
+
388
+ glamm_refcocog_dataset = dict(
389
+ type=RefCOCOgGCGDataset,
390
+ data_path=refcocog_ann_file,
391
+ image_folder=refcocog_image_path,
392
+ tokenizer=tokenizer,
393
+ image_processor=image_processor,
394
+ dataset_map_fn=glamm_refcocog_map_fn,
395
+ template_map_fn=dict(
396
+ type=template_map_fn_factory, template=prompt_template),
397
+ max_length=max_length,
398
+ pad_image_to_square=True,
399
+ debug=False,
400
+ repeats=1,
401
+ )
402
+
403
+ glamm_grandf_dataset = dict(
404
+ type=GranDfGCGDataset,
405
+ data_path=grandf_ann_file,
406
+ image_folder=grandf_image_path,
407
+ tokenizer=tokenizer,
408
+ image_processor=image_processor,
409
+ dataset_map_fn=glamm_granf_map_fn,
410
+ template_map_fn=dict(
411
+ type=template_map_fn_factory, template=prompt_template),
412
+ max_length=max_length,
413
+ pad_image_to_square=True,
414
+ debug=debug,
415
+ repeats=10,
416
+ )
417
+
418
+ glamm_psg_dataset = dict(
419
+ type=OpenPsgGCGDataset,
420
+ data_path=psg_ann_file,
421
+ image_folder=psg_image_path,
422
+ tokenizer=tokenizer,
423
+ image_processor=image_processor,
424
+ dataset_map_fn=glamm_openpsg_map_fn,
425
+ template_map_fn=dict(
426
+ type=template_map_fn_factory, template=prompt_template),
427
+ max_length=max_length,
428
+ pad_image_to_square=True,
429
+ debug=debug,
430
+ repeats=1,
431
+ )
432
+
433
+ glamm_flickr_dataset = dict(
434
+ type=FlickrGCGDataset,
435
+ data_path=flickr_ann_file,
436
+ image_folder=flickr_image_path,
437
+ tokenizer=tokenizer,
438
+ image_processor=image_processor,
439
+ dataset_map_fn=glamm_flickr_map_fn,
440
+ template_map_fn=dict(
441
+ type=template_map_fn_factory, template=prompt_template),
442
+ max_length=max_length,
443
+ pad_image_to_square=True,
444
+ debug=debug,
445
+ repeats=1,
446
+ )
447
+
448
+ semantic_seg_ade20k_dataset = dict(
449
+ type=ADE20kSemanticSegDataset,
450
+ data_path=ade20k_class_file,
451
+ image_folder=ade20k_image_path,
452
+ tokenizer=tokenizer,
453
+ image_processor=image_processor,
454
+ dataset_map_fn=semantic_seg_map_fn,
455
+ template_map_fn=dict(
456
+ type=template_map_fn_factory, template=prompt_template),
457
+ max_length=max_length,
458
+ pad_image_to_square=True,
459
+ debug=False,
460
+ repeats=1,
461
+ )
462
+
463
+ semantic_seg_cocostuff_dataset = dict(
464
+ type=COCOStuffSemanticSegDataset,
465
+ data_path=cocostuff_class_file,
466
+ image_folder=cocostuff_image_path,
467
+ label_path=cocostuff_label_path,
468
+ tokenizer=tokenizer,
469
+ image_processor=image_processor,
470
+ dataset_map_fn=semantic_seg_map_fn,
471
+ template_map_fn=dict(
472
+ type=template_map_fn_factory, template=prompt_template),
473
+ max_length=max_length,
474
+ pad_image_to_square=True,
475
+ debug=False,
476
+ repeats=1,
477
+ )
478
+
479
+ semantic_seg_mapillary_dataset = dict(
480
+ type=MapillarySemanticSegDataset,
481
+ data_path=mapillary_class_file,
482
+ image_folder=mapillary_image_path,
483
+ label_path=mapillary_label_path,
484
+ tokenizer=tokenizer,
485
+ image_processor=image_processor,
486
+ dataset_map_fn=semantic_seg_map_fn,
487
+ template_map_fn=dict(
488
+ type=template_map_fn_factory, template=prompt_template),
489
+ max_length=max_length,
490
+ pad_image_to_square=True,
491
+ debug=False,
492
+ repeats=1,
493
+ )
494
+
495
+ semantic_seg_pascal_part_dataset = dict(
496
+ type=PascalPartSemanticSegDataset,
497
+ data_path=pascal_file,
498
+ image_folder=pascal_part_image_path,
499
+ tokenizer=tokenizer,
500
+ image_processor=image_processor,
501
+ dataset_map_fn=pascal_part_map_fn,
502
+ template_map_fn=dict(
503
+ type=template_map_fn_factory, template=prompt_template),
504
+ max_length=max_length,
505
+ pad_image_to_square=True,
506
+ debug=False,
507
+ repeats=1,
508
+ )
509
+
510
+ semantic_seg_paco_dataset = dict(
511
+ type=PacoSemanticSegDataset,
512
+ data_path=paco_file,
513
+ image_folder=paco_image_path,
514
+ tokenizer=tokenizer,
515
+ image_processor=image_processor,
516
+ dataset_map_fn=pascal_part_map_fn,
517
+ template_map_fn=dict(
518
+ type=template_map_fn_factory, template=prompt_template),
519
+ max_length=max_length,
520
+ pad_image_to_square=True,
521
+ debug=False,
522
+ repeats=1,
523
+ )
524
+
525
+ referring_seg_refcoco_dataset = dict(
526
+ type=RefcocoReferringSegDataset,
527
+ data_path=referring_refcoco_data_path,
528
+ image_folder=referring_refcoco_image_path,
529
+ tokenizer=tokenizer,
530
+ image_processor=image_processor,
531
+ dataset_map_fn=referring_seg_map_fn,
532
+ template_map_fn=dict(
533
+ type=template_map_fn_factory, template=prompt_template),
534
+ max_length=max_length,
535
+ pad_image_to_square=True,
536
+ debug=False,
537
+ repeats=1,
538
+ )
539
+
540
+ referring_seg_refcoco_plus_dataset = dict(
541
+ type=Refcoco_plus_ReferringSegDataset,
542
+ data_path=referring_refcoco_plus_data_path,
543
+ image_folder=referring_refcoco_plus_image_path,
544
+ tokenizer=tokenizer,
545
+ image_processor=image_processor,
546
+ dataset_map_fn=referring_seg_map_fn,
547
+ template_map_fn=dict(
548
+ type=template_map_fn_factory, template=prompt_template),
549
+ max_length=max_length,
550
+ pad_image_to_square=True,
551
+ debug=False,
552
+ repeats=1,
553
+ )
554
+
555
+ referring_seg_refcocog_dataset = dict(
556
+ type=Refcocog_ReferringSegDataset,
557
+ data_path=referring_refcocog_data_path,
558
+ image_folder=referring_refcocog_image_path,
559
+ tokenizer=tokenizer,
560
+ image_processor=image_processor,
561
+ dataset_map_fn=referring_seg_map_fn,
562
+ template_map_fn=dict(
563
+ type=template_map_fn_factory, template=prompt_template),
564
+ max_length=max_length,
565
+ pad_image_to_square=True,
566
+ debug=False,
567
+ repeats=1,
568
+ )
569
+
570
+ referring_seg_refclef_dataset = dict(
571
+ type=Refclef_ReferringSegDataset,
572
+ data_path=referring_refclef_data_path,
573
+ image_folder=referring_refclef_image_path,
574
+ tokenizer=tokenizer,
575
+ image_processor=image_processor,
576
+ dataset_map_fn=referring_seg_map_fn,
577
+ template_map_fn=dict(
578
+ type=template_map_fn_factory, template=prompt_template),
579
+ max_length=max_length,
580
+ pad_image_to_square=True,
581
+ debug=False,
582
+ repeats=1,
583
+ )
584
+
585
+ region_cap_osprey_dataset = dict(
586
+ type=OspreyRegionCaptionDataset,
587
+ data_path=region_cap_osprey_data_path,
588
+ image_folder=region_cap_osprey_image_path,
589
+ tokenizer=tokenizer,
590
+ image_processor=image_processor,
591
+ dataset_map_fn=osprey_region_caption_map_fn,
592
+ template_map_fn=dict(
593
+ type=template_map_fn_factory, template=prompt_template),
594
+ max_length=max_length,
595
+ pad_image_to_square=True,
596
+ debug=False,
597
+ repeats=1,
598
+ )
599
+
600
+ region_conversation_osprey_dataset = dict(
601
+ type=OspreyRegionConversationDataset,
602
+ data_path=region_conversation_osprey_data_path,
603
+ image_folder=region_conversation_osprey_image_path,
604
+ tokenizer=tokenizer,
605
+ image_processor=image_processor,
606
+ dataset_map_fn=osprey_region_conversation_map_fn,
607
+ template_map_fn=dict(
608
+ type=template_map_fn_factory, template=prompt_template),
609
+ max_length=max_length,
610
+ pad_image_to_square=True,
611
+ debug=False,
612
+ repeats=1,
613
+ )
614
+
615
+ mdpv_detailed_description_ade20k_dataset = dict(
616
+ type=MDPVPointDetailedCaptionDataset,
617
+ data_path=mdpv_detailed_caption_ade20k_data_path,
618
+ image_folder=mdpv_detailed_caption_ade20k_image_path,
619
+ tokenizer=tokenizer,
620
+ image_processor=image_processor,
621
+ dataset_map_fn=mdpv_points_map_fn,
622
+ template_map_fn=dict(
623
+ type=template_map_fn_factory, template=prompt_template),
624
+ max_length=max_length,
625
+ pad_image_to_square=True,
626
+ debug=False,
627
+ repeats=1,
628
+ )
629
+
630
+ mdpv_detailed_description_cocostuff_10k_dataset = dict(
631
+ type=MDPVPointDetailedCaptionDataset,
632
+ data_path=mdpv_detailed_caption_cocostuff_10k_data_path,
633
+ image_folder=mdpv_detailed_caption_cocostuff_10k_image_path,
634
+ tokenizer=tokenizer,
635
+ image_processor=image_processor,
636
+ dataset_map_fn=mdpv_points_map_fn,
637
+ template_map_fn=dict(
638
+ type=template_map_fn_factory, template=prompt_template),
639
+ max_length=max_length,
640
+ pad_image_to_square=True,
641
+ debug=False,
642
+ repeats=1,
643
+ )
644
+
645
+ mdpv_detailed_description_cocostuff_164k_dataset = dict(
646
+ type=MDPVPointDetailedCaptionDataset,
647
+ data_path=mdpv_detailed_caption_cocostuff_164k_data_path,
648
+ image_folder=mdpv_detailed_caption_cocostuff_164k_image_path,
649
+ tokenizer=tokenizer,
650
+ image_processor=image_processor,
651
+ dataset_map_fn=mdpv_points_map_fn,
652
+ template_map_fn=dict(
653
+ type=template_map_fn_factory, template=prompt_template),
654
+ max_length=max_length,
655
+ pad_image_to_square=True,
656
+ debug=False,
657
+ repeats=1,
658
+ )
659
+
660
+ mdpv_detailed_description_vg_dataset = dict(
661
+ type=MDPVPointDetailedCaptionDataset,
662
+ data_path=mdpv_detailed_caption_vg_data_path,
663
+ image_folder=mdpv_detailed_caption_vg_image_path,
664
+ tokenizer=tokenizer,
665
+ image_processor=image_processor,
666
+ dataset_map_fn=mdpv_points_map_fn,
667
+ template_map_fn=dict(
668
+ type=template_map_fn_factory, template=prompt_template),
669
+ max_length=max_length,
670
+ pad_image_to_square=True,
671
+ debug=False,
672
+ repeats=1,
673
+ )
674
+
675
+ mdpv_brief_description_vg_dataset = dict(
676
+ type=MDPVPointBriefCaptionDataset,
677
+ data_path=mdpv_brief_caption_vg_data_path,
678
+ image_folder=mdpv_brief_caption_vg_image_path,
679
+ tokenizer=tokenizer,
680
+ image_processor=image_processor,
681
+ dataset_map_fn=mdpv_points_map_fn,
682
+ template_map_fn=dict(
683
+ type=template_map_fn_factory, template=prompt_template),
684
+ max_length=max_length,
685
+ pad_image_to_square=True,
686
+ debug=False,
687
+ repeats=1,
688
+ )
689
+
690
+ mdpv_brief_description_cocostuff10k_dataset = dict(
691
+ type=MDPVPointBriefCaptionDataset,
692
+ data_path=mdpv_brief_caption_cocostuff_10k_data_path,
693
+ image_folder=mdpv_brief_caption_cocostuff_10k_image_path,
694
+ tokenizer=tokenizer,
695
+ image_processor=image_processor,
696
+ dataset_map_fn=mdpv_points_map_fn,
697
+ template_map_fn=dict(
698
+ type=template_map_fn_factory, template=prompt_template),
699
+ max_length=max_length,
700
+ pad_image_to_square=True,
701
+ debug=False,
702
+ repeats=1,
703
+ )
704
+
705
+ mdpv_brief_description_cocostuff164k_dataset = dict(
706
+ type=MDPVPointBriefCaptionDataset,
707
+ data_path=mdpv_brief_caption_cocostuff_164k_data_path,
708
+ image_folder=mdpv_brief_caption_cocostuff_164k_image_path,
709
+ tokenizer=tokenizer,
710
+ image_processor=image_processor,
711
+ dataset_map_fn=mdpv_points_map_fn,
712
+ template_map_fn=dict(
713
+ type=template_map_fn_factory, template=prompt_template),
714
+ max_length=max_length,
715
+ pad_image_to_square=True,
716
+ debug=False,
717
+ repeats=1,
718
+ )
719
+
720
+ mdpv_brief_description_ade20k_dataset = dict(
721
+ type=MDPVPointBriefCaptionDataset,
722
+ data_path=mdpv_brief_caption_ade20k_data_path,
723
+ image_folder=mdpv_brief_caption_ade20k_image_path,
724
+ tokenizer=tokenizer,
725
+ image_processor=image_processor,
726
+ dataset_map_fn=mdpv_points_map_fn,
727
+ template_map_fn=dict(
728
+ type=template_map_fn_factory, template=prompt_template),
729
+ max_length=max_length,
730
+ pad_image_to_square=True,
731
+ debug=False,
732
+ repeats=1,
733
+ )
734
+
735
+ mdpv_brief_description_lvis_dataset = dict(
736
+ type=MDPVPointBriefCaptionDataset,
737
+ data_path=mdpv_brief_caption_lvis_data_path,
738
+ image_folder=mdpv_brief_caption_lvis_image_path,
739
+ tokenizer=tokenizer,
740
+ image_processor=image_processor,
741
+ dataset_map_fn=mdpv_points_map_fn,
742
+ template_map_fn=dict(
743
+ type=template_map_fn_factory, template=prompt_template),
744
+ max_length=max_length,
745
+ pad_image_to_square=True,
746
+ debug=False,
747
+ repeats=1,
748
+ )
749
+
750
+ mdpv_qa_vg_dataset = dict(
751
+ type=MDPVPointBriefCaptionDataset,
752
+ data_path=mdpv_qa_vg_data_path,
753
+ image_folder=mdpv_qa_vg_image_path,
754
+ tokenizer=tokenizer,
755
+ image_processor=image_processor,
756
+ dataset_map_fn=mdpv_points_map_fn,
757
+ template_map_fn=dict(
758
+ type=template_map_fn_factory, template=prompt_template),
759
+ max_length=max_length,
760
+ pad_image_to_square=True,
761
+ debug=False,
762
+ repeats=1,
763
+ )
764
+
765
+ mdpv_qa_ade20k_dataset = dict(
766
+ type=MDPVPointBriefCaptionDataset,
767
+ data_path=mdpv_qa_ade20k_data_path,
768
+ image_folder=mdpv_qa_ade20k_image_path,
769
+ tokenizer=tokenizer,
770
+ image_processor=image_processor,
771
+ dataset_map_fn=mdpv_points_map_fn,
772
+ template_map_fn=dict(
773
+ type=template_map_fn_factory, template=prompt_template),
774
+ max_length=max_length,
775
+ pad_image_to_square=True,
776
+ debug=False,
777
+ repeats=1,
778
+ )
779
+
780
+ mdpv_qa_lvis_dataset = dict(
781
+ type=MDPVPointBriefCaptionDataset,
782
+ data_path=mdpv_qa_lvis_data_path,
783
+ image_folder=mdpv_qa_lvis_image_path,
784
+ tokenizer=tokenizer,
785
+ image_processor=image_processor,
786
+ dataset_map_fn=mdpv_points_map_fn,
787
+ template_map_fn=dict(
788
+ type=template_map_fn_factory, template=prompt_template),
789
+ max_length=max_length,
790
+ pad_image_to_square=True,
791
+ debug=False,
792
+ repeats=1,
793
+ )
794
+
795
+ mdpv_qa_cocostuff10k_dataset = dict(
796
+ type=MDPVPointBriefCaptionDataset,
797
+ data_path=mdpv_qa_cocostuff10k_data_path,
798
+ image_folder=mdpv_qa_cocostuff10k_image_path,
799
+ tokenizer=tokenizer,
800
+ image_processor=image_processor,
801
+ dataset_map_fn=mdpv_points_map_fn,
802
+ template_map_fn=dict(
803
+ type=template_map_fn_factory, template=prompt_template),
804
+ max_length=max_length,
805
+ pad_image_to_square=True,
806
+ debug=False,
807
+ repeats=1,
808
+ )
809
+
810
+ mdpv_qa_cocostuff164k_dataset = dict(
811
+ type=MDPVPointBriefCaptionDataset,
812
+ data_path=mdpv_qa_cocostuff164k_data_path,
813
+ image_folder=mdpv_qa_cocostuff164k_image_path,
814
+ tokenizer=tokenizer,
815
+ image_processor=image_processor,
816
+ dataset_map_fn=mdpv_points_map_fn,
817
+ template_map_fn=dict(
818
+ type=template_map_fn_factory, template=prompt_template),
819
+ max_length=max_length,
820
+ pad_image_to_square=True,
821
+ debug=False,
822
+ repeats=1,
823
+ )
824
+
825
+ mdpv_multi_points_openpsg_dataset = dict(
826
+ type=MDPVPointBriefCaptionDataset,
827
+ data_path=mdpv_multi_points_openpsg_data_path,
828
+ image_folder=mdpv_multi_points_openpsg_image_path,
829
+ tokenizer=tokenizer,
830
+ image_processor=image_processor,
831
+ dataset_map_fn=mdpv_points_map_fn,
832
+ template_map_fn=dict(
833
+ type=template_map_fn_factory, template=prompt_template),
834
+ max_length=max_length,
835
+ pad_image_to_square=True,
836
+ debug=False,
837
+ repeats=1,
838
+ )
839
+
840
+ mdpv_multi_points_flicker30k_dataset = dict(
841
+ type=MDPVPointBriefCaptionDataset,
842
+ data_path=mdpv_multi_points_flicker30k_data_path,
843
+ image_folder=mdpv_multi_points_flicker30k_image_path,
844
+ tokenizer=tokenizer,
845
+ image_processor=image_processor,
846
+ dataset_map_fn=mdpv_points_map_fn,
847
+ template_map_fn=dict(
848
+ type=template_map_fn_factory, template=prompt_template),
849
+ max_length=max_length,
850
+ pad_image_to_square=True,
851
+ debug=False,
852
+ repeats=1,
853
+ )
854
+
855
+ train_dataset = dict(
856
+ type=CombineDataset,
857
+ datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset,
858
+ glamm_grandf_dataset, glamm_psg_dataset,
859
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x
860
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset,
861
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset,
862
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
863
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x
864
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
865
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset,
866
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
867
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset,
868
+ region_cap_osprey_dataset, region_conversation_osprey_dataset,
869
+ mdpv_detailed_description_ade20k_dataset,
870
+ mdpv_detailed_description_cocostuff_10k_dataset,
871
+ mdpv_detailed_description_cocostuff_164k_dataset,
872
+ mdpv_detailed_description_vg_dataset,
873
+ mdpv_brief_description_lvis_dataset,
874
+ mdpv_brief_description_vg_dataset,
875
+ mdpv_brief_description_ade20k_dataset,
876
+ mdpv_brief_description_cocostuff10k_dataset,
877
+ mdpv_brief_description_cocostuff164k_dataset,
878
+ mdpv_qa_vg_dataset,
879
+ mdpv_qa_lvis_dataset,
880
+ mdpv_qa_ade20k_dataset,
881
+ mdpv_qa_cocostuff10k_dataset,
882
+ mdpv_qa_cocostuff164k_dataset,
883
+ mdpv_multi_points_flicker30k_dataset,
884
+ mdpv_multi_points_openpsg_dataset,],
885
+ )
886
+
887
+ train_dataloader = dict(
888
+ batch_size=batch_size,
889
+ num_workers=dataloader_num_workers,
890
+ dataset=train_dataset,
891
+ sampler=dict(
892
+ type=LengthGroupedSampler,
893
+ length_property='modality_length',
894
+ per_device_batch_size=batch_size * accumulative_counts),
895
+ collate_fn=dict(type=omg_llava_collate_fn))
896
+
897
+ #######################################################################
898
+ # PART 4 Scheduler & Optimizer #
899
+ #######################################################################
900
+ # optimizer
901
+ optim_wrapper = dict(
902
+ type=AmpOptimWrapper,
903
+ optimizer=dict(
904
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
905
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
906
+ accumulative_counts=accumulative_counts,
907
+ loss_scale='dynamic',
908
+ dtype='float16')
909
+
910
+ # learning policy
911
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
912
+ param_scheduler = [
913
+ dict(
914
+ type=LinearLR,
915
+ start_factor=1e-5,
916
+ by_epoch=True,
917
+ begin=0,
918
+ end=warmup_ratio * max_epochs,
919
+ convert_to_iter_based=True),
920
+ dict(
921
+ type=CosineAnnealingLR,
922
+ eta_min=0.0,
923
+ by_epoch=True,
924
+ begin=warmup_ratio * max_epochs,
925
+ end=max_epochs,
926
+ convert_to_iter_based=True)
927
+ ]
928
+
929
+ # train, val, test setting
930
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
931
+
932
+ #######################################################################
933
+ # PART 5 Runtime #
934
+ #######################################################################
935
+ # Log the dialogue periodically during the training process, optional
936
+ custom_hooks = [
937
+ dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer),
938
+ dict(
939
+ type=EvaluateChatHook_withSpecialTokens,
940
+ tokenizer=tokenizer,
941
+ image_processor=image_processor,
942
+ every_n_iters=evaluation_freq,
943
+ evaluation_inputs=evaluation_inputs,
944
+ evaluation_images=evaluation_images,
945
+ system=SYSTEM,
946
+ prompt_template=prompt_template)
947
+ ]
948
+
949
+ # configure default hooks
950
+ default_hooks = dict(
951
+ # record the time of every iteration.
952
+ timer=dict(type=IterTimerHook),
953
+ # print log every 10 iterations.
954
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
955
+ # enable the parameter scheduler.
956
+ param_scheduler=dict(type=ParamSchedulerHook),
957
+ # save checkpoint per `save_steps`.
958
+ checkpoint=dict(
959
+ type=CheckpointHook,
960
+ by_epoch=False,
961
+ interval=save_steps,
962
+ max_keep_ckpts=save_total_limit),
963
+ # set sampler seed in distributed evrionment.
964
+ sampler_seed=dict(type=DistSamplerSeedHook),
965
+ )
966
+
967
+ # configure environment
968
+ env_cfg = dict(
969
+ # whether to enable cudnn benchmark
970
+ cudnn_benchmark=False,
971
+ # set multi process parameters
972
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
973
+ # set distributed parameters
974
+ dist_cfg=dict(backend='nccl'),
975
+ )
976
+
977
+ # set visualizer
978
+ visualizer = None
979
+
980
+ # set log level
981
+ log_level = 'INFO'
982
+
983
+ # load from which checkpoint
984
+ load_from = None
985
+
986
+ # whether to resume training from the loaded checkpoint
987
+ resume = False
988
+
989
+ # Defaults to use random seed and disable `deterministic`
990
+ randomness = dict(seed=None, deterministic=False)
991
+
992
+ # set log processor
993
+ log_processor = dict(by_epoch=False)
omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_8gpus_01.py ADDED
@@ -0,0 +1,1007 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
4
+ LoggerHook, ParamSchedulerHook)
5
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
6
+ from peft import LoraConfig
7
+ from torch.optim import AdamW
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig, CLIPImageProcessor,
10
+ CLIPVisionModel)
11
+
12
+ from omg_llava.dataset import LLaVADataset
13
+ from omg_llava.dataset.collect_fns import omg_llava_collate_fn
14
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
15
+ from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\
16
+ CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\
17
+ ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\
18
+ PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\
19
+ RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\
20
+ Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\
21
+ OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\
22
+ OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\
23
+ MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset
24
+ from xtuner.dataset.samplers import LengthGroupedSampler
25
+ from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens
26
+ from xtuner.engine.runner import TrainLoop
27
+ from omg_llava.model import OMG_LLaVA
28
+ from xtuner.utils import PROMPT_TEMPLATE
29
+ from omg_llava.model import OpenCLIPBackbone_omgseg
30
+ from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead
31
+
32
+ from torch.nn import GroupNorm, ReLU
33
+
34
+ from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \
35
+ DiceLoss, MaskFormerFusionHead, FocalLoss
36
+ from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost
37
+ from mmdet.models.task_modules.samplers import MaskPseudoSampler
38
+
39
+ #######################################################################
40
+ # PART 1 Settings #
41
+ #######################################################################
42
+ # Model
43
+ llm_name_or_path = './pretrained/omg_llava/internlm2-7b' # Please change to your own path
44
+ pretrained_pth = './pretrained/omg_llava/omg_llava_7b_pretrain_1024image_8gpus.pth'
45
+ omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path
46
+ omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path
47
+
48
+ # Data
49
+ data_root = './data/llava_data/'
50
+ data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json'
51
+ image_folder = data_root + 'llava_images'
52
+
53
+ glamm_data_root = './data/glamm_data/'
54
+
55
+ refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/'
56
+ refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json'
57
+
58
+ grandf_image_path = glamm_data_root + 'images/grandf/train/'
59
+ grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json'
60
+
61
+ flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/'
62
+ flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json'
63
+
64
+ psg_image_path = glamm_data_root + 'images/coco2017/'
65
+ psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json'
66
+
67
+ ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
68
+ ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json'
69
+
70
+ cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/'
71
+ cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt'
72
+ cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/'
73
+
74
+ mapillary_image_path = './data/semantic_seg/mapillary/training/images/'
75
+ mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json'
76
+ mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/'
77
+
78
+ pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/'
79
+ pascal_file = './data/semantic_seg/pascal_part/train.json'
80
+
81
+ paco_image_path = './data/glamm_data/images/coco2017/'
82
+ paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json'
83
+
84
+ referring_refcoco_image_path = refcocog_image_path
85
+ referring_refcoco_data_path = "./data/ref_seg/"
86
+
87
+ referring_refcoco_plus_image_path = refcocog_image_path
88
+ referring_refcoco_plus_data_path = "./data/ref_seg/"
89
+
90
+ referring_refcocog_image_path = refcocog_image_path
91
+ referring_refcocog_data_path = "./data/ref_seg/"
92
+
93
+ referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/"
94
+ referring_refclef_data_path = "./data/ref_seg/"
95
+
96
+ region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
97
+ region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json"
98
+
99
+ region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
100
+ region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json"
101
+
102
+ mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
103
+ mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json'
104
+
105
+ mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
106
+ mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json'
107
+
108
+ mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
109
+ mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json'
110
+
111
+ mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
112
+ mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json'
113
+
114
+ mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
115
+ mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json'
116
+
117
+ mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
118
+ mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json'
119
+
120
+ mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
121
+ mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json'
122
+
123
+ mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
124
+ mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json'
125
+
126
+ mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
127
+ mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json'
128
+
129
+ mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
130
+ mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json'
131
+
132
+ mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
133
+ mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json'
134
+
135
+ mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017'
136
+ mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json'
137
+
138
+ mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
139
+ mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json'
140
+
141
+ mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
142
+ mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json'
143
+
144
+ mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/'
145
+ mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json'
146
+
147
+ mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017'
148
+ mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json'
149
+
150
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
151
+ max_length = int(2048 - (1024 / 64)**2 - 100)
152
+
153
+ # Scheduler & Optimizer
154
+ batch_size = 8 # per_device
155
+ accumulative_counts = 2
156
+ dataloader_num_workers = 4
157
+ max_epochs = 1
158
+ optim_type = AdamW
159
+ lr = 2e-4
160
+ betas = (0.9, 0.999)
161
+ weight_decay = 0
162
+ max_norm = 1 # grad clip
163
+ warmup_ratio = 0.03
164
+
165
+
166
+ # Save
167
+ save_steps = 2000
168
+ save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited)
169
+
170
+ # Evaluate the generation performance during the training
171
+ evaluation_freq = 2000
172
+ SYSTEM = ''
173
+ evaluation_images = './work_dirs/test.jpg'
174
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture',
175
+ 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.']
176
+
177
+ #######################################################################
178
+ # PART 2 Model & Tokenizer & Image Processor #
179
+ #######################################################################
180
+ tokenizer = dict(
181
+ type=AutoTokenizer.from_pretrained,
182
+ pretrained_model_name_or_path=llm_name_or_path,
183
+ trust_remote_code=True,
184
+ padding_side='right')
185
+
186
+ image_processor = dict(
187
+ type=CLIPImageProcessor,
188
+ do_resize=True,
189
+ size=1024,
190
+ resample=3,
191
+ do_center_crop=True,
192
+ crop_size=1024,
193
+ do_rescale=True,
194
+ do_normalize=True,
195
+ image_mean=[0.4814, 0.4578, 0.4082],
196
+ image_std=[0.2686, 0.2613, 0.2757],
197
+ do_convert_rgb=True
198
+ )
199
+
200
+ class_embed = 'convnext_large_d_320_CocoPanopticOVDataset'
201
+ num_things_classes = 80
202
+ num_stuff_classes = 53
203
+ num_classes = num_things_classes + num_stuff_classes
204
+
205
+ omgseg_model = dict(
206
+ type=OMGSegVisualEncoder,
207
+ data_preprocessor=None,
208
+ pixel_shuffle_down_ratio=2,
209
+ backbone=dict(
210
+ type=OpenCLIPBackbone_omgseg,
211
+ model_name='convnext_large_d_320',
212
+ fix=True,
213
+ init_cfg=dict(
214
+ type='clip_pretrain',
215
+ checkpoint='laion2b_s29b_b131k_ft_soup'
216
+ )
217
+ ),
218
+ panoptic_head=dict(
219
+ type=Mask2FormerVideoSemSamHead,
220
+ sphere_cls=True,
221
+ ov_path=omg_ov_class_embed_path,
222
+ enable_box_query=False,
223
+ ov_classifier_name=class_embed,
224
+ logit=None,
225
+ in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside
226
+ strides=[4, 8, 16, 32],
227
+ feat_channels=256,
228
+ out_channels=256,
229
+ num_things_classes=num_things_classes,
230
+ num_stuff_classes=num_stuff_classes,
231
+ num_queries=300,
232
+ num_transformer_feat_level=3,
233
+ pixel_decoder=dict(
234
+ type=MSDeformAttnPixelDecoder,
235
+ num_outs=3,
236
+ norm_cfg=dict(type=GroupNorm, num_groups=32),
237
+ act_cfg=dict(type=ReLU),
238
+ encoder=dict( # DeformableDetrTransformerEncoder
239
+ num_layers=6,
240
+ layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
241
+ self_attn_cfg=dict( # MultiScaleDeformableAttention
242
+ embed_dims=256,
243
+ num_heads=8,
244
+ num_levels=3,
245
+ num_points=4,
246
+ dropout=0.0,
247
+ batch_first=True),
248
+ ffn_cfg=dict(
249
+ embed_dims=256,
250
+ feedforward_channels=1024,
251
+ num_fcs=2,
252
+ ffn_drop=0.0,
253
+ act_cfg=dict(type=ReLU, inplace=True)))),
254
+ positional_encoding=dict(num_feats=128, normalize=True)),
255
+ enforce_decoder_input_project=False,
256
+ positional_encoding=dict(num_feats=128, normalize=True),
257
+ transformer_decoder=dict( # Mask2FormerTransformerDecoder
258
+ return_intermediate=True,
259
+ num_layers=9,
260
+ layer_cfg=dict( # Mask2FormerTransformerDecoderLayer
261
+ self_attn_cfg=dict( # MultiheadAttention
262
+ embed_dims=256,
263
+ num_heads=8,
264
+ dropout=0.0,
265
+ batch_first=True),
266
+ cross_attn_cfg=dict( # MultiheadAttention
267
+ embed_dims=256,
268
+ num_heads=8,
269
+ dropout=0.0,
270
+ batch_first=True),
271
+ ffn_cfg=dict(
272
+ embed_dims=256,
273
+ feedforward_channels=2048,
274
+ num_fcs=2,
275
+ ffn_drop=0.0,
276
+ act_cfg=dict(type='ReLU', inplace=True))),
277
+ init_cfg=None),
278
+ loss_cls=dict(
279
+ type=CrossEntropyLoss,
280
+ use_sigmoid=False,
281
+ loss_weight=2.0,
282
+ reduction='mean',
283
+ class_weight=[1.0] * 240 + [0.1]),
284
+ loss_mask=dict(
285
+ type=CrossEntropyLoss,
286
+ use_sigmoid=True,
287
+ reduction='mean',
288
+ loss_weight=5.0),
289
+ loss_dice=dict(
290
+ type=DiceLoss,
291
+ use_sigmoid=True,
292
+ activate=True,
293
+ reduction='mean',
294
+ naive_dice=True,
295
+ eps=1.0,
296
+ loss_weight=5.0),
297
+ loss_iou=dict(
298
+ type=FocalLoss,
299
+ use_sigmoid=True,
300
+ loss_weight=2.0,
301
+ reduction='mean')
302
+ ),
303
+ panoptic_fusion_head=dict(
304
+ type=MaskFormerFusionHead,
305
+ num_things_classes=num_things_classes,
306
+ num_stuff_classes=num_stuff_classes,
307
+ loss_panoptic=None,
308
+ init_cfg=None),
309
+ train_cfg=dict(
310
+ num_points=12544,
311
+ oversample_ratio=3.0,
312
+ importance_sample_ratio=0.75,
313
+ assigner=dict(
314
+ type=HungarianAssigner,
315
+ match_costs=[
316
+ # dict(type=FlexibleClassificationCost, weight=2.0),
317
+ dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True),
318
+ dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0)
319
+ ]),
320
+ sampler=dict(type=MaskPseudoSampler)),
321
+ test_cfg=dict(
322
+ panoptic_on=True,
323
+ # For now, the dataset does not support
324
+ # evaluating semantic segmentation metric.
325
+ semantic_on=False,
326
+ instance_on=True,
327
+ # max_per_image is for instance segmentation.
328
+ max_per_image=100,
329
+ iou_thr=0.8,
330
+ # In Mask2Former's panoptic postprocessing,
331
+ # it will filter mask area where score is less than 0.5 .
332
+ filter_low_score=True),
333
+ init_cfg=dict(
334
+ type='Pretrained',
335
+ checkpoint=omg_head_pretrain_pth_path,
336
+ )
337
+ )
338
+
339
+ model = dict(
340
+ type=OMG_LLaVA,
341
+ freeze_llm=True,
342
+ freeze_visual_encoder=True,
343
+ require_omg_decoder=False,
344
+ pretrained_pth=pretrained_pth,
345
+ text2vision_projector=True,
346
+ pixel_shuffle_ratio=2,
347
+ llm=dict(
348
+ type=AutoModelForCausalLM.from_pretrained,
349
+ pretrained_model_name_or_path=llm_name_or_path,
350
+ trust_remote_code=True,
351
+ torch_dtype=torch.float16,
352
+ quantization_config=dict(
353
+ type=BitsAndBytesConfig,
354
+ load_in_4bit=True,
355
+ load_in_8bit=False,
356
+ llm_int8_threshold=6.0,
357
+ llm_int8_has_fp16_weight=False,
358
+ bnb_4bit_compute_dtype=torch.float16,
359
+ bnb_4bit_use_double_quant=True,
360
+ bnb_4bit_quant_type='nf4')),
361
+ llm_lora=dict(
362
+ type=LoraConfig,
363
+ r=512,
364
+ lora_alpha=256,
365
+ lora_dropout=0.05,
366
+ bias='none',
367
+ task_type='CAUSAL_LM'),
368
+ visual_encoder=omgseg_model,
369
+ tokenizer=tokenizer,
370
+ )
371
+
372
+ #######################################################################
373
+ # PART 3 Dataset & Dataloader #
374
+ #######################################################################
375
+ debug=False
376
+ llava_dataset = dict(
377
+ type=LLaVADataset,
378
+ data_path=data_path,
379
+ image_folder=image_folder,
380
+ tokenizer=tokenizer,
381
+ image_processor=image_processor,
382
+ dataset_map_fn=llava_map_fn,
383
+ template_map_fn=dict(
384
+ type=template_map_fn_factory, template=prompt_template),
385
+ max_length=max_length,
386
+ pad_image_to_square=True)
387
+
388
+ glamm_refcocog_dataset = dict(
389
+ type=RefCOCOgGCGDataset,
390
+ data_path=refcocog_ann_file,
391
+ image_folder=refcocog_image_path,
392
+ tokenizer=tokenizer,
393
+ image_processor=image_processor,
394
+ dataset_map_fn=glamm_refcocog_map_fn,
395
+ template_map_fn=dict(
396
+ type=template_map_fn_factory, template=prompt_template),
397
+ max_length=max_length,
398
+ pad_image_to_square=True,
399
+ debug=False,
400
+ repeats=1,
401
+ )
402
+
403
+ glamm_grandf_dataset = dict(
404
+ type=GranDfGCGDataset,
405
+ data_path=grandf_ann_file,
406
+ image_folder=grandf_image_path,
407
+ tokenizer=tokenizer,
408
+ image_processor=image_processor,
409
+ dataset_map_fn=glamm_granf_map_fn,
410
+ template_map_fn=dict(
411
+ type=template_map_fn_factory, template=prompt_template),
412
+ max_length=max_length,
413
+ pad_image_to_square=True,
414
+ debug=debug,
415
+ repeats=10,
416
+ num_proc=32
417
+ )
418
+
419
+ glamm_psg_dataset = dict(
420
+ type=OpenPsgGCGDataset,
421
+ data_path=psg_ann_file,
422
+ image_folder=psg_image_path,
423
+ tokenizer=tokenizer,
424
+ image_processor=image_processor,
425
+ dataset_map_fn=glamm_openpsg_map_fn,
426
+ template_map_fn=dict(
427
+ type=template_map_fn_factory, template=prompt_template),
428
+ max_length=max_length,
429
+ pad_image_to_square=True,
430
+ debug=debug,
431
+ repeats=1,
432
+ num_proc=32
433
+ )
434
+
435
+ glamm_flickr_dataset = dict(
436
+ type=FlickrGCGDataset,
437
+ data_path=flickr_ann_file,
438
+ image_folder=flickr_image_path,
439
+ tokenizer=tokenizer,
440
+ image_processor=image_processor,
441
+ dataset_map_fn=glamm_flickr_map_fn,
442
+ template_map_fn=dict(
443
+ type=template_map_fn_factory, template=prompt_template),
444
+ max_length=max_length,
445
+ pad_image_to_square=True,
446
+ debug=debug,
447
+ repeats=1,
448
+ num_proc=32
449
+ )
450
+
451
+ semantic_seg_ade20k_dataset = dict(
452
+ type=ADE20kSemanticSegDataset,
453
+ data_path=ade20k_class_file,
454
+ image_folder=ade20k_image_path,
455
+ tokenizer=tokenizer,
456
+ image_processor=image_processor,
457
+ dataset_map_fn=semantic_seg_map_fn,
458
+ template_map_fn=dict(
459
+ type=template_map_fn_factory, template=prompt_template),
460
+ max_length=max_length,
461
+ pad_image_to_square=True,
462
+ debug=False,
463
+ repeats=1,
464
+ num_proc=32
465
+ )
466
+
467
+ semantic_seg_cocostuff_dataset = dict(
468
+ type=COCOStuffSemanticSegDataset,
469
+ data_path=cocostuff_class_file,
470
+ image_folder=cocostuff_image_path,
471
+ label_path=cocostuff_label_path,
472
+ tokenizer=tokenizer,
473
+ image_processor=image_processor,
474
+ dataset_map_fn=semantic_seg_map_fn,
475
+ template_map_fn=dict(
476
+ type=template_map_fn_factory, template=prompt_template),
477
+ max_length=max_length,
478
+ pad_image_to_square=True,
479
+ debug=False,
480
+ repeats=1,
481
+ num_proc=32
482
+ )
483
+
484
+ semantic_seg_mapillary_dataset = dict(
485
+ type=MapillarySemanticSegDataset,
486
+ data_path=mapillary_class_file,
487
+ image_folder=mapillary_image_path,
488
+ label_path=mapillary_label_path,
489
+ tokenizer=tokenizer,
490
+ image_processor=image_processor,
491
+ dataset_map_fn=semantic_seg_map_fn,
492
+ template_map_fn=dict(
493
+ type=template_map_fn_factory, template=prompt_template),
494
+ max_length=max_length,
495
+ pad_image_to_square=True,
496
+ debug=False,
497
+ repeats=1,
498
+ num_proc=32
499
+ )
500
+
501
+ semantic_seg_pascal_part_dataset = dict(
502
+ type=PascalPartSemanticSegDataset,
503
+ data_path=pascal_file,
504
+ image_folder=pascal_part_image_path,
505
+ tokenizer=tokenizer,
506
+ image_processor=image_processor,
507
+ dataset_map_fn=pascal_part_map_fn,
508
+ template_map_fn=dict(
509
+ type=template_map_fn_factory, template=prompt_template),
510
+ max_length=max_length,
511
+ pad_image_to_square=True,
512
+ debug=False,
513
+ repeats=1,
514
+ num_proc=32
515
+ )
516
+
517
+ semantic_seg_paco_dataset = dict(
518
+ type=PacoSemanticSegDataset,
519
+ data_path=paco_file,
520
+ image_folder=paco_image_path,
521
+ tokenizer=tokenizer,
522
+ image_processor=image_processor,
523
+ dataset_map_fn=pascal_part_map_fn,
524
+ template_map_fn=dict(
525
+ type=template_map_fn_factory, template=prompt_template),
526
+ max_length=max_length,
527
+ pad_image_to_square=True,
528
+ debug=False,
529
+ repeats=1,
530
+ num_proc=32
531
+ )
532
+
533
+ referring_seg_refcoco_dataset = dict(
534
+ type=RefcocoReferringSegDataset,
535
+ data_path=referring_refcoco_data_path,
536
+ image_folder=referring_refcoco_image_path,
537
+ tokenizer=tokenizer,
538
+ image_processor=image_processor,
539
+ dataset_map_fn=referring_seg_map_fn,
540
+ template_map_fn=dict(
541
+ type=template_map_fn_factory, template=prompt_template),
542
+ max_length=max_length,
543
+ pad_image_to_square=True,
544
+ debug=False,
545
+ repeats=1,
546
+ num_proc=32
547
+ )
548
+
549
+ referring_seg_refcoco_plus_dataset = dict(
550
+ type=Refcoco_plus_ReferringSegDataset,
551
+ data_path=referring_refcoco_plus_data_path,
552
+ image_folder=referring_refcoco_plus_image_path,
553
+ tokenizer=tokenizer,
554
+ image_processor=image_processor,
555
+ dataset_map_fn=referring_seg_map_fn,
556
+ template_map_fn=dict(
557
+ type=template_map_fn_factory, template=prompt_template),
558
+ max_length=max_length,
559
+ pad_image_to_square=True,
560
+ debug=False,
561
+ repeats=1,
562
+ num_proc=32
563
+ )
564
+
565
+ referring_seg_refcocog_dataset = dict(
566
+ type=Refcocog_ReferringSegDataset,
567
+ data_path=referring_refcocog_data_path,
568
+ image_folder=referring_refcocog_image_path,
569
+ tokenizer=tokenizer,
570
+ image_processor=image_processor,
571
+ dataset_map_fn=referring_seg_map_fn,
572
+ template_map_fn=dict(
573
+ type=template_map_fn_factory, template=prompt_template),
574
+ max_length=max_length,
575
+ pad_image_to_square=True,
576
+ debug=False,
577
+ repeats=1,
578
+ num_proc=32
579
+ )
580
+
581
+ referring_seg_refclef_dataset = dict(
582
+ type=Refclef_ReferringSegDataset,
583
+ data_path=referring_refclef_data_path,
584
+ image_folder=referring_refclef_image_path,
585
+ tokenizer=tokenizer,
586
+ image_processor=image_processor,
587
+ dataset_map_fn=referring_seg_map_fn,
588
+ template_map_fn=dict(
589
+ type=template_map_fn_factory, template=prompt_template),
590
+ max_length=max_length,
591
+ pad_image_to_square=True,
592
+ debug=False,
593
+ repeats=1,
594
+ num_proc=32
595
+ )
596
+
597
+ region_cap_osprey_dataset = dict(
598
+ type=OspreyRegionCaptionDataset,
599
+ data_path=region_cap_osprey_data_path,
600
+ image_folder=region_cap_osprey_image_path,
601
+ tokenizer=tokenizer,
602
+ image_processor=image_processor,
603
+ dataset_map_fn=osprey_region_caption_map_fn,
604
+ template_map_fn=dict(
605
+ type=template_map_fn_factory, template=prompt_template),
606
+ max_length=max_length,
607
+ pad_image_to_square=True,
608
+ debug=False,
609
+ repeats=1,
610
+ num_proc=32
611
+ )
612
+
613
+ region_conversation_osprey_dataset = dict(
614
+ type=OspreyRegionConversationDataset,
615
+ data_path=region_conversation_osprey_data_path,
616
+ image_folder=region_conversation_osprey_image_path,
617
+ tokenizer=tokenizer,
618
+ image_processor=image_processor,
619
+ dataset_map_fn=osprey_region_conversation_map_fn,
620
+ template_map_fn=dict(
621
+ type=template_map_fn_factory, template=prompt_template),
622
+ max_length=max_length,
623
+ pad_image_to_square=True,
624
+ debug=False,
625
+ repeats=1,
626
+ num_proc=32
627
+ )
628
+
629
+ mdpv_detailed_description_ade20k_dataset = dict(
630
+ type=MDPVPointDetailedCaptionDataset,
631
+ data_path=mdpv_detailed_caption_ade20k_data_path,
632
+ image_folder=mdpv_detailed_caption_ade20k_image_path,
633
+ tokenizer=tokenizer,
634
+ image_processor=image_processor,
635
+ dataset_map_fn=mdpv_points_map_fn,
636
+ template_map_fn=dict(
637
+ type=template_map_fn_factory, template=prompt_template),
638
+ max_length=max_length,
639
+ pad_image_to_square=True,
640
+ debug=False,
641
+ repeats=1,
642
+ )
643
+
644
+ mdpv_detailed_description_cocostuff_10k_dataset = dict(
645
+ type=MDPVPointDetailedCaptionDataset,
646
+ data_path=mdpv_detailed_caption_cocostuff_10k_data_path,
647
+ image_folder=mdpv_detailed_caption_cocostuff_10k_image_path,
648
+ tokenizer=tokenizer,
649
+ image_processor=image_processor,
650
+ dataset_map_fn=mdpv_points_map_fn,
651
+ template_map_fn=dict(
652
+ type=template_map_fn_factory, template=prompt_template),
653
+ max_length=max_length,
654
+ pad_image_to_square=True,
655
+ debug=False,
656
+ repeats=1,
657
+ )
658
+
659
+ mdpv_detailed_description_cocostuff_164k_dataset = dict(
660
+ type=MDPVPointDetailedCaptionDataset,
661
+ data_path=mdpv_detailed_caption_cocostuff_164k_data_path,
662
+ image_folder=mdpv_detailed_caption_cocostuff_164k_image_path,
663
+ tokenizer=tokenizer,
664
+ image_processor=image_processor,
665
+ dataset_map_fn=mdpv_points_map_fn,
666
+ template_map_fn=dict(
667
+ type=template_map_fn_factory, template=prompt_template),
668
+ max_length=max_length,
669
+ pad_image_to_square=True,
670
+ debug=False,
671
+ repeats=1,
672
+ )
673
+
674
+ mdpv_detailed_description_vg_dataset = dict(
675
+ type=MDPVPointDetailedCaptionDataset,
676
+ data_path=mdpv_detailed_caption_vg_data_path,
677
+ image_folder=mdpv_detailed_caption_vg_image_path,
678
+ tokenizer=tokenizer,
679
+ image_processor=image_processor,
680
+ dataset_map_fn=mdpv_points_map_fn,
681
+ template_map_fn=dict(
682
+ type=template_map_fn_factory, template=prompt_template),
683
+ max_length=max_length,
684
+ pad_image_to_square=True,
685
+ debug=False,
686
+ repeats=1,
687
+ )
688
+
689
+ mdpv_brief_description_vg_dataset = dict(
690
+ type=MDPVPointBriefCaptionDataset,
691
+ data_path=mdpv_brief_caption_vg_data_path,
692
+ image_folder=mdpv_brief_caption_vg_image_path,
693
+ tokenizer=tokenizer,
694
+ image_processor=image_processor,
695
+ dataset_map_fn=mdpv_points_map_fn,
696
+ template_map_fn=dict(
697
+ type=template_map_fn_factory, template=prompt_template),
698
+ max_length=max_length,
699
+ pad_image_to_square=True,
700
+ debug=False,
701
+ repeats=1,
702
+ )
703
+
704
+ mdpv_brief_description_cocostuff10k_dataset = dict(
705
+ type=MDPVPointBriefCaptionDataset,
706
+ data_path=mdpv_brief_caption_cocostuff_10k_data_path,
707
+ image_folder=mdpv_brief_caption_cocostuff_10k_image_path,
708
+ tokenizer=tokenizer,
709
+ image_processor=image_processor,
710
+ dataset_map_fn=mdpv_points_map_fn,
711
+ template_map_fn=dict(
712
+ type=template_map_fn_factory, template=prompt_template),
713
+ max_length=max_length,
714
+ pad_image_to_square=True,
715
+ debug=False,
716
+ repeats=1,
717
+ )
718
+
719
+ mdpv_brief_description_cocostuff164k_dataset = dict(
720
+ type=MDPVPointBriefCaptionDataset,
721
+ data_path=mdpv_brief_caption_cocostuff_164k_data_path,
722
+ image_folder=mdpv_brief_caption_cocostuff_164k_image_path,
723
+ tokenizer=tokenizer,
724
+ image_processor=image_processor,
725
+ dataset_map_fn=mdpv_points_map_fn,
726
+ template_map_fn=dict(
727
+ type=template_map_fn_factory, template=prompt_template),
728
+ max_length=max_length,
729
+ pad_image_to_square=True,
730
+ debug=False,
731
+ repeats=1,
732
+ )
733
+
734
+ mdpv_brief_description_ade20k_dataset = dict(
735
+ type=MDPVPointBriefCaptionDataset,
736
+ data_path=mdpv_brief_caption_ade20k_data_path,
737
+ image_folder=mdpv_brief_caption_ade20k_image_path,
738
+ tokenizer=tokenizer,
739
+ image_processor=image_processor,
740
+ dataset_map_fn=mdpv_points_map_fn,
741
+ template_map_fn=dict(
742
+ type=template_map_fn_factory, template=prompt_template),
743
+ max_length=max_length,
744
+ pad_image_to_square=True,
745
+ debug=False,
746
+ repeats=1,
747
+ )
748
+
749
+ mdpv_brief_description_lvis_dataset = dict(
750
+ type=MDPVPointBriefCaptionDataset,
751
+ data_path=mdpv_brief_caption_lvis_data_path,
752
+ image_folder=mdpv_brief_caption_lvis_image_path,
753
+ tokenizer=tokenizer,
754
+ image_processor=image_processor,
755
+ dataset_map_fn=mdpv_points_map_fn,
756
+ template_map_fn=dict(
757
+ type=template_map_fn_factory, template=prompt_template),
758
+ max_length=max_length,
759
+ pad_image_to_square=True,
760
+ debug=False,
761
+ repeats=1,
762
+ )
763
+
764
+ mdpv_qa_vg_dataset = dict(
765
+ type=MDPVPointBriefCaptionDataset,
766
+ data_path=mdpv_qa_vg_data_path,
767
+ image_folder=mdpv_qa_vg_image_path,
768
+ tokenizer=tokenizer,
769
+ image_processor=image_processor,
770
+ dataset_map_fn=mdpv_points_map_fn,
771
+ template_map_fn=dict(
772
+ type=template_map_fn_factory, template=prompt_template),
773
+ max_length=max_length,
774
+ pad_image_to_square=True,
775
+ debug=False,
776
+ repeats=1,
777
+ )
778
+
779
+ mdpv_qa_ade20k_dataset = dict(
780
+ type=MDPVPointBriefCaptionDataset,
781
+ data_path=mdpv_qa_ade20k_data_path,
782
+ image_folder=mdpv_qa_ade20k_image_path,
783
+ tokenizer=tokenizer,
784
+ image_processor=image_processor,
785
+ dataset_map_fn=mdpv_points_map_fn,
786
+ template_map_fn=dict(
787
+ type=template_map_fn_factory, template=prompt_template),
788
+ max_length=max_length,
789
+ pad_image_to_square=True,
790
+ debug=False,
791
+ repeats=1,
792
+ )
793
+
794
+ mdpv_qa_lvis_dataset = dict(
795
+ type=MDPVPointBriefCaptionDataset,
796
+ data_path=mdpv_qa_lvis_data_path,
797
+ image_folder=mdpv_qa_lvis_image_path,
798
+ tokenizer=tokenizer,
799
+ image_processor=image_processor,
800
+ dataset_map_fn=mdpv_points_map_fn,
801
+ template_map_fn=dict(
802
+ type=template_map_fn_factory, template=prompt_template),
803
+ max_length=max_length,
804
+ pad_image_to_square=True,
805
+ debug=False,
806
+ repeats=1,
807
+ )
808
+
809
+ mdpv_qa_cocostuff10k_dataset = dict(
810
+ type=MDPVPointBriefCaptionDataset,
811
+ data_path=mdpv_qa_cocostuff10k_data_path,
812
+ image_folder=mdpv_qa_cocostuff10k_image_path,
813
+ tokenizer=tokenizer,
814
+ image_processor=image_processor,
815
+ dataset_map_fn=mdpv_points_map_fn,
816
+ template_map_fn=dict(
817
+ type=template_map_fn_factory, template=prompt_template),
818
+ max_length=max_length,
819
+ pad_image_to_square=True,
820
+ debug=False,
821
+ repeats=1,
822
+ )
823
+
824
+ mdpv_qa_cocostuff164k_dataset = dict(
825
+ type=MDPVPointBriefCaptionDataset,
826
+ data_path=mdpv_qa_cocostuff164k_data_path,
827
+ image_folder=mdpv_qa_cocostuff164k_image_path,
828
+ tokenizer=tokenizer,
829
+ image_processor=image_processor,
830
+ dataset_map_fn=mdpv_points_map_fn,
831
+ template_map_fn=dict(
832
+ type=template_map_fn_factory, template=prompt_template),
833
+ max_length=max_length,
834
+ pad_image_to_square=True,
835
+ debug=False,
836
+ repeats=1,
837
+ )
838
+
839
+ mdpv_multi_points_openpsg_dataset = dict(
840
+ type=MDPVPointBriefCaptionDataset,
841
+ data_path=mdpv_multi_points_openpsg_data_path,
842
+ image_folder=mdpv_multi_points_openpsg_image_path,
843
+ tokenizer=tokenizer,
844
+ image_processor=image_processor,
845
+ dataset_map_fn=mdpv_points_map_fn,
846
+ template_map_fn=dict(
847
+ type=template_map_fn_factory, template=prompt_template),
848
+ max_length=max_length,
849
+ pad_image_to_square=True,
850
+ debug=False,
851
+ repeats=1,
852
+ )
853
+
854
+ mdpv_multi_points_flicker30k_dataset = dict(
855
+ type=MDPVPointBriefCaptionDataset,
856
+ data_path=mdpv_multi_points_flicker30k_data_path,
857
+ image_folder=mdpv_multi_points_flicker30k_image_path,
858
+ tokenizer=tokenizer,
859
+ image_processor=image_processor,
860
+ dataset_map_fn=mdpv_points_map_fn,
861
+ template_map_fn=dict(
862
+ type=template_map_fn_factory, template=prompt_template),
863
+ max_length=max_length,
864
+ pad_image_to_square=True,
865
+ debug=False,
866
+ repeats=1,
867
+ )
868
+
869
+ train_dataset = dict(
870
+ type=CombineDataset,
871
+ datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset,
872
+ glamm_grandf_dataset, glamm_psg_dataset,
873
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x
874
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset,
875
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset,
876
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
877
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x
878
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
879
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset,
880
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
881
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset,
882
+ region_cap_osprey_dataset, region_conversation_osprey_dataset,
883
+ mdpv_detailed_description_ade20k_dataset,
884
+ mdpv_detailed_description_cocostuff_10k_dataset,
885
+ mdpv_detailed_description_cocostuff_164k_dataset,
886
+ mdpv_detailed_description_vg_dataset,
887
+ mdpv_brief_description_lvis_dataset,
888
+ mdpv_brief_description_vg_dataset,
889
+ mdpv_brief_description_ade20k_dataset,
890
+ mdpv_brief_description_cocostuff10k_dataset,
891
+ mdpv_brief_description_cocostuff164k_dataset,
892
+ mdpv_qa_vg_dataset,
893
+ mdpv_qa_lvis_dataset,
894
+ mdpv_qa_ade20k_dataset,
895
+ mdpv_qa_cocostuff10k_dataset,
896
+ mdpv_qa_cocostuff164k_dataset,
897
+ mdpv_multi_points_flicker30k_dataset,
898
+ mdpv_multi_points_openpsg_dataset,],
899
+ )
900
+
901
+ train_dataloader = dict(
902
+ batch_size=batch_size,
903
+ num_workers=dataloader_num_workers,
904
+ dataset=train_dataset,
905
+ sampler=dict(
906
+ type=LengthGroupedSampler,
907
+ length_property='modality_length',
908
+ per_device_batch_size=batch_size * accumulative_counts),
909
+ collate_fn=dict(type=omg_llava_collate_fn))
910
+
911
+ #######################################################################
912
+ # PART 4 Scheduler & Optimizer #
913
+ #######################################################################
914
+ # optimizer
915
+ optim_wrapper = dict(
916
+ type=AmpOptimWrapper,
917
+ optimizer=dict(
918
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
919
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
920
+ accumulative_counts=accumulative_counts,
921
+ loss_scale='dynamic',
922
+ dtype='float16')
923
+
924
+ # learning policy
925
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
926
+ param_scheduler = [
927
+ dict(
928
+ type=LinearLR,
929
+ start_factor=1e-5,
930
+ by_epoch=True,
931
+ begin=0,
932
+ end=warmup_ratio * max_epochs,
933
+ convert_to_iter_based=True),
934
+ dict(
935
+ type=CosineAnnealingLR,
936
+ eta_min=0.0,
937
+ by_epoch=True,
938
+ begin=warmup_ratio * max_epochs,
939
+ end=max_epochs,
940
+ convert_to_iter_based=True)
941
+ ]
942
+
943
+ # train, val, test setting
944
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
945
+
946
+ #######################################################################
947
+ # PART 5 Runtime #
948
+ #######################################################################
949
+ # Log the dialogue periodically during the training process, optional
950
+ custom_hooks = [
951
+ dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer),
952
+ dict(
953
+ type=EvaluateChatHook_withSpecialTokens,
954
+ tokenizer=tokenizer,
955
+ image_processor=image_processor,
956
+ every_n_iters=evaluation_freq,
957
+ evaluation_inputs=evaluation_inputs,
958
+ evaluation_images=evaluation_images,
959
+ system=SYSTEM,
960
+ prompt_template=prompt_template)
961
+ ]
962
+
963
+ # configure default hooks
964
+ default_hooks = dict(
965
+ # record the time of every iteration.
966
+ timer=dict(type=IterTimerHook),
967
+ # print log every 10 iterations.
968
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
969
+ # enable the parameter scheduler.
970
+ param_scheduler=dict(type=ParamSchedulerHook),
971
+ # save checkpoint per `save_steps`.
972
+ checkpoint=dict(
973
+ type=CheckpointHook,
974
+ by_epoch=False,
975
+ interval=save_steps,
976
+ max_keep_ckpts=save_total_limit),
977
+ # set sampler seed in distributed evrionment.
978
+ sampler_seed=dict(type=DistSamplerSeedHook),
979
+ )
980
+
981
+ # configure environment
982
+ env_cfg = dict(
983
+ # whether to enable cudnn benchmark
984
+ cudnn_benchmark=False,
985
+ # set multi process parameters
986
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
987
+ # set distributed parameters
988
+ dist_cfg=dict(backend='nccl'),
989
+ )
990
+
991
+ # set visualizer
992
+ visualizer = None
993
+
994
+ # set log level
995
+ log_level = 'INFO'
996
+
997
+ # load from which checkpoint
998
+ load_from = None
999
+
1000
+ # whether to resume training from the loaded checkpoint
1001
+ resume = False
1002
+
1003
+ # Defaults to use random seed and disable `deterministic`
1004
+ randomness = dict(seed=None, deterministic=False)
1005
+
1006
+ # set log processor
1007
+ log_processor = dict(by_epoch=False)
omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_decoupleGCG_8gpus.py ADDED
@@ -0,0 +1,1028 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
4
+ LoggerHook, ParamSchedulerHook)
5
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
6
+ from peft import LoraConfig
7
+ from torch.optim import AdamW
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig, CLIPImageProcessor,
10
+ CLIPVisionModel)
11
+
12
+ from omg_llava.dataset import LLaVADataset
13
+ from omg_llava.dataset.collect_fns import omg_llava_collate_fn
14
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
15
+ from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\
16
+ CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\
17
+ ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\
18
+ PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\
19
+ RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\
20
+ Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\
21
+ OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\
22
+ OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\
23
+ MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\
24
+ semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\
25
+ referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn,\
26
+ DecoupledGranDfGCGDataset, DecoupledOpenPsgGCGDataset, DecoupledRefCOCOgGCGDataset, DecoupledFlickrGCGDataset,\
27
+ glamm_openpsg_decoupled_given_description_map_fn, glamm_openpsg_decoupled_given_objects_map_fn,\
28
+ glamm_flickr_decoupled_given_objects_map_fn, glamm_flickr_decoupled_given_description_map_fn,\
29
+ glamm_granf_decoupled_given_objects_map_fn, glamm_granf_decoupled_given_description_map_fn,\
30
+ glamm_refcocog_decoupled_given_objects_map_fn, glamm_refcocog_decoupled_given_description_map_fn
31
+
32
+ from xtuner.dataset.samplers import LengthGroupedSampler
33
+ from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens
34
+ from xtuner.engine.runner import TrainLoop
35
+ from omg_llava.model import OMG_LLaVA
36
+ from xtuner.utils import PROMPT_TEMPLATE
37
+ from omg_llava.model import OpenCLIPBackbone_omgseg
38
+ from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead
39
+
40
+ from torch.nn import GroupNorm, ReLU
41
+
42
+ from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \
43
+ DiceLoss, MaskFormerFusionHead, FocalLoss
44
+ from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost
45
+ from mmdet.models.task_modules.samplers import MaskPseudoSampler
46
+
47
+ #######################################################################
48
+ # PART 1 Settings #
49
+ #######################################################################
50
+ # Model
51
+ llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path
52
+ pretrained_pth = './pretrained/omg_llava/omg_llava_7b_pretrain_1024image_8gpus.pth'
53
+ omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path
54
+ omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path
55
+
56
+ # Data
57
+ data_root = './data/llava_data/'
58
+ data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json'
59
+ image_folder = data_root + 'llava_images'
60
+
61
+ glamm_data_root = './data/glamm_data/'
62
+
63
+ refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/'
64
+ refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json'
65
+
66
+ grandf_image_path = glamm_data_root + 'images/grandf/train/'
67
+ grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json'
68
+
69
+ flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/'
70
+ flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json'
71
+
72
+ psg_image_path = glamm_data_root + 'images/coco2017/'
73
+ psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json'
74
+
75
+ ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
76
+ ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json'
77
+
78
+ cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/'
79
+ cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt'
80
+ cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/'
81
+
82
+ mapillary_image_path = './data/semantic_seg/mapillary/training/images/'
83
+ mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json'
84
+ mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/'
85
+
86
+ pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/'
87
+ pascal_file = './data/semantic_seg/pascal_part/train.json'
88
+
89
+ paco_image_path = './data/glamm_data/images/coco2017/'
90
+ paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json'
91
+
92
+ referring_refcoco_image_path = refcocog_image_path
93
+ referring_refcoco_data_path = "./data/ref_seg/"
94
+
95
+ referring_refcoco_plus_image_path = refcocog_image_path
96
+ referring_refcoco_plus_data_path = "./data/ref_seg/"
97
+
98
+ referring_refcocog_image_path = refcocog_image_path
99
+ referring_refcocog_data_path = "./data/ref_seg/"
100
+
101
+ referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/"
102
+ referring_refclef_data_path = "./data/ref_seg/"
103
+
104
+ region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
105
+ region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json"
106
+
107
+ region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
108
+ region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json"
109
+
110
+ mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
111
+ mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json'
112
+
113
+ mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
114
+ mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json'
115
+
116
+ mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
117
+ mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json'
118
+
119
+ mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
120
+ mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json'
121
+
122
+ mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
123
+ mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json'
124
+
125
+ mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
126
+ mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json'
127
+
128
+ mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
129
+ mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json'
130
+
131
+ mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
132
+ mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json'
133
+
134
+ mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
135
+ mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json'
136
+
137
+ mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
138
+ mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json'
139
+
140
+ mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
141
+ mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json'
142
+
143
+ mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017'
144
+ mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json'
145
+
146
+ mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
147
+ mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json'
148
+
149
+ mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
150
+ mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json'
151
+
152
+ mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/'
153
+ mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json'
154
+
155
+ mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017'
156
+ mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json'
157
+
158
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
159
+ max_length = int(2048 - (1024 / 64)**2 - 100)
160
+
161
+ # Scheduler & Optimizer
162
+ batch_size = 8 # per_device
163
+ accumulative_counts = 2
164
+ dataloader_num_workers = 0
165
+ max_epochs = 1
166
+ optim_type = AdamW
167
+ lr = 2e-4
168
+ betas = (0.9, 0.999)
169
+ weight_decay = 0
170
+ max_norm = 1 # grad clip
171
+ warmup_ratio = 0.03
172
+
173
+
174
+ # Save
175
+ save_steps = 2000
176
+ save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited)
177
+
178
+ # Evaluate the generation performance during the training
179
+ evaluation_freq = 2000
180
+ SYSTEM = ''
181
+ evaluation_images = './work_dirs/test.jpg'
182
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture',
183
+ 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.']
184
+
185
+ #######################################################################
186
+ # PART 2 Model & Tokenizer & Image Processor #
187
+ #######################################################################
188
+ tokenizer = dict(
189
+ type=AutoTokenizer.from_pretrained,
190
+ pretrained_model_name_or_path=llm_name_or_path,
191
+ trust_remote_code=True,
192
+ padding_side='right')
193
+
194
+ image_processor = dict(
195
+ type=CLIPImageProcessor,
196
+ do_resize=True,
197
+ size=1024,
198
+ resample=3,
199
+ do_center_crop=True,
200
+ crop_size=1024,
201
+ do_rescale=True,
202
+ do_normalize=True,
203
+ image_mean=[0.4814, 0.4578, 0.4082],
204
+ image_std=[0.2686, 0.2613, 0.2757],
205
+ do_convert_rgb=True
206
+ )
207
+
208
+ class_embed = 'convnext_large_d_320_CocoPanopticOVDataset'
209
+ num_things_classes = 80
210
+ num_stuff_classes = 53
211
+ num_classes = num_things_classes + num_stuff_classes
212
+
213
+ omgseg_model = dict(
214
+ type=OMGSegVisualEncoder,
215
+ data_preprocessor=None,
216
+ pixel_shuffle_down_ratio=2,
217
+ backbone=dict(
218
+ type=OpenCLIPBackbone_omgseg,
219
+ model_name='convnext_large_d_320',
220
+ fix=True,
221
+ init_cfg=dict(
222
+ type='clip_pretrain',
223
+ checkpoint='laion2b_s29b_b131k_ft_soup'
224
+ )
225
+ ),
226
+ panoptic_head=dict(
227
+ type=Mask2FormerVideoSemSamHead,
228
+ sphere_cls=True,
229
+ ov_path=omg_ov_class_embed_path,
230
+ enable_box_query=False,
231
+ ov_classifier_name=class_embed,
232
+ logit=None,
233
+ in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside
234
+ strides=[4, 8, 16, 32],
235
+ feat_channels=256,
236
+ out_channels=256,
237
+ num_things_classes=num_things_classes,
238
+ num_stuff_classes=num_stuff_classes,
239
+ num_queries=300,
240
+ num_transformer_feat_level=3,
241
+ pixel_decoder=dict(
242
+ type=MSDeformAttnPixelDecoder,
243
+ num_outs=3,
244
+ norm_cfg=dict(type=GroupNorm, num_groups=32),
245
+ act_cfg=dict(type=ReLU),
246
+ encoder=dict( # DeformableDetrTransformerEncoder
247
+ num_layers=6,
248
+ layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
249
+ self_attn_cfg=dict( # MultiScaleDeformableAttention
250
+ embed_dims=256,
251
+ num_heads=8,
252
+ num_levels=3,
253
+ num_points=4,
254
+ dropout=0.0,
255
+ batch_first=True),
256
+ ffn_cfg=dict(
257
+ embed_dims=256,
258
+ feedforward_channels=1024,
259
+ num_fcs=2,
260
+ ffn_drop=0.0,
261
+ act_cfg=dict(type=ReLU, inplace=True)))),
262
+ positional_encoding=dict(num_feats=128, normalize=True)),
263
+ enforce_decoder_input_project=False,
264
+ positional_encoding=dict(num_feats=128, normalize=True),
265
+ transformer_decoder=dict( # Mask2FormerTransformerDecoder
266
+ return_intermediate=True,
267
+ num_layers=9,
268
+ layer_cfg=dict( # Mask2FormerTransformerDecoderLayer
269
+ self_attn_cfg=dict( # MultiheadAttention
270
+ embed_dims=256,
271
+ num_heads=8,
272
+ dropout=0.0,
273
+ batch_first=True),
274
+ cross_attn_cfg=dict( # MultiheadAttention
275
+ embed_dims=256,
276
+ num_heads=8,
277
+ dropout=0.0,
278
+ batch_first=True),
279
+ ffn_cfg=dict(
280
+ embed_dims=256,
281
+ feedforward_channels=2048,
282
+ num_fcs=2,
283
+ ffn_drop=0.0,
284
+ act_cfg=dict(type='ReLU', inplace=True))),
285
+ init_cfg=None),
286
+ loss_cls=dict(
287
+ type=CrossEntropyLoss,
288
+ use_sigmoid=False,
289
+ loss_weight=2.0,
290
+ reduction='mean',
291
+ class_weight=[1.0] * 240 + [0.1]),
292
+ loss_mask=dict(
293
+ type=CrossEntropyLoss,
294
+ use_sigmoid=True,
295
+ reduction='mean',
296
+ loss_weight=5.0),
297
+ loss_dice=dict(
298
+ type=DiceLoss,
299
+ use_sigmoid=True,
300
+ activate=True,
301
+ reduction='mean',
302
+ naive_dice=True,
303
+ eps=1.0,
304
+ loss_weight=5.0),
305
+ loss_iou=dict(
306
+ type=FocalLoss,
307
+ use_sigmoid=True,
308
+ loss_weight=2.0,
309
+ reduction='mean')
310
+ ),
311
+ panoptic_fusion_head=dict(
312
+ type=MaskFormerFusionHead,
313
+ num_things_classes=num_things_classes,
314
+ num_stuff_classes=num_stuff_classes,
315
+ loss_panoptic=None,
316
+ init_cfg=None),
317
+ train_cfg=dict(
318
+ num_points=12544,
319
+ oversample_ratio=3.0,
320
+ importance_sample_ratio=0.75,
321
+ assigner=dict(
322
+ type=HungarianAssigner,
323
+ match_costs=[
324
+ # dict(type=FlexibleClassificationCost, weight=2.0),
325
+ dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True),
326
+ dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0)
327
+ ]),
328
+ sampler=dict(type=MaskPseudoSampler)),
329
+ test_cfg=dict(
330
+ panoptic_on=True,
331
+ # For now, the dataset does not support
332
+ # evaluating semantic segmentation metric.
333
+ semantic_on=False,
334
+ instance_on=True,
335
+ # max_per_image is for instance segmentation.
336
+ max_per_image=100,
337
+ iou_thr=0.8,
338
+ # In Mask2Former's panoptic postprocessing,
339
+ # it will filter mask area where score is less than 0.5 .
340
+ filter_low_score=True),
341
+ init_cfg=dict(
342
+ type='Pretrained',
343
+ checkpoint=omg_head_pretrain_pth_path,
344
+ )
345
+ )
346
+
347
+ model = dict(
348
+ type=OMG_LLaVA,
349
+ freeze_llm=True,
350
+ freeze_visual_encoder=True,
351
+ require_omg_decoder=False,
352
+ pretrained_pth=pretrained_pth,
353
+ text2vision_projector=True,
354
+ pixel_shuffle_ratio=2,
355
+ llm=dict(
356
+ type=AutoModelForCausalLM.from_pretrained,
357
+ pretrained_model_name_or_path=llm_name_or_path,
358
+ trust_remote_code=True,
359
+ torch_dtype=torch.float16,
360
+ quantization_config=dict(
361
+ type=BitsAndBytesConfig,
362
+ load_in_4bit=True,
363
+ load_in_8bit=False,
364
+ llm_int8_threshold=6.0,
365
+ llm_int8_has_fp16_weight=False,
366
+ bnb_4bit_compute_dtype=torch.float16,
367
+ bnb_4bit_use_double_quant=True,
368
+ bnb_4bit_quant_type='nf4')),
369
+ llm_lora=dict(
370
+ type=LoraConfig,
371
+ r=512,
372
+ lora_alpha=256,
373
+ lora_dropout=0.05,
374
+ bias='none',
375
+ task_type='CAUSAL_LM'),
376
+ visual_encoder=omgseg_model,
377
+ tokenizer=tokenizer,
378
+ )
379
+
380
+ #######################################################################
381
+ # PART 3 Dataset & Dataloader #
382
+ #######################################################################
383
+ debug=False
384
+ llava_dataset = dict(
385
+ type=LLaVADataset,
386
+ data_path=data_path,
387
+ image_folder=image_folder,
388
+ tokenizer=tokenizer,
389
+ image_processor=image_processor,
390
+ dataset_map_fn=llava_map_fn,
391
+ template_map_fn=dict(
392
+ type=template_map_fn_factory, template=prompt_template),
393
+ max_length=max_length,
394
+ pad_image_to_square=True)
395
+
396
+ glamm_refcocog_dataset_given_description = dict(
397
+ type=DecoupledRefCOCOgGCGDataset,
398
+ data_path=refcocog_ann_file,
399
+ image_folder=refcocog_image_path,
400
+ tokenizer=tokenizer,
401
+ image_processor=image_processor,
402
+ dataset_map_fn=glamm_refcocog_decoupled_given_description_map_fn,
403
+ template_map_fn=dict(
404
+ type=template_map_fn_factory, template=prompt_template),
405
+ max_length=max_length,
406
+ pad_image_to_square=True,
407
+ debug=False,
408
+ repeats=1,
409
+ mode='given_description'
410
+ )
411
+
412
+ glamm_refcocog_dataset_given_objects = dict(
413
+ type=DecoupledRefCOCOgGCGDataset,
414
+ data_path=refcocog_ann_file,
415
+ image_folder=refcocog_image_path,
416
+ tokenizer=tokenizer,
417
+ image_processor=image_processor,
418
+ dataset_map_fn=glamm_refcocog_decoupled_given_objects_map_fn,
419
+ template_map_fn=dict(
420
+ type=template_map_fn_factory, template=prompt_template),
421
+ max_length=max_length,
422
+ pad_image_to_square=True,
423
+ debug=False,
424
+ repeats=1,
425
+ mode='given_objects'
426
+ )
427
+
428
+ glamm_grandf_dataset_given_description = dict(
429
+ type=DecoupledGranDfGCGDataset,
430
+ data_path=grandf_ann_file,
431
+ image_folder=grandf_image_path,
432
+ tokenizer=tokenizer,
433
+ image_processor=image_processor,
434
+ dataset_map_fn=glamm_granf_decoupled_given_description_map_fn,
435
+ template_map_fn=dict(
436
+ type=template_map_fn_factory, template=prompt_template),
437
+ max_length=max_length,
438
+ pad_image_to_square=True,
439
+ debug=debug,
440
+ repeats=10,
441
+ mode='given_description'
442
+ )
443
+
444
+ glamm_grandf_dataset_given_objects = dict(
445
+ type=DecoupledGranDfGCGDataset,
446
+ data_path=grandf_ann_file,
447
+ image_folder=grandf_image_path,
448
+ tokenizer=tokenizer,
449
+ image_processor=image_processor,
450
+ dataset_map_fn=glamm_granf_decoupled_given_objects_map_fn,
451
+ template_map_fn=dict(
452
+ type=template_map_fn_factory, template=prompt_template),
453
+ max_length=max_length,
454
+ pad_image_to_square=True,
455
+ debug=debug,
456
+ repeats=10,
457
+ mode='given_objects'
458
+ )
459
+
460
+ glamm_psg_dataset_given_description = dict(
461
+ type=DecoupledOpenPsgGCGDataset,
462
+ data_path=psg_ann_file,
463
+ image_folder=psg_image_path,
464
+ tokenizer=tokenizer,
465
+ image_processor=image_processor,
466
+ dataset_map_fn=glamm_openpsg_decoupled_given_description_map_fn,
467
+ template_map_fn=dict(
468
+ type=template_map_fn_factory, template=prompt_template),
469
+ max_length=max_length,
470
+ pad_image_to_square=True,
471
+ debug=debug,
472
+ repeats=1,
473
+ mode='given_description'
474
+ )
475
+
476
+ glamm_psg_dataset_given_objects = dict(
477
+ type=DecoupledOpenPsgGCGDataset,
478
+ data_path=psg_ann_file,
479
+ image_folder=psg_image_path,
480
+ tokenizer=tokenizer,
481
+ image_processor=image_processor,
482
+ dataset_map_fn=glamm_openpsg_decoupled_given_objects_map_fn,
483
+ template_map_fn=dict(
484
+ type=template_map_fn_factory, template=prompt_template),
485
+ max_length=max_length,
486
+ pad_image_to_square=True,
487
+ debug=debug,
488
+ repeats=1,
489
+ mode='given_objects'
490
+ )
491
+
492
+ glamm_flickr_dataset_given_description = dict(
493
+ type=DecoupledFlickrGCGDataset,
494
+ data_path=flickr_ann_file,
495
+ image_folder=flickr_image_path,
496
+ tokenizer=tokenizer,
497
+ image_processor=image_processor,
498
+ dataset_map_fn=glamm_flickr_decoupled_given_description_map_fn,
499
+ template_map_fn=dict(
500
+ type=template_map_fn_factory, template=prompt_template),
501
+ max_length=max_length,
502
+ pad_image_to_square=True,
503
+ debug=debug,
504
+ repeats=1,
505
+ mode='given_description'
506
+ )
507
+
508
+ glamm_flickr_dataset_given_objects = dict(
509
+ type=DecoupledFlickrGCGDataset,
510
+ data_path=flickr_ann_file,
511
+ image_folder=flickr_image_path,
512
+ tokenizer=tokenizer,
513
+ image_processor=image_processor,
514
+ dataset_map_fn=glamm_flickr_decoupled_given_objects_map_fn,
515
+ template_map_fn=dict(
516
+ type=template_map_fn_factory, template=prompt_template),
517
+ max_length=max_length,
518
+ pad_image_to_square=True,
519
+ debug=debug,
520
+ repeats=1,
521
+ mode='given_objects'
522
+ )
523
+
524
+ semantic_seg_ade20k_dataset = dict(
525
+ type=ADE20kSemanticSegDataset,
526
+ data_path=ade20k_class_file,
527
+ image_folder=ade20k_image_path,
528
+ tokenizer=tokenizer,
529
+ image_processor=image_processor,
530
+ dataset_map_fn=semantic_seg_gcg_format_map_fn,
531
+ template_map_fn=dict(
532
+ type=template_map_fn_factory, template=prompt_template),
533
+ max_length=max_length,
534
+ pad_image_to_square=True,
535
+ debug=False,
536
+ repeats=1,
537
+ gcg_format=True,
538
+ )
539
+
540
+ semantic_seg_cocostuff_dataset = dict(
541
+ type=COCOStuffSemanticSegDataset,
542
+ data_path=cocostuff_class_file,
543
+ image_folder=cocostuff_image_path,
544
+ label_path=cocostuff_label_path,
545
+ tokenizer=tokenizer,
546
+ image_processor=image_processor,
547
+ dataset_map_fn=semantic_seg_gcg_format_map_fn,
548
+ template_map_fn=dict(
549
+ type=template_map_fn_factory, template=prompt_template),
550
+ max_length=max_length,
551
+ pad_image_to_square=True,
552
+ debug=False,
553
+ repeats=1,
554
+ gcg_format=True,
555
+ )
556
+
557
+ referring_seg_refcoco_dataset = dict(
558
+ type=RefcocoReferringSegDataset,
559
+ data_path=referring_refcoco_data_path,
560
+ image_folder=referring_refcoco_image_path,
561
+ tokenizer=tokenizer,
562
+ image_processor=image_processor,
563
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
564
+ template_map_fn=dict(
565
+ type=template_map_fn_factory, template=prompt_template),
566
+ max_length=max_length,
567
+ pad_image_to_square=True,
568
+ debug=False,
569
+ repeats=1,
570
+ )
571
+
572
+ referring_seg_refcoco_plus_dataset = dict(
573
+ type=Refcoco_plus_ReferringSegDataset,
574
+ data_path=referring_refcoco_plus_data_path,
575
+ image_folder=referring_refcoco_plus_image_path,
576
+ tokenizer=tokenizer,
577
+ image_processor=image_processor,
578
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
579
+ template_map_fn=dict(
580
+ type=template_map_fn_factory, template=prompt_template),
581
+ max_length=max_length,
582
+ pad_image_to_square=True,
583
+ debug=False,
584
+ repeats=1,
585
+ )
586
+
587
+ referring_seg_refcocog_dataset = dict(
588
+ type=Refcocog_ReferringSegDataset,
589
+ data_path=referring_refcocog_data_path,
590
+ image_folder=referring_refcocog_image_path,
591
+ tokenizer=tokenizer,
592
+ image_processor=image_processor,
593
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
594
+ template_map_fn=dict(
595
+ type=template_map_fn_factory, template=prompt_template),
596
+ max_length=max_length,
597
+ pad_image_to_square=True,
598
+ debug=False,
599
+ repeats=1,
600
+ )
601
+
602
+ referring_seg_refclef_dataset = dict(
603
+ type=Refclef_ReferringSegDataset,
604
+ data_path=referring_refclef_data_path,
605
+ image_folder=referring_refclef_image_path,
606
+ tokenizer=tokenizer,
607
+ image_processor=image_processor,
608
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
609
+ template_map_fn=dict(
610
+ type=template_map_fn_factory, template=prompt_template),
611
+ max_length=max_length,
612
+ pad_image_to_square=True,
613
+ debug=False,
614
+ repeats=1,
615
+ )
616
+
617
+ region_cap_osprey_dataset = dict(
618
+ type=OspreyRegionCaptionDataset,
619
+ data_path=region_cap_osprey_data_path,
620
+ image_folder=region_cap_osprey_image_path,
621
+ tokenizer=tokenizer,
622
+ image_processor=image_processor,
623
+ dataset_map_fn=osprey_region_caption_gcg_format_map_fn,
624
+ template_map_fn=dict(
625
+ type=template_map_fn_factory, template=prompt_template),
626
+ max_length=max_length,
627
+ pad_image_to_square=True,
628
+ debug=False,
629
+ repeats=1,
630
+ )
631
+
632
+ region_conversation_osprey_dataset = dict(
633
+ type=OspreyRegionConversationDataset,
634
+ data_path=region_conversation_osprey_data_path,
635
+ image_folder=region_conversation_osprey_image_path,
636
+ tokenizer=tokenizer,
637
+ image_processor=image_processor,
638
+ dataset_map_fn=osprey_region_conversation_map_fn,
639
+ template_map_fn=dict(
640
+ type=template_map_fn_factory, template=prompt_template),
641
+ max_length=max_length,
642
+ pad_image_to_square=True,
643
+ debug=False,
644
+ repeats=1,
645
+ )
646
+
647
+ mdpv_detailed_description_ade20k_dataset = dict(
648
+ type=MDPVPointDetailedCaptionDataset,
649
+ data_path=mdpv_detailed_caption_ade20k_data_path,
650
+ image_folder=mdpv_detailed_caption_ade20k_image_path,
651
+ tokenizer=tokenizer,
652
+ image_processor=image_processor,
653
+ dataset_map_fn=mdpv_points_map_fn,
654
+ template_map_fn=dict(
655
+ type=template_map_fn_factory, template=prompt_template),
656
+ max_length=max_length,
657
+ pad_image_to_square=True,
658
+ debug=False,
659
+ repeats=1,
660
+ )
661
+
662
+ mdpv_detailed_description_cocostuff_10k_dataset = dict(
663
+ type=MDPVPointDetailedCaptionDataset,
664
+ data_path=mdpv_detailed_caption_cocostuff_10k_data_path,
665
+ image_folder=mdpv_detailed_caption_cocostuff_10k_image_path,
666
+ tokenizer=tokenizer,
667
+ image_processor=image_processor,
668
+ dataset_map_fn=mdpv_points_map_fn,
669
+ template_map_fn=dict(
670
+ type=template_map_fn_factory, template=prompt_template),
671
+ max_length=max_length,
672
+ pad_image_to_square=True,
673
+ debug=False,
674
+ repeats=1,
675
+ )
676
+
677
+ mdpv_detailed_description_cocostuff_164k_dataset = dict(
678
+ type=MDPVPointDetailedCaptionDataset,
679
+ data_path=mdpv_detailed_caption_cocostuff_164k_data_path,
680
+ image_folder=mdpv_detailed_caption_cocostuff_164k_image_path,
681
+ tokenizer=tokenizer,
682
+ image_processor=image_processor,
683
+ dataset_map_fn=mdpv_points_map_fn,
684
+ template_map_fn=dict(
685
+ type=template_map_fn_factory, template=prompt_template),
686
+ max_length=max_length,
687
+ pad_image_to_square=True,
688
+ debug=False,
689
+ repeats=1,
690
+ )
691
+
692
+ mdpv_detailed_description_vg_dataset = dict(
693
+ type=MDPVPointDetailedCaptionDataset,
694
+ data_path=mdpv_detailed_caption_vg_data_path,
695
+ image_folder=mdpv_detailed_caption_vg_image_path,
696
+ tokenizer=tokenizer,
697
+ image_processor=image_processor,
698
+ dataset_map_fn=mdpv_points_map_fn,
699
+ template_map_fn=dict(
700
+ type=template_map_fn_factory, template=prompt_template),
701
+ max_length=max_length,
702
+ pad_image_to_square=True,
703
+ debug=False,
704
+ repeats=1,
705
+ )
706
+
707
+ mdpv_brief_description_vg_dataset = dict(
708
+ type=MDPVPointBriefCaptionDataset,
709
+ data_path=mdpv_brief_caption_vg_data_path,
710
+ image_folder=mdpv_brief_caption_vg_image_path,
711
+ tokenizer=tokenizer,
712
+ image_processor=image_processor,
713
+ dataset_map_fn=mdpv_points_map_fn,
714
+ template_map_fn=dict(
715
+ type=template_map_fn_factory, template=prompt_template),
716
+ max_length=max_length,
717
+ pad_image_to_square=True,
718
+ debug=False,
719
+ repeats=1,
720
+ )
721
+
722
+ mdpv_brief_description_cocostuff10k_dataset = dict(
723
+ type=MDPVPointBriefCaptionDataset,
724
+ data_path=mdpv_brief_caption_cocostuff_10k_data_path,
725
+ image_folder=mdpv_brief_caption_cocostuff_10k_image_path,
726
+ tokenizer=tokenizer,
727
+ image_processor=image_processor,
728
+ dataset_map_fn=mdpv_points_map_fn,
729
+ template_map_fn=dict(
730
+ type=template_map_fn_factory, template=prompt_template),
731
+ max_length=max_length,
732
+ pad_image_to_square=True,
733
+ debug=False,
734
+ repeats=1,
735
+ )
736
+
737
+ mdpv_brief_description_cocostuff164k_dataset = dict(
738
+ type=MDPVPointBriefCaptionDataset,
739
+ data_path=mdpv_brief_caption_cocostuff_164k_data_path,
740
+ image_folder=mdpv_brief_caption_cocostuff_164k_image_path,
741
+ tokenizer=tokenizer,
742
+ image_processor=image_processor,
743
+ dataset_map_fn=mdpv_points_map_fn,
744
+ template_map_fn=dict(
745
+ type=template_map_fn_factory, template=prompt_template),
746
+ max_length=max_length,
747
+ pad_image_to_square=True,
748
+ debug=False,
749
+ repeats=1,
750
+ )
751
+
752
+ mdpv_brief_description_ade20k_dataset = dict(
753
+ type=MDPVPointBriefCaptionDataset,
754
+ data_path=mdpv_brief_caption_ade20k_data_path,
755
+ image_folder=mdpv_brief_caption_ade20k_image_path,
756
+ tokenizer=tokenizer,
757
+ image_processor=image_processor,
758
+ dataset_map_fn=mdpv_points_map_fn,
759
+ template_map_fn=dict(
760
+ type=template_map_fn_factory, template=prompt_template),
761
+ max_length=max_length,
762
+ pad_image_to_square=True,
763
+ debug=False,
764
+ repeats=1,
765
+ )
766
+
767
+ mdpv_brief_description_lvis_dataset = dict(
768
+ type=MDPVPointBriefCaptionDataset,
769
+ data_path=mdpv_brief_caption_lvis_data_path,
770
+ image_folder=mdpv_brief_caption_lvis_image_path,
771
+ tokenizer=tokenizer,
772
+ image_processor=image_processor,
773
+ dataset_map_fn=mdpv_points_map_fn,
774
+ template_map_fn=dict(
775
+ type=template_map_fn_factory, template=prompt_template),
776
+ max_length=max_length,
777
+ pad_image_to_square=True,
778
+ debug=False,
779
+ repeats=1,
780
+ )
781
+
782
+ mdpv_qa_vg_dataset = dict(
783
+ type=MDPVPointBriefCaptionDataset,
784
+ data_path=mdpv_qa_vg_data_path,
785
+ image_folder=mdpv_qa_vg_image_path,
786
+ tokenizer=tokenizer,
787
+ image_processor=image_processor,
788
+ dataset_map_fn=mdpv_points_map_fn,
789
+ template_map_fn=dict(
790
+ type=template_map_fn_factory, template=prompt_template),
791
+ max_length=max_length,
792
+ pad_image_to_square=True,
793
+ debug=False,
794
+ repeats=1,
795
+ )
796
+
797
+ mdpv_qa_ade20k_dataset = dict(
798
+ type=MDPVPointBriefCaptionDataset,
799
+ data_path=mdpv_qa_ade20k_data_path,
800
+ image_folder=mdpv_qa_ade20k_image_path,
801
+ tokenizer=tokenizer,
802
+ image_processor=image_processor,
803
+ dataset_map_fn=mdpv_points_map_fn,
804
+ template_map_fn=dict(
805
+ type=template_map_fn_factory, template=prompt_template),
806
+ max_length=max_length,
807
+ pad_image_to_square=True,
808
+ debug=False,
809
+ repeats=1,
810
+ )
811
+
812
+ mdpv_qa_lvis_dataset = dict(
813
+ type=MDPVPointBriefCaptionDataset,
814
+ data_path=mdpv_qa_lvis_data_path,
815
+ image_folder=mdpv_qa_lvis_image_path,
816
+ tokenizer=tokenizer,
817
+ image_processor=image_processor,
818
+ dataset_map_fn=mdpv_points_map_fn,
819
+ template_map_fn=dict(
820
+ type=template_map_fn_factory, template=prompt_template),
821
+ max_length=max_length,
822
+ pad_image_to_square=True,
823
+ debug=False,
824
+ repeats=1,
825
+ )
826
+
827
+ mdpv_qa_cocostuff10k_dataset = dict(
828
+ type=MDPVPointBriefCaptionDataset,
829
+ data_path=mdpv_qa_cocostuff10k_data_path,
830
+ image_folder=mdpv_qa_cocostuff10k_image_path,
831
+ tokenizer=tokenizer,
832
+ image_processor=image_processor,
833
+ dataset_map_fn=mdpv_points_map_fn,
834
+ template_map_fn=dict(
835
+ type=template_map_fn_factory, template=prompt_template),
836
+ max_length=max_length,
837
+ pad_image_to_square=True,
838
+ debug=False,
839
+ repeats=1,
840
+ )
841
+
842
+ mdpv_qa_cocostuff164k_dataset = dict(
843
+ type=MDPVPointBriefCaptionDataset,
844
+ data_path=mdpv_qa_cocostuff164k_data_path,
845
+ image_folder=mdpv_qa_cocostuff164k_image_path,
846
+ tokenizer=tokenizer,
847
+ image_processor=image_processor,
848
+ dataset_map_fn=mdpv_points_map_fn,
849
+ template_map_fn=dict(
850
+ type=template_map_fn_factory, template=prompt_template),
851
+ max_length=max_length,
852
+ pad_image_to_square=True,
853
+ debug=False,
854
+ repeats=1,
855
+ )
856
+
857
+ mdpv_multi_points_openpsg_dataset = dict(
858
+ type=MDPVPointBriefCaptionDataset,
859
+ data_path=mdpv_multi_points_openpsg_data_path,
860
+ image_folder=mdpv_multi_points_openpsg_image_path,
861
+ tokenizer=tokenizer,
862
+ image_processor=image_processor,
863
+ dataset_map_fn=mdpv_points_map_fn,
864
+ template_map_fn=dict(
865
+ type=template_map_fn_factory, template=prompt_template),
866
+ max_length=max_length,
867
+ pad_image_to_square=True,
868
+ debug=False,
869
+ repeats=1,
870
+ )
871
+
872
+ mdpv_multi_points_flicker30k_dataset = dict(
873
+ type=MDPVPointBriefCaptionDataset,
874
+ data_path=mdpv_multi_points_flicker30k_data_path,
875
+ image_folder=mdpv_multi_points_flicker30k_image_path,
876
+ tokenizer=tokenizer,
877
+ image_processor=image_processor,
878
+ dataset_map_fn=mdpv_points_map_fn,
879
+ template_map_fn=dict(
880
+ type=template_map_fn_factory, template=prompt_template),
881
+ max_length=max_length,
882
+ pad_image_to_square=True,
883
+ debug=False,
884
+ repeats=1,
885
+ )
886
+
887
+ train_dataset = dict(
888
+ type=CombineDataset,
889
+ datasets_cfgs=[llava_dataset,
890
+ glamm_flickr_dataset_given_description, glamm_flickr_dataset_given_objects,
891
+ glamm_refcocog_dataset_given_objects, glamm_refcocog_dataset_given_description,
892
+ glamm_psg_dataset_given_description, glamm_psg_dataset_given_objects,
893
+ glamm_grandf_dataset_given_description, glamm_grandf_dataset_given_objects,
894
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x
895
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset,
896
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset,
897
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
898
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x
899
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
900
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset,
901
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
902
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset,
903
+ region_cap_osprey_dataset, region_conversation_osprey_dataset,
904
+ mdpv_detailed_description_ade20k_dataset,
905
+ mdpv_detailed_description_cocostuff_10k_dataset,
906
+ mdpv_detailed_description_cocostuff_164k_dataset,
907
+ mdpv_detailed_description_vg_dataset,
908
+ mdpv_brief_description_lvis_dataset,
909
+ mdpv_brief_description_vg_dataset,
910
+ mdpv_brief_description_ade20k_dataset,
911
+ mdpv_brief_description_cocostuff10k_dataset,
912
+ mdpv_brief_description_cocostuff164k_dataset,
913
+ mdpv_qa_vg_dataset,
914
+ mdpv_qa_lvis_dataset,
915
+ mdpv_qa_ade20k_dataset,
916
+ mdpv_qa_cocostuff10k_dataset,
917
+ mdpv_qa_cocostuff164k_dataset,
918
+ mdpv_multi_points_flicker30k_dataset,
919
+ mdpv_multi_points_openpsg_dataset,],
920
+ )
921
+
922
+ train_dataloader = dict(
923
+ batch_size=batch_size,
924
+ num_workers=dataloader_num_workers,
925
+ dataset=train_dataset,
926
+ sampler=dict(
927
+ type=LengthGroupedSampler,
928
+ length_property='modality_length',
929
+ per_device_batch_size=batch_size * accumulative_counts),
930
+ collate_fn=dict(type=omg_llava_collate_fn))
931
+
932
+ #######################################################################
933
+ # PART 4 Scheduler & Optimizer #
934
+ #######################################################################
935
+ # optimizer
936
+ optim_wrapper = dict(
937
+ type=AmpOptimWrapper,
938
+ optimizer=dict(
939
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
940
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
941
+ accumulative_counts=accumulative_counts,
942
+ loss_scale='dynamic',
943
+ dtype='float16')
944
+
945
+ # learning policy
946
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
947
+ param_scheduler = [
948
+ dict(
949
+ type=LinearLR,
950
+ start_factor=1e-5,
951
+ by_epoch=True,
952
+ begin=0,
953
+ end=warmup_ratio * max_epochs,
954
+ convert_to_iter_based=True),
955
+ dict(
956
+ type=CosineAnnealingLR,
957
+ eta_min=0.0,
958
+ by_epoch=True,
959
+ begin=warmup_ratio * max_epochs,
960
+ end=max_epochs,
961
+ convert_to_iter_based=True)
962
+ ]
963
+
964
+ # train, val, test setting
965
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
966
+
967
+ #######################################################################
968
+ # PART 5 Runtime #
969
+ #######################################################################
970
+ # Log the dialogue periodically during the training process, optional
971
+ custom_hooks = [
972
+ dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer),
973
+ dict(
974
+ type=EvaluateChatHook_withSpecialTokens,
975
+ tokenizer=tokenizer,
976
+ image_processor=image_processor,
977
+ every_n_iters=evaluation_freq,
978
+ evaluation_inputs=evaluation_inputs,
979
+ evaluation_images=evaluation_images,
980
+ system=SYSTEM,
981
+ prompt_template=prompt_template)
982
+ ]
983
+
984
+ # configure default hooks
985
+ default_hooks = dict(
986
+ # record the time of every iteration.
987
+ timer=dict(type=IterTimerHook),
988
+ # print log every 10 iterations.
989
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
990
+ # enable the parameter scheduler.
991
+ param_scheduler=dict(type=ParamSchedulerHook),
992
+ # save checkpoint per `save_steps`.
993
+ checkpoint=dict(
994
+ type=CheckpointHook,
995
+ by_epoch=False,
996
+ interval=save_steps,
997
+ max_keep_ckpts=save_total_limit),
998
+ # set sampler seed in distributed evrionment.
999
+ sampler_seed=dict(type=DistSamplerSeedHook),
1000
+ )
1001
+
1002
+ # configure environment
1003
+ env_cfg = dict(
1004
+ # whether to enable cudnn benchmark
1005
+ cudnn_benchmark=False,
1006
+ # set multi process parameters
1007
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
1008
+ # set distributed parameters
1009
+ dist_cfg=dict(backend='nccl'),
1010
+ )
1011
+
1012
+ # set visualizer
1013
+ visualizer = None
1014
+
1015
+ # set log level
1016
+ log_level = 'INFO'
1017
+
1018
+ # load from which checkpoint
1019
+ load_from = None
1020
+
1021
+ # whether to resume training from the loaded checkpoint
1022
+ resume = False
1023
+
1024
+ # Defaults to use random seed and disable `deterministic`
1025
+ randomness = dict(seed=None, deterministic=False)
1026
+
1027
+ # set log processor
1028
+ log_processor = dict(by_epoch=False)
omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_decoupleGCG_8gpus_debug.py ADDED
@@ -0,0 +1,1000 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
4
+ LoggerHook, ParamSchedulerHook)
5
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
6
+ from peft import LoraConfig
7
+ from torch.optim import AdamW
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig, CLIPImageProcessor,
10
+ CLIPVisionModel)
11
+
12
+ from omg_llava.dataset import LLaVADataset
13
+ from omg_llava.dataset.collect_fns import omg_llava_collate_fn
14
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
15
+ from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\
16
+ CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\
17
+ ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\
18
+ PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\
19
+ RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\
20
+ Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\
21
+ OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\
22
+ OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\
23
+ MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\
24
+ semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\
25
+ referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn,\
26
+ DecoupledGranDfGCGDataset, DecoupledOpenPsgGCGDataset, DecoupledRefCOCOgGCGDataset, DecoupledFlickrGCGDataset,\
27
+ glamm_openpsg_decoupled_given_description_map_fn, glamm_openpsg_decoupled_given_objects_map_fn,\
28
+ glamm_flickr_decoupled_given_objects_map_fn, glamm_flickr_decoupled_given_description_map_fn,\
29
+ glamm_granf_decoupled_given_objects_map_fn, glamm_granf_decoupled_given_description_map_fn,\
30
+ glamm_refcocog_decoupled_given_objects_map_fn, glamm_refcocog_decoupled_given_description_map_fn
31
+
32
+ from xtuner.dataset.samplers import LengthGroupedSampler
33
+ from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens
34
+ from xtuner.engine.runner import TrainLoop
35
+ from omg_llava.model import OMG_LLaVA
36
+ from xtuner.utils import PROMPT_TEMPLATE
37
+ from omg_llava.model import OpenCLIPBackbone_omgseg
38
+ from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead
39
+
40
+ from torch.nn import GroupNorm, ReLU
41
+
42
+ from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \
43
+ DiceLoss, MaskFormerFusionHead, FocalLoss
44
+ from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost
45
+ from mmdet.models.task_modules.samplers import MaskPseudoSampler
46
+
47
+ #######################################################################
48
+ # PART 1 Settings #
49
+ #######################################################################
50
+ # Model
51
+ llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path
52
+ pretrained_pth = './pretrained/omg_llava/omg_llava_7b_pretrain_1024image_8gpus.pth'
53
+ omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path
54
+ omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path
55
+
56
+ # Data
57
+ data_root = './data/llava_data/'
58
+ data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json'
59
+ image_folder = data_root + 'llava_images'
60
+
61
+ glamm_data_root = './data/glamm_data/'
62
+
63
+ refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/'
64
+ refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json'
65
+
66
+ grandf_image_path = glamm_data_root + 'images/grandf/train/'
67
+ grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json'
68
+
69
+ flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/'
70
+ flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json'
71
+
72
+ psg_image_path = glamm_data_root + 'images/coco2017/'
73
+ psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json'
74
+
75
+ ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
76
+ ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json'
77
+
78
+ cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/'
79
+ cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt'
80
+ cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/'
81
+
82
+ mapillary_image_path = './data/semantic_seg/mapillary/training/images/'
83
+ mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json'
84
+ mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/'
85
+
86
+ pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/'
87
+ pascal_file = './data/semantic_seg/pascal_part/train.json'
88
+
89
+ paco_image_path = './data/glamm_data/images/coco2017/'
90
+ paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json'
91
+
92
+ referring_refcoco_image_path = refcocog_image_path
93
+ referring_refcoco_data_path = "./data/ref_seg/"
94
+
95
+ referring_refcoco_plus_image_path = refcocog_image_path
96
+ referring_refcoco_plus_data_path = "./data/ref_seg/"
97
+
98
+ referring_refcocog_image_path = refcocog_image_path
99
+ referring_refcocog_data_path = "./data/ref_seg/"
100
+
101
+ referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/"
102
+ referring_refclef_data_path = "./data/ref_seg/"
103
+
104
+ region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
105
+ region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json"
106
+
107
+ region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
108
+ region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json"
109
+
110
+ mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
111
+ mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json'
112
+
113
+ mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
114
+ mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json'
115
+
116
+ mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
117
+ mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json'
118
+
119
+ mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
120
+ mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json'
121
+
122
+ mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
123
+ mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json'
124
+
125
+ mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
126
+ mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json'
127
+
128
+ mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
129
+ mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json'
130
+
131
+ mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
132
+ mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json'
133
+
134
+ mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
135
+ mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json'
136
+
137
+ mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
138
+ mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json'
139
+
140
+ mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
141
+ mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json'
142
+
143
+ mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017'
144
+ mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json'
145
+
146
+ mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
147
+ mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json'
148
+
149
+ mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
150
+ mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json'
151
+
152
+ mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/'
153
+ mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json'
154
+
155
+ mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017'
156
+ mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json'
157
+
158
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
159
+ max_length = int(2048 - (1024 / 64)**2 - 100)
160
+
161
+ # Scheduler & Optimizer
162
+ batch_size = 8 # per_device
163
+ accumulative_counts = 2
164
+ dataloader_num_workers = 0
165
+ max_epochs = 1
166
+ optim_type = AdamW
167
+ lr = 2e-4
168
+ betas = (0.9, 0.999)
169
+ weight_decay = 0
170
+ max_norm = 1 # grad clip
171
+ warmup_ratio = 0.03
172
+
173
+
174
+ # Save
175
+ save_steps = 2000
176
+ save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited)
177
+
178
+ # Evaluate the generation performance during the training
179
+ evaluation_freq = 2000
180
+ SYSTEM = ''
181
+ evaluation_images = './work_dirs/test.jpg'
182
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture',
183
+ 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.']
184
+
185
+ #######################################################################
186
+ # PART 2 Model & Tokenizer & Image Processor #
187
+ #######################################################################
188
+ tokenizer = dict(
189
+ type=AutoTokenizer.from_pretrained,
190
+ pretrained_model_name_or_path=llm_name_or_path,
191
+ trust_remote_code=True,
192
+ padding_side='right')
193
+
194
+ image_processor = dict(
195
+ type=CLIPImageProcessor,
196
+ do_resize=True,
197
+ size=1024,
198
+ resample=3,
199
+ do_center_crop=True,
200
+ crop_size=1024,
201
+ do_rescale=True,
202
+ do_normalize=True,
203
+ image_mean=[0.4814, 0.4578, 0.4082],
204
+ image_std=[0.2686, 0.2613, 0.2757],
205
+ do_convert_rgb=True
206
+ )
207
+
208
+ class_embed = 'convnext_large_d_320_CocoPanopticOVDataset'
209
+ num_things_classes = 80
210
+ num_stuff_classes = 53
211
+ num_classes = num_things_classes + num_stuff_classes
212
+
213
+ omgseg_model = dict(
214
+ type=OMGSegVisualEncoder,
215
+ data_preprocessor=None,
216
+ pixel_shuffle_down_ratio=2,
217
+ backbone=dict(
218
+ type=OpenCLIPBackbone_omgseg,
219
+ model_name='convnext_large_d_320',
220
+ fix=True,
221
+ init_cfg=dict(
222
+ type='clip_pretrain',
223
+ checkpoint='laion2b_s29b_b131k_ft_soup'
224
+ )
225
+ ),
226
+ panoptic_head=dict(
227
+ type=Mask2FormerVideoSemSamHead,
228
+ sphere_cls=True,
229
+ ov_path=omg_ov_class_embed_path,
230
+ enable_box_query=False,
231
+ ov_classifier_name=class_embed,
232
+ logit=None,
233
+ in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside
234
+ strides=[4, 8, 16, 32],
235
+ feat_channels=256,
236
+ out_channels=256,
237
+ num_things_classes=num_things_classes,
238
+ num_stuff_classes=num_stuff_classes,
239
+ num_queries=300,
240
+ num_transformer_feat_level=3,
241
+ pixel_decoder=dict(
242
+ type=MSDeformAttnPixelDecoder,
243
+ num_outs=3,
244
+ norm_cfg=dict(type=GroupNorm, num_groups=32),
245
+ act_cfg=dict(type=ReLU),
246
+ encoder=dict( # DeformableDetrTransformerEncoder
247
+ num_layers=6,
248
+ layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
249
+ self_attn_cfg=dict( # MultiScaleDeformableAttention
250
+ embed_dims=256,
251
+ num_heads=8,
252
+ num_levels=3,
253
+ num_points=4,
254
+ dropout=0.0,
255
+ batch_first=True),
256
+ ffn_cfg=dict(
257
+ embed_dims=256,
258
+ feedforward_channels=1024,
259
+ num_fcs=2,
260
+ ffn_drop=0.0,
261
+ act_cfg=dict(type=ReLU, inplace=True)))),
262
+ positional_encoding=dict(num_feats=128, normalize=True)),
263
+ enforce_decoder_input_project=False,
264
+ positional_encoding=dict(num_feats=128, normalize=True),
265
+ transformer_decoder=dict( # Mask2FormerTransformerDecoder
266
+ return_intermediate=True,
267
+ num_layers=9,
268
+ layer_cfg=dict( # Mask2FormerTransformerDecoderLayer
269
+ self_attn_cfg=dict( # MultiheadAttention
270
+ embed_dims=256,
271
+ num_heads=8,
272
+ dropout=0.0,
273
+ batch_first=True),
274
+ cross_attn_cfg=dict( # MultiheadAttention
275
+ embed_dims=256,
276
+ num_heads=8,
277
+ dropout=0.0,
278
+ batch_first=True),
279
+ ffn_cfg=dict(
280
+ embed_dims=256,
281
+ feedforward_channels=2048,
282
+ num_fcs=2,
283
+ ffn_drop=0.0,
284
+ act_cfg=dict(type='ReLU', inplace=True))),
285
+ init_cfg=None),
286
+ loss_cls=dict(
287
+ type=CrossEntropyLoss,
288
+ use_sigmoid=False,
289
+ loss_weight=2.0,
290
+ reduction='mean',
291
+ class_weight=[1.0] * 240 + [0.1]),
292
+ loss_mask=dict(
293
+ type=CrossEntropyLoss,
294
+ use_sigmoid=True,
295
+ reduction='mean',
296
+ loss_weight=5.0),
297
+ loss_dice=dict(
298
+ type=DiceLoss,
299
+ use_sigmoid=True,
300
+ activate=True,
301
+ reduction='mean',
302
+ naive_dice=True,
303
+ eps=1.0,
304
+ loss_weight=5.0),
305
+ loss_iou=dict(
306
+ type=FocalLoss,
307
+ use_sigmoid=True,
308
+ loss_weight=2.0,
309
+ reduction='mean')
310
+ ),
311
+ panoptic_fusion_head=dict(
312
+ type=MaskFormerFusionHead,
313
+ num_things_classes=num_things_classes,
314
+ num_stuff_classes=num_stuff_classes,
315
+ loss_panoptic=None,
316
+ init_cfg=None),
317
+ train_cfg=dict(
318
+ num_points=12544,
319
+ oversample_ratio=3.0,
320
+ importance_sample_ratio=0.75,
321
+ assigner=dict(
322
+ type=HungarianAssigner,
323
+ match_costs=[
324
+ # dict(type=FlexibleClassificationCost, weight=2.0),
325
+ dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True),
326
+ dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0)
327
+ ]),
328
+ sampler=dict(type=MaskPseudoSampler)),
329
+ test_cfg=dict(
330
+ panoptic_on=True,
331
+ # For now, the dataset does not support
332
+ # evaluating semantic segmentation metric.
333
+ semantic_on=False,
334
+ instance_on=True,
335
+ # max_per_image is for instance segmentation.
336
+ max_per_image=100,
337
+ iou_thr=0.8,
338
+ # In Mask2Former's panoptic postprocessing,
339
+ # it will filter mask area where score is less than 0.5 .
340
+ filter_low_score=True),
341
+ init_cfg=dict(
342
+ type='Pretrained',
343
+ checkpoint=omg_head_pretrain_pth_path,
344
+ )
345
+ )
346
+
347
+ model = dict(
348
+ type=OMG_LLaVA,
349
+ freeze_llm=True,
350
+ freeze_visual_encoder=True,
351
+ require_omg_decoder=False,
352
+ pretrained_pth=pretrained_pth,
353
+ text2vision_projector=True,
354
+ pixel_shuffle_ratio=2,
355
+ llm=dict(
356
+ type=AutoModelForCausalLM.from_pretrained,
357
+ pretrained_model_name_or_path=llm_name_or_path,
358
+ trust_remote_code=True,
359
+ torch_dtype=torch.float16,
360
+ quantization_config=dict(
361
+ type=BitsAndBytesConfig,
362
+ load_in_4bit=True,
363
+ load_in_8bit=False,
364
+ llm_int8_threshold=6.0,
365
+ llm_int8_has_fp16_weight=False,
366
+ bnb_4bit_compute_dtype=torch.float16,
367
+ bnb_4bit_use_double_quant=True,
368
+ bnb_4bit_quant_type='nf4')),
369
+ llm_lora=dict(
370
+ type=LoraConfig,
371
+ r=512,
372
+ lora_alpha=256,
373
+ lora_dropout=0.05,
374
+ bias='none',
375
+ task_type='CAUSAL_LM'),
376
+ visual_encoder=omgseg_model,
377
+ tokenizer=tokenizer,
378
+ )
379
+
380
+ #######################################################################
381
+ # PART 3 Dataset & Dataloader #
382
+ #######################################################################
383
+ debug=False
384
+ llava_dataset = dict(
385
+ type=LLaVADataset,
386
+ data_path=data_path,
387
+ image_folder=image_folder,
388
+ tokenizer=tokenizer,
389
+ image_processor=image_processor,
390
+ dataset_map_fn=llava_map_fn,
391
+ template_map_fn=dict(
392
+ type=template_map_fn_factory, template=prompt_template),
393
+ max_length=max_length,
394
+ pad_image_to_square=True)
395
+
396
+ glamm_refcocog_dataset_given_description = dict(
397
+ type=DecoupledRefCOCOgGCGDataset,
398
+ data_path=refcocog_ann_file,
399
+ image_folder=refcocog_image_path,
400
+ tokenizer=tokenizer,
401
+ image_processor=image_processor,
402
+ dataset_map_fn=glamm_refcocog_decoupled_given_description_map_fn,
403
+ template_map_fn=dict(
404
+ type=template_map_fn_factory, template=prompt_template),
405
+ max_length=max_length,
406
+ pad_image_to_square=True,
407
+ debug=False,
408
+ repeats=1,
409
+ mode='given_description'
410
+ )
411
+
412
+ glamm_refcocog_dataset_given_objects = dict(
413
+ type=DecoupledRefCOCOgGCGDataset,
414
+ data_path=refcocog_ann_file,
415
+ image_folder=refcocog_image_path,
416
+ tokenizer=tokenizer,
417
+ image_processor=image_processor,
418
+ dataset_map_fn=glamm_refcocog_decoupled_given_objects_map_fn,
419
+ template_map_fn=dict(
420
+ type=template_map_fn_factory, template=prompt_template),
421
+ max_length=max_length,
422
+ pad_image_to_square=True,
423
+ debug=False,
424
+ repeats=1,
425
+ mode='given_objects'
426
+ )
427
+
428
+ glamm_grandf_dataset_given_description = dict(
429
+ type=DecoupledGranDfGCGDataset,
430
+ data_path=grandf_ann_file,
431
+ image_folder=grandf_image_path,
432
+ tokenizer=tokenizer,
433
+ image_processor=image_processor,
434
+ dataset_map_fn=glamm_granf_decoupled_given_description_map_fn,
435
+ template_map_fn=dict(
436
+ type=template_map_fn_factory, template=prompt_template),
437
+ max_length=max_length,
438
+ pad_image_to_square=True,
439
+ debug=debug,
440
+ repeats=10,
441
+ mode='given_description'
442
+ )
443
+
444
+ glamm_grandf_dataset_given_objects = dict(
445
+ type=DecoupledGranDfGCGDataset,
446
+ data_path=grandf_ann_file,
447
+ image_folder=grandf_image_path,
448
+ tokenizer=tokenizer,
449
+ image_processor=image_processor,
450
+ dataset_map_fn=glamm_granf_decoupled_given_objects_map_fn,
451
+ template_map_fn=dict(
452
+ type=template_map_fn_factory, template=prompt_template),
453
+ max_length=max_length,
454
+ pad_image_to_square=True,
455
+ debug=debug,
456
+ repeats=10,
457
+ mode='given_objects'
458
+ )
459
+
460
+ glamm_psg_dataset_given_description = dict(
461
+ type=DecoupledOpenPsgGCGDataset,
462
+ data_path=psg_ann_file,
463
+ image_folder=psg_image_path,
464
+ tokenizer=tokenizer,
465
+ image_processor=image_processor,
466
+ dataset_map_fn=glamm_openpsg_decoupled_given_description_map_fn,
467
+ template_map_fn=dict(
468
+ type=template_map_fn_factory, template=prompt_template),
469
+ max_length=max_length,
470
+ pad_image_to_square=True,
471
+ debug=debug,
472
+ repeats=1,
473
+ mode='given_description'
474
+ )
475
+
476
+ glamm_psg_dataset_given_objects = dict(
477
+ type=DecoupledOpenPsgGCGDataset,
478
+ data_path=psg_ann_file,
479
+ image_folder=psg_image_path,
480
+ tokenizer=tokenizer,
481
+ image_processor=image_processor,
482
+ dataset_map_fn=glamm_openpsg_decoupled_given_objects_map_fn,
483
+ template_map_fn=dict(
484
+ type=template_map_fn_factory, template=prompt_template),
485
+ max_length=max_length,
486
+ pad_image_to_square=True,
487
+ debug=debug,
488
+ repeats=1,
489
+ mode='given_objects'
490
+ )
491
+
492
+ glamm_flickr_dataset_given_description = dict(
493
+ type=DecoupledFlickrGCGDataset,
494
+ data_path=flickr_ann_file,
495
+ image_folder=flickr_image_path,
496
+ tokenizer=tokenizer,
497
+ image_processor=image_processor,
498
+ dataset_map_fn=glamm_flickr_decoupled_given_description_map_fn,
499
+ template_map_fn=dict(
500
+ type=template_map_fn_factory, template=prompt_template),
501
+ max_length=max_length,
502
+ pad_image_to_square=True,
503
+ debug=debug,
504
+ repeats=1,
505
+ mode='given_description'
506
+ )
507
+
508
+ glamm_flickr_dataset_given_objects = dict(
509
+ type=DecoupledFlickrGCGDataset,
510
+ data_path=flickr_ann_file,
511
+ image_folder=flickr_image_path,
512
+ tokenizer=tokenizer,
513
+ image_processor=image_processor,
514
+ dataset_map_fn=glamm_flickr_decoupled_given_objects_map_fn,
515
+ template_map_fn=dict(
516
+ type=template_map_fn_factory, template=prompt_template),
517
+ max_length=max_length,
518
+ pad_image_to_square=True,
519
+ debug=debug,
520
+ repeats=1,
521
+ mode='given_objects'
522
+ )
523
+
524
+ semantic_seg_ade20k_dataset = dict(
525
+ type=ADE20kSemanticSegDataset,
526
+ data_path=ade20k_class_file,
527
+ image_folder=ade20k_image_path,
528
+ tokenizer=tokenizer,
529
+ image_processor=image_processor,
530
+ dataset_map_fn=semantic_seg_gcg_format_map_fn,
531
+ template_map_fn=dict(
532
+ type=template_map_fn_factory, template=prompt_template),
533
+ max_length=max_length,
534
+ pad_image_to_square=True,
535
+ debug=False,
536
+ repeats=1,
537
+ gcg_format=True,
538
+ )
539
+
540
+ semantic_seg_cocostuff_dataset = dict(
541
+ type=COCOStuffSemanticSegDataset,
542
+ data_path=cocostuff_class_file,
543
+ image_folder=cocostuff_image_path,
544
+ label_path=cocostuff_label_path,
545
+ tokenizer=tokenizer,
546
+ image_processor=image_processor,
547
+ dataset_map_fn=semantic_seg_gcg_format_map_fn,
548
+ template_map_fn=dict(
549
+ type=template_map_fn_factory, template=prompt_template),
550
+ max_length=max_length,
551
+ pad_image_to_square=True,
552
+ debug=False,
553
+ repeats=1,
554
+ gcg_format=True,
555
+ )
556
+
557
+ referring_seg_refcoco_dataset = dict(
558
+ type=RefcocoReferringSegDataset,
559
+ data_path=referring_refcoco_data_path,
560
+ image_folder=referring_refcoco_image_path,
561
+ tokenizer=tokenizer,
562
+ image_processor=image_processor,
563
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
564
+ template_map_fn=dict(
565
+ type=template_map_fn_factory, template=prompt_template),
566
+ max_length=max_length,
567
+ pad_image_to_square=True,
568
+ debug=False,
569
+ repeats=1,
570
+ )
571
+
572
+ referring_seg_refcoco_plus_dataset = dict(
573
+ type=Refcoco_plus_ReferringSegDataset,
574
+ data_path=referring_refcoco_plus_data_path,
575
+ image_folder=referring_refcoco_plus_image_path,
576
+ tokenizer=tokenizer,
577
+ image_processor=image_processor,
578
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
579
+ template_map_fn=dict(
580
+ type=template_map_fn_factory, template=prompt_template),
581
+ max_length=max_length,
582
+ pad_image_to_square=True,
583
+ debug=False,
584
+ repeats=1,
585
+ )
586
+
587
+ referring_seg_refcocog_dataset = dict(
588
+ type=Refcocog_ReferringSegDataset,
589
+ data_path=referring_refcocog_data_path,
590
+ image_folder=referring_refcocog_image_path,
591
+ tokenizer=tokenizer,
592
+ image_processor=image_processor,
593
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
594
+ template_map_fn=dict(
595
+ type=template_map_fn_factory, template=prompt_template),
596
+ max_length=max_length,
597
+ pad_image_to_square=True,
598
+ debug=False,
599
+ repeats=1,
600
+ )
601
+
602
+ referring_seg_refclef_dataset = dict(
603
+ type=Refclef_ReferringSegDataset,
604
+ data_path=referring_refclef_data_path,
605
+ image_folder=referring_refclef_image_path,
606
+ tokenizer=tokenizer,
607
+ image_processor=image_processor,
608
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
609
+ template_map_fn=dict(
610
+ type=template_map_fn_factory, template=prompt_template),
611
+ max_length=max_length,
612
+ pad_image_to_square=True,
613
+ debug=False,
614
+ repeats=1,
615
+ )
616
+
617
+ region_cap_osprey_dataset = dict(
618
+ type=OspreyRegionCaptionDataset,
619
+ data_path=region_cap_osprey_data_path,
620
+ image_folder=region_cap_osprey_image_path,
621
+ tokenizer=tokenizer,
622
+ image_processor=image_processor,
623
+ dataset_map_fn=osprey_region_caption_gcg_format_map_fn,
624
+ template_map_fn=dict(
625
+ type=template_map_fn_factory, template=prompt_template),
626
+ max_length=max_length,
627
+ pad_image_to_square=True,
628
+ debug=False,
629
+ repeats=1,
630
+ )
631
+
632
+ region_conversation_osprey_dataset = dict(
633
+ type=OspreyRegionConversationDataset,
634
+ data_path=region_conversation_osprey_data_path,
635
+ image_folder=region_conversation_osprey_image_path,
636
+ tokenizer=tokenizer,
637
+ image_processor=image_processor,
638
+ dataset_map_fn=osprey_region_conversation_map_fn,
639
+ template_map_fn=dict(
640
+ type=template_map_fn_factory, template=prompt_template),
641
+ max_length=max_length,
642
+ pad_image_to_square=True,
643
+ debug=False,
644
+ repeats=1,
645
+ )
646
+
647
+ mdpv_detailed_description_ade20k_dataset = dict(
648
+ type=MDPVPointDetailedCaptionDataset,
649
+ data_path=mdpv_detailed_caption_ade20k_data_path,
650
+ image_folder=mdpv_detailed_caption_ade20k_image_path,
651
+ tokenizer=tokenizer,
652
+ image_processor=image_processor,
653
+ dataset_map_fn=mdpv_points_map_fn,
654
+ template_map_fn=dict(
655
+ type=template_map_fn_factory, template=prompt_template),
656
+ max_length=max_length,
657
+ pad_image_to_square=True,
658
+ debug=False,
659
+ repeats=1,
660
+ )
661
+
662
+ mdpv_detailed_description_cocostuff_10k_dataset = dict(
663
+ type=MDPVPointDetailedCaptionDataset,
664
+ data_path=mdpv_detailed_caption_cocostuff_10k_data_path,
665
+ image_folder=mdpv_detailed_caption_cocostuff_10k_image_path,
666
+ tokenizer=tokenizer,
667
+ image_processor=image_processor,
668
+ dataset_map_fn=mdpv_points_map_fn,
669
+ template_map_fn=dict(
670
+ type=template_map_fn_factory, template=prompt_template),
671
+ max_length=max_length,
672
+ pad_image_to_square=True,
673
+ debug=False,
674
+ repeats=1,
675
+ )
676
+
677
+ mdpv_detailed_description_cocostuff_164k_dataset = dict(
678
+ type=MDPVPointDetailedCaptionDataset,
679
+ data_path=mdpv_detailed_caption_cocostuff_164k_data_path,
680
+ image_folder=mdpv_detailed_caption_cocostuff_164k_image_path,
681
+ tokenizer=tokenizer,
682
+ image_processor=image_processor,
683
+ dataset_map_fn=mdpv_points_map_fn,
684
+ template_map_fn=dict(
685
+ type=template_map_fn_factory, template=prompt_template),
686
+ max_length=max_length,
687
+ pad_image_to_square=True,
688
+ debug=False,
689
+ repeats=1,
690
+ )
691
+
692
+ mdpv_detailed_description_vg_dataset = dict(
693
+ type=MDPVPointDetailedCaptionDataset,
694
+ data_path=mdpv_detailed_caption_vg_data_path,
695
+ image_folder=mdpv_detailed_caption_vg_image_path,
696
+ tokenizer=tokenizer,
697
+ image_processor=image_processor,
698
+ dataset_map_fn=mdpv_points_map_fn,
699
+ template_map_fn=dict(
700
+ type=template_map_fn_factory, template=prompt_template),
701
+ max_length=max_length,
702
+ pad_image_to_square=True,
703
+ debug=False,
704
+ repeats=1,
705
+ )
706
+
707
+ mdpv_brief_description_vg_dataset = dict(
708
+ type=MDPVPointBriefCaptionDataset,
709
+ data_path=mdpv_brief_caption_vg_data_path,
710
+ image_folder=mdpv_brief_caption_vg_image_path,
711
+ tokenizer=tokenizer,
712
+ image_processor=image_processor,
713
+ dataset_map_fn=mdpv_points_map_fn,
714
+ template_map_fn=dict(
715
+ type=template_map_fn_factory, template=prompt_template),
716
+ max_length=max_length,
717
+ pad_image_to_square=True,
718
+ debug=False,
719
+ repeats=1,
720
+ )
721
+
722
+ mdpv_brief_description_cocostuff10k_dataset = dict(
723
+ type=MDPVPointBriefCaptionDataset,
724
+ data_path=mdpv_brief_caption_cocostuff_10k_data_path,
725
+ image_folder=mdpv_brief_caption_cocostuff_10k_image_path,
726
+ tokenizer=tokenizer,
727
+ image_processor=image_processor,
728
+ dataset_map_fn=mdpv_points_map_fn,
729
+ template_map_fn=dict(
730
+ type=template_map_fn_factory, template=prompt_template),
731
+ max_length=max_length,
732
+ pad_image_to_square=True,
733
+ debug=False,
734
+ repeats=1,
735
+ )
736
+
737
+ mdpv_brief_description_cocostuff164k_dataset = dict(
738
+ type=MDPVPointBriefCaptionDataset,
739
+ data_path=mdpv_brief_caption_cocostuff_164k_data_path,
740
+ image_folder=mdpv_brief_caption_cocostuff_164k_image_path,
741
+ tokenizer=tokenizer,
742
+ image_processor=image_processor,
743
+ dataset_map_fn=mdpv_points_map_fn,
744
+ template_map_fn=dict(
745
+ type=template_map_fn_factory, template=prompt_template),
746
+ max_length=max_length,
747
+ pad_image_to_square=True,
748
+ debug=False,
749
+ repeats=1,
750
+ )
751
+
752
+ mdpv_brief_description_ade20k_dataset = dict(
753
+ type=MDPVPointBriefCaptionDataset,
754
+ data_path=mdpv_brief_caption_ade20k_data_path,
755
+ image_folder=mdpv_brief_caption_ade20k_image_path,
756
+ tokenizer=tokenizer,
757
+ image_processor=image_processor,
758
+ dataset_map_fn=mdpv_points_map_fn,
759
+ template_map_fn=dict(
760
+ type=template_map_fn_factory, template=prompt_template),
761
+ max_length=max_length,
762
+ pad_image_to_square=True,
763
+ debug=False,
764
+ repeats=1,
765
+ )
766
+
767
+ mdpv_brief_description_lvis_dataset = dict(
768
+ type=MDPVPointBriefCaptionDataset,
769
+ data_path=mdpv_brief_caption_lvis_data_path,
770
+ image_folder=mdpv_brief_caption_lvis_image_path,
771
+ tokenizer=tokenizer,
772
+ image_processor=image_processor,
773
+ dataset_map_fn=mdpv_points_map_fn,
774
+ template_map_fn=dict(
775
+ type=template_map_fn_factory, template=prompt_template),
776
+ max_length=max_length,
777
+ pad_image_to_square=True,
778
+ debug=False,
779
+ repeats=1,
780
+ )
781
+
782
+ mdpv_qa_vg_dataset = dict(
783
+ type=MDPVPointBriefCaptionDataset,
784
+ data_path=mdpv_qa_vg_data_path,
785
+ image_folder=mdpv_qa_vg_image_path,
786
+ tokenizer=tokenizer,
787
+ image_processor=image_processor,
788
+ dataset_map_fn=mdpv_points_map_fn,
789
+ template_map_fn=dict(
790
+ type=template_map_fn_factory, template=prompt_template),
791
+ max_length=max_length,
792
+ pad_image_to_square=True,
793
+ debug=False,
794
+ repeats=1,
795
+ )
796
+
797
+ mdpv_qa_ade20k_dataset = dict(
798
+ type=MDPVPointBriefCaptionDataset,
799
+ data_path=mdpv_qa_ade20k_data_path,
800
+ image_folder=mdpv_qa_ade20k_image_path,
801
+ tokenizer=tokenizer,
802
+ image_processor=image_processor,
803
+ dataset_map_fn=mdpv_points_map_fn,
804
+ template_map_fn=dict(
805
+ type=template_map_fn_factory, template=prompt_template),
806
+ max_length=max_length,
807
+ pad_image_to_square=True,
808
+ debug=False,
809
+ repeats=1,
810
+ )
811
+
812
+ mdpv_qa_lvis_dataset = dict(
813
+ type=MDPVPointBriefCaptionDataset,
814
+ data_path=mdpv_qa_lvis_data_path,
815
+ image_folder=mdpv_qa_lvis_image_path,
816
+ tokenizer=tokenizer,
817
+ image_processor=image_processor,
818
+ dataset_map_fn=mdpv_points_map_fn,
819
+ template_map_fn=dict(
820
+ type=template_map_fn_factory, template=prompt_template),
821
+ max_length=max_length,
822
+ pad_image_to_square=True,
823
+ debug=False,
824
+ repeats=1,
825
+ )
826
+
827
+ mdpv_qa_cocostuff10k_dataset = dict(
828
+ type=MDPVPointBriefCaptionDataset,
829
+ data_path=mdpv_qa_cocostuff10k_data_path,
830
+ image_folder=mdpv_qa_cocostuff10k_image_path,
831
+ tokenizer=tokenizer,
832
+ image_processor=image_processor,
833
+ dataset_map_fn=mdpv_points_map_fn,
834
+ template_map_fn=dict(
835
+ type=template_map_fn_factory, template=prompt_template),
836
+ max_length=max_length,
837
+ pad_image_to_square=True,
838
+ debug=False,
839
+ repeats=1,
840
+ )
841
+
842
+ mdpv_qa_cocostuff164k_dataset = dict(
843
+ type=MDPVPointBriefCaptionDataset,
844
+ data_path=mdpv_qa_cocostuff164k_data_path,
845
+ image_folder=mdpv_qa_cocostuff164k_image_path,
846
+ tokenizer=tokenizer,
847
+ image_processor=image_processor,
848
+ dataset_map_fn=mdpv_points_map_fn,
849
+ template_map_fn=dict(
850
+ type=template_map_fn_factory, template=prompt_template),
851
+ max_length=max_length,
852
+ pad_image_to_square=True,
853
+ debug=False,
854
+ repeats=1,
855
+ )
856
+
857
+ mdpv_multi_points_openpsg_dataset = dict(
858
+ type=MDPVPointBriefCaptionDataset,
859
+ data_path=mdpv_multi_points_openpsg_data_path,
860
+ image_folder=mdpv_multi_points_openpsg_image_path,
861
+ tokenizer=tokenizer,
862
+ image_processor=image_processor,
863
+ dataset_map_fn=mdpv_points_map_fn,
864
+ template_map_fn=dict(
865
+ type=template_map_fn_factory, template=prompt_template),
866
+ max_length=max_length,
867
+ pad_image_to_square=True,
868
+ debug=False,
869
+ repeats=1,
870
+ )
871
+
872
+ mdpv_multi_points_flicker30k_dataset = dict(
873
+ type=MDPVPointBriefCaptionDataset,
874
+ data_path=mdpv_multi_points_flicker30k_data_path,
875
+ image_folder=mdpv_multi_points_flicker30k_image_path,
876
+ tokenizer=tokenizer,
877
+ image_processor=image_processor,
878
+ dataset_map_fn=mdpv_points_map_fn,
879
+ template_map_fn=dict(
880
+ type=template_map_fn_factory, template=prompt_template),
881
+ max_length=max_length,
882
+ pad_image_to_square=True,
883
+ debug=False,
884
+ repeats=1,
885
+ )
886
+
887
+ train_dataset = dict(
888
+ type=CombineDataset,
889
+ datasets_cfgs=[
890
+ glamm_refcocog_dataset_given_objects, glamm_refcocog_dataset_given_description,
891
+ ],
892
+ )
893
+
894
+ train_dataloader = dict(
895
+ batch_size=batch_size,
896
+ num_workers=dataloader_num_workers,
897
+ dataset=train_dataset,
898
+ sampler=dict(
899
+ type=LengthGroupedSampler,
900
+ length_property='modality_length',
901
+ per_device_batch_size=batch_size * accumulative_counts),
902
+ collate_fn=dict(type=omg_llava_collate_fn))
903
+
904
+ #######################################################################
905
+ # PART 4 Scheduler & Optimizer #
906
+ #######################################################################
907
+ # optimizer
908
+ optim_wrapper = dict(
909
+ type=AmpOptimWrapper,
910
+ optimizer=dict(
911
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
912
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
913
+ accumulative_counts=accumulative_counts,
914
+ loss_scale='dynamic',
915
+ dtype='float16')
916
+
917
+ # learning policy
918
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
919
+ param_scheduler = [
920
+ dict(
921
+ type=LinearLR,
922
+ start_factor=1e-5,
923
+ by_epoch=True,
924
+ begin=0,
925
+ end=warmup_ratio * max_epochs,
926
+ convert_to_iter_based=True),
927
+ dict(
928
+ type=CosineAnnealingLR,
929
+ eta_min=0.0,
930
+ by_epoch=True,
931
+ begin=warmup_ratio * max_epochs,
932
+ end=max_epochs,
933
+ convert_to_iter_based=True)
934
+ ]
935
+
936
+ # train, val, test setting
937
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
938
+
939
+ #######################################################################
940
+ # PART 5 Runtime #
941
+ #######################################################################
942
+ # Log the dialogue periodically during the training process, optional
943
+ custom_hooks = [
944
+ dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer),
945
+ dict(
946
+ type=EvaluateChatHook_withSpecialTokens,
947
+ tokenizer=tokenizer,
948
+ image_processor=image_processor,
949
+ every_n_iters=evaluation_freq,
950
+ evaluation_inputs=evaluation_inputs,
951
+ evaluation_images=evaluation_images,
952
+ system=SYSTEM,
953
+ prompt_template=prompt_template)
954
+ ]
955
+
956
+ # configure default hooks
957
+ default_hooks = dict(
958
+ # record the time of every iteration.
959
+ timer=dict(type=IterTimerHook),
960
+ # print log every 10 iterations.
961
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
962
+ # enable the parameter scheduler.
963
+ param_scheduler=dict(type=ParamSchedulerHook),
964
+ # save checkpoint per `save_steps`.
965
+ checkpoint=dict(
966
+ type=CheckpointHook,
967
+ by_epoch=False,
968
+ interval=save_steps,
969
+ max_keep_ckpts=save_total_limit),
970
+ # set sampler seed in distributed evrionment.
971
+ sampler_seed=dict(type=DistSamplerSeedHook),
972
+ )
973
+
974
+ # configure environment
975
+ env_cfg = dict(
976
+ # whether to enable cudnn benchmark
977
+ cudnn_benchmark=False,
978
+ # set multi process parameters
979
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
980
+ # set distributed parameters
981
+ dist_cfg=dict(backend='nccl'),
982
+ )
983
+
984
+ # set visualizer
985
+ visualizer = None
986
+
987
+ # set log level
988
+ log_level = 'INFO'
989
+
990
+ # load from which checkpoint
991
+ load_from = None
992
+
993
+ # whether to resume training from the loaded checkpoint
994
+ resume = False
995
+
996
+ # Defaults to use random seed and disable `deterministic`
997
+ randomness = dict(seed=None, deterministic=False)
998
+
999
+ # set log processor
1000
+ log_processor = dict(by_epoch=False)
omg_llava/configs/finetune/omg_llava_7b_finetune_stage1_1024image_uniSegFormat_8gpus.py ADDED
@@ -0,0 +1,951 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
4
+ LoggerHook, ParamSchedulerHook)
5
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
6
+ from peft import LoraConfig
7
+ from torch.optim import AdamW
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig, CLIPImageProcessor,
10
+ CLIPVisionModel)
11
+
12
+ from omg_llava.dataset import LLaVADataset
13
+ from omg_llava.dataset.collect_fns import omg_llava_collate_fn
14
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
15
+ from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\
16
+ CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\
17
+ ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\
18
+ PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\
19
+ RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\
20
+ Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\
21
+ OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\
22
+ OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\
23
+ MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\
24
+ semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\
25
+ referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn
26
+ from xtuner.dataset.samplers import LengthGroupedSampler
27
+ from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens
28
+ from xtuner.engine.runner import TrainLoop
29
+ from omg_llava.model import OMG_LLaVA
30
+ from xtuner.utils import PROMPT_TEMPLATE
31
+ from omg_llava.model import OpenCLIPBackbone_omgseg
32
+ from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead
33
+
34
+ from torch.nn import GroupNorm, ReLU
35
+
36
+ from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \
37
+ DiceLoss, MaskFormerFusionHead, FocalLoss
38
+ from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost
39
+ from mmdet.models.task_modules.samplers import MaskPseudoSampler
40
+
41
+ #######################################################################
42
+ # PART 1 Settings #
43
+ #######################################################################
44
+ # Model
45
+ llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path
46
+ pretrained_pth = './pretrained/omg_llava/omg_llava_7b_pretrain_1024image_8gpus.pth'
47
+ omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path
48
+ omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path
49
+
50
+ # Data
51
+ data_root = './data/llava_data/'
52
+ data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json'
53
+ image_folder = data_root + 'llava_images'
54
+
55
+ glamm_data_root = './data/glamm_data/'
56
+
57
+ refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/'
58
+ refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json'
59
+
60
+ grandf_image_path = glamm_data_root + 'images/grandf/train/'
61
+ grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json'
62
+
63
+ flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/'
64
+ flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json'
65
+
66
+ psg_image_path = glamm_data_root + 'images/coco2017/'
67
+ psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json'
68
+
69
+ ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
70
+ ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json'
71
+
72
+ cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/'
73
+ cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt'
74
+ cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/'
75
+
76
+ mapillary_image_path = './data/semantic_seg/mapillary/training/images/'
77
+ mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json'
78
+ mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/'
79
+
80
+ pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/'
81
+ pascal_file = './data/semantic_seg/pascal_part/train.json'
82
+
83
+ paco_image_path = './data/glamm_data/images/coco2017/'
84
+ paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json'
85
+
86
+ referring_refcoco_image_path = refcocog_image_path
87
+ referring_refcoco_data_path = "./data/ref_seg/"
88
+
89
+ referring_refcoco_plus_image_path = refcocog_image_path
90
+ referring_refcoco_plus_data_path = "./data/ref_seg/"
91
+
92
+ referring_refcocog_image_path = refcocog_image_path
93
+ referring_refcocog_data_path = "./data/ref_seg/"
94
+
95
+ referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/"
96
+ referring_refclef_data_path = "./data/ref_seg/"
97
+
98
+ region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
99
+ region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json"
100
+
101
+ region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
102
+ region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json"
103
+
104
+ mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
105
+ mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json'
106
+
107
+ mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
108
+ mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json'
109
+
110
+ mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
111
+ mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json'
112
+
113
+ mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
114
+ mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json'
115
+
116
+ mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
117
+ mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json'
118
+
119
+ mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
120
+ mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json'
121
+
122
+ mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
123
+ mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json'
124
+
125
+ mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
126
+ mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json'
127
+
128
+ mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
129
+ mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json'
130
+
131
+ mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
132
+ mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json'
133
+
134
+ mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
135
+ mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json'
136
+
137
+ mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017'
138
+ mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json'
139
+
140
+ mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
141
+ mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json'
142
+
143
+ mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
144
+ mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json'
145
+
146
+ mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/'
147
+ mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json'
148
+
149
+ mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017'
150
+ mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json'
151
+
152
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
153
+ max_length = int(2048 - (1024 / 64)**2 - 100)
154
+
155
+ # Scheduler & Optimizer
156
+ batch_size = 8 # per_device
157
+ accumulative_counts = 2
158
+ dataloader_num_workers = 4
159
+ max_epochs = 1
160
+ optim_type = AdamW
161
+ lr = 2e-4
162
+ betas = (0.9, 0.999)
163
+ weight_decay = 0
164
+ max_norm = 1 # grad clip
165
+ warmup_ratio = 0.03
166
+
167
+
168
+ # Save
169
+ save_steps = 2000
170
+ save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited)
171
+
172
+ # Evaluate the generation performance during the training
173
+ evaluation_freq = 2000
174
+ SYSTEM = ''
175
+ evaluation_images = './work_dirs/test.jpg'
176
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture',
177
+ 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.']
178
+
179
+ #######################################################################
180
+ # PART 2 Model & Tokenizer & Image Processor #
181
+ #######################################################################
182
+ tokenizer = dict(
183
+ type=AutoTokenizer.from_pretrained,
184
+ pretrained_model_name_or_path=llm_name_or_path,
185
+ trust_remote_code=True,
186
+ padding_side='right')
187
+
188
+ image_processor = dict(
189
+ type=CLIPImageProcessor,
190
+ do_resize=True,
191
+ size=1024,
192
+ resample=3,
193
+ do_center_crop=True,
194
+ crop_size=1024,
195
+ do_rescale=True,
196
+ do_normalize=True,
197
+ image_mean=[0.4814, 0.4578, 0.4082],
198
+ image_std=[0.2686, 0.2613, 0.2757],
199
+ do_convert_rgb=True
200
+ )
201
+
202
+ class_embed = 'convnext_large_d_320_CocoPanopticOVDataset'
203
+ num_things_classes = 80
204
+ num_stuff_classes = 53
205
+ num_classes = num_things_classes + num_stuff_classes
206
+
207
+ omgseg_model = dict(
208
+ type=OMGSegVisualEncoder,
209
+ data_preprocessor=None,
210
+ pixel_shuffle_down_ratio=2,
211
+ backbone=dict(
212
+ type=OpenCLIPBackbone_omgseg,
213
+ model_name='convnext_large_d_320',
214
+ fix=True,
215
+ init_cfg=dict(
216
+ type='clip_pretrain',
217
+ checkpoint='laion2b_s29b_b131k_ft_soup'
218
+ )
219
+ ),
220
+ panoptic_head=dict(
221
+ type=Mask2FormerVideoSemSamHead,
222
+ sphere_cls=True,
223
+ ov_path=omg_ov_class_embed_path,
224
+ enable_box_query=False,
225
+ ov_classifier_name=class_embed,
226
+ logit=None,
227
+ in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside
228
+ strides=[4, 8, 16, 32],
229
+ feat_channels=256,
230
+ out_channels=256,
231
+ num_things_classes=num_things_classes,
232
+ num_stuff_classes=num_stuff_classes,
233
+ num_queries=300,
234
+ num_transformer_feat_level=3,
235
+ pixel_decoder=dict(
236
+ type=MSDeformAttnPixelDecoder,
237
+ num_outs=3,
238
+ norm_cfg=dict(type=GroupNorm, num_groups=32),
239
+ act_cfg=dict(type=ReLU),
240
+ encoder=dict( # DeformableDetrTransformerEncoder
241
+ num_layers=6,
242
+ layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
243
+ self_attn_cfg=dict( # MultiScaleDeformableAttention
244
+ embed_dims=256,
245
+ num_heads=8,
246
+ num_levels=3,
247
+ num_points=4,
248
+ dropout=0.0,
249
+ batch_first=True),
250
+ ffn_cfg=dict(
251
+ embed_dims=256,
252
+ feedforward_channels=1024,
253
+ num_fcs=2,
254
+ ffn_drop=0.0,
255
+ act_cfg=dict(type=ReLU, inplace=True)))),
256
+ positional_encoding=dict(num_feats=128, normalize=True)),
257
+ enforce_decoder_input_project=False,
258
+ positional_encoding=dict(num_feats=128, normalize=True),
259
+ transformer_decoder=dict( # Mask2FormerTransformerDecoder
260
+ return_intermediate=True,
261
+ num_layers=9,
262
+ layer_cfg=dict( # Mask2FormerTransformerDecoderLayer
263
+ self_attn_cfg=dict( # MultiheadAttention
264
+ embed_dims=256,
265
+ num_heads=8,
266
+ dropout=0.0,
267
+ batch_first=True),
268
+ cross_attn_cfg=dict( # MultiheadAttention
269
+ embed_dims=256,
270
+ num_heads=8,
271
+ dropout=0.0,
272
+ batch_first=True),
273
+ ffn_cfg=dict(
274
+ embed_dims=256,
275
+ feedforward_channels=2048,
276
+ num_fcs=2,
277
+ ffn_drop=0.0,
278
+ act_cfg=dict(type='ReLU', inplace=True))),
279
+ init_cfg=None),
280
+ loss_cls=dict(
281
+ type=CrossEntropyLoss,
282
+ use_sigmoid=False,
283
+ loss_weight=2.0,
284
+ reduction='mean',
285
+ class_weight=[1.0] * 240 + [0.1]),
286
+ loss_mask=dict(
287
+ type=CrossEntropyLoss,
288
+ use_sigmoid=True,
289
+ reduction='mean',
290
+ loss_weight=5.0),
291
+ loss_dice=dict(
292
+ type=DiceLoss,
293
+ use_sigmoid=True,
294
+ activate=True,
295
+ reduction='mean',
296
+ naive_dice=True,
297
+ eps=1.0,
298
+ loss_weight=5.0),
299
+ loss_iou=dict(
300
+ type=FocalLoss,
301
+ use_sigmoid=True,
302
+ loss_weight=2.0,
303
+ reduction='mean')
304
+ ),
305
+ panoptic_fusion_head=dict(
306
+ type=MaskFormerFusionHead,
307
+ num_things_classes=num_things_classes,
308
+ num_stuff_classes=num_stuff_classes,
309
+ loss_panoptic=None,
310
+ init_cfg=None),
311
+ train_cfg=dict(
312
+ num_points=12544,
313
+ oversample_ratio=3.0,
314
+ importance_sample_ratio=0.75,
315
+ assigner=dict(
316
+ type=HungarianAssigner,
317
+ match_costs=[
318
+ # dict(type=FlexibleClassificationCost, weight=2.0),
319
+ dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True),
320
+ dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0)
321
+ ]),
322
+ sampler=dict(type=MaskPseudoSampler)),
323
+ test_cfg=dict(
324
+ panoptic_on=True,
325
+ # For now, the dataset does not support
326
+ # evaluating semantic segmentation metric.
327
+ semantic_on=False,
328
+ instance_on=True,
329
+ # max_per_image is for instance segmentation.
330
+ max_per_image=100,
331
+ iou_thr=0.8,
332
+ # In Mask2Former's panoptic postprocessing,
333
+ # it will filter mask area where score is less than 0.5 .
334
+ filter_low_score=True),
335
+ init_cfg=dict(
336
+ type='Pretrained',
337
+ checkpoint=omg_head_pretrain_pth_path,
338
+ )
339
+ )
340
+
341
+ model = dict(
342
+ type=OMG_LLaVA,
343
+ freeze_llm=True,
344
+ freeze_visual_encoder=True,
345
+ require_omg_decoder=False,
346
+ pretrained_pth=pretrained_pth,
347
+ text2vision_projector=True,
348
+ pixel_shuffle_ratio=2,
349
+ llm=dict(
350
+ type=AutoModelForCausalLM.from_pretrained,
351
+ pretrained_model_name_or_path=llm_name_or_path,
352
+ trust_remote_code=True,
353
+ torch_dtype=torch.float16,
354
+ quantization_config=dict(
355
+ type=BitsAndBytesConfig,
356
+ load_in_4bit=True,
357
+ load_in_8bit=False,
358
+ llm_int8_threshold=6.0,
359
+ llm_int8_has_fp16_weight=False,
360
+ bnb_4bit_compute_dtype=torch.float16,
361
+ bnb_4bit_use_double_quant=True,
362
+ bnb_4bit_quant_type='nf4')),
363
+ llm_lora=dict(
364
+ type=LoraConfig,
365
+ r=512,
366
+ lora_alpha=256,
367
+ lora_dropout=0.05,
368
+ bias='none',
369
+ task_type='CAUSAL_LM'),
370
+ visual_encoder=omgseg_model,
371
+ tokenizer=tokenizer,
372
+ )
373
+
374
+ #######################################################################
375
+ # PART 3 Dataset & Dataloader #
376
+ #######################################################################
377
+ debug=False
378
+ llava_dataset = dict(
379
+ type=LLaVADataset,
380
+ data_path=data_path,
381
+ image_folder=image_folder,
382
+ tokenizer=tokenizer,
383
+ image_processor=image_processor,
384
+ dataset_map_fn=llava_map_fn,
385
+ template_map_fn=dict(
386
+ type=template_map_fn_factory, template=prompt_template),
387
+ max_length=max_length,
388
+ pad_image_to_square=True)
389
+
390
+ glamm_refcocog_dataset = dict(
391
+ type=RefCOCOgGCGDataset,
392
+ data_path=refcocog_ann_file,
393
+ image_folder=refcocog_image_path,
394
+ tokenizer=tokenizer,
395
+ image_processor=image_processor,
396
+ dataset_map_fn=glamm_refcocog_map_fn,
397
+ template_map_fn=dict(
398
+ type=template_map_fn_factory, template=prompt_template),
399
+ max_length=max_length,
400
+ pad_image_to_square=True,
401
+ debug=False,
402
+ repeats=1,
403
+ )
404
+
405
+ glamm_grandf_dataset = dict(
406
+ type=GranDfGCGDataset,
407
+ data_path=grandf_ann_file,
408
+ image_folder=grandf_image_path,
409
+ tokenizer=tokenizer,
410
+ image_processor=image_processor,
411
+ dataset_map_fn=glamm_granf_map_fn,
412
+ template_map_fn=dict(
413
+ type=template_map_fn_factory, template=prompt_template),
414
+ max_length=max_length,
415
+ pad_image_to_square=True,
416
+ debug=debug,
417
+ repeats=10,
418
+ )
419
+
420
+ glamm_psg_dataset = dict(
421
+ type=OpenPsgGCGDataset,
422
+ data_path=psg_ann_file,
423
+ image_folder=psg_image_path,
424
+ tokenizer=tokenizer,
425
+ image_processor=image_processor,
426
+ dataset_map_fn=glamm_openpsg_map_fn,
427
+ template_map_fn=dict(
428
+ type=template_map_fn_factory, template=prompt_template),
429
+ max_length=max_length,
430
+ pad_image_to_square=True,
431
+ debug=debug,
432
+ repeats=1,
433
+ )
434
+
435
+ glamm_flickr_dataset = dict(
436
+ type=FlickrGCGDataset,
437
+ data_path=flickr_ann_file,
438
+ image_folder=flickr_image_path,
439
+ tokenizer=tokenizer,
440
+ image_processor=image_processor,
441
+ dataset_map_fn=glamm_flickr_map_fn,
442
+ template_map_fn=dict(
443
+ type=template_map_fn_factory, template=prompt_template),
444
+ max_length=max_length,
445
+ pad_image_to_square=True,
446
+ debug=debug,
447
+ repeats=1,
448
+ )
449
+
450
+ semantic_seg_ade20k_dataset = dict(
451
+ type=ADE20kSemanticSegDataset,
452
+ data_path=ade20k_class_file,
453
+ image_folder=ade20k_image_path,
454
+ tokenizer=tokenizer,
455
+ image_processor=image_processor,
456
+ dataset_map_fn=semantic_seg_gcg_format_map_fn,
457
+ template_map_fn=dict(
458
+ type=template_map_fn_factory, template=prompt_template),
459
+ max_length=max_length,
460
+ pad_image_to_square=True,
461
+ debug=False,
462
+ repeats=1,
463
+ gcg_format=True,
464
+ )
465
+
466
+ semantic_seg_cocostuff_dataset = dict(
467
+ type=COCOStuffSemanticSegDataset,
468
+ data_path=cocostuff_class_file,
469
+ image_folder=cocostuff_image_path,
470
+ label_path=cocostuff_label_path,
471
+ tokenizer=tokenizer,
472
+ image_processor=image_processor,
473
+ dataset_map_fn=semantic_seg_gcg_format_map_fn,
474
+ template_map_fn=dict(
475
+ type=template_map_fn_factory, template=prompt_template),
476
+ max_length=max_length,
477
+ pad_image_to_square=True,
478
+ debug=False,
479
+ repeats=1,
480
+ gcg_format=True,
481
+ )
482
+
483
+ referring_seg_refcoco_dataset = dict(
484
+ type=RefcocoReferringSegDataset,
485
+ data_path=referring_refcoco_data_path,
486
+ image_folder=referring_refcoco_image_path,
487
+ tokenizer=tokenizer,
488
+ image_processor=image_processor,
489
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
490
+ template_map_fn=dict(
491
+ type=template_map_fn_factory, template=prompt_template),
492
+ max_length=max_length,
493
+ pad_image_to_square=True,
494
+ debug=False,
495
+ repeats=1,
496
+ )
497
+
498
+ referring_seg_refcoco_plus_dataset = dict(
499
+ type=Refcoco_plus_ReferringSegDataset,
500
+ data_path=referring_refcoco_plus_data_path,
501
+ image_folder=referring_refcoco_plus_image_path,
502
+ tokenizer=tokenizer,
503
+ image_processor=image_processor,
504
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
505
+ template_map_fn=dict(
506
+ type=template_map_fn_factory, template=prompt_template),
507
+ max_length=max_length,
508
+ pad_image_to_square=True,
509
+ debug=False,
510
+ repeats=1,
511
+ )
512
+
513
+ referring_seg_refcocog_dataset = dict(
514
+ type=Refcocog_ReferringSegDataset,
515
+ data_path=referring_refcocog_data_path,
516
+ image_folder=referring_refcocog_image_path,
517
+ tokenizer=tokenizer,
518
+ image_processor=image_processor,
519
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
520
+ template_map_fn=dict(
521
+ type=template_map_fn_factory, template=prompt_template),
522
+ max_length=max_length,
523
+ pad_image_to_square=True,
524
+ debug=False,
525
+ repeats=1,
526
+ )
527
+
528
+ referring_seg_refclef_dataset = dict(
529
+ type=Refclef_ReferringSegDataset,
530
+ data_path=referring_refclef_data_path,
531
+ image_folder=referring_refclef_image_path,
532
+ tokenizer=tokenizer,
533
+ image_processor=image_processor,
534
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
535
+ template_map_fn=dict(
536
+ type=template_map_fn_factory, template=prompt_template),
537
+ max_length=max_length,
538
+ pad_image_to_square=True,
539
+ debug=False,
540
+ repeats=1,
541
+ )
542
+
543
+ region_cap_osprey_dataset = dict(
544
+ type=OspreyRegionCaptionDataset,
545
+ data_path=region_cap_osprey_data_path,
546
+ image_folder=region_cap_osprey_image_path,
547
+ tokenizer=tokenizer,
548
+ image_processor=image_processor,
549
+ dataset_map_fn=osprey_region_caption_gcg_format_map_fn,
550
+ template_map_fn=dict(
551
+ type=template_map_fn_factory, template=prompt_template),
552
+ max_length=max_length,
553
+ pad_image_to_square=True,
554
+ debug=False,
555
+ repeats=1,
556
+ )
557
+
558
+ region_conversation_osprey_dataset = dict(
559
+ type=OspreyRegionConversationDataset,
560
+ data_path=region_conversation_osprey_data_path,
561
+ image_folder=region_conversation_osprey_image_path,
562
+ tokenizer=tokenizer,
563
+ image_processor=image_processor,
564
+ dataset_map_fn=osprey_region_conversation_map_fn,
565
+ template_map_fn=dict(
566
+ type=template_map_fn_factory, template=prompt_template),
567
+ max_length=max_length,
568
+ pad_image_to_square=True,
569
+ debug=False,
570
+ repeats=1,
571
+ )
572
+
573
+ mdpv_detailed_description_ade20k_dataset = dict(
574
+ type=MDPVPointDetailedCaptionDataset,
575
+ data_path=mdpv_detailed_caption_ade20k_data_path,
576
+ image_folder=mdpv_detailed_caption_ade20k_image_path,
577
+ tokenizer=tokenizer,
578
+ image_processor=image_processor,
579
+ dataset_map_fn=mdpv_points_map_fn,
580
+ template_map_fn=dict(
581
+ type=template_map_fn_factory, template=prompt_template),
582
+ max_length=max_length,
583
+ pad_image_to_square=True,
584
+ debug=False,
585
+ repeats=1,
586
+ )
587
+
588
+ mdpv_detailed_description_cocostuff_10k_dataset = dict(
589
+ type=MDPVPointDetailedCaptionDataset,
590
+ data_path=mdpv_detailed_caption_cocostuff_10k_data_path,
591
+ image_folder=mdpv_detailed_caption_cocostuff_10k_image_path,
592
+ tokenizer=tokenizer,
593
+ image_processor=image_processor,
594
+ dataset_map_fn=mdpv_points_map_fn,
595
+ template_map_fn=dict(
596
+ type=template_map_fn_factory, template=prompt_template),
597
+ max_length=max_length,
598
+ pad_image_to_square=True,
599
+ debug=False,
600
+ repeats=1,
601
+ )
602
+
603
+ mdpv_detailed_description_cocostuff_164k_dataset = dict(
604
+ type=MDPVPointDetailedCaptionDataset,
605
+ data_path=mdpv_detailed_caption_cocostuff_164k_data_path,
606
+ image_folder=mdpv_detailed_caption_cocostuff_164k_image_path,
607
+ tokenizer=tokenizer,
608
+ image_processor=image_processor,
609
+ dataset_map_fn=mdpv_points_map_fn,
610
+ template_map_fn=dict(
611
+ type=template_map_fn_factory, template=prompt_template),
612
+ max_length=max_length,
613
+ pad_image_to_square=True,
614
+ debug=False,
615
+ repeats=1,
616
+ )
617
+
618
+ mdpv_detailed_description_vg_dataset = dict(
619
+ type=MDPVPointDetailedCaptionDataset,
620
+ data_path=mdpv_detailed_caption_vg_data_path,
621
+ image_folder=mdpv_detailed_caption_vg_image_path,
622
+ tokenizer=tokenizer,
623
+ image_processor=image_processor,
624
+ dataset_map_fn=mdpv_points_map_fn,
625
+ template_map_fn=dict(
626
+ type=template_map_fn_factory, template=prompt_template),
627
+ max_length=max_length,
628
+ pad_image_to_square=True,
629
+ debug=False,
630
+ repeats=1,
631
+ )
632
+
633
+ mdpv_brief_description_vg_dataset = dict(
634
+ type=MDPVPointBriefCaptionDataset,
635
+ data_path=mdpv_brief_caption_vg_data_path,
636
+ image_folder=mdpv_brief_caption_vg_image_path,
637
+ tokenizer=tokenizer,
638
+ image_processor=image_processor,
639
+ dataset_map_fn=mdpv_points_map_fn,
640
+ template_map_fn=dict(
641
+ type=template_map_fn_factory, template=prompt_template),
642
+ max_length=max_length,
643
+ pad_image_to_square=True,
644
+ debug=False,
645
+ repeats=1,
646
+ )
647
+
648
+ mdpv_brief_description_cocostuff10k_dataset = dict(
649
+ type=MDPVPointBriefCaptionDataset,
650
+ data_path=mdpv_brief_caption_cocostuff_10k_data_path,
651
+ image_folder=mdpv_brief_caption_cocostuff_10k_image_path,
652
+ tokenizer=tokenizer,
653
+ image_processor=image_processor,
654
+ dataset_map_fn=mdpv_points_map_fn,
655
+ template_map_fn=dict(
656
+ type=template_map_fn_factory, template=prompt_template),
657
+ max_length=max_length,
658
+ pad_image_to_square=True,
659
+ debug=False,
660
+ repeats=1,
661
+ )
662
+
663
+ mdpv_brief_description_cocostuff164k_dataset = dict(
664
+ type=MDPVPointBriefCaptionDataset,
665
+ data_path=mdpv_brief_caption_cocostuff_164k_data_path,
666
+ image_folder=mdpv_brief_caption_cocostuff_164k_image_path,
667
+ tokenizer=tokenizer,
668
+ image_processor=image_processor,
669
+ dataset_map_fn=mdpv_points_map_fn,
670
+ template_map_fn=dict(
671
+ type=template_map_fn_factory, template=prompt_template),
672
+ max_length=max_length,
673
+ pad_image_to_square=True,
674
+ debug=False,
675
+ repeats=1,
676
+ )
677
+
678
+ mdpv_brief_description_ade20k_dataset = dict(
679
+ type=MDPVPointBriefCaptionDataset,
680
+ data_path=mdpv_brief_caption_ade20k_data_path,
681
+ image_folder=mdpv_brief_caption_ade20k_image_path,
682
+ tokenizer=tokenizer,
683
+ image_processor=image_processor,
684
+ dataset_map_fn=mdpv_points_map_fn,
685
+ template_map_fn=dict(
686
+ type=template_map_fn_factory, template=prompt_template),
687
+ max_length=max_length,
688
+ pad_image_to_square=True,
689
+ debug=False,
690
+ repeats=1,
691
+ )
692
+
693
+ mdpv_brief_description_lvis_dataset = dict(
694
+ type=MDPVPointBriefCaptionDataset,
695
+ data_path=mdpv_brief_caption_lvis_data_path,
696
+ image_folder=mdpv_brief_caption_lvis_image_path,
697
+ tokenizer=tokenizer,
698
+ image_processor=image_processor,
699
+ dataset_map_fn=mdpv_points_map_fn,
700
+ template_map_fn=dict(
701
+ type=template_map_fn_factory, template=prompt_template),
702
+ max_length=max_length,
703
+ pad_image_to_square=True,
704
+ debug=False,
705
+ repeats=1,
706
+ )
707
+
708
+ mdpv_qa_vg_dataset = dict(
709
+ type=MDPVPointBriefCaptionDataset,
710
+ data_path=mdpv_qa_vg_data_path,
711
+ image_folder=mdpv_qa_vg_image_path,
712
+ tokenizer=tokenizer,
713
+ image_processor=image_processor,
714
+ dataset_map_fn=mdpv_points_map_fn,
715
+ template_map_fn=dict(
716
+ type=template_map_fn_factory, template=prompt_template),
717
+ max_length=max_length,
718
+ pad_image_to_square=True,
719
+ debug=False,
720
+ repeats=1,
721
+ )
722
+
723
+ mdpv_qa_ade20k_dataset = dict(
724
+ type=MDPVPointBriefCaptionDataset,
725
+ data_path=mdpv_qa_ade20k_data_path,
726
+ image_folder=mdpv_qa_ade20k_image_path,
727
+ tokenizer=tokenizer,
728
+ image_processor=image_processor,
729
+ dataset_map_fn=mdpv_points_map_fn,
730
+ template_map_fn=dict(
731
+ type=template_map_fn_factory, template=prompt_template),
732
+ max_length=max_length,
733
+ pad_image_to_square=True,
734
+ debug=False,
735
+ repeats=1,
736
+ )
737
+
738
+ mdpv_qa_lvis_dataset = dict(
739
+ type=MDPVPointBriefCaptionDataset,
740
+ data_path=mdpv_qa_lvis_data_path,
741
+ image_folder=mdpv_qa_lvis_image_path,
742
+ tokenizer=tokenizer,
743
+ image_processor=image_processor,
744
+ dataset_map_fn=mdpv_points_map_fn,
745
+ template_map_fn=dict(
746
+ type=template_map_fn_factory, template=prompt_template),
747
+ max_length=max_length,
748
+ pad_image_to_square=True,
749
+ debug=False,
750
+ repeats=1,
751
+ )
752
+
753
+ mdpv_qa_cocostuff10k_dataset = dict(
754
+ type=MDPVPointBriefCaptionDataset,
755
+ data_path=mdpv_qa_cocostuff10k_data_path,
756
+ image_folder=mdpv_qa_cocostuff10k_image_path,
757
+ tokenizer=tokenizer,
758
+ image_processor=image_processor,
759
+ dataset_map_fn=mdpv_points_map_fn,
760
+ template_map_fn=dict(
761
+ type=template_map_fn_factory, template=prompt_template),
762
+ max_length=max_length,
763
+ pad_image_to_square=True,
764
+ debug=False,
765
+ repeats=1,
766
+ )
767
+
768
+ mdpv_qa_cocostuff164k_dataset = dict(
769
+ type=MDPVPointBriefCaptionDataset,
770
+ data_path=mdpv_qa_cocostuff164k_data_path,
771
+ image_folder=mdpv_qa_cocostuff164k_image_path,
772
+ tokenizer=tokenizer,
773
+ image_processor=image_processor,
774
+ dataset_map_fn=mdpv_points_map_fn,
775
+ template_map_fn=dict(
776
+ type=template_map_fn_factory, template=prompt_template),
777
+ max_length=max_length,
778
+ pad_image_to_square=True,
779
+ debug=False,
780
+ repeats=1,
781
+ )
782
+
783
+ mdpv_multi_points_openpsg_dataset = dict(
784
+ type=MDPVPointBriefCaptionDataset,
785
+ data_path=mdpv_multi_points_openpsg_data_path,
786
+ image_folder=mdpv_multi_points_openpsg_image_path,
787
+ tokenizer=tokenizer,
788
+ image_processor=image_processor,
789
+ dataset_map_fn=mdpv_points_map_fn,
790
+ template_map_fn=dict(
791
+ type=template_map_fn_factory, template=prompt_template),
792
+ max_length=max_length,
793
+ pad_image_to_square=True,
794
+ debug=False,
795
+ repeats=1,
796
+ )
797
+
798
+ mdpv_multi_points_flicker30k_dataset = dict(
799
+ type=MDPVPointBriefCaptionDataset,
800
+ data_path=mdpv_multi_points_flicker30k_data_path,
801
+ image_folder=mdpv_multi_points_flicker30k_image_path,
802
+ tokenizer=tokenizer,
803
+ image_processor=image_processor,
804
+ dataset_map_fn=mdpv_points_map_fn,
805
+ template_map_fn=dict(
806
+ type=template_map_fn_factory, template=prompt_template),
807
+ max_length=max_length,
808
+ pad_image_to_square=True,
809
+ debug=False,
810
+ repeats=1,
811
+ )
812
+
813
+ train_dataset = dict(
814
+ type=CombineDataset,
815
+ datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset,
816
+ glamm_grandf_dataset, glamm_psg_dataset,
817
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x
818
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset,
819
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset,
820
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
821
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x
822
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
823
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset,
824
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
825
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset,
826
+ region_cap_osprey_dataset, region_conversation_osprey_dataset,
827
+ mdpv_detailed_description_ade20k_dataset,
828
+ mdpv_detailed_description_cocostuff_10k_dataset,
829
+ mdpv_detailed_description_cocostuff_164k_dataset,
830
+ mdpv_detailed_description_vg_dataset,
831
+ mdpv_brief_description_lvis_dataset,
832
+ mdpv_brief_description_vg_dataset,
833
+ mdpv_brief_description_ade20k_dataset,
834
+ mdpv_brief_description_cocostuff10k_dataset,
835
+ mdpv_brief_description_cocostuff164k_dataset,
836
+ mdpv_qa_vg_dataset,
837
+ mdpv_qa_lvis_dataset,
838
+ mdpv_qa_ade20k_dataset,
839
+ mdpv_qa_cocostuff10k_dataset,
840
+ mdpv_qa_cocostuff164k_dataset,
841
+ mdpv_multi_points_flicker30k_dataset,
842
+ mdpv_multi_points_openpsg_dataset,],
843
+ )
844
+
845
+ train_dataloader = dict(
846
+ batch_size=batch_size,
847
+ num_workers=dataloader_num_workers,
848
+ dataset=train_dataset,
849
+ sampler=dict(
850
+ type=LengthGroupedSampler,
851
+ length_property='modality_length',
852
+ per_device_batch_size=batch_size * accumulative_counts),
853
+ collate_fn=dict(type=omg_llava_collate_fn))
854
+
855
+ #######################################################################
856
+ # PART 4 Scheduler & Optimizer #
857
+ #######################################################################
858
+ # optimizer
859
+ optim_wrapper = dict(
860
+ type=AmpOptimWrapper,
861
+ optimizer=dict(
862
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
863
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
864
+ accumulative_counts=accumulative_counts,
865
+ loss_scale='dynamic',
866
+ dtype='float16')
867
+
868
+ # learning policy
869
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
870
+ param_scheduler = [
871
+ dict(
872
+ type=LinearLR,
873
+ start_factor=1e-5,
874
+ by_epoch=True,
875
+ begin=0,
876
+ end=warmup_ratio * max_epochs,
877
+ convert_to_iter_based=True),
878
+ dict(
879
+ type=CosineAnnealingLR,
880
+ eta_min=0.0,
881
+ by_epoch=True,
882
+ begin=warmup_ratio * max_epochs,
883
+ end=max_epochs,
884
+ convert_to_iter_based=True)
885
+ ]
886
+
887
+ # train, val, test setting
888
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
889
+
890
+ #######################################################################
891
+ # PART 5 Runtime #
892
+ #######################################################################
893
+ # Log the dialogue periodically during the training process, optional
894
+ custom_hooks = [
895
+ dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer),
896
+ dict(
897
+ type=EvaluateChatHook_withSpecialTokens,
898
+ tokenizer=tokenizer,
899
+ image_processor=image_processor,
900
+ every_n_iters=evaluation_freq,
901
+ evaluation_inputs=evaluation_inputs,
902
+ evaluation_images=evaluation_images,
903
+ system=SYSTEM,
904
+ prompt_template=prompt_template)
905
+ ]
906
+
907
+ # configure default hooks
908
+ default_hooks = dict(
909
+ # record the time of every iteration.
910
+ timer=dict(type=IterTimerHook),
911
+ # print log every 10 iterations.
912
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
913
+ # enable the parameter scheduler.
914
+ param_scheduler=dict(type=ParamSchedulerHook),
915
+ # save checkpoint per `save_steps`.
916
+ checkpoint=dict(
917
+ type=CheckpointHook,
918
+ by_epoch=False,
919
+ interval=save_steps,
920
+ max_keep_ckpts=save_total_limit),
921
+ # set sampler seed in distributed evrionment.
922
+ sampler_seed=dict(type=DistSamplerSeedHook),
923
+ )
924
+
925
+ # configure environment
926
+ env_cfg = dict(
927
+ # whether to enable cudnn benchmark
928
+ cudnn_benchmark=False,
929
+ # set multi process parameters
930
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
931
+ # set distributed parameters
932
+ dist_cfg=dict(backend='nccl'),
933
+ )
934
+
935
+ # set visualizer
936
+ visualizer = None
937
+
938
+ # set log level
939
+ log_level = 'INFO'
940
+
941
+ # load from which checkpoint
942
+ load_from = None
943
+
944
+ # whether to resume training from the loaded checkpoint
945
+ resume = False
946
+
947
+ # Defaults to use random seed and disable `deterministic`
948
+ randomness = dict(seed=None, deterministic=False)
949
+
950
+ # set log processor
951
+ log_processor = dict(by_epoch=False)
omg_llava/configs/finetune/omg_llava_7b_finetune_stage2_1024image_8gpus.py ADDED
@@ -0,0 +1,994 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
4
+ LoggerHook, ParamSchedulerHook)
5
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
6
+ from peft import LoraConfig
7
+ from torch.optim import AdamW
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig, CLIPImageProcessor,
10
+ CLIPVisionModel)
11
+
12
+ from omg_llava.dataset import LLaVADataset
13
+ from omg_llava.dataset.collect_fns import omg_llava_collate_fn
14
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
15
+ from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\
16
+ CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\
17
+ ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\
18
+ PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\
19
+ RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\
20
+ Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\
21
+ OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\
22
+ OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\
23
+ MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset
24
+ from xtuner.dataset.samplers import LengthGroupedSampler
25
+ from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens
26
+ from xtuner.engine.runner import TrainLoop
27
+ from omg_llava.model import OMG_LLaVA
28
+ from xtuner.utils import PROMPT_TEMPLATE
29
+ from omg_llava.model import OpenCLIPBackbone_omgseg
30
+ from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead
31
+
32
+ from torch.nn import GroupNorm, ReLU
33
+
34
+ from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \
35
+ DiceLoss, MaskFormerFusionHead, FocalLoss
36
+ from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost
37
+ from mmdet.models.task_modules.samplers import MaskPseudoSampler
38
+
39
+ #######################################################################
40
+ # PART 1 Settings #
41
+ #######################################################################
42
+ # Model
43
+ llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path
44
+ pretrained_pth = './work_dirs/omg_llava_7b_finetune_stage1_1024image_8gpus/iter_27600.pth' # noqa: E501
45
+ omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path
46
+ omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path
47
+
48
+ # Data
49
+ data_root = './data/llava_data/'
50
+ data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json'
51
+ image_folder = data_root + 'llava_images'
52
+
53
+ glamm_data_root = './data/glamm_data/'
54
+
55
+ refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/'
56
+ refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json'
57
+
58
+ grandf_image_path = glamm_data_root + 'images/grandf/train/'
59
+ grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json'
60
+
61
+ flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/'
62
+ flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json'
63
+
64
+ psg_image_path = glamm_data_root + 'images/coco2017/'
65
+ psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json'
66
+
67
+ ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
68
+ ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json'
69
+
70
+ cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/'
71
+ cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt'
72
+ cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/'
73
+
74
+ mapillary_image_path = './data/semantic_seg/mapillary/training/images/'
75
+ mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json'
76
+ mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/'
77
+
78
+ pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/'
79
+ pascal_file = './data/semantic_seg/pascal_part/train.json'
80
+
81
+ paco_image_path = './data/glamm_data/images/coco2017/'
82
+ paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json'
83
+
84
+ referring_refcoco_image_path = refcocog_image_path
85
+ referring_refcoco_data_path = "./data/ref_seg/"
86
+
87
+ referring_refcoco_plus_image_path = refcocog_image_path
88
+ referring_refcoco_plus_data_path = "./data/ref_seg/"
89
+
90
+ referring_refcocog_image_path = refcocog_image_path
91
+ referring_refcocog_data_path = "./data/ref_seg/"
92
+
93
+ referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/"
94
+ referring_refclef_data_path = "./data/ref_seg/"
95
+
96
+ region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
97
+ region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json"
98
+
99
+ region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
100
+ region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json"
101
+
102
+ mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
103
+ mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json'
104
+
105
+ mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
106
+ mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json'
107
+
108
+ mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
109
+ mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json'
110
+
111
+ mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
112
+ mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json'
113
+
114
+ mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
115
+ mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json'
116
+
117
+ mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
118
+ mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json'
119
+
120
+ mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
121
+ mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json'
122
+
123
+ mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
124
+ mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json'
125
+
126
+ mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
127
+ mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json'
128
+
129
+ mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
130
+ mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json'
131
+
132
+ mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
133
+ mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json'
134
+
135
+ mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017'
136
+ mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json'
137
+
138
+ mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
139
+ mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json'
140
+
141
+ mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
142
+ mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json'
143
+
144
+ mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/'
145
+ mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json'
146
+
147
+ mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017'
148
+ mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json'
149
+
150
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
151
+ max_length = int(2048 - (1024 / 64)**2 - 100)
152
+
153
+ # Scheduler & Optimizer
154
+ batch_size = 8 # per_device
155
+ accumulative_counts = 2
156
+ dataloader_num_workers = 4
157
+ max_epochs = 1
158
+ optim_type = AdamW
159
+ lr = 2e-4
160
+ betas = (0.9, 0.999)
161
+ weight_decay = 0
162
+ max_norm = 1 # grad clip
163
+ warmup_ratio = 0.03
164
+
165
+
166
+ # Save
167
+ save_steps = 2000
168
+ save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited)
169
+
170
+ # Evaluate the generation performance during the training
171
+ evaluation_freq = 2000
172
+ SYSTEM = ''
173
+ evaluation_images = './work_dirs/test.jpg'
174
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture',
175
+ 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.']
176
+
177
+ #######################################################################
178
+ # PART 2 Model & Tokenizer & Image Processor #
179
+ #######################################################################
180
+ tokenizer = dict(
181
+ type=AutoTokenizer.from_pretrained,
182
+ pretrained_model_name_or_path=llm_name_or_path,
183
+ trust_remote_code=True,
184
+ padding_side='right')
185
+
186
+ image_processor = dict(
187
+ type=CLIPImageProcessor,
188
+ do_resize=True,
189
+ size=1024,
190
+ resample=3,
191
+ do_center_crop=True,
192
+ crop_size=1024,
193
+ do_rescale=True,
194
+ do_normalize=True,
195
+ image_mean=[0.4814, 0.4578, 0.4082],
196
+ image_std=[0.2686, 0.2613, 0.2757],
197
+ do_convert_rgb=True
198
+ )
199
+
200
+ class_embed = 'convnext_large_d_320_CocoPanopticOVDataset'
201
+ num_things_classes = 80
202
+ num_stuff_classes = 53
203
+ num_classes = num_things_classes + num_stuff_classes
204
+
205
+ omgseg_model = dict(
206
+ type=OMGSegVisualEncoder,
207
+ data_preprocessor=None,
208
+ pixel_shuffle_down_ratio=2,
209
+ backbone=dict(
210
+ type=OpenCLIPBackbone_omgseg,
211
+ model_name='convnext_large_d_320',
212
+ fix=True,
213
+ init_cfg=dict(
214
+ type='clip_pretrain',
215
+ checkpoint='laion2b_s29b_b131k_ft_soup'
216
+ )
217
+ ),
218
+ panoptic_head=dict(
219
+ type=Mask2FormerVideoSemSamHead,
220
+ sphere_cls=True,
221
+ ov_path=omg_ov_class_embed_path,
222
+ enable_box_query=False,
223
+ ov_classifier_name=class_embed,
224
+ logit=None,
225
+ in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside
226
+ strides=[4, 8, 16, 32],
227
+ feat_channels=256,
228
+ out_channels=256,
229
+ num_things_classes=num_things_classes,
230
+ num_stuff_classes=num_stuff_classes,
231
+ num_queries=300,
232
+ num_transformer_feat_level=3,
233
+ pixel_decoder=dict(
234
+ type=MSDeformAttnPixelDecoder,
235
+ num_outs=3,
236
+ norm_cfg=dict(type=GroupNorm, num_groups=32),
237
+ act_cfg=dict(type=ReLU),
238
+ encoder=dict( # DeformableDetrTransformerEncoder
239
+ num_layers=6,
240
+ layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
241
+ self_attn_cfg=dict( # MultiScaleDeformableAttention
242
+ embed_dims=256,
243
+ num_heads=8,
244
+ num_levels=3,
245
+ num_points=4,
246
+ dropout=0.0,
247
+ batch_first=True),
248
+ ffn_cfg=dict(
249
+ embed_dims=256,
250
+ feedforward_channels=1024,
251
+ num_fcs=2,
252
+ ffn_drop=0.0,
253
+ act_cfg=dict(type=ReLU, inplace=True)))),
254
+ positional_encoding=dict(num_feats=128, normalize=True)),
255
+ enforce_decoder_input_project=False,
256
+ positional_encoding=dict(num_feats=128, normalize=True),
257
+ transformer_decoder=dict( # Mask2FormerTransformerDecoder
258
+ return_intermediate=True,
259
+ num_layers=9,
260
+ layer_cfg=dict( # Mask2FormerTransformerDecoderLayer
261
+ self_attn_cfg=dict( # MultiheadAttention
262
+ embed_dims=256,
263
+ num_heads=8,
264
+ dropout=0.0,
265
+ batch_first=True),
266
+ cross_attn_cfg=dict( # MultiheadAttention
267
+ embed_dims=256,
268
+ num_heads=8,
269
+ dropout=0.0,
270
+ batch_first=True),
271
+ ffn_cfg=dict(
272
+ embed_dims=256,
273
+ feedforward_channels=2048,
274
+ num_fcs=2,
275
+ ffn_drop=0.0,
276
+ act_cfg=dict(type='ReLU', inplace=True))),
277
+ init_cfg=None),
278
+ loss_cls=dict(
279
+ type=CrossEntropyLoss,
280
+ use_sigmoid=False,
281
+ loss_weight=2.0,
282
+ reduction='mean',
283
+ class_weight=[1.0] * 240 + [0.1]),
284
+ loss_mask=dict(
285
+ type=CrossEntropyLoss,
286
+ use_sigmoid=True,
287
+ reduction='mean',
288
+ loss_weight=5.0),
289
+ loss_dice=dict(
290
+ type=DiceLoss,
291
+ use_sigmoid=True,
292
+ activate=True,
293
+ reduction='mean',
294
+ naive_dice=True,
295
+ eps=1.0,
296
+ loss_weight=5.0),
297
+ loss_iou=dict(
298
+ type=FocalLoss,
299
+ use_sigmoid=True,
300
+ loss_weight=2.0,
301
+ reduction='mean')
302
+ ),
303
+ panoptic_fusion_head=dict(
304
+ type=MaskFormerFusionHead,
305
+ num_things_classes=num_things_classes,
306
+ num_stuff_classes=num_stuff_classes,
307
+ loss_panoptic=None,
308
+ init_cfg=None),
309
+ train_cfg=dict(
310
+ num_points=12544,
311
+ oversample_ratio=3.0,
312
+ importance_sample_ratio=0.75,
313
+ assigner=dict(
314
+ type=HungarianAssigner,
315
+ match_costs=[
316
+ # dict(type=FlexibleClassificationCost, weight=2.0),
317
+ dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True),
318
+ dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0)
319
+ ]),
320
+ sampler=dict(type=MaskPseudoSampler)),
321
+ test_cfg=dict(
322
+ panoptic_on=True,
323
+ # For now, the dataset does not support
324
+ # evaluating semantic segmentation metric.
325
+ semantic_on=False,
326
+ instance_on=True,
327
+ # max_per_image is for instance segmentation.
328
+ max_per_image=100,
329
+ iou_thr=0.8,
330
+ # In Mask2Former's panoptic postprocessing,
331
+ # it will filter mask area where score is less than 0.5 .
332
+ filter_low_score=True),
333
+ init_cfg=dict(
334
+ type='Pretrained',
335
+ checkpoint=omg_head_pretrain_pth_path,
336
+ )
337
+ )
338
+
339
+ model = dict(
340
+ type=OMG_LLaVA,
341
+ freeze_llm=True,
342
+ freeze_visual_encoder=True,
343
+ require_omg_decoder=True,
344
+ freeze_llm_with_lora=False,
345
+ pretrained_pth=pretrained_pth,
346
+ text2vision_projector=True,
347
+ pixel_shuffle_ratio=2,
348
+ llm=dict(
349
+ type=AutoModelForCausalLM.from_pretrained,
350
+ pretrained_model_name_or_path=llm_name_or_path,
351
+ trust_remote_code=True,
352
+ torch_dtype=torch.float16,
353
+ quantization_config=dict(
354
+ type=BitsAndBytesConfig,
355
+ load_in_4bit=True,
356
+ load_in_8bit=False,
357
+ llm_int8_threshold=6.0,
358
+ llm_int8_has_fp16_weight=False,
359
+ bnb_4bit_compute_dtype=torch.float16,
360
+ bnb_4bit_use_double_quant=True,
361
+ bnb_4bit_quant_type='nf4')),
362
+ llm_lora=dict(
363
+ type=LoraConfig,
364
+ r=512,
365
+ lora_alpha=256,
366
+ lora_dropout=0.05,
367
+ bias='none',
368
+ task_type='CAUSAL_LM'),
369
+ visual_encoder=omgseg_model,
370
+ tokenizer=tokenizer,
371
+ )
372
+
373
+ #######################################################################
374
+ # PART 3 Dataset & Dataloader #
375
+ #######################################################################
376
+ debug=False
377
+ llava_dataset = dict(
378
+ type=LLaVADataset,
379
+ data_path=data_path,
380
+ image_folder=image_folder,
381
+ tokenizer=tokenizer,
382
+ image_processor=image_processor,
383
+ dataset_map_fn=llava_map_fn,
384
+ template_map_fn=dict(
385
+ type=template_map_fn_factory, template=prompt_template),
386
+ max_length=max_length,
387
+ pad_image_to_square=True)
388
+
389
+ glamm_refcocog_dataset = dict(
390
+ type=RefCOCOgGCGDataset,
391
+ data_path=refcocog_ann_file,
392
+ image_folder=refcocog_image_path,
393
+ tokenizer=tokenizer,
394
+ image_processor=image_processor,
395
+ dataset_map_fn=glamm_refcocog_map_fn,
396
+ template_map_fn=dict(
397
+ type=template_map_fn_factory, template=prompt_template),
398
+ max_length=max_length,
399
+ pad_image_to_square=True,
400
+ debug=False,
401
+ repeats=1,
402
+ )
403
+
404
+ glamm_grandf_dataset = dict(
405
+ type=GranDfGCGDataset,
406
+ data_path=grandf_ann_file,
407
+ image_folder=grandf_image_path,
408
+ tokenizer=tokenizer,
409
+ image_processor=image_processor,
410
+ dataset_map_fn=glamm_granf_map_fn,
411
+ template_map_fn=dict(
412
+ type=template_map_fn_factory, template=prompt_template),
413
+ max_length=max_length,
414
+ pad_image_to_square=True,
415
+ debug=debug,
416
+ repeats=10,
417
+ )
418
+
419
+ glamm_psg_dataset = dict(
420
+ type=OpenPsgGCGDataset,
421
+ data_path=psg_ann_file,
422
+ image_folder=psg_image_path,
423
+ tokenizer=tokenizer,
424
+ image_processor=image_processor,
425
+ dataset_map_fn=glamm_openpsg_map_fn,
426
+ template_map_fn=dict(
427
+ type=template_map_fn_factory, template=prompt_template),
428
+ max_length=max_length,
429
+ pad_image_to_square=True,
430
+ debug=debug,
431
+ repeats=1,
432
+ )
433
+
434
+ glamm_flickr_dataset = dict(
435
+ type=FlickrGCGDataset,
436
+ data_path=flickr_ann_file,
437
+ image_folder=flickr_image_path,
438
+ tokenizer=tokenizer,
439
+ image_processor=image_processor,
440
+ dataset_map_fn=glamm_flickr_map_fn,
441
+ template_map_fn=dict(
442
+ type=template_map_fn_factory, template=prompt_template),
443
+ max_length=max_length,
444
+ pad_image_to_square=True,
445
+ debug=debug,
446
+ repeats=1,
447
+ )
448
+
449
+ semantic_seg_ade20k_dataset = dict(
450
+ type=ADE20kSemanticSegDataset,
451
+ data_path=ade20k_class_file,
452
+ image_folder=ade20k_image_path,
453
+ tokenizer=tokenizer,
454
+ image_processor=image_processor,
455
+ dataset_map_fn=semantic_seg_map_fn,
456
+ template_map_fn=dict(
457
+ type=template_map_fn_factory, template=prompt_template),
458
+ max_length=max_length,
459
+ pad_image_to_square=True,
460
+ debug=False,
461
+ repeats=1,
462
+ )
463
+
464
+ semantic_seg_cocostuff_dataset = dict(
465
+ type=COCOStuffSemanticSegDataset,
466
+ data_path=cocostuff_class_file,
467
+ image_folder=cocostuff_image_path,
468
+ label_path=cocostuff_label_path,
469
+ tokenizer=tokenizer,
470
+ image_processor=image_processor,
471
+ dataset_map_fn=semantic_seg_map_fn,
472
+ template_map_fn=dict(
473
+ type=template_map_fn_factory, template=prompt_template),
474
+ max_length=max_length,
475
+ pad_image_to_square=True,
476
+ debug=False,
477
+ repeats=1,
478
+ )
479
+
480
+ semantic_seg_mapillary_dataset = dict(
481
+ type=MapillarySemanticSegDataset,
482
+ data_path=mapillary_class_file,
483
+ image_folder=mapillary_image_path,
484
+ label_path=mapillary_label_path,
485
+ tokenizer=tokenizer,
486
+ image_processor=image_processor,
487
+ dataset_map_fn=semantic_seg_map_fn,
488
+ template_map_fn=dict(
489
+ type=template_map_fn_factory, template=prompt_template),
490
+ max_length=max_length,
491
+ pad_image_to_square=True,
492
+ debug=False,
493
+ repeats=1,
494
+ )
495
+
496
+ semantic_seg_pascal_part_dataset = dict(
497
+ type=PascalPartSemanticSegDataset,
498
+ data_path=pascal_file,
499
+ image_folder=pascal_part_image_path,
500
+ tokenizer=tokenizer,
501
+ image_processor=image_processor,
502
+ dataset_map_fn=pascal_part_map_fn,
503
+ template_map_fn=dict(
504
+ type=template_map_fn_factory, template=prompt_template),
505
+ max_length=max_length,
506
+ pad_image_to_square=True,
507
+ debug=False,
508
+ repeats=1,
509
+ )
510
+
511
+ semantic_seg_paco_dataset = dict(
512
+ type=PacoSemanticSegDataset,
513
+ data_path=paco_file,
514
+ image_folder=paco_image_path,
515
+ tokenizer=tokenizer,
516
+ image_processor=image_processor,
517
+ dataset_map_fn=pascal_part_map_fn,
518
+ template_map_fn=dict(
519
+ type=template_map_fn_factory, template=prompt_template),
520
+ max_length=max_length,
521
+ pad_image_to_square=True,
522
+ debug=False,
523
+ repeats=1,
524
+ )
525
+
526
+ referring_seg_refcoco_dataset = dict(
527
+ type=RefcocoReferringSegDataset,
528
+ data_path=referring_refcoco_data_path,
529
+ image_folder=referring_refcoco_image_path,
530
+ tokenizer=tokenizer,
531
+ image_processor=image_processor,
532
+ dataset_map_fn=referring_seg_map_fn,
533
+ template_map_fn=dict(
534
+ type=template_map_fn_factory, template=prompt_template),
535
+ max_length=max_length,
536
+ pad_image_to_square=True,
537
+ debug=False,
538
+ repeats=1,
539
+ )
540
+
541
+ referring_seg_refcoco_plus_dataset = dict(
542
+ type=Refcoco_plus_ReferringSegDataset,
543
+ data_path=referring_refcoco_plus_data_path,
544
+ image_folder=referring_refcoco_plus_image_path,
545
+ tokenizer=tokenizer,
546
+ image_processor=image_processor,
547
+ dataset_map_fn=referring_seg_map_fn,
548
+ template_map_fn=dict(
549
+ type=template_map_fn_factory, template=prompt_template),
550
+ max_length=max_length,
551
+ pad_image_to_square=True,
552
+ debug=False,
553
+ repeats=1,
554
+ )
555
+
556
+ referring_seg_refcocog_dataset = dict(
557
+ type=Refcocog_ReferringSegDataset,
558
+ data_path=referring_refcocog_data_path,
559
+ image_folder=referring_refcocog_image_path,
560
+ tokenizer=tokenizer,
561
+ image_processor=image_processor,
562
+ dataset_map_fn=referring_seg_map_fn,
563
+ template_map_fn=dict(
564
+ type=template_map_fn_factory, template=prompt_template),
565
+ max_length=max_length,
566
+ pad_image_to_square=True,
567
+ debug=False,
568
+ repeats=1,
569
+ )
570
+
571
+ referring_seg_refclef_dataset = dict(
572
+ type=Refclef_ReferringSegDataset,
573
+ data_path=referring_refclef_data_path,
574
+ image_folder=referring_refclef_image_path,
575
+ tokenizer=tokenizer,
576
+ image_processor=image_processor,
577
+ dataset_map_fn=referring_seg_map_fn,
578
+ template_map_fn=dict(
579
+ type=template_map_fn_factory, template=prompt_template),
580
+ max_length=max_length,
581
+ pad_image_to_square=True,
582
+ debug=False,
583
+ repeats=1,
584
+ )
585
+
586
+ region_cap_osprey_dataset = dict(
587
+ type=OspreyRegionCaptionDataset,
588
+ data_path=region_cap_osprey_data_path,
589
+ image_folder=region_cap_osprey_image_path,
590
+ tokenizer=tokenizer,
591
+ image_processor=image_processor,
592
+ dataset_map_fn=osprey_region_caption_map_fn,
593
+ template_map_fn=dict(
594
+ type=template_map_fn_factory, template=prompt_template),
595
+ max_length=max_length,
596
+ pad_image_to_square=True,
597
+ debug=False,
598
+ repeats=1,
599
+ )
600
+
601
+ region_conversation_osprey_dataset = dict(
602
+ type=OspreyRegionConversationDataset,
603
+ data_path=region_conversation_osprey_data_path,
604
+ image_folder=region_conversation_osprey_image_path,
605
+ tokenizer=tokenizer,
606
+ image_processor=image_processor,
607
+ dataset_map_fn=osprey_region_conversation_map_fn,
608
+ template_map_fn=dict(
609
+ type=template_map_fn_factory, template=prompt_template),
610
+ max_length=max_length,
611
+ pad_image_to_square=True,
612
+ debug=False,
613
+ repeats=1,
614
+ )
615
+
616
+ mdpv_detailed_description_ade20k_dataset = dict(
617
+ type=MDPVPointDetailedCaptionDataset,
618
+ data_path=mdpv_detailed_caption_ade20k_data_path,
619
+ image_folder=mdpv_detailed_caption_ade20k_image_path,
620
+ tokenizer=tokenizer,
621
+ image_processor=image_processor,
622
+ dataset_map_fn=mdpv_points_map_fn,
623
+ template_map_fn=dict(
624
+ type=template_map_fn_factory, template=prompt_template),
625
+ max_length=max_length,
626
+ pad_image_to_square=True,
627
+ debug=False,
628
+ repeats=1,
629
+ )
630
+
631
+ mdpv_detailed_description_cocostuff_10k_dataset = dict(
632
+ type=MDPVPointDetailedCaptionDataset,
633
+ data_path=mdpv_detailed_caption_cocostuff_10k_data_path,
634
+ image_folder=mdpv_detailed_caption_cocostuff_10k_image_path,
635
+ tokenizer=tokenizer,
636
+ image_processor=image_processor,
637
+ dataset_map_fn=mdpv_points_map_fn,
638
+ template_map_fn=dict(
639
+ type=template_map_fn_factory, template=prompt_template),
640
+ max_length=max_length,
641
+ pad_image_to_square=True,
642
+ debug=False,
643
+ repeats=1,
644
+ )
645
+
646
+ mdpv_detailed_description_cocostuff_164k_dataset = dict(
647
+ type=MDPVPointDetailedCaptionDataset,
648
+ data_path=mdpv_detailed_caption_cocostuff_164k_data_path,
649
+ image_folder=mdpv_detailed_caption_cocostuff_164k_image_path,
650
+ tokenizer=tokenizer,
651
+ image_processor=image_processor,
652
+ dataset_map_fn=mdpv_points_map_fn,
653
+ template_map_fn=dict(
654
+ type=template_map_fn_factory, template=prompt_template),
655
+ max_length=max_length,
656
+ pad_image_to_square=True,
657
+ debug=False,
658
+ repeats=1,
659
+ )
660
+
661
+ mdpv_detailed_description_vg_dataset = dict(
662
+ type=MDPVPointDetailedCaptionDataset,
663
+ data_path=mdpv_detailed_caption_vg_data_path,
664
+ image_folder=mdpv_detailed_caption_vg_image_path,
665
+ tokenizer=tokenizer,
666
+ image_processor=image_processor,
667
+ dataset_map_fn=mdpv_points_map_fn,
668
+ template_map_fn=dict(
669
+ type=template_map_fn_factory, template=prompt_template),
670
+ max_length=max_length,
671
+ pad_image_to_square=True,
672
+ debug=False,
673
+ repeats=1,
674
+ )
675
+
676
+ mdpv_brief_description_vg_dataset = dict(
677
+ type=MDPVPointBriefCaptionDataset,
678
+ data_path=mdpv_brief_caption_vg_data_path,
679
+ image_folder=mdpv_brief_caption_vg_image_path,
680
+ tokenizer=tokenizer,
681
+ image_processor=image_processor,
682
+ dataset_map_fn=mdpv_points_map_fn,
683
+ template_map_fn=dict(
684
+ type=template_map_fn_factory, template=prompt_template),
685
+ max_length=max_length,
686
+ pad_image_to_square=True,
687
+ debug=False,
688
+ repeats=1,
689
+ )
690
+
691
+ mdpv_brief_description_cocostuff10k_dataset = dict(
692
+ type=MDPVPointBriefCaptionDataset,
693
+ data_path=mdpv_brief_caption_cocostuff_10k_data_path,
694
+ image_folder=mdpv_brief_caption_cocostuff_10k_image_path,
695
+ tokenizer=tokenizer,
696
+ image_processor=image_processor,
697
+ dataset_map_fn=mdpv_points_map_fn,
698
+ template_map_fn=dict(
699
+ type=template_map_fn_factory, template=prompt_template),
700
+ max_length=max_length,
701
+ pad_image_to_square=True,
702
+ debug=False,
703
+ repeats=1,
704
+ )
705
+
706
+ mdpv_brief_description_cocostuff164k_dataset = dict(
707
+ type=MDPVPointBriefCaptionDataset,
708
+ data_path=mdpv_brief_caption_cocostuff_164k_data_path,
709
+ image_folder=mdpv_brief_caption_cocostuff_164k_image_path,
710
+ tokenizer=tokenizer,
711
+ image_processor=image_processor,
712
+ dataset_map_fn=mdpv_points_map_fn,
713
+ template_map_fn=dict(
714
+ type=template_map_fn_factory, template=prompt_template),
715
+ max_length=max_length,
716
+ pad_image_to_square=True,
717
+ debug=False,
718
+ repeats=1,
719
+ )
720
+
721
+ mdpv_brief_description_ade20k_dataset = dict(
722
+ type=MDPVPointBriefCaptionDataset,
723
+ data_path=mdpv_brief_caption_ade20k_data_path,
724
+ image_folder=mdpv_brief_caption_ade20k_image_path,
725
+ tokenizer=tokenizer,
726
+ image_processor=image_processor,
727
+ dataset_map_fn=mdpv_points_map_fn,
728
+ template_map_fn=dict(
729
+ type=template_map_fn_factory, template=prompt_template),
730
+ max_length=max_length,
731
+ pad_image_to_square=True,
732
+ debug=False,
733
+ repeats=1,
734
+ )
735
+
736
+ mdpv_brief_description_lvis_dataset = dict(
737
+ type=MDPVPointBriefCaptionDataset,
738
+ data_path=mdpv_brief_caption_lvis_data_path,
739
+ image_folder=mdpv_brief_caption_lvis_image_path,
740
+ tokenizer=tokenizer,
741
+ image_processor=image_processor,
742
+ dataset_map_fn=mdpv_points_map_fn,
743
+ template_map_fn=dict(
744
+ type=template_map_fn_factory, template=prompt_template),
745
+ max_length=max_length,
746
+ pad_image_to_square=True,
747
+ debug=False,
748
+ repeats=1,
749
+ )
750
+
751
+ mdpv_qa_vg_dataset = dict(
752
+ type=MDPVPointBriefCaptionDataset,
753
+ data_path=mdpv_qa_vg_data_path,
754
+ image_folder=mdpv_qa_vg_image_path,
755
+ tokenizer=tokenizer,
756
+ image_processor=image_processor,
757
+ dataset_map_fn=mdpv_points_map_fn,
758
+ template_map_fn=dict(
759
+ type=template_map_fn_factory, template=prompt_template),
760
+ max_length=max_length,
761
+ pad_image_to_square=True,
762
+ debug=False,
763
+ repeats=1,
764
+ )
765
+
766
+ mdpv_qa_ade20k_dataset = dict(
767
+ type=MDPVPointBriefCaptionDataset,
768
+ data_path=mdpv_qa_ade20k_data_path,
769
+ image_folder=mdpv_qa_ade20k_image_path,
770
+ tokenizer=tokenizer,
771
+ image_processor=image_processor,
772
+ dataset_map_fn=mdpv_points_map_fn,
773
+ template_map_fn=dict(
774
+ type=template_map_fn_factory, template=prompt_template),
775
+ max_length=max_length,
776
+ pad_image_to_square=True,
777
+ debug=False,
778
+ repeats=1,
779
+ )
780
+
781
+ mdpv_qa_lvis_dataset = dict(
782
+ type=MDPVPointBriefCaptionDataset,
783
+ data_path=mdpv_qa_lvis_data_path,
784
+ image_folder=mdpv_qa_lvis_image_path,
785
+ tokenizer=tokenizer,
786
+ image_processor=image_processor,
787
+ dataset_map_fn=mdpv_points_map_fn,
788
+ template_map_fn=dict(
789
+ type=template_map_fn_factory, template=prompt_template),
790
+ max_length=max_length,
791
+ pad_image_to_square=True,
792
+ debug=False,
793
+ repeats=1,
794
+ )
795
+
796
+ mdpv_qa_cocostuff10k_dataset = dict(
797
+ type=MDPVPointBriefCaptionDataset,
798
+ data_path=mdpv_qa_cocostuff10k_data_path,
799
+ image_folder=mdpv_qa_cocostuff10k_image_path,
800
+ tokenizer=tokenizer,
801
+ image_processor=image_processor,
802
+ dataset_map_fn=mdpv_points_map_fn,
803
+ template_map_fn=dict(
804
+ type=template_map_fn_factory, template=prompt_template),
805
+ max_length=max_length,
806
+ pad_image_to_square=True,
807
+ debug=False,
808
+ repeats=1,
809
+ )
810
+
811
+ mdpv_qa_cocostuff164k_dataset = dict(
812
+ type=MDPVPointBriefCaptionDataset,
813
+ data_path=mdpv_qa_cocostuff164k_data_path,
814
+ image_folder=mdpv_qa_cocostuff164k_image_path,
815
+ tokenizer=tokenizer,
816
+ image_processor=image_processor,
817
+ dataset_map_fn=mdpv_points_map_fn,
818
+ template_map_fn=dict(
819
+ type=template_map_fn_factory, template=prompt_template),
820
+ max_length=max_length,
821
+ pad_image_to_square=True,
822
+ debug=False,
823
+ repeats=1,
824
+ )
825
+
826
+ mdpv_multi_points_openpsg_dataset = dict(
827
+ type=MDPVPointBriefCaptionDataset,
828
+ data_path=mdpv_multi_points_openpsg_data_path,
829
+ image_folder=mdpv_multi_points_openpsg_image_path,
830
+ tokenizer=tokenizer,
831
+ image_processor=image_processor,
832
+ dataset_map_fn=mdpv_points_map_fn,
833
+ template_map_fn=dict(
834
+ type=template_map_fn_factory, template=prompt_template),
835
+ max_length=max_length,
836
+ pad_image_to_square=True,
837
+ debug=False,
838
+ repeats=1,
839
+ )
840
+
841
+ mdpv_multi_points_flicker30k_dataset = dict(
842
+ type=MDPVPointBriefCaptionDataset,
843
+ data_path=mdpv_multi_points_flicker30k_data_path,
844
+ image_folder=mdpv_multi_points_flicker30k_image_path,
845
+ tokenizer=tokenizer,
846
+ image_processor=image_processor,
847
+ dataset_map_fn=mdpv_points_map_fn,
848
+ template_map_fn=dict(
849
+ type=template_map_fn_factory, template=prompt_template),
850
+ max_length=max_length,
851
+ pad_image_to_square=True,
852
+ debug=False,
853
+ repeats=1,
854
+ )
855
+
856
+ train_dataset = dict(
857
+ type=CombineDataset,
858
+ datasets_cfgs=[llava_dataset, glamm_flickr_dataset, glamm_refcocog_dataset,
859
+ glamm_grandf_dataset, glamm_psg_dataset,
860
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset, # repeat 3x
861
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset,
862
+ semantic_seg_ade20k_dataset, semantic_seg_cocostuff_dataset,
863
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
864
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x
865
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
866
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset,
867
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
868
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset,
869
+ region_cap_osprey_dataset, region_conversation_osprey_dataset,
870
+ mdpv_detailed_description_ade20k_dataset,
871
+ mdpv_detailed_description_cocostuff_10k_dataset,
872
+ mdpv_detailed_description_cocostuff_164k_dataset,
873
+ mdpv_detailed_description_vg_dataset,
874
+ mdpv_brief_description_lvis_dataset,
875
+ mdpv_brief_description_vg_dataset,
876
+ mdpv_brief_description_ade20k_dataset,
877
+ mdpv_brief_description_cocostuff10k_dataset,
878
+ mdpv_brief_description_cocostuff164k_dataset,
879
+ mdpv_qa_vg_dataset,
880
+ mdpv_qa_lvis_dataset,
881
+ mdpv_qa_ade20k_dataset,
882
+ mdpv_qa_cocostuff10k_dataset,
883
+ mdpv_qa_cocostuff164k_dataset,
884
+ mdpv_multi_points_flicker30k_dataset,
885
+ mdpv_multi_points_openpsg_dataset,],
886
+ )
887
+
888
+ train_dataloader = dict(
889
+ batch_size=batch_size,
890
+ num_workers=dataloader_num_workers,
891
+ dataset=train_dataset,
892
+ sampler=dict(
893
+ type=LengthGroupedSampler,
894
+ length_property='modality_length',
895
+ per_device_batch_size=batch_size * accumulative_counts),
896
+ collate_fn=dict(type=omg_llava_collate_fn))
897
+
898
+ #######################################################################
899
+ # PART 4 Scheduler & Optimizer #
900
+ #######################################################################
901
+ # optimizer
902
+ optim_wrapper = dict(
903
+ type=AmpOptimWrapper,
904
+ optimizer=dict(
905
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
906
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
907
+ accumulative_counts=accumulative_counts,
908
+ loss_scale='dynamic',
909
+ dtype='float16')
910
+
911
+ # learning policy
912
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
913
+ param_scheduler = [
914
+ dict(
915
+ type=LinearLR,
916
+ start_factor=1e-5,
917
+ by_epoch=True,
918
+ begin=0,
919
+ end=warmup_ratio * max_epochs,
920
+ convert_to_iter_based=True),
921
+ dict(
922
+ type=CosineAnnealingLR,
923
+ eta_min=0.0,
924
+ by_epoch=True,
925
+ begin=warmup_ratio * max_epochs,
926
+ end=max_epochs,
927
+ convert_to_iter_based=True)
928
+ ]
929
+
930
+ # train, val, test setting
931
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
932
+
933
+ #######################################################################
934
+ # PART 5 Runtime #
935
+ #######################################################################
936
+ # Log the dialogue periodically during the training process, optional
937
+ custom_hooks = [
938
+ dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer),
939
+ dict(
940
+ type=EvaluateChatHook_withSpecialTokens,
941
+ tokenizer=tokenizer,
942
+ image_processor=image_processor,
943
+ every_n_iters=evaluation_freq,
944
+ evaluation_inputs=evaluation_inputs,
945
+ evaluation_images=evaluation_images,
946
+ system=SYSTEM,
947
+ prompt_template=prompt_template)
948
+ ]
949
+
950
+ # configure default hooks
951
+ default_hooks = dict(
952
+ # record the time of every iteration.
953
+ timer=dict(type=IterTimerHook),
954
+ # print log every 10 iterations.
955
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
956
+ # enable the parameter scheduler.
957
+ param_scheduler=dict(type=ParamSchedulerHook),
958
+ # save checkpoint per `save_steps`.
959
+ checkpoint=dict(
960
+ type=CheckpointHook,
961
+ by_epoch=False,
962
+ interval=save_steps,
963
+ max_keep_ckpts=save_total_limit),
964
+ # set sampler seed in distributed evrionment.
965
+ sampler_seed=dict(type=DistSamplerSeedHook),
966
+ )
967
+
968
+ # configure environment
969
+ env_cfg = dict(
970
+ # whether to enable cudnn benchmark
971
+ cudnn_benchmark=False,
972
+ # set multi process parameters
973
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
974
+ # set distributed parameters
975
+ dist_cfg=dict(backend='nccl'),
976
+ )
977
+
978
+ # set visualizer
979
+ visualizer = None
980
+
981
+ # set log level
982
+ log_level = 'INFO'
983
+
984
+ # load from which checkpoint
985
+ load_from = None
986
+
987
+ # whether to resume training from the loaded checkpoint
988
+ resume = False
989
+
990
+ # Defaults to use random seed and disable `deterministic`
991
+ randomness = dict(seed=None, deterministic=False)
992
+
993
+ # set log processor
994
+ log_processor = dict(by_epoch=False)
omg_llava/configs/finetune/specific_tasks_finetune/finetune_gcg.py ADDED
@@ -0,0 +1,925 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
4
+ LoggerHook, ParamSchedulerHook)
5
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
6
+ from peft import LoraConfig
7
+ from torch.optim import AdamW
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig, CLIPImageProcessor,
10
+ CLIPVisionModel)
11
+
12
+ from omg_llava.dataset import LLaVADataset
13
+ from omg_llava.dataset.collect_fns import omg_llava_collate_fn
14
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
15
+ from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\
16
+ CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\
17
+ ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\
18
+ PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\
19
+ RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\
20
+ Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\
21
+ OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\
22
+ OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\
23
+ MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\
24
+ semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\
25
+ referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn
26
+ from xtuner.dataset.samplers import LengthGroupedSampler
27
+ from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens
28
+ from xtuner.engine.runner import TrainLoop
29
+ from omg_llava.model import OMG_LLaVA
30
+ from xtuner.utils import PROMPT_TEMPLATE
31
+ from omg_llava.model import OpenCLIPBackbone_omgseg
32
+ from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead
33
+
34
+ from torch.nn import GroupNorm, ReLU
35
+
36
+ from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \
37
+ DiceLoss, MaskFormerFusionHead, FocalLoss
38
+ from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost
39
+ from mmdet.models.task_modules.samplers import MaskPseudoSampler
40
+
41
+ #######################################################################
42
+ # PART 1 Settings #
43
+ #######################################################################
44
+ # Model
45
+ llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path
46
+ pretrained_pth = './work_dirs/omg_llava_1024x_2stage_finetune_1_clear_reratio_rmqcache_uniformSegFormat_8gpus.pth'
47
+ omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path
48
+ omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path
49
+
50
+ # Data
51
+ data_root = './data/llava_data/'
52
+ data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json'
53
+ image_folder = data_root + 'llava_images'
54
+
55
+ glamm_data_root = './data/glamm_data/'
56
+
57
+ refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/'
58
+ refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json'
59
+
60
+ grandf_image_path = glamm_data_root + 'images/grandf/train/'
61
+ grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json'
62
+
63
+ flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/'
64
+ flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json'
65
+
66
+ psg_image_path = glamm_data_root + 'images/coco2017/'
67
+ psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json'
68
+
69
+ ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
70
+ ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json'
71
+
72
+ cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/'
73
+ cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt'
74
+ cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/'
75
+
76
+ mapillary_image_path = './data/semantic_seg/mapillary/training/images/'
77
+ mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json'
78
+ mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/'
79
+
80
+ pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/'
81
+ pascal_file = './data/semantic_seg/pascal_part/train.json'
82
+
83
+ paco_image_path = './data/glamm_data/images/coco2017/'
84
+ paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json'
85
+
86
+ referring_refcoco_image_path = refcocog_image_path
87
+ referring_refcoco_data_path = "./data/ref_seg/"
88
+
89
+ referring_refcoco_plus_image_path = refcocog_image_path
90
+ referring_refcoco_plus_data_path = "./data/ref_seg/"
91
+
92
+ referring_refcocog_image_path = refcocog_image_path
93
+ referring_refcocog_data_path = "./data/ref_seg/"
94
+
95
+ referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/"
96
+ referring_refclef_data_path = "./data/ref_seg/"
97
+
98
+ region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
99
+ region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json"
100
+
101
+ region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
102
+ region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json"
103
+
104
+ mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
105
+ mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json'
106
+
107
+ mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
108
+ mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json'
109
+
110
+ mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
111
+ mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json'
112
+
113
+ mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
114
+ mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json'
115
+
116
+ mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
117
+ mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json'
118
+
119
+ mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
120
+ mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json'
121
+
122
+ mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
123
+ mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json'
124
+
125
+ mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
126
+ mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json'
127
+
128
+ mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
129
+ mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json'
130
+
131
+ mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
132
+ mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json'
133
+
134
+ mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
135
+ mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json'
136
+
137
+ mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017'
138
+ mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json'
139
+
140
+ mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
141
+ mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json'
142
+
143
+ mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
144
+ mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json'
145
+
146
+ mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/'
147
+ mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json'
148
+
149
+ mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017'
150
+ mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json'
151
+
152
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
153
+ max_length = int(2048 - (1024 / 64)**2 - 100)
154
+
155
+ # Scheduler & Optimizer
156
+ batch_size = 8 # per_device
157
+ accumulative_counts = 2
158
+ dataloader_num_workers = 4
159
+ max_epochs = 1
160
+ optim_type = AdamW
161
+ lr = 2e-4
162
+ betas = (0.9, 0.999)
163
+ weight_decay = 0
164
+ max_norm = 1 # grad clip
165
+ warmup_ratio = 0.03
166
+
167
+
168
+ # Save
169
+ save_steps = 2000
170
+ save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited)
171
+
172
+ # Evaluate the generation performance during the training
173
+ evaluation_freq = 2000
174
+ SYSTEM = ''
175
+ evaluation_images = './work_dirs/test.jpg'
176
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture',
177
+ 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.']
178
+
179
+ #######################################################################
180
+ # PART 2 Model & Tokenizer & Image Processor #
181
+ #######################################################################
182
+ tokenizer = dict(
183
+ type=AutoTokenizer.from_pretrained,
184
+ pretrained_model_name_or_path=llm_name_or_path,
185
+ trust_remote_code=True,
186
+ padding_side='right')
187
+
188
+ image_processor = dict(
189
+ type=CLIPImageProcessor,
190
+ do_resize=True,
191
+ size=1024,
192
+ resample=3,
193
+ do_center_crop=True,
194
+ crop_size=1024,
195
+ do_rescale=True,
196
+ do_normalize=True,
197
+ image_mean=[0.4814, 0.4578, 0.4082],
198
+ image_std=[0.2686, 0.2613, 0.2757],
199
+ do_convert_rgb=True
200
+ )
201
+
202
+ class_embed = 'convnext_large_d_320_CocoPanopticOVDataset'
203
+ num_things_classes = 80
204
+ num_stuff_classes = 53
205
+ num_classes = num_things_classes + num_stuff_classes
206
+
207
+ omgseg_model = dict(
208
+ type=OMGSegVisualEncoder,
209
+ data_preprocessor=None,
210
+ pixel_shuffle_down_ratio=2,
211
+ backbone=dict(
212
+ type=OpenCLIPBackbone_omgseg,
213
+ model_name='convnext_large_d_320',
214
+ fix=True,
215
+ init_cfg=dict(
216
+ type='clip_pretrain',
217
+ checkpoint='laion2b_s29b_b131k_ft_soup'
218
+ )
219
+ ),
220
+ panoptic_head=dict(
221
+ type=Mask2FormerVideoSemSamHead,
222
+ sphere_cls=True,
223
+ ov_path=omg_ov_class_embed_path,
224
+ enable_box_query=False,
225
+ ov_classifier_name=class_embed,
226
+ logit=None,
227
+ in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside
228
+ strides=[4, 8, 16, 32],
229
+ feat_channels=256,
230
+ out_channels=256,
231
+ num_things_classes=num_things_classes,
232
+ num_stuff_classes=num_stuff_classes,
233
+ num_queries=300,
234
+ num_transformer_feat_level=3,
235
+ pixel_decoder=dict(
236
+ type=MSDeformAttnPixelDecoder,
237
+ num_outs=3,
238
+ norm_cfg=dict(type=GroupNorm, num_groups=32),
239
+ act_cfg=dict(type=ReLU),
240
+ encoder=dict( # DeformableDetrTransformerEncoder
241
+ num_layers=6,
242
+ layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
243
+ self_attn_cfg=dict( # MultiScaleDeformableAttention
244
+ embed_dims=256,
245
+ num_heads=8,
246
+ num_levels=3,
247
+ num_points=4,
248
+ dropout=0.0,
249
+ batch_first=True),
250
+ ffn_cfg=dict(
251
+ embed_dims=256,
252
+ feedforward_channels=1024,
253
+ num_fcs=2,
254
+ ffn_drop=0.0,
255
+ act_cfg=dict(type=ReLU, inplace=True)))),
256
+ positional_encoding=dict(num_feats=128, normalize=True)),
257
+ enforce_decoder_input_project=False,
258
+ positional_encoding=dict(num_feats=128, normalize=True),
259
+ transformer_decoder=dict( # Mask2FormerTransformerDecoder
260
+ return_intermediate=True,
261
+ num_layers=9,
262
+ layer_cfg=dict( # Mask2FormerTransformerDecoderLayer
263
+ self_attn_cfg=dict( # MultiheadAttention
264
+ embed_dims=256,
265
+ num_heads=8,
266
+ dropout=0.0,
267
+ batch_first=True),
268
+ cross_attn_cfg=dict( # MultiheadAttention
269
+ embed_dims=256,
270
+ num_heads=8,
271
+ dropout=0.0,
272
+ batch_first=True),
273
+ ffn_cfg=dict(
274
+ embed_dims=256,
275
+ feedforward_channels=2048,
276
+ num_fcs=2,
277
+ ffn_drop=0.0,
278
+ act_cfg=dict(type='ReLU', inplace=True))),
279
+ init_cfg=None),
280
+ loss_cls=dict(
281
+ type=CrossEntropyLoss,
282
+ use_sigmoid=False,
283
+ loss_weight=2.0,
284
+ reduction='mean',
285
+ class_weight=[1.0] * 240 + [0.1]),
286
+ loss_mask=dict(
287
+ type=CrossEntropyLoss,
288
+ use_sigmoid=True,
289
+ reduction='mean',
290
+ loss_weight=5.0),
291
+ loss_dice=dict(
292
+ type=DiceLoss,
293
+ use_sigmoid=True,
294
+ activate=True,
295
+ reduction='mean',
296
+ naive_dice=True,
297
+ eps=1.0,
298
+ loss_weight=5.0),
299
+ loss_iou=dict(
300
+ type=FocalLoss,
301
+ use_sigmoid=True,
302
+ loss_weight=2.0,
303
+ reduction='mean')
304
+ ),
305
+ panoptic_fusion_head=dict(
306
+ type=MaskFormerFusionHead,
307
+ num_things_classes=num_things_classes,
308
+ num_stuff_classes=num_stuff_classes,
309
+ loss_panoptic=None,
310
+ init_cfg=None),
311
+ train_cfg=dict(
312
+ num_points=12544,
313
+ oversample_ratio=3.0,
314
+ importance_sample_ratio=0.75,
315
+ assigner=dict(
316
+ type=HungarianAssigner,
317
+ match_costs=[
318
+ # dict(type=FlexibleClassificationCost, weight=2.0),
319
+ dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True),
320
+ dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0)
321
+ ]),
322
+ sampler=dict(type=MaskPseudoSampler)),
323
+ test_cfg=dict(
324
+ panoptic_on=True,
325
+ # For now, the dataset does not support
326
+ # evaluating semantic segmentation metric.
327
+ semantic_on=False,
328
+ instance_on=True,
329
+ # max_per_image is for instance segmentation.
330
+ max_per_image=100,
331
+ iou_thr=0.8,
332
+ # In Mask2Former's panoptic postprocessing,
333
+ # it will filter mask area where score is less than 0.5 .
334
+ filter_low_score=True),
335
+ init_cfg=dict(
336
+ type='Pretrained',
337
+ checkpoint=omg_head_pretrain_pth_path,
338
+ )
339
+ )
340
+
341
+ model = dict(
342
+ type=OMG_LLaVA,
343
+ freeze_llm=True,
344
+ freeze_visual_encoder=True,
345
+ require_omg_decoder=True,
346
+ pretrained_pth=pretrained_pth,
347
+ text2vision_projector=True,
348
+ pixel_shuffle_ratio=2,
349
+ llm=dict(
350
+ type=AutoModelForCausalLM.from_pretrained,
351
+ pretrained_model_name_or_path=llm_name_or_path,
352
+ trust_remote_code=True,
353
+ torch_dtype=torch.float16,
354
+ quantization_config=dict(
355
+ type=BitsAndBytesConfig,
356
+ load_in_4bit=True,
357
+ load_in_8bit=False,
358
+ llm_int8_threshold=6.0,
359
+ llm_int8_has_fp16_weight=False,
360
+ bnb_4bit_compute_dtype=torch.float16,
361
+ bnb_4bit_use_double_quant=True,
362
+ bnb_4bit_quant_type='nf4')),
363
+ llm_lora=dict(
364
+ type=LoraConfig,
365
+ r=512,
366
+ lora_alpha=256,
367
+ lora_dropout=0.05,
368
+ bias='none',
369
+ task_type='CAUSAL_LM'),
370
+ visual_encoder=omgseg_model,
371
+ tokenizer=tokenizer,
372
+ )
373
+
374
+ #######################################################################
375
+ # PART 3 Dataset & Dataloader #
376
+ #######################################################################
377
+ debug=False
378
+ llava_dataset = dict(
379
+ type=LLaVADataset,
380
+ data_path=data_path,
381
+ image_folder=image_folder,
382
+ tokenizer=tokenizer,
383
+ image_processor=image_processor,
384
+ dataset_map_fn=llava_map_fn,
385
+ template_map_fn=dict(
386
+ type=template_map_fn_factory, template=prompt_template),
387
+ max_length=max_length,
388
+ pad_image_to_square=True)
389
+
390
+ glamm_refcocog_dataset = dict(
391
+ type=RefCOCOgGCGDataset,
392
+ data_path=refcocog_ann_file,
393
+ image_folder=refcocog_image_path,
394
+ tokenizer=tokenizer,
395
+ image_processor=image_processor,
396
+ dataset_map_fn=glamm_refcocog_map_fn,
397
+ template_map_fn=dict(
398
+ type=template_map_fn_factory, template=prompt_template),
399
+ max_length=max_length,
400
+ pad_image_to_square=True,
401
+ debug=False,
402
+ repeats=1,
403
+ )
404
+
405
+ glamm_grandf_dataset = dict(
406
+ type=GranDfGCGDataset,
407
+ data_path=grandf_ann_file,
408
+ image_folder=grandf_image_path,
409
+ tokenizer=tokenizer,
410
+ image_processor=image_processor,
411
+ dataset_map_fn=glamm_granf_map_fn,
412
+ template_map_fn=dict(
413
+ type=template_map_fn_factory, template=prompt_template),
414
+ max_length=max_length,
415
+ pad_image_to_square=True,
416
+ debug=debug,
417
+ repeats=10,
418
+ )
419
+
420
+ glamm_psg_dataset = dict(
421
+ type=OpenPsgGCGDataset,
422
+ data_path=psg_ann_file,
423
+ image_folder=psg_image_path,
424
+ tokenizer=tokenizer,
425
+ image_processor=image_processor,
426
+ dataset_map_fn=glamm_openpsg_map_fn,
427
+ template_map_fn=dict(
428
+ type=template_map_fn_factory, template=prompt_template),
429
+ max_length=max_length,
430
+ pad_image_to_square=True,
431
+ debug=debug,
432
+ repeats=1,
433
+ )
434
+
435
+ glamm_flickr_dataset = dict(
436
+ type=FlickrGCGDataset,
437
+ data_path=flickr_ann_file,
438
+ image_folder=flickr_image_path,
439
+ tokenizer=tokenizer,
440
+ image_processor=image_processor,
441
+ dataset_map_fn=glamm_flickr_map_fn,
442
+ template_map_fn=dict(
443
+ type=template_map_fn_factory, template=prompt_template),
444
+ max_length=max_length,
445
+ pad_image_to_square=True,
446
+ debug=debug,
447
+ repeats=1,
448
+ )
449
+
450
+ semantic_seg_ade20k_dataset = dict(
451
+ type=ADE20kSemanticSegDataset,
452
+ data_path=ade20k_class_file,
453
+ image_folder=ade20k_image_path,
454
+ tokenizer=tokenizer,
455
+ image_processor=image_processor,
456
+ dataset_map_fn=semantic_seg_gcg_format_map_fn,
457
+ template_map_fn=dict(
458
+ type=template_map_fn_factory, template=prompt_template),
459
+ max_length=max_length,
460
+ pad_image_to_square=True,
461
+ debug=False,
462
+ repeats=1,
463
+ gcg_format=True,
464
+ )
465
+
466
+ semantic_seg_cocostuff_dataset = dict(
467
+ type=COCOStuffSemanticSegDataset,
468
+ data_path=cocostuff_class_file,
469
+ image_folder=cocostuff_image_path,
470
+ label_path=cocostuff_label_path,
471
+ tokenizer=tokenizer,
472
+ image_processor=image_processor,
473
+ dataset_map_fn=semantic_seg_gcg_format_map_fn,
474
+ template_map_fn=dict(
475
+ type=template_map_fn_factory, template=prompt_template),
476
+ max_length=max_length,
477
+ pad_image_to_square=True,
478
+ debug=False,
479
+ repeats=1,
480
+ gcg_format=True,
481
+ )
482
+
483
+ referring_seg_refcoco_dataset = dict(
484
+ type=RefcocoReferringSegDataset,
485
+ data_path=referring_refcoco_data_path,
486
+ image_folder=referring_refcoco_image_path,
487
+ tokenizer=tokenizer,
488
+ image_processor=image_processor,
489
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
490
+ template_map_fn=dict(
491
+ type=template_map_fn_factory, template=prompt_template),
492
+ max_length=max_length,
493
+ pad_image_to_square=True,
494
+ debug=False,
495
+ repeats=1,
496
+ )
497
+
498
+ referring_seg_refcoco_plus_dataset = dict(
499
+ type=Refcoco_plus_ReferringSegDataset,
500
+ data_path=referring_refcoco_plus_data_path,
501
+ image_folder=referring_refcoco_plus_image_path,
502
+ tokenizer=tokenizer,
503
+ image_processor=image_processor,
504
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
505
+ template_map_fn=dict(
506
+ type=template_map_fn_factory, template=prompt_template),
507
+ max_length=max_length,
508
+ pad_image_to_square=True,
509
+ debug=False,
510
+ repeats=1,
511
+ )
512
+
513
+ referring_seg_refcocog_dataset = dict(
514
+ type=Refcocog_ReferringSegDataset,
515
+ data_path=referring_refcocog_data_path,
516
+ image_folder=referring_refcocog_image_path,
517
+ tokenizer=tokenizer,
518
+ image_processor=image_processor,
519
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
520
+ template_map_fn=dict(
521
+ type=template_map_fn_factory, template=prompt_template),
522
+ max_length=max_length,
523
+ pad_image_to_square=True,
524
+ debug=False,
525
+ repeats=1,
526
+ )
527
+
528
+ referring_seg_refclef_dataset = dict(
529
+ type=Refclef_ReferringSegDataset,
530
+ data_path=referring_refclef_data_path,
531
+ image_folder=referring_refclef_image_path,
532
+ tokenizer=tokenizer,
533
+ image_processor=image_processor,
534
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
535
+ template_map_fn=dict(
536
+ type=template_map_fn_factory, template=prompt_template),
537
+ max_length=max_length,
538
+ pad_image_to_square=True,
539
+ debug=False,
540
+ repeats=1,
541
+ )
542
+
543
+ region_cap_osprey_dataset = dict(
544
+ type=OspreyRegionCaptionDataset,
545
+ data_path=region_cap_osprey_data_path,
546
+ image_folder=region_cap_osprey_image_path,
547
+ tokenizer=tokenizer,
548
+ image_processor=image_processor,
549
+ dataset_map_fn=osprey_region_caption_gcg_format_map_fn,
550
+ template_map_fn=dict(
551
+ type=template_map_fn_factory, template=prompt_template),
552
+ max_length=max_length,
553
+ pad_image_to_square=True,
554
+ debug=False,
555
+ repeats=1,
556
+ )
557
+
558
+ region_conversation_osprey_dataset = dict(
559
+ type=OspreyRegionConversationDataset,
560
+ data_path=region_conversation_osprey_data_path,
561
+ image_folder=region_conversation_osprey_image_path,
562
+ tokenizer=tokenizer,
563
+ image_processor=image_processor,
564
+ dataset_map_fn=osprey_region_conversation_map_fn,
565
+ template_map_fn=dict(
566
+ type=template_map_fn_factory, template=prompt_template),
567
+ max_length=max_length,
568
+ pad_image_to_square=True,
569
+ debug=False,
570
+ repeats=1,
571
+ )
572
+
573
+ mdpv_detailed_description_ade20k_dataset = dict(
574
+ type=MDPVPointDetailedCaptionDataset,
575
+ data_path=mdpv_detailed_caption_ade20k_data_path,
576
+ image_folder=mdpv_detailed_caption_ade20k_image_path,
577
+ tokenizer=tokenizer,
578
+ image_processor=image_processor,
579
+ dataset_map_fn=mdpv_points_map_fn,
580
+ template_map_fn=dict(
581
+ type=template_map_fn_factory, template=prompt_template),
582
+ max_length=max_length,
583
+ pad_image_to_square=True,
584
+ debug=False,
585
+ repeats=1,
586
+ )
587
+
588
+ mdpv_detailed_description_cocostuff_10k_dataset = dict(
589
+ type=MDPVPointDetailedCaptionDataset,
590
+ data_path=mdpv_detailed_caption_cocostuff_10k_data_path,
591
+ image_folder=mdpv_detailed_caption_cocostuff_10k_image_path,
592
+ tokenizer=tokenizer,
593
+ image_processor=image_processor,
594
+ dataset_map_fn=mdpv_points_map_fn,
595
+ template_map_fn=dict(
596
+ type=template_map_fn_factory, template=prompt_template),
597
+ max_length=max_length,
598
+ pad_image_to_square=True,
599
+ debug=False,
600
+ repeats=1,
601
+ )
602
+
603
+ mdpv_detailed_description_cocostuff_164k_dataset = dict(
604
+ type=MDPVPointDetailedCaptionDataset,
605
+ data_path=mdpv_detailed_caption_cocostuff_164k_data_path,
606
+ image_folder=mdpv_detailed_caption_cocostuff_164k_image_path,
607
+ tokenizer=tokenizer,
608
+ image_processor=image_processor,
609
+ dataset_map_fn=mdpv_points_map_fn,
610
+ template_map_fn=dict(
611
+ type=template_map_fn_factory, template=prompt_template),
612
+ max_length=max_length,
613
+ pad_image_to_square=True,
614
+ debug=False,
615
+ repeats=1,
616
+ )
617
+
618
+ mdpv_detailed_description_vg_dataset = dict(
619
+ type=MDPVPointDetailedCaptionDataset,
620
+ data_path=mdpv_detailed_caption_vg_data_path,
621
+ image_folder=mdpv_detailed_caption_vg_image_path,
622
+ tokenizer=tokenizer,
623
+ image_processor=image_processor,
624
+ dataset_map_fn=mdpv_points_map_fn,
625
+ template_map_fn=dict(
626
+ type=template_map_fn_factory, template=prompt_template),
627
+ max_length=max_length,
628
+ pad_image_to_square=True,
629
+ debug=False,
630
+ repeats=1,
631
+ )
632
+
633
+ mdpv_brief_description_vg_dataset = dict(
634
+ type=MDPVPointBriefCaptionDataset,
635
+ data_path=mdpv_brief_caption_vg_data_path,
636
+ image_folder=mdpv_brief_caption_vg_image_path,
637
+ tokenizer=tokenizer,
638
+ image_processor=image_processor,
639
+ dataset_map_fn=mdpv_points_map_fn,
640
+ template_map_fn=dict(
641
+ type=template_map_fn_factory, template=prompt_template),
642
+ max_length=max_length,
643
+ pad_image_to_square=True,
644
+ debug=False,
645
+ repeats=1,
646
+ )
647
+
648
+ mdpv_brief_description_cocostuff10k_dataset = dict(
649
+ type=MDPVPointBriefCaptionDataset,
650
+ data_path=mdpv_brief_caption_cocostuff_10k_data_path,
651
+ image_folder=mdpv_brief_caption_cocostuff_10k_image_path,
652
+ tokenizer=tokenizer,
653
+ image_processor=image_processor,
654
+ dataset_map_fn=mdpv_points_map_fn,
655
+ template_map_fn=dict(
656
+ type=template_map_fn_factory, template=prompt_template),
657
+ max_length=max_length,
658
+ pad_image_to_square=True,
659
+ debug=False,
660
+ repeats=1,
661
+ )
662
+
663
+ mdpv_brief_description_cocostuff164k_dataset = dict(
664
+ type=MDPVPointBriefCaptionDataset,
665
+ data_path=mdpv_brief_caption_cocostuff_164k_data_path,
666
+ image_folder=mdpv_brief_caption_cocostuff_164k_image_path,
667
+ tokenizer=tokenizer,
668
+ image_processor=image_processor,
669
+ dataset_map_fn=mdpv_points_map_fn,
670
+ template_map_fn=dict(
671
+ type=template_map_fn_factory, template=prompt_template),
672
+ max_length=max_length,
673
+ pad_image_to_square=True,
674
+ debug=False,
675
+ repeats=1,
676
+ )
677
+
678
+ mdpv_brief_description_ade20k_dataset = dict(
679
+ type=MDPVPointBriefCaptionDataset,
680
+ data_path=mdpv_brief_caption_ade20k_data_path,
681
+ image_folder=mdpv_brief_caption_ade20k_image_path,
682
+ tokenizer=tokenizer,
683
+ image_processor=image_processor,
684
+ dataset_map_fn=mdpv_points_map_fn,
685
+ template_map_fn=dict(
686
+ type=template_map_fn_factory, template=prompt_template),
687
+ max_length=max_length,
688
+ pad_image_to_square=True,
689
+ debug=False,
690
+ repeats=1,
691
+ )
692
+
693
+ mdpv_brief_description_lvis_dataset = dict(
694
+ type=MDPVPointBriefCaptionDataset,
695
+ data_path=mdpv_brief_caption_lvis_data_path,
696
+ image_folder=mdpv_brief_caption_lvis_image_path,
697
+ tokenizer=tokenizer,
698
+ image_processor=image_processor,
699
+ dataset_map_fn=mdpv_points_map_fn,
700
+ template_map_fn=dict(
701
+ type=template_map_fn_factory, template=prompt_template),
702
+ max_length=max_length,
703
+ pad_image_to_square=True,
704
+ debug=False,
705
+ repeats=1,
706
+ )
707
+
708
+ mdpv_qa_vg_dataset = dict(
709
+ type=MDPVPointBriefCaptionDataset,
710
+ data_path=mdpv_qa_vg_data_path,
711
+ image_folder=mdpv_qa_vg_image_path,
712
+ tokenizer=tokenizer,
713
+ image_processor=image_processor,
714
+ dataset_map_fn=mdpv_points_map_fn,
715
+ template_map_fn=dict(
716
+ type=template_map_fn_factory, template=prompt_template),
717
+ max_length=max_length,
718
+ pad_image_to_square=True,
719
+ debug=False,
720
+ repeats=1,
721
+ )
722
+
723
+ mdpv_qa_ade20k_dataset = dict(
724
+ type=MDPVPointBriefCaptionDataset,
725
+ data_path=mdpv_qa_ade20k_data_path,
726
+ image_folder=mdpv_qa_ade20k_image_path,
727
+ tokenizer=tokenizer,
728
+ image_processor=image_processor,
729
+ dataset_map_fn=mdpv_points_map_fn,
730
+ template_map_fn=dict(
731
+ type=template_map_fn_factory, template=prompt_template),
732
+ max_length=max_length,
733
+ pad_image_to_square=True,
734
+ debug=False,
735
+ repeats=1,
736
+ )
737
+
738
+ mdpv_qa_lvis_dataset = dict(
739
+ type=MDPVPointBriefCaptionDataset,
740
+ data_path=mdpv_qa_lvis_data_path,
741
+ image_folder=mdpv_qa_lvis_image_path,
742
+ tokenizer=tokenizer,
743
+ image_processor=image_processor,
744
+ dataset_map_fn=mdpv_points_map_fn,
745
+ template_map_fn=dict(
746
+ type=template_map_fn_factory, template=prompt_template),
747
+ max_length=max_length,
748
+ pad_image_to_square=True,
749
+ debug=False,
750
+ repeats=1,
751
+ )
752
+
753
+ mdpv_qa_cocostuff10k_dataset = dict(
754
+ type=MDPVPointBriefCaptionDataset,
755
+ data_path=mdpv_qa_cocostuff10k_data_path,
756
+ image_folder=mdpv_qa_cocostuff10k_image_path,
757
+ tokenizer=tokenizer,
758
+ image_processor=image_processor,
759
+ dataset_map_fn=mdpv_points_map_fn,
760
+ template_map_fn=dict(
761
+ type=template_map_fn_factory, template=prompt_template),
762
+ max_length=max_length,
763
+ pad_image_to_square=True,
764
+ debug=False,
765
+ repeats=1,
766
+ )
767
+
768
+ mdpv_qa_cocostuff164k_dataset = dict(
769
+ type=MDPVPointBriefCaptionDataset,
770
+ data_path=mdpv_qa_cocostuff164k_data_path,
771
+ image_folder=mdpv_qa_cocostuff164k_image_path,
772
+ tokenizer=tokenizer,
773
+ image_processor=image_processor,
774
+ dataset_map_fn=mdpv_points_map_fn,
775
+ template_map_fn=dict(
776
+ type=template_map_fn_factory, template=prompt_template),
777
+ max_length=max_length,
778
+ pad_image_to_square=True,
779
+ debug=False,
780
+ repeats=1,
781
+ )
782
+
783
+ mdpv_multi_points_openpsg_dataset = dict(
784
+ type=MDPVPointBriefCaptionDataset,
785
+ data_path=mdpv_multi_points_openpsg_data_path,
786
+ image_folder=mdpv_multi_points_openpsg_image_path,
787
+ tokenizer=tokenizer,
788
+ image_processor=image_processor,
789
+ dataset_map_fn=mdpv_points_map_fn,
790
+ template_map_fn=dict(
791
+ type=template_map_fn_factory, template=prompt_template),
792
+ max_length=max_length,
793
+ pad_image_to_square=True,
794
+ debug=False,
795
+ repeats=1,
796
+ )
797
+
798
+ mdpv_multi_points_flicker30k_dataset = dict(
799
+ type=MDPVPointBriefCaptionDataset,
800
+ data_path=mdpv_multi_points_flicker30k_data_path,
801
+ image_folder=mdpv_multi_points_flicker30k_image_path,
802
+ tokenizer=tokenizer,
803
+ image_processor=image_processor,
804
+ dataset_map_fn=mdpv_points_map_fn,
805
+ template_map_fn=dict(
806
+ type=template_map_fn_factory, template=prompt_template),
807
+ max_length=max_length,
808
+ pad_image_to_square=True,
809
+ debug=False,
810
+ repeats=1,
811
+ )
812
+
813
+ train_dataset = dict(
814
+ type=CombineDataset,
815
+ datasets_cfgs=[glamm_flickr_dataset, glamm_refcocog_dataset,
816
+ glamm_grandf_dataset, glamm_psg_dataset,],
817
+ )
818
+
819
+ train_dataloader = dict(
820
+ batch_size=batch_size,
821
+ num_workers=dataloader_num_workers,
822
+ dataset=train_dataset,
823
+ sampler=dict(
824
+ type=LengthGroupedSampler,
825
+ length_property='modality_length',
826
+ per_device_batch_size=batch_size * accumulative_counts),
827
+ collate_fn=dict(type=omg_llava_collate_fn))
828
+
829
+ #######################################################################
830
+ # PART 4 Scheduler & Optimizer #
831
+ #######################################################################
832
+ # optimizer
833
+ optim_wrapper = dict(
834
+ type=AmpOptimWrapper,
835
+ optimizer=dict(
836
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
837
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
838
+ accumulative_counts=accumulative_counts,
839
+ loss_scale='dynamic',
840
+ dtype='float16')
841
+
842
+ # learning policy
843
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
844
+ param_scheduler = [
845
+ dict(
846
+ type=LinearLR,
847
+ start_factor=1e-5,
848
+ by_epoch=True,
849
+ begin=0,
850
+ end=warmup_ratio * max_epochs,
851
+ convert_to_iter_based=True),
852
+ dict(
853
+ type=CosineAnnealingLR,
854
+ eta_min=0.0,
855
+ by_epoch=True,
856
+ begin=warmup_ratio * max_epochs,
857
+ end=max_epochs,
858
+ convert_to_iter_based=True)
859
+ ]
860
+
861
+ # train, val, test setting
862
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
863
+
864
+ #######################################################################
865
+ # PART 5 Runtime #
866
+ #######################################################################
867
+ # Log the dialogue periodically during the training process, optional
868
+ custom_hooks = [
869
+ dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer),
870
+ dict(
871
+ type=EvaluateChatHook_withSpecialTokens,
872
+ tokenizer=tokenizer,
873
+ image_processor=image_processor,
874
+ every_n_iters=evaluation_freq,
875
+ evaluation_inputs=evaluation_inputs,
876
+ evaluation_images=evaluation_images,
877
+ system=SYSTEM,
878
+ prompt_template=prompt_template)
879
+ ]
880
+
881
+ # configure default hooks
882
+ default_hooks = dict(
883
+ # record the time of every iteration.
884
+ timer=dict(type=IterTimerHook),
885
+ # print log every 10 iterations.
886
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
887
+ # enable the parameter scheduler.
888
+ param_scheduler=dict(type=ParamSchedulerHook),
889
+ # save checkpoint per `save_steps`.
890
+ checkpoint=dict(
891
+ type=CheckpointHook,
892
+ by_epoch=False,
893
+ interval=save_steps,
894
+ max_keep_ckpts=save_total_limit),
895
+ # set sampler seed in distributed evrionment.
896
+ sampler_seed=dict(type=DistSamplerSeedHook),
897
+ )
898
+
899
+ # configure environment
900
+ env_cfg = dict(
901
+ # whether to enable cudnn benchmark
902
+ cudnn_benchmark=False,
903
+ # set multi process parameters
904
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
905
+ # set distributed parameters
906
+ dist_cfg=dict(backend='nccl'),
907
+ )
908
+
909
+ # set visualizer
910
+ visualizer = None
911
+
912
+ # set log level
913
+ log_level = 'INFO'
914
+
915
+ # load from which checkpoint
916
+ load_from = None
917
+
918
+ # whether to resume training from the loaded checkpoint
919
+ resume = False
920
+
921
+ # Defaults to use random seed and disable `deterministic`
922
+ randomness = dict(seed=None, deterministic=False)
923
+
924
+ # set log processor
925
+ log_processor = dict(by_epoch=False)
omg_llava/configs/finetune/specific_tasks_finetune/finetune_refseg.py ADDED
@@ -0,0 +1,929 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
4
+ LoggerHook, ParamSchedulerHook)
5
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
6
+ from peft import LoraConfig
7
+ from torch.optim import AdamW
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig, CLIPImageProcessor,
10
+ CLIPVisionModel)
11
+
12
+ from omg_llava.dataset import LLaVADataset
13
+ from omg_llava.dataset.collect_fns import omg_llava_collate_fn
14
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
15
+ from omg_llava.dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset,\
16
+ CombineDataset, glamm_refcocog_map_fn, glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn,\
17
+ ADE20kSemanticSegDataset, COCOStuffSemanticSegDataset, semantic_seg_map_fn, MapillarySemanticSegDataset,\
18
+ PascalPartSemanticSegDataset, pascal_part_map_fn, PacoSemanticSegDataset,\
19
+ RefcocoReferringSegDataset, referring_seg_map_fn, Refcoco_plus_ReferringSegDataset,\
20
+ Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset,\
21
+ OspreyRegionCaptionDataset, osprey_region_caption_map_fn,\
22
+ OspreyRegionConversationDataset, osprey_region_conversation_map_fn,\
23
+ MDPVPointDetailedCaptionDataset, mdpv_points_map_fn, MDPVPointBriefCaptionDataset,\
24
+ semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\
25
+ referring_seg_gcg_format_map_fn, osprey_region_caption_gcg_format_map_fn
26
+ from xtuner.dataset.samplers import LengthGroupedSampler
27
+ from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens
28
+ from xtuner.engine.runner import TrainLoop
29
+ from omg_llava.model import OMG_LLaVA
30
+ from xtuner.utils import PROMPT_TEMPLATE
31
+ from omg_llava.model import OpenCLIPBackbone_omgseg
32
+ from omg_llava.model import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead
33
+
34
+ from torch.nn import GroupNorm, ReLU
35
+
36
+ from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \
37
+ DiceLoss, MaskFormerFusionHead, FocalLoss
38
+ from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost
39
+ from mmdet.models.task_modules.samplers import MaskPseudoSampler
40
+
41
+ #######################################################################
42
+ # PART 1 Settings #
43
+ #######################################################################
44
+ # Model
45
+ llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path
46
+ pretrained_pth = './work_dirs/omg_llava_1024x_2stage_finetune_1_clear_reratio_rmqcache_uniformSegFormat_8gpus.pth'
47
+ omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path
48
+ omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path
49
+
50
+ # Data
51
+ data_root = './data/llava_data/'
52
+ data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json'
53
+ image_folder = data_root + 'llava_images'
54
+
55
+ glamm_data_root = './data/glamm_data/'
56
+
57
+ refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/'
58
+ refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json'
59
+
60
+ grandf_image_path = glamm_data_root + 'images/grandf/train/'
61
+ grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json'
62
+
63
+ flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/'
64
+ flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json'
65
+
66
+ psg_image_path = glamm_data_root + 'images/coco2017/'
67
+ psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json'
68
+
69
+ ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
70
+ ade20k_class_file = './omg_llava/dataset/utils/ade20k_classes.json'
71
+
72
+ cocostuff_image_path = './data/glamm_data/images/coco2017/train2017/'
73
+ cocostuff_class_file = './omg_llava/dataset/utils/cocostuff_classes.txt'
74
+ cocostuff_label_path = './data/semantic_seg/coco_stuff/stuffthingmaps_trainval2017/train2017/'
75
+
76
+ mapillary_image_path = './data/semantic_seg/mapillary/training/images/'
77
+ mapillary_class_file = './data/semantic_seg/mapillary/config_v2.0.json'
78
+ mapillary_label_path = './data/semantic_seg/mapillary/training/v2.0/labels/'
79
+
80
+ pascal_part_image_path = './data/semantic_seg/pascal_part/VOCdevkit/VOC2010/JPEGImages/'
81
+ pascal_file = './data/semantic_seg/pascal_part/train.json'
82
+
83
+ paco_image_path = './data/glamm_data/images/coco2017/'
84
+ paco_file = './data/semantic_seg/paco_lvis/paco_lvis_v1_train.json'
85
+
86
+ referring_refcoco_image_path = refcocog_image_path
87
+ referring_refcoco_data_path = "./data/ref_seg/"
88
+
89
+ referring_refcoco_plus_image_path = refcocog_image_path
90
+ referring_refcoco_plus_data_path = "./data/ref_seg/"
91
+
92
+ referring_refcocog_image_path = refcocog_image_path
93
+ referring_refcocog_data_path = "./data/ref_seg/"
94
+
95
+ referring_refclef_image_path = "./data/ref_seg/saiapr_tc-12/"
96
+ referring_refclef_data_path = "./data/ref_seg/"
97
+
98
+ region_cap_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
99
+ region_cap_osprey_data_path = "./data/region_caption/osprey/osprey_detail_description.json"
100
+
101
+ region_conversation_osprey_image_path = glamm_data_root + 'images/coco2014/train2014/'
102
+ region_conversation_osprey_data_path = "./data/region_caption/osprey/osprey_conversation.json"
103
+
104
+ mdpv_detailed_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
105
+ mdpv_detailed_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_detailed_caption_point.json'
106
+
107
+ mdpv_detailed_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
108
+ mdpv_detailed_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_detailed_caption_point.json'
109
+
110
+ mdpv_detailed_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
111
+ mdpv_detailed_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_detailed_caption_point.json'
112
+
113
+ mdpv_detailed_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
114
+ mdpv_detailed_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_detailed_caption_point.json'
115
+
116
+ mdpv_brief_caption_cocostuff_10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
117
+ mdpv_brief_caption_cocostuff_10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_brief_caption_point.json'
118
+
119
+ mdpv_brief_caption_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
120
+ mdpv_brief_caption_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_brief_caption_point.json'
121
+
122
+ mdpv_brief_caption_cocostuff_164k_image_path = './data/glamm_data/images/coco2017/train2017'
123
+ mdpv_brief_caption_cocostuff_164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_brief_caption_point.json'
124
+
125
+ mdpv_brief_caption_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
126
+ mdpv_brief_caption_vg_data_path = './data/mdpv_point/gpt4v_vg_brief_caption_point.json'
127
+
128
+ mdpv_brief_caption_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
129
+ mdpv_brief_caption_lvis_data_path = './data/mdpv_point/gpt4v_lvis_brief_caption_point.json'
130
+
131
+ mdpv_qa_vg_image_path = './data/llava_data/llava_images/vg/VG_100K'
132
+ mdpv_qa_vg_data_path = './data/mdpv_point/gpt4v_vg_QA_point.json'
133
+
134
+ mdpv_qa_ade20k_image_path = './data/semantic_seg/ADEChallengeData2016/images/training/'
135
+ mdpv_qa_ade20k_data_path = './data/mdpv_point/gpt4v_ade20k_QA_point.json'
136
+
137
+ mdpv_qa_cocostuff164k_image_path = './data/glamm_data/images/coco2017/train2017'
138
+ mdpv_qa_cocostuff164k_data_path = './data/mdpv_point/gpt4v_cocostuff_164k_QA_point.json'
139
+
140
+ mdpv_qa_lvis_image_path = './data/glamm_data/images/coco2017/train2017'
141
+ mdpv_qa_lvis_data_path = './data/mdpv_point/gpt4v_lvis_QA_point.json'
142
+
143
+ mdpv_qa_cocostuff10k_image_path = glamm_data_root + 'images/coco2014/train2014/'
144
+ mdpv_qa_cocostuff10k_data_path = './data/mdpv_point/gpt4v_cocostuff_10k_QA_point.json'
145
+
146
+ mdpv_multi_points_flicker30k_image_path = './data/glamm_data/images/flickr30k/Flickr30K/'
147
+ mdpv_multi_points_flicker30k_data_path = './data/mdpv_point/Flicker30K_multi_points_to_caption.json'
148
+
149
+ mdpv_multi_points_openpsg_image_path = glamm_data_root + 'images/coco2017/train2017'
150
+ mdpv_multi_points_openpsg_data_path = './data/mdpv_point/OpenPsgGCG_train_multi_points_to_caption.json'
151
+
152
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
153
+ max_length = int(2048 - (1024 / 64)**2 - 100)
154
+
155
+ # Scheduler & Optimizer
156
+ batch_size = 8 # per_device
157
+ accumulative_counts = 2
158
+ dataloader_num_workers = 4
159
+ max_epochs = 1
160
+ optim_type = AdamW
161
+ lr = 2e-4
162
+ betas = (0.9, 0.999)
163
+ weight_decay = 0
164
+ max_norm = 1 # grad clip
165
+ warmup_ratio = 0.03
166
+
167
+
168
+ # Save
169
+ save_steps = 2000
170
+ save_total_limit = 4 # Maximum checkpoints to keep (-1 means unlimited)
171
+
172
+ # Evaluate the generation performance during the training
173
+ evaluation_freq = 2000
174
+ SYSTEM = ''
175
+ evaluation_images = './work_dirs/test.jpg'
176
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture',
177
+ 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.']
178
+
179
+ #######################################################################
180
+ # PART 2 Model & Tokenizer & Image Processor #
181
+ #######################################################################
182
+ tokenizer = dict(
183
+ type=AutoTokenizer.from_pretrained,
184
+ pretrained_model_name_or_path=llm_name_or_path,
185
+ trust_remote_code=True,
186
+ padding_side='right')
187
+
188
+ image_processor = dict(
189
+ type=CLIPImageProcessor,
190
+ do_resize=True,
191
+ size=1024,
192
+ resample=3,
193
+ do_center_crop=True,
194
+ crop_size=1024,
195
+ do_rescale=True,
196
+ do_normalize=True,
197
+ image_mean=[0.4814, 0.4578, 0.4082],
198
+ image_std=[0.2686, 0.2613, 0.2757],
199
+ do_convert_rgb=True
200
+ )
201
+
202
+ class_embed = 'convnext_large_d_320_CocoPanopticOVDataset'
203
+ num_things_classes = 80
204
+ num_stuff_classes = 53
205
+ num_classes = num_things_classes + num_stuff_classes
206
+
207
+ omgseg_model = dict(
208
+ type=OMGSegVisualEncoder,
209
+ data_preprocessor=None,
210
+ pixel_shuffle_down_ratio=2,
211
+ backbone=dict(
212
+ type=OpenCLIPBackbone_omgseg,
213
+ model_name='convnext_large_d_320',
214
+ fix=True,
215
+ init_cfg=dict(
216
+ type='clip_pretrain',
217
+ checkpoint='laion2b_s29b_b131k_ft_soup'
218
+ )
219
+ ),
220
+ panoptic_head=dict(
221
+ type=Mask2FormerVideoSemSamHead,
222
+ sphere_cls=True,
223
+ ov_path=omg_ov_class_embed_path,
224
+ enable_box_query=False,
225
+ ov_classifier_name=class_embed,
226
+ logit=None,
227
+ in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside
228
+ strides=[4, 8, 16, 32],
229
+ feat_channels=256,
230
+ out_channels=256,
231
+ num_things_classes=num_things_classes,
232
+ num_stuff_classes=num_stuff_classes,
233
+ num_queries=300,
234
+ num_transformer_feat_level=3,
235
+ pixel_decoder=dict(
236
+ type=MSDeformAttnPixelDecoder,
237
+ num_outs=3,
238
+ norm_cfg=dict(type=GroupNorm, num_groups=32),
239
+ act_cfg=dict(type=ReLU),
240
+ encoder=dict( # DeformableDetrTransformerEncoder
241
+ num_layers=6,
242
+ layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
243
+ self_attn_cfg=dict( # MultiScaleDeformableAttention
244
+ embed_dims=256,
245
+ num_heads=8,
246
+ num_levels=3,
247
+ num_points=4,
248
+ dropout=0.0,
249
+ batch_first=True),
250
+ ffn_cfg=dict(
251
+ embed_dims=256,
252
+ feedforward_channels=1024,
253
+ num_fcs=2,
254
+ ffn_drop=0.0,
255
+ act_cfg=dict(type=ReLU, inplace=True)))),
256
+ positional_encoding=dict(num_feats=128, normalize=True)),
257
+ enforce_decoder_input_project=False,
258
+ positional_encoding=dict(num_feats=128, normalize=True),
259
+ transformer_decoder=dict( # Mask2FormerTransformerDecoder
260
+ return_intermediate=True,
261
+ num_layers=9,
262
+ layer_cfg=dict( # Mask2FormerTransformerDecoderLayer
263
+ self_attn_cfg=dict( # MultiheadAttention
264
+ embed_dims=256,
265
+ num_heads=8,
266
+ dropout=0.0,
267
+ batch_first=True),
268
+ cross_attn_cfg=dict( # MultiheadAttention
269
+ embed_dims=256,
270
+ num_heads=8,
271
+ dropout=0.0,
272
+ batch_first=True),
273
+ ffn_cfg=dict(
274
+ embed_dims=256,
275
+ feedforward_channels=2048,
276
+ num_fcs=2,
277
+ ffn_drop=0.0,
278
+ act_cfg=dict(type='ReLU', inplace=True))),
279
+ init_cfg=None),
280
+ loss_cls=dict(
281
+ type=CrossEntropyLoss,
282
+ use_sigmoid=False,
283
+ loss_weight=2.0,
284
+ reduction='mean',
285
+ class_weight=[1.0] * 240 + [0.1]),
286
+ loss_mask=dict(
287
+ type=CrossEntropyLoss,
288
+ use_sigmoid=True,
289
+ reduction='mean',
290
+ loss_weight=5.0),
291
+ loss_dice=dict(
292
+ type=DiceLoss,
293
+ use_sigmoid=True,
294
+ activate=True,
295
+ reduction='mean',
296
+ naive_dice=True,
297
+ eps=1.0,
298
+ loss_weight=5.0),
299
+ loss_iou=dict(
300
+ type=FocalLoss,
301
+ use_sigmoid=True,
302
+ loss_weight=2.0,
303
+ reduction='mean')
304
+ ),
305
+ panoptic_fusion_head=dict(
306
+ type=MaskFormerFusionHead,
307
+ num_things_classes=num_things_classes,
308
+ num_stuff_classes=num_stuff_classes,
309
+ loss_panoptic=None,
310
+ init_cfg=None),
311
+ train_cfg=dict(
312
+ num_points=12544,
313
+ oversample_ratio=3.0,
314
+ importance_sample_ratio=0.75,
315
+ assigner=dict(
316
+ type=HungarianAssigner,
317
+ match_costs=[
318
+ # dict(type=FlexibleClassificationCost, weight=2.0),
319
+ dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True),
320
+ dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0)
321
+ ]),
322
+ sampler=dict(type=MaskPseudoSampler)),
323
+ test_cfg=dict(
324
+ panoptic_on=True,
325
+ # For now, the dataset does not support
326
+ # evaluating semantic segmentation metric.
327
+ semantic_on=False,
328
+ instance_on=True,
329
+ # max_per_image is for instance segmentation.
330
+ max_per_image=100,
331
+ iou_thr=0.8,
332
+ # In Mask2Former's panoptic postprocessing,
333
+ # it will filter mask area where score is less than 0.5 .
334
+ filter_low_score=True),
335
+ init_cfg=dict(
336
+ type='Pretrained',
337
+ checkpoint=omg_head_pretrain_pth_path,
338
+ )
339
+ )
340
+
341
+ model = dict(
342
+ type=OMG_LLaVA,
343
+ freeze_llm=True,
344
+ freeze_visual_encoder=True,
345
+ require_omg_decoder=True,
346
+ pretrained_pth=pretrained_pth,
347
+ text2vision_projector=True,
348
+ pixel_shuffle_ratio=2,
349
+ llm=dict(
350
+ type=AutoModelForCausalLM.from_pretrained,
351
+ pretrained_model_name_or_path=llm_name_or_path,
352
+ trust_remote_code=True,
353
+ torch_dtype=torch.float16,
354
+ quantization_config=dict(
355
+ type=BitsAndBytesConfig,
356
+ load_in_4bit=True,
357
+ load_in_8bit=False,
358
+ llm_int8_threshold=6.0,
359
+ llm_int8_has_fp16_weight=False,
360
+ bnb_4bit_compute_dtype=torch.float16,
361
+ bnb_4bit_use_double_quant=True,
362
+ bnb_4bit_quant_type='nf4')),
363
+ llm_lora=dict(
364
+ type=LoraConfig,
365
+ r=512,
366
+ lora_alpha=256,
367
+ lora_dropout=0.05,
368
+ bias='none',
369
+ task_type='CAUSAL_LM'),
370
+ visual_encoder=omgseg_model,
371
+ tokenizer=tokenizer,
372
+ )
373
+
374
+ #######################################################################
375
+ # PART 3 Dataset & Dataloader #
376
+ #######################################################################
377
+ debug=False
378
+ llava_dataset = dict(
379
+ type=LLaVADataset,
380
+ data_path=data_path,
381
+ image_folder=image_folder,
382
+ tokenizer=tokenizer,
383
+ image_processor=image_processor,
384
+ dataset_map_fn=llava_map_fn,
385
+ template_map_fn=dict(
386
+ type=template_map_fn_factory, template=prompt_template),
387
+ max_length=max_length,
388
+ pad_image_to_square=True)
389
+
390
+ glamm_refcocog_dataset = dict(
391
+ type=RefCOCOgGCGDataset,
392
+ data_path=refcocog_ann_file,
393
+ image_folder=refcocog_image_path,
394
+ tokenizer=tokenizer,
395
+ image_processor=image_processor,
396
+ dataset_map_fn=glamm_refcocog_map_fn,
397
+ template_map_fn=dict(
398
+ type=template_map_fn_factory, template=prompt_template),
399
+ max_length=max_length,
400
+ pad_image_to_square=True,
401
+ debug=False,
402
+ repeats=1,
403
+ )
404
+
405
+ glamm_grandf_dataset = dict(
406
+ type=GranDfGCGDataset,
407
+ data_path=grandf_ann_file,
408
+ image_folder=grandf_image_path,
409
+ tokenizer=tokenizer,
410
+ image_processor=image_processor,
411
+ dataset_map_fn=glamm_granf_map_fn,
412
+ template_map_fn=dict(
413
+ type=template_map_fn_factory, template=prompt_template),
414
+ max_length=max_length,
415
+ pad_image_to_square=True,
416
+ debug=debug,
417
+ repeats=10,
418
+ )
419
+
420
+ glamm_psg_dataset = dict(
421
+ type=OpenPsgGCGDataset,
422
+ data_path=psg_ann_file,
423
+ image_folder=psg_image_path,
424
+ tokenizer=tokenizer,
425
+ image_processor=image_processor,
426
+ dataset_map_fn=glamm_openpsg_map_fn,
427
+ template_map_fn=dict(
428
+ type=template_map_fn_factory, template=prompt_template),
429
+ max_length=max_length,
430
+ pad_image_to_square=True,
431
+ debug=debug,
432
+ repeats=1,
433
+ )
434
+
435
+ glamm_flickr_dataset = dict(
436
+ type=FlickrGCGDataset,
437
+ data_path=flickr_ann_file,
438
+ image_folder=flickr_image_path,
439
+ tokenizer=tokenizer,
440
+ image_processor=image_processor,
441
+ dataset_map_fn=glamm_flickr_map_fn,
442
+ template_map_fn=dict(
443
+ type=template_map_fn_factory, template=prompt_template),
444
+ max_length=max_length,
445
+ pad_image_to_square=True,
446
+ debug=debug,
447
+ repeats=1,
448
+ )
449
+
450
+ semantic_seg_ade20k_dataset = dict(
451
+ type=ADE20kSemanticSegDataset,
452
+ data_path=ade20k_class_file,
453
+ image_folder=ade20k_image_path,
454
+ tokenizer=tokenizer,
455
+ image_processor=image_processor,
456
+ dataset_map_fn=semantic_seg_gcg_format_map_fn,
457
+ template_map_fn=dict(
458
+ type=template_map_fn_factory, template=prompt_template),
459
+ max_length=max_length,
460
+ pad_image_to_square=True,
461
+ debug=False,
462
+ repeats=1,
463
+ gcg_format=True,
464
+ )
465
+
466
+ semantic_seg_cocostuff_dataset = dict(
467
+ type=COCOStuffSemanticSegDataset,
468
+ data_path=cocostuff_class_file,
469
+ image_folder=cocostuff_image_path,
470
+ label_path=cocostuff_label_path,
471
+ tokenizer=tokenizer,
472
+ image_processor=image_processor,
473
+ dataset_map_fn=semantic_seg_gcg_format_map_fn,
474
+ template_map_fn=dict(
475
+ type=template_map_fn_factory, template=prompt_template),
476
+ max_length=max_length,
477
+ pad_image_to_square=True,
478
+ debug=False,
479
+ repeats=1,
480
+ gcg_format=True,
481
+ )
482
+
483
+ referring_seg_refcoco_dataset = dict(
484
+ type=RefcocoReferringSegDataset,
485
+ data_path=referring_refcoco_data_path,
486
+ image_folder=referring_refcoco_image_path,
487
+ tokenizer=tokenizer,
488
+ image_processor=image_processor,
489
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
490
+ template_map_fn=dict(
491
+ type=template_map_fn_factory, template=prompt_template),
492
+ max_length=max_length,
493
+ pad_image_to_square=True,
494
+ debug=False,
495
+ repeats=1,
496
+ )
497
+
498
+ referring_seg_refcoco_plus_dataset = dict(
499
+ type=Refcoco_plus_ReferringSegDataset,
500
+ data_path=referring_refcoco_plus_data_path,
501
+ image_folder=referring_refcoco_plus_image_path,
502
+ tokenizer=tokenizer,
503
+ image_processor=image_processor,
504
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
505
+ template_map_fn=dict(
506
+ type=template_map_fn_factory, template=prompt_template),
507
+ max_length=max_length,
508
+ pad_image_to_square=True,
509
+ debug=False,
510
+ repeats=1,
511
+ )
512
+
513
+ referring_seg_refcocog_dataset = dict(
514
+ type=Refcocog_ReferringSegDataset,
515
+ data_path=referring_refcocog_data_path,
516
+ image_folder=referring_refcocog_image_path,
517
+ tokenizer=tokenizer,
518
+ image_processor=image_processor,
519
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
520
+ template_map_fn=dict(
521
+ type=template_map_fn_factory, template=prompt_template),
522
+ max_length=max_length,
523
+ pad_image_to_square=True,
524
+ debug=False,
525
+ repeats=1,
526
+ )
527
+
528
+ referring_seg_refclef_dataset = dict(
529
+ type=Refclef_ReferringSegDataset,
530
+ data_path=referring_refclef_data_path,
531
+ image_folder=referring_refclef_image_path,
532
+ tokenizer=tokenizer,
533
+ image_processor=image_processor,
534
+ dataset_map_fn=referring_seg_gcg_format_map_fn,
535
+ template_map_fn=dict(
536
+ type=template_map_fn_factory, template=prompt_template),
537
+ max_length=max_length,
538
+ pad_image_to_square=True,
539
+ debug=False,
540
+ repeats=1,
541
+ )
542
+
543
+ region_cap_osprey_dataset = dict(
544
+ type=OspreyRegionCaptionDataset,
545
+ data_path=region_cap_osprey_data_path,
546
+ image_folder=region_cap_osprey_image_path,
547
+ tokenizer=tokenizer,
548
+ image_processor=image_processor,
549
+ dataset_map_fn=osprey_region_caption_gcg_format_map_fn,
550
+ template_map_fn=dict(
551
+ type=template_map_fn_factory, template=prompt_template),
552
+ max_length=max_length,
553
+ pad_image_to_square=True,
554
+ debug=False,
555
+ repeats=1,
556
+ )
557
+
558
+ region_conversation_osprey_dataset = dict(
559
+ type=OspreyRegionConversationDataset,
560
+ data_path=region_conversation_osprey_data_path,
561
+ image_folder=region_conversation_osprey_image_path,
562
+ tokenizer=tokenizer,
563
+ image_processor=image_processor,
564
+ dataset_map_fn=osprey_region_conversation_map_fn,
565
+ template_map_fn=dict(
566
+ type=template_map_fn_factory, template=prompt_template),
567
+ max_length=max_length,
568
+ pad_image_to_square=True,
569
+ debug=False,
570
+ repeats=1,
571
+ )
572
+
573
+ mdpv_detailed_description_ade20k_dataset = dict(
574
+ type=MDPVPointDetailedCaptionDataset,
575
+ data_path=mdpv_detailed_caption_ade20k_data_path,
576
+ image_folder=mdpv_detailed_caption_ade20k_image_path,
577
+ tokenizer=tokenizer,
578
+ image_processor=image_processor,
579
+ dataset_map_fn=mdpv_points_map_fn,
580
+ template_map_fn=dict(
581
+ type=template_map_fn_factory, template=prompt_template),
582
+ max_length=max_length,
583
+ pad_image_to_square=True,
584
+ debug=False,
585
+ repeats=1,
586
+ )
587
+
588
+ mdpv_detailed_description_cocostuff_10k_dataset = dict(
589
+ type=MDPVPointDetailedCaptionDataset,
590
+ data_path=mdpv_detailed_caption_cocostuff_10k_data_path,
591
+ image_folder=mdpv_detailed_caption_cocostuff_10k_image_path,
592
+ tokenizer=tokenizer,
593
+ image_processor=image_processor,
594
+ dataset_map_fn=mdpv_points_map_fn,
595
+ template_map_fn=dict(
596
+ type=template_map_fn_factory, template=prompt_template),
597
+ max_length=max_length,
598
+ pad_image_to_square=True,
599
+ debug=False,
600
+ repeats=1,
601
+ )
602
+
603
+ mdpv_detailed_description_cocostuff_164k_dataset = dict(
604
+ type=MDPVPointDetailedCaptionDataset,
605
+ data_path=mdpv_detailed_caption_cocostuff_164k_data_path,
606
+ image_folder=mdpv_detailed_caption_cocostuff_164k_image_path,
607
+ tokenizer=tokenizer,
608
+ image_processor=image_processor,
609
+ dataset_map_fn=mdpv_points_map_fn,
610
+ template_map_fn=dict(
611
+ type=template_map_fn_factory, template=prompt_template),
612
+ max_length=max_length,
613
+ pad_image_to_square=True,
614
+ debug=False,
615
+ repeats=1,
616
+ )
617
+
618
+ mdpv_detailed_description_vg_dataset = dict(
619
+ type=MDPVPointDetailedCaptionDataset,
620
+ data_path=mdpv_detailed_caption_vg_data_path,
621
+ image_folder=mdpv_detailed_caption_vg_image_path,
622
+ tokenizer=tokenizer,
623
+ image_processor=image_processor,
624
+ dataset_map_fn=mdpv_points_map_fn,
625
+ template_map_fn=dict(
626
+ type=template_map_fn_factory, template=prompt_template),
627
+ max_length=max_length,
628
+ pad_image_to_square=True,
629
+ debug=False,
630
+ repeats=1,
631
+ )
632
+
633
+ mdpv_brief_description_vg_dataset = dict(
634
+ type=MDPVPointBriefCaptionDataset,
635
+ data_path=mdpv_brief_caption_vg_data_path,
636
+ image_folder=mdpv_brief_caption_vg_image_path,
637
+ tokenizer=tokenizer,
638
+ image_processor=image_processor,
639
+ dataset_map_fn=mdpv_points_map_fn,
640
+ template_map_fn=dict(
641
+ type=template_map_fn_factory, template=prompt_template),
642
+ max_length=max_length,
643
+ pad_image_to_square=True,
644
+ debug=False,
645
+ repeats=1,
646
+ )
647
+
648
+ mdpv_brief_description_cocostuff10k_dataset = dict(
649
+ type=MDPVPointBriefCaptionDataset,
650
+ data_path=mdpv_brief_caption_cocostuff_10k_data_path,
651
+ image_folder=mdpv_brief_caption_cocostuff_10k_image_path,
652
+ tokenizer=tokenizer,
653
+ image_processor=image_processor,
654
+ dataset_map_fn=mdpv_points_map_fn,
655
+ template_map_fn=dict(
656
+ type=template_map_fn_factory, template=prompt_template),
657
+ max_length=max_length,
658
+ pad_image_to_square=True,
659
+ debug=False,
660
+ repeats=1,
661
+ )
662
+
663
+ mdpv_brief_description_cocostuff164k_dataset = dict(
664
+ type=MDPVPointBriefCaptionDataset,
665
+ data_path=mdpv_brief_caption_cocostuff_164k_data_path,
666
+ image_folder=mdpv_brief_caption_cocostuff_164k_image_path,
667
+ tokenizer=tokenizer,
668
+ image_processor=image_processor,
669
+ dataset_map_fn=mdpv_points_map_fn,
670
+ template_map_fn=dict(
671
+ type=template_map_fn_factory, template=prompt_template),
672
+ max_length=max_length,
673
+ pad_image_to_square=True,
674
+ debug=False,
675
+ repeats=1,
676
+ )
677
+
678
+ mdpv_brief_description_ade20k_dataset = dict(
679
+ type=MDPVPointBriefCaptionDataset,
680
+ data_path=mdpv_brief_caption_ade20k_data_path,
681
+ image_folder=mdpv_brief_caption_ade20k_image_path,
682
+ tokenizer=tokenizer,
683
+ image_processor=image_processor,
684
+ dataset_map_fn=mdpv_points_map_fn,
685
+ template_map_fn=dict(
686
+ type=template_map_fn_factory, template=prompt_template),
687
+ max_length=max_length,
688
+ pad_image_to_square=True,
689
+ debug=False,
690
+ repeats=1,
691
+ )
692
+
693
+ mdpv_brief_description_lvis_dataset = dict(
694
+ type=MDPVPointBriefCaptionDataset,
695
+ data_path=mdpv_brief_caption_lvis_data_path,
696
+ image_folder=mdpv_brief_caption_lvis_image_path,
697
+ tokenizer=tokenizer,
698
+ image_processor=image_processor,
699
+ dataset_map_fn=mdpv_points_map_fn,
700
+ template_map_fn=dict(
701
+ type=template_map_fn_factory, template=prompt_template),
702
+ max_length=max_length,
703
+ pad_image_to_square=True,
704
+ debug=False,
705
+ repeats=1,
706
+ )
707
+
708
+ mdpv_qa_vg_dataset = dict(
709
+ type=MDPVPointBriefCaptionDataset,
710
+ data_path=mdpv_qa_vg_data_path,
711
+ image_folder=mdpv_qa_vg_image_path,
712
+ tokenizer=tokenizer,
713
+ image_processor=image_processor,
714
+ dataset_map_fn=mdpv_points_map_fn,
715
+ template_map_fn=dict(
716
+ type=template_map_fn_factory, template=prompt_template),
717
+ max_length=max_length,
718
+ pad_image_to_square=True,
719
+ debug=False,
720
+ repeats=1,
721
+ )
722
+
723
+ mdpv_qa_ade20k_dataset = dict(
724
+ type=MDPVPointBriefCaptionDataset,
725
+ data_path=mdpv_qa_ade20k_data_path,
726
+ image_folder=mdpv_qa_ade20k_image_path,
727
+ tokenizer=tokenizer,
728
+ image_processor=image_processor,
729
+ dataset_map_fn=mdpv_points_map_fn,
730
+ template_map_fn=dict(
731
+ type=template_map_fn_factory, template=prompt_template),
732
+ max_length=max_length,
733
+ pad_image_to_square=True,
734
+ debug=False,
735
+ repeats=1,
736
+ )
737
+
738
+ mdpv_qa_lvis_dataset = dict(
739
+ type=MDPVPointBriefCaptionDataset,
740
+ data_path=mdpv_qa_lvis_data_path,
741
+ image_folder=mdpv_qa_lvis_image_path,
742
+ tokenizer=tokenizer,
743
+ image_processor=image_processor,
744
+ dataset_map_fn=mdpv_points_map_fn,
745
+ template_map_fn=dict(
746
+ type=template_map_fn_factory, template=prompt_template),
747
+ max_length=max_length,
748
+ pad_image_to_square=True,
749
+ debug=False,
750
+ repeats=1,
751
+ )
752
+
753
+ mdpv_qa_cocostuff10k_dataset = dict(
754
+ type=MDPVPointBriefCaptionDataset,
755
+ data_path=mdpv_qa_cocostuff10k_data_path,
756
+ image_folder=mdpv_qa_cocostuff10k_image_path,
757
+ tokenizer=tokenizer,
758
+ image_processor=image_processor,
759
+ dataset_map_fn=mdpv_points_map_fn,
760
+ template_map_fn=dict(
761
+ type=template_map_fn_factory, template=prompt_template),
762
+ max_length=max_length,
763
+ pad_image_to_square=True,
764
+ debug=False,
765
+ repeats=1,
766
+ )
767
+
768
+ mdpv_qa_cocostuff164k_dataset = dict(
769
+ type=MDPVPointBriefCaptionDataset,
770
+ data_path=mdpv_qa_cocostuff164k_data_path,
771
+ image_folder=mdpv_qa_cocostuff164k_image_path,
772
+ tokenizer=tokenizer,
773
+ image_processor=image_processor,
774
+ dataset_map_fn=mdpv_points_map_fn,
775
+ template_map_fn=dict(
776
+ type=template_map_fn_factory, template=prompt_template),
777
+ max_length=max_length,
778
+ pad_image_to_square=True,
779
+ debug=False,
780
+ repeats=1,
781
+ )
782
+
783
+ mdpv_multi_points_openpsg_dataset = dict(
784
+ type=MDPVPointBriefCaptionDataset,
785
+ data_path=mdpv_multi_points_openpsg_data_path,
786
+ image_folder=mdpv_multi_points_openpsg_image_path,
787
+ tokenizer=tokenizer,
788
+ image_processor=image_processor,
789
+ dataset_map_fn=mdpv_points_map_fn,
790
+ template_map_fn=dict(
791
+ type=template_map_fn_factory, template=prompt_template),
792
+ max_length=max_length,
793
+ pad_image_to_square=True,
794
+ debug=False,
795
+ repeats=1,
796
+ )
797
+
798
+ mdpv_multi_points_flicker30k_dataset = dict(
799
+ type=MDPVPointBriefCaptionDataset,
800
+ data_path=mdpv_multi_points_flicker30k_data_path,
801
+ image_folder=mdpv_multi_points_flicker30k_image_path,
802
+ tokenizer=tokenizer,
803
+ image_processor=image_processor,
804
+ dataset_map_fn=mdpv_points_map_fn,
805
+ template_map_fn=dict(
806
+ type=template_map_fn_factory, template=prompt_template),
807
+ max_length=max_length,
808
+ pad_image_to_square=True,
809
+ debug=False,
810
+ repeats=1,
811
+ )
812
+
813
+ train_dataset = dict(
814
+ type=CombineDataset,
815
+ datasets_cfgs=[referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
816
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset, # repeat 3x
817
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
818
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset,
819
+ referring_seg_refcoco_dataset, referring_seg_refcoco_plus_dataset,
820
+ referring_seg_refcocog_dataset, referring_seg_refclef_dataset,],
821
+ )
822
+
823
+ train_dataloader = dict(
824
+ batch_size=batch_size,
825
+ num_workers=dataloader_num_workers,
826
+ dataset=train_dataset,
827
+ sampler=dict(
828
+ type=LengthGroupedSampler,
829
+ length_property='modality_length',
830
+ per_device_batch_size=batch_size * accumulative_counts),
831
+ collate_fn=dict(type=omg_llava_collate_fn))
832
+
833
+ #######################################################################
834
+ # PART 4 Scheduler & Optimizer #
835
+ #######################################################################
836
+ # optimizer
837
+ optim_wrapper = dict(
838
+ type=AmpOptimWrapper,
839
+ optimizer=dict(
840
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
841
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
842
+ accumulative_counts=accumulative_counts,
843
+ loss_scale='dynamic',
844
+ dtype='float16')
845
+
846
+ # learning policy
847
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
848
+ param_scheduler = [
849
+ dict(
850
+ type=LinearLR,
851
+ start_factor=1e-5,
852
+ by_epoch=True,
853
+ begin=0,
854
+ end=warmup_ratio * max_epochs,
855
+ convert_to_iter_based=True),
856
+ dict(
857
+ type=CosineAnnealingLR,
858
+ eta_min=0.0,
859
+ by_epoch=True,
860
+ begin=warmup_ratio * max_epochs,
861
+ end=max_epochs,
862
+ convert_to_iter_based=True)
863
+ ]
864
+
865
+ # train, val, test setting
866
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
867
+
868
+ #######################################################################
869
+ # PART 5 Runtime #
870
+ #######################################################################
871
+ # Log the dialogue periodically during the training process, optional
872
+ custom_hooks = [
873
+ dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer),
874
+ dict(
875
+ type=EvaluateChatHook_withSpecialTokens,
876
+ tokenizer=tokenizer,
877
+ image_processor=image_processor,
878
+ every_n_iters=evaluation_freq,
879
+ evaluation_inputs=evaluation_inputs,
880
+ evaluation_images=evaluation_images,
881
+ system=SYSTEM,
882
+ prompt_template=prompt_template)
883
+ ]
884
+
885
+ # configure default hooks
886
+ default_hooks = dict(
887
+ # record the time of every iteration.
888
+ timer=dict(type=IterTimerHook),
889
+ # print log every 10 iterations.
890
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
891
+ # enable the parameter scheduler.
892
+ param_scheduler=dict(type=ParamSchedulerHook),
893
+ # save checkpoint per `save_steps`.
894
+ checkpoint=dict(
895
+ type=CheckpointHook,
896
+ by_epoch=False,
897
+ interval=save_steps,
898
+ max_keep_ckpts=save_total_limit),
899
+ # set sampler seed in distributed evrionment.
900
+ sampler_seed=dict(type=DistSamplerSeedHook),
901
+ )
902
+
903
+ # configure environment
904
+ env_cfg = dict(
905
+ # whether to enable cudnn benchmark
906
+ cudnn_benchmark=False,
907
+ # set multi process parameters
908
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
909
+ # set distributed parameters
910
+ dist_cfg=dict(backend='nccl'),
911
+ )
912
+
913
+ # set visualizer
914
+ visualizer = None
915
+
916
+ # set log level
917
+ log_level = 'INFO'
918
+
919
+ # load from which checkpoint
920
+ load_from = None
921
+
922
+ # whether to resume training from the loaded checkpoint
923
+ resume = False
924
+
925
+ # Defaults to use random seed and disable `deterministic`
926
+ randomness = dict(seed=None, deterministic=False)
927
+
928
+ # set log processor
929
+ log_processor = dict(by_epoch=False)
omg_llava/configs/pretrain/ablation_projector/ablation_projector_baseline.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.dataset import DefaultSampler
4
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
5
+ LoggerHook, ParamSchedulerHook)
6
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
7
+ from torch.optim import AdamW
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig, CLIPImageProcessor,
10
+ CLIPVisionModel)
11
+
12
+ from omg_llava.dataset import LLaVADataset
13
+ from omg_llava.dataset.collect_fns import omg_llava_collate_fn
14
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
15
+ from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens
16
+ from xtuner.engine.runner import TrainLoop
17
+ from omg_llava.model import OMG_LLaVA
18
+ from xtuner.utils import PROMPT_TEMPLATE
19
+ from omg_llava.model.convnext_clip import OpenCLIPBackbone_omgseg
20
+ from omg_llava.model.omg_seg import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead
21
+
22
+ from torch.nn import GroupNorm, ReLU
23
+
24
+ from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \
25
+ DiceLoss, MaskFormerFusionHead, FocalLoss
26
+ from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost
27
+ from mmdet.models.task_modules.samplers import MaskPseudoSampler
28
+
29
+ #######################################################################
30
+ # PART 1 Settings #
31
+ #######################################################################
32
+ # Model or model paths
33
+ llm_name_or_path = '/mnt/workspace/taozhang/chekpoints/internlm2-chat-7b' # Please change to your own path
34
+ omg_ov_class_embed_path='/mnt/workspace/taozhang/chekpoints/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path
35
+ omg_head_pretrain_pth_path = '/mnt/workspace/taozhang/chekpoints/omg_seg_convl.pth' # Please change to your own path
36
+
37
+ # Data paths
38
+ data_root = './data/llava_data/'
39
+ data_path = data_root + 'LLaVA-Pretrain/blip_laion_cc_sbu_558k.json'
40
+ image_folder = data_root + 'LLaVA-Pretrain/images'
41
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
42
+ max_length = int(2048 - (1024 / 64)**2)
43
+
44
+ # Scheduler & Optimizer
45
+ batch_size = 16 # per_device
46
+ accumulative_counts = 2
47
+ dataloader_num_workers = 4
48
+ max_epochs = 1
49
+ optim_type = AdamW
50
+ lr = 1e-3
51
+ betas = (0.9, 0.999)
52
+ weight_decay = 0
53
+ max_norm = 1 # grad clip
54
+ warmup_ratio = 0.03
55
+
56
+ # Save
57
+ save_steps = 500
58
+ save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
59
+
60
+ # Evaluate the generation performance during the training
61
+ evaluation_freq = 200
62
+ SYSTEM = ''
63
+ evaluation_images = './work_dirs/test.jpg'
64
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture']
65
+
66
+ #######################################################################
67
+ # PART 2 Model & Tokenizer & Image Processor #
68
+ #######################################################################
69
+ tokenizer = dict(
70
+ type=AutoTokenizer.from_pretrained,
71
+ pretrained_model_name_or_path=llm_name_or_path,
72
+ trust_remote_code=True,
73
+ padding_side='right')
74
+
75
+ image_processor = dict(
76
+ type=CLIPImageProcessor,
77
+ do_resize=True,
78
+ size=1024,
79
+ resample=3,
80
+ do_center_crop=True,
81
+ crop_size=1024,
82
+ do_rescale=True,
83
+ do_normalize=True,
84
+ image_mean=[0.4814, 0.4578, 0.4082],
85
+ image_std=[0.2686, 0.2613, 0.2757],
86
+ do_convert_rgb=True
87
+ )
88
+
89
+ # using coco class as the class classifier
90
+ class_embed = 'convnext_large_d_320_CocoPanopticOVDataset'
91
+ num_things_classes = 80
92
+ num_stuff_classes = 53
93
+ num_classes = num_things_classes + num_stuff_classes
94
+
95
+ omgseg_model = dict(
96
+ type=OMGSegVisualEncoder,
97
+ data_preprocessor=None,
98
+ pixel_shuffle_down_ratio=2,
99
+ backbone=dict(
100
+ type=OpenCLIPBackbone_omgseg,
101
+ model_name='convnext_large_d_320',
102
+ fix=True,
103
+ init_cfg=dict(
104
+ type='clip_pretrain',
105
+ checkpoint='laion2b_s29b_b131k_ft_soup'
106
+ )
107
+ ),
108
+ panoptic_head=dict(
109
+ type=Mask2FormerVideoSemSamHead,
110
+ sphere_cls=True,
111
+ ov_path=omg_ov_class_embed_path,
112
+ enable_box_query=False,
113
+ ov_classifier_name=class_embed,
114
+ logit=None,
115
+ in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside
116
+ strides=[4, 8, 16, 32],
117
+ feat_channels=256,
118
+ out_channels=256,
119
+ num_things_classes=num_things_classes,
120
+ num_stuff_classes=num_stuff_classes,
121
+ num_queries=300,
122
+ num_transformer_feat_level=3,
123
+ pixel_decoder=dict(
124
+ type=MSDeformAttnPixelDecoder,
125
+ num_outs=3,
126
+ norm_cfg=dict(type=GroupNorm, num_groups=32),
127
+ act_cfg=dict(type=ReLU),
128
+ encoder=dict( # DeformableDetrTransformerEncoder
129
+ num_layers=6,
130
+ layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
131
+ self_attn_cfg=dict( # MultiScaleDeformableAttention
132
+ embed_dims=256,
133
+ num_heads=8,
134
+ num_levels=3,
135
+ num_points=4,
136
+ dropout=0.0,
137
+ batch_first=True),
138
+ ffn_cfg=dict(
139
+ embed_dims=256,
140
+ feedforward_channels=1024,
141
+ num_fcs=2,
142
+ ffn_drop=0.0,
143
+ act_cfg=dict(type=ReLU, inplace=True)))),
144
+ positional_encoding=dict(num_feats=128, normalize=True)),
145
+ enforce_decoder_input_project=False,
146
+ positional_encoding=dict(num_feats=128, normalize=True),
147
+ transformer_decoder=dict( # Mask2FormerTransformerDecoder
148
+ return_intermediate=True,
149
+ num_layers=9,
150
+ layer_cfg=dict( # Mask2FormerTransformerDecoderLayer
151
+ self_attn_cfg=dict( # MultiheadAttention
152
+ embed_dims=256,
153
+ num_heads=8,
154
+ dropout=0.0,
155
+ batch_first=True),
156
+ cross_attn_cfg=dict( # MultiheadAttention
157
+ embed_dims=256,
158
+ num_heads=8,
159
+ dropout=0.0,
160
+ batch_first=True),
161
+ ffn_cfg=dict(
162
+ embed_dims=256,
163
+ feedforward_channels=2048,
164
+ num_fcs=2,
165
+ ffn_drop=0.0,
166
+ act_cfg=dict(type='ReLU', inplace=True))),
167
+ init_cfg=None),
168
+ loss_cls=dict(
169
+ type=CrossEntropyLoss,
170
+ use_sigmoid=False,
171
+ loss_weight=2.0,
172
+ reduction='mean',
173
+ class_weight=[1.0] * 240 + [0.1]),
174
+ loss_mask=dict(
175
+ type=CrossEntropyLoss,
176
+ use_sigmoid=True,
177
+ reduction='mean',
178
+ loss_weight=5.0),
179
+ loss_dice=dict(
180
+ type=DiceLoss,
181
+ use_sigmoid=True,
182
+ activate=True,
183
+ reduction='mean',
184
+ naive_dice=True,
185
+ eps=1.0,
186
+ loss_weight=5.0),
187
+ loss_iou=dict(
188
+ type=FocalLoss,
189
+ use_sigmoid=True,
190
+ loss_weight=2.0,
191
+ reduction='mean')
192
+ ),
193
+ panoptic_fusion_head=dict(
194
+ type=MaskFormerFusionHead,
195
+ num_things_classes=num_things_classes,
196
+ num_stuff_classes=num_stuff_classes,
197
+ loss_panoptic=None,
198
+ init_cfg=None),
199
+ train_cfg=dict(
200
+ num_points=12544,
201
+ oversample_ratio=3.0,
202
+ importance_sample_ratio=0.75,
203
+ assigner=dict(
204
+ type=HungarianAssigner,
205
+ match_costs=[
206
+ # dict(type=FlexibleClassificationCost, weight=2.0),
207
+ dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True),
208
+ dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0)
209
+ ]),
210
+ sampler=dict(type=MaskPseudoSampler)),
211
+ test_cfg=dict(
212
+ panoptic_on=True,
213
+ # For now, the dataset does not support
214
+ # evaluating semantic segmentation metric.
215
+ semantic_on=False,
216
+ instance_on=True,
217
+ # max_per_image is for instance segmentation.
218
+ max_per_image=100,
219
+ iou_thr=0.8,
220
+ # In Mask2Former's panoptic postprocessing,
221
+ # it will filter mask area where score is less than 0.5 .
222
+ filter_low_score=True),
223
+ init_cfg=dict(
224
+ type='Pretrained',
225
+ checkpoint=omg_head_pretrain_pth_path,
226
+ )
227
+ )
228
+
229
+ model = dict(
230
+ type=OMG_LLaVA,
231
+ freeze_llm=True,
232
+ freeze_visual_encoder=True,
233
+ text2vision_projector=True,
234
+ keep_omg_decoder_frozen=True,
235
+ add_seg_pretrain=True,
236
+ pixel_shuffle_ratio=2,
237
+ visual_prompt_proj=False,
238
+ add_cross_attn_layer=False,
239
+ llm=dict(
240
+ type=AutoModelForCausalLM.from_pretrained,
241
+ pretrained_model_name_or_path=llm_name_or_path,
242
+ trust_remote_code=True,
243
+ torch_dtype=torch.float16,
244
+ quantization_config=dict(
245
+ type=BitsAndBytesConfig,
246
+ load_in_4bit=True,
247
+ load_in_8bit=False,
248
+ llm_int8_threshold=6.0,
249
+ llm_int8_has_fp16_weight=False,
250
+ bnb_4bit_compute_dtype=torch.float16,
251
+ bnb_4bit_use_double_quant=True,
252
+ bnb_4bit_quant_type='nf4')),
253
+ visual_encoder=omgseg_model,
254
+ tokenizer=tokenizer,
255
+ )
256
+
257
+ #######################################################################
258
+ # PART 3 Dataset & Dataloader #
259
+ #######################################################################
260
+ llava_dataset = dict(
261
+ type=LLaVADataset,
262
+ data_path=data_path,
263
+ image_folder=image_folder,
264
+ tokenizer=tokenizer,
265
+ image_processor=image_processor,
266
+ dataset_map_fn=llava_map_fn,
267
+ template_map_fn=dict(
268
+ type=template_map_fn_factory, template=prompt_template),
269
+ max_length=max_length,
270
+ pad_image_to_square=True,
271
+ debug=False,
272
+ )
273
+
274
+ train_dataloader = dict(
275
+ batch_size=batch_size,
276
+ num_workers=dataloader_num_workers,
277
+ dataset=llava_dataset,
278
+ sampler=dict(type=DefaultSampler, shuffle=True),
279
+ collate_fn=dict(type=omg_llava_collate_fn))
280
+
281
+ #######################################################################
282
+ # PART 4 Scheduler & Optimizer #
283
+ #######################################################################
284
+ # optimizer
285
+ optim_wrapper = dict(
286
+ type=AmpOptimWrapper,
287
+ optimizer=dict(
288
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
289
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
290
+ accumulative_counts=accumulative_counts,
291
+ loss_scale='dynamic',
292
+ dtype='float16')
293
+
294
+ # learning policy
295
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
296
+ param_scheduler = [
297
+ dict(
298
+ type=LinearLR,
299
+ start_factor=1e-5,
300
+ by_epoch=True,
301
+ begin=0,
302
+ end=warmup_ratio * max_epochs,
303
+ convert_to_iter_based=True),
304
+ dict(
305
+ type=CosineAnnealingLR,
306
+ eta_min=0.0,
307
+ by_epoch=True,
308
+ begin=warmup_ratio * max_epochs,
309
+ end=max_epochs,
310
+ convert_to_iter_based=True)
311
+ ]
312
+
313
+ # train, val, test setting
314
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
315
+
316
+ #######################################################################
317
+ # PART 5 Runtime #
318
+ #######################################################################
319
+ # Log the dialogue periodically during the training process, optional
320
+ custom_hooks = [
321
+ dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer),
322
+ dict(
323
+ type=EvaluateChatHook_withSpecialTokens,
324
+ tokenizer=tokenizer,
325
+ image_processor=image_processor,
326
+ every_n_iters=evaluation_freq,
327
+ evaluation_inputs=evaluation_inputs,
328
+ evaluation_images=evaluation_images,
329
+ system=SYSTEM,
330
+ prompt_template=prompt_template)
331
+ ]
332
+
333
+ # configure default hooks
334
+ default_hooks = dict(
335
+ # record the time of every iteration.
336
+ timer=dict(type=IterTimerHook),
337
+ # print log every 10 iterations.
338
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
339
+ # enable the parameter scheduler.
340
+ param_scheduler=dict(type=ParamSchedulerHook),
341
+ # save checkpoint per `save_steps`.
342
+ checkpoint=dict(
343
+ type=CheckpointHook,
344
+ by_epoch=False,
345
+ interval=save_steps,
346
+ max_keep_ckpts=save_total_limit),
347
+ # set sampler seed in distributed evrionment.
348
+ sampler_seed=dict(type=DistSamplerSeedHook),
349
+ )
350
+
351
+ # configure environment
352
+ env_cfg = dict(
353
+ # whether to enable cudnn benchmark
354
+ cudnn_benchmark=False,
355
+ # set multi process parameters
356
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
357
+ # set distributed parameters
358
+ dist_cfg=dict(backend='nccl'),
359
+ )
360
+
361
+ # set visualizer
362
+ visualizer = None
363
+
364
+ # set log level
365
+ log_level = 'INFO'
366
+
367
+ # load from which checkpoint
368
+ load_from = None
369
+
370
+ # whether to resume training from the loaded checkpoint
371
+ resume = False
372
+
373
+ # Defaults to use random seed and disable `deterministic`
374
+ randomness = dict(seed=None, deterministic=False)
375
+
376
+ # set log processor
377
+ log_processor = dict(by_epoch=False)
omg_llava/configs/pretrain/ablation_projector/ablation_projector_seperate_cross.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.dataset import DefaultSampler
4
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
5
+ LoggerHook, ParamSchedulerHook)
6
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
7
+ from torch.optim import AdamW
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig, CLIPImageProcessor,
10
+ CLIPVisionModel)
11
+
12
+ from omg_llava.dataset import LLaVADataset
13
+ from omg_llava.dataset.collect_fns import omg_llava_collate_fn
14
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
15
+ from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens
16
+ from xtuner.engine.runner import TrainLoop
17
+ from omg_llava.model import OMG_LLaVA
18
+ from xtuner.utils import PROMPT_TEMPLATE
19
+ from omg_llava.model.convnext_clip import OpenCLIPBackbone_omgseg
20
+ from omg_llava.model.omg_seg import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead
21
+
22
+ from torch.nn import GroupNorm, ReLU
23
+
24
+ from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \
25
+ DiceLoss, MaskFormerFusionHead, FocalLoss
26
+ from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost
27
+ from mmdet.models.task_modules.samplers import MaskPseudoSampler
28
+
29
+ #######################################################################
30
+ # PART 1 Settings #
31
+ #######################################################################
32
+ # Model or model paths
33
+ llm_name_or_path = '/mnt/workspace/taozhang/chekpoints/internlm2-chat-7b' # Please change to your own path
34
+ omg_ov_class_embed_path='/mnt/workspace/taozhang/chekpoints/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path
35
+ omg_head_pretrain_pth_path = '/mnt/workspace/taozhang/chekpoints/omg_seg_convl.pth' # Please change to your own path
36
+
37
+ # Data paths
38
+ data_root = './data/llava_data/'
39
+ data_path = data_root + 'LLaVA-Pretrain/blip_laion_cc_sbu_558k.json'
40
+ image_folder = data_root + 'LLaVA-Pretrain/images'
41
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
42
+ max_length = int(2048 - (1024 / 64)**2)
43
+
44
+ # Scheduler & Optimizer
45
+ batch_size = 16 # per_device
46
+ accumulative_counts = 2
47
+ dataloader_num_workers = 4
48
+ max_epochs = 1
49
+ optim_type = AdamW
50
+ lr = 1e-3
51
+ betas = (0.9, 0.999)
52
+ weight_decay = 0
53
+ max_norm = 1 # grad clip
54
+ warmup_ratio = 0.03
55
+
56
+ # Save
57
+ save_steps = 500
58
+ save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
59
+
60
+ # Evaluate the generation performance during the training
61
+ evaluation_freq = 200
62
+ SYSTEM = ''
63
+ evaluation_images = './work_dirs/test.jpg'
64
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture']
65
+
66
+ #######################################################################
67
+ # PART 2 Model & Tokenizer & Image Processor #
68
+ #######################################################################
69
+ tokenizer = dict(
70
+ type=AutoTokenizer.from_pretrained,
71
+ pretrained_model_name_or_path=llm_name_or_path,
72
+ trust_remote_code=True,
73
+ padding_side='right')
74
+
75
+ image_processor = dict(
76
+ type=CLIPImageProcessor,
77
+ do_resize=True,
78
+ size=1024,
79
+ resample=3,
80
+ do_center_crop=True,
81
+ crop_size=1024,
82
+ do_rescale=True,
83
+ do_normalize=True,
84
+ image_mean=[0.4814, 0.4578, 0.4082],
85
+ image_std=[0.2686, 0.2613, 0.2757],
86
+ do_convert_rgb=True
87
+ )
88
+
89
+ # using coco class as the class classifier
90
+ class_embed = 'convnext_large_d_320_CocoPanopticOVDataset'
91
+ num_things_classes = 80
92
+ num_stuff_classes = 53
93
+ num_classes = num_things_classes + num_stuff_classes
94
+
95
+ omgseg_model = dict(
96
+ type=OMGSegVisualEncoder,
97
+ data_preprocessor=None,
98
+ pixel_shuffle_down_ratio=2,
99
+ backbone=dict(
100
+ type=OpenCLIPBackbone_omgseg,
101
+ model_name='convnext_large_d_320',
102
+ fix=True,
103
+ init_cfg=dict(
104
+ type='clip_pretrain',
105
+ checkpoint='laion2b_s29b_b131k_ft_soup'
106
+ )
107
+ ),
108
+ panoptic_head=dict(
109
+ type=Mask2FormerVideoSemSamHead,
110
+ sphere_cls=True,
111
+ ov_path=omg_ov_class_embed_path,
112
+ enable_box_query=False,
113
+ ov_classifier_name=class_embed,
114
+ logit=None,
115
+ in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside
116
+ strides=[4, 8, 16, 32],
117
+ feat_channels=256,
118
+ out_channels=256,
119
+ num_things_classes=num_things_classes,
120
+ num_stuff_classes=num_stuff_classes,
121
+ num_queries=300,
122
+ num_transformer_feat_level=3,
123
+ pixel_decoder=dict(
124
+ type=MSDeformAttnPixelDecoder,
125
+ num_outs=3,
126
+ norm_cfg=dict(type=GroupNorm, num_groups=32),
127
+ act_cfg=dict(type=ReLU),
128
+ encoder=dict( # DeformableDetrTransformerEncoder
129
+ num_layers=6,
130
+ layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
131
+ self_attn_cfg=dict( # MultiScaleDeformableAttention
132
+ embed_dims=256,
133
+ num_heads=8,
134
+ num_levels=3,
135
+ num_points=4,
136
+ dropout=0.0,
137
+ batch_first=True),
138
+ ffn_cfg=dict(
139
+ embed_dims=256,
140
+ feedforward_channels=1024,
141
+ num_fcs=2,
142
+ ffn_drop=0.0,
143
+ act_cfg=dict(type=ReLU, inplace=True)))),
144
+ positional_encoding=dict(num_feats=128, normalize=True)),
145
+ enforce_decoder_input_project=False,
146
+ positional_encoding=dict(num_feats=128, normalize=True),
147
+ transformer_decoder=dict( # Mask2FormerTransformerDecoder
148
+ return_intermediate=True,
149
+ num_layers=9,
150
+ layer_cfg=dict( # Mask2FormerTransformerDecoderLayer
151
+ self_attn_cfg=dict( # MultiheadAttention
152
+ embed_dims=256,
153
+ num_heads=8,
154
+ dropout=0.0,
155
+ batch_first=True),
156
+ cross_attn_cfg=dict( # MultiheadAttention
157
+ embed_dims=256,
158
+ num_heads=8,
159
+ dropout=0.0,
160
+ batch_first=True),
161
+ ffn_cfg=dict(
162
+ embed_dims=256,
163
+ feedforward_channels=2048,
164
+ num_fcs=2,
165
+ ffn_drop=0.0,
166
+ act_cfg=dict(type='ReLU', inplace=True))),
167
+ init_cfg=None),
168
+ loss_cls=dict(
169
+ type=CrossEntropyLoss,
170
+ use_sigmoid=False,
171
+ loss_weight=2.0,
172
+ reduction='mean',
173
+ class_weight=[1.0] * 240 + [0.1]),
174
+ loss_mask=dict(
175
+ type=CrossEntropyLoss,
176
+ use_sigmoid=True,
177
+ reduction='mean',
178
+ loss_weight=5.0),
179
+ loss_dice=dict(
180
+ type=DiceLoss,
181
+ use_sigmoid=True,
182
+ activate=True,
183
+ reduction='mean',
184
+ naive_dice=True,
185
+ eps=1.0,
186
+ loss_weight=5.0),
187
+ loss_iou=dict(
188
+ type=FocalLoss,
189
+ use_sigmoid=True,
190
+ loss_weight=2.0,
191
+ reduction='mean')
192
+ ),
193
+ panoptic_fusion_head=dict(
194
+ type=MaskFormerFusionHead,
195
+ num_things_classes=num_things_classes,
196
+ num_stuff_classes=num_stuff_classes,
197
+ loss_panoptic=None,
198
+ init_cfg=None),
199
+ train_cfg=dict(
200
+ num_points=12544,
201
+ oversample_ratio=3.0,
202
+ importance_sample_ratio=0.75,
203
+ assigner=dict(
204
+ type=HungarianAssigner,
205
+ match_costs=[
206
+ # dict(type=FlexibleClassificationCost, weight=2.0),
207
+ dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True),
208
+ dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0)
209
+ ]),
210
+ sampler=dict(type=MaskPseudoSampler)),
211
+ test_cfg=dict(
212
+ panoptic_on=True,
213
+ # For now, the dataset does not support
214
+ # evaluating semantic segmentation metric.
215
+ semantic_on=False,
216
+ instance_on=True,
217
+ # max_per_image is for instance segmentation.
218
+ max_per_image=100,
219
+ iou_thr=0.8,
220
+ # In Mask2Former's panoptic postprocessing,
221
+ # it will filter mask area where score is less than 0.5 .
222
+ filter_low_score=True),
223
+ init_cfg=dict(
224
+ type='Pretrained',
225
+ checkpoint=omg_head_pretrain_pth_path,
226
+ )
227
+ )
228
+
229
+ model = dict(
230
+ type=OMG_LLaVA,
231
+ freeze_llm=True,
232
+ freeze_visual_encoder=True,
233
+ text2vision_projector=True,
234
+ keep_omg_decoder_frozen=True,
235
+ add_seg_pretrain=True,
236
+ pixel_shuffle_ratio=2,
237
+ visual_prompt_proj=False,
238
+ add_cross_attn_layer=True,
239
+ llm=dict(
240
+ type=AutoModelForCausalLM.from_pretrained,
241
+ pretrained_model_name_or_path=llm_name_or_path,
242
+ trust_remote_code=True,
243
+ torch_dtype=torch.float16,
244
+ quantization_config=dict(
245
+ type=BitsAndBytesConfig,
246
+ load_in_4bit=True,
247
+ load_in_8bit=False,
248
+ llm_int8_threshold=6.0,
249
+ llm_int8_has_fp16_weight=False,
250
+ bnb_4bit_compute_dtype=torch.float16,
251
+ bnb_4bit_use_double_quant=True,
252
+ bnb_4bit_quant_type='nf4')),
253
+ visual_encoder=omgseg_model,
254
+ tokenizer=tokenizer,
255
+ )
256
+
257
+ #######################################################################
258
+ # PART 3 Dataset & Dataloader #
259
+ #######################################################################
260
+ llava_dataset = dict(
261
+ type=LLaVADataset,
262
+ data_path=data_path,
263
+ image_folder=image_folder,
264
+ tokenizer=tokenizer,
265
+ image_processor=image_processor,
266
+ dataset_map_fn=llava_map_fn,
267
+ template_map_fn=dict(
268
+ type=template_map_fn_factory, template=prompt_template),
269
+ max_length=max_length,
270
+ pad_image_to_square=True,
271
+ debug=False,
272
+ )
273
+
274
+ train_dataloader = dict(
275
+ batch_size=batch_size,
276
+ num_workers=dataloader_num_workers,
277
+ dataset=llava_dataset,
278
+ sampler=dict(type=DefaultSampler, shuffle=True),
279
+ collate_fn=dict(type=omg_llava_collate_fn))
280
+
281
+ #######################################################################
282
+ # PART 4 Scheduler & Optimizer #
283
+ #######################################################################
284
+ # optimizer
285
+ optim_wrapper = dict(
286
+ type=AmpOptimWrapper,
287
+ optimizer=dict(
288
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
289
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
290
+ accumulative_counts=accumulative_counts,
291
+ loss_scale='dynamic',
292
+ dtype='float16')
293
+
294
+ # learning policy
295
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
296
+ param_scheduler = [
297
+ dict(
298
+ type=LinearLR,
299
+ start_factor=1e-5,
300
+ by_epoch=True,
301
+ begin=0,
302
+ end=warmup_ratio * max_epochs,
303
+ convert_to_iter_based=True),
304
+ dict(
305
+ type=CosineAnnealingLR,
306
+ eta_min=0.0,
307
+ by_epoch=True,
308
+ begin=warmup_ratio * max_epochs,
309
+ end=max_epochs,
310
+ convert_to_iter_based=True)
311
+ ]
312
+
313
+ # train, val, test setting
314
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
315
+
316
+ #######################################################################
317
+ # PART 5 Runtime #
318
+ #######################################################################
319
+ # Log the dialogue periodically during the training process, optional
320
+ custom_hooks = [
321
+ dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer),
322
+ dict(
323
+ type=EvaluateChatHook_withSpecialTokens,
324
+ tokenizer=tokenizer,
325
+ image_processor=image_processor,
326
+ every_n_iters=evaluation_freq,
327
+ evaluation_inputs=evaluation_inputs,
328
+ evaluation_images=evaluation_images,
329
+ system=SYSTEM,
330
+ prompt_template=prompt_template)
331
+ ]
332
+
333
+ # configure default hooks
334
+ default_hooks = dict(
335
+ # record the time of every iteration.
336
+ timer=dict(type=IterTimerHook),
337
+ # print log every 10 iterations.
338
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
339
+ # enable the parameter scheduler.
340
+ param_scheduler=dict(type=ParamSchedulerHook),
341
+ # save checkpoint per `save_steps`.
342
+ checkpoint=dict(
343
+ type=CheckpointHook,
344
+ by_epoch=False,
345
+ interval=save_steps,
346
+ max_keep_ckpts=save_total_limit),
347
+ # set sampler seed in distributed evrionment.
348
+ sampler_seed=dict(type=DistSamplerSeedHook),
349
+ )
350
+
351
+ # configure environment
352
+ env_cfg = dict(
353
+ # whether to enable cudnn benchmark
354
+ cudnn_benchmark=False,
355
+ # set multi process parameters
356
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
357
+ # set distributed parameters
358
+ dist_cfg=dict(backend='nccl'),
359
+ )
360
+
361
+ # set visualizer
362
+ visualizer = None
363
+
364
+ # set log level
365
+ log_level = 'INFO'
366
+
367
+ # load from which checkpoint
368
+ load_from = None
369
+
370
+ # whether to resume training from the loaded checkpoint
371
+ resume = False
372
+
373
+ # Defaults to use random seed and disable `deterministic`
374
+ randomness = dict(seed=None, deterministic=False)
375
+
376
+ # set log processor
377
+ log_processor = dict(by_epoch=False)
omg_llava/configs/pretrain/ablation_projector/ablation_projector_seperate_cross_rmProjloss.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.dataset import DefaultSampler
4
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
5
+ LoggerHook, ParamSchedulerHook)
6
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
7
+ from torch.optim import AdamW
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig, CLIPImageProcessor,
10
+ CLIPVisionModel)
11
+
12
+ from omg_llava.dataset import LLaVADataset
13
+ from omg_llava.dataset.collect_fns import omg_llava_collate_fn
14
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
15
+ from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens
16
+ from xtuner.engine.runner import TrainLoop
17
+ from omg_llava.model import OMG_LLaVA
18
+ from xtuner.utils import PROMPT_TEMPLATE
19
+ from omg_llava.model.convnext_clip import OpenCLIPBackbone_omgseg
20
+ from omg_llava.model.omg_seg import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead
21
+
22
+ from torch.nn import GroupNorm, ReLU
23
+
24
+ from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \
25
+ DiceLoss, MaskFormerFusionHead, FocalLoss
26
+ from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost
27
+ from mmdet.models.task_modules.samplers import MaskPseudoSampler
28
+
29
+ #######################################################################
30
+ # PART 1 Settings #
31
+ #######################################################################
32
+ # Model or model paths
33
+ llm_name_or_path = '/mnt/workspace/taozhang/chekpoints/internlm2-chat-7b' # Please change to your own path
34
+ omg_ov_class_embed_path='/mnt/workspace/taozhang/chekpoints/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path
35
+ omg_head_pretrain_pth_path = '/mnt/workspace/taozhang/chekpoints/omg_seg_convl.pth' # Please change to your own path
36
+
37
+ # Data paths
38
+ data_root = './data/llava_data/'
39
+ data_path = data_root + 'LLaVA-Pretrain/blip_laion_cc_sbu_558k.json'
40
+ image_folder = data_root + 'LLaVA-Pretrain/images'
41
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
42
+ max_length = int(2048 - (1024 / 64)**2)
43
+
44
+ # Scheduler & Optimizer
45
+ batch_size = 16 # per_device
46
+ accumulative_counts = 4
47
+ dataloader_num_workers = 4
48
+ max_epochs = 1
49
+ optim_type = AdamW
50
+ lr = 1e-3
51
+ betas = (0.9, 0.999)
52
+ weight_decay = 0
53
+ max_norm = 1 # grad clip
54
+ warmup_ratio = 0.03
55
+
56
+ # Save
57
+ save_steps = 500
58
+ save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
59
+
60
+ # Evaluate the generation performance during the training
61
+ evaluation_freq = 200
62
+ SYSTEM = ''
63
+ evaluation_images = './work_dirs/test.jpg'
64
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture']
65
+
66
+ #######################################################################
67
+ # PART 2 Model & Tokenizer & Image Processor #
68
+ #######################################################################
69
+ tokenizer = dict(
70
+ type=AutoTokenizer.from_pretrained,
71
+ pretrained_model_name_or_path=llm_name_or_path,
72
+ trust_remote_code=True,
73
+ padding_side='right')
74
+
75
+ image_processor = dict(
76
+ type=CLIPImageProcessor,
77
+ do_resize=True,
78
+ size=1024,
79
+ resample=3,
80
+ do_center_crop=True,
81
+ crop_size=1024,
82
+ do_rescale=True,
83
+ do_normalize=True,
84
+ image_mean=[0.4814, 0.4578, 0.4082],
85
+ image_std=[0.2686, 0.2613, 0.2757],
86
+ do_convert_rgb=True
87
+ )
88
+
89
+ # using coco class as the class classifier
90
+ class_embed = 'convnext_large_d_320_CocoPanopticOVDataset'
91
+ num_things_classes = 80
92
+ num_stuff_classes = 53
93
+ num_classes = num_things_classes + num_stuff_classes
94
+
95
+ omgseg_model = dict(
96
+ type=OMGSegVisualEncoder,
97
+ data_preprocessor=None,
98
+ pixel_shuffle_down_ratio=2,
99
+ backbone=dict(
100
+ type=OpenCLIPBackbone_omgseg,
101
+ model_name='convnext_large_d_320',
102
+ fix=True,
103
+ init_cfg=dict(
104
+ type='clip_pretrain',
105
+ checkpoint='laion2b_s29b_b131k_ft_soup'
106
+ )
107
+ ),
108
+ panoptic_head=dict(
109
+ type=Mask2FormerVideoSemSamHead,
110
+ sphere_cls=True,
111
+ ov_path=omg_ov_class_embed_path,
112
+ enable_box_query=False,
113
+ ov_classifier_name=class_embed,
114
+ logit=None,
115
+ in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside
116
+ strides=[4, 8, 16, 32],
117
+ feat_channels=256,
118
+ out_channels=256,
119
+ num_things_classes=num_things_classes,
120
+ num_stuff_classes=num_stuff_classes,
121
+ num_queries=300,
122
+ num_transformer_feat_level=3,
123
+ pixel_decoder=dict(
124
+ type=MSDeformAttnPixelDecoder,
125
+ num_outs=3,
126
+ norm_cfg=dict(type=GroupNorm, num_groups=32),
127
+ act_cfg=dict(type=ReLU),
128
+ encoder=dict( # DeformableDetrTransformerEncoder
129
+ num_layers=6,
130
+ layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
131
+ self_attn_cfg=dict( # MultiScaleDeformableAttention
132
+ embed_dims=256,
133
+ num_heads=8,
134
+ num_levels=3,
135
+ num_points=4,
136
+ dropout=0.0,
137
+ batch_first=True),
138
+ ffn_cfg=dict(
139
+ embed_dims=256,
140
+ feedforward_channels=1024,
141
+ num_fcs=2,
142
+ ffn_drop=0.0,
143
+ act_cfg=dict(type=ReLU, inplace=True)))),
144
+ positional_encoding=dict(num_feats=128, normalize=True)),
145
+ enforce_decoder_input_project=False,
146
+ positional_encoding=dict(num_feats=128, normalize=True),
147
+ transformer_decoder=dict( # Mask2FormerTransformerDecoder
148
+ return_intermediate=True,
149
+ num_layers=9,
150
+ layer_cfg=dict( # Mask2FormerTransformerDecoderLayer
151
+ self_attn_cfg=dict( # MultiheadAttention
152
+ embed_dims=256,
153
+ num_heads=8,
154
+ dropout=0.0,
155
+ batch_first=True),
156
+ cross_attn_cfg=dict( # MultiheadAttention
157
+ embed_dims=256,
158
+ num_heads=8,
159
+ dropout=0.0,
160
+ batch_first=True),
161
+ ffn_cfg=dict(
162
+ embed_dims=256,
163
+ feedforward_channels=2048,
164
+ num_fcs=2,
165
+ ffn_drop=0.0,
166
+ act_cfg=dict(type='ReLU', inplace=True))),
167
+ init_cfg=None),
168
+ loss_cls=dict(
169
+ type=CrossEntropyLoss,
170
+ use_sigmoid=False,
171
+ loss_weight=2.0,
172
+ reduction='mean',
173
+ class_weight=[1.0] * 240 + [0.1]),
174
+ loss_mask=dict(
175
+ type=CrossEntropyLoss,
176
+ use_sigmoid=True,
177
+ reduction='mean',
178
+ loss_weight=5.0),
179
+ loss_dice=dict(
180
+ type=DiceLoss,
181
+ use_sigmoid=True,
182
+ activate=True,
183
+ reduction='mean',
184
+ naive_dice=True,
185
+ eps=1.0,
186
+ loss_weight=5.0),
187
+ loss_iou=dict(
188
+ type=FocalLoss,
189
+ use_sigmoid=True,
190
+ loss_weight=2.0,
191
+ reduction='mean')
192
+ ),
193
+ panoptic_fusion_head=dict(
194
+ type=MaskFormerFusionHead,
195
+ num_things_classes=num_things_classes,
196
+ num_stuff_classes=num_stuff_classes,
197
+ loss_panoptic=None,
198
+ init_cfg=None),
199
+ train_cfg=dict(
200
+ num_points=12544,
201
+ oversample_ratio=3.0,
202
+ importance_sample_ratio=0.75,
203
+ assigner=dict(
204
+ type=HungarianAssigner,
205
+ match_costs=[
206
+ # dict(type=FlexibleClassificationCost, weight=2.0),
207
+ dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True),
208
+ dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0)
209
+ ]),
210
+ sampler=dict(type=MaskPseudoSampler)),
211
+ test_cfg=dict(
212
+ panoptic_on=True,
213
+ # For now, the dataset does not support
214
+ # evaluating semantic segmentation metric.
215
+ semantic_on=False,
216
+ instance_on=True,
217
+ # max_per_image is for instance segmentation.
218
+ max_per_image=100,
219
+ iou_thr=0.8,
220
+ # In Mask2Former's panoptic postprocessing,
221
+ # it will filter mask area where score is less than 0.5 .
222
+ filter_low_score=True),
223
+ init_cfg=dict(
224
+ type='Pretrained',
225
+ checkpoint=omg_head_pretrain_pth_path,
226
+ )
227
+ )
228
+
229
+ model = dict(
230
+ type=OMG_LLaVA,
231
+ freeze_llm=True,
232
+ freeze_visual_encoder=True,
233
+ text2vision_projector=True,
234
+ keep_omg_decoder_frozen=True,
235
+ add_seg_pretrain=True,
236
+ pixel_shuffle_ratio=2,
237
+ visual_prompt_proj=False,
238
+ add_cross_attn_layer=True,
239
+ llm=dict(
240
+ type=AutoModelForCausalLM.from_pretrained,
241
+ pretrained_model_name_or_path=llm_name_or_path,
242
+ trust_remote_code=True,
243
+ torch_dtype=torch.float16,
244
+ quantization_config=dict(
245
+ type=BitsAndBytesConfig,
246
+ load_in_4bit=True,
247
+ load_in_8bit=False,
248
+ llm_int8_threshold=6.0,
249
+ llm_int8_has_fp16_weight=False,
250
+ bnb_4bit_compute_dtype=torch.float16,
251
+ bnb_4bit_use_double_quant=True,
252
+ bnb_4bit_quant_type='nf4')),
253
+ visual_encoder=omgseg_model,
254
+ tokenizer=tokenizer,
255
+ )
256
+
257
+ #######################################################################
258
+ # PART 3 Dataset & Dataloader #
259
+ #######################################################################
260
+ llava_dataset = dict(
261
+ type=LLaVADataset,
262
+ data_path=data_path,
263
+ image_folder=image_folder,
264
+ tokenizer=tokenizer,
265
+ image_processor=image_processor,
266
+ dataset_map_fn=llava_map_fn,
267
+ template_map_fn=dict(
268
+ type=template_map_fn_factory, template=prompt_template),
269
+ max_length=max_length,
270
+ pad_image_to_square=True,
271
+ debug=False,
272
+ )
273
+
274
+ train_dataloader = dict(
275
+ batch_size=batch_size,
276
+ num_workers=dataloader_num_workers,
277
+ dataset=llava_dataset,
278
+ sampler=dict(type=DefaultSampler, shuffle=True),
279
+ collate_fn=dict(type=omg_llava_collate_fn))
280
+
281
+ #######################################################################
282
+ # PART 4 Scheduler & Optimizer #
283
+ #######################################################################
284
+ # optimizer
285
+ optim_wrapper = dict(
286
+ type=AmpOptimWrapper,
287
+ optimizer=dict(
288
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
289
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
290
+ accumulative_counts=accumulative_counts,
291
+ loss_scale='dynamic',
292
+ dtype='float16')
293
+
294
+ # learning policy
295
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
296
+ param_scheduler = [
297
+ dict(
298
+ type=LinearLR,
299
+ start_factor=1e-5,
300
+ by_epoch=True,
301
+ begin=0,
302
+ end=warmup_ratio * max_epochs,
303
+ convert_to_iter_based=True),
304
+ dict(
305
+ type=CosineAnnealingLR,
306
+ eta_min=0.0,
307
+ by_epoch=True,
308
+ begin=warmup_ratio * max_epochs,
309
+ end=max_epochs,
310
+ convert_to_iter_based=True)
311
+ ]
312
+
313
+ # train, val, test setting
314
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
315
+
316
+ #######################################################################
317
+ # PART 5 Runtime #
318
+ #######################################################################
319
+ # Log the dialogue periodically during the training process, optional
320
+ custom_hooks = [
321
+ dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer),
322
+ dict(
323
+ type=EvaluateChatHook_withSpecialTokens,
324
+ tokenizer=tokenizer,
325
+ image_processor=image_processor,
326
+ every_n_iters=evaluation_freq,
327
+ evaluation_inputs=evaluation_inputs,
328
+ evaluation_images=evaluation_images,
329
+ system=SYSTEM,
330
+ prompt_template=prompt_template)
331
+ ]
332
+
333
+ # configure default hooks
334
+ default_hooks = dict(
335
+ # record the time of every iteration.
336
+ timer=dict(type=IterTimerHook),
337
+ # print log every 10 iterations.
338
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
339
+ # enable the parameter scheduler.
340
+ param_scheduler=dict(type=ParamSchedulerHook),
341
+ # save checkpoint per `save_steps`.
342
+ checkpoint=dict(
343
+ type=CheckpointHook,
344
+ by_epoch=False,
345
+ interval=save_steps,
346
+ max_keep_ckpts=save_total_limit),
347
+ # set sampler seed in distributed evrionment.
348
+ sampler_seed=dict(type=DistSamplerSeedHook),
349
+ )
350
+
351
+ # configure environment
352
+ env_cfg = dict(
353
+ # whether to enable cudnn benchmark
354
+ cudnn_benchmark=False,
355
+ # set multi process parameters
356
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
357
+ # set distributed parameters
358
+ dist_cfg=dict(backend='nccl'),
359
+ )
360
+
361
+ # set visualizer
362
+ visualizer = None
363
+
364
+ # set log level
365
+ log_level = 'INFO'
366
+
367
+ # load from which checkpoint
368
+ load_from = None
369
+
370
+ # whether to resume training from the loaded checkpoint
371
+ resume = False
372
+
373
+ # Defaults to use random seed and disable `deterministic`
374
+ randomness = dict(seed=None, deterministic=False)
375
+
376
+ # set log processor
377
+ log_processor = dict(by_epoch=False)
omg_llava/configs/pretrain/omg_llava_20b_pretrain_1024image_8gpus.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.dataset import DefaultSampler
4
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
5
+ LoggerHook, ParamSchedulerHook)
6
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
7
+ from torch.optim import AdamW
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig, CLIPImageProcessor,
10
+ CLIPVisionModel)
11
+
12
+ from omg_llava.dataset import LLaVADataset
13
+ from omg_llava.dataset.collect_fns import omg_llava_collate_fn
14
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
15
+ from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens
16
+ from xtuner.engine.runner import TrainLoop
17
+ from omg_llava.model import OMG_LLaVA
18
+ from xtuner.utils import PROMPT_TEMPLATE
19
+ from omg_llava.model.convnext_clip import OpenCLIPBackbone_omgseg
20
+ from omg_llava.model.omg_seg import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead
21
+
22
+ from torch.nn import GroupNorm, ReLU
23
+
24
+ from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \
25
+ DiceLoss, MaskFormerFusionHead, FocalLoss
26
+ from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost
27
+ from mmdet.models.task_modules.samplers import MaskPseudoSampler
28
+
29
+ #######################################################################
30
+ # PART 1 Settings #
31
+ #######################################################################
32
+ # Model or model paths
33
+ llm_name_or_path = './pretrained/omg_llava/internlm2-chat-20b' # Please change to your own path
34
+ omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path
35
+ omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path
36
+
37
+ # Data paths
38
+ data_root = './data/llava_data/'
39
+ data_path = data_root + 'LLaVA-Pretrain/blip_laion_cc_sbu_558k.json'
40
+ image_folder = data_root + 'LLaVA-Pretrain/images'
41
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
42
+ max_length = int(2048 - (1024 / 64)**2)
43
+
44
+ # Scheduler & Optimizer
45
+ batch_size = 16 # per_device
46
+ accumulative_counts = 2
47
+ dataloader_num_workers = 0
48
+ max_epochs = 1
49
+ optim_type = AdamW
50
+ lr = 1e-3
51
+ betas = (0.9, 0.999)
52
+ weight_decay = 0
53
+ max_norm = 1 # grad clip
54
+ warmup_ratio = 0.03
55
+
56
+ # Save
57
+ save_steps = 500
58
+ save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
59
+
60
+ # Evaluate the generation performance during the training
61
+ evaluation_freq = 200
62
+ SYSTEM = ''
63
+ evaluation_images = './work_dirs/test.jpg'
64
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture']
65
+
66
+ #######################################################################
67
+ # PART 2 Model & Tokenizer & Image Processor #
68
+ #######################################################################
69
+ tokenizer = dict(
70
+ type=AutoTokenizer.from_pretrained,
71
+ pretrained_model_name_or_path=llm_name_or_path,
72
+ trust_remote_code=True,
73
+ padding_side='right')
74
+
75
+ image_processor = dict(
76
+ type=CLIPImageProcessor,
77
+ do_resize=True,
78
+ size=1024,
79
+ resample=3,
80
+ do_center_crop=True,
81
+ crop_size=1024,
82
+ do_rescale=True,
83
+ do_normalize=True,
84
+ image_mean=[0.4814, 0.4578, 0.4082],
85
+ image_std=[0.2686, 0.2613, 0.2757],
86
+ do_convert_rgb=True
87
+ )
88
+
89
+ # using coco class as the class classifier
90
+ class_embed = 'convnext_large_d_320_CocoPanopticOVDataset'
91
+ num_things_classes = 80
92
+ num_stuff_classes = 53
93
+ num_classes = num_things_classes + num_stuff_classes
94
+
95
+
96
+
97
+
98
+
99
+ omgseg_model = dict(
100
+ type=OMGSegVisualEncoder,
101
+ data_preprocessor=None,
102
+ pixel_shuffle_down_ratio=2,
103
+ backbone=dict(
104
+ type=OpenCLIPBackbone_omgseg,
105
+ model_name='convnext_large_d_320',
106
+ fix=True,
107
+ init_cfg=dict(
108
+ type='clip_pretrain',
109
+ checkpoint='laion2b_s29b_b131k_ft_soup'
110
+ )
111
+ ),
112
+ panoptic_head=dict(
113
+ type=Mask2FormerVideoSemSamHead,
114
+ sphere_cls=True,
115
+ ov_path=omg_ov_class_embed_path,
116
+ enable_box_query=False,
117
+ ov_classifier_name=class_embed,
118
+ logit=None,
119
+ in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside
120
+ strides=[4, 8, 16, 32],
121
+ feat_channels=256,
122
+ out_channels=256,
123
+ num_things_classes=num_things_classes,
124
+ num_stuff_classes=num_stuff_classes,
125
+ num_queries=300,
126
+ num_transformer_feat_level=3,
127
+ pixel_decoder=dict(
128
+ type=MSDeformAttnPixelDecoder,
129
+ num_outs=3,
130
+ norm_cfg=dict(type=GroupNorm, num_groups=32),
131
+ act_cfg=dict(type=ReLU),
132
+ encoder=dict( # DeformableDetrTransformerEncoder
133
+ num_layers=6,
134
+ layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
135
+ self_attn_cfg=dict( # MultiScaleDeformableAttention
136
+ embed_dims=256,
137
+ num_heads=8,
138
+ num_levels=3,
139
+ num_points=4,
140
+ dropout=0.0,
141
+ batch_first=True),
142
+ ffn_cfg=dict(
143
+ embed_dims=256,
144
+ feedforward_channels=1024,
145
+ num_fcs=2,
146
+ ffn_drop=0.0,
147
+ act_cfg=dict(type=ReLU, inplace=True)))),
148
+ positional_encoding=dict(num_feats=128, normalize=True)),
149
+ enforce_decoder_input_project=False,
150
+ positional_encoding=dict(num_feats=128, normalize=True),
151
+ transformer_decoder=dict( # Mask2FormerTransformerDecoder
152
+ return_intermediate=True,
153
+ num_layers=9,
154
+ layer_cfg=dict( # Mask2FormerTransformerDecoderLayer
155
+ self_attn_cfg=dict( # MultiheadAttention
156
+ embed_dims=256,
157
+ num_heads=8,
158
+ dropout=0.0,
159
+ batch_first=True),
160
+ cross_attn_cfg=dict( # MultiheadAttention
161
+ embed_dims=256,
162
+ num_heads=8,
163
+ dropout=0.0,
164
+ batch_first=True),
165
+ ffn_cfg=dict(
166
+ embed_dims=256,
167
+ feedforward_channels=2048,
168
+ num_fcs=2,
169
+ ffn_drop=0.0,
170
+ act_cfg=dict(type='ReLU', inplace=True))),
171
+ init_cfg=None),
172
+ loss_cls=dict(
173
+ type=CrossEntropyLoss,
174
+ use_sigmoid=False,
175
+ loss_weight=2.0,
176
+ reduction='mean',
177
+ class_weight=[1.0] * 240 + [0.1]),
178
+ loss_mask=dict(
179
+ type=CrossEntropyLoss,
180
+ use_sigmoid=True,
181
+ reduction='mean',
182
+ loss_weight=5.0),
183
+ loss_dice=dict(
184
+ type=DiceLoss,
185
+ use_sigmoid=True,
186
+ activate=True,
187
+ reduction='mean',
188
+ naive_dice=True,
189
+ eps=1.0,
190
+ loss_weight=5.0),
191
+ loss_iou=dict(
192
+ type=FocalLoss,
193
+ use_sigmoid=True,
194
+ loss_weight=2.0,
195
+ reduction='mean')
196
+ ),
197
+ panoptic_fusion_head=dict(
198
+ type=MaskFormerFusionHead,
199
+ num_things_classes=num_things_classes,
200
+ num_stuff_classes=num_stuff_classes,
201
+ loss_panoptic=None,
202
+ init_cfg=None),
203
+ train_cfg=dict(
204
+ num_points=12544,
205
+ oversample_ratio=3.0,
206
+ importance_sample_ratio=0.75,
207
+ assigner=dict(
208
+ type=HungarianAssigner,
209
+ match_costs=[
210
+ # dict(type=FlexibleClassificationCost, weight=2.0),
211
+ dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True),
212
+ dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0)
213
+ ]),
214
+ sampler=dict(type=MaskPseudoSampler)),
215
+ test_cfg=dict(
216
+ panoptic_on=True,
217
+ # For now, the dataset does not support
218
+ # evaluating semantic segmentation metric.
219
+ semantic_on=False,
220
+ instance_on=True,
221
+ # max_per_image is for instance segmentation.
222
+ max_per_image=100,
223
+ iou_thr=0.8,
224
+ # In Mask2Former's panoptic postprocessing,
225
+ # it will filter mask area where score is less than 0.5 .
226
+ filter_low_score=True),
227
+ init_cfg=dict(
228
+ type='Pretrained',
229
+ checkpoint=omg_head_pretrain_pth_path,
230
+ )
231
+ )
232
+
233
+ model = dict(
234
+ type=OMG_LLaVA,
235
+ freeze_llm=True,
236
+ freeze_visual_encoder=True,
237
+ text2vision_projector=True,
238
+ keep_omg_decoder_frozen=True,
239
+ add_seg_pretrain=True,
240
+ pixel_shuffle_ratio=2,
241
+ llm=dict(
242
+ type=AutoModelForCausalLM.from_pretrained,
243
+ pretrained_model_name_or_path=llm_name_or_path,
244
+ trust_remote_code=True,
245
+ torch_dtype=torch.float16,
246
+ quantization_config=dict(
247
+ type=BitsAndBytesConfig,
248
+ load_in_4bit=True,
249
+ load_in_8bit=False,
250
+ llm_int8_threshold=6.0,
251
+ llm_int8_has_fp16_weight=False,
252
+ bnb_4bit_compute_dtype=torch.float16,
253
+ bnb_4bit_use_double_quant=True,
254
+ bnb_4bit_quant_type='nf4')),
255
+ visual_encoder=omgseg_model,
256
+ tokenizer=tokenizer,
257
+ )
258
+
259
+ #######################################################################
260
+ # PART 3 Dataset & Dataloader #
261
+ #######################################################################
262
+ llava_dataset = dict(
263
+ type=LLaVADataset,
264
+ data_path=data_path,
265
+ image_folder=image_folder,
266
+ tokenizer=tokenizer,
267
+ image_processor=image_processor,
268
+ dataset_map_fn=llava_map_fn,
269
+ template_map_fn=dict(
270
+ type=template_map_fn_factory, template=prompt_template),
271
+ max_length=max_length,
272
+ pad_image_to_square=True,
273
+ debug=False,
274
+ )
275
+
276
+ train_dataloader = dict(
277
+ batch_size=batch_size,
278
+ num_workers=dataloader_num_workers,
279
+ dataset=llava_dataset,
280
+ sampler=dict(type=DefaultSampler, shuffle=True),
281
+ collate_fn=dict(type=omg_llava_collate_fn))
282
+
283
+ #######################################################################
284
+ # PART 4 Scheduler & Optimizer #
285
+ #######################################################################
286
+ # optimizer
287
+ optim_wrapper = dict(
288
+ type=AmpOptimWrapper,
289
+ optimizer=dict(
290
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
291
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
292
+ accumulative_counts=accumulative_counts,
293
+ loss_scale='dynamic',
294
+ dtype='float16')
295
+
296
+ # learning policy
297
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
298
+ param_scheduler = [
299
+ dict(
300
+ type=LinearLR,
301
+ start_factor=1e-5,
302
+ by_epoch=True,
303
+ begin=0,
304
+ end=warmup_ratio * max_epochs,
305
+ convert_to_iter_based=True),
306
+ dict(
307
+ type=CosineAnnealingLR,
308
+ eta_min=0.0,
309
+ by_epoch=True,
310
+ begin=warmup_ratio * max_epochs,
311
+ end=max_epochs,
312
+ convert_to_iter_based=True)
313
+ ]
314
+
315
+ # train, val, test setting
316
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
317
+
318
+ #######################################################################
319
+ # PART 5 Runtime #
320
+ #######################################################################
321
+ # Log the dialogue periodically during the training process, optional
322
+ custom_hooks = [
323
+ dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer),
324
+ dict(
325
+ type=EvaluateChatHook_withSpecialTokens,
326
+ tokenizer=tokenizer,
327
+ image_processor=image_processor,
328
+ every_n_iters=evaluation_freq,
329
+ evaluation_inputs=evaluation_inputs,
330
+ evaluation_images=evaluation_images,
331
+ system=SYSTEM,
332
+ prompt_template=prompt_template)
333
+ ]
334
+
335
+ # configure default hooks
336
+ default_hooks = dict(
337
+ # record the time of every iteration.
338
+ timer=dict(type=IterTimerHook),
339
+ # print log every 10 iterations.
340
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
341
+ # enable the parameter scheduler.
342
+ param_scheduler=dict(type=ParamSchedulerHook),
343
+ # save checkpoint per `save_steps`.
344
+ checkpoint=dict(
345
+ type=CheckpointHook,
346
+ by_epoch=False,
347
+ interval=save_steps,
348
+ max_keep_ckpts=save_total_limit),
349
+ # set sampler seed in distributed evrionment.
350
+ sampler_seed=dict(type=DistSamplerSeedHook),
351
+ )
352
+
353
+ # configure environment
354
+ env_cfg = dict(
355
+ # whether to enable cudnn benchmark
356
+ cudnn_benchmark=False,
357
+ # set multi process parameters
358
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
359
+ # set distributed parameters
360
+ dist_cfg=dict(backend='nccl'),
361
+ )
362
+
363
+ # set visualizer
364
+ visualizer = None
365
+
366
+ # set log level
367
+ log_level = 'INFO'
368
+
369
+ # load from which checkpoint
370
+ load_from = None
371
+
372
+ # whether to resume training from the loaded checkpoint
373
+ resume = False
374
+
375
+ # Defaults to use random seed and disable `deterministic`
376
+ randomness = dict(seed=None, deterministic=False)
377
+
378
+ # set log processor
379
+ log_processor = dict(by_epoch=False)
omg_llava/configs/pretrain/omg_llava_7b_pretrain_1024image_8gpus.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.dataset import DefaultSampler
4
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
5
+ LoggerHook, ParamSchedulerHook)
6
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
7
+ from torch.optim import AdamW
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig, CLIPImageProcessor,
10
+ CLIPVisionModel)
11
+
12
+ from omg_llava.dataset import LLaVADataset
13
+ from omg_llava.dataset.collect_fns import omg_llava_collate_fn
14
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
15
+ from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens
16
+ from xtuner.engine.runner import TrainLoop
17
+ from omg_llava.model import OMG_LLaVA
18
+ from xtuner.utils import PROMPT_TEMPLATE
19
+ from omg_llava.model.convnext_clip import OpenCLIPBackbone_omgseg
20
+ from omg_llava.model.omg_seg import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead
21
+
22
+ from torch.nn import GroupNorm, ReLU
23
+
24
+ from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \
25
+ DiceLoss, MaskFormerFusionHead, FocalLoss
26
+ from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost
27
+ from mmdet.models.task_modules.samplers import MaskPseudoSampler
28
+
29
+ #######################################################################
30
+ # PART 1 Settings #
31
+ #######################################################################
32
+ # Model or model paths
33
+ llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path
34
+ omg_ov_class_embed_path='./pretrained/omg_llava/convnext_large_d_320_CocoPanopticOVDataset.pth' # Please change to your own path
35
+ omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convl.pth' # Please change to your own path
36
+
37
+ # Data paths
38
+ data_root = './data/llava_data/'
39
+ data_path = data_root + 'LLaVA-Pretrain/blip_laion_cc_sbu_558k.json'
40
+ image_folder = data_root + 'LLaVA-Pretrain/images'
41
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
42
+ max_length = int(2048 - (1024 / 64)**2)
43
+
44
+ # Scheduler & Optimizer
45
+ batch_size = 16 # per_device
46
+ accumulative_counts = 2
47
+ dataloader_num_workers = 4
48
+ max_epochs = 1
49
+ optim_type = AdamW
50
+ lr = 1e-3
51
+ betas = (0.9, 0.999)
52
+ weight_decay = 0
53
+ max_norm = 1 # grad clip
54
+ warmup_ratio = 0.03
55
+
56
+ # Save
57
+ save_steps = 500
58
+ save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
59
+
60
+ # Evaluate the generation performance during the training
61
+ evaluation_freq = 200
62
+ SYSTEM = ''
63
+ evaluation_images = './work_dirs/test.jpg'
64
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture']
65
+
66
+ #######################################################################
67
+ # PART 2 Model & Tokenizer & Image Processor #
68
+ #######################################################################
69
+ tokenizer = dict(
70
+ type=AutoTokenizer.from_pretrained,
71
+ pretrained_model_name_or_path=llm_name_or_path,
72
+ trust_remote_code=True,
73
+ padding_side='right')
74
+
75
+ image_processor = dict(
76
+ type=CLIPImageProcessor,
77
+ do_resize=True,
78
+ size=1024,
79
+ resample=3,
80
+ do_center_crop=True,
81
+ crop_size=1024,
82
+ do_rescale=True,
83
+ do_normalize=True,
84
+ image_mean=[0.4814, 0.4578, 0.4082],
85
+ image_std=[0.2686, 0.2613, 0.2757],
86
+ do_convert_rgb=True
87
+ )
88
+
89
+ # using coco class as the class classifier
90
+ class_embed = 'convnext_large_d_320_CocoPanopticOVDataset'
91
+ num_things_classes = 80
92
+ num_stuff_classes = 53
93
+ num_classes = num_things_classes + num_stuff_classes
94
+
95
+ omgseg_model = dict(
96
+ type=OMGSegVisualEncoder,
97
+ data_preprocessor=None,
98
+ pixel_shuffle_down_ratio=2,
99
+ backbone=dict(
100
+ type=OpenCLIPBackbone_omgseg,
101
+ model_name='convnext_large_d_320',
102
+ fix=True,
103
+ init_cfg=dict(
104
+ type='clip_pretrain',
105
+ checkpoint='laion2b_s29b_b131k_ft_soup'
106
+ )
107
+ ),
108
+ panoptic_head=dict(
109
+ type=Mask2FormerVideoSemSamHead,
110
+ sphere_cls=True,
111
+ ov_path=omg_ov_class_embed_path,
112
+ enable_box_query=False,
113
+ ov_classifier_name=class_embed,
114
+ logit=None,
115
+ in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside
116
+ strides=[4, 8, 16, 32],
117
+ feat_channels=256,
118
+ out_channels=256,
119
+ num_things_classes=num_things_classes,
120
+ num_stuff_classes=num_stuff_classes,
121
+ num_queries=300,
122
+ num_transformer_feat_level=3,
123
+ pixel_decoder=dict(
124
+ type=MSDeformAttnPixelDecoder,
125
+ num_outs=3,
126
+ norm_cfg=dict(type=GroupNorm, num_groups=32),
127
+ act_cfg=dict(type=ReLU),
128
+ encoder=dict( # DeformableDetrTransformerEncoder
129
+ num_layers=6,
130
+ layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
131
+ self_attn_cfg=dict( # MultiScaleDeformableAttention
132
+ embed_dims=256,
133
+ num_heads=8,
134
+ num_levels=3,
135
+ num_points=4,
136
+ dropout=0.0,
137
+ batch_first=True),
138
+ ffn_cfg=dict(
139
+ embed_dims=256,
140
+ feedforward_channels=1024,
141
+ num_fcs=2,
142
+ ffn_drop=0.0,
143
+ act_cfg=dict(type=ReLU, inplace=True)))),
144
+ positional_encoding=dict(num_feats=128, normalize=True)),
145
+ enforce_decoder_input_project=False,
146
+ positional_encoding=dict(num_feats=128, normalize=True),
147
+ transformer_decoder=dict( # Mask2FormerTransformerDecoder
148
+ return_intermediate=True,
149
+ num_layers=9,
150
+ layer_cfg=dict( # Mask2FormerTransformerDecoderLayer
151
+ self_attn_cfg=dict( # MultiheadAttention
152
+ embed_dims=256,
153
+ num_heads=8,
154
+ dropout=0.0,
155
+ batch_first=True),
156
+ cross_attn_cfg=dict( # MultiheadAttention
157
+ embed_dims=256,
158
+ num_heads=8,
159
+ dropout=0.0,
160
+ batch_first=True),
161
+ ffn_cfg=dict(
162
+ embed_dims=256,
163
+ feedforward_channels=2048,
164
+ num_fcs=2,
165
+ ffn_drop=0.0,
166
+ act_cfg=dict(type='ReLU', inplace=True))),
167
+ init_cfg=None),
168
+ loss_cls=dict(
169
+ type=CrossEntropyLoss,
170
+ use_sigmoid=False,
171
+ loss_weight=2.0,
172
+ reduction='mean',
173
+ class_weight=[1.0] * 240 + [0.1]),
174
+ loss_mask=dict(
175
+ type=CrossEntropyLoss,
176
+ use_sigmoid=True,
177
+ reduction='mean',
178
+ loss_weight=5.0),
179
+ loss_dice=dict(
180
+ type=DiceLoss,
181
+ use_sigmoid=True,
182
+ activate=True,
183
+ reduction='mean',
184
+ naive_dice=True,
185
+ eps=1.0,
186
+ loss_weight=5.0),
187
+ loss_iou=dict(
188
+ type=FocalLoss,
189
+ use_sigmoid=True,
190
+ loss_weight=2.0,
191
+ reduction='mean')
192
+ ),
193
+ panoptic_fusion_head=dict(
194
+ type=MaskFormerFusionHead,
195
+ num_things_classes=num_things_classes,
196
+ num_stuff_classes=num_stuff_classes,
197
+ loss_panoptic=None,
198
+ init_cfg=None),
199
+ train_cfg=dict(
200
+ num_points=12544,
201
+ oversample_ratio=3.0,
202
+ importance_sample_ratio=0.75,
203
+ assigner=dict(
204
+ type=HungarianAssigner,
205
+ match_costs=[
206
+ # dict(type=FlexibleClassificationCost, weight=2.0),
207
+ dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True),
208
+ dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0)
209
+ ]),
210
+ sampler=dict(type=MaskPseudoSampler)),
211
+ test_cfg=dict(
212
+ panoptic_on=True,
213
+ # For now, the dataset does not support
214
+ # evaluating semantic segmentation metric.
215
+ semantic_on=False,
216
+ instance_on=True,
217
+ # max_per_image is for instance segmentation.
218
+ max_per_image=100,
219
+ iou_thr=0.8,
220
+ # In Mask2Former's panoptic postprocessing,
221
+ # it will filter mask area where score is less than 0.5 .
222
+ filter_low_score=True),
223
+ init_cfg=dict(
224
+ type='Pretrained',
225
+ checkpoint=omg_head_pretrain_pth_path,
226
+ )
227
+ )
228
+
229
+ model = dict(
230
+ type=OMG_LLaVA,
231
+ freeze_llm=True,
232
+ freeze_visual_encoder=True,
233
+ text2vision_projector=True,
234
+ keep_omg_decoder_frozen=True,
235
+ add_seg_pretrain=True,
236
+ pixel_shuffle_ratio=2,
237
+ llm=dict(
238
+ type=AutoModelForCausalLM.from_pretrained,
239
+ pretrained_model_name_or_path=llm_name_or_path,
240
+ trust_remote_code=True,
241
+ torch_dtype=torch.float16,
242
+ quantization_config=dict(
243
+ type=BitsAndBytesConfig,
244
+ load_in_4bit=True,
245
+ load_in_8bit=False,
246
+ llm_int8_threshold=6.0,
247
+ llm_int8_has_fp16_weight=False,
248
+ bnb_4bit_compute_dtype=torch.float16,
249
+ bnb_4bit_use_double_quant=True,
250
+ bnb_4bit_quant_type='nf4')),
251
+ visual_encoder=omgseg_model,
252
+ tokenizer=tokenizer,
253
+ )
254
+
255
+ #######################################################################
256
+ # PART 3 Dataset & Dataloader #
257
+ #######################################################################
258
+ llava_dataset = dict(
259
+ type=LLaVADataset,
260
+ data_path=data_path,
261
+ image_folder=image_folder,
262
+ tokenizer=tokenizer,
263
+ image_processor=image_processor,
264
+ dataset_map_fn=llava_map_fn,
265
+ template_map_fn=dict(
266
+ type=template_map_fn_factory, template=prompt_template),
267
+ max_length=max_length,
268
+ pad_image_to_square=True,
269
+ debug=False,
270
+ )
271
+
272
+ train_dataloader = dict(
273
+ batch_size=batch_size,
274
+ num_workers=dataloader_num_workers,
275
+ dataset=llava_dataset,
276
+ sampler=dict(type=DefaultSampler, shuffle=True),
277
+ collate_fn=dict(type=omg_llava_collate_fn))
278
+
279
+ #######################################################################
280
+ # PART 4 Scheduler & Optimizer #
281
+ #######################################################################
282
+ # optimizer
283
+ optim_wrapper = dict(
284
+ type=AmpOptimWrapper,
285
+ optimizer=dict(
286
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
287
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
288
+ accumulative_counts=accumulative_counts,
289
+ loss_scale='dynamic',
290
+ dtype='float16')
291
+
292
+ # learning policy
293
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
294
+ param_scheduler = [
295
+ dict(
296
+ type=LinearLR,
297
+ start_factor=1e-5,
298
+ by_epoch=True,
299
+ begin=0,
300
+ end=warmup_ratio * max_epochs,
301
+ convert_to_iter_based=True),
302
+ dict(
303
+ type=CosineAnnealingLR,
304
+ eta_min=0.0,
305
+ by_epoch=True,
306
+ begin=warmup_ratio * max_epochs,
307
+ end=max_epochs,
308
+ convert_to_iter_based=True)
309
+ ]
310
+
311
+ # train, val, test setting
312
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
313
+
314
+ #######################################################################
315
+ # PART 5 Runtime #
316
+ #######################################################################
317
+ # Log the dialogue periodically during the training process, optional
318
+ custom_hooks = [
319
+ dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer),
320
+ dict(
321
+ type=EvaluateChatHook_withSpecialTokens,
322
+ tokenizer=tokenizer,
323
+ image_processor=image_processor,
324
+ every_n_iters=evaluation_freq,
325
+ evaluation_inputs=evaluation_inputs,
326
+ evaluation_images=evaluation_images,
327
+ system=SYSTEM,
328
+ prompt_template=prompt_template)
329
+ ]
330
+
331
+ # configure default hooks
332
+ default_hooks = dict(
333
+ # record the time of every iteration.
334
+ timer=dict(type=IterTimerHook),
335
+ # print log every 10 iterations.
336
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
337
+ # enable the parameter scheduler.
338
+ param_scheduler=dict(type=ParamSchedulerHook),
339
+ # save checkpoint per `save_steps`.
340
+ checkpoint=dict(
341
+ type=CheckpointHook,
342
+ by_epoch=False,
343
+ interval=save_steps,
344
+ max_keep_ckpts=save_total_limit),
345
+ # set sampler seed in distributed evrionment.
346
+ sampler_seed=dict(type=DistSamplerSeedHook),
347
+ )
348
+
349
+ # configure environment
350
+ env_cfg = dict(
351
+ # whether to enable cudnn benchmark
352
+ cudnn_benchmark=False,
353
+ # set multi process parameters
354
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
355
+ # set distributed parameters
356
+ dist_cfg=dict(backend='nccl'),
357
+ )
358
+
359
+ # set visualizer
360
+ visualizer = None
361
+
362
+ # set log level
363
+ log_level = 'INFO'
364
+
365
+ # load from which checkpoint
366
+ load_from = None
367
+
368
+ # whether to resume training from the loaded checkpoint
369
+ resume = False
370
+
371
+ # Defaults to use random seed and disable `deterministic`
372
+ randomness = dict(seed=None, deterministic=False)
373
+
374
+ # set log processor
375
+ log_processor = dict(by_epoch=False)
omg_llava/configs/pretrain/omg_llava_7b_pretrain_1024image_convnextXXL_8gpus.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.dataset import DefaultSampler
4
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
5
+ LoggerHook, ParamSchedulerHook)
6
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
7
+ from torch.optim import AdamW
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig, CLIPImageProcessor,
10
+ CLIPVisionModel)
11
+
12
+ from omg_llava.dataset import LLaVADataset
13
+ from omg_llava.dataset.collect_fns import omg_llava_collate_fn
14
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
15
+ from omg_llava.engine import DatasetInfoHook_withSpecoalTokens, EvaluateChatHook_withSpecialTokens
16
+ from xtuner.engine.runner import TrainLoop
17
+ from omg_llava.model import OMG_LLaVA
18
+ from xtuner.utils import PROMPT_TEMPLATE
19
+ from omg_llava.model.convnext_clip import OpenCLIPBackbone_omgseg
20
+ from omg_llava.model.omg_seg import OMGSegVisualEncoder, Mask2FormerVideoSemSamHead
21
+
22
+ from torch.nn import GroupNorm, ReLU
23
+
24
+ from mmdet.models import BatchFixedSizePad, MSDeformAttnPixelDecoder, CrossEntropyLoss, \
25
+ DiceLoss, MaskFormerFusionHead, FocalLoss
26
+ from mmdet.models.task_modules.assigners import HungarianAssigner, CrossEntropyLossCost, DiceCost
27
+ from mmdet.models.task_modules.samplers import MaskPseudoSampler
28
+
29
+ #######################################################################
30
+ # PART 1 Settings #
31
+ #######################################################################
32
+ # Model or model paths
33
+ llm_name_or_path = './pretrained/omg_llava/internlm2-chat-7b' # Please change to your own path
34
+ omg_ov_class_embed_path='./pretrained/omg_llava/convnext_xxlarge_CocoPanopticOVDataset.pth' # Please change to your own path
35
+ omg_head_pretrain_pth_path = './pretrained/omg_llava/omg_seg_convxxl.pth' # Please change to your own path
36
+
37
+ # Data paths
38
+ data_root = './data/llava_data/'
39
+ data_path = data_root + 'LLaVA-Pretrain/blip_laion_cc_sbu_558k.json'
40
+ image_folder = data_root + 'LLaVA-Pretrain/images'
41
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
42
+ max_length = int(2048 - (1024 / 64)**2)
43
+
44
+ # Scheduler & Optimizer
45
+ batch_size = 16 # per_device
46
+ accumulative_counts = 4
47
+ dataloader_num_workers = 4
48
+ max_epochs = 1
49
+ optim_type = AdamW
50
+ lr = 1e-3
51
+ betas = (0.9, 0.999)
52
+ weight_decay = 0
53
+ max_norm = 1 # grad clip
54
+ warmup_ratio = 0.03
55
+
56
+ # Save
57
+ save_steps = 500
58
+ save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
59
+
60
+ # Evaluate the generation performance during the training
61
+ evaluation_freq = 200
62
+ SYSTEM = ''
63
+ evaluation_images = './work_dirs/test.jpg'
64
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture']
65
+
66
+ #######################################################################
67
+ # PART 2 Model & Tokenizer & Image Processor #
68
+ #######################################################################
69
+ tokenizer = dict(
70
+ type=AutoTokenizer.from_pretrained,
71
+ pretrained_model_name_or_path=llm_name_or_path,
72
+ trust_remote_code=True,
73
+ padding_side='right')
74
+
75
+ image_processor = dict(
76
+ type=CLIPImageProcessor,
77
+ do_resize=True,
78
+ size=1024,
79
+ resample=3,
80
+ do_center_crop=True,
81
+ crop_size=1024,
82
+ do_rescale=True,
83
+ do_normalize=True,
84
+ image_mean=[0.4814, 0.4578, 0.4082],
85
+ image_std=[0.2686, 0.2613, 0.2757],
86
+ do_convert_rgb=True
87
+ )
88
+
89
+ # using coco class as the class classifier
90
+ class_embed = 'convnext_large_d_320_CocoPanopticOVDataset'
91
+ num_things_classes = 80
92
+ num_stuff_classes = 53
93
+ num_classes = num_things_classes + num_stuff_classes
94
+
95
+ omgseg_model = dict(
96
+ type=OMGSegVisualEncoder,
97
+ data_preprocessor=None,
98
+ pixel_shuffle_down_ratio=2,
99
+ backbone=dict(
100
+ type=OpenCLIPBackbone_omgseg,
101
+ model_name='convnext_xxlarge',
102
+ fix=True,
103
+ init_cfg=dict(
104
+ type='clip_pretrain',
105
+ checkpoint='laion2b_s34b_b82k_augreg_soup'
106
+ )
107
+ ),
108
+ panoptic_head=dict(
109
+ type=Mask2FormerVideoSemSamHead,
110
+ sphere_cls=True,
111
+ ov_path=omg_ov_class_embed_path,
112
+ enable_box_query=False,
113
+ ov_classifier_name=class_embed,
114
+ logit=None,
115
+ in_channels=[384, 768, 1536, 3072], # pass to pixel_decoder inside
116
+ strides=[4, 8, 16, 32],
117
+ feat_channels=256,
118
+ out_channels=256,
119
+ num_things_classes=num_things_classes,
120
+ num_stuff_classes=num_stuff_classes,
121
+ num_queries=300,
122
+ num_transformer_feat_level=3,
123
+ pixel_decoder=dict(
124
+ type=MSDeformAttnPixelDecoder,
125
+ num_outs=3,
126
+ norm_cfg=dict(type=GroupNorm, num_groups=32),
127
+ act_cfg=dict(type=ReLU),
128
+ encoder=dict( # DeformableDetrTransformerEncoder
129
+ num_layers=6,
130
+ layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
131
+ self_attn_cfg=dict( # MultiScaleDeformableAttention
132
+ embed_dims=256,
133
+ num_heads=8,
134
+ num_levels=3,
135
+ num_points=4,
136
+ dropout=0.0,
137
+ batch_first=True),
138
+ ffn_cfg=dict(
139
+ embed_dims=256,
140
+ feedforward_channels=1024,
141
+ num_fcs=2,
142
+ ffn_drop=0.0,
143
+ act_cfg=dict(type=ReLU, inplace=True)))),
144
+ positional_encoding=dict(num_feats=128, normalize=True)),
145
+ enforce_decoder_input_project=False,
146
+ positional_encoding=dict(num_feats=128, normalize=True),
147
+ transformer_decoder=dict( # Mask2FormerTransformerDecoder
148
+ return_intermediate=True,
149
+ num_layers=9,
150
+ layer_cfg=dict( # Mask2FormerTransformerDecoderLayer
151
+ self_attn_cfg=dict( # MultiheadAttention
152
+ embed_dims=256,
153
+ num_heads=8,
154
+ dropout=0.0,
155
+ batch_first=True),
156
+ cross_attn_cfg=dict( # MultiheadAttention
157
+ embed_dims=256,
158
+ num_heads=8,
159
+ dropout=0.0,
160
+ batch_first=True),
161
+ ffn_cfg=dict(
162
+ embed_dims=256,
163
+ feedforward_channels=2048,
164
+ num_fcs=2,
165
+ ffn_drop=0.0,
166
+ act_cfg=dict(type='ReLU', inplace=True))),
167
+ init_cfg=None),
168
+ loss_cls=dict(
169
+ type=CrossEntropyLoss,
170
+ use_sigmoid=False,
171
+ loss_weight=2.0,
172
+ reduction='mean',
173
+ class_weight=[1.0] * 240 + [0.1]),
174
+ loss_mask=dict(
175
+ type=CrossEntropyLoss,
176
+ use_sigmoid=True,
177
+ reduction='mean',
178
+ loss_weight=5.0),
179
+ loss_dice=dict(
180
+ type=DiceLoss,
181
+ use_sigmoid=True,
182
+ activate=True,
183
+ reduction='mean',
184
+ naive_dice=True,
185
+ eps=1.0,
186
+ loss_weight=5.0),
187
+ loss_iou=dict(
188
+ type=FocalLoss,
189
+ use_sigmoid=True,
190
+ loss_weight=2.0,
191
+ reduction='mean')
192
+ ),
193
+ panoptic_fusion_head=dict(
194
+ type=MaskFormerFusionHead,
195
+ num_things_classes=num_things_classes,
196
+ num_stuff_classes=num_stuff_classes,
197
+ loss_panoptic=None,
198
+ init_cfg=None),
199
+ train_cfg=dict(
200
+ num_points=12544,
201
+ oversample_ratio=3.0,
202
+ importance_sample_ratio=0.75,
203
+ assigner=dict(
204
+ type=HungarianAssigner,
205
+ match_costs=[
206
+ # dict(type=FlexibleClassificationCost, weight=2.0),
207
+ dict(type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True),
208
+ dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0)
209
+ ]),
210
+ sampler=dict(type=MaskPseudoSampler)),
211
+ test_cfg=dict(
212
+ panoptic_on=True,
213
+ # For now, the dataset does not support
214
+ # evaluating semantic segmentation metric.
215
+ semantic_on=False,
216
+ instance_on=True,
217
+ # max_per_image is for instance segmentation.
218
+ max_per_image=100,
219
+ iou_thr=0.8,
220
+ # In Mask2Former's panoptic postprocessing,
221
+ # it will filter mask area where score is less than 0.5 .
222
+ filter_low_score=True),
223
+ init_cfg=dict(
224
+ type='Pretrained',
225
+ checkpoint=omg_head_pretrain_pth_path,
226
+ )
227
+ )
228
+
229
+ model = dict(
230
+ type=OMG_LLaVA,
231
+ freeze_llm=True,
232
+ freeze_visual_encoder=True,
233
+ text2vision_projector=True,
234
+ keep_omg_decoder_frozen=True,
235
+ add_seg_pretrain=True,
236
+ pixel_shuffle_ratio=2,
237
+ clip_feat_channel=3072,
238
+ llm=dict(
239
+ type=AutoModelForCausalLM.from_pretrained,
240
+ pretrained_model_name_or_path=llm_name_or_path,
241
+ trust_remote_code=True,
242
+ torch_dtype=torch.float16,
243
+ quantization_config=dict(
244
+ type=BitsAndBytesConfig,
245
+ load_in_4bit=True,
246
+ load_in_8bit=False,
247
+ llm_int8_threshold=6.0,
248
+ llm_int8_has_fp16_weight=False,
249
+ bnb_4bit_compute_dtype=torch.float16,
250
+ bnb_4bit_use_double_quant=True,
251
+ bnb_4bit_quant_type='nf4')),
252
+ visual_encoder=omgseg_model,
253
+ tokenizer=tokenizer,
254
+ )
255
+
256
+ #######################################################################
257
+ # PART 3 Dataset & Dataloader #
258
+ #######################################################################
259
+ llava_dataset = dict(
260
+ type=LLaVADataset,
261
+ data_path=data_path,
262
+ image_folder=image_folder,
263
+ tokenizer=tokenizer,
264
+ image_processor=image_processor,
265
+ dataset_map_fn=llava_map_fn,
266
+ template_map_fn=dict(
267
+ type=template_map_fn_factory, template=prompt_template),
268
+ max_length=max_length,
269
+ pad_image_to_square=True,
270
+ debug=False,
271
+ )
272
+
273
+ train_dataloader = dict(
274
+ batch_size=batch_size,
275
+ num_workers=dataloader_num_workers,
276
+ dataset=llava_dataset,
277
+ sampler=dict(type=DefaultSampler, shuffle=True),
278
+ collate_fn=dict(type=omg_llava_collate_fn))
279
+
280
+ #######################################################################
281
+ # PART 4 Scheduler & Optimizer #
282
+ #######################################################################
283
+ # optimizer
284
+ optim_wrapper = dict(
285
+ type=AmpOptimWrapper,
286
+ optimizer=dict(
287
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
288
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
289
+ accumulative_counts=accumulative_counts,
290
+ loss_scale='dynamic',
291
+ dtype='float16')
292
+
293
+ # learning policy
294
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
295
+ param_scheduler = [
296
+ dict(
297
+ type=LinearLR,
298
+ start_factor=1e-5,
299
+ by_epoch=True,
300
+ begin=0,
301
+ end=warmup_ratio * max_epochs,
302
+ convert_to_iter_based=True),
303
+ dict(
304
+ type=CosineAnnealingLR,
305
+ eta_min=0.0,
306
+ by_epoch=True,
307
+ begin=warmup_ratio * max_epochs,
308
+ end=max_epochs,
309
+ convert_to_iter_based=True)
310
+ ]
311
+
312
+ # train, val, test setting
313
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
314
+
315
+ #######################################################################
316
+ # PART 5 Runtime #
317
+ #######################################################################
318
+ # Log the dialogue periodically during the training process, optional
319
+ custom_hooks = [
320
+ dict(type=DatasetInfoHook_withSpecoalTokens, tokenizer=tokenizer),
321
+ dict(
322
+ type=EvaluateChatHook_withSpecialTokens,
323
+ tokenizer=tokenizer,
324
+ image_processor=image_processor,
325
+ every_n_iters=evaluation_freq,
326
+ evaluation_inputs=evaluation_inputs,
327
+ evaluation_images=evaluation_images,
328
+ system=SYSTEM,
329
+ prompt_template=prompt_template)
330
+ ]
331
+
332
+ # configure default hooks
333
+ default_hooks = dict(
334
+ # record the time of every iteration.
335
+ timer=dict(type=IterTimerHook),
336
+ # print log every 10 iterations.
337
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
338
+ # enable the parameter scheduler.
339
+ param_scheduler=dict(type=ParamSchedulerHook),
340
+ # save checkpoint per `save_steps`.
341
+ checkpoint=dict(
342
+ type=CheckpointHook,
343
+ by_epoch=False,
344
+ interval=save_steps,
345
+ max_keep_ckpts=save_total_limit),
346
+ # set sampler seed in distributed evrionment.
347
+ sampler_seed=dict(type=DistSamplerSeedHook),
348
+ )
349
+
350
+ # configure environment
351
+ env_cfg = dict(
352
+ # whether to enable cudnn benchmark
353
+ cudnn_benchmark=False,
354
+ # set multi process parameters
355
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
356
+ # set distributed parameters
357
+ dist_cfg=dict(backend='nccl'),
358
+ )
359
+
360
+ # set visualizer
361
+ visualizer = None
362
+
363
+ # set log level
364
+ log_level = 'INFO'
365
+
366
+ # load from which checkpoint
367
+ load_from = None
368
+
369
+ # whether to resume training from the loaded checkpoint
370
+ resume = False
371
+
372
+ # Defaults to use random seed and disable `deterministic`
373
+ randomness = dict(seed=None, deterministic=False)
374
+
375
+ # set log processor
376
+ log_processor = dict(by_epoch=False)
omg_llava/dataset/CombineDataset.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ import numpy as np
3
+
4
+ class CombineDataset(Dataset):
5
+ def __init__(self,
6
+ datasets_cfgs,
7
+ ):
8
+ super().__init__()
9
+
10
+ self.datasets = []
11
+ self.datasets_length = []
12
+
13
+ self.tokenizer = datasets_cfgs[0].tokenizer
14
+ tokenizer_type = self.tokenizer['type']
15
+ del self.tokenizer['type']
16
+ self.tokenizer = tokenizer_type(**self.tokenizer)
17
+
18
+ self._add_special_tokens()
19
+
20
+ for i in range(len(datasets_cfgs)):
21
+ datasets_cfgs[i].tokenizer = self.tokenizer
22
+
23
+ for dataset_cfg in datasets_cfgs:
24
+ dataset = dataset_cfg['type']
25
+ del dataset_cfg['type']
26
+ dataset = dataset(**dataset_cfg)
27
+ self.datasets.append(dataset)
28
+ self.datasets_length.append(len(dataset))
29
+
30
+ self.dataset_threthold = []
31
+ for i, length in enumerate(self.datasets_length):
32
+ if i == 0:
33
+ self.dataset_threthold.append(length)
34
+ else:
35
+ self.dataset_threthold.append(length + self.dataset_threthold[i - 1])
36
+
37
+ np.random.seed(42)
38
+ self.shuffled_index = np.arange(self.dataset_threthold[-1])
39
+ np.random.shuffle(self.shuffled_index)
40
+
41
+ @property
42
+ def modality_length(self):
43
+ length_list = []
44
+ for dataset in self.datasets:
45
+ for data_dict in dataset.text_data:
46
+ cur_len = len(data_dict['input_ids'])
47
+ if data_dict.get('image', None) is None:
48
+ cur_len = -cur_len
49
+ length_list.append(cur_len)
50
+ return length_list
51
+
52
+ def __len__(self):
53
+ return self.dataset_threthold[-1]
54
+
55
+ def __getitem__(self, index):
56
+ index = int(self.shuffled_index[index])
57
+ for i, thred in enumerate(self.dataset_threthold):
58
+ if index < thred:
59
+ break
60
+
61
+
62
+ if i == 0:
63
+ _index = index
64
+ else:
65
+ _index = index - self.dataset_threthold[i - 1]
66
+
67
+ return self.datasets[i][_index]
68
+
69
+ def _add_special_tokens(self):
70
+ assert hasattr(self, "tokenizer")
71
+ # Adding special tokens for pixel grounding
72
+ segmentation_tokens = ['[SEG]']
73
+ # Adding tokens for GCG
74
+ phrase_tokens = ['<p>', '</p>']
75
+ # add for visual prompt
76
+ region_tokens = ['<region>']
77
+ point_tokens = ['<mark>']
78
+ special_tokens = segmentation_tokens + phrase_tokens + region_tokens + point_tokens
79
+
80
+ self.tokenizer.add_tokens(special_tokens, special_tokens=True)
81
+ return
omg_llava/dataset/DecoupledGCGDataset.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+
5
+ import torch
6
+ from datasets import Dataset as HFDataset
7
+ from datasets import DatasetDict, load_from_disk
8
+ from mmengine import print_log
9
+ from mmengine.config import Config, ConfigDict
10
+ from PIL import Image
11
+ from torch.utils.data import Dataset
12
+ from pycocotools import mask
13
+ import numpy as np
14
+ import torch.nn.functional as F
15
+ import copy
16
+
17
+ from xtuner.registry import BUILDER
18
+ from omg_llava.dataset.utils import expand2square, expand2square_mask
19
+ from xtuner.dataset.huggingface import process_hf_dataset
20
+
21
+ class DecoupledGCGDataset(Dataset):
22
+
23
+ def __init__(self,
24
+ image_folder,
25
+ image_processor,
26
+ data_path=None,
27
+ tokenizer=None,
28
+ offline_processed_text_folder=None,
29
+ max_dataset_length=None,
30
+ dataset_map_fn=None,
31
+ template_map_fn=None,
32
+ max_length=2048,
33
+ pad_image_to_square=False,
34
+ num_proc=32,
35
+ debug=False,
36
+ repeats=1,
37
+ mode='given_description'):
38
+ super().__init__()
39
+
40
+ assert offline_processed_text_folder or (data_path and tokenizer)
41
+ self.debug = debug
42
+ if offline_processed_text_folder and data_path:
43
+ print_log(
44
+ 'Both `offline_processed_text_folder` and '
45
+ '`data_path` are set, and we load dataset from'
46
+ '`offline_processed_text_folder` '
47
+ f'({offline_processed_text_folder})',
48
+ logger='current',
49
+ level=logging.WARNING)
50
+
51
+ if offline_processed_text_folder is not None:
52
+ raise NotImplementedError
53
+ else:
54
+ json_data = self.json_file_preprocess(data_path)
55
+ json_data = DatasetDict({'train': HFDataset.from_list(json_data)})
56
+ self.text_data = process_hf_dataset(
57
+ dataset=json_data,
58
+ tokenizer=tokenizer,
59
+ max_length=max_length,
60
+ dataset_map_fn=dataset_map_fn,
61
+ template_map_fn=template_map_fn,
62
+ split='train',
63
+ max_dataset_length=max_dataset_length,
64
+ remove_unused_columns=False,
65
+ pack_to_max_length=False,
66
+ with_image_token=True,
67
+ map_num_proc=num_proc, # because limited mem
68
+ )
69
+
70
+ self.image_folder = image_folder
71
+ size = image_processor.crop_size
72
+ if isinstance(size, int):
73
+ self.image_h, self.image_w = size, size
74
+ else:
75
+ self.image_w, self.image_h = size
76
+
77
+ if isinstance(image_processor, dict) or isinstance(
78
+ image_processor, Config) or isinstance(image_processor,
79
+ ConfigDict):
80
+ self.image_processor = BUILDER.build(image_processor)
81
+ else:
82
+ self.image_processor = image_processor
83
+ self.pad_image_to_square = pad_image_to_square
84
+ self.down_ratio = 1
85
+ self.repeats = repeats
86
+ self.mode = mode
87
+
88
+ def json_file_preprocess(self, data_path):
89
+ with open(data_path, 'r') as f:
90
+ json_data = json.load(f)
91
+
92
+ # for quickly debug with mini split
93
+ if self.debug:
94
+ json_data = json_data[:100]
95
+ return json_data
96
+
97
+ @property
98
+ def modality_length(self):
99
+ length_list = []
100
+ for data_dict in self.text_data:
101
+ cur_len = len(data_dict['input_ids'])
102
+ if data_dict.get('image', None) is None:
103
+ cur_len = -cur_len
104
+ length_list.append(cur_len)
105
+ length_list = length_list * self.repeats
106
+ return length_list
107
+
108
+ def __len__(self):
109
+ return len(self.text_data) * self.repeats
110
+
111
+ def real_len(self):
112
+ return len(self.text_data)
113
+
114
+ def decode_mask(self, object_masks, ori_height, ori_width):
115
+ binary_masks = []
116
+ for object_mask in object_masks:
117
+ binary_mask = np.zeros((ori_height, ori_width), dtype=np.uint8)
118
+ for seg in object_mask:
119
+ rles = mask.frPyObjects([seg], ori_height, ori_width)
120
+ m = mask.decode(rles)
121
+ m = m.astype(np.uint8)
122
+ binary_mask += m.squeeze()
123
+
124
+ binary_masks.append(binary_mask)
125
+ if len(binary_masks) == 0:
126
+ return None
127
+ masks = np.stack(binary_masks, axis=0)
128
+ if self.pad_image_to_square:
129
+ masks = expand2square_mask(masks)
130
+ masks = torch.from_numpy(masks)
131
+ masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio, self.image_w // self.down_ratio), mode='nearest').squeeze(0)
132
+ return masks
133
+
134
+ def __getitem__(self, index):
135
+ index = index % self.real_len()
136
+ data_dict = copy.deepcopy(self.text_data[index])
137
+
138
+ if data_dict.get('image', None) is not None:
139
+ image_file = data_dict['image']
140
+ image = Image.open(os.path.join(self.image_folder,
141
+ image_file)).convert('RGB')
142
+ ori_width, ori_height = image.size
143
+ if self.pad_image_to_square:
144
+ image = expand2square(
145
+ image,
146
+ tuple(
147
+ int(x * 255) for x in self.image_processor.image_mean))
148
+ image = self.image_processor.preprocess(
149
+ image, return_tensors='pt')['pixel_values'][0]
150
+ data_dict['pixel_values'] = image
151
+
152
+ # process and get masks
153
+ data_dict['masks'] = self.decode_mask(data_dict['masks'], ori_height=ori_height, ori_width=ori_width)
154
+
155
+ assert self.mode in ['given_objects', 'given_description']
156
+ if self.mode == 'given_objects':
157
+ data_dict['regions'] = copy.deepcopy(data_dict['masks'])
158
+
159
+ # if data_dict['masks'] is None:
160
+ # return self.__getitem__(0)
161
+ else:
162
+ if hasattr(self.image_processor, 'crop_size'):
163
+ crop_size = self.image_processor.crop_size
164
+ else:
165
+ crop_size = self.image_processor.size
166
+ data_dict['pixel_values'] = torch.zeros(3, crop_size['height'],
167
+ crop_size['width'])
168
+ data_dict['masks'] = None
169
+ return data_dict
170
+
171
+ class DecoupledRefCOCOgGCGDataset(DecoupledGCGDataset):
172
+ def __init__(self,
173
+ image_folder,
174
+ image_processor,
175
+ data_path=None,
176
+ tokenizer=None,
177
+ offline_processed_text_folder=None,
178
+ max_dataset_length=None,
179
+ dataset_map_fn=None,
180
+ template_map_fn=None,
181
+ max_length=2048,
182
+ pad_image_to_square=False,
183
+ debug=False,
184
+ repeats=1,
185
+ mode='given_description',
186
+ ):
187
+ super().__init__(
188
+ image_folder=image_folder,
189
+ image_processor=image_processor,
190
+ data_path=data_path,
191
+ tokenizer=tokenizer,
192
+ offline_processed_text_folder=offline_processed_text_folder,
193
+ max_dataset_length=max_dataset_length,
194
+ dataset_map_fn=dataset_map_fn,
195
+ template_map_fn=template_map_fn,
196
+ max_length=max_length,
197
+ pad_image_to_square=pad_image_to_square,
198
+ debug=debug,
199
+ repeats=repeats,
200
+ mode=mode,
201
+ )
202
+
203
+ def json_file_preprocess(self, data_path):
204
+ json_data = json.load(open(data_path))
205
+ if self.debug:
206
+ json_data = json_data[:100]
207
+
208
+ # convert {id: dict} to dict(..., id=xx)
209
+ for idx in range(len(json_data)):
210
+ id = list(json_data[idx].keys())[0]
211
+ json_data[idx] = json_data[idx][id]
212
+ json_data[idx].update({'id': id})
213
+ return json_data
214
+
215
+ class DecoupledGranDfGCGDataset(DecoupledGCGDataset):
216
+ def __init__(self,
217
+ image_folder,
218
+ image_processor,
219
+ data_path=None,
220
+ tokenizer=None,
221
+ offline_processed_text_folder=None,
222
+ max_dataset_length=None,
223
+ dataset_map_fn=None,
224
+ template_map_fn=None,
225
+ max_length=2048,
226
+ pad_image_to_square=False,
227
+ num_proc=4,
228
+ debug=False,
229
+ repeats=1,
230
+ mode='given_description'):
231
+ super().__init__(
232
+ image_folder=image_folder,
233
+ image_processor=image_processor,
234
+ data_path=data_path,
235
+ tokenizer=tokenizer,
236
+ offline_processed_text_folder=offline_processed_text_folder,
237
+ max_dataset_length=max_dataset_length,
238
+ dataset_map_fn=dataset_map_fn,
239
+ template_map_fn=template_map_fn,
240
+ max_length=max_length,
241
+ pad_image_to_square=pad_image_to_square,
242
+ num_proc=num_proc,
243
+ debug=debug,
244
+ repeats=repeats,
245
+ mode=mode
246
+ )
247
+
248
+ def decode_mask(self, object_masks, ori_height, ori_width):
249
+ binary_masks = []
250
+ for object_mask in object_masks:
251
+ binary_mask = np.zeros((ori_height, ori_width), dtype=np.uint8)
252
+
253
+ for rle in object_mask:
254
+ m = mask.decode(rle).astype(np.uint8)
255
+ binary_mask += m.squeeze()
256
+
257
+ binary_masks.append(binary_mask)
258
+ if len(binary_masks) == 0:
259
+ return None
260
+ masks = np.stack(binary_masks, axis=0)
261
+ if self.pad_image_to_square:
262
+ masks = expand2square_mask(masks)
263
+ masks = torch.from_numpy(masks)
264
+ masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio, self.image_w // self.down_ratio), mode='nearest').squeeze(0)
265
+ return masks
266
+
267
+ class DecoupledOpenPsgGCGDataset(DecoupledGranDfGCGDataset):
268
+ def __init__(self,
269
+ image_folder,
270
+ image_processor,
271
+ data_path=None,
272
+ tokenizer=None,
273
+ offline_processed_text_folder=None,
274
+ max_dataset_length=None,
275
+ dataset_map_fn=None,
276
+ template_map_fn=None,
277
+ max_length=2048,
278
+ pad_image_to_square=False,
279
+ num_proc=4,
280
+ debug=False,
281
+ repeats=1,
282
+ mode='given_description'):
283
+ super().__init__(
284
+ image_folder=image_folder,
285
+ image_processor=image_processor,
286
+ data_path=data_path,
287
+ tokenizer=tokenizer,
288
+ offline_processed_text_folder=offline_processed_text_folder,
289
+ max_dataset_length=max_dataset_length,
290
+ dataset_map_fn=dataset_map_fn,
291
+ template_map_fn=template_map_fn,
292
+ max_length=max_length,
293
+ pad_image_to_square=pad_image_to_square,
294
+ num_proc=num_proc,
295
+ debug=debug,
296
+ repeats=repeats,
297
+ mode=mode
298
+ )
299
+
300
+ class DecoupledFlickrGCGDataset(DecoupledGCGDataset):
301
+ def __init__(self,
302
+ image_folder,
303
+ image_processor,
304
+ data_path=None,
305
+ tokenizer=None,
306
+ offline_processed_text_folder=None,
307
+ max_dataset_length=None,
308
+ dataset_map_fn=None,
309
+ template_map_fn=None,
310
+ max_length=2048,
311
+ pad_image_to_square=False,
312
+ num_proc=4,
313
+ debug=False,
314
+ repeats=1,
315
+ mode='given_description'
316
+ ):
317
+ super().__init__(
318
+ image_folder=image_folder,
319
+ image_processor=image_processor,
320
+ data_path=data_path,
321
+ tokenizer=tokenizer,
322
+ offline_processed_text_folder=offline_processed_text_folder,
323
+ max_dataset_length=max_dataset_length,
324
+ dataset_map_fn=dataset_map_fn,
325
+ template_map_fn=template_map_fn,
326
+ max_length=max_length,
327
+ pad_image_to_square=pad_image_to_square,
328
+ num_proc=num_proc,
329
+ debug=debug,
330
+ repeats=repeats,
331
+ mode=mode
332
+ )
333
+
334
+ def json_file_preprocess(self, data_path):
335
+ def filter_images(data_infos, min_size):
336
+ return [i for i, info in enumerate(data_infos) if min(info['width'], info['height']) >= min_size]
337
+
338
+ # convert {id: dict} to dict(..., id=xx)
339
+ from pycocotools.coco import COCO
340
+ self.coco = COCO(data_path)
341
+ self.image_ids = self.coco.getImgIds()
342
+ data_infos = []
343
+ total_ann_ids = []
344
+ removed_img_count = 0
345
+ for img_id in self.image_ids:
346
+ info = self.coco.loadImgs([img_id])[0]
347
+ if len(info['caption'].split(' ')) < 3:
348
+ removed_img_count += 1
349
+ continue
350
+ info['filename'] = info['file_name'].split('_')[-1]
351
+ info['height'] = int(info['height'])
352
+ info['width'] = int(info['width'])
353
+ data_infos.append(info)
354
+ ann_ids = self.coco.getAnnIds(imgIds=[img_id])
355
+ total_ann_ids.extend(ann_ids)
356
+ assert len(set(total_ann_ids)) == len(total_ann_ids), f"Non-unique annotation IDs in '{data_path}'!"
357
+ print(f'Removed {removed_img_count} images.')
358
+ data_infos = [data_infos[i] for i in filter_images(data_infos, min_size=32)]
359
+
360
+ # obtain_annotations
361
+ for data_info in data_infos:
362
+ ann_ids = self.coco.getAnnIds(imgIds=data_info['id'])
363
+ ann_info = self.coco.loadAnns(ann_ids)
364
+ data_info.update({'ann_info': ann_info})
365
+ if self.debug:
366
+ data_infos = data_infos[:32]
367
+ return data_infos
368
+
369
+ def decode_mask(self, object_masks, ori_height, ori_width):
370
+ binary_masks = []
371
+ for object_mask in object_masks:
372
+ binary_mask = mask.decode(object_mask).astype(np.uint8)
373
+ binary_masks.append(binary_mask)
374
+ if len(binary_masks) == 0:
375
+ return None
376
+ masks = np.stack(binary_masks, axis=0)
377
+ if self.pad_image_to_square:
378
+ masks = expand2square_mask(masks)
379
+ masks = torch.from_numpy(masks)
380
+ masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio, self.image_w // self.down_ratio), mode='nearest').squeeze(0)
381
+ return masks
omg_llava/dataset/GCGDataset.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+
5
+ import torch
6
+ from datasets import Dataset as HFDataset
7
+ from datasets import DatasetDict, load_from_disk
8
+ from mmengine import print_log
9
+ from mmengine.config import Config, ConfigDict
10
+ from PIL import Image
11
+ from torch.utils.data import Dataset
12
+ from pycocotools import mask
13
+ import numpy as np
14
+ import torch.nn.functional as F
15
+ import copy
16
+
17
+ from xtuner.registry import BUILDER
18
+ from omg_llava.dataset.utils import expand2square, expand2square_mask
19
+ from xtuner.dataset.huggingface import process_hf_dataset
20
+
21
+ class GCGDataset(Dataset):
22
+
23
+ def __init__(self,
24
+ image_folder,
25
+ image_processor,
26
+ data_path=None,
27
+ tokenizer=None,
28
+ offline_processed_text_folder=None,
29
+ max_dataset_length=None,
30
+ dataset_map_fn=None,
31
+ template_map_fn=None,
32
+ max_length=2048,
33
+ pad_image_to_square=False,
34
+ num_proc=32,
35
+ debug=False,
36
+ repeats=1):
37
+ super().__init__()
38
+
39
+ assert offline_processed_text_folder or (data_path and tokenizer)
40
+ self.debug = debug
41
+ if offline_processed_text_folder and data_path:
42
+ print_log(
43
+ 'Both `offline_processed_text_folder` and '
44
+ '`data_path` are set, and we load dataset from'
45
+ '`offline_processed_text_folder` '
46
+ f'({offline_processed_text_folder})',
47
+ logger='current',
48
+ level=logging.WARNING)
49
+
50
+ if offline_processed_text_folder is not None:
51
+ raise NotImplementedError
52
+ else:
53
+ json_data = self.json_file_preprocess(data_path)
54
+ json_data = DatasetDict({'train': HFDataset.from_list(json_data)})
55
+ self.text_data = process_hf_dataset(
56
+ dataset=json_data,
57
+ tokenizer=tokenizer,
58
+ max_length=max_length,
59
+ dataset_map_fn=dataset_map_fn,
60
+ template_map_fn=template_map_fn,
61
+ split='train',
62
+ max_dataset_length=max_dataset_length,
63
+ remove_unused_columns=False,
64
+ pack_to_max_length=False,
65
+ with_image_token=True,
66
+ map_num_proc=num_proc, # because limited mem
67
+ )
68
+
69
+ self.image_folder = image_folder
70
+ size = image_processor.crop_size
71
+ if isinstance(size, int):
72
+ self.image_h, self.image_w = size, size
73
+ else:
74
+ self.image_w, self.image_h = size
75
+
76
+ if isinstance(image_processor, dict) or isinstance(
77
+ image_processor, Config) or isinstance(image_processor,
78
+ ConfigDict):
79
+ self.image_processor = BUILDER.build(image_processor)
80
+ else:
81
+ self.image_processor = image_processor
82
+ self.pad_image_to_square = pad_image_to_square
83
+ self.down_ratio = 1
84
+ self.repeats = repeats
85
+
86
+ def json_file_preprocess(self, data_path):
87
+ with open(data_path, 'r') as f:
88
+ json_data = json.load(f)
89
+
90
+ # for quickly debug with mini split
91
+ if self.debug:
92
+ json_data = json_data[:100]
93
+ return json_data
94
+
95
+ @property
96
+ def modality_length(self):
97
+ length_list = []
98
+ for data_dict in self.text_data:
99
+ cur_len = len(data_dict['input_ids'])
100
+ if data_dict.get('image', None) is None:
101
+ cur_len = -cur_len
102
+ length_list.append(cur_len)
103
+ length_list = length_list * self.repeats
104
+ return length_list
105
+
106
+ def __len__(self):
107
+ return len(self.text_data) * self.repeats
108
+
109
+ def real_len(self):
110
+ return len(self.text_data)
111
+
112
+ def decode_mask(self, object_masks, ori_height, ori_width):
113
+ binary_masks = []
114
+ for object_mask in object_masks:
115
+ binary_mask = np.zeros((ori_height, ori_width), dtype=np.uint8)
116
+ for seg in object_mask:
117
+ rles = mask.frPyObjects([seg], ori_height, ori_width)
118
+ m = mask.decode(rles)
119
+ m = m.astype(np.uint8)
120
+ binary_mask += m.squeeze()
121
+
122
+ binary_masks.append(binary_mask)
123
+ if len(binary_masks) == 0:
124
+ return None
125
+ masks = np.stack(binary_masks, axis=0)
126
+ if self.pad_image_to_square:
127
+ masks = expand2square_mask(masks)
128
+ masks = torch.from_numpy(masks)
129
+ masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio, self.image_w // self.down_ratio), mode='nearest').squeeze(0)
130
+ return masks
131
+
132
+ def __getitem__(self, index):
133
+ index = index % self.real_len()
134
+ data_dict = copy.deepcopy(self.text_data[index])
135
+
136
+ if data_dict.get('image', None) is not None:
137
+ image_file = data_dict['image']
138
+ image = Image.open(os.path.join(self.image_folder,
139
+ image_file)).convert('RGB')
140
+ ori_width, ori_height = image.size
141
+ if self.pad_image_to_square:
142
+ image = expand2square(
143
+ image,
144
+ tuple(
145
+ int(x * 255) for x in self.image_processor.image_mean))
146
+ image = self.image_processor.preprocess(
147
+ image, return_tensors='pt')['pixel_values'][0]
148
+ data_dict['pixel_values'] = image
149
+
150
+ # process and get masks
151
+ data_dict['masks'] = self.decode_mask(data_dict['masks'], ori_height=ori_height, ori_width=ori_width)
152
+ if data_dict['masks'] is None:
153
+ return self.__getitem__(0)
154
+ else:
155
+ if hasattr(self.image_processor, 'crop_size'):
156
+ crop_size = self.image_processor.crop_size
157
+ else:
158
+ crop_size = self.image_processor.size
159
+ data_dict['pixel_values'] = torch.zeros(3, crop_size['height'],
160
+ crop_size['width'])
161
+ data_dict['masks'] = None
162
+ return data_dict
163
+
164
+ class RefCOCOgGCGDataset(GCGDataset):
165
+ def __init__(self,
166
+ image_folder,
167
+ image_processor,
168
+ data_path=None,
169
+ tokenizer=None,
170
+ offline_processed_text_folder=None,
171
+ max_dataset_length=None,
172
+ dataset_map_fn=None,
173
+ template_map_fn=None,
174
+ max_length=2048,
175
+ pad_image_to_square=False,
176
+ debug=False,
177
+ repeats=1,):
178
+ super().__init__(
179
+ image_folder=image_folder,
180
+ image_processor=image_processor,
181
+ data_path=data_path,
182
+ tokenizer=tokenizer,
183
+ offline_processed_text_folder=offline_processed_text_folder,
184
+ max_dataset_length=max_dataset_length,
185
+ dataset_map_fn=dataset_map_fn,
186
+ template_map_fn=template_map_fn,
187
+ max_length=max_length,
188
+ pad_image_to_square=pad_image_to_square,
189
+ debug=debug,
190
+ repeats=repeats,
191
+ )
192
+
193
+ def json_file_preprocess(self, data_path):
194
+ json_data = json.load(open(data_path))
195
+ if self.debug:
196
+ json_data = json_data[:100]
197
+
198
+ # convert {id: dict} to dict(..., id=xx)
199
+ for idx in range(len(json_data)):
200
+ id = list(json_data[idx].keys())[0]
201
+ json_data[idx] = json_data[idx][id]
202
+ json_data[idx].update({'id': id})
203
+ return json_data
204
+
205
+ class GranDfGCGDataset(GCGDataset):
206
+ def __init__(self,
207
+ image_folder,
208
+ image_processor,
209
+ data_path=None,
210
+ tokenizer=None,
211
+ offline_processed_text_folder=None,
212
+ max_dataset_length=None,
213
+ dataset_map_fn=None,
214
+ template_map_fn=None,
215
+ max_length=2048,
216
+ pad_image_to_square=False,
217
+ num_proc=4,
218
+ debug=False,
219
+ repeats=1):
220
+ super().__init__(
221
+ image_folder=image_folder,
222
+ image_processor=image_processor,
223
+ data_path=data_path,
224
+ tokenizer=tokenizer,
225
+ offline_processed_text_folder=offline_processed_text_folder,
226
+ max_dataset_length=max_dataset_length,
227
+ dataset_map_fn=dataset_map_fn,
228
+ template_map_fn=template_map_fn,
229
+ max_length=max_length,
230
+ pad_image_to_square=pad_image_to_square,
231
+ num_proc=num_proc,
232
+ debug=debug,
233
+ repeats=repeats,
234
+ )
235
+
236
+ def decode_mask(self, object_masks, ori_height, ori_width):
237
+ binary_masks = []
238
+ for object_mask in object_masks:
239
+ binary_mask = np.zeros((ori_height, ori_width), dtype=np.uint8)
240
+
241
+ for rle in object_mask:
242
+ m = mask.decode(rle).astype(np.uint8)
243
+ binary_mask += m.squeeze()
244
+
245
+ binary_masks.append(binary_mask)
246
+ if len(binary_masks) == 0:
247
+ return None
248
+ masks = np.stack(binary_masks, axis=0)
249
+ if self.pad_image_to_square:
250
+ masks = expand2square_mask(masks)
251
+ masks = torch.from_numpy(masks)
252
+ masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio, self.image_w // self.down_ratio), mode='nearest').squeeze(0)
253
+ return masks
254
+
255
+ class OpenPsgGCGDataset(GranDfGCGDataset):
256
+ def __init__(self,
257
+ image_folder,
258
+ image_processor,
259
+ data_path=None,
260
+ tokenizer=None,
261
+ offline_processed_text_folder=None,
262
+ max_dataset_length=None,
263
+ dataset_map_fn=None,
264
+ template_map_fn=None,
265
+ max_length=2048,
266
+ pad_image_to_square=False,
267
+ num_proc=4,
268
+ debug=False,
269
+ repeats=1):
270
+ super().__init__(
271
+ image_folder=image_folder,
272
+ image_processor=image_processor,
273
+ data_path=data_path,
274
+ tokenizer=tokenizer,
275
+ offline_processed_text_folder=offline_processed_text_folder,
276
+ max_dataset_length=max_dataset_length,
277
+ dataset_map_fn=dataset_map_fn,
278
+ template_map_fn=template_map_fn,
279
+ max_length=max_length,
280
+ pad_image_to_square=pad_image_to_square,
281
+ num_proc=num_proc,
282
+ debug=debug,
283
+ repeats=repeats,
284
+ )
285
+
286
+ class FlickrGCGDataset(GCGDataset):
287
+ def __init__(self,
288
+ image_folder,
289
+ image_processor,
290
+ data_path=None,
291
+ tokenizer=None,
292
+ offline_processed_text_folder=None,
293
+ max_dataset_length=None,
294
+ dataset_map_fn=None,
295
+ template_map_fn=None,
296
+ max_length=2048,
297
+ pad_image_to_square=False,
298
+ num_proc=4,
299
+ debug=False,
300
+ repeats=1,):
301
+ super().__init__(
302
+ image_folder=image_folder,
303
+ image_processor=image_processor,
304
+ data_path=data_path,
305
+ tokenizer=tokenizer,
306
+ offline_processed_text_folder=offline_processed_text_folder,
307
+ max_dataset_length=max_dataset_length,
308
+ dataset_map_fn=dataset_map_fn,
309
+ template_map_fn=template_map_fn,
310
+ max_length=max_length,
311
+ pad_image_to_square=pad_image_to_square,
312
+ num_proc=num_proc,
313
+ debug=debug,
314
+ repeats=repeats,
315
+ )
316
+
317
+ def json_file_preprocess(self, data_path):
318
+ def filter_images(data_infos, min_size):
319
+ return [i for i, info in enumerate(data_infos) if min(info['width'], info['height']) >= min_size]
320
+
321
+ # convert {id: dict} to dict(..., id=xx)
322
+ from pycocotools.coco import COCO
323
+ self.coco = COCO(data_path)
324
+ self.image_ids = self.coco.getImgIds()
325
+ data_infos = []
326
+ total_ann_ids = []
327
+ removed_img_count = 0
328
+ for img_id in self.image_ids:
329
+ info = self.coco.loadImgs([img_id])[0]
330
+ if len(info['caption'].split(' ')) < 3:
331
+ removed_img_count += 1
332
+ continue
333
+ info['filename'] = info['file_name'].split('_')[-1]
334
+ info['height'] = int(info['height'])
335
+ info['width'] = int(info['width'])
336
+ data_infos.append(info)
337
+ ann_ids = self.coco.getAnnIds(imgIds=[img_id])
338
+ total_ann_ids.extend(ann_ids)
339
+ assert len(set(total_ann_ids)) == len(total_ann_ids), f"Non-unique annotation IDs in '{data_path}'!"
340
+ print(f'Removed {removed_img_count} images.')
341
+ data_infos = [data_infos[i] for i in filter_images(data_infos, min_size=32)]
342
+
343
+ # obtain_annotations
344
+ for data_info in data_infos:
345
+ ann_ids = self.coco.getAnnIds(imgIds=data_info['id'])
346
+ ann_info = self.coco.loadAnns(ann_ids)
347
+ data_info.update({'ann_info': ann_info})
348
+ if self.debug:
349
+ data_infos = data_infos[:32]
350
+ return data_infos
351
+
352
+ def decode_mask(self, object_masks, ori_height, ori_width):
353
+ binary_masks = []
354
+ for object_mask in object_masks:
355
+ binary_mask = mask.decode(object_mask).astype(np.uint8)
356
+ binary_masks.append(binary_mask)
357
+ if len(binary_masks) == 0:
358
+ return None
359
+ masks = np.stack(binary_masks, axis=0)
360
+ if self.pad_image_to_square:
361
+ masks = expand2square_mask(masks)
362
+ masks = torch.from_numpy(masks)
363
+ masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio, self.image_w // self.down_ratio), mode='nearest').squeeze(0)
364
+ return masks
omg_llava/dataset/LlavaDataset.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import json
3
+ import logging
4
+ import os
5
+
6
+ import torch
7
+ from datasets import Dataset as HFDataset
8
+ from datasets import DatasetDict, load_from_disk
9
+ from mmengine import print_log
10
+ from mmengine.config import Config, ConfigDict
11
+ from PIL import Image
12
+ from torch.utils.data import Dataset
13
+
14
+ from xtuner.registry import BUILDER
15
+ from xtuner.dataset.huggingface import process_hf_dataset
16
+ from .utils import expand2square
17
+ import copy
18
+
19
+ class LLaVADataset(Dataset):
20
+
21
+ def __init__(self,
22
+ image_folder,
23
+ image_processor,
24
+ data_path=None,
25
+ tokenizer=None,
26
+ offline_processed_text_folder=None,
27
+ max_dataset_length=None,
28
+ dataset_map_fn=None,
29
+ template_map_fn=None,
30
+ max_length=2048,
31
+ pad_image_to_square=False,
32
+ debug=False):
33
+ super().__init__()
34
+
35
+ assert offline_processed_text_folder or (data_path and tokenizer)
36
+
37
+ self.tokenizer = tokenizer
38
+ if isinstance(tokenizer, dict) or isinstance(
39
+ tokenizer, Config) or isinstance(tokenizer, ConfigDict):
40
+ tokenizer_type = self.tokenizer['type']
41
+ del self.tokenizer['type']
42
+ self.tokenizer = tokenizer_type(**self.tokenizer)
43
+ self._add_special_tokens()
44
+
45
+ if offline_processed_text_folder and data_path:
46
+ print_log(
47
+ 'Both `offline_processed_text_folder` and '
48
+ '`data_path` are set, and we load dataset from'
49
+ '`offline_processed_text_folder` '
50
+ f'({offline_processed_text_folder})',
51
+ logger='current',
52
+ level=logging.WARNING)
53
+
54
+ if offline_processed_text_folder is not None:
55
+ self.text_data = load_from_disk(offline_processed_text_folder)
56
+ else:
57
+ json_data = json.load(open(data_path))
58
+ if debug:
59
+ json_data = json_data[:10000]
60
+ for idx in range(len(json_data)):
61
+ if isinstance(json_data[idx]['id'], int):
62
+ json_data[idx]['id'] = str(json_data[idx]['id'])
63
+ json_data = DatasetDict({'train': HFDataset.from_list(json_data)})
64
+ self.text_data = process_hf_dataset(
65
+ dataset=json_data,
66
+ tokenizer=self.tokenizer,
67
+ max_length=max_length,
68
+ dataset_map_fn=dataset_map_fn,
69
+ template_map_fn=template_map_fn,
70
+ split='train',
71
+ max_dataset_length=max_dataset_length,
72
+ remove_unused_columns=False,
73
+ pack_to_max_length=False,
74
+ with_image_token=True,
75
+ map_num_proc=32, # because limited mem
76
+ )
77
+
78
+ self.image_folder = image_folder
79
+ if isinstance(image_processor, dict) or isinstance(
80
+ image_processor, Config) or isinstance(image_processor,
81
+ ConfigDict):
82
+ self.image_processor = BUILDER.build(image_processor)
83
+ else:
84
+ self.image_processor = image_processor
85
+ self.pad_image_to_square = pad_image_to_square
86
+
87
+ @property
88
+ def modality_length(self):
89
+ length_list = []
90
+ for data_dict in self.text_data:
91
+ cur_len = len(data_dict['input_ids'])
92
+ if data_dict.get('image', None) is None:
93
+ cur_len = -cur_len
94
+ length_list.append(cur_len)
95
+ return length_list
96
+
97
+ def __len__(self):
98
+ return len(self.text_data)
99
+
100
+ def __getitem__(self, index):
101
+ data_dict = copy.deepcopy(self.text_data[index])
102
+ if data_dict.get('image', None) is not None:
103
+ image_file = data_dict['image']
104
+ image = Image.open(os.path.join(self.image_folder,
105
+ image_file)).convert('RGB')
106
+ if self.pad_image_to_square:
107
+ image = expand2square(
108
+ image,
109
+ tuple(
110
+ int(x * 255) for x in self.image_processor.image_mean))
111
+ image = self.image_processor.preprocess(
112
+ image, return_tensors='pt')['pixel_values'][0]
113
+ data_dict['pixel_values'] = image
114
+ else:
115
+ if hasattr(self.image_processor, 'crop_size'):
116
+ crop_size = self.image_processor.crop_size
117
+ else:
118
+ crop_size = self.image_processor.size
119
+ data_dict['pixel_values'] = torch.zeros(3, crop_size['height'],
120
+ crop_size['width'])
121
+ return data_dict
122
+
123
+ def _add_special_tokens(self):
124
+ assert hasattr(self, "tokenizer")
125
+ # Adding special tokens for pixel grounding
126
+ segmentation_tokens = ['[SEG]']
127
+ # Adding tokens for GCG
128
+ phrase_tokens = ['<p>', '</p>']
129
+ # add for visual prompt
130
+ region_tokens = ['<region>']
131
+ point_tokens = ['<mark>']
132
+ special_tokens = segmentation_tokens + phrase_tokens + region_tokens + point_tokens
133
+ self.tokenizer.add_tokens(special_tokens, special_tokens=True)
134
+ return
omg_llava/dataset/MDPVPointsDataset.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import torch
5
+ from datasets import Dataset as HFDataset
6
+ from datasets import DatasetDict, load_from_disk
7
+ from mmengine import print_log
8
+ from mmengine.config import Config, ConfigDict
9
+ from PIL import Image
10
+ from torch.utils.data import Dataset
11
+ from pycocotools import mask
12
+ import numpy as np
13
+ import torch.nn.functional as F
14
+
15
+ from xtuner.registry import BUILDER
16
+ from omg_llava.dataset.utils import expand2square, expand2square_mask, expand2square_points
17
+ from xtuner.dataset.huggingface import process_hf_dataset
18
+ import copy
19
+
20
+ class MDPVPointDetailedCaptionDataset(Dataset):
21
+ def __init__(self,
22
+ image_folder,
23
+ image_processor,
24
+ data_path=None,
25
+ tokenizer=None,
26
+ offline_processed_text_folder=None,
27
+ max_dataset_length=None,
28
+ dataset_map_fn=None,
29
+ template_map_fn=None,
30
+ max_length=2048,
31
+ pad_image_to_square=False,
32
+ num_proc=32,
33
+ debug=False,
34
+ repeats=1):
35
+ super().__init__()
36
+
37
+ assert offline_processed_text_folder or (data_path and tokenizer)
38
+ self.debug = debug
39
+ if offline_processed_text_folder and data_path:
40
+ print_log(
41
+ 'Both `offline_processed_text_folder` and '
42
+ '`data_path` are set, and we load dataset from'
43
+ '`offline_processed_text_folder` '
44
+ f'({offline_processed_text_folder})',
45
+ logger='current',
46
+ level=logging.WARNING)
47
+
48
+ if offline_processed_text_folder is not None:
49
+ raise NotImplementedError
50
+ else:
51
+ json_data = self.json_file_preprocess(data_path)
52
+ self.json_data = json_data
53
+ hf_json_data = self.filter_hf_require_infos(json_data)
54
+ hf_json_data = DatasetDict({'train': HFDataset.from_list(hf_json_data)})
55
+ self.text_data = process_hf_dataset(
56
+ dataset=hf_json_data,
57
+ tokenizer=tokenizer,
58
+ max_length=max_length,
59
+ dataset_map_fn=dataset_map_fn,
60
+ template_map_fn=template_map_fn,
61
+ split='train',
62
+ max_dataset_length=max_dataset_length,
63
+ remove_unused_columns=False,
64
+ pack_to_max_length=False,
65
+ with_image_token=True,
66
+ map_num_proc=num_proc, # because limited mem
67
+ )
68
+
69
+ self.image_folder = image_folder
70
+ size = image_processor.crop_size
71
+ if isinstance(size, int):
72
+ self.image_h, self.image_w = size, size
73
+ else:
74
+ self.image_w, self.image_h = size
75
+
76
+ if isinstance(image_processor, dict) or isinstance(
77
+ image_processor, Config) or isinstance(image_processor,
78
+ ConfigDict):
79
+ self.image_processor = BUILDER.build(image_processor)
80
+ else:
81
+ self.image_processor = image_processor
82
+ self.pad_image_to_square = pad_image_to_square
83
+ self.down_ratio = 1
84
+ self.repeats = repeats
85
+
86
+ def filter_hf_require_infos(self, dataset_infos):
87
+ ret = []
88
+ for dataset_info in dataset_infos:
89
+ conversations = dataset_info["conversations"]
90
+ image = dataset_info['image'].split('/')[-1]
91
+ num_marks = len(dataset_info['points'])
92
+ required_info = {'image': image,
93
+ 'conversations': conversations,
94
+ 'num_marks': num_marks}
95
+ ret.append(required_info)
96
+ return ret
97
+
98
+ def json_file_preprocess(self, data_path):
99
+ with open(data_path, 'r') as f:
100
+ json_file = json.load(f)
101
+ if self.debug:
102
+ json_file = json_file[:10000]
103
+ return json_file
104
+
105
+ @property
106
+ def modality_length(self):
107
+ length_list = []
108
+ for data_dict in self.text_data:
109
+ cur_len = len(data_dict['input_ids'])
110
+ if data_dict.get('image', None) is None:
111
+ cur_len = -cur_len
112
+ length_list.append(cur_len)
113
+ length_list = length_list * self.repeats
114
+ return length_list
115
+
116
+ def __len__(self):
117
+ return len(self.text_data) * self.repeats
118
+
119
+ def real_len(self):
120
+ return len(self.text_data)
121
+
122
+ def decode_mask(self, object_masks, ori_height, ori_width):
123
+ binary_masks = []
124
+ for object_mask in object_masks:
125
+ binary_mask = np.zeros((ori_height, ori_width), dtype=np.uint8)
126
+ for seg in object_mask:
127
+ rles = mask.frPyObjects([seg], ori_height, ori_width)
128
+ m = mask.decode(rles)
129
+ m = m.astype(np.uint8)
130
+ binary_mask += m.squeeze()
131
+
132
+ binary_masks.append(binary_mask)
133
+ if len(binary_masks) == 0:
134
+ return None
135
+ masks = np.stack(binary_masks, axis=0)
136
+ if self.pad_image_to_square:
137
+ masks = expand2square_mask(masks)
138
+ masks = torch.from_numpy(masks)
139
+ masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio, self.image_w // self.down_ratio), mode='nearest').squeeze(0)
140
+ return masks
141
+
142
+ def __getitem__(self, index):
143
+ index = index % self.real_len()
144
+ data_dict = copy.deepcopy(self.json_data[index])
145
+ data_dict.update(self.text_data[index])
146
+
147
+ if data_dict.get('image', None) is not None:
148
+ image_file = data_dict['image']
149
+ image_path = os.path.join(self.image_folder, image_file)
150
+ if not os.path.exists(image_path) and "VG" in self.image_folder:
151
+ image_path = os.path.join(self.image_folder + "_2", image_file)
152
+ image = Image.open(image_path).convert('RGB')
153
+ ori_width, ori_height = image.size
154
+ if self.pad_image_to_square:
155
+ image = expand2square(
156
+ image,
157
+ tuple(
158
+ int(x * 255) for x in self.image_processor.image_mean))
159
+
160
+ image = self.image_processor.preprocess(
161
+ image, return_tensors='pt')['pixel_values'][0]
162
+ data_dict['pixel_values'] = image
163
+
164
+ # process and get masks
165
+ points = data_dict["points"]
166
+ points = np.array(points)
167
+ if self.pad_image_to_square:
168
+ points = expand2square_points(points, height=ori_height, width=ori_width)
169
+ points[:, 0] = points[:, 0] / max(ori_height, ori_width) * self.image_w
170
+ points[:, 1] = points[:, 1] / max(ori_height, ori_width) * self.image_h
171
+ else:
172
+ points[:, 0] = points[:, 0] / ori_width * self.image_w
173
+ points[:, 1] = points[:, 1] / ori_height * self.image_h
174
+ data_dict['points'] = torch.from_numpy(points)
175
+ if data_dict['points'] is None:
176
+ return self.__getitem__(0)
177
+ data_dict['masks'] = None
178
+ data_dict['regions'] = None
179
+ else:
180
+ if hasattr(self.image_processor, 'crop_size'):
181
+ crop_size = self.image_processor.crop_size
182
+ else:
183
+ crop_size = self.image_processor.size
184
+ data_dict['pixel_values'] = torch.zeros(3, crop_size['height'],
185
+ crop_size['width'])
186
+ data_dict['masks'] = None
187
+ data_dict['regions'] = None
188
+ data_dict['points'] = None
189
+ return data_dict
190
+
191
+ class MDPVPointBriefCaptionDataset(MDPVPointDetailedCaptionDataset):
192
+ def __init__(self,
193
+ image_folder,
194
+ image_processor,
195
+ data_path=None,
196
+ tokenizer=None,
197
+ offline_processed_text_folder=None,
198
+ max_dataset_length=None,
199
+ dataset_map_fn=None,
200
+ template_map_fn=None,
201
+ max_length=2048,
202
+ pad_image_to_square=False,
203
+ num_proc=32,
204
+ debug=False,
205
+ repeats=1):
206
+ super().__init__(
207
+ image_folder=image_folder,
208
+ image_processor=image_processor,
209
+ data_path=data_path,
210
+ tokenizer=tokenizer,
211
+ offline_processed_text_folder=offline_processed_text_folder,
212
+ max_dataset_length=max_dataset_length,
213
+ dataset_map_fn=dataset_map_fn,
214
+ template_map_fn=template_map_fn,
215
+ max_length=max_length,
216
+ pad_image_to_square=pad_image_to_square,
217
+ num_proc=num_proc,
218
+ debug=debug,
219
+ repeats=repeats
220
+ )
omg_llava/dataset/ReferringSegDataset.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import torch
4
+ from datasets import Dataset as HFDataset
5
+ from datasets import DatasetDict, load_from_disk
6
+ from mmengine import print_log
7
+ from mmengine.config import Config, ConfigDict
8
+ from PIL import Image
9
+ from torch.utils.data import Dataset
10
+ from pycocotools import mask
11
+ import numpy as np
12
+ import torch.nn.functional as F
13
+
14
+ from xtuner.registry import BUILDER
15
+ from omg_llava.dataset.utils import expand2square, expand2square_mask
16
+ from xtuner.dataset.huggingface import process_hf_dataset
17
+ from omg_llava.dataset.utils.refcoco_refer import REFER
18
+ import copy
19
+
20
+ class RefcocoReferringSegDataset(Dataset):
21
+ def __init__(self,
22
+ image_folder,
23
+ image_processor,
24
+ data_path=None,
25
+ tokenizer=None,
26
+ offline_processed_text_folder=None,
27
+ max_dataset_length=None,
28
+ dataset_map_fn=None,
29
+ template_map_fn=None,
30
+ max_length=2048,
31
+ pad_image_to_square=False,
32
+ num_proc=8,
33
+ debug=False,
34
+ repeats=1,):
35
+ self._set_attribute()
36
+ self.tokenizer = tokenizer
37
+ assert offline_processed_text_folder or (data_path and tokenizer)
38
+ self.debug = debug
39
+ if offline_processed_text_folder and data_path:
40
+ print_log(
41
+ 'Both `offline_processed_text_folder` and '
42
+ '`data_path` are set, and we load dataset from'
43
+ '`offline_processed_text_folder` '
44
+ f'({offline_processed_text_folder})',
45
+ logger='current',
46
+ level=logging.WARNING)
47
+
48
+ if offline_processed_text_folder is not None:
49
+ raise NotImplementedError
50
+ else:
51
+ json_datas = self.json_file_preprocess(data_path)
52
+ self.json_datas = json_datas
53
+ json_datas = self.only_get_hf_map_infos()
54
+ json_data = DatasetDict({'train': HFDataset.from_list(json_datas)})
55
+ self.text_data = process_hf_dataset(
56
+ dataset=json_data,
57
+ tokenizer=tokenizer,
58
+ max_length=max_length,
59
+ dataset_map_fn=dataset_map_fn,
60
+ template_map_fn=template_map_fn,
61
+ split='train',
62
+ max_dataset_length=max_dataset_length,
63
+ remove_unused_columns=False,
64
+ pack_to_max_length=False,
65
+ with_image_token=True,
66
+ map_num_proc=num_proc, # because limited mem
67
+ )
68
+
69
+ self.image_folder = image_folder
70
+ size = image_processor.crop_size
71
+ if isinstance(size, int):
72
+ self.image_h, self.image_w = size, size
73
+ else:
74
+ self.image_w, self.image_h = size
75
+
76
+ if isinstance(image_processor, dict) or isinstance(
77
+ image_processor, Config) or isinstance(image_processor,
78
+ ConfigDict):
79
+ self.image_processor = BUILDER.build(image_processor)
80
+ else:
81
+ self.image_processor = image_processor
82
+ self.pad_image_to_square = pad_image_to_square
83
+ self.down_ratio = 1
84
+ self.repeats = repeats
85
+
86
+ def _set_attribute(self):
87
+ self.splitBy = "unc"
88
+ self.dataset_name = 'refcoco'
89
+
90
+ def only_get_hf_map_infos(self):
91
+ ret = []
92
+ for json_data in self.json_datas:
93
+ ret.append({'sampled_sents': json_data['selected_labels']})
94
+ return ret
95
+
96
+ def __len__(self):
97
+ return len(self.text_data) * self.repeats
98
+
99
+ @property
100
+ def modality_length(self):
101
+ length_list = []
102
+ for data_dict in self.text_data:
103
+ cur_len = len(data_dict['input_ids'])
104
+ if data_dict.get('image', None) is None:
105
+ cur_len = -cur_len
106
+ length_list.append(cur_len)
107
+ length_list = length_list * self.repeats
108
+ return length_list
109
+
110
+ def real_len(self):
111
+ return len(self.text_data)
112
+
113
+ def json_file_preprocess(self, data_path):
114
+ splitBy = self.splitBy
115
+ dataset_name = self.dataset_name
116
+ refer_api = REFER(data_path, dataset_name, splitBy)
117
+ ref_ids_train = refer_api.getRefIds(split='train')
118
+ images_ids_train = refer_api.getImgIds(ref_ids=ref_ids_train)
119
+ refs_train = refer_api.loadRefs(ref_ids=ref_ids_train)
120
+ self.img2refs = self.create_img_to_refs_mapping(refs_train)
121
+
122
+ image_infos = []
123
+ loaded_images = refer_api.loadImgs(image_ids=images_ids_train)
124
+ for item in loaded_images:
125
+ item = item.copy()
126
+ image_infos.append(item)
127
+
128
+ self.annotations = refer_api.Anns
129
+
130
+ refs = [self.img2refs[image_info['id']] for image_info in image_infos]
131
+
132
+ ret = []
133
+ for image_info, ref in zip(image_infos, refs):
134
+ if len(ref) == 0:
135
+ continue
136
+
137
+ sents = []
138
+ ann_ids = []
139
+ for _ref in ref:
140
+ for sent in _ref["sentences"]:
141
+ text = sent["sent"]
142
+ sents.append(text)
143
+ ann_ids.append(_ref["ann_id"])
144
+ if len(sents) >= 3:
145
+ sampled_inds = np.random.choice(
146
+ list(range(len(sents))), size=3, replace=False
147
+ )
148
+ else:
149
+ sampled_inds = list(range(len(sents)))
150
+ sampled_sents = np.vectorize(sents.__getitem__)(sampled_inds).tolist()
151
+ sampled_ann_ids = [ann_ids[ind] for ind in sampled_inds]
152
+ selected_labels = sampled_sents
153
+ ret.append(
154
+ {'image_info': image_info,
155
+ 'sampled_ann_id': sampled_ann_ids,
156
+ 'selected_labels': selected_labels,
157
+ 'image': image_info['file_name']
158
+ }
159
+ )
160
+ if self.debug:
161
+ return ret[:1000]
162
+ return ret
163
+
164
+ def create_img_to_refs_mapping(self, refs_train):
165
+ img2refs = {}
166
+ for ref in refs_train:
167
+ img2refs[ref["image_id"]] = img2refs.get(ref["image_id"], []) + [ref, ]
168
+ return img2refs
169
+
170
+ def decode_mask(self, annotations_ids, image_info):
171
+ flag = False
172
+ masks = []
173
+
174
+ for ann_id in annotations_ids:
175
+ if isinstance(ann_id, list):
176
+ flag = True
177
+ if -1 in ann_id:
178
+ assert len(ann_id) == 1
179
+ m = np.zeros((image_info["height"], image_info["width"])).astype(
180
+ np.uint8
181
+ )
182
+ else:
183
+ m_final = np.zeros(
184
+ (image_info["height"], image_info["width"])
185
+ ).astype(np.uint8)
186
+ for ann_id_i in ann_id:
187
+ ann = self.annotations[ann_id_i]
188
+
189
+ if len(ann["segmentation"]) == 0:
190
+ m = np.zeros(
191
+ (image_info["height"], image_info["width"])
192
+ ).astype(np.uint8)
193
+ else:
194
+ if type(ann["segmentation"][0]) == list: # polygon
195
+ rle = mask.frPyObjects(
196
+ ann["segmentation"], image_info["height"], image_info["width"], )
197
+ else:
198
+ rle = ann["segmentation"]
199
+ for i in range(len(rle)):
200
+ if not isinstance(rle[i]["counts"], bytes):
201
+ rle[i]["counts"] = rle[i]["counts"].encode()
202
+ m = mask.decode(rle)
203
+ m = np.sum(
204
+ m, axis=2
205
+ ) # sometimes there are multiple binary map (corresponding to multiple segs)
206
+ m = m.astype(np.uint8) # convert to np.uint8
207
+ m_final = m_final | m
208
+ m = m_final
209
+ masks.append(m)
210
+ continue
211
+
212
+ ann = self.annotations[ann_id]
213
+
214
+ if len(ann["segmentation"]) == 0:
215
+ m = np.zeros((image_info["height"], image_info["width"])).astype(
216
+ np.uint8
217
+ )
218
+ masks.append(m)
219
+ continue
220
+
221
+ if type(ann["segmentation"][0]) == list: # polygon
222
+ rle = mask.frPyObjects(
223
+ ann["segmentation"], image_info["height"], image_info["width"]
224
+ )
225
+ else:
226
+ rle = ann["segmentation"]
227
+ for i in range(len(rle)):
228
+ if not isinstance(rle[i]["counts"], bytes):
229
+ rle[i]["counts"] = rle[i]["counts"].encode()
230
+ m = mask.decode(rle)
231
+ m = np.sum(m, axis=2) # sometimes there are multiple binary map (corresponding to multiple segs)
232
+ m = m.astype(np.uint8) # convert to np.uint8
233
+ masks.append(m)
234
+ masks = np.stack(masks, axis=0)
235
+
236
+ if self.pad_image_to_square:
237
+ masks = expand2square_mask(masks)
238
+ masks = torch.from_numpy(masks)
239
+ masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio,
240
+ self.image_w // self.down_ratio), mode='nearest').squeeze(0)
241
+ return masks
242
+
243
+ def __getitem__(self, index):
244
+ index = index % self.real_len()
245
+ data_dict = copy.deepcopy(self.text_data[index])
246
+ data_dict.update(self.json_datas[index])
247
+
248
+ if data_dict.get('image', None) is not None:
249
+ image_file = data_dict['image']
250
+ image_file = os.path.join(self.image_folder, image_file)
251
+ image = Image.open(image_file).convert('RGB')
252
+ ori_width, ori_height = image.size
253
+ if self.pad_image_to_square:
254
+ image = expand2square(
255
+ image,
256
+ tuple(
257
+ int(x * 255) for x in self.image_processor.image_mean))
258
+ image = self.image_processor.preprocess(
259
+ image, return_tensors='pt')['pixel_values'][0]
260
+ data_dict['pixel_values'] = image
261
+
262
+ # process and get masks
263
+ masks = self.decode_mask(data_dict['sampled_ann_id'], data_dict['image_info'])
264
+ data_dict['masks'] = masks
265
+ else:
266
+ if hasattr(self.image_processor, 'crop_size'):
267
+ crop_size = self.image_processor.crop_size
268
+ else:
269
+ crop_size = self.image_processor.size
270
+ data_dict['pixel_values'] = torch.zeros(3, crop_size['height'],
271
+ crop_size['width'])
272
+ data_dict['masks'] = None
273
+ return data_dict
274
+
275
+ class Refcoco_plus_ReferringSegDataset(RefcocoReferringSegDataset):
276
+ def __init__(self,
277
+ image_folder,
278
+ image_processor,
279
+ data_path=None,
280
+ tokenizer=None,
281
+ offline_processed_text_folder=None,
282
+ max_dataset_length=None,
283
+ dataset_map_fn=None,
284
+ template_map_fn=None,
285
+ max_length=2048,
286
+ pad_image_to_square=False,
287
+ num_proc=8,
288
+ debug=False,
289
+ repeats=1,):
290
+
291
+ super().__init__(
292
+ image_folder=image_folder,
293
+ image_processor=image_processor,
294
+ data_path=data_path,
295
+ tokenizer=tokenizer,
296
+ offline_processed_text_folder=offline_processed_text_folder,
297
+ max_dataset_length=max_dataset_length,
298
+ dataset_map_fn=dataset_map_fn,
299
+ template_map_fn=template_map_fn,
300
+ max_length=max_length,
301
+ pad_image_to_square=pad_image_to_square,
302
+ num_proc=num_proc,
303
+ debug=debug,
304
+ repeats=repeats,)
305
+
306
+ def _set_attribute(self):
307
+ self.splitBy = "unc"
308
+ self.dataset_name = 'refcoco+'
309
+
310
+ class Refcocog_ReferringSegDataset(RefcocoReferringSegDataset):
311
+ def __init__(self,
312
+ image_folder,
313
+ image_processor,
314
+ data_path=None,
315
+ tokenizer=None,
316
+ offline_processed_text_folder=None,
317
+ max_dataset_length=None,
318
+ dataset_map_fn=None,
319
+ template_map_fn=None,
320
+ max_length=2048,
321
+ pad_image_to_square=False,
322
+ num_proc=8,
323
+ debug=False,
324
+ repeats=1,):
325
+
326
+ super().__init__(
327
+ image_folder=image_folder,
328
+ image_processor=image_processor,
329
+ data_path=data_path,
330
+ tokenizer=tokenizer,
331
+ offline_processed_text_folder=offline_processed_text_folder,
332
+ max_dataset_length=max_dataset_length,
333
+ dataset_map_fn=dataset_map_fn,
334
+ template_map_fn=template_map_fn,
335
+ max_length=max_length,
336
+ pad_image_to_square=pad_image_to_square,
337
+ num_proc=num_proc,
338
+ debug=debug,
339
+ repeats=repeats,
340
+ )
341
+
342
+ def _set_attribute(self):
343
+ self.splitBy = "umd"
344
+ self.dataset_name = 'refcocog'
345
+
346
+ class Refclef_ReferringSegDataset(RefcocoReferringSegDataset):
347
+ def __init__(self,
348
+ image_folder,
349
+ image_processor,
350
+ data_path=None,
351
+ tokenizer=None,
352
+ offline_processed_text_folder=None,
353
+ max_dataset_length=None,
354
+ dataset_map_fn=None,
355
+ template_map_fn=None,
356
+ max_length=2048,
357
+ pad_image_to_square=False,
358
+ num_proc=8,
359
+ debug=False,
360
+ repeats=1,):
361
+
362
+ super().__init__(
363
+ image_folder=image_folder,
364
+ image_processor=image_processor,
365
+ data_path=data_path,
366
+ tokenizer=tokenizer,
367
+ offline_processed_text_folder=offline_processed_text_folder,
368
+ max_dataset_length=max_dataset_length,
369
+ dataset_map_fn=dataset_map_fn,
370
+ template_map_fn=template_map_fn,
371
+ max_length=max_length,
372
+ pad_image_to_square=pad_image_to_square,
373
+ num_proc=num_proc,
374
+ debug=debug,
375
+ repeats=repeats,
376
+ )
377
+
378
+ def _set_attribute(self):
379
+ self.splitBy = "unc"
380
+ self.dataset_name = 'refclef'
omg_llava/dataset/RegionCaptionDataset.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import copy
5
+
6
+ import torch
7
+ from datasets import Dataset as HFDataset
8
+ from datasets import DatasetDict, load_from_disk
9
+ from mmengine import print_log
10
+ from mmengine.config import Config, ConfigDict
11
+ from PIL import Image, ImageDraw
12
+ from torch.utils.data import Dataset
13
+ from pycocotools import mask
14
+ import numpy as np
15
+ import torch.nn.functional as F
16
+
17
+ from xtuner.registry import BUILDER
18
+ from omg_llava.dataset.utils import expand2square, expand2square_mask
19
+ from xtuner.dataset.huggingface import process_hf_dataset
20
+
21
+ class OspreyRegionCaptionDataset(Dataset):
22
+ def __init__(self,
23
+ image_folder,
24
+ image_processor,
25
+ data_path=None,
26
+ tokenizer=None,
27
+ offline_processed_text_folder=None,
28
+ max_dataset_length=None,
29
+ dataset_map_fn=None,
30
+ template_map_fn=None,
31
+ max_length=2048,
32
+ pad_image_to_square=False,
33
+ num_proc=32,
34
+ debug=False,
35
+ repeats=1):
36
+ super().__init__()
37
+
38
+ assert offline_processed_text_folder or (data_path and tokenizer)
39
+ self.debug = debug
40
+ if offline_processed_text_folder and data_path:
41
+ print_log(
42
+ 'Both `offline_processed_text_folder` and '
43
+ '`data_path` are set, and we load dataset from'
44
+ '`offline_processed_text_folder` '
45
+ f'({offline_processed_text_folder})',
46
+ logger='current',
47
+ level=logging.WARNING)
48
+
49
+ if offline_processed_text_folder is not None:
50
+ raise NotImplementedError
51
+ else:
52
+ json_data = self.json_file_preprocess(data_path)
53
+ self.json_data = json_data
54
+ hf_json_data = self.filter_hf_require_infos(json_data)
55
+ hf_json_data = DatasetDict({'train': HFDataset.from_list(hf_json_data)})
56
+ self.text_data = process_hf_dataset(
57
+ dataset=hf_json_data,
58
+ tokenizer=tokenizer,
59
+ max_length=max_length,
60
+ dataset_map_fn=dataset_map_fn,
61
+ template_map_fn=template_map_fn,
62
+ split='train',
63
+ max_dataset_length=max_dataset_length,
64
+ remove_unused_columns=False,
65
+ pack_to_max_length=False,
66
+ with_image_token=True,
67
+ map_num_proc=num_proc, # because limited mem
68
+ )
69
+
70
+ self.image_folder = image_folder
71
+ size = image_processor.crop_size
72
+ if isinstance(size, int):
73
+ self.image_h, self.image_w = size, size
74
+ else:
75
+ self.image_w, self.image_h = size
76
+
77
+ if isinstance(image_processor, dict) or isinstance(
78
+ image_processor, Config) or isinstance(image_processor,
79
+ ConfigDict):
80
+ self.image_processor = BUILDER.build(image_processor)
81
+ else:
82
+ self.image_processor = image_processor
83
+ self.pad_image_to_square = pad_image_to_square
84
+ self.down_ratio = 1
85
+ self.repeats = repeats
86
+
87
+ def filter_hf_require_infos(self, dataset_infos):
88
+ ret = []
89
+ for dataset_info in dataset_infos:
90
+ description = dataset_info["description"]
91
+ image = dataset_info['file_name']
92
+ required_info = {'image': image, 'description': description}
93
+ ret.append(required_info)
94
+ return ret
95
+
96
+ def json_file_preprocess(self, data_path):
97
+ with open(data_path, 'r') as f:
98
+ json_file = json.load(f)
99
+
100
+ ret = []
101
+ for item in json_file:
102
+ if len(item["description"]) != len(item["annotation"]):
103
+ print("The number of description is not equal to seg !!!")
104
+ else:
105
+ ret.append(item)
106
+
107
+ if self.debug:
108
+ ret = ret[:10000]
109
+ return ret
110
+
111
+ @property
112
+ def modality_length(self):
113
+ length_list = []
114
+ for data_dict in self.text_data:
115
+ cur_len = len(data_dict['input_ids'])
116
+ if data_dict.get('image', None) is None:
117
+ cur_len = -cur_len
118
+ length_list.append(cur_len)
119
+ length_list = length_list * self.repeats
120
+ return length_list
121
+
122
+ def __len__(self):
123
+ return len(self.text_data) * self.repeats
124
+
125
+ def real_len(self):
126
+ return len(self.text_data)
127
+
128
+ def decode_mask(self, object_masks, ori_height, ori_width):
129
+ binary_masks = []
130
+ for object_mask in object_masks:
131
+ binary_mask = np.zeros((ori_height, ori_width), dtype=np.uint8)
132
+ for seg in object_mask:
133
+ rles = mask.frPyObjects([seg], ori_height, ori_width)
134
+ m = mask.decode(rles)
135
+ m = m.astype(np.uint8)
136
+ binary_mask += m.squeeze()
137
+
138
+ binary_masks.append(binary_mask)
139
+ if len(binary_masks) == 0:
140
+ return None
141
+ masks = np.stack(binary_masks, axis=0)
142
+ if self.pad_image_to_square:
143
+ masks = expand2square_mask(masks)
144
+ masks = torch.from_numpy(masks)
145
+ masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio, self.image_w // self.down_ratio), mode='nearest').squeeze(0)
146
+ return masks
147
+
148
+ def __getitem__(self, index):
149
+ index = index % self.real_len()
150
+ data_dict = copy.deepcopy(self.json_data[index])
151
+ data_dict.update(self.text_data[index])
152
+
153
+ if data_dict.get('image', None) is not None:
154
+ image_file = data_dict['image']
155
+ image = Image.open(os.path.join(self.image_folder,
156
+ image_file)).convert('RGB')
157
+ ori_width, ori_height = image.size
158
+ if self.pad_image_to_square:
159
+ image = expand2square(
160
+ image,
161
+ tuple(
162
+ int(x * 255) for x in self.image_processor.image_mean))
163
+ image = self.image_processor.preprocess(
164
+ image, return_tensors='pt')['pixel_values'][0]
165
+ data_dict['pixel_values'] = image
166
+
167
+ # process and get masks
168
+ annotations = data_dict['annotation']
169
+ sampled_inds = data_dict['sampled_inds']
170
+ annotations = [annotations[idx]['segmentation'] for idx in sampled_inds]
171
+ data_dict['regions'] = self.decode_mask(annotations, ori_height=ori_height, ori_width=ori_width)
172
+
173
+ if data_dict['regions'] is None or len(data_dict['regions']) != len(sampled_inds):
174
+ print("Bad data item !!!")
175
+ return self.__getitem__(0)
176
+ seg_region_idx = data_dict['seg_region_idx']
177
+ if len(seg_region_idx) == 0:
178
+ data_dict['masks'] = None
179
+ else:
180
+ data_dict['masks'] = data_dict['regions'][seg_region_idx]
181
+ else:
182
+ if hasattr(self.image_processor, 'crop_size'):
183
+ crop_size = self.image_processor.crop_size
184
+ else:
185
+ crop_size = self.image_processor.size
186
+ data_dict['pixel_values'] = torch.zeros(3, crop_size['height'],
187
+ crop_size['width'])
188
+ data_dict['masks'] = None
189
+ data_dict['regions'] = None
190
+ return data_dict
191
+
192
+ class OspreyRegionConversationDataset(Dataset):
193
+ def __init__(self,
194
+ image_folder,
195
+ image_processor,
196
+ data_path=None,
197
+ tokenizer=None,
198
+ offline_processed_text_folder=None,
199
+ max_dataset_length=None,
200
+ dataset_map_fn=None,
201
+ template_map_fn=None,
202
+ max_length=2048,
203
+ pad_image_to_square=False,
204
+ num_proc=32,
205
+ debug=False,
206
+ repeats=1):
207
+ super().__init__()
208
+
209
+ assert offline_processed_text_folder or (data_path and tokenizer)
210
+ self.debug = debug
211
+ if offline_processed_text_folder and data_path:
212
+ print_log(
213
+ 'Both `offline_processed_text_folder` and '
214
+ '`data_path` are set, and we load dataset from'
215
+ '`offline_processed_text_folder` '
216
+ f'({offline_processed_text_folder})',
217
+ logger='current',
218
+ level=logging.WARNING)
219
+
220
+ if offline_processed_text_folder is not None:
221
+ raise NotImplementedError
222
+ else:
223
+ json_data = self.json_file_preprocess(data_path)
224
+ self.json_data = json_data
225
+ hf_json_data = self.filter_hf_require_infos(json_data)
226
+ hf_json_data = DatasetDict({'train': HFDataset.from_list(hf_json_data)})
227
+ self.text_data = process_hf_dataset(
228
+ dataset=hf_json_data,
229
+ tokenizer=tokenizer,
230
+ max_length=max_length,
231
+ dataset_map_fn=dataset_map_fn,
232
+ template_map_fn=template_map_fn,
233
+ split='train',
234
+ max_dataset_length=max_dataset_length,
235
+ remove_unused_columns=False,
236
+ pack_to_max_length=False,
237
+ with_image_token=True,
238
+ map_num_proc=num_proc, # because limited mem
239
+ )
240
+
241
+ self.image_folder = image_folder
242
+ size = image_processor.crop_size
243
+ if isinstance(size, int):
244
+ self.image_h, self.image_w = size, size
245
+ else:
246
+ self.image_w, self.image_h = size
247
+
248
+ if isinstance(image_processor, dict) or isinstance(
249
+ image_processor, Config) or isinstance(image_processor,
250
+ ConfigDict):
251
+ self.image_processor = BUILDER.build(image_processor)
252
+ else:
253
+ self.image_processor = image_processor
254
+ self.pad_image_to_square = pad_image_to_square
255
+ self.down_ratio = 1
256
+ self.repeats = repeats
257
+
258
+ def filter_hf_require_infos(self, dataset_infos):
259
+ ret = []
260
+ for dataset_info in dataset_infos:
261
+ conversations = dataset_info["conversations"]
262
+ image = dataset_info['file_name']
263
+ num_regions = len(dataset_info['annotation'])
264
+ required_info = {'image': image, 'conversations': conversations,
265
+ 'num_regions': num_regions}
266
+ ret.append(required_info)
267
+ return ret
268
+
269
+ def json_file_preprocess(self, data_path):
270
+ with open(data_path, 'r') as f:
271
+ json_file = json.load(f)
272
+
273
+ # filter
274
+ ret = []
275
+ for dataset_info in json_file:
276
+ if 'annotation' not in dataset_info or len(dataset_info['annotation']) == 0:
277
+ print("The annotation is not valid, filter out!!!")
278
+ continue
279
+ ret.append(dataset_info)
280
+
281
+ if self.debug:
282
+ ret = ret[:10000]
283
+ return ret
284
+
285
+ @property
286
+ def modality_length(self):
287
+ length_list = []
288
+ for data_dict in self.text_data:
289
+ cur_len = len(data_dict['input_ids'])
290
+ if data_dict.get('image', None) is None:
291
+ cur_len = -cur_len
292
+ length_list.append(cur_len)
293
+ length_list = length_list * self.repeats
294
+ return length_list
295
+
296
+ def __len__(self):
297
+ return len(self.text_data) * self.repeats
298
+
299
+ def real_len(self):
300
+ return len(self.text_data)
301
+
302
+ def decode_mask(self, object_masks, ori_height, ori_width):
303
+ binary_masks = []
304
+ for object_mask in object_masks:
305
+ binary_mask = np.zeros((ori_height, ori_width), dtype=np.uint8)
306
+ for seg in object_mask:
307
+ rles = mask.frPyObjects([seg], ori_height, ori_width)
308
+ m = mask.decode(rles)
309
+ m = m.astype(np.uint8)
310
+ binary_mask += m.squeeze()
311
+ binary_masks.append(binary_mask)
312
+ if len(binary_masks) == 0:
313
+ return None
314
+ masks = np.stack(binary_masks, axis=0)
315
+ if self.pad_image_to_square:
316
+ masks = expand2square_mask(masks)
317
+ masks = torch.from_numpy(masks)
318
+ masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio, self.image_w // self.down_ratio), mode='nearest').squeeze(0)
319
+ return masks
320
+
321
+ def __getitem__(self, index):
322
+ index = index % self.real_len()
323
+ data_dict = copy.deepcopy(self.json_data[index])
324
+ data_dict.update(self.text_data[index])
325
+
326
+ if data_dict.get('image', None) is not None:
327
+ image_file = data_dict['image']
328
+ image = Image.open(os.path.join(self.image_folder,
329
+ image_file)).convert('RGB')
330
+ ori_width, ori_height = image.size
331
+ if self.pad_image_to_square:
332
+ image = expand2square(
333
+ image,
334
+ tuple(
335
+ int(x * 255) for x in self.image_processor.image_mean))
336
+ image = self.image_processor.preprocess(
337
+ image, return_tensors='pt')['pixel_values'][0]
338
+ data_dict['pixel_values'] = image
339
+
340
+ # process and get masks
341
+ annotations = data_dict['annotation']
342
+ annotations = [annotations[idx]['segmentation'] for idx in range(len(annotations))]
343
+ data_dict['regions'] = self.decode_mask(annotations, ori_height=ori_height, ori_width=ori_width)
344
+ if data_dict['regions'] is None:
345
+ return self.__getitem__(0)
346
+ data_dict['masks'] = None
347
+ else:
348
+ if hasattr(self.image_processor, 'crop_size'):
349
+ crop_size = self.image_processor.crop_size
350
+ else:
351
+ crop_size = self.image_processor.size
352
+ data_dict['pixel_values'] = torch.zeros(3, crop_size['height'],
353
+ crop_size['width'])
354
+ data_dict['masks'] = None
355
+ data_dict['regions'] = None
356
+ return data_dict
omg_llava/dataset/SemanticSegDataset.py ADDED
@@ -0,0 +1,725 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import glob
3
+ import json
4
+ import logging
5
+ import os
6
+ import torch
7
+ from datasets import Dataset as HFDataset
8
+ from datasets import DatasetDict, load_from_disk
9
+ from mmengine import print_log
10
+ from mmengine.config import Config, ConfigDict
11
+ from PIL import Image
12
+ from torch.utils.data import Dataset
13
+ import numpy as np
14
+ import torch.nn.functional as F
15
+ from pycocotools.coco import COCO
16
+
17
+ from xtuner.registry import BUILDER
18
+ from omg_llava.dataset.utils import expand2square, expand2square_mask
19
+ from xtuner.dataset.huggingface import process_hf_dataset
20
+ from omg_llava.dataset.process_functions.semantic_seg_process import semantic_seg_conversations, semantic_seg_gcg_format_conversations
21
+ import copy
22
+
23
+ class SemanticSegDataset(Dataset):
24
+ def __init__(self,
25
+ image_folder,
26
+ image_processor,
27
+ data_path=None,
28
+ tokenizer=None,
29
+ offline_processed_text_folder=None,
30
+ max_dataset_length=None,
31
+ dataset_map_fn=None,
32
+ template_map_fn=None,
33
+ max_length=2048,
34
+ pad_image_to_square=False,
35
+ num_proc=8,
36
+ debug=False,
37
+ repeats=1,
38
+ gcg_format=False):
39
+ super().__init__()
40
+ self.tokenizer = tokenizer
41
+ assert offline_processed_text_folder or (data_path and tokenizer)
42
+ self.debug = debug
43
+ if offline_processed_text_folder and data_path:
44
+ print_log(
45
+ 'Both `offline_processed_text_folder` and '
46
+ '`data_path` are set, and we load dataset from'
47
+ '`offline_processed_text_folder` '
48
+ f'({offline_processed_text_folder})',
49
+ logger='current',
50
+ level=logging.WARNING)
51
+
52
+ if offline_processed_text_folder is not None:
53
+ raise NotImplementedError
54
+ else:
55
+ self.image_label_datas = self.json_file_preprocess(data_path, image_folder)
56
+ if gcg_format:
57
+ conversations_datas = semantic_seg_gcg_format_conversations(self.classes)
58
+ else:
59
+ conversations_datas = semantic_seg_conversations(self.classes)
60
+ json_data = DatasetDict({'train': HFDataset.from_list(conversations_datas)})
61
+ self.text_data = process_hf_dataset(
62
+ dataset=json_data,
63
+ tokenizer=tokenizer,
64
+ max_length=max_length,
65
+ dataset_map_fn=dataset_map_fn,
66
+ template_map_fn=template_map_fn,
67
+ split='train',
68
+ max_dataset_length=max_dataset_length,
69
+ remove_unused_columns=False,
70
+ pack_to_max_length=False,
71
+ with_image_token=True,
72
+ map_num_proc=num_proc, # because limited mem
73
+ )
74
+
75
+ self.clsid2convs = self.construct_cls2convs_dict()
76
+ self.image_folder = image_folder
77
+ size = image_processor.crop_size
78
+ if isinstance(size, int):
79
+ self.image_h, self.image_w = size, size
80
+ else:
81
+ self.image_w, self.image_h = size
82
+
83
+ if isinstance(image_processor, dict) or isinstance(
84
+ image_processor, Config) or isinstance(image_processor,
85
+ ConfigDict):
86
+ self.image_processor = BUILDER.build(image_processor)
87
+ else:
88
+ self.image_processor = image_processor
89
+ self.pad_image_to_square = pad_image_to_square
90
+ self.down_ratio = 1
91
+ self.repeats = repeats
92
+
93
+ def construct_cls2convs_dict(self):
94
+ ret = {}
95
+ for conv_item in self.text_data:
96
+ cls_id = conv_item['class_id']
97
+ if cls_id in ret.keys():
98
+ ret[cls_id].append(conv_item)
99
+ else:
100
+ ret[cls_id] = [conv_item]
101
+ return ret
102
+
103
+ def json_file_preprocess(self, data_path, image_folder):
104
+ # ade20k
105
+ with open(data_path, 'r') as file:
106
+ ade20k_classes = json.load(file)
107
+ ade20k_image_dir = image_folder
108
+ ade20k_images = [os.path.join(ade20k_image_dir, img) for img in os.listdir(ade20k_image_dir) if
109
+ img.endswith('.jpg')]
110
+ ade20k_labels = [img.replace(".jpg", ".png").replace("images", "annotations") for img in ade20k_images]
111
+ self.classes = np.array(ade20k_classes)
112
+
113
+ ret = []
114
+ for image, label in zip(ade20k_images, ade20k_labels):
115
+ ret.append({"image": image, "label": label})
116
+ if self.debug:
117
+ return ret[:1000]
118
+ return ret
119
+
120
+ def __len__(self):
121
+ return len(self.image_label_datas) * self.repeats
122
+
123
+ @property
124
+ def modality_length(self):
125
+ length_list = []
126
+ for data_dict in self.image_label_datas:
127
+ length_list.append(-100)
128
+ length_list = length_list * self.repeats
129
+ return length_list
130
+
131
+ def real_len(self):
132
+ return len(self.image_label_datas)
133
+
134
+ def decode_mask(self, label_path):
135
+ label = np.array(Image.open(label_path))
136
+
137
+ # ade 20k
138
+ label = np.where(label == 0, 255, label - 1)
139
+ unique_labels = [lbl for lbl in np.unique(label) if lbl != 255]
140
+ if not unique_labels:
141
+ return None, None
142
+
143
+ # only choose 1
144
+ selected_labels = np.random.choice(
145
+ unique_labels, 1, replace=False
146
+ )
147
+ label = torch.from_numpy(label).long()
148
+ masks = torch.stack([label == class_id for class_id in selected_labels], dim=0)
149
+
150
+ masks = masks.numpy()
151
+ if self.pad_image_to_square:
152
+ masks = expand2square_mask(masks)
153
+
154
+ masks = torch.from_numpy(masks).to(torch.float32)
155
+ masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio,
156
+ self.image_w // self.down_ratio), mode='nearest').squeeze(0)
157
+ return masks, selected_labels[0]
158
+
159
+ def __getitem__(self, index):
160
+ index = index % self.real_len()
161
+ data_dict = copy.deepcopy(self.image_label_datas[index])
162
+
163
+ if data_dict.get('image', None) is not None:
164
+ image_file = data_dict['image']
165
+ image = Image.open(image_file).convert('RGB')
166
+ ori_width, ori_height = image.size
167
+ if self.pad_image_to_square:
168
+ image = expand2square(
169
+ image,
170
+ tuple(
171
+ int(x * 255) for x in self.image_processor.image_mean))
172
+ image = self.image_processor.preprocess(
173
+ image, return_tensors='pt')['pixel_values'][0]
174
+ data_dict['pixel_values'] = image
175
+
176
+ # process and get masks
177
+ data_dict['masks'], class_id = self.decode_mask(data_dict['label'])
178
+ if class_id is None:
179
+ return self.__getitem__(0)
180
+ conv_datas = self.clsid2convs[class_id]
181
+ selected_idx = np.random.randint(0, len(conv_datas))
182
+ data_dict.update(conv_datas[selected_idx])
183
+ else:
184
+ if hasattr(self.image_processor, 'crop_size'):
185
+ crop_size = self.image_processor.crop_size
186
+ else:
187
+ crop_size = self.image_processor.size
188
+ data_dict['pixel_values'] = torch.zeros(3, crop_size['height'],
189
+ crop_size['width'])
190
+ data_dict['masks'] = None
191
+ return data_dict
192
+
193
+ class ADE20kSemanticSegDataset(SemanticSegDataset):
194
+ def __init__(self,
195
+ image_folder,
196
+ image_processor,
197
+ data_path=None,
198
+ tokenizer=None,
199
+ offline_processed_text_folder=None,
200
+ max_dataset_length=None,
201
+ dataset_map_fn=None,
202
+ template_map_fn=None,
203
+ max_length=2048,
204
+ pad_image_to_square=False,
205
+ num_proc=8,
206
+ debug=False,
207
+ repeats=1,
208
+ gcg_format=False):
209
+ super().__init__(
210
+ image_folder=image_folder,
211
+ image_processor=image_processor,
212
+ data_path=data_path,
213
+ tokenizer=tokenizer,
214
+ offline_processed_text_folder=offline_processed_text_folder,
215
+ max_dataset_length=max_dataset_length,
216
+ dataset_map_fn=dataset_map_fn,
217
+ template_map_fn=template_map_fn,
218
+ max_length=max_length,
219
+ pad_image_to_square=pad_image_to_square,
220
+ num_proc=num_proc,
221
+ debug=debug,
222
+ repeats=repeats,
223
+ gcg_format=gcg_format,
224
+ )
225
+
226
+ class COCOStuffSemanticSegDataset(SemanticSegDataset):
227
+ def __init__(self,
228
+ image_folder,
229
+ image_processor,
230
+ data_path=None,
231
+ tokenizer=None,
232
+ offline_processed_text_folder=None,
233
+ max_dataset_length=None,
234
+ dataset_map_fn=None,
235
+ template_map_fn=None,
236
+ max_length=2048,
237
+ pad_image_to_square=False,
238
+ num_proc=8,
239
+ debug=False,
240
+ repeats=1,
241
+ label_path=None,
242
+ gcg_format=False,):
243
+ self.label_path = label_path
244
+ super().__init__(
245
+ image_folder=image_folder,
246
+ image_processor=image_processor,
247
+ data_path=data_path,
248
+ tokenizer=tokenizer,
249
+ offline_processed_text_folder=offline_processed_text_folder,
250
+ max_dataset_length=max_dataset_length,
251
+ dataset_map_fn=dataset_map_fn,
252
+ template_map_fn=template_map_fn,
253
+ max_length=max_length,
254
+ pad_image_to_square=pad_image_to_square,
255
+ num_proc=num_proc,
256
+ debug=debug,
257
+ repeats=repeats,
258
+ gcg_format=gcg_format,
259
+ )
260
+ self.cocostuff_class2index = {c: i for i, c in enumerate(self.classes)}
261
+
262
+ def json_file_preprocess(self, data_path, image_folder):
263
+ # coco stuff
264
+ assert self.label_path is not None
265
+ with open(data_path, 'r') as file:
266
+ cocostuff_classes = [line.strip().split(": ")[-1] for line in file.readlines()[1:]]
267
+ coco_stuff_image_dir = image_folder
268
+ coco_stuff_label_dir = self.label_path
269
+ coco_stuff_labels = glob.glob(os.path.join(coco_stuff_label_dir, "*.png"))
270
+
271
+ coco_stuff_images = [label.replace(".png", ".jpg").replace(coco_stuff_label_dir, coco_stuff_image_dir)
272
+ for label in coco_stuff_labels]
273
+
274
+ self.classes = np.array(cocostuff_classes)
275
+
276
+ ret = []
277
+ for image, label in zip(coco_stuff_images, coco_stuff_labels):
278
+ ret.append({"image": image, "label": label})
279
+ if self.debug:
280
+ return ret[:1000]
281
+ return ret
282
+
283
+ def decode_mask(self, label_path):
284
+ label = np.array(Image.open(label_path))
285
+
286
+ # coco stuff
287
+ ignored_classes = [index for class_name, index in self.cocostuff_class2index.items() if
288
+ "-" in class_name]
289
+ label = np.where(np.isin(label, ignored_classes), 255, label)
290
+
291
+ unique_labels = [lbl for lbl in np.unique(label) if lbl != 255]
292
+ if not unique_labels:
293
+ print("No valid label !!!")
294
+ return None, None
295
+
296
+ # only choose 1
297
+ selected_labels = np.random.choice(
298
+ unique_labels, 1, replace=False
299
+ )
300
+ label = torch.from_numpy(label).long()
301
+ masks = torch.stack([label == class_id for class_id in selected_labels], dim=0)
302
+
303
+ masks = masks.numpy()
304
+ if self.pad_image_to_square:
305
+ masks = expand2square_mask(masks)
306
+
307
+ masks = torch.from_numpy(masks).to(torch.float32)
308
+ masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio,
309
+ self.image_w // self.down_ratio), mode='nearest').squeeze(0)
310
+ return masks, selected_labels[0]
311
+
312
+ class MapillarySemanticSegDataset(SemanticSegDataset):
313
+ def __init__(self,
314
+ image_folder,
315
+ image_processor,
316
+ data_path=None,
317
+ tokenizer=None,
318
+ offline_processed_text_folder=None,
319
+ max_dataset_length=None,
320
+ dataset_map_fn=None,
321
+ template_map_fn=None,
322
+ max_length=2048,
323
+ pad_image_to_square=False,
324
+ num_proc=8,
325
+ debug=False,
326
+ repeats=1,
327
+ label_path=None,
328
+ gcg_format=False,):
329
+ self.label_path = label_path
330
+ super().__init__(
331
+ image_folder=image_folder,
332
+ image_processor=image_processor,
333
+ data_path=data_path,
334
+ tokenizer=tokenizer,
335
+ offline_processed_text_folder=offline_processed_text_folder,
336
+ max_dataset_length=max_dataset_length,
337
+ dataset_map_fn=dataset_map_fn,
338
+ template_map_fn=template_map_fn,
339
+ max_length=max_length,
340
+ pad_image_to_square=pad_image_to_square,
341
+ num_proc=num_proc,
342
+ debug=debug,
343
+ repeats=repeats,
344
+ gcg_format=gcg_format,
345
+ )
346
+
347
+ def json_file_preprocess(self, data_path, image_folder):
348
+ assert self.label_path is not None
349
+ # mapillary
350
+ with open(data_path, 'r') as file:
351
+ mapillary_classes = json.load(file)["labels"]
352
+ mapillary_classes = [cls["readable"].lower() for cls in mapillary_classes]
353
+
354
+ mapillary_labels = sorted(
355
+ glob.glob(os.path.join(self.label_path, "*.png")))
356
+ mapillary_images = [
357
+ label.replace(".png", ".jpg").replace(self.label_path, image_folder)
358
+ for label in mapillary_labels]
359
+
360
+ self.classes = np.array(mapillary_classes)
361
+
362
+ ret = []
363
+ for image, label in zip(mapillary_images, mapillary_labels):
364
+ ret.append({"image": image, "label": label})
365
+ if self.debug:
366
+ return ret[:1000]
367
+ return ret
368
+
369
+ def decode_mask(self, label_path):
370
+ label = np.array(Image.open(label_path))
371
+
372
+ ignored_classes = [index for index, class_name in enumerate(self.classes) if
373
+ "-" in class_name or '(' in class_name or
374
+ 'unlabeled' in class_name]
375
+ label = np.where(np.isin(label, ignored_classes), 255, label)
376
+ unique_labels = [lbl for lbl in np.unique(label) if lbl != 255]
377
+ if not unique_labels:
378
+ print("No valid label !!!")
379
+ return None, None
380
+ # only choose 1
381
+ selected_labels = np.random.choice(
382
+ unique_labels, 1, replace=False
383
+ )
384
+ label = torch.from_numpy(label).long()
385
+ masks = torch.stack([label == class_id for class_id in selected_labels], dim=0)
386
+
387
+ masks = masks.numpy()
388
+ if self.pad_image_to_square:
389
+ masks = expand2square_mask(masks)
390
+
391
+ masks = torch.from_numpy(masks).to(torch.float32)
392
+ masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio,
393
+ self.image_w // self.down_ratio), mode='nearest').squeeze(0)
394
+ return masks, selected_labels[0]
395
+
396
+ class PascalPartSemanticSegDataset(Dataset):
397
+ def __init__(self,
398
+ image_folder,
399
+ image_processor,
400
+ data_path=None,
401
+ tokenizer=None,
402
+ offline_processed_text_folder=None,
403
+ max_dataset_length=None,
404
+ dataset_map_fn=None,
405
+ template_map_fn=None,
406
+ max_length=2048,
407
+ pad_image_to_square=False,
408
+ num_proc=8,
409
+ debug=False,
410
+ repeats=1):
411
+ super().__init__()
412
+ self.tokenizer = tokenizer
413
+ assert offline_processed_text_folder or (data_path and tokenizer)
414
+ self.debug = debug
415
+ if offline_processed_text_folder and data_path:
416
+ print_log(
417
+ 'Both `offline_processed_text_folder` and '
418
+ '`data_path` are set, and we load dataset from'
419
+ '`offline_processed_text_folder` '
420
+ f'({offline_processed_text_folder})',
421
+ logger='current',
422
+ level=logging.WARNING)
423
+
424
+ if offline_processed_text_folder is not None:
425
+ raise NotImplementedError
426
+ else:
427
+ json_datas = self.json_file_preprocess(data_path)
428
+ json_data = DatasetDict({'train': HFDataset.from_list(json_datas)})
429
+ self.text_data = process_hf_dataset(
430
+ dataset=json_data,
431
+ tokenizer=tokenizer,
432
+ max_length=max_length,
433
+ dataset_map_fn=dataset_map_fn,
434
+ template_map_fn=template_map_fn,
435
+ split='train',
436
+ max_dataset_length=max_dataset_length,
437
+ remove_unused_columns=False,
438
+ pack_to_max_length=False,
439
+ with_image_token=True,
440
+ map_num_proc=num_proc, # because limited mem
441
+ )
442
+
443
+ self.image_folder = image_folder
444
+ size = image_processor.crop_size
445
+ if isinstance(size, int):
446
+ self.image_h, self.image_w = size, size
447
+ else:
448
+ self.image_w, self.image_h = size
449
+
450
+ if isinstance(image_processor, dict) or isinstance(
451
+ image_processor, Config) or isinstance(image_processor,
452
+ ConfigDict):
453
+ self.image_processor = BUILDER.build(image_processor)
454
+ else:
455
+ self.image_processor = image_processor
456
+ self.pad_image_to_square = pad_image_to_square
457
+ self.down_ratio = 1
458
+ self.repeats = repeats
459
+
460
+ def json_file_preprocess(self, data_path):
461
+ pascal_part_api = COCO(data_path)
462
+ all_classes = pascal_part_api.loadCats(pascal_part_api.getCatIds())
463
+ class_map_pascal_part = {}
464
+ for cat in all_classes:
465
+ cat_main, cat_part = cat["name"].strip().split(":")
466
+ name = (cat_main, cat_part)
467
+ class_map_pascal_part[cat["id"]] = name
468
+ img_ids = pascal_part_api.getImgIds()
469
+ self.classes = class_map_pascal_part
470
+ self.coco_api = pascal_part_api
471
+
472
+ img_infos = [self.coco_api.loadImgs([img_id])[0] for img_id in img_ids]
473
+ valid_img_infos = []
474
+ for img_info in img_infos:
475
+ annotation_ids = self.coco_api.getAnnIds(imgIds=img_info["id"])
476
+ annotations = self.coco_api.loadAnns(annotation_ids)
477
+ if not annotations:
478
+ continue
479
+
480
+ # sampled to max number as 5
481
+ sampled_anns = np.random.choice(annotations, 5, replace=False) if len(
482
+ annotations
483
+ ) >= 5 else annotations
484
+
485
+ selected_labels = []
486
+ for ann in sampled_anns:
487
+ category_id = ann["category_id"]
488
+ sampled_cls = self.classes[category_id]
489
+ if isinstance(sampled_cls, tuple):
490
+ obj, part = sampled_cls
491
+ name = f"{obj} {part}" if random.random() < 0.5 else f"the {part} of the {obj}"
492
+ else:
493
+ name = sampled_cls
494
+ selected_labels.append(name)
495
+
496
+ img_info.update({"annotations": sampled_anns,
497
+ "selected_labels": selected_labels})
498
+ valid_img_infos.append(img_info)
499
+
500
+ if self.debug:
501
+ return valid_img_infos[:1000]
502
+ return valid_img_infos
503
+
504
+ def __len__(self):
505
+ return len(self.text_data) * self.repeats
506
+
507
+ @property
508
+ def modality_length(self):
509
+ length_list = []
510
+ for data_dict in self.text_data:
511
+ cur_len = len(data_dict['input_ids'])
512
+ if data_dict.get('image', None) is None:
513
+ cur_len = -cur_len
514
+ length_list.append(cur_len)
515
+ length_list = length_list * self.repeats
516
+ return length_list
517
+
518
+ def real_len(self):
519
+ return len(self.text_data)
520
+
521
+ def decode_mask(self, annotations):
522
+
523
+ try:
524
+ masks = [self.coco_api.annToMask(ann) for ann in annotations]
525
+ except Exception as e:
526
+ print(f"Error generating mask: {e}")
527
+ return None
528
+
529
+ masks = np.stack(masks, axis=0)
530
+ if self.pad_image_to_square:
531
+ masks = expand2square_mask(masks)
532
+ masks = torch.from_numpy(masks)
533
+ masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio,
534
+ self.image_w // self.down_ratio), mode='nearest').squeeze(0)
535
+ return masks
536
+
537
+ def __getitem__(self, index):
538
+ index = index % self.real_len()
539
+ data_dict = copy.deepcopy(self.text_data[index])
540
+
541
+ if data_dict.get('image', None) is not None:
542
+ image_file = data_dict['image']
543
+ image_file = os.path.join(self.image_folder, image_file)
544
+ image = Image.open(image_file).convert('RGB')
545
+ ori_width, ori_height = image.size
546
+ if self.pad_image_to_square:
547
+ image = expand2square(
548
+ image,
549
+ tuple(
550
+ int(x * 255) for x in self.image_processor.image_mean))
551
+ image = self.image_processor.preprocess(
552
+ image, return_tensors='pt')['pixel_values'][0]
553
+ data_dict['pixel_values'] = image
554
+
555
+ # process and get masks
556
+ data_dict['masks'] = self.decode_mask(data_dict['annotations'])
557
+ if data_dict['masks'] is None:
558
+ return self.__getitem__(0)
559
+ else:
560
+ if hasattr(self.image_processor, 'crop_size'):
561
+ crop_size = self.image_processor.crop_size
562
+ else:
563
+ crop_size = self.image_processor.size
564
+ data_dict['pixel_values'] = torch.zeros(3, crop_size['height'],
565
+ crop_size['width'])
566
+ data_dict['masks'] = None
567
+ return data_dict
568
+
569
+ class PacoSemanticSegDataset(PascalPartSemanticSegDataset):
570
+ def __init__(self,
571
+ image_folder,
572
+ image_processor,
573
+ data_path=None,
574
+ tokenizer=None,
575
+ offline_processed_text_folder=None,
576
+ max_dataset_length=None,
577
+ dataset_map_fn=None,
578
+ template_map_fn=None,
579
+ max_length=2048,
580
+ pad_image_to_square=False,
581
+ num_proc=8,
582
+ debug=False,
583
+ repeats=1,):
584
+ self.tokenizer = tokenizer
585
+ assert offline_processed_text_folder or (data_path and tokenizer)
586
+ self.debug = debug
587
+ if offline_processed_text_folder and data_path:
588
+ print_log(
589
+ 'Both `offline_processed_text_folder` and '
590
+ '`data_path` are set, and we load dataset from'
591
+ '`offline_processed_text_folder` '
592
+ f'({offline_processed_text_folder})',
593
+ logger='current',
594
+ level=logging.WARNING)
595
+
596
+ if offline_processed_text_folder is not None:
597
+ raise NotImplementedError
598
+ else:
599
+ json_datas = self.json_file_preprocess(data_path)
600
+ self.json_datas = json_datas
601
+ json_datas = self.only_get_hf_map_infos()
602
+ json_data = DatasetDict({'train': HFDataset.from_list(json_datas)})
603
+ self.text_data = process_hf_dataset(
604
+ dataset=json_data,
605
+ tokenizer=tokenizer,
606
+ max_length=max_length,
607
+ dataset_map_fn=dataset_map_fn,
608
+ template_map_fn=template_map_fn,
609
+ split='train',
610
+ max_dataset_length=max_dataset_length,
611
+ remove_unused_columns=False,
612
+ pack_to_max_length=False,
613
+ with_image_token=True,
614
+ map_num_proc=num_proc, # because limited mem
615
+ )
616
+
617
+ self.image_folder = image_folder
618
+ size = image_processor.crop_size
619
+ if isinstance(size, int):
620
+ self.image_h, self.image_w = size, size
621
+ else:
622
+ self.image_w, self.image_h = size
623
+
624
+ if isinstance(image_processor, dict) or isinstance(
625
+ image_processor, Config) or isinstance(image_processor,
626
+ ConfigDict):
627
+ self.image_processor = BUILDER.build(image_processor)
628
+ else:
629
+ self.image_processor = image_processor
630
+ self.pad_image_to_square = pad_image_to_square
631
+ self.down_ratio = 1
632
+ self.repeats = repeats
633
+
634
+ def only_get_hf_map_infos(self):
635
+ ret = []
636
+ for json_data in self.json_datas:
637
+ ret.append({'file_name': json_data['file_name'],
638
+ 'selected_labels': json_data['selected_labels']})
639
+ return ret
640
+
641
+ def json_file_preprocess(self, data_path):
642
+ paco_api = COCO(data_path)
643
+ all_classes = paco_api.loadCats(paco_api.getCatIds())
644
+ class_map_paco = {}
645
+ for cat in all_classes:
646
+ cat_split = cat["name"].strip().split(":")
647
+ if len(cat_split) == 1:
648
+ name = cat_split[0].split("_(")[0]
649
+ else:
650
+ assert len(cat_split) == 2
651
+ obj, part = cat_split
652
+ obj = obj.split("_(")[0]
653
+ part = part.split("_(")[0]
654
+ name = (obj, part)
655
+ class_map_paco[cat["id"]] = name
656
+
657
+ img_ids = paco_api.getImgIds()
658
+ self.classes = class_map_paco
659
+ self.coco_api = paco_api
660
+
661
+ img_infos = [self.coco_api.loadImgs([img_id])[0] for img_id in img_ids]
662
+ valid_img_infos = []
663
+ for img_info in img_infos:
664
+ annotation_ids = self.coco_api.getAnnIds(imgIds=img_info["id"])
665
+ annotations = self.coco_api.loadAnns(annotation_ids)
666
+ if not annotations:
667
+ continue
668
+
669
+ # sampled to max number as 5
670
+ sampled_anns = np.random.choice(annotations, 5, replace=False) if len(
671
+ annotations
672
+ ) >= 5 else annotations
673
+
674
+ selected_labels = []
675
+ for ann in sampled_anns:
676
+ category_id = ann["category_id"]
677
+ sampled_cls = self.classes[category_id]
678
+ if isinstance(sampled_cls, tuple):
679
+ obj, part = sampled_cls
680
+ name = f"{obj} {part}" if random.random() < 0.5 else f"the {part} of the {obj}"
681
+ else:
682
+ name = sampled_cls
683
+ selected_labels.append(name)
684
+
685
+ img_info.update({"annotations": sampled_anns,
686
+ "selected_labels": selected_labels})
687
+ valid_img_infos.append(img_info)
688
+
689
+ if self.debug:
690
+ return valid_img_infos[:1000]
691
+ return valid_img_infos
692
+
693
+ def __getitem__(self, index):
694
+ index = index % self.real_len()
695
+ data_dict = copy.deepcopy(self.text_data[index])
696
+ data_dict.update(self.json_datas[index])
697
+
698
+ if data_dict.get('image', None) is not None:
699
+ image_file = data_dict['image']
700
+ image_file = os.path.join(self.image_folder, image_file)
701
+ image = Image.open(image_file).convert('RGB')
702
+ ori_width, ori_height = image.size
703
+ if self.pad_image_to_square:
704
+ image = expand2square(
705
+ image,
706
+ tuple(
707
+ int(x * 255) for x in self.image_processor.image_mean))
708
+ image = self.image_processor.preprocess(
709
+ image, return_tensors='pt')['pixel_values'][0]
710
+ data_dict['pixel_values'] = image
711
+
712
+ # process and get masks
713
+ data_dict['masks'] = self.decode_mask(data_dict['annotations'])
714
+ if data_dict['masks'] is None:
715
+ return self.__getitem__(0)
716
+ else:
717
+ if hasattr(self.image_processor, 'crop_size'):
718
+ crop_size = self.image_processor.crop_size
719
+ else:
720
+ crop_size = self.image_processor.size
721
+ data_dict['pixel_values'] = torch.zeros(3, crop_size['height'],
722
+ crop_size['width'])
723
+ data_dict['masks'] = None
724
+ return data_dict
725
+
omg_llava/dataset/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .CombineDataset import CombineDataset
2
+ from .GCGDataset import RefCOCOgGCGDataset, OpenPsgGCGDataset, GranDfGCGDataset, FlickrGCGDataset
3
+ from .SemanticSegDataset import SemanticSegDataset, ADE20kSemanticSegDataset,\
4
+ COCOStuffSemanticSegDataset,MapillarySemanticSegDataset, PascalPartSemanticSegDataset,\
5
+ PacoSemanticSegDataset
6
+ from .MDPVPointsDataset import MDPVPointDetailedCaptionDataset, MDPVPointBriefCaptionDataset
7
+ from .ReferringSegDataset import RefcocoReferringSegDataset, Refcoco_plus_ReferringSegDataset,\
8
+ Refcocog_ReferringSegDataset, Refclef_ReferringSegDataset
9
+ from .RegionCaptionDataset import OspreyRegionCaptionDataset, OspreyRegionConversationDataset
10
+ from .LlavaDataset import LLaVADataset
11
+ from .DecoupledGCGDataset import DecoupledRefCOCOgGCGDataset, DecoupledOpenPsgGCGDataset,\
12
+ DecoupledGranDfGCGDataset, DecoupledFlickrGCGDataset
13
+
14
+
15
+ from .process_functions import glamm_openpsg_map_fn, glamm_refcocog_map_fn,\
16
+ glamm_granf_map_fn, glamm_flickr_map_fn,\
17
+ semantic_seg_map_fn, pascal_part_map_fn,\
18
+ semantic_seg_gcg_format_map_fn, pascal_part_gcg_format_map_fn,\
19
+ referring_seg_map_fn, referring_seg_gcg_format_map_fn,\
20
+ osprey_region_caption_map_fn, osprey_region_caption_gcg_format_map_fn,\
21
+ osprey_region_conversation_map_fn,\
22
+ mdpv_points_map_fn
23
+
24
+ from .process_functions import glamm_refcocog_decoupled_given_objects_map_fn, glamm_refcocog_decoupled_given_description_map_fn,\
25
+ glamm_granf_decoupled_given_description_map_fn, glamm_granf_decoupled_given_objects_map_fn,\
26
+ glamm_flickr_decoupled_given_description_map_fn, glamm_flickr_decoupled_given_objects_map_fn,\
27
+ glamm_openpsg_decoupled_given_objects_map_fn, glamm_openpsg_decoupled_given_description_map_fn
28
+
29
+ from .collect_fns import omg_llava_collate_fn
omg_llava/dataset/__pycache__/CombineDataset.cpython-310.pyc ADDED
Binary file (2.28 kB). View file
 
omg_llava/dataset/__pycache__/DecoupledGCGDataset.cpython-310.pyc ADDED
Binary file (9.73 kB). View file
 
omg_llava/dataset/__pycache__/GCGDataset.cpython-310.pyc ADDED
Binary file (9.29 kB). View file
 
omg_llava/dataset/__pycache__/LlavaDataset.cpython-310.pyc ADDED
Binary file (3.84 kB). View file
 
omg_llava/dataset/__pycache__/MDPVPointsDataset.cpython-310.pyc ADDED
Binary file (6.39 kB). View file
 
omg_llava/dataset/__pycache__/ReferringSegDataset.cpython-310.pyc ADDED
Binary file (9.05 kB). View file
 
omg_llava/dataset/__pycache__/RegionCaptionDataset.cpython-310.pyc ADDED
Binary file (8.45 kB). View file
 
omg_llava/dataset/__pycache__/SemanticSegDataset.cpython-310.pyc ADDED
Binary file (19.1 kB). View file