Spaces:
Configuration error
Configuration error
| from __future__ import absolute_import | |
| import warnings | |
| import numpy as np | |
| import torch | |
| from torchvision.transforms import functional as F | |
| from ..core.transforms_interface import BasicTransform | |
| __all__ = ["ToTensorV2"] | |
| def img_to_tensor(im, normalize=None): | |
| tensor = torch.from_numpy(np.moveaxis(im / (255.0 if im.dtype == np.uint8 else 1), -1, 0).astype(np.float32)) | |
| if normalize is not None: | |
| return F.normalize(tensor, **normalize) | |
| return tensor | |
| def mask_to_tensor(mask, num_classes, sigmoid): | |
| if num_classes > 1: | |
| if not sigmoid: | |
| # softmax | |
| long_mask = np.zeros((mask.shape[:2]), dtype=np.int64) | |
| if len(mask.shape) == 3: | |
| for c in range(mask.shape[2]): | |
| long_mask[mask[..., c] > 0] = c | |
| else: | |
| long_mask[mask > 127] = 1 | |
| long_mask[mask == 0] = 0 | |
| mask = long_mask | |
| else: | |
| mask = np.moveaxis(mask / (255.0 if mask.dtype == np.uint8 else 1), -1, 0).astype(np.float32) | |
| else: | |
| mask = np.expand_dims(mask / (255.0 if mask.dtype == np.uint8 else 1), 0).astype(np.float32) | |
| return torch.from_numpy(mask) | |
| class ToTensor(BasicTransform): | |
| """Convert image and mask to `torch.Tensor` and divide by 255 if image or mask are `uint8` type. | |
| This transform is now removed from custom_albumentations. If you need it downgrade the library to version 0.5.2. | |
| Args: | |
| num_classes (int): only for segmentation | |
| sigmoid (bool, optional): only for segmentation, transform mask to LongTensor or not. | |
| normalize (dict, optional): dict with keys [mean, std] to pass it into torchvision.normalize | |
| """ | |
| def __init__(self, num_classes=1, sigmoid=True, normalize=None): | |
| raise RuntimeError( | |
| "`ToTensor` is obsolete and it was removed from custom_albumentations. Please use `ToTensorV2` instead - " | |
| "https://albumentations.ai/docs/api_reference/pytorch/transforms/" | |
| "#albumentations.pytorch.transforms.ToTensorV2. " | |
| "\n\nIf you need `ToTensor` downgrade Albumentations to version 0.5.2." | |
| ) | |
| class ToTensorV2(BasicTransform): | |
| """Convert image and mask to `torch.Tensor`. The numpy `HWC` image is converted to pytorch `CHW` tensor. | |
| If the image is in `HW` format (grayscale image), it will be converted to pytorch `HW` tensor. | |
| This is a simplified and improved version of the old `ToTensor` | |
| transform (`ToTensor` was deprecated, and now it is not present in Albumentations. You should use `ToTensorV2` | |
| instead). | |
| Args: | |
| transpose_mask (bool): If True and an input mask has three dimensions, this transform will transpose dimensions | |
| so the shape `[height, width, num_channels]` becomes `[num_channels, height, width]`. The latter format is a | |
| standard format for PyTorch Tensors. Default: False. | |
| always_apply (bool): Indicates whether this transformation should be always applied. Default: True. | |
| p (float): Probability of applying the transform. Default: 1.0. | |
| """ | |
| def __init__(self, transpose_mask=False, always_apply=True, p=1.0): | |
| super(ToTensorV2, self).__init__(always_apply=always_apply, p=p) | |
| self.transpose_mask = transpose_mask | |
| def targets(self): | |
| return {"image": self.apply, "mask": self.apply_to_mask, "masks": self.apply_to_masks} | |
| def apply(self, img, **params): # skipcq: PYL-W0613 | |
| if len(img.shape) not in [2, 3]: | |
| raise ValueError("Albumentations only supports images in HW or HWC format") | |
| if len(img.shape) == 2: | |
| img = np.expand_dims(img, 2) | |
| return torch.from_numpy(img.transpose(2, 0, 1)) | |
| def apply_to_mask(self, mask, **params): # skipcq: PYL-W0613 | |
| if self.transpose_mask and mask.ndim == 3: | |
| mask = mask.transpose(2, 0, 1) | |
| return torch.from_numpy(mask) | |
| def apply_to_masks(self, masks, **params): | |
| return [self.apply_to_mask(mask, **params) for mask in masks] | |
| def get_transform_init_args_names(self): | |
| return ("transpose_mask",) | |
| def get_params_dependent_on_targets(self, params): | |
| return {} | |