Spaces:
Runtime error
Runtime error
| from datetime import timedelta | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| from isegm.data.datasets import (BerkeleyDataset, DavisDataset, GrabCutDataset, | |
| PascalVocDataset, SBDEvaluationDataset) | |
| from isegm.utils.serialization import load_model | |
| def get_time_metrics(all_ious, elapsed_time): | |
| n_images = len(all_ious) | |
| n_clicks = sum(map(len, all_ious)) | |
| mean_spc = elapsed_time / n_clicks | |
| mean_spi = elapsed_time / n_images | |
| return mean_spc, mean_spi | |
| def load_is_model(checkpoint, device, **kwargs): | |
| if isinstance(checkpoint, (str, Path)): | |
| state_dict = torch.load(checkpoint, map_location="cpu") | |
| else: | |
| state_dict = checkpoint | |
| if isinstance(state_dict, list): | |
| model = load_single_is_model(state_dict[0], device, **kwargs) | |
| models = [load_single_is_model(x, device, **kwargs) for x in state_dict] | |
| return model, models | |
| else: | |
| return load_single_is_model(state_dict, device, **kwargs) | |
| def load_single_is_model(state_dict, device, **kwargs): | |
| model = load_model(state_dict["config"], **kwargs) | |
| model.load_state_dict(state_dict["state_dict"], strict=False) | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| model.to(device) | |
| model.eval() | |
| return model | |
| def get_dataset(dataset_name, cfg): | |
| if dataset_name == "GrabCut": | |
| dataset = GrabCutDataset(cfg.GRABCUT_PATH) | |
| elif dataset_name == "Berkeley": | |
| dataset = BerkeleyDataset(cfg.BERKELEY_PATH) | |
| elif dataset_name == "DAVIS": | |
| dataset = DavisDataset(cfg.DAVIS_PATH) | |
| elif dataset_name == "SBD": | |
| dataset = SBDEvaluationDataset(cfg.SBD_PATH) | |
| elif dataset_name == "SBD_Train": | |
| dataset = SBDEvaluationDataset(cfg.SBD_PATH, split="train") | |
| elif dataset_name == "PascalVOC": | |
| dataset = PascalVocDataset(cfg.PASCALVOC_PATH, split="test") | |
| elif dataset_name == "COCO_MVal": | |
| dataset = DavisDataset(cfg.COCO_MVAL_PATH) | |
| else: | |
| dataset = None | |
| return dataset | |
| def get_iou(gt_mask, pred_mask, ignore_label=-1): | |
| ignore_gt_mask_inv = gt_mask != ignore_label | |
| obj_gt_mask = gt_mask == 1 | |
| intersection = np.logical_and( | |
| np.logical_and(pred_mask, obj_gt_mask), ignore_gt_mask_inv | |
| ).sum() | |
| union = np.logical_and( | |
| np.logical_or(pred_mask, obj_gt_mask), ignore_gt_mask_inv | |
| ).sum() | |
| return intersection / union | |
| def compute_noc_metric(all_ious, iou_thrs, max_clicks=20): | |
| def _get_noc(iou_arr, iou_thr): | |
| vals = iou_arr >= iou_thr | |
| return np.argmax(vals) + 1 if np.any(vals) else max_clicks | |
| noc_list = [] | |
| over_max_list = [] | |
| for iou_thr in iou_thrs: | |
| scores_arr = np.array( | |
| [_get_noc(iou_arr, iou_thr) for iou_arr in all_ious], dtype=np.int | |
| ) | |
| score = scores_arr.mean() | |
| over_max = (scores_arr == max_clicks).sum() | |
| noc_list.append(score) | |
| over_max_list.append(over_max) | |
| return noc_list, over_max_list | |
| def find_checkpoint(weights_folder, checkpoint_name): | |
| weights_folder = Path(weights_folder) | |
| if ":" in checkpoint_name: | |
| model_name, checkpoint_name = checkpoint_name.split(":") | |
| models_candidates = [ | |
| x for x in weights_folder.glob(f"{model_name}*") if x.is_dir() | |
| ] | |
| assert len(models_candidates) == 1 | |
| model_folder = models_candidates[0] | |
| else: | |
| model_folder = weights_folder | |
| if checkpoint_name.endswith(".pth"): | |
| if Path(checkpoint_name).exists(): | |
| checkpoint_path = checkpoint_name | |
| else: | |
| checkpoint_path = weights_folder / checkpoint_name | |
| else: | |
| model_checkpoints = list(model_folder.rglob(f"{checkpoint_name}*.pth")) | |
| assert len(model_checkpoints) == 1 | |
| checkpoint_path = model_checkpoints[0] | |
| return str(checkpoint_path) | |
| def get_results_table( | |
| noc_list, | |
| over_max_list, | |
| brs_type, | |
| dataset_name, | |
| mean_spc, | |
| elapsed_time, | |
| n_clicks=20, | |
| model_name=None, | |
| ): | |
| table_header = ( | |
| f'|{"BRS Type":^13}|{"Dataset":^11}|' | |
| f'{"NoC@80%":^9}|{"NoC@85%":^9}|{"NoC@90%":^9}|' | |
| f'{">="+str(n_clicks)+"@85%":^9}|{">="+str(n_clicks)+"@90%":^9}|' | |
| f'{"SPC,s":^7}|{"Time":^9}|' | |
| ) | |
| row_width = len(table_header) | |
| header = f"Eval results for model: {model_name}\n" if model_name is not None else "" | |
| header += "-" * row_width + "\n" | |
| header += table_header + "\n" + "-" * row_width | |
| eval_time = str(timedelta(seconds=int(elapsed_time))) | |
| table_row = f"|{brs_type:^13}|{dataset_name:^11}|" | |
| table_row += f"{noc_list[0]:^9.2f}|" | |
| table_row += f"{noc_list[1]:^9.2f}|" if len(noc_list) > 1 else f'{"?":^9}|' | |
| table_row += f"{noc_list[2]:^9.2f}|" if len(noc_list) > 2 else f'{"?":^9}|' | |
| table_row += f"{over_max_list[1]:^9}|" if len(noc_list) > 1 else f'{"?":^9}|' | |
| table_row += f"{over_max_list[2]:^9}|" if len(noc_list) > 2 else f'{"?":^9}|' | |
| table_row += f"{mean_spc:^7.3f}|{eval_time:^9}|" | |
| return header, table_row | |