Spaces:
Running
on
Zero
Running
on
Zero
############################################################################################ | |
# Name: test_parameter.py | |
# | |
# NOTE: Change all your hyper-params here! | |
# Simple How-To Guide: | |
# 1. CLIP TTA: USE_CLIP_PREDS = True, EXECUTE_TTA = True | |
# 2. CLIP (No TTA): USE_CLIP_PREDS = True, EXECUTE_TTA = False | |
# 3. Custom masks (e.g. LLMSeg): USE_CLIP_PREDS = False, EXECUTE_TTA = False | |
############################################################################################ | |
import os | |
import sys | |
sys.modules['TRAINING'] = False # False = Inference Testing | |
############################################################### | |
OPT_VARS = {} | |
def getenv(var_name, default=None, cast_type=str): | |
try: | |
value = os.environ.get(var_name, None) | |
if value is None: | |
result = default | |
elif cast_type == bool: | |
result = value.lower() in ("true", "1", "yes") | |
else: | |
result = cast_type(value) | |
except (ValueError, TypeError): | |
result = default | |
OPT_VARS[var_name] = result # Log the result | |
return result | |
############################################################### | |
POLICY = getenv("POLICY", default="RL", cast_type=str) | |
# TAX_HIERARCHY_TO_CONDENSE = 3 # Remove N layers of the taxonomy hierarchy from the back | |
NUM_TEST = 800 # Overriden if TAXABIND_TTA is True and performing search ds val | |
NUM_RUN = 1 | |
SAVE_GIFS = getenv("SAVE_GIFS", default=True, cast_type=bool) # do you want to save GIFs | |
SAVE_TRAJECTORY = False # do you want to save per-step metrics | |
SAVE_LENGTH = False # do you want to save per-episode metrics | |
VIZ_GRAPH_EDGES = False # do you want to visualize the graph edges | |
# MODEL_NAME = "pure_coverage_no_pose_obs_230325_stage1.pth" # checkpoint.pth | |
# MODEL_NAME = "STAGE2_20k_vlm_search_24x24_290225_NO_TARGET_REWARDS_600steps.pth" # checkpoint.pth | |
# MODEL_NAME = "vlm_search_24x24_230225_NO_TARGET_REWARDS_600steps.pth" # checkpoint.pth | |
# MODEL_NAME = "vlm_search_20x20_200125_256steps_CORRECT_REWARDS.pth" # checkpoint.pth | |
MODEL_NAME = "STAGE1_vlm_search_24x24_040425_no_tgt_rewards_iNAT_DS_16k.pth" | |
NUM_EPS_STEPS = getenv("NUM_EPS_STEPS", default=384, cast_type=int) | |
TERMINATE_ON_TGTS_FOUND = True # Whether to terminate episode when all targets found | |
FORCE_LOGGING_DONE_TGTS_FOUND = True # Whether to force csv logging when all targets found | |
FIX_START_POSITION = getenv("FIX_START_POSITION", default=True, cast_type=bool) # Whether to fix the starting position of the robots (middle index) | |
## Whether to override initial score mask from CLIP | |
USE_CLIP_PREDS = getenv("USE_CLIP_PREDS", default=True, cast_type=bool) # If false, use custom masks from OVERRIDE_MASK_DIR | |
OVERRIDE_MASK_DIR = getenv("OVERRIDE_MASK_DIR", default="/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/FINAL_COUNTS_COMBINED_top4000_pos_top55000_neg/Baselines_Masks/LLaVA_GroundedSAM/v1_010425/LLaVA_GroundedSAM_out_mask_val_in/out_mask_val_in", cast_type=str) | |
# OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/pred/Animalia_Chordata_Mammalia_Rodentia" | |
# OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/pred/Animalia_Chordata_Mammalia_Artiodactyla" | |
# OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/pred/Animalia_Arthropoda_Arachnida_Araneae" | |
# OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/pred/Plantae_Tracheophyta_Magnoliopsida_Caryophyllales" | |
# OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/FINAL_COUNTS_COMBINED_top4000_pos_top55000_neg/Baselines_Masks/LLaVA_GroundedSAM/v1_010425/LLaVA_GroundedSAM_out_mask_val_in/out_mask_val_in" | |
# OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/FINAL_COUNTS_COMBINED_top4000_pos_top55000_neg/Baselines_Masks/LLaVA_GroundedSAM/v1_010425/LLaVA_GroundedSAM_out_mask_val_out/out_mask_val_out" | |
# Used to calcultae info_gain metric | |
OVERRIDE_GT_MASK_DIR = getenv("OVERRIDE_GT_MASK_DIR", default="", cast_type=str) | |
# OVERRIDE_GT_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/gt/val_in_4gsnet_score_map" | |
# OVERRIDE_GT_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/gt/val_out_4gsnet_score_map" | |
####################################################################### | |
# iNAT TTA | |
####################################################################### | |
# Query Params | |
QUERY_TAX = getenv("QUERY_TAX", default="", cast_type=str) # "" = Test all tax | |
# QUERY_TAX = "Animalia Chordata Mammalia Rodentia" # search_val_in | |
# QUERY_TAX = "Animalia Chordata Mammalia Artiodactyla" # search_val_in | |
# QUERY_TAX = "Animalia Arthropoda Arachnida Araneae" # search_val_out | |
# QUERY_TAX = "Plantae Tracheophyta Magnoliopsida Caryophyllales" # search_val_out | |
# TTA PARAMS | |
EXECUTE_TTA = getenv("EXECUTE_TTA", default=True, cast_type=bool) # Whether to execute TTA mask updates | |
STEPS_PER_TTA = 20 # no. steps before each TTA series | |
NUM_TTA_STEPS = 1 # no. of TTA steps during each series | |
INITIAL_MODALITY = getenv("INITIAL_MODALITY", default="image", cast_type=str) # "image", "text", "combined" | |
MODALITY = getenv("MODALITY", default="image", cast_type=str) # "image", "text", "combined" | |
QUERY_VARIETY = getenv("QUERY_VARIETY", default=False, cast_type=bool) # "image", "text", "combined" | |
RESET_WEIGHTS = True | |
MIN_LR = 1e-6 | |
MAX_LR = 1e-5 # 1e-5 | |
GAMMA_EXPONENT = 2 # 2 | |
# Paths related to taxabind (TRAIN w/ TARGETS) | |
TAXABIND_TTA = True # Whether to init TTA classes - FOR NOW: Always True | |
TAXABIND_IMG_DIR = '/mnt/hdd/inat2021_ds/inat21' | |
TAXABIND_IMO_DIR = '/mnt/hdd/inat2021_ds/sat_train_jpg_512px' # sat_test_jpg_256px, sat_test_jpg_512px | |
TAXABIND_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/inat21_train.json' # no filter needed | |
TAXABIND_FILTERED_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/2_OTHERS/inat21_filtered_pixel_clip_train.json' # no filter needed | |
# TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = '/mnt/hdd/inat2021_ds/filtered_mapping_sat_to_img_ids_val.json' | |
TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = getenv("TAXABIND_SAT_TO_IMG_IDS_JSON_PATH", default="/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/FINAL_COUNTS_COMBINED_top4000_pos_top55000_neg/search_val_in.json", cast_type=str) | |
# TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/search_val_in.json" | |
# TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/search_val_out.json" | |
TAXABIND_PATCH_SIZE=14 | |
TAXABIND_SAT_CHECKPOINT_PATH="/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px_search_ds_filtered/pixel_clip_512px_search_ds_100625_CLIP-L-336_FINAL_SPLIT_LARGE_BUGFIX_CLIP_TRAIN_CORRECT_VAL_IN_TAX_FILTER_TGT_ONLY/satbind-epoch=02-val_loss=2.50_BACKUP.ckpt" # "/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px_search_ds_filtered/pixel_clip_512px_search_ds_070425_CLIP-L-336_FINAL_SPLIT_LARGE/satbind-epoch=02-val_loss=2.48-BACKUP.ckpt" | |
TAXABIND_GAUSSIAN_BLUR_KERNEL = (5,5) | |
TAXABIND_SAMPLE_INDEX = 8 # DEBUG (Starting point) 5, 6, 8 | |
# Sound | |
TAXABIND_SOUND_DATA_PATH = '/mnt/hdd/inat2021_ds/2_OTHERS/sound_test' | |
TAXABIND_SOUND_CHECKPOINT_PATH = "/home/user/Taxabind/TaxaBind/SoundBind/checkpoints/BUGFIX_CLIP_TRAIN_CORRECT_without_out_domain_taxs_v4_220625/soundbind-epoch=19-val_loss=3.92_BACKUP.ckpt" | |
# # Paths related to taxabind (TRAIN w/ TARGETS) | |
# TAXABIND_TTA = True | |
# TAXABIND_IMG_DIR = '/mnt/hdd/inat2021_ds/inat21' | |
# TAXABIND_IMO_DIR = '/mnt/hdd/inat2021_ds/sat_train_jpg_512px' # sat_test_jpg_256px, sat_test_jpg_512px | |
# TAXABIND_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/inat21_train.json' # no filter needed | |
# TAXABIND_FILTERED_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/2_OTHERS/inat21_filtered_pixel_clip_train.json' # no filter needed | |
# # TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = '/mnt/hdd/inat2021_ds/filtered_mapping_sat_to_img_ids_val.json' | |
# TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = "/mnt/hdd/inat2021_ds/target_search_ds/OLD/taxon_sat_target_search_100x_per_10-20counts.json" | |
# TAXABIND_PATCH_SIZE=14 | |
# TAXABIND_SAT_CHECKPOINT_PATH="/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_190225_CLIP-L-336/satbind-epoch=02-val_loss=2.40.ckpt" # /home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_160225_NO_DATASET_SHUFFLE/satbind-epoch=02-val_loss=2.26-BACKUP.ckpt | |
# TAXABIND_SAMPLE_INDEX = 99 # (Starting point) 99,141 | |
# # Paths related to taxabind (VAL) | |
# TAXABIND_TTA = True | |
# TAXABIND_IMG_DIR = '/mnt/hdd/inat2021_ds/inat21' | |
# TAXABIND_IMO_DIR = '/mnt/hdd/inat2021_ds/sat_test_jpg_512px' # sat_test_jpg_256px, sat_test_jpg_512px | |
# TAXABIND_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/inat21_val.json' # no filter needed | |
# TAXABIND_FILTERED_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/inat21_filtered_pixel_clip_val.json' # no filter needed | |
# # TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = '/mnt/hdd/inat2021_ds/filtered_mapping_sat_to_img_ids_val.json' | |
# TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = "/mnt/hdd/inat2021_ds/target_search_ds/taxon_sat_target_search_100x_per_10-20counts.json" | |
# TAXABIND_PATCH_SIZE=14 | |
# TAXABIND_SAT_CHECKPOINT_PATH="/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_190225_CLIP-L-336/satbind-epoch=02-val_loss=2.40.ckpt" # /home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_160225_NO_DATASET_SHUFFLE/satbind-epoch=02-val_loss=2.26-BACKUP.ckpt | |
# TAXABIND_SAMPLE_INDEX = 45 # TEMP | |
####################################################################### | |
# Pretraining | |
####################################################################### | |
# TODO: Get rid of the LISA stuff... | |
# If LISA trained clss | |
GRIDMAP_SET_DIR = "Maps/flair_real_maps/envs_val_trained_clss" | |
MASK_SET_DIR = "Maps/flair_real_maps_lisa_pred/finetuned_LISA_v3_original_losses/flair_lisa_soft_masks_trained_clss_v3" # original_LISA, finetuned_LISA_v3_original_losses | |
TARGETS_SET_DIR = "Maps/flair_real_maps/masks_val_trained_clss" # If empty, then targets assumed to be on MASK_SET_DIR | |
RAW_IMG_PATH_DICT = "Maps/flair_real_maps/flair-ds-paths-filtered-with-scores-val-trained-clss.csv" # flair-ds-paths-filtered-with-scores-train.csv, flair-ds-paths-filtered-with-scores-val-trained-clss.csv, flair-ds-paths-filtered-with-scores-val-out-clss.csv | |
# # If LISA out clss | |
# GRIDMAP_SET_DIR = "Maps/flair_real_maps/envs_val_out_clss" | |
# MASK_SET_DIR = "Maps/flair_real_maps_lisa_pred/finetuned_LISA_v3_original_losses/flair_lisa_soft_masks_out_clss_v3" # original_LISA, finetuned_LISA_v3_original_losses | |
# TARGETS_SET_DIR = "Maps/flair_real_maps/masks_val_out_clss" # If empty, then targets assumed to be on MASK_SET_DIR | |
# RAW_IMG_PATH_DICT = "Maps/flair_real_maps/flair-ds-paths-filtered-with-scores-val-out-clss.csv" # flair-ds-paths-filtered-with-scores-train.csv, flair-ds-paths-filtered-with-scores-val-trained-clss.csv, flair-ds-paths-filtered-with-scores-val-out-clss.csv | |
####################################################################### | |
NUM_ROBOTS = 1 | |
NUM_COORDS_WIDTH=24 # How many node coords across width? | |
NUM_COORDS_HEIGHT=24 # How many node coords across height? | |
HIGH_INFO_REWARD_RATIO = 0.75 # Ratio of rewards for moving to uncertain area (high info vs low info) | |
SENSOR_RANGE=80 # Only applicable to 'circle' sensor model | |
SENSOR_MODEL="rectangular" # "rectangular", "circle" (NOTE: (no colllision check for rectangular) | |
INPUT_DIM = 4 | |
EMBEDDING_DIM = 128 | |
K_SIZE = 8 # 8 | |
USE_GPU = False # do you want to use GPUS? | |
NUM_GPU = getenv("NUM_GPU", default=2, cast_type=int) # the number of GPUs | |
NUM_META_AGENT = getenv("NUM_META_AGENT", default=4, cast_type=int) # the number of processes | |
FOLDER_NAME = 'inference' | |
model_path = f'{FOLDER_NAME}/model' | |
gifs_path = f'{FOLDER_NAME}/test_results/gifs' | |
trajectory_path = f'{FOLDER_NAME}/test_results/trajectory' | |
length_path = f'{FOLDER_NAME}/test_results/length' | |
log_path = f'{FOLDER_NAME}/test_results/log' | |
CSV_EXPT_NAME = getenv("CSV_EXPT_NAME", default="data", cast_type=str) | |
# trajectory_path = f'results/trajectory' | |
# length_path = f'results/length' | |
# COLORS (for printing) | |
RED='\033[1;31m' | |
GREEN='\033[1;32m' | |
YELLOW='\033[1;93m' | |
NC_BOLD='\033[1m' # Bold, No Color | |
NC='\033[0m' # No Color | |