Spaces:
Running
on
Zero
Running
on
Zero
File size: 12,704 Bytes
4f09ecf f1165ca 4f09ecf c9f0c04 4f09ecf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
############################################################################################
# Name: test_parameter.py
#
# NOTE: Change all your hyper-params here!
# Simple How-To Guide:
# 1. CLIP TTA: USE_CLIP_PREDS = True, EXECUTE_TTA = True
# 2. CLIP (No TTA): USE_CLIP_PREDS = True, EXECUTE_TTA = False
# 3. Custom masks (e.g. LLMSeg): USE_CLIP_PREDS = False, EXECUTE_TTA = False
############################################################################################
import os
import sys
sys.modules['TRAINING'] = False # False = Inference Testing
###############################################################
OPT_VARS = {}
def getenv(var_name, default=None, cast_type=str):
try:
value = os.environ.get(var_name, None)
if value is None:
result = default
elif cast_type == bool:
result = value.lower() in ("true", "1", "yes")
else:
result = cast_type(value)
except (ValueError, TypeError):
result = default
OPT_VARS[var_name] = result # Log the result
return result
###############################################################
POLICY = getenv("POLICY", default="RL", cast_type=str)
# TAX_HIERARCHY_TO_CONDENSE = 3 # Remove N layers of the taxonomy hierarchy from the back
NUM_TEST = 800 # Overriden if TAXABIND_TTA is True and performing search ds val
NUM_RUN = 1
SAVE_GIFS = getenv("SAVE_GIFS", default=True, cast_type=bool) # do you want to save GIFs
SAVE_TRAJECTORY = False # do you want to save per-step metrics
SAVE_LENGTH = False # do you want to save per-episode metrics
VIZ_GRAPH_EDGES = False # do you want to visualize the graph edges
# MODEL_NAME = "pure_coverage_no_pose_obs_230325_stage1.pth" # checkpoint.pth
# MODEL_NAME = "STAGE2_20k_vlm_search_24x24_290225_NO_TARGET_REWARDS_600steps.pth" # checkpoint.pth
# MODEL_NAME = "vlm_search_24x24_230225_NO_TARGET_REWARDS_600steps.pth" # checkpoint.pth
# MODEL_NAME = "vlm_search_20x20_200125_256steps_CORRECT_REWARDS.pth" # checkpoint.pth
MODEL_NAME = "STAGE1_vlm_search_24x24_040425_no_tgt_rewards_iNAT_DS_16k.pth"
NUM_EPS_STEPS = getenv("NUM_EPS_STEPS", default=384, cast_type=int)
TERMINATE_ON_TGTS_FOUND = True # Whether to terminate episode when all targets found
FORCE_LOGGING_DONE_TGTS_FOUND = True # Whether to force csv logging when all targets found
FIX_START_POSITION = getenv("FIX_START_POSITION", default=True, cast_type=bool) # Whether to fix the starting position of the robots (middle index)
## Whether to override initial score mask from CLIP
USE_CLIP_PREDS = getenv("USE_CLIP_PREDS", default=True, cast_type=bool) # If false, use custom masks from OVERRIDE_MASK_DIR
OVERRIDE_MASK_DIR = getenv("OVERRIDE_MASK_DIR", default="/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/FINAL_COUNTS_COMBINED_top4000_pos_top55000_neg/Baselines_Masks/LLaVA_GroundedSAM/v1_010425/LLaVA_GroundedSAM_out_mask_val_in/out_mask_val_in", cast_type=str)
# OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/pred/Animalia_Chordata_Mammalia_Rodentia"
# OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/pred/Animalia_Chordata_Mammalia_Artiodactyla"
# OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/pred/Animalia_Arthropoda_Arachnida_Araneae"
# OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/pred/Plantae_Tracheophyta_Magnoliopsida_Caryophyllales"
# OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/FINAL_COUNTS_COMBINED_top4000_pos_top55000_neg/Baselines_Masks/LLaVA_GroundedSAM/v1_010425/LLaVA_GroundedSAM_out_mask_val_in/out_mask_val_in"
# OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/FINAL_COUNTS_COMBINED_top4000_pos_top55000_neg/Baselines_Masks/LLaVA_GroundedSAM/v1_010425/LLaVA_GroundedSAM_out_mask_val_out/out_mask_val_out"
# Used to calcultae info_gain metric
OVERRIDE_GT_MASK_DIR = getenv("OVERRIDE_GT_MASK_DIR", default="", cast_type=str)
# OVERRIDE_GT_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/gt/val_in_4gsnet_score_map"
# OVERRIDE_GT_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/gt/val_out_4gsnet_score_map"
#######################################################################
# iNAT TTA
#######################################################################
# Query Params
QUERY_TAX = getenv("QUERY_TAX", default="", cast_type=str) # "" = Test all tax
# QUERY_TAX = "Animalia Chordata Mammalia Rodentia" # search_val_in
# QUERY_TAX = "Animalia Chordata Mammalia Artiodactyla" # search_val_in
# QUERY_TAX = "Animalia Arthropoda Arachnida Araneae" # search_val_out
# QUERY_TAX = "Plantae Tracheophyta Magnoliopsida Caryophyllales" # search_val_out
# TTA PARAMS
EXECUTE_TTA = getenv("EXECUTE_TTA", default=True, cast_type=bool) # Whether to execute TTA mask updates
STEPS_PER_TTA = 20 # no. steps before each TTA series
NUM_TTA_STEPS = 1 # no. of TTA steps during each series
INITIAL_MODALITY = getenv("INITIAL_MODALITY", default="image", cast_type=str) # "image", "text", "combined"
MODALITY = getenv("MODALITY", default="image", cast_type=str) # "image", "text", "combined"
QUERY_VARIETY = getenv("QUERY_VARIETY", default=False, cast_type=bool) # "image", "text", "combined"
RESET_WEIGHTS = True
MIN_LR = 1e-6
MAX_LR = 1e-5 # 1e-5
GAMMA_EXPONENT = 2 # 2
# Paths related to taxabind (TRAIN w/ TARGETS)
TAXABIND_TTA = True # Whether to init TTA classes - FOR NOW: Always True
TAXABIND_IMG_DIR = '/mnt/hdd/inat2021_ds/inat21'
TAXABIND_IMO_DIR = '/mnt/hdd/inat2021_ds/sat_train_jpg_512px' # sat_test_jpg_256px, sat_test_jpg_512px
TAXABIND_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/inat21_train.json' # no filter needed
TAXABIND_FILTERED_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/2_OTHERS/inat21_filtered_pixel_clip_train.json' # no filter needed
# TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = '/mnt/hdd/inat2021_ds/filtered_mapping_sat_to_img_ids_val.json'
TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = getenv("TAXABIND_SAT_TO_IMG_IDS_JSON_PATH", default="/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/FINAL_COUNTS_COMBINED_top4000_pos_top55000_neg/search_val_in.json", cast_type=str)
# TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/search_val_in.json"
# TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/search_val_out.json"
TAXABIND_PATCH_SIZE=14
TAXABIND_SAT_CHECKPOINT_PATH="/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px_search_ds_filtered/pixel_clip_512px_search_ds_100625_CLIP-L-336_FINAL_SPLIT_LARGE_BUGFIX_CLIP_TRAIN_CORRECT_VAL_IN_TAX_FILTER_TGT_ONLY/satbind-epoch=02-val_loss=2.50_BACKUP.ckpt" # "/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px_search_ds_filtered/pixel_clip_512px_search_ds_070425_CLIP-L-336_FINAL_SPLIT_LARGE/satbind-epoch=02-val_loss=2.48-BACKUP.ckpt"
TAXABIND_GAUSSIAN_BLUR_KERNEL = (5,5)
TAXABIND_SAMPLE_INDEX = 8 # DEBUG (Starting point) 5, 6, 8
# Sound
TAXABIND_SOUND_DATA_PATH = '/mnt/hdd/inat2021_ds/2_OTHERS/sound_test'
TAXABIND_SOUND_CHECKPOINT_PATH = "/home/user/Taxabind/TaxaBind/SoundBind/checkpoints/BUGFIX_CLIP_TRAIN_CORRECT_without_out_domain_taxs_v4_220625/soundbind-epoch=19-val_loss=3.92_BACKUP.ckpt"
# # Paths related to taxabind (TRAIN w/ TARGETS)
# TAXABIND_TTA = True
# TAXABIND_IMG_DIR = '/mnt/hdd/inat2021_ds/inat21'
# TAXABIND_IMO_DIR = '/mnt/hdd/inat2021_ds/sat_train_jpg_512px' # sat_test_jpg_256px, sat_test_jpg_512px
# TAXABIND_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/inat21_train.json' # no filter needed
# TAXABIND_FILTERED_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/2_OTHERS/inat21_filtered_pixel_clip_train.json' # no filter needed
# # TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = '/mnt/hdd/inat2021_ds/filtered_mapping_sat_to_img_ids_val.json'
# TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = "/mnt/hdd/inat2021_ds/target_search_ds/OLD/taxon_sat_target_search_100x_per_10-20counts.json"
# TAXABIND_PATCH_SIZE=14
# TAXABIND_SAT_CHECKPOINT_PATH="/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_190225_CLIP-L-336/satbind-epoch=02-val_loss=2.40.ckpt" # /home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_160225_NO_DATASET_SHUFFLE/satbind-epoch=02-val_loss=2.26-BACKUP.ckpt
# TAXABIND_SAMPLE_INDEX = 99 # (Starting point) 99,141
# # Paths related to taxabind (VAL)
# TAXABIND_TTA = True
# TAXABIND_IMG_DIR = '/mnt/hdd/inat2021_ds/inat21'
# TAXABIND_IMO_DIR = '/mnt/hdd/inat2021_ds/sat_test_jpg_512px' # sat_test_jpg_256px, sat_test_jpg_512px
# TAXABIND_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/inat21_val.json' # no filter needed
# TAXABIND_FILTERED_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/inat21_filtered_pixel_clip_val.json' # no filter needed
# # TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = '/mnt/hdd/inat2021_ds/filtered_mapping_sat_to_img_ids_val.json'
# TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = "/mnt/hdd/inat2021_ds/target_search_ds/taxon_sat_target_search_100x_per_10-20counts.json"
# TAXABIND_PATCH_SIZE=14
# TAXABIND_SAT_CHECKPOINT_PATH="/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_190225_CLIP-L-336/satbind-epoch=02-val_loss=2.40.ckpt" # /home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_160225_NO_DATASET_SHUFFLE/satbind-epoch=02-val_loss=2.26-BACKUP.ckpt
# TAXABIND_SAMPLE_INDEX = 45 # TEMP
#######################################################################
# Pretraining
#######################################################################
# TODO: Get rid of the LISA stuff...
# If LISA trained clss
GRIDMAP_SET_DIR = "Maps/flair_real_maps/envs_val_trained_clss"
MASK_SET_DIR = "Maps/flair_real_maps_lisa_pred/finetuned_LISA_v3_original_losses/flair_lisa_soft_masks_trained_clss_v3" # original_LISA, finetuned_LISA_v3_original_losses
TARGETS_SET_DIR = "Maps/flair_real_maps/masks_val_trained_clss" # If empty, then targets assumed to be on MASK_SET_DIR
RAW_IMG_PATH_DICT = "Maps/flair_real_maps/flair-ds-paths-filtered-with-scores-val-trained-clss.csv" # flair-ds-paths-filtered-with-scores-train.csv, flair-ds-paths-filtered-with-scores-val-trained-clss.csv, flair-ds-paths-filtered-with-scores-val-out-clss.csv
# # If LISA out clss
# GRIDMAP_SET_DIR = "Maps/flair_real_maps/envs_val_out_clss"
# MASK_SET_DIR = "Maps/flair_real_maps_lisa_pred/finetuned_LISA_v3_original_losses/flair_lisa_soft_masks_out_clss_v3" # original_LISA, finetuned_LISA_v3_original_losses
# TARGETS_SET_DIR = "Maps/flair_real_maps/masks_val_out_clss" # If empty, then targets assumed to be on MASK_SET_DIR
# RAW_IMG_PATH_DICT = "Maps/flair_real_maps/flair-ds-paths-filtered-with-scores-val-out-clss.csv" # flair-ds-paths-filtered-with-scores-train.csv, flair-ds-paths-filtered-with-scores-val-trained-clss.csv, flair-ds-paths-filtered-with-scores-val-out-clss.csv
#######################################################################
NUM_ROBOTS = 1
NUM_COORDS_WIDTH=24 # How many node coords across width?
NUM_COORDS_HEIGHT=24 # How many node coords across height?
HIGH_INFO_REWARD_RATIO = 0.75 # Ratio of rewards for moving to uncertain area (high info vs low info)
SENSOR_RANGE=80 # Only applicable to 'circle' sensor model
SENSOR_MODEL="rectangular" # "rectangular", "circle" (NOTE: (no colllision check for rectangular)
INPUT_DIM = 4
EMBEDDING_DIM = 128
K_SIZE = 8 # 8
USE_GPU = False # do you want to use GPUS?
NUM_GPU = getenv("NUM_GPU", default=2, cast_type=int) # the number of GPUs
NUM_META_AGENT = getenv("NUM_META_AGENT", default=4, cast_type=int) # the number of processes
FOLDER_NAME = 'inference'
model_path = f'{FOLDER_NAME}/model'
gifs_path = f'{FOLDER_NAME}/test_results/gifs'
trajectory_path = f'{FOLDER_NAME}/test_results/trajectory'
length_path = f'{FOLDER_NAME}/test_results/length'
log_path = f'{FOLDER_NAME}/test_results/log'
CSV_EXPT_NAME = getenv("CSV_EXPT_NAME", default="data", cast_type=str)
# trajectory_path = f'results/trajectory'
# length_path = f'results/length'
# COLORS (for printing)
RED='\033[1;31m'
GREEN='\033[1;32m'
YELLOW='\033[1;93m'
NC_BOLD='\033[1m' # Bold, No Color
NC='\033[0m' # No Color
|