File size: 12,704 Bytes
4f09ecf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1165ca
 
4f09ecf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9f0c04
4f09ecf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
############################################################################################
# Name: test_parameter.py
#
# NOTE: Change all your hyper-params here!
# Simple How-To Guide: 
# 1. CLIP TTA: USE_CLIP_PREDS = True, EXECUTE_TTA = True
# 2. CLIP (No TTA): USE_CLIP_PREDS = True, EXECUTE_TTA = False 
# 3. Custom masks (e.g. LLMSeg): USE_CLIP_PREDS = False, EXECUTE_TTA = False 
############################################################################################

import os
import sys
sys.modules['TRAINING'] = False           # False = Inference Testing    

###############################################################
OPT_VARS = {}
def getenv(var_name, default=None, cast_type=str):
    try:
        value = os.environ.get(var_name, None)
        if value is None:
            result = default
        elif cast_type == bool:
            result = value.lower() in ("true", "1", "yes")
        else:
            result = cast_type(value)
    except (ValueError, TypeError):
        result = default

    OPT_VARS[var_name] = result  # Log the result
    return result
###############################################################

POLICY = getenv("POLICY", default="RL", cast_type=str)
# TAX_HIERARCHY_TO_CONDENSE = 3  # Remove N layers of the taxonomy hierarchy from the back

NUM_TEST = 800  # Overriden if TAXABIND_TTA is True and performing search ds val
NUM_RUN = 1
SAVE_GIFS = getenv("SAVE_GIFS", default=True, cast_type=bool)  # do you want to save GIFs
SAVE_TRAJECTORY = False  # do you want to save per-step metrics
SAVE_LENGTH = False  # do you want to save per-episode metrics
VIZ_GRAPH_EDGES = False  # do you want to visualize the graph edges
# MODEL_NAME = "pure_coverage_no_pose_obs_230325_stage1.pth"  # checkpoint.pth
# MODEL_NAME = "STAGE2_20k_vlm_search_24x24_290225_NO_TARGET_REWARDS_600steps.pth" # checkpoint.pth
# MODEL_NAME = "vlm_search_24x24_230225_NO_TARGET_REWARDS_600steps.pth" # checkpoint.pth
# MODEL_NAME = "vlm_search_20x20_200125_256steps_CORRECT_REWARDS.pth" # checkpoint.pth
MODEL_NAME = "STAGE1_vlm_search_24x24_040425_no_tgt_rewards_iNAT_DS_16k.pth" 

NUM_EPS_STEPS = getenv("NUM_EPS_STEPS", default=384, cast_type=int)
TERMINATE_ON_TGTS_FOUND = True  # Whether to terminate episode when all targets found
FORCE_LOGGING_DONE_TGTS_FOUND = True  # Whether to force csv logging when all targets found
FIX_START_POSITION = getenv("FIX_START_POSITION", default=True, cast_type=bool)  # Whether to fix the starting position of the robots (middle index)

## Whether to override initial score mask from CLIP 
USE_CLIP_PREDS = getenv("USE_CLIP_PREDS", default=True, cast_type=bool)     # If false, use custom masks from OVERRIDE_MASK_DIR
OVERRIDE_MASK_DIR = getenv("OVERRIDE_MASK_DIR", default="/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/FINAL_COUNTS_COMBINED_top4000_pos_top55000_neg/Baselines_Masks/LLaVA_GroundedSAM/v1_010425/LLaVA_GroundedSAM_out_mask_val_in/out_mask_val_in", cast_type=str)
# OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/pred/Animalia_Chordata_Mammalia_Rodentia"
# OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/pred/Animalia_Chordata_Mammalia_Artiodactyla"
# OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/pred/Animalia_Arthropoda_Arachnida_Araneae"
# OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/pred/Plantae_Tracheophyta_Magnoliopsida_Caryophyllales"
# OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/FINAL_COUNTS_COMBINED_top4000_pos_top55000_neg/Baselines_Masks/LLaVA_GroundedSAM/v1_010425/LLaVA_GroundedSAM_out_mask_val_in/out_mask_val_in"
# OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/FINAL_COUNTS_COMBINED_top4000_pos_top55000_neg/Baselines_Masks/LLaVA_GroundedSAM/v1_010425/LLaVA_GroundedSAM_out_mask_val_out/out_mask_val_out"

# Used to calcultae info_gain metric
OVERRIDE_GT_MASK_DIR = getenv("OVERRIDE_GT_MASK_DIR", default="", cast_type=str)
# OVERRIDE_GT_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/gt/val_in_4gsnet_score_map"
# OVERRIDE_GT_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/gt/val_out_4gsnet_score_map"

#######################################################################
# iNAT TTA
#######################################################################

# Query Params
QUERY_TAX = getenv("QUERY_TAX", default="", cast_type=str)  # "" = Test all tax
# QUERY_TAX = "Animalia Chordata Mammalia Rodentia"     # search_val_in
# QUERY_TAX = "Animalia Chordata Mammalia Artiodactyla"     # search_val_in
# QUERY_TAX = "Animalia Arthropoda Arachnida Araneae"     # search_val_out
# QUERY_TAX = "Plantae Tracheophyta Magnoliopsida Caryophyllales"     # search_val_out

# TTA PARAMS
EXECUTE_TTA = getenv("EXECUTE_TTA", default=True, cast_type=bool)  # Whether to execute TTA mask updates
STEPS_PER_TTA = 20  # no. steps before each TTA series
NUM_TTA_STEPS = 1   # no. of TTA steps during each series
INITIAL_MODALITY = getenv("INITIAL_MODALITY", default="image", cast_type=str)  # "image", "text", "combined"
MODALITY = getenv("MODALITY", default="image", cast_type=str)  # "image", "text", "combined"
QUERY_VARIETY = getenv("QUERY_VARIETY", default=False, cast_type=bool)  # "image", "text", "combined"
RESET_WEIGHTS = True
MIN_LR = 1e-6
MAX_LR = 1e-5   # 1e-5
GAMMA_EXPONENT = 2  # 2 

# Paths related to taxabind (TRAIN w/ TARGETS)
TAXABIND_TTA = True     # Whether to init TTA classes - FOR NOW: Always True
TAXABIND_IMG_DIR = '/mnt/hdd/inat2021_ds/inat21' 
TAXABIND_IMO_DIR = '/mnt/hdd/inat2021_ds/sat_train_jpg_512px'    # sat_test_jpg_256px, sat_test_jpg_512px
TAXABIND_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/inat21_train.json' # no filter needed
TAXABIND_FILTERED_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/2_OTHERS/inat21_filtered_pixel_clip_train.json' # no filter needed
# TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = '/mnt/hdd/inat2021_ds/filtered_mapping_sat_to_img_ids_val.json' 
TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = getenv("TAXABIND_SAT_TO_IMG_IDS_JSON_PATH", default="/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/FINAL_COUNTS_COMBINED_top4000_pos_top55000_neg/search_val_in.json", cast_type=str)
# TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/search_val_in.json"
# TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/search_val_out.json"
TAXABIND_PATCH_SIZE=14
TAXABIND_SAT_CHECKPOINT_PATH="/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px_search_ds_filtered/pixel_clip_512px_search_ds_100625_CLIP-L-336_FINAL_SPLIT_LARGE_BUGFIX_CLIP_TRAIN_CORRECT_VAL_IN_TAX_FILTER_TGT_ONLY/satbind-epoch=02-val_loss=2.50_BACKUP.ckpt" # "/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px_search_ds_filtered/pixel_clip_512px_search_ds_070425_CLIP-L-336_FINAL_SPLIT_LARGE/satbind-epoch=02-val_loss=2.48-BACKUP.ckpt" 
TAXABIND_GAUSSIAN_BLUR_KERNEL = (5,5)
TAXABIND_SAMPLE_INDEX = 8   # DEBUG (Starting point) 5, 6, 8

# Sound
TAXABIND_SOUND_DATA_PATH = '/mnt/hdd/inat2021_ds/2_OTHERS/sound_test'
TAXABIND_SOUND_CHECKPOINT_PATH = "/home/user/Taxabind/TaxaBind/SoundBind/checkpoints/BUGFIX_CLIP_TRAIN_CORRECT_without_out_domain_taxs_v4_220625/soundbind-epoch=19-val_loss=3.92_BACKUP.ckpt"

# # Paths related to taxabind (TRAIN w/ TARGETS)
# TAXABIND_TTA = True
# TAXABIND_IMG_DIR = '/mnt/hdd/inat2021_ds/inat21' 
# TAXABIND_IMO_DIR = '/mnt/hdd/inat2021_ds/sat_train_jpg_512px'    # sat_test_jpg_256px, sat_test_jpg_512px
# TAXABIND_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/inat21_train.json' # no filter needed
# TAXABIND_FILTERED_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/2_OTHERS/inat21_filtered_pixel_clip_train.json' # no filter needed
# # TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = '/mnt/hdd/inat2021_ds/filtered_mapping_sat_to_img_ids_val.json' 
# TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = "/mnt/hdd/inat2021_ds/target_search_ds/OLD/taxon_sat_target_search_100x_per_10-20counts.json"
# TAXABIND_PATCH_SIZE=14
# TAXABIND_SAT_CHECKPOINT_PATH="/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_190225_CLIP-L-336/satbind-epoch=02-val_loss=2.40.ckpt" # /home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_160225_NO_DATASET_SHUFFLE/satbind-epoch=02-val_loss=2.26-BACKUP.ckpt
# TAXABIND_SAMPLE_INDEX = 99   # (Starting point) 99,141

# # Paths related to taxabind (VAL) 
# TAXABIND_TTA = True
# TAXABIND_IMG_DIR = '/mnt/hdd/inat2021_ds/inat21' 
# TAXABIND_IMO_DIR = '/mnt/hdd/inat2021_ds/sat_test_jpg_512px'    # sat_test_jpg_256px, sat_test_jpg_512px
# TAXABIND_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/inat21_val.json' # no filter needed
# TAXABIND_FILTERED_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/inat21_filtered_pixel_clip_val.json' # no filter needed
# # TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = '/mnt/hdd/inat2021_ds/filtered_mapping_sat_to_img_ids_val.json' 
# TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = "/mnt/hdd/inat2021_ds/target_search_ds/taxon_sat_target_search_100x_per_10-20counts.json"
# TAXABIND_PATCH_SIZE=14
# TAXABIND_SAT_CHECKPOINT_PATH="/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_190225_CLIP-L-336/satbind-epoch=02-val_loss=2.40.ckpt" # /home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_160225_NO_DATASET_SHUFFLE/satbind-epoch=02-val_loss=2.26-BACKUP.ckpt
# TAXABIND_SAMPLE_INDEX = 45   # TEMP

#######################################################################
# Pretraining
#######################################################################

# TODO: Get rid of the LISA stuff...
# If LISA trained clss
GRIDMAP_SET_DIR = "Maps/flair_real_maps/envs_val_trained_clss"
MASK_SET_DIR = "Maps/flair_real_maps_lisa_pred/finetuned_LISA_v3_original_losses/flair_lisa_soft_masks_trained_clss_v3"     # original_LISA, finetuned_LISA_v3_original_losses
TARGETS_SET_DIR = "Maps/flair_real_maps/masks_val_trained_clss"  # If empty, then targets assumed to be on MASK_SET_DIR
RAW_IMG_PATH_DICT = "Maps/flair_real_maps/flair-ds-paths-filtered-with-scores-val-trained-clss.csv"  # flair-ds-paths-filtered-with-scores-train.csv, flair-ds-paths-filtered-with-scores-val-trained-clss.csv, flair-ds-paths-filtered-with-scores-val-out-clss.csv

# # If LISA out clss
# GRIDMAP_SET_DIR = "Maps/flair_real_maps/envs_val_out_clss"
# MASK_SET_DIR = "Maps/flair_real_maps_lisa_pred/finetuned_LISA_v3_original_losses/flair_lisa_soft_masks_out_clss_v3"    # original_LISA, finetuned_LISA_v3_original_losses
# TARGETS_SET_DIR = "Maps/flair_real_maps/masks_val_out_clss"  # If empty, then targets assumed to be on MASK_SET_DIR
# RAW_IMG_PATH_DICT = "Maps/flair_real_maps/flair-ds-paths-filtered-with-scores-val-out-clss.csv"  # flair-ds-paths-filtered-with-scores-train.csv, flair-ds-paths-filtered-with-scores-val-trained-clss.csv, flair-ds-paths-filtered-with-scores-val-out-clss.csv

#######################################################################

NUM_ROBOTS = 1
NUM_COORDS_WIDTH=24           # How many node coords across width?
NUM_COORDS_HEIGHT=24          # How many node coords across height?
HIGH_INFO_REWARD_RATIO = 0.75   # Ratio of rewards for moving to uncertain area (high info vs low info)

SENSOR_RANGE=80              # Only applicable to 'circle' sensor model
SENSOR_MODEL="rectangular"   # "rectangular", "circle" (NOTE: (no colllision check for rectangular)

INPUT_DIM = 4
EMBEDDING_DIM = 128
K_SIZE = 8   # 8

USE_GPU = False  # do you want to use GPUS?
NUM_GPU = getenv("NUM_GPU", default=2, cast_type=int) # the number of GPUs
NUM_META_AGENT = getenv("NUM_META_AGENT", default=4, cast_type=int)  # the number of processes
FOLDER_NAME = 'inference' 
model_path = f'{FOLDER_NAME}/model'
gifs_path = f'{FOLDER_NAME}/test_results/gifs'
trajectory_path = f'{FOLDER_NAME}/test_results/trajectory'
length_path = f'{FOLDER_NAME}/test_results/length'
log_path = f'{FOLDER_NAME}/test_results/log'
CSV_EXPT_NAME = getenv("CSV_EXPT_NAME", default="data", cast_type=str)
# trajectory_path = f'results/trajectory'
# length_path = f'results/length'

# COLORS (for printing)
RED='\033[1;31m'          
GREEN='\033[1;32m'
YELLOW='\033[1;93m'       
NC_BOLD='\033[1m' # Bold, No Color 
NC='\033[0m' # No Color