search-tta-demo / test_parameter.py
derektan
Terminate upon all tgts found
f1165ca
############################################################################################
# Name: test_parameter.py
#
# NOTE: Change all your hyper-params here!
# Simple How-To Guide:
# 1. CLIP TTA: USE_CLIP_PREDS = True, EXECUTE_TTA = True
# 2. CLIP (No TTA): USE_CLIP_PREDS = True, EXECUTE_TTA = False
# 3. Custom masks (e.g. LLMSeg): USE_CLIP_PREDS = False, EXECUTE_TTA = False
############################################################################################
import os
import sys
sys.modules['TRAINING'] = False # False = Inference Testing
###############################################################
OPT_VARS = {}
def getenv(var_name, default=None, cast_type=str):
try:
value = os.environ.get(var_name, None)
if value is None:
result = default
elif cast_type == bool:
result = value.lower() in ("true", "1", "yes")
else:
result = cast_type(value)
except (ValueError, TypeError):
result = default
OPT_VARS[var_name] = result # Log the result
return result
###############################################################
POLICY = getenv("POLICY", default="RL", cast_type=str)
# TAX_HIERARCHY_TO_CONDENSE = 3 # Remove N layers of the taxonomy hierarchy from the back
NUM_TEST = 800 # Overriden if TAXABIND_TTA is True and performing search ds val
NUM_RUN = 1
SAVE_GIFS = getenv("SAVE_GIFS", default=True, cast_type=bool) # do you want to save GIFs
SAVE_TRAJECTORY = False # do you want to save per-step metrics
SAVE_LENGTH = False # do you want to save per-episode metrics
VIZ_GRAPH_EDGES = False # do you want to visualize the graph edges
# MODEL_NAME = "pure_coverage_no_pose_obs_230325_stage1.pth" # checkpoint.pth
# MODEL_NAME = "STAGE2_20k_vlm_search_24x24_290225_NO_TARGET_REWARDS_600steps.pth" # checkpoint.pth
# MODEL_NAME = "vlm_search_24x24_230225_NO_TARGET_REWARDS_600steps.pth" # checkpoint.pth
# MODEL_NAME = "vlm_search_20x20_200125_256steps_CORRECT_REWARDS.pth" # checkpoint.pth
MODEL_NAME = "STAGE1_vlm_search_24x24_040425_no_tgt_rewards_iNAT_DS_16k.pth"
NUM_EPS_STEPS = getenv("NUM_EPS_STEPS", default=384, cast_type=int)
TERMINATE_ON_TGTS_FOUND = True # Whether to terminate episode when all targets found
FORCE_LOGGING_DONE_TGTS_FOUND = True # Whether to force csv logging when all targets found
FIX_START_POSITION = getenv("FIX_START_POSITION", default=True, cast_type=bool) # Whether to fix the starting position of the robots (middle index)
## Whether to override initial score mask from CLIP
USE_CLIP_PREDS = getenv("USE_CLIP_PREDS", default=True, cast_type=bool) # If false, use custom masks from OVERRIDE_MASK_DIR
OVERRIDE_MASK_DIR = getenv("OVERRIDE_MASK_DIR", default="/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/FINAL_COUNTS_COMBINED_top4000_pos_top55000_neg/Baselines_Masks/LLaVA_GroundedSAM/v1_010425/LLaVA_GroundedSAM_out_mask_val_in/out_mask_val_in", cast_type=str)
# OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/pred/Animalia_Chordata_Mammalia_Rodentia"
# OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/pred/Animalia_Chordata_Mammalia_Artiodactyla"
# OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/pred/Animalia_Arthropoda_Arachnida_Araneae"
# OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/pred/Plantae_Tracheophyta_Magnoliopsida_Caryophyllales"
# OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/FINAL_COUNTS_COMBINED_top4000_pos_top55000_neg/Baselines_Masks/LLaVA_GroundedSAM/v1_010425/LLaVA_GroundedSAM_out_mask_val_in/out_mask_val_in"
# OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/FINAL_COUNTS_COMBINED_top4000_pos_top55000_neg/Baselines_Masks/LLaVA_GroundedSAM/v1_010425/LLaVA_GroundedSAM_out_mask_val_out/out_mask_val_out"
# Used to calcultae info_gain metric
OVERRIDE_GT_MASK_DIR = getenv("OVERRIDE_GT_MASK_DIR", default="", cast_type=str)
# OVERRIDE_GT_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/gt/val_in_4gsnet_score_map"
# OVERRIDE_GT_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/gt/val_out_4gsnet_score_map"
#######################################################################
# iNAT TTA
#######################################################################
# Query Params
QUERY_TAX = getenv("QUERY_TAX", default="", cast_type=str) # "" = Test all tax
# QUERY_TAX = "Animalia Chordata Mammalia Rodentia" # search_val_in
# QUERY_TAX = "Animalia Chordata Mammalia Artiodactyla" # search_val_in
# QUERY_TAX = "Animalia Arthropoda Arachnida Araneae" # search_val_out
# QUERY_TAX = "Plantae Tracheophyta Magnoliopsida Caryophyllales" # search_val_out
# TTA PARAMS
EXECUTE_TTA = getenv("EXECUTE_TTA", default=True, cast_type=bool) # Whether to execute TTA mask updates
STEPS_PER_TTA = 20 # no. steps before each TTA series
NUM_TTA_STEPS = 1 # no. of TTA steps during each series
INITIAL_MODALITY = getenv("INITIAL_MODALITY", default="image", cast_type=str) # "image", "text", "combined"
MODALITY = getenv("MODALITY", default="image", cast_type=str) # "image", "text", "combined"
QUERY_VARIETY = getenv("QUERY_VARIETY", default=False, cast_type=bool) # "image", "text", "combined"
RESET_WEIGHTS = True
MIN_LR = 1e-6
MAX_LR = 1e-5 # 1e-5
GAMMA_EXPONENT = 2 # 2
# Paths related to taxabind (TRAIN w/ TARGETS)
TAXABIND_TTA = True # Whether to init TTA classes - FOR NOW: Always True
TAXABIND_IMG_DIR = '/mnt/hdd/inat2021_ds/inat21'
TAXABIND_IMO_DIR = '/mnt/hdd/inat2021_ds/sat_train_jpg_512px' # sat_test_jpg_256px, sat_test_jpg_512px
TAXABIND_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/inat21_train.json' # no filter needed
TAXABIND_FILTERED_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/2_OTHERS/inat21_filtered_pixel_clip_train.json' # no filter needed
# TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = '/mnt/hdd/inat2021_ds/filtered_mapping_sat_to_img_ids_val.json'
TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = getenv("TAXABIND_SAT_TO_IMG_IDS_JSON_PATH", default="/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/FINAL_COUNTS_COMBINED_top4000_pos_top55000_neg/search_val_in.json", cast_type=str)
# TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/search_val_in.json"
# TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/search_val_out.json"
TAXABIND_PATCH_SIZE=14
TAXABIND_SAT_CHECKPOINT_PATH="/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px_search_ds_filtered/pixel_clip_512px_search_ds_100625_CLIP-L-336_FINAL_SPLIT_LARGE_BUGFIX_CLIP_TRAIN_CORRECT_VAL_IN_TAX_FILTER_TGT_ONLY/satbind-epoch=02-val_loss=2.50_BACKUP.ckpt" # "/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px_search_ds_filtered/pixel_clip_512px_search_ds_070425_CLIP-L-336_FINAL_SPLIT_LARGE/satbind-epoch=02-val_loss=2.48-BACKUP.ckpt"
TAXABIND_GAUSSIAN_BLUR_KERNEL = (5,5)
TAXABIND_SAMPLE_INDEX = 8 # DEBUG (Starting point) 5, 6, 8
# Sound
TAXABIND_SOUND_DATA_PATH = '/mnt/hdd/inat2021_ds/2_OTHERS/sound_test'
TAXABIND_SOUND_CHECKPOINT_PATH = "/home/user/Taxabind/TaxaBind/SoundBind/checkpoints/BUGFIX_CLIP_TRAIN_CORRECT_without_out_domain_taxs_v4_220625/soundbind-epoch=19-val_loss=3.92_BACKUP.ckpt"
# # Paths related to taxabind (TRAIN w/ TARGETS)
# TAXABIND_TTA = True
# TAXABIND_IMG_DIR = '/mnt/hdd/inat2021_ds/inat21'
# TAXABIND_IMO_DIR = '/mnt/hdd/inat2021_ds/sat_train_jpg_512px' # sat_test_jpg_256px, sat_test_jpg_512px
# TAXABIND_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/inat21_train.json' # no filter needed
# TAXABIND_FILTERED_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/2_OTHERS/inat21_filtered_pixel_clip_train.json' # no filter needed
# # TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = '/mnt/hdd/inat2021_ds/filtered_mapping_sat_to_img_ids_val.json'
# TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = "/mnt/hdd/inat2021_ds/target_search_ds/OLD/taxon_sat_target_search_100x_per_10-20counts.json"
# TAXABIND_PATCH_SIZE=14
# TAXABIND_SAT_CHECKPOINT_PATH="/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_190225_CLIP-L-336/satbind-epoch=02-val_loss=2.40.ckpt" # /home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_160225_NO_DATASET_SHUFFLE/satbind-epoch=02-val_loss=2.26-BACKUP.ckpt
# TAXABIND_SAMPLE_INDEX = 99 # (Starting point) 99,141
# # Paths related to taxabind (VAL)
# TAXABIND_TTA = True
# TAXABIND_IMG_DIR = '/mnt/hdd/inat2021_ds/inat21'
# TAXABIND_IMO_DIR = '/mnt/hdd/inat2021_ds/sat_test_jpg_512px' # sat_test_jpg_256px, sat_test_jpg_512px
# TAXABIND_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/inat21_val.json' # no filter needed
# TAXABIND_FILTERED_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/inat21_filtered_pixel_clip_val.json' # no filter needed
# # TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = '/mnt/hdd/inat2021_ds/filtered_mapping_sat_to_img_ids_val.json'
# TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = "/mnt/hdd/inat2021_ds/target_search_ds/taxon_sat_target_search_100x_per_10-20counts.json"
# TAXABIND_PATCH_SIZE=14
# TAXABIND_SAT_CHECKPOINT_PATH="/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_190225_CLIP-L-336/satbind-epoch=02-val_loss=2.40.ckpt" # /home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_160225_NO_DATASET_SHUFFLE/satbind-epoch=02-val_loss=2.26-BACKUP.ckpt
# TAXABIND_SAMPLE_INDEX = 45 # TEMP
#######################################################################
# Pretraining
#######################################################################
# TODO: Get rid of the LISA stuff...
# If LISA trained clss
GRIDMAP_SET_DIR = "Maps/flair_real_maps/envs_val_trained_clss"
MASK_SET_DIR = "Maps/flair_real_maps_lisa_pred/finetuned_LISA_v3_original_losses/flair_lisa_soft_masks_trained_clss_v3" # original_LISA, finetuned_LISA_v3_original_losses
TARGETS_SET_DIR = "Maps/flair_real_maps/masks_val_trained_clss" # If empty, then targets assumed to be on MASK_SET_DIR
RAW_IMG_PATH_DICT = "Maps/flair_real_maps/flair-ds-paths-filtered-with-scores-val-trained-clss.csv" # flair-ds-paths-filtered-with-scores-train.csv, flair-ds-paths-filtered-with-scores-val-trained-clss.csv, flair-ds-paths-filtered-with-scores-val-out-clss.csv
# # If LISA out clss
# GRIDMAP_SET_DIR = "Maps/flair_real_maps/envs_val_out_clss"
# MASK_SET_DIR = "Maps/flair_real_maps_lisa_pred/finetuned_LISA_v3_original_losses/flair_lisa_soft_masks_out_clss_v3" # original_LISA, finetuned_LISA_v3_original_losses
# TARGETS_SET_DIR = "Maps/flair_real_maps/masks_val_out_clss" # If empty, then targets assumed to be on MASK_SET_DIR
# RAW_IMG_PATH_DICT = "Maps/flair_real_maps/flair-ds-paths-filtered-with-scores-val-out-clss.csv" # flair-ds-paths-filtered-with-scores-train.csv, flair-ds-paths-filtered-with-scores-val-trained-clss.csv, flair-ds-paths-filtered-with-scores-val-out-clss.csv
#######################################################################
NUM_ROBOTS = 1
NUM_COORDS_WIDTH=24 # How many node coords across width?
NUM_COORDS_HEIGHT=24 # How many node coords across height?
HIGH_INFO_REWARD_RATIO = 0.75 # Ratio of rewards for moving to uncertain area (high info vs low info)
SENSOR_RANGE=80 # Only applicable to 'circle' sensor model
SENSOR_MODEL="rectangular" # "rectangular", "circle" (NOTE: (no colllision check for rectangular)
INPUT_DIM = 4
EMBEDDING_DIM = 128
K_SIZE = 8 # 8
USE_GPU = False # do you want to use GPUS?
NUM_GPU = getenv("NUM_GPU", default=2, cast_type=int) # the number of GPUs
NUM_META_AGENT = getenv("NUM_META_AGENT", default=4, cast_type=int) # the number of processes
FOLDER_NAME = 'inference'
model_path = f'{FOLDER_NAME}/model'
gifs_path = f'{FOLDER_NAME}/test_results/gifs'
trajectory_path = f'{FOLDER_NAME}/test_results/trajectory'
length_path = f'{FOLDER_NAME}/test_results/length'
log_path = f'{FOLDER_NAME}/test_results/log'
CSV_EXPT_NAME = getenv("CSV_EXPT_NAME", default="data", cast_type=str)
# trajectory_path = f'results/trajectory'
# length_path = f'results/length'
# COLORS (for printing)
RED='\033[1;31m'
GREEN='\033[1;32m'
YELLOW='\033[1;93m'
NC_BOLD='\033[1m' # Bold, No Color
NC='\033[0m' # No Color