siyan824's picture
init
b2c1c6e
import os
import numpy as np
import torch
import gradio
import functools
from reloc3r.utils.image import parse_video, load_images, check_images_shape_format
from reloc3r.reloc3r_relpose import setup_reloc3r_relpose_model, inference_relpose
from reloc3r.utils.device import to_numpy
import cv2
import trimesh
import PIL
from scipy.spatial.transform import Rotation
from pdb import set_trace as bb
# from dust3r
OPENGL = np.array([[1, 0, 0, 0],
[0, -1, 0, 0],
[0, 0, -1, 0],
[0, 0, 0, 1]])
# func from dust3r
def geotrf(Trf, pts, ncol=None, norm=False):
""" Apply a geometric transformation to a list of 3-D points.
H: 3x3 or 4x4 projection matrix (typically a Homography)
p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)
ncol: int. number of columns of the result (2 or 3)
norm: float. if != 0, the resut is projected on the z=norm plane.
Returns an array of projected 2d points.
"""
assert Trf.ndim >= 2
if isinstance(Trf, np.ndarray):
pts = np.asarray(pts)
elif isinstance(Trf, torch.Tensor):
pts = torch.as_tensor(pts, dtype=Trf.dtype)
# adapt shape if necessary
output_reshape = pts.shape[:-1]
ncol = ncol or pts.shape[-1]
# optimized code
if (isinstance(Trf, torch.Tensor) and isinstance(pts, torch.Tensor) and
Trf.ndim == 3 and pts.ndim == 4):
d = pts.shape[3]
if Trf.shape[-1] == d:
pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
elif Trf.shape[-1] == d + 1:
pts = torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + Trf[:, None, None, :d, d]
else:
raise ValueError(f'bad shape, not ending with 3 or 4, for {pts.shape=}')
else:
if Trf.ndim >= 3:
n = Trf.ndim - 2
assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match'
Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
if pts.ndim > Trf.ndim:
# Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d)
pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
elif pts.ndim == 2:
# Trf == (B,d,d) & pts == (B,d) --> (B, 1, d)
pts = pts[:, None, :]
if pts.shape[-1] + 1 == Trf.shape[-1]:
Trf = Trf.swapaxes(-1, -2) # transpose Trf
pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
elif pts.shape[-1] == Trf.shape[-1]:
Trf = Trf.swapaxes(-1, -2) # transpose Trf
pts = pts @ Trf
else:
pts = Trf @ pts.T
if pts.ndim >= 2:
pts = pts.swapaxes(-1, -2)
if norm:
pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
if norm != 1:
pts *= norm
res = pts[..., :ncol].reshape(*output_reshape, ncol)
return res
# func from dust3r
def add_scene_cam(scene, pose_c2w, edge_color, image=None, focal=None, imsize=None, screen_width=0.11, marker=None):
if image is not None:
image = np.asarray(image)
H, W, THREE = image.shape
assert THREE == 3
if image.dtype != np.uint8:
image = np.uint8(255*image)
elif imsize is not None:
W, H = imsize
elif focal is not None:
H = W = focal / 1.1
else:
H = W = 1
if isinstance(focal, np.ndarray):
focal = focal[0]
if not focal:
focal = min(H,W) * 1.1 # default value
# create fake camera
height = max( screen_width/10, focal * screen_width / H )
width = screen_width * 0.5**0.5
rot45 = np.eye(4)
rot45[:3, :3] = Rotation.from_euler('z', np.deg2rad(45)).as_matrix()
rot45[2, 3] = -height # set the tip of the cone = optical center
aspect_ratio = np.eye(4)
aspect_ratio[0, 0] = W/H
transform = pose_c2w @ OPENGL @ aspect_ratio @ rot45
cam = trimesh.creation.cone(width, height, sections=4) # , transform=transform)
# this is the image
if image is not None:
vertices = geotrf(transform, cam.vertices[[4, 5, 1, 3]])
faces = np.array([[0, 1, 2], [0, 2, 3], [2, 1, 0], [3, 2, 0]])
img = trimesh.Trimesh(vertices=vertices, faces=faces)
uv_coords = np.float32([[0, 0], [1, 0], [1, 1], [0, 1]])
img.visual = trimesh.visual.TextureVisuals(uv_coords, image=PIL.Image.fromarray(image))
scene.add_geometry(img)
# this is the camera mesh
rot2 = np.eye(4)
rot2[:3, :3] = Rotation.from_euler('z', np.deg2rad(2)).as_matrix()
vertices = np.r_[cam.vertices, 0.95*cam.vertices, geotrf(rot2, cam.vertices)]
vertices = geotrf(transform, vertices)
faces = []
for face in cam.faces:
if 0 in face:
continue
a, b, c = face
a2, b2, c2 = face + len(cam.vertices)
a3, b3, c3 = face + 2*len(cam.vertices)
# add 3 pseudo-edges
faces.append((a, b, b2))
faces.append((a, a2, c))
faces.append((c2, b, c))
faces.append((a, b, b3))
faces.append((a, a3, c))
faces.append((c3, b, c))
# no culling
faces += [(c, b, a) for a, b, c in faces]
cam = trimesh.Trimesh(vertices=vertices, faces=faces)
cam.visual.face_colors[:, :3] = edge_color
scene.add_geometry(cam)
if marker == 'o':
marker = trimesh.creation.icosphere(3, radius=screen_width/4)
marker.vertices += pose_c2w[:3,3]
marker.visual.face_colors[:,:3] = edge_color
scene.add_geometry(marker)
# save relpose to .glb file
def vis_pose2to1(pose2to1, images):
poses = [np.identity(4), pose2to1]
colors = [(255, 0, 0), (0, 0, 255)]
scene = trimesh.Scene()
# add each camera
for i, pose_c2w in enumerate(poses):
camera_edge_color = colors[i]
add_scene_cam(scene, pose_c2w, camera_edge_color, images[i])
# coord transform for vis
rot = np.eye(4)
rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
scene.apply_transform(np.linalg.inv(poses[0] @ OPENGL @ rot))
path = '_tmp_vis/scene.glb'
scene.export(file_obj=path)
print('Scene saved to', path)
return path
# draw matches in images
def vis_ca_match(img1, img2, ca2to1_path, topkq=5, topkk=4):
img1 = np.ascontiguousarray(img1)
img2 = np.ascontiguousarray(img2)
attn_map = np.loadtxt(ca2to1_path)
h1, w1, _ = img1.shape
h2, w2, _ = img2.shape
assert w1 == w2
hp1, wp1 = h1//16, w1//16
hp2, wp2 = h2//16, w2//16
vis = np.concatenate((img1, img2), axis=0)
alpha = 0.5
overlay = vis.copy()
overlay[:,:] = (255, 255, 255)
cv2.addWeighted(overlay, alpha, vis, 1 - alpha, 0, vis)
cv2.rectangle(vis, (1, 1), (w1-2, h1-2), color=(255, 0, 0, 255), thickness=1, lineType=cv2.LINE_AA)
cv2.rectangle(vis, (1, h1+1), (w2-2, h1+h2-2), color=(0, 0, 255, 255), thickness=1, lineType=cv2.LINE_AA)
colors = [(245, 67, 62, 255), (93, 141, 253, 255), (94, 128, 64, 255), (245, 168, 61, 255), (0, 0, 0, 255)]
def find_top_k_indices(arr, k):
sorted_indices = np.argsort(arr)
top_k_indices = sorted_indices[-k:]
return top_k_indices
# select topkq responses
response_list = []
for id_v2 in range(attn_map.shape[0]):
out_of_bound = False
topk_list = find_top_k_indices(attn_map[id_v2], k=topkk)
response = attn_map[id_v2][topk_list].mean()
response_list.append(response)
for id_v1 in topk_list:
y_v1 = (id_v1//wp1)*16
x_v1 = (id_v1%wp1)*16
if x_v1 ==0 or x_v1 ==w1-16:
out_of_bound = True
if y_v1 ==0 or y_v1 ==h1-16:
out_of_bound = True
if out_of_bound:
response_list[-1] = 0
continue
top_match_ids = np.argsort(response_list)[-topkq:]
# draw the responses as matches
for i in range(len(top_match_ids)):
id_v2 = top_match_ids[i]
color = colors[i] if i < len(colors) else colors[-1]
# query
y_v2 = (id_v2//wp2)*16
x_v2 = (id_v2%wp2)*16
overlay = np.zeros_like(vis)
cv2.rectangle(vis, (x_v2, y_v2+h1), (x_v2+15, y_v2+h1+15), color=color, thickness=1, lineType=cv2.LINE_AA)
# keys
topk_list = find_top_k_indices(attn_map[id_v2], k=topkk)
for id_v1 in topk_list:
y_v1 = (id_v1//wp1)*16
x_v1 = (id_v1%wp1)*16
cv2.rectangle(vis, (x_v1, y_v1), (x_v1+15, y_v1+15), color=color, thickness=1, lineType=cv2.LINE_AA)
# lines
for id_v1 in topk_list:
y_v1 = (id_v1//wp1)*16
x_v1 = (id_v1%wp1)*16
cv2.line(vis, (x_v1+7, y_v1+7), (x_v2+7, y_v2+h1+7), color=color, thickness=1, lineType=cv2.LINE_AA)
return vis
# run the whole process
def run_reloc3r_rpr(reloc3r_relpose, img_reso, device,
imgs):
if not len(imgs) == 2:
print('There are >2 images uploaded, running with the first 2 images...')
# load images
print('Loading images...')
images = load_images(imgs[0:2], size=int(img_reso))
images = check_images_shape_format(images, device)
img1 = ((images[0]['img'].detach().cpu().numpy().squeeze().transpose(1,2,0) + 1) / 2 * 255).astype(np.uint8)
img2 = ((images[1]['img'].detach().cpu().numpy().squeeze().transpose(1,2,0) + 1) / 2 * 255).astype(np.uint8)
# estimate relpose
print('Running relative pose estimation...')
batch = [images[0], images[1]]
pose2to1 = to_numpy(inference_relpose(batch, reloc3r_relpose, device)[0])
pose2to1[0:3,3] = pose2to1[0:3,3] / np.linalg.norm(pose2to1[0:3,3]) # normalize the scale to 1 meter
pose_vis = vis_pose2to1(pose2to1, [img1, img2])
path = '_tmp_vis/pose2to1.txt'
np.savetxt(path, pose2to1)
print('Pose saved to', path)
# patch matches from cross-attn
print('Visualizing patch matches...')
block_id = 5
head_id = 0
match_vis = vis_ca_match(img1, img2, '_tmp_vis/_ca_block{}_head{}.txt'.format(block_id, head_id))
path = '_tmp_vis/match.png'
cv2.imwrite(path, match_vis[:,:,[2,1,0]])
print('Match visualization saved to', path)
return pose_vis, [match_vis]
# gradio interface
def main_demo(reloc3r_relpose, device, img_reso, server_name, server_port):
run = functools.partial(run_reloc3r_rpr, reloc3r_relpose, img_reso, device)
with gradio.Blocks(css=""".gradio-container {margin: 0 !important; min-width: 100%};""", title="Reloc3r relative camera pose demo") as demo:
gradio.HTML('<h2 style="text-align: center;">Reloc3r relative camera pose demo</h2>')
# components
with gradio.Row():
with gradio.Column():
inputfiles = gradio.File(file_count="multiple", file_types=["image"],
scale=2,
height=200,
label="Upload a pair of images")
run_btn = gradio.Button("Run")
outmodel = gradio.Model3D(camera_position=(-90, 45, 2),
height=400,
label="Camera poses")
outgallery = gradio.Gallery(preview=True,
height=400,
label="Cross-attention responses (top-5 queries with top-4 keys)")
# events
run_btn.click(fn=run,
inputs=[inputfiles],
outputs=[outmodel, outgallery])
# demo.launch(share=False, server_name=server_name, server_port=server_port)
demo.launch()
if __name__ == '__main__':
print('Note: This demo runs slowly because it operates on CPU and saves intermediate data.')
os.environ["GRADIO_TEMP_DIR"] = '_tmp_gradio'
if not os.path.exists('_tmp_gradio'):
os.mkdir('_tmp_gradio')
if not os.path.exists('_tmp_vis'):
os.mkdir('_tmp_vis')
server_name = '127.0.0.1'
server_port = 7867
img_reso = '512'
device = 'cpu'
device = torch.device(device)
print('Loading Reloc3r-512 RPR model...')
reloc3r_relpose = setup_reloc3r_relpose_model(model_args=img_reso, device=device)
main_demo(reloc3r_relpose, device, img_reso, server_name, server_port)