Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| from .geom import rnd_sample, interpolate, get_dist_mat | |
| def make_detector_loss( | |
| pos0, | |
| pos1, | |
| dense_feat_map0, | |
| dense_feat_map1, | |
| score_map0, | |
| score_map1, | |
| batch_size, | |
| num_corr, | |
| loss_type, | |
| config, | |
| ): | |
| joint_loss = 0.0 | |
| accuracy = 0.0 | |
| all_valid_pos0 = [] | |
| all_valid_pos1 = [] | |
| all_valid_match = [] | |
| for i in range(batch_size): | |
| # random sample | |
| valid_pos0, valid_pos1 = rnd_sample([pos0[i], pos1[i]], num_corr) | |
| valid_num = valid_pos0.shape[0] | |
| valid_feat0 = interpolate(valid_pos0 / 4, dense_feat_map0[i]) | |
| valid_feat1 = interpolate(valid_pos1 / 4, dense_feat_map1[i]) | |
| valid_feat0 = F.normalize(valid_feat0, p=2, dim=-1) | |
| valid_feat1 = F.normalize(valid_feat1, p=2, dim=-1) | |
| valid_score0 = interpolate( | |
| valid_pos0, torch.squeeze(score_map0[i], dim=-1), nd=False | |
| ) | |
| valid_score1 = interpolate( | |
| valid_pos1, torch.squeeze(score_map1[i], dim=-1), nd=False | |
| ) | |
| if config["network"]["det"]["corr_weight"]: | |
| corr_weight = valid_score0 * valid_score1 | |
| else: | |
| corr_weight = None | |
| safe_radius = config["network"]["det"]["safe_radius"] | |
| if safe_radius > 0: | |
| radius_mask_row = get_dist_mat( | |
| valid_pos1, valid_pos1, "euclidean_dist_no_norm" | |
| ) | |
| radius_mask_row = torch.le(radius_mask_row, safe_radius) | |
| radius_mask_col = get_dist_mat( | |
| valid_pos0, valid_pos0, "euclidean_dist_no_norm" | |
| ) | |
| radius_mask_col = torch.le(radius_mask_col, safe_radius) | |
| radius_mask_row = radius_mask_row.float() - torch.eye( | |
| valid_num, device=radius_mask_row.device | |
| ) | |
| radius_mask_col = radius_mask_col.float() - torch.eye( | |
| valid_num, device=radius_mask_col.device | |
| ) | |
| else: | |
| radius_mask_row = None | |
| radius_mask_col = None | |
| if valid_num < 32: | |
| si_loss, si_accuracy, matched_mask = ( | |
| 0.0, | |
| 1.0, | |
| torch.zeros((1, valid_num)).bool(), | |
| ) | |
| else: | |
| si_loss, si_accuracy, matched_mask = make_structured_loss( | |
| torch.unsqueeze(valid_feat0, 0), | |
| torch.unsqueeze(valid_feat1, 0), | |
| loss_type=loss_type, | |
| radius_mask_row=radius_mask_row, | |
| radius_mask_col=radius_mask_col, | |
| corr_weight=torch.unsqueeze(corr_weight, 0) | |
| if corr_weight is not None | |
| else None, | |
| ) | |
| joint_loss += si_loss / batch_size | |
| accuracy += si_accuracy / batch_size | |
| all_valid_match.append(torch.squeeze(matched_mask, dim=0)) | |
| all_valid_pos0.append(valid_pos0) | |
| all_valid_pos1.append(valid_pos1) | |
| return joint_loss, accuracy | |
| def make_structured_loss( | |
| feat_anc, | |
| feat_pos, | |
| loss_type="RATIO", | |
| inlier_mask=None, | |
| radius_mask_row=None, | |
| radius_mask_col=None, | |
| corr_weight=None, | |
| dist_mat=None, | |
| ): | |
| """ | |
| Structured loss construction. | |
| Args: | |
| feat_anc, feat_pos: Feature matrix. | |
| loss_type: Loss type. | |
| inlier_mask: | |
| Returns: | |
| """ | |
| batch_size = feat_anc.shape[0] | |
| num_corr = feat_anc.shape[1] | |
| if inlier_mask is None: | |
| inlier_mask = torch.ones((batch_size, num_corr), device=feat_anc.device).bool() | |
| inlier_num = torch.count_nonzero(inlier_mask.float(), dim=-1) | |
| if loss_type == "L2NET" or loss_type == "CIRCLE": | |
| dist_type = "cosine_dist" | |
| elif loss_type.find("HARD") >= 0: | |
| dist_type = "euclidean_dist" | |
| else: | |
| raise NotImplementedError() | |
| if dist_mat is None: | |
| dist_mat = get_dist_mat( | |
| feat_anc.squeeze(0), feat_pos.squeeze(0), dist_type | |
| ).unsqueeze(0) | |
| pos_vec = dist_mat[0].diag().unsqueeze(0) | |
| if loss_type.find("HARD") >= 0: | |
| neg_margin = 1 | |
| dist_mat_without_min_on_diag = dist_mat + 10 * torch.unsqueeze( | |
| torch.eye(num_corr, device=dist_mat.device), dim=0 | |
| ) | |
| mask = torch.le(dist_mat_without_min_on_diag, 0.008).float() | |
| dist_mat_without_min_on_diag += mask * 10 | |
| if radius_mask_row is not None: | |
| hard_neg_dist_row = dist_mat_without_min_on_diag + 10 * radius_mask_row | |
| else: | |
| hard_neg_dist_row = dist_mat_without_min_on_diag | |
| if radius_mask_col is not None: | |
| hard_neg_dist_col = dist_mat_without_min_on_diag + 10 * radius_mask_col | |
| else: | |
| hard_neg_dist_col = dist_mat_without_min_on_diag | |
| hard_neg_dist_row = torch.min(hard_neg_dist_row, dim=-1)[0] | |
| hard_neg_dist_col = torch.min(hard_neg_dist_col, dim=-2)[0] | |
| if loss_type == "HARD_TRIPLET": | |
| loss_row = torch.clamp(neg_margin + pos_vec - hard_neg_dist_row, min=0) | |
| loss_col = torch.clamp(neg_margin + pos_vec - hard_neg_dist_col, min=0) | |
| elif loss_type == "HARD_CONTRASTIVE": | |
| pos_margin = 0.2 | |
| pos_loss = torch.clamp(pos_vec - pos_margin, min=0) | |
| loss_row = pos_loss + torch.clamp(neg_margin - hard_neg_dist_row, min=0) | |
| loss_col = pos_loss + torch.clamp(neg_margin - hard_neg_dist_col, min=0) | |
| else: | |
| raise NotImplementedError() | |
| elif loss_type == "CIRCLE": | |
| log_scale = 512 | |
| m = 0.1 | |
| neg_mask_row = torch.unsqueeze(torch.eye(num_corr, device=feat_anc.device), 0) | |
| if radius_mask_row is not None: | |
| neg_mask_row += radius_mask_row | |
| neg_mask_col = torch.unsqueeze(torch.eye(num_corr, device=feat_anc.device), 0) | |
| if radius_mask_col is not None: | |
| neg_mask_col += radius_mask_col | |
| pos_margin = 1 - m | |
| neg_margin = m | |
| pos_optimal = 1 + m | |
| neg_optimal = -m | |
| neg_mat_row = dist_mat - 128 * neg_mask_row | |
| neg_mat_col = dist_mat - 128 * neg_mask_col | |
| lse_positive = torch.logsumexp( | |
| -log_scale | |
| * (pos_vec[..., None] - pos_margin) | |
| * torch.clamp(pos_optimal - pos_vec[..., None], min=0).detach(), | |
| dim=-1, | |
| ) | |
| lse_negative_row = torch.logsumexp( | |
| log_scale | |
| * (neg_mat_row - neg_margin) | |
| * torch.clamp(neg_mat_row - neg_optimal, min=0).detach(), | |
| dim=-1, | |
| ) | |
| lse_negative_col = torch.logsumexp( | |
| log_scale | |
| * (neg_mat_col - neg_margin) | |
| * torch.clamp(neg_mat_col - neg_optimal, min=0).detach(), | |
| dim=-2, | |
| ) | |
| loss_row = F.softplus(lse_positive + lse_negative_row) / log_scale | |
| loss_col = F.softplus(lse_positive + lse_negative_col) / log_scale | |
| else: | |
| raise NotImplementedError() | |
| if dist_type == "cosine_dist": | |
| err_row = dist_mat - torch.unsqueeze(pos_vec, -1) | |
| err_col = dist_mat - torch.unsqueeze(pos_vec, -2) | |
| elif dist_type == "euclidean_dist" or dist_type == "euclidean_dist_no_norm": | |
| err_row = torch.unsqueeze(pos_vec, -1) - dist_mat | |
| err_col = torch.unsqueeze(pos_vec, -2) - dist_mat | |
| else: | |
| raise NotImplementedError() | |
| if radius_mask_row is not None: | |
| err_row = err_row - 10 * radius_mask_row | |
| if radius_mask_col is not None: | |
| err_col = err_col - 10 * radius_mask_col | |
| err_row = torch.sum(torch.clamp(err_row, min=0), dim=-1) | |
| err_col = torch.sum(torch.clamp(err_col, min=0), dim=-2) | |
| loss = 0 | |
| accuracy = 0 | |
| tot_loss = (loss_row + loss_col) / 2 | |
| if corr_weight is not None: | |
| tot_loss = tot_loss * corr_weight | |
| for i in range(batch_size): | |
| if corr_weight is not None: | |
| loss += torch.sum(tot_loss[i][inlier_mask[i]]) / ( | |
| torch.sum(corr_weight[i][inlier_mask[i]]) + 1e-6 | |
| ) | |
| else: | |
| loss += torch.mean(tot_loss[i][inlier_mask[i]]) | |
| cnt_err_row = torch.count_nonzero(err_row[i][inlier_mask[i]]).float() | |
| cnt_err_col = torch.count_nonzero(err_col[i][inlier_mask[i]]).float() | |
| tot_err = cnt_err_row + cnt_err_col | |
| if inlier_num[i] != 0: | |
| accuracy += 1.0 - tot_err / inlier_num[i] / batch_size / 2.0 | |
| else: | |
| accuracy += 1.0 | |
| matched_mask = torch.logical_and(torch.eq(err_row, 0), torch.eq(err_col, 0)) | |
| matched_mask = torch.logical_and(matched_mask, inlier_mask) | |
| loss /= batch_size | |
| accuracy /= batch_size | |
| return loss, accuracy, matched_mask | |
| # for the neighborhood areas of keypoints extracted from normal image, the score from noise_score_map should be close | |
| # for the rest, the noise image's score should less than normal image | |
| # input: score_map [batch_size, H, W, 1]; indices [2, k, 2] | |
| # output: loss [scalar] | |
| def make_noise_score_map_loss( | |
| score_map, noise_score_map, indices, batch_size, thld=0.0 | |
| ): | |
| H, W = score_map.shape[1:3] | |
| loss = 0 | |
| for i in range(batch_size): | |
| kpts_coords = indices[i].T # (2, num_kpts) | |
| mask = torch.zeros([H, W], device=score_map.device) | |
| mask[kpts_coords.cpu().numpy()] = 1 | |
| # using 3x3 kernel to put kpts' neightborhood area into the mask | |
| kernel = torch.ones([1, 1, 3, 3], device=score_map.device) | |
| mask = F.conv2d(mask.unsqueeze(0).unsqueeze(0), kernel, padding=1)[0, 0] > 0 | |
| loss1 = torch.sum( | |
| torch.abs(score_map[i] - noise_score_map[i]).squeeze() * mask | |
| ) / torch.sum(mask) | |
| loss2 = torch.sum( | |
| torch.clamp(noise_score_map[i] - score_map[i] - thld, min=0).squeeze() | |
| * torch.logical_not(mask) | |
| ) / (H * W - torch.sum(mask)) | |
| loss += loss1 | |
| loss += loss2 | |
| if i == 0: | |
| first_mask = mask | |
| return loss, first_mask | |
| def make_noise_score_map_loss_labelmap( | |
| score_map, noise_score_map, labelmap, batch_size, thld=0.0 | |
| ): | |
| H, W = score_map.shape[1:3] | |
| loss = 0 | |
| for i in range(batch_size): | |
| # using 3x3 kernel to put kpts' neightborhood area into the mask | |
| kernel = torch.ones([1, 1, 3, 3], device=score_map.device) | |
| mask = ( | |
| F.conv2d( | |
| labelmap[i].unsqueeze(0).to(score_map.device).float(), kernel, padding=1 | |
| )[0, 0] | |
| > 0 | |
| ) | |
| loss1 = torch.sum( | |
| torch.abs(score_map[i] - noise_score_map[i]).squeeze() * mask | |
| ) / torch.sum(mask) | |
| loss2 = torch.sum( | |
| torch.clamp(noise_score_map[i] - score_map[i] - thld, min=0).squeeze() | |
| * torch.logical_not(mask) | |
| ) / (H * W - torch.sum(mask)) | |
| loss += loss1 | |
| loss += loss2 | |
| if i == 0: | |
| first_mask = mask | |
| return loss, first_mask | |
| def make_score_map_peakiness_loss(score_map, scores, batch_size): | |
| H, W = score_map.shape[1:3] | |
| loss = 0 | |
| for i in range(batch_size): | |
| loss += torch.mean(scores[i]) - torch.mean(score_map[i]) | |
| loss /= batch_size | |
| return 1 - loss | |