|
import numpy as np
|
|
import cv2
|
|
from typing import List, Tuple
|
|
from dnafiber.postprocess.skan import find_endpoints, compute_points_angle
|
|
from scipy.spatial.distance import cdist
|
|
|
|
from scipy.sparse.csgraph import connected_components
|
|
from scipy.sparse import csr_array
|
|
from skimage.morphology import skeletonize
|
|
from dnafiber.postprocess.skan import find_line_intersection
|
|
from dnafiber.postprocess.fiber import Fiber, FiberProps, Bbox
|
|
from itertools import compress
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib.colors import ListedColormap
|
|
|
|
cmlabel = ListedColormap(["black", "red", "green"])
|
|
|
|
MIN_ANGLE = 20
|
|
MIN_BRANCH_LENGTH = 10
|
|
MIN_BRANCH_DISTANCE = 30
|
|
|
|
|
|
def handle_multiple_fiber_in_cc(fiber, junctions_fiber, coordinates):
|
|
for y, x in junctions_fiber:
|
|
fiber[y - 1 : y + 2, x - 1 : x + 2] = 0
|
|
|
|
endpoints = find_endpoints(fiber > 0)
|
|
endpoints = np.asarray(endpoints)
|
|
|
|
|
|
distances = np.linalg.norm(
|
|
np.expand_dims(endpoints, axis=1) - np.expand_dims(junctions_fiber, axis=0),
|
|
axis=2,
|
|
)
|
|
|
|
distances = distances < 5
|
|
endpoints = endpoints[distances.any(axis=1)]
|
|
|
|
retval, branches, branches_stats, _ = cv2.connectedComponentsWithStatsWithAlgorithm(
|
|
fiber, connectivity=8, ccltype=cv2.CCL_BOLELLI, ltype=cv2.CV_16U
|
|
)
|
|
branches_bboxes = branches_stats[
|
|
:,
|
|
[cv2.CC_STAT_LEFT, cv2.CC_STAT_TOP, cv2.CC_STAT_WIDTH, cv2.CC_STAT_HEIGHT],
|
|
]
|
|
|
|
num_branches = branches_bboxes.shape[0] - 1
|
|
|
|
endpoints_ids = np.zeros((len(endpoints),), dtype=np.uint16)
|
|
endpoints_color = np.zeros((len(endpoints),), dtype=np.uint8)
|
|
for i, endpoint in enumerate(endpoints):
|
|
|
|
branch_id = branches[endpoint[0], endpoint[1]]
|
|
|
|
if branch_id != 0:
|
|
endpoints_ids[i] = branch_id
|
|
endpoints_color[i] = fiber[endpoint[0], endpoint[1]]
|
|
|
|
|
|
kept_branches = set()
|
|
for i in range(1, num_branches + 1):
|
|
|
|
branch = branches == i
|
|
|
|
area = np.sum(branch.astype(np.uint8))
|
|
|
|
if area < MIN_BRANCH_LENGTH:
|
|
branches[branch] = 0
|
|
else:
|
|
kept_branches.add(i)
|
|
|
|
|
|
remaining_idxs = np.isin(endpoints_ids, np.asarray(list(kept_branches)))
|
|
if remaining_idxs.sum() == 0:
|
|
return []
|
|
endpoints = endpoints[remaining_idxs]
|
|
|
|
endpoints_color = endpoints_color[remaining_idxs]
|
|
endpoints_ids = endpoints_ids[remaining_idxs]
|
|
|
|
|
|
angles = compute_points_angle(fiber, endpoints, steps=15)
|
|
angles = np.rad2deg(angles)
|
|
|
|
endpoints_angles_diff = cdist(angles[:, None], angles[:, None], metric="cityblock")
|
|
|
|
|
|
endpoints_angles_diff[range(len(endpoints)), range(len(endpoints))] = np.inf
|
|
endpoints_distances = cdist(endpoints, endpoints, metric="euclidean")
|
|
|
|
endpoints_distances[range(len(endpoints)), range(len(endpoints))] = np.inf
|
|
|
|
|
|
endpoints_distances[endpoints_distances > MIN_BRANCH_DISTANCE] = np.inf
|
|
endpoints_distances[endpoints_angles_diff > MIN_ANGLE] = np.inf
|
|
|
|
matchB = np.argmin(endpoints_distances, axis=1)
|
|
values = np.take_along_axis(endpoints_distances, matchB[:, None], axis=1)
|
|
|
|
added_edges = dict()
|
|
N = len(endpoints)
|
|
A = np.eye(N, dtype=np.uint8)
|
|
for i in range(N):
|
|
for j in range(N):
|
|
if i == j:
|
|
continue
|
|
if endpoints_ids[i] == endpoints_ids[j]:
|
|
A[i, j] = 1
|
|
A[j, i] = 1
|
|
|
|
if matchB[i] == j and values[i, 0] < np.inf:
|
|
added_edges[i] = j
|
|
A[i, j] = 1
|
|
A[j, i] = 1
|
|
|
|
A = csr_array(A)
|
|
n, ccs = connected_components(A, directed=False, return_labels=True)
|
|
unique_clusters = np.unique(ccs)
|
|
results = []
|
|
for c in unique_clusters:
|
|
idx = np.where(ccs == c)[0]
|
|
branches_ids = np.unique(endpoints_ids[idx])
|
|
|
|
unique_branches = np.logical_or.reduce(
|
|
[branches == i for i in branches_ids], axis=0
|
|
)
|
|
|
|
commons_bboxes = branches_bboxes[branches_ids]
|
|
|
|
min_x = np.min(commons_bboxes[:, 0])
|
|
min_y = np.min(commons_bboxes[:, 1])
|
|
max_x = np.max(commons_bboxes[:, 0] + commons_bboxes[:, 2])
|
|
max_y = np.max(commons_bboxes[:, 1] + commons_bboxes[:, 3])
|
|
|
|
new_fiber = fiber[min_y:max_y, min_x:max_x]
|
|
new_fiber = unique_branches[min_y:max_y, min_x:max_x] * new_fiber
|
|
for cidx in idx:
|
|
if cidx not in added_edges:
|
|
continue
|
|
pointA = endpoints[cidx]
|
|
pointB = endpoints[added_edges[cidx]]
|
|
pointA = (
|
|
pointA[1] - min_x,
|
|
pointA[0] - min_y,
|
|
)
|
|
pointB = (
|
|
pointB[1] - min_x,
|
|
pointB[0] - min_y,
|
|
)
|
|
colA = endpoints_color[cidx]
|
|
colB = endpoints_color[added_edges[cidx]]
|
|
new_fiber = cv2.line(
|
|
new_fiber,
|
|
pointA,
|
|
pointB,
|
|
color=2 if colA != colB else int(colA),
|
|
thickness=1,
|
|
)
|
|
|
|
bbox = (
|
|
coordinates[0] + min_x,
|
|
coordinates[1] + min_y,
|
|
max_x - min_x,
|
|
max_y - min_y,
|
|
)
|
|
bbox = Bbox(x=bbox[0], y=bbox[1], width=bbox[2], height=bbox[3])
|
|
result = Fiber(bbox=bbox, data=new_fiber)
|
|
results.append(result)
|
|
return results
|
|
|
|
|
|
def handle_ccs_with_junctions(
|
|
ccs: List[np.ndarray],
|
|
junctions: List[List[Tuple[int, int]]],
|
|
coordinates: List[Tuple[int, int]],
|
|
):
|
|
"""
|
|
Handle the connected components with junctions.
|
|
The function takes a list of connected components, a list of list of junctions and a list of coordinates.
|
|
The junctions
|
|
The coordinates corresponds to the top left corner of the connected component.
|
|
"""
|
|
jncts_fibers = []
|
|
for fiber, junction, coordinate in zip(ccs, junctions, coordinates):
|
|
jncts_fibers += handle_multiple_fiber_in_cc(fiber, junction, coordinate)
|
|
|
|
return jncts_fibers
|
|
|
|
|
|
def refine_segmentation(segmentation, fix_junctions=True, show=False):
|
|
skeleton = skeletonize(segmentation > 0, method="lee").astype(np.uint8)
|
|
skeleton_gt = skeleton * segmentation
|
|
retval, labels, stats, centroids = cv2.connectedComponentsWithStatsWithAlgorithm(
|
|
skeleton, connectivity=8, ccltype=cv2.CCL_BOLELLI, ltype=cv2.CV_16U
|
|
)
|
|
|
|
bboxes = stats[
|
|
:,
|
|
[
|
|
cv2.CC_STAT_LEFT,
|
|
cv2.CC_STAT_TOP,
|
|
cv2.CC_STAT_WIDTH,
|
|
cv2.CC_STAT_HEIGHT,
|
|
],
|
|
]
|
|
|
|
local_fibers = []
|
|
coordinates = []
|
|
junctions = []
|
|
for i in range(1, retval):
|
|
bbox = bboxes[i]
|
|
x1, y1, w, h = bbox
|
|
local_gt = skeleton_gt[y1 : y1 + h, x1 : x1 + w]
|
|
local_label = (labels[y1 : y1 + h, x1 : x1 + w] == i).astype(np.uint8)
|
|
local_fiber = local_gt * local_label
|
|
local_fibers.append(local_fiber)
|
|
coordinates.append(np.asarray([x1, y1, w, h]))
|
|
local_junctions = find_line_intersection(local_fiber > 0)
|
|
local_junctions = np.where(local_junctions)
|
|
local_junctions = np.array(local_junctions).transpose()
|
|
junctions.append(local_junctions)
|
|
if show:
|
|
for bbox, junction in zip(coordinates, junctions):
|
|
x, y, w, h = bbox
|
|
junction_to_global = np.array(junction) + np.array([y, x])
|
|
|
|
plt.scatter(
|
|
junction_to_global[:, 1],
|
|
junction_to_global[:, 0],
|
|
color="white",
|
|
s=30,
|
|
alpha=0.35,
|
|
)
|
|
|
|
plt.imshow(skeleton_gt, cmap=cmlabel, interpolation="nearest")
|
|
plt.axis("off")
|
|
plt.xticks([])
|
|
plt.yticks([])
|
|
plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
|
|
plt.show()
|
|
|
|
fibers = []
|
|
if fix_junctions:
|
|
has_junctions = [len(j) > 0 for j in junctions]
|
|
for fiber, coordinate in zip(
|
|
compress(local_fibers, np.logical_not(has_junctions)),
|
|
compress(coordinates, np.logical_not(has_junctions)),
|
|
):
|
|
bbox = Bbox(
|
|
x=coordinate[0],
|
|
y=coordinate[1],
|
|
width=coordinate[2],
|
|
height=coordinate[3],
|
|
)
|
|
fibers.append(Fiber(bbox=bbox, data=fiber))
|
|
|
|
fibers += handle_ccs_with_junctions(
|
|
compress(local_fibers, has_junctions),
|
|
compress(junctions, has_junctions),
|
|
compress(coordinates, has_junctions),
|
|
)
|
|
else:
|
|
for fiber, coordinate in zip(local_fibers, coordinates):
|
|
bbox = Bbox(
|
|
x=coordinate[0],
|
|
y=coordinate[1],
|
|
width=coordinate[2],
|
|
height=coordinate[3],
|
|
)
|
|
fibers.append(Fiber(bbox=bbox, data=fiber))
|
|
|
|
fiberprops = [FiberProps(fiber=f, fiber_id=i) for i, f in enumerate(fibers)]
|
|
|
|
return fiberprops
|
|
|