Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| def edgeLoss(preds_edges, edges): | |
| """ | |
| Args: | |
| preds_edges: with shape [b, c, h , w] | |
| edges: with shape [b, c, h, w] | |
| Returns: Edge losses | |
| """ | |
| mask = (edges > 0.5).float() | |
| b, c, h, w = mask.shape | |
| num_pos = torch.sum(mask, dim=[1, 2, 3]).float() | |
| num_neg = c * h * w - num_pos | |
| neg_weights = (num_neg / (num_pos + num_neg)).unsqueeze(1).unsqueeze(2).unsqueeze(3) | |
| pos_weights = (num_pos / (num_pos + num_neg)).unsqueeze(1).unsqueeze(2).unsqueeze(3) | |
| weight = neg_weights * mask + pos_weights * (1 - mask) # weight for debug | |
| losses = F.binary_cross_entropy_with_logits(preds_edges.float(), edges.float(), weight=weight, reduction='none') | |
| loss = torch.mean(losses) | |
| return loss | |
| class EdgeAcc(nn.Module): | |
| """ | |
| Measure the accuracy of the edge map | |
| """ | |
| def __init__(self, threshold=0.5): | |
| super(EdgeAcc, self).__init__() | |
| self.threshold = threshold | |
| def __call__(self, pred_edge, gt_edge): | |
| """ | |
| Args: | |
| pred_edge: Predicted edges, with shape [b, c, h, w] | |
| gt_edge: GT edges, with shape [b, c, h, w] | |
| Returns: The prediction accuracy and the recall of the edges | |
| """ | |
| labels = gt_edge > self.threshold | |
| preds = pred_edge > self.threshold | |
| relevant = torch.sum(labels.float()) | |
| selected = torch.sum(preds.float()) | |
| if relevant == 0 and selected == 0: | |
| return torch.tensor(1), torch.tensor(1) | |
| true_positive = ((preds == labels) * labels).float() | |
| recall = torch.sum(true_positive) / (relevant + 1e-8) | |
| precision = torch.sum(true_positive) / (selected + 1e-8) | |
| return precision, recall | |
| if __name__ == '__main__': | |
| edge = torch.zeros([2, 1, 10, 10]) # [b, 1, h, w] -> the extracted edges | |
| edge[0, :, 2:8, 2:8] = 1 | |
| edge[1, :, 3:7, 3:7] = 1 | |
| mask = (edge > 0.5).float() | |
| b, c, h, w = mask.shape | |
| num_pos = torch.sum(mask, dim=[1, 2, 3]).float() | |
| num_neg = c * h * w - num_pos | |
| print(num_pos, num_neg) | |
| n = num_neg / (num_pos + num_neg) | |
| p = num_pos / (num_pos + num_neg) | |
| n = n.unsqueeze(1).unsqueeze(2).unsqueeze(3) | |
| p = p.unsqueeze(1).unsqueeze(2).unsqueeze(3) | |
| print(n * mask + p * (1 - mask)) | |
| # weight = num_neg / (num_pos + num_neg) * mask + num_pos / (num_pos + num_neg) * (1 - mask) | |
| # print(weight) | |