ClementP's picture
Upload 55 files
69591a9 verified
import numpy as np
import cv2
from typing import List, Tuple
from dnafiber.postprocess.skan import find_endpoints, compute_points_angle
from scipy.spatial.distance import cdist
from scipy.sparse.csgraph import connected_components
from scipy.sparse import csr_array
from skimage.morphology import skeletonize
from dnafiber.postprocess.skan import find_line_intersection
from dnafiber.postprocess.fiber import Fiber, FiberProps, Bbox
from itertools import compress
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
cmlabel = ListedColormap(["black", "red", "green"])
MIN_ANGLE = 20
MIN_BRANCH_LENGTH = 10
MIN_BRANCH_DISTANCE = 30
def handle_multiple_fiber_in_cc(fiber, junctions_fiber, coordinates):
for y, x in junctions_fiber:
fiber[y - 1 : y + 2, x - 1 : x + 2] = 0
endpoints = find_endpoints(fiber > 0)
endpoints = np.asarray(endpoints)
# We only keep the endpoints that are close to the junction
# We compute the distance between the endpoints and the junctions
distances = np.linalg.norm(
np.expand_dims(endpoints, axis=1) - np.expand_dims(junctions_fiber, axis=0),
axis=2,
)
# We only keep the endpoints that are close to the junctions
distances = distances < 5
endpoints = endpoints[distances.any(axis=1)]
retval, branches, branches_stats, _ = cv2.connectedComponentsWithStatsWithAlgorithm(
fiber, connectivity=8, ccltype=cv2.CCL_BOLELLI, ltype=cv2.CV_16U
)
branches_bboxes = branches_stats[
:,
[cv2.CC_STAT_LEFT, cv2.CC_STAT_TOP, cv2.CC_STAT_WIDTH, cv2.CC_STAT_HEIGHT],
]
num_branches = branches_bboxes.shape[0] - 1
# We associate the endpoints to the branches
endpoints_ids = np.zeros((len(endpoints),), dtype=np.uint16)
endpoints_color = np.zeros((len(endpoints),), dtype=np.uint8)
for i, endpoint in enumerate(endpoints):
# Get the branch id
branch_id = branches[endpoint[0], endpoint[1]]
# Check if the branch id is not 0
if branch_id != 0:
endpoints_ids[i] = branch_id
endpoints_color[i] = fiber[endpoint[0], endpoint[1]]
# We remove the small branches
kept_branches = set()
for i in range(1, num_branches + 1):
# Get the branch
branch = branches == i
# Compute the area of the branch
area = np.sum(branch.astype(np.uint8))
# If the area is less than 10 pixels, remove the branch
if area < MIN_BRANCH_LENGTH:
branches[branch] = 0
else:
kept_branches.add(i)
# We remove the endpoints that are in the filtered branches
remaining_idxs = np.isin(endpoints_ids, np.asarray(list(kept_branches)))
if remaining_idxs.sum() == 0:
return []
endpoints = endpoints[remaining_idxs]
endpoints_color = endpoints_color[remaining_idxs]
endpoints_ids = endpoints_ids[remaining_idxs]
# We compute the angles of the endpoints
angles = compute_points_angle(fiber, endpoints, steps=15)
angles = np.rad2deg(angles)
# We compute the difference of angles between all the endpoints
endpoints_angles_diff = cdist(angles[:, None], angles[:, None], metric="cityblock")
# Put inf to the diagonal
endpoints_angles_diff[range(len(endpoints)), range(len(endpoints))] = np.inf
endpoints_distances = cdist(endpoints, endpoints, metric="euclidean")
endpoints_distances[range(len(endpoints)), range(len(endpoints))] = np.inf
# We sort by the distance
endpoints_distances[endpoints_distances > MIN_BRANCH_DISTANCE] = np.inf
endpoints_distances[endpoints_angles_diff > MIN_ANGLE] = np.inf
matchB = np.argmin(endpoints_distances, axis=1)
values = np.take_along_axis(endpoints_distances, matchB[:, None], axis=1)
added_edges = dict()
N = len(endpoints)
A = np.eye(N, dtype=np.uint8)
for i in range(N):
for j in range(N):
if i == j:
continue
if endpoints_ids[i] == endpoints_ids[j]:
A[i, j] = 1
A[j, i] = 1
if matchB[i] == j and values[i, 0] < np.inf:
added_edges[i] = j
A[i, j] = 1
A[j, i] = 1
A = csr_array(A)
n, ccs = connected_components(A, directed=False, return_labels=True)
unique_clusters = np.unique(ccs)
results = []
for c in unique_clusters:
idx = np.where(ccs == c)[0]
branches_ids = np.unique(endpoints_ids[idx])
unique_branches = np.logical_or.reduce(
[branches == i for i in branches_ids], axis=0
)
commons_bboxes = branches_bboxes[branches_ids]
# Compute the union of the bboxes
min_x = np.min(commons_bboxes[:, 0])
min_y = np.min(commons_bboxes[:, 1])
max_x = np.max(commons_bboxes[:, 0] + commons_bboxes[:, 2])
max_y = np.max(commons_bboxes[:, 1] + commons_bboxes[:, 3])
new_fiber = fiber[min_y:max_y, min_x:max_x]
new_fiber = unique_branches[min_y:max_y, min_x:max_x] * new_fiber
for cidx in idx:
if cidx not in added_edges:
continue
pointA = endpoints[cidx]
pointB = endpoints[added_edges[cidx]]
pointA = (
pointA[1] - min_x,
pointA[0] - min_y,
)
pointB = (
pointB[1] - min_x,
pointB[0] - min_y,
)
colA = endpoints_color[cidx]
colB = endpoints_color[added_edges[cidx]]
new_fiber = cv2.line(
new_fiber,
pointA,
pointB,
color=2 if colA != colB else int(colA),
thickness=1,
)
# We express the bbox in the original image
bbox = (
coordinates[0] + min_x,
coordinates[1] + min_y,
max_x - min_x,
max_y - min_y,
)
bbox = Bbox(x=bbox[0], y=bbox[1], width=bbox[2], height=bbox[3])
result = Fiber(bbox=bbox, data=new_fiber)
results.append(result)
return results
def handle_ccs_with_junctions(
ccs: List[np.ndarray],
junctions: List[List[Tuple[int, int]]],
coordinates: List[Tuple[int, int]],
):
"""
Handle the connected components with junctions.
The function takes a list of connected components, a list of list of junctions and a list of coordinates.
The junctions
The coordinates corresponds to the top left corner of the connected component.
"""
jncts_fibers = []
for fiber, junction, coordinate in zip(ccs, junctions, coordinates):
jncts_fibers += handle_multiple_fiber_in_cc(fiber, junction, coordinate)
return jncts_fibers
def refine_segmentation(segmentation, fix_junctions=True, show=False):
skeleton = skeletonize(segmentation > 0, method="lee").astype(np.uint8)
skeleton_gt = skeleton * segmentation
retval, labels, stats, centroids = cv2.connectedComponentsWithStatsWithAlgorithm(
skeleton, connectivity=8, ccltype=cv2.CCL_BOLELLI, ltype=cv2.CV_16U
)
bboxes = stats[
:,
[
cv2.CC_STAT_LEFT,
cv2.CC_STAT_TOP,
cv2.CC_STAT_WIDTH,
cv2.CC_STAT_HEIGHT,
],
]
local_fibers = []
coordinates = []
junctions = []
for i in range(1, retval):
bbox = bboxes[i]
x1, y1, w, h = bbox
local_gt = skeleton_gt[y1 : y1 + h, x1 : x1 + w]
local_label = (labels[y1 : y1 + h, x1 : x1 + w] == i).astype(np.uint8)
local_fiber = local_gt * local_label
local_fibers.append(local_fiber)
coordinates.append(np.asarray([x1, y1, w, h]))
local_junctions = find_line_intersection(local_fiber > 0)
local_junctions = np.where(local_junctions)
local_junctions = np.array(local_junctions).transpose()
junctions.append(local_junctions)
if show:
for bbox, junction in zip(coordinates, junctions):
x, y, w, h = bbox
junction_to_global = np.array(junction) + np.array([y, x])
plt.scatter(
junction_to_global[:, 1],
junction_to_global[:, 0],
color="white",
s=30,
alpha=0.35,
)
plt.imshow(skeleton_gt, cmap=cmlabel, interpolation="nearest")
plt.axis("off")
plt.xticks([])
plt.yticks([])
plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
plt.show()
fibers = []
if fix_junctions:
has_junctions = [len(j) > 0 for j in junctions]
for fiber, coordinate in zip(
compress(local_fibers, np.logical_not(has_junctions)),
compress(coordinates, np.logical_not(has_junctions)),
):
bbox = Bbox(
x=coordinate[0],
y=coordinate[1],
width=coordinate[2],
height=coordinate[3],
)
fibers.append(Fiber(bbox=bbox, data=fiber))
fibers += handle_ccs_with_junctions(
compress(local_fibers, has_junctions),
compress(junctions, has_junctions),
compress(coordinates, has_junctions),
)
else:
for fiber, coordinate in zip(local_fibers, coordinates):
bbox = Bbox(
x=coordinate[0],
y=coordinate[1],
width=coordinate[2],
height=coordinate[3],
)
fibers.append(Fiber(bbox=bbox, data=fiber))
fiberprops = [FiberProps(fiber=f, fiber_id=i) for i, f in enumerate(fibers)]
return fiberprops