############################################################################################ # Name: test_parameter.py # # NOTE: Change all your hyper-params here! # Simple How-To Guide: # 1. CLIP TTA: USE_CLIP_PREDS = True, EXECUTE_TTA = True # 2. CLIP (No TTA): USE_CLIP_PREDS = True, EXECUTE_TTA = False # 3. Custom masks (e.g. LLMSeg): USE_CLIP_PREDS = False, EXECUTE_TTA = False ############################################################################################ import os import sys sys.modules['TRAINING'] = False # False = Inference Testing ############################################################### OPT_VARS = {} def getenv(var_name, default=None, cast_type=str): try: value = os.environ.get(var_name, None) if value is None: result = default elif cast_type == bool: result = value.lower() in ("true", "1", "yes") else: result = cast_type(value) except (ValueError, TypeError): result = default OPT_VARS[var_name] = result # Log the result return result ############################################################### POLICY = getenv("POLICY", default="RL", cast_type=str) # TAX_HIERARCHY_TO_CONDENSE = 3 # Remove N layers of the taxonomy hierarchy from the back NUM_TEST = 800 # Overriden if TAXABIND_TTA is True and performing search ds val NUM_RUN = 1 SAVE_GIFS = getenv("SAVE_GIFS", default=True, cast_type=bool) # do you want to save GIFs SAVE_TRAJECTORY = False # do you want to save per-step metrics SAVE_LENGTH = False # do you want to save per-episode metrics VIZ_GRAPH_EDGES = False # do you want to visualize the graph edges # MODEL_NAME = "pure_coverage_no_pose_obs_230325_stage1.pth" # checkpoint.pth # MODEL_NAME = "STAGE2_20k_vlm_search_24x24_290225_NO_TARGET_REWARDS_600steps.pth" # checkpoint.pth # MODEL_NAME = "vlm_search_24x24_230225_NO_TARGET_REWARDS_600steps.pth" # checkpoint.pth # MODEL_NAME = "vlm_search_20x20_200125_256steps_CORRECT_REWARDS.pth" # checkpoint.pth MODEL_NAME = "STAGE1_vlm_search_24x24_040425_no_tgt_rewards_iNAT_DS_16k.pth" NUM_EPS_STEPS = getenv("NUM_EPS_STEPS", default=384, cast_type=int) TERMINATE_ON_TGTS_FOUND = True # Whether to terminate episode when all targets found FORCE_LOGGING_DONE_TGTS_FOUND = True # Whether to force csv logging when all targets found FIX_START_POSITION = getenv("FIX_START_POSITION", default=True, cast_type=bool) # Whether to fix the starting position of the robots (middle index) ## Whether to override initial score mask from CLIP USE_CLIP_PREDS = getenv("USE_CLIP_PREDS", default=True, cast_type=bool) # If false, use custom masks from OVERRIDE_MASK_DIR OVERRIDE_MASK_DIR = getenv("OVERRIDE_MASK_DIR", default="/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/FINAL_COUNTS_COMBINED_top4000_pos_top55000_neg/Baselines_Masks/LLaVA_GroundedSAM/v1_010425/LLaVA_GroundedSAM_out_mask_val_in/out_mask_val_in", cast_type=str) # OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/pred/Animalia_Chordata_Mammalia_Rodentia" # OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/pred/Animalia_Chordata_Mammalia_Artiodactyla" # OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/pred/Animalia_Arthropoda_Arachnida_Araneae" # OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/pred/Plantae_Tracheophyta_Magnoliopsida_Caryophyllales" # OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/FINAL_COUNTS_COMBINED_top4000_pos_top55000_neg/Baselines_Masks/LLaVA_GroundedSAM/v1_010425/LLaVA_GroundedSAM_out_mask_val_in/out_mask_val_in" # OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/FINAL_COUNTS_COMBINED_top4000_pos_top55000_neg/Baselines_Masks/LLaVA_GroundedSAM/v1_010425/LLaVA_GroundedSAM_out_mask_val_out/out_mask_val_out" # Used to calcultae info_gain metric OVERRIDE_GT_MASK_DIR = getenv("OVERRIDE_GT_MASK_DIR", default="", cast_type=str) # OVERRIDE_GT_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/gt/val_in_4gsnet_score_map" # OVERRIDE_GT_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/gt/val_out_4gsnet_score_map" ####################################################################### # iNAT TTA ####################################################################### # Query Params QUERY_TAX = getenv("QUERY_TAX", default="", cast_type=str) # "" = Test all tax # QUERY_TAX = "Animalia Chordata Mammalia Rodentia" # search_val_in # QUERY_TAX = "Animalia Chordata Mammalia Artiodactyla" # search_val_in # QUERY_TAX = "Animalia Arthropoda Arachnida Araneae" # search_val_out # QUERY_TAX = "Plantae Tracheophyta Magnoliopsida Caryophyllales" # search_val_out # TTA PARAMS EXECUTE_TTA = getenv("EXECUTE_TTA", default=True, cast_type=bool) # Whether to execute TTA mask updates STEPS_PER_TTA = 20 # no. steps before each TTA series NUM_TTA_STEPS = 1 # no. of TTA steps during each series INITIAL_MODALITY = getenv("INITIAL_MODALITY", default="image", cast_type=str) # "image", "text", "combined" MODALITY = getenv("MODALITY", default="image", cast_type=str) # "image", "text", "combined" QUERY_VARIETY = getenv("QUERY_VARIETY", default=False, cast_type=bool) # "image", "text", "combined" RESET_WEIGHTS = True MIN_LR = 1e-6 MAX_LR = 1e-5 # 1e-5 GAMMA_EXPONENT = 2 # 2 # Paths related to taxabind (TRAIN w/ TARGETS) TAXABIND_TTA = True # Whether to init TTA classes - FOR NOW: Always True TAXABIND_IMG_DIR = '/mnt/hdd/inat2021_ds/inat21' TAXABIND_IMO_DIR = '/mnt/hdd/inat2021_ds/sat_train_jpg_512px' # sat_test_jpg_256px, sat_test_jpg_512px TAXABIND_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/inat21_train.json' # no filter needed TAXABIND_FILTERED_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/2_OTHERS/inat21_filtered_pixel_clip_train.json' # no filter needed # TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = '/mnt/hdd/inat2021_ds/filtered_mapping_sat_to_img_ids_val.json' TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = getenv("TAXABIND_SAT_TO_IMG_IDS_JSON_PATH", default="/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/FINAL_COUNTS_COMBINED_top4000_pos_top55000_neg/search_val_in.json", cast_type=str) # TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/search_val_in.json" # TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/search_val_out.json" TAXABIND_PATCH_SIZE=14 TAXABIND_SAT_CHECKPOINT_PATH="/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px_search_ds_filtered/pixel_clip_512px_search_ds_100625_CLIP-L-336_FINAL_SPLIT_LARGE_BUGFIX_CLIP_TRAIN_CORRECT_VAL_IN_TAX_FILTER_TGT_ONLY/satbind-epoch=02-val_loss=2.50_BACKUP.ckpt" # "/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px_search_ds_filtered/pixel_clip_512px_search_ds_070425_CLIP-L-336_FINAL_SPLIT_LARGE/satbind-epoch=02-val_loss=2.48-BACKUP.ckpt" TAXABIND_GAUSSIAN_BLUR_KERNEL = (5,5) TAXABIND_SAMPLE_INDEX = 8 # DEBUG (Starting point) 5, 6, 8 # Sound TAXABIND_SOUND_DATA_PATH = '/mnt/hdd/inat2021_ds/2_OTHERS/sound_test' TAXABIND_SOUND_CHECKPOINT_PATH = "/home/user/Taxabind/TaxaBind/SoundBind/checkpoints/BUGFIX_CLIP_TRAIN_CORRECT_without_out_domain_taxs_v4_220625/soundbind-epoch=19-val_loss=3.92_BACKUP.ckpt" # # Paths related to taxabind (TRAIN w/ TARGETS) # TAXABIND_TTA = True # TAXABIND_IMG_DIR = '/mnt/hdd/inat2021_ds/inat21' # TAXABIND_IMO_DIR = '/mnt/hdd/inat2021_ds/sat_train_jpg_512px' # sat_test_jpg_256px, sat_test_jpg_512px # TAXABIND_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/inat21_train.json' # no filter needed # TAXABIND_FILTERED_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/2_OTHERS/inat21_filtered_pixel_clip_train.json' # no filter needed # # TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = '/mnt/hdd/inat2021_ds/filtered_mapping_sat_to_img_ids_val.json' # TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = "/mnt/hdd/inat2021_ds/target_search_ds/OLD/taxon_sat_target_search_100x_per_10-20counts.json" # TAXABIND_PATCH_SIZE=14 # TAXABIND_SAT_CHECKPOINT_PATH="/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_190225_CLIP-L-336/satbind-epoch=02-val_loss=2.40.ckpt" # /home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_160225_NO_DATASET_SHUFFLE/satbind-epoch=02-val_loss=2.26-BACKUP.ckpt # TAXABIND_SAMPLE_INDEX = 99 # (Starting point) 99,141 # # Paths related to taxabind (VAL) # TAXABIND_TTA = True # TAXABIND_IMG_DIR = '/mnt/hdd/inat2021_ds/inat21' # TAXABIND_IMO_DIR = '/mnt/hdd/inat2021_ds/sat_test_jpg_512px' # sat_test_jpg_256px, sat_test_jpg_512px # TAXABIND_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/inat21_val.json' # no filter needed # TAXABIND_FILTERED_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/inat21_filtered_pixel_clip_val.json' # no filter needed # # TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = '/mnt/hdd/inat2021_ds/filtered_mapping_sat_to_img_ids_val.json' # TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = "/mnt/hdd/inat2021_ds/target_search_ds/taxon_sat_target_search_100x_per_10-20counts.json" # TAXABIND_PATCH_SIZE=14 # TAXABIND_SAT_CHECKPOINT_PATH="/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_190225_CLIP-L-336/satbind-epoch=02-val_loss=2.40.ckpt" # /home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_160225_NO_DATASET_SHUFFLE/satbind-epoch=02-val_loss=2.26-BACKUP.ckpt # TAXABIND_SAMPLE_INDEX = 45 # TEMP ####################################################################### # Pretraining ####################################################################### # TODO: Get rid of the LISA stuff... # If LISA trained clss GRIDMAP_SET_DIR = "Maps/flair_real_maps/envs_val_trained_clss" MASK_SET_DIR = "Maps/flair_real_maps_lisa_pred/finetuned_LISA_v3_original_losses/flair_lisa_soft_masks_trained_clss_v3" # original_LISA, finetuned_LISA_v3_original_losses TARGETS_SET_DIR = "Maps/flair_real_maps/masks_val_trained_clss" # If empty, then targets assumed to be on MASK_SET_DIR RAW_IMG_PATH_DICT = "Maps/flair_real_maps/flair-ds-paths-filtered-with-scores-val-trained-clss.csv" # flair-ds-paths-filtered-with-scores-train.csv, flair-ds-paths-filtered-with-scores-val-trained-clss.csv, flair-ds-paths-filtered-with-scores-val-out-clss.csv # # If LISA out clss # GRIDMAP_SET_DIR = "Maps/flair_real_maps/envs_val_out_clss" # MASK_SET_DIR = "Maps/flair_real_maps_lisa_pred/finetuned_LISA_v3_original_losses/flair_lisa_soft_masks_out_clss_v3" # original_LISA, finetuned_LISA_v3_original_losses # TARGETS_SET_DIR = "Maps/flair_real_maps/masks_val_out_clss" # If empty, then targets assumed to be on MASK_SET_DIR # RAW_IMG_PATH_DICT = "Maps/flair_real_maps/flair-ds-paths-filtered-with-scores-val-out-clss.csv" # flair-ds-paths-filtered-with-scores-train.csv, flair-ds-paths-filtered-with-scores-val-trained-clss.csv, flair-ds-paths-filtered-with-scores-val-out-clss.csv ####################################################################### NUM_ROBOTS = 1 NUM_COORDS_WIDTH=24 # How many node coords across width? NUM_COORDS_HEIGHT=24 # How many node coords across height? HIGH_INFO_REWARD_RATIO = 0.75 # Ratio of rewards for moving to uncertain area (high info vs low info) SENSOR_RANGE=80 # Only applicable to 'circle' sensor model SENSOR_MODEL="rectangular" # "rectangular", "circle" (NOTE: (no colllision check for rectangular) INPUT_DIM = 4 EMBEDDING_DIM = 128 K_SIZE = 8 # 8 USE_GPU = False # do you want to use GPUS? NUM_GPU = getenv("NUM_GPU", default=2, cast_type=int) # the number of GPUs NUM_META_AGENT = getenv("NUM_META_AGENT", default=4, cast_type=int) # the number of processes FOLDER_NAME = 'inference' model_path = f'{FOLDER_NAME}/model' gifs_path = f'{FOLDER_NAME}/test_results/gifs' trajectory_path = f'{FOLDER_NAME}/test_results/trajectory' length_path = f'{FOLDER_NAME}/test_results/length' log_path = f'{FOLDER_NAME}/test_results/log' CSV_EXPT_NAME = getenv("CSV_EXPT_NAME", default="data", cast_type=str) # trajectory_path = f'results/trajectory' # length_path = f'results/length' # COLORS (for printing) RED='\033[1;31m' GREEN='\033[1;32m' YELLOW='\033[1;93m' NC_BOLD='\033[1m' # Bold, No Color NC='\033[0m' # No Color