|
import albumentations as A
|
|
import nntools.dataset as D
|
|
import numpy as np
|
|
from albumentations.pytorch import ToTensorV2
|
|
from lightning import LightningDataModule
|
|
from sklearn.model_selection import train_test_split
|
|
from torch.utils.data import DataLoader
|
|
from skimage.measure import label, regionprops
|
|
from skimage.morphology import skeletonize, dilation
|
|
from skimage.segmentation import expand_labels
|
|
import torch
|
|
from nntools.dataset.composer import CacheBullet
|
|
|
|
|
|
@D.nntools_wrapper
|
|
def convert_mask(mask):
|
|
output = np.zeros(mask.shape[:2], dtype=np.uint8)
|
|
output[mask[:, :, 0] > 200] = 1
|
|
output[mask[:, :, 1] > 200] = 2
|
|
binary_mask = output > 0
|
|
skeleton = skeletonize(binary_mask) * output
|
|
output = expand_labels(skeleton, 3)
|
|
output = np.clip(output, 0, 2)
|
|
return {"mask": output}
|
|
|
|
|
|
@D.nntools_wrapper
|
|
def extract_bbox(mask):
|
|
binary_mask = mask > 0
|
|
labelled = label(binary_mask)
|
|
props = regionprops(labelled, intensity_image=mask)
|
|
skeleton = skeletonize(binary_mask) * mask
|
|
mask = dilation(skeleton, np.ones((3, 3)))
|
|
bboxes = []
|
|
masks = []
|
|
|
|
for prop in props:
|
|
minr, minc, maxr, maxc = prop.bbox
|
|
bboxes.append([minc, minr, maxc, maxr])
|
|
masks.append((labelled == prop.label).astype(np.uint8))
|
|
if not masks:
|
|
masks = np.zeros_like(mask)[np.newaxis, :, :]
|
|
masks = np.array(masks)
|
|
masks = np.moveaxis(masks, 0, -1)
|
|
|
|
return {
|
|
"bboxes": np.array(bboxes),
|
|
"mask": masks,
|
|
"fiber_ids": np.array([p.label for p in props]),
|
|
}
|
|
|
|
|
|
class FiberDatamodule(LightningDataModule):
|
|
def __init__(
|
|
self,
|
|
root_img,
|
|
crop_size=(256, 256),
|
|
shape=1024,
|
|
batch_size=32,
|
|
num_workers=8,
|
|
use_bbox=False,
|
|
**kwargs,
|
|
):
|
|
self.shape = shape
|
|
self.root_img = str(root_img)
|
|
self.crop_size = crop_size
|
|
self.batch_size = batch_size
|
|
self.num_workers = num_workers
|
|
self.kwargs = kwargs
|
|
self.use_bbox = use_bbox
|
|
|
|
super().__init__()
|
|
|
|
def setup(self, *args, **kwargs):
|
|
def _get_dataset(version):
|
|
dataset = D.MultiImageDataset(
|
|
{
|
|
"image": f"{self.root_img}/{version}/images/",
|
|
"mask": f"{self.root_img}/{version}/annotations/",
|
|
},
|
|
shape=(self.shape, self.shape),
|
|
use_cache=self.kwargs.get("use_cache", False),
|
|
cache_option=self.kwargs.get("cache_option", None),
|
|
)
|
|
dataset.img_filepath["image"] = np.asarray(
|
|
sorted(
|
|
list(dataset.img_filepath["image"]),
|
|
key=lambda x: (x.parent.stem, x.stem),
|
|
)
|
|
)
|
|
dataset.img_filepath["mask"] = np.asarray(
|
|
sorted(
|
|
list(dataset.img_filepath["mask"]),
|
|
key=lambda x: (x.parent.stem, x.stem),
|
|
)
|
|
)
|
|
dataset.composer = D.Composition()
|
|
dataset.composer << convert_mask
|
|
if self.use_bbox:
|
|
dataset.composer << extract_bbox
|
|
|
|
return dataset
|
|
|
|
self.train = _get_dataset("train")
|
|
self.val = _get_dataset("train")
|
|
self.test = _get_dataset("test")
|
|
self.train.composer << CacheBullet()
|
|
self.val.use_cache = False
|
|
self.test.use_cache = False
|
|
|
|
stratify = []
|
|
for f in self.train.img_filepath["image"]:
|
|
if "tile" in f.stem:
|
|
stratify.append(int(f.parent.stem))
|
|
else:
|
|
stratify.append(25)
|
|
train_idx, val_idx = train_test_split(
|
|
np.arange(len(self.train)),
|
|
stratify=stratify,
|
|
test_size=0.2,
|
|
random_state=42,
|
|
)
|
|
self.train.subset(train_idx)
|
|
self.val.subset(val_idx)
|
|
|
|
self.train.composer.add(*self.get_train_composer())
|
|
self.val.composer.add(*self.cast_operators())
|
|
self.test.composer.add(*self.cast_operators())
|
|
|
|
def get_train_composer(self):
|
|
transforms = []
|
|
if self.crop_size is not None:
|
|
transforms.append(
|
|
A.CropNonEmptyMaskIfExists(
|
|
width=self.crop_size[0], height=self.crop_size[1]
|
|
),
|
|
)
|
|
return [
|
|
A.Compose(
|
|
transforms
|
|
+ [
|
|
A.HorizontalFlip(),
|
|
A.VerticalFlip(),
|
|
A.Affine(),
|
|
A.ElasticTransform(),
|
|
A.RandomRotate90(),
|
|
A.OneOf(
|
|
[
|
|
A.RandomBrightnessContrast(
|
|
brightness_limit=(-0.2, 0.1),
|
|
contrast_limit=(-0.2, 0.1),
|
|
p=0.5,
|
|
),
|
|
A.HueSaturationValue(
|
|
hue_shift_limit=(-5, 5),
|
|
sat_shift_limit=(-20, 20),
|
|
val_shift_limit=(-20, 20),
|
|
p=0.5,
|
|
),
|
|
]
|
|
),
|
|
A.GaussNoise(std_range=(0.0, 0.1), p=0.5),
|
|
],
|
|
bbox_params=A.BboxParams(
|
|
format="pascal_voc", label_fields=["fiber_ids"], min_visibility=0.95
|
|
)
|
|
if self.use_bbox
|
|
else None,
|
|
),
|
|
*self.cast_operators(),
|
|
]
|
|
|
|
def cast_operators(self):
|
|
return [
|
|
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
|
|
if not self.use_bbox
|
|
else A.Normalize(
|
|
mean=(
|
|
0.0,
|
|
0.0,
|
|
0.0,
|
|
),
|
|
std=(1.0, 1.0, 1.0),
|
|
max_pixel_value=255,
|
|
),
|
|
ToTensorV2(),
|
|
]
|
|
|
|
def train_dataloader(self):
|
|
if self.use_bbox:
|
|
return DataLoader(
|
|
self.train,
|
|
batch_size=self.batch_size,
|
|
shuffle=True,
|
|
num_workers=self.num_workers,
|
|
pin_memory=True,
|
|
persistent_workers=True,
|
|
collate_fn=bbox_collate_fn,
|
|
)
|
|
|
|
else:
|
|
return DataLoader(
|
|
self.train,
|
|
batch_size=self.batch_size,
|
|
shuffle=True,
|
|
num_workers=self.num_workers,
|
|
pin_memory=True,
|
|
persistent_workers=True,
|
|
)
|
|
|
|
def val_dataloader(self):
|
|
if self.use_bbox:
|
|
return DataLoader(
|
|
self.val,
|
|
batch_size=self.batch_size,
|
|
shuffle=False,
|
|
num_workers=self.num_workers,
|
|
pin_memory=True,
|
|
persistent_workers=True,
|
|
collate_fn=bbox_collate_fn,
|
|
)
|
|
return DataLoader(
|
|
self.val,
|
|
batch_size=self.batch_size,
|
|
shuffle=False,
|
|
num_workers=self.num_workers,
|
|
)
|
|
|
|
def test_dataloader(self):
|
|
if self.use_bbox:
|
|
return DataLoader(
|
|
self.test,
|
|
batch_size=self.batch_size,
|
|
shuffle=False,
|
|
num_workers=self.num_workers,
|
|
pin_memory=True,
|
|
persistent_workers=True,
|
|
collate_fn=bbox_collate_fn,
|
|
)
|
|
return DataLoader(
|
|
self.test,
|
|
batch_size=self.batch_size,
|
|
shuffle=False,
|
|
num_workers=self.num_workers,
|
|
)
|
|
|
|
|
|
def bbox_collate_fn(batch):
|
|
images = []
|
|
targets = []
|
|
|
|
for b in batch:
|
|
target = dict()
|
|
|
|
target["boxes"] = torch.from_numpy(b["bboxes"])
|
|
if target["boxes"].shape[0] == 0:
|
|
target["boxes"] = torch.zeros((0, 4), dtype=torch.float32)
|
|
images.append(b["image"])
|
|
target["boxes"] = torch.from_numpy(b["bboxes"])
|
|
target["masks"] = b["mask"].permute(2, 0, 1)
|
|
if target["boxes"].shape[0] == 0:
|
|
target["labels"] = torch.zeros(1, dtype=torch.int64)
|
|
else:
|
|
target["labels"] = torch.ones_like(target["boxes"][:, 0], dtype=torch.int64)
|
|
|
|
targets.append(target)
|
|
|
|
return {
|
|
"image": torch.stack(images),
|
|
"targets": targets,
|
|
}
|
|
|