import os import numpy as np import torch import gradio import functools from reloc3r.utils.image import parse_video, load_images, check_images_shape_format from reloc3r.reloc3r_relpose import setup_reloc3r_relpose_model, inference_relpose from reloc3r.utils.device import to_numpy import cv2 import trimesh import PIL from scipy.spatial.transform import Rotation from pdb import set_trace as bb # from dust3r OPENGL = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) # func from dust3r def geotrf(Trf, pts, ncol=None, norm=False): """ Apply a geometric transformation to a list of 3-D points. H: 3x3 or 4x4 projection matrix (typically a Homography) p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3) ncol: int. number of columns of the result (2 or 3) norm: float. if != 0, the resut is projected on the z=norm plane. Returns an array of projected 2d points. """ assert Trf.ndim >= 2 if isinstance(Trf, np.ndarray): pts = np.asarray(pts) elif isinstance(Trf, torch.Tensor): pts = torch.as_tensor(pts, dtype=Trf.dtype) # adapt shape if necessary output_reshape = pts.shape[:-1] ncol = ncol or pts.shape[-1] # optimized code if (isinstance(Trf, torch.Tensor) and isinstance(pts, torch.Tensor) and Trf.ndim == 3 and pts.ndim == 4): d = pts.shape[3] if Trf.shape[-1] == d: pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts) elif Trf.shape[-1] == d + 1: pts = torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + Trf[:, None, None, :d, d] else: raise ValueError(f'bad shape, not ending with 3 or 4, for {pts.shape=}') else: if Trf.ndim >= 3: n = Trf.ndim - 2 assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match' Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1]) if pts.ndim > Trf.ndim: # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d) pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1]) elif pts.ndim == 2: # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d) pts = pts[:, None, :] if pts.shape[-1] + 1 == Trf.shape[-1]: Trf = Trf.swapaxes(-1, -2) # transpose Trf pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :] elif pts.shape[-1] == Trf.shape[-1]: Trf = Trf.swapaxes(-1, -2) # transpose Trf pts = pts @ Trf else: pts = Trf @ pts.T if pts.ndim >= 2: pts = pts.swapaxes(-1, -2) if norm: pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG if norm != 1: pts *= norm res = pts[..., :ncol].reshape(*output_reshape, ncol) return res # func from dust3r def add_scene_cam(scene, pose_c2w, edge_color, image=None, focal=None, imsize=None, screen_width=0.11, marker=None): if image is not None: image = np.asarray(image) H, W, THREE = image.shape assert THREE == 3 if image.dtype != np.uint8: image = np.uint8(255*image) elif imsize is not None: W, H = imsize elif focal is not None: H = W = focal / 1.1 else: H = W = 1 if isinstance(focal, np.ndarray): focal = focal[0] if not focal: focal = min(H,W) * 1.1 # default value # create fake camera height = max( screen_width/10, focal * screen_width / H ) width = screen_width * 0.5**0.5 rot45 = np.eye(4) rot45[:3, :3] = Rotation.from_euler('z', np.deg2rad(45)).as_matrix() rot45[2, 3] = -height # set the tip of the cone = optical center aspect_ratio = np.eye(4) aspect_ratio[0, 0] = W/H transform = pose_c2w @ OPENGL @ aspect_ratio @ rot45 cam = trimesh.creation.cone(width, height, sections=4) # , transform=transform) # this is the image if image is not None: vertices = geotrf(transform, cam.vertices[[4, 5, 1, 3]]) faces = np.array([[0, 1, 2], [0, 2, 3], [2, 1, 0], [3, 2, 0]]) img = trimesh.Trimesh(vertices=vertices, faces=faces) uv_coords = np.float32([[0, 0], [1, 0], [1, 1], [0, 1]]) img.visual = trimesh.visual.TextureVisuals(uv_coords, image=PIL.Image.fromarray(image)) scene.add_geometry(img) # this is the camera mesh rot2 = np.eye(4) rot2[:3, :3] = Rotation.from_euler('z', np.deg2rad(2)).as_matrix() vertices = np.r_[cam.vertices, 0.95*cam.vertices, geotrf(rot2, cam.vertices)] vertices = geotrf(transform, vertices) faces = [] for face in cam.faces: if 0 in face: continue a, b, c = face a2, b2, c2 = face + len(cam.vertices) a3, b3, c3 = face + 2*len(cam.vertices) # add 3 pseudo-edges faces.append((a, b, b2)) faces.append((a, a2, c)) faces.append((c2, b, c)) faces.append((a, b, b3)) faces.append((a, a3, c)) faces.append((c3, b, c)) # no culling faces += [(c, b, a) for a, b, c in faces] cam = trimesh.Trimesh(vertices=vertices, faces=faces) cam.visual.face_colors[:, :3] = edge_color scene.add_geometry(cam) if marker == 'o': marker = trimesh.creation.icosphere(3, radius=screen_width/4) marker.vertices += pose_c2w[:3,3] marker.visual.face_colors[:,:3] = edge_color scene.add_geometry(marker) # save relpose to .glb file def vis_pose2to1(pose2to1, images): poses = [np.identity(4), pose2to1] colors = [(255, 0, 0), (0, 0, 255)] scene = trimesh.Scene() # add each camera for i, pose_c2w in enumerate(poses): camera_edge_color = colors[i] add_scene_cam(scene, pose_c2w, camera_edge_color, images[i]) # coord transform for vis rot = np.eye(4) rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix() scene.apply_transform(np.linalg.inv(poses[0] @ OPENGL @ rot)) path = '_tmp_vis/scene.glb' scene.export(file_obj=path) print('Scene saved to', path) return path # draw matches in images def vis_ca_match(img1, img2, ca2to1_path, topkq=5, topkk=4): img1 = np.ascontiguousarray(img1) img2 = np.ascontiguousarray(img2) attn_map = np.loadtxt(ca2to1_path) h1, w1, _ = img1.shape h2, w2, _ = img2.shape assert w1 == w2 hp1, wp1 = h1//16, w1//16 hp2, wp2 = h2//16, w2//16 vis = np.concatenate((img1, img2), axis=0) alpha = 0.5 overlay = vis.copy() overlay[:,:] = (255, 255, 255) cv2.addWeighted(overlay, alpha, vis, 1 - alpha, 0, vis) cv2.rectangle(vis, (1, 1), (w1-2, h1-2), color=(255, 0, 0, 255), thickness=1, lineType=cv2.LINE_AA) cv2.rectangle(vis, (1, h1+1), (w2-2, h1+h2-2), color=(0, 0, 255, 255), thickness=1, lineType=cv2.LINE_AA) colors = [(245, 67, 62, 255), (93, 141, 253, 255), (94, 128, 64, 255), (245, 168, 61, 255), (0, 0, 0, 255)] def find_top_k_indices(arr, k): sorted_indices = np.argsort(arr) top_k_indices = sorted_indices[-k:] return top_k_indices # select topkq responses response_list = [] for id_v2 in range(attn_map.shape[0]): out_of_bound = False topk_list = find_top_k_indices(attn_map[id_v2], k=topkk) response = attn_map[id_v2][topk_list].mean() response_list.append(response) for id_v1 in topk_list: y_v1 = (id_v1//wp1)*16 x_v1 = (id_v1%wp1)*16 if x_v1 ==0 or x_v1 ==w1-16: out_of_bound = True if y_v1 ==0 or y_v1 ==h1-16: out_of_bound = True if out_of_bound: response_list[-1] = 0 continue top_match_ids = np.argsort(response_list)[-topkq:] # draw the responses as matches for i in range(len(top_match_ids)): id_v2 = top_match_ids[i] color = colors[i] if i < len(colors) else colors[-1] # query y_v2 = (id_v2//wp2)*16 x_v2 = (id_v2%wp2)*16 overlay = np.zeros_like(vis) cv2.rectangle(vis, (x_v2, y_v2+h1), (x_v2+15, y_v2+h1+15), color=color, thickness=1, lineType=cv2.LINE_AA) # keys topk_list = find_top_k_indices(attn_map[id_v2], k=topkk) for id_v1 in topk_list: y_v1 = (id_v1//wp1)*16 x_v1 = (id_v1%wp1)*16 cv2.rectangle(vis, (x_v1, y_v1), (x_v1+15, y_v1+15), color=color, thickness=1, lineType=cv2.LINE_AA) # lines for id_v1 in topk_list: y_v1 = (id_v1//wp1)*16 x_v1 = (id_v1%wp1)*16 cv2.line(vis, (x_v1+7, y_v1+7), (x_v2+7, y_v2+h1+7), color=color, thickness=1, lineType=cv2.LINE_AA) return vis # run the whole process def run_reloc3r_rpr(reloc3r_relpose, img_reso, device, imgs): if not len(imgs) == 2: print('There are >2 images uploaded, running with the first 2 images...') # load images print('Loading images...') images = load_images(imgs[0:2], size=int(img_reso)) images = check_images_shape_format(images, device) img1 = ((images[0]['img'].detach().cpu().numpy().squeeze().transpose(1,2,0) + 1) / 2 * 255).astype(np.uint8) img2 = ((images[1]['img'].detach().cpu().numpy().squeeze().transpose(1,2,0) + 1) / 2 * 255).astype(np.uint8) # estimate relpose print('Running relative pose estimation...') batch = [images[0], images[1]] pose2to1 = to_numpy(inference_relpose(batch, reloc3r_relpose, device)[0]) pose2to1[0:3,3] = pose2to1[0:3,3] / np.linalg.norm(pose2to1[0:3,3]) # normalize the scale to 1 meter pose_vis = vis_pose2to1(pose2to1, [img1, img2]) path = '_tmp_vis/pose2to1.txt' np.savetxt(path, pose2to1) print('Pose saved to', path) # patch matches from cross-attn print('Visualizing patch matches...') block_id = 5 head_id = 0 match_vis = vis_ca_match(img1, img2, '_tmp_vis/_ca_block{}_head{}.txt'.format(block_id, head_id)) path = '_tmp_vis/match.png' cv2.imwrite(path, match_vis[:,:,[2,1,0]]) print('Match visualization saved to', path) return pose_vis, [match_vis] # gradio interface def main_demo(reloc3r_relpose, device, img_reso, server_name, server_port): run = functools.partial(run_reloc3r_rpr, reloc3r_relpose, img_reso, device) with gradio.Blocks(css=""".gradio-container {margin: 0 !important; min-width: 100%};""", title="Reloc3r relative camera pose demo") as demo: gradio.HTML('

Reloc3r relative camera pose demo

') # components with gradio.Row(): with gradio.Column(): inputfiles = gradio.File(file_count="multiple", file_types=["image"], scale=2, height=200, label="Upload a pair of images") run_btn = gradio.Button("Run") outmodel = gradio.Model3D(camera_position=(-90, 45, 2), height=400, label="Camera poses") outgallery = gradio.Gallery(preview=True, height=400, label="Cross-attention responses (top-5 queries with top-4 keys)") # events run_btn.click(fn=run, inputs=[inputfiles], outputs=[outmodel, outgallery]) # demo.launch(share=False, server_name=server_name, server_port=server_port) demo.launch() if __name__ == '__main__': print('Note: This demo runs slowly because it operates on CPU and saves intermediate data.') os.environ["GRADIO_TEMP_DIR"] = '_tmp_gradio' if not os.path.exists('_tmp_gradio'): os.mkdir('_tmp_gradio') if not os.path.exists('_tmp_vis'): os.mkdir('_tmp_vis') server_name = '127.0.0.1' server_port = 7867 img_reso = '512' device = 'cpu' device = torch.device(device) print('Loading Reloc3r-512 RPR model...') reloc3r_relpose = setup_reloc3r_relpose_model(model_args=img_reso, device=device) main_demo(reloc3r_relpose, device, img_reso, server_name, server_port)