lch01's picture
add dependencies
f4ba42f
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# modified from DUSt3R
import numpy as np
from dust3r.utils.device import to_numpy
from dust3r.utils.geometry import inv, geotrf
def reproject_view(pts3d, view2):
shape = view2["pts3d"].shape[:2]
return reproject(
pts3d, view2["camera_intrinsics"], inv(view2["camera_pose"]), shape
)
def reproject(pts3d, K, world2cam, shape):
H, W, THREE = pts3d.shape
assert THREE == 3
# reproject in camera2 space
with np.errstate(divide="ignore", invalid="ignore"):
pos = geotrf(K @ world2cam[:3], pts3d, norm=1, ncol=2)
# quantize to pixel positions
return (H, W), ravel_xy(pos, shape)
def ravel_xy(pos, shape):
H, W = shape
with np.errstate(invalid="ignore"):
qx, qy = pos.reshape(-1, 2).round().astype(np.int32).T
quantized_pos = qx.clip(min=0, max=W - 1, out=qx) + W * qy.clip(
min=0, max=H - 1, out=qy
)
return quantized_pos
def unravel_xy(pos, shape):
# convert (x+W*y) back to 2d (x,y) coordinates
return np.unravel_index(pos, shape)[0].base[:, ::-1].copy()
def reciprocal_1d(corres_1_to_2, corres_2_to_1, ret_recip=False):
is_reciprocal1 = corres_2_to_1[corres_1_to_2] == np.arange(len(corres_1_to_2))
pos1 = is_reciprocal1.nonzero()[0]
pos2 = corres_1_to_2[pos1]
if ret_recip:
return is_reciprocal1, pos1, pos2
return pos1, pos2
def extract_correspondences_from_pts3d(
view1, view2, target_n_corres, rng=np.random, ret_xy=True, nneg=0
):
view1, view2 = to_numpy((view1, view2))
# project pixels from image1 --> 3d points --> image2 pixels
shape1, corres1_to_2 = reproject_view(view1["pts3d"], view2)
shape2, corres2_to_1 = reproject_view(view2["pts3d"], view1)
# compute reciprocal correspondences:
# pos1 == valid pixels (correspondences) in image1
is_reciprocal1, pos1, pos2 = reciprocal_1d(
corres1_to_2, corres2_to_1, ret_recip=True
)
is_reciprocal2 = corres1_to_2[corres2_to_1] == np.arange(len(corres2_to_1))
if target_n_corres is None:
if ret_xy:
pos1 = unravel_xy(pos1, shape1)
pos2 = unravel_xy(pos2, shape2)
return pos1, pos2
available_negatives = min((~is_reciprocal1).sum(), (~is_reciprocal2).sum())
target_n_positives = int(target_n_corres * (1 - nneg))
n_positives = min(len(pos1), target_n_positives)
n_negatives = min(target_n_corres - n_positives, available_negatives)
if n_negatives + n_positives != target_n_corres:
# should be really rare => when there are not enough negatives
# in that case, break nneg and add a few more positives ?
n_positives = target_n_corres - n_negatives
assert n_positives <= len(pos1)
assert n_positives <= len(pos1)
assert n_positives <= len(pos2)
assert n_negatives <= (~is_reciprocal1).sum()
assert n_negatives <= (~is_reciprocal2).sum()
assert n_positives + n_negatives == target_n_corres
valid = np.ones(n_positives, dtype=bool)
if n_positives < len(pos1):
# random sub-sampling of valid correspondences
perm = rng.permutation(len(pos1))[:n_positives]
pos1 = pos1[perm]
pos2 = pos2[perm]
if n_negatives > 0:
# add false correspondences if not enough
def norm(p):
return p / p.sum()
pos1 = np.r_[
pos1,
rng.choice(
shape1[0] * shape1[1],
size=n_negatives,
replace=False,
p=norm(~is_reciprocal1),
),
]
pos2 = np.r_[
pos2,
rng.choice(
shape2[0] * shape2[1],
size=n_negatives,
replace=False,
p=norm(~is_reciprocal2),
),
]
valid = np.r_[valid, np.zeros(n_negatives, dtype=bool)]
# convert (x+W*y) back to 2d (x,y) coordinates
if ret_xy:
pos1 = unravel_xy(pos1, shape1)
pos2 = unravel_xy(pos2, shape2)
return pos1, pos2, valid