Spaces:
Running
Running
| import os | |
| import numpy as np | |
| import torch | |
| import gradio | |
| import functools | |
| from reloc3r.utils.image import parse_video, load_images, check_images_shape_format | |
| from reloc3r.reloc3r_relpose import setup_reloc3r_relpose_model, inference_relpose | |
| from reloc3r.utils.device import to_numpy | |
| import cv2 | |
| import trimesh | |
| import PIL | |
| from scipy.spatial.transform import Rotation | |
| from pdb import set_trace as bb | |
| # from dust3r | |
| OPENGL = np.array([[1, 0, 0, 0], | |
| [0, -1, 0, 0], | |
| [0, 0, -1, 0], | |
| [0, 0, 0, 1]]) | |
| # func from dust3r | |
| def geotrf(Trf, pts, ncol=None, norm=False): | |
| """ Apply a geometric transformation to a list of 3-D points. | |
| H: 3x3 or 4x4 projection matrix (typically a Homography) | |
| p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3) | |
| ncol: int. number of columns of the result (2 or 3) | |
| norm: float. if != 0, the resut is projected on the z=norm plane. | |
| Returns an array of projected 2d points. | |
| """ | |
| assert Trf.ndim >= 2 | |
| if isinstance(Trf, np.ndarray): | |
| pts = np.asarray(pts) | |
| elif isinstance(Trf, torch.Tensor): | |
| pts = torch.as_tensor(pts, dtype=Trf.dtype) | |
| # adapt shape if necessary | |
| output_reshape = pts.shape[:-1] | |
| ncol = ncol or pts.shape[-1] | |
| # optimized code | |
| if (isinstance(Trf, torch.Tensor) and isinstance(pts, torch.Tensor) and | |
| Trf.ndim == 3 and pts.ndim == 4): | |
| d = pts.shape[3] | |
| if Trf.shape[-1] == d: | |
| pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts) | |
| elif Trf.shape[-1] == d + 1: | |
| pts = torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + Trf[:, None, None, :d, d] | |
| else: | |
| raise ValueError(f'bad shape, not ending with 3 or 4, for {pts.shape=}') | |
| else: | |
| if Trf.ndim >= 3: | |
| n = Trf.ndim - 2 | |
| assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match' | |
| Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1]) | |
| if pts.ndim > Trf.ndim: | |
| # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d) | |
| pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1]) | |
| elif pts.ndim == 2: | |
| # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d) | |
| pts = pts[:, None, :] | |
| if pts.shape[-1] + 1 == Trf.shape[-1]: | |
| Trf = Trf.swapaxes(-1, -2) # transpose Trf | |
| pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :] | |
| elif pts.shape[-1] == Trf.shape[-1]: | |
| Trf = Trf.swapaxes(-1, -2) # transpose Trf | |
| pts = pts @ Trf | |
| else: | |
| pts = Trf @ pts.T | |
| if pts.ndim >= 2: | |
| pts = pts.swapaxes(-1, -2) | |
| if norm: | |
| pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG | |
| if norm != 1: | |
| pts *= norm | |
| res = pts[..., :ncol].reshape(*output_reshape, ncol) | |
| return res | |
| # func from dust3r | |
| def add_scene_cam(scene, pose_c2w, edge_color, image=None, focal=None, imsize=None, screen_width=0.11, marker=None): | |
| if image is not None: | |
| image = np.asarray(image) | |
| H, W, THREE = image.shape | |
| assert THREE == 3 | |
| if image.dtype != np.uint8: | |
| image = np.uint8(255*image) | |
| elif imsize is not None: | |
| W, H = imsize | |
| elif focal is not None: | |
| H = W = focal / 1.1 | |
| else: | |
| H = W = 1 | |
| if isinstance(focal, np.ndarray): | |
| focal = focal[0] | |
| if not focal: | |
| focal = min(H,W) * 1.1 # default value | |
| # create fake camera | |
| height = max( screen_width/10, focal * screen_width / H ) | |
| width = screen_width * 0.5**0.5 | |
| rot45 = np.eye(4) | |
| rot45[:3, :3] = Rotation.from_euler('z', np.deg2rad(45)).as_matrix() | |
| rot45[2, 3] = -height # set the tip of the cone = optical center | |
| aspect_ratio = np.eye(4) | |
| aspect_ratio[0, 0] = W/H | |
| transform = pose_c2w @ OPENGL @ aspect_ratio @ rot45 | |
| cam = trimesh.creation.cone(width, height, sections=4) # , transform=transform) | |
| # this is the image | |
| if image is not None: | |
| vertices = geotrf(transform, cam.vertices[[4, 5, 1, 3]]) | |
| faces = np.array([[0, 1, 2], [0, 2, 3], [2, 1, 0], [3, 2, 0]]) | |
| img = trimesh.Trimesh(vertices=vertices, faces=faces) | |
| uv_coords = np.float32([[0, 0], [1, 0], [1, 1], [0, 1]]) | |
| img.visual = trimesh.visual.TextureVisuals(uv_coords, image=PIL.Image.fromarray(image)) | |
| scene.add_geometry(img) | |
| # this is the camera mesh | |
| rot2 = np.eye(4) | |
| rot2[:3, :3] = Rotation.from_euler('z', np.deg2rad(2)).as_matrix() | |
| vertices = np.r_[cam.vertices, 0.95*cam.vertices, geotrf(rot2, cam.vertices)] | |
| vertices = geotrf(transform, vertices) | |
| faces = [] | |
| for face in cam.faces: | |
| if 0 in face: | |
| continue | |
| a, b, c = face | |
| a2, b2, c2 = face + len(cam.vertices) | |
| a3, b3, c3 = face + 2*len(cam.vertices) | |
| # add 3 pseudo-edges | |
| faces.append((a, b, b2)) | |
| faces.append((a, a2, c)) | |
| faces.append((c2, b, c)) | |
| faces.append((a, b, b3)) | |
| faces.append((a, a3, c)) | |
| faces.append((c3, b, c)) | |
| # no culling | |
| faces += [(c, b, a) for a, b, c in faces] | |
| cam = trimesh.Trimesh(vertices=vertices, faces=faces) | |
| cam.visual.face_colors[:, :3] = edge_color | |
| scene.add_geometry(cam) | |
| if marker == 'o': | |
| marker = trimesh.creation.icosphere(3, radius=screen_width/4) | |
| marker.vertices += pose_c2w[:3,3] | |
| marker.visual.face_colors[:,:3] = edge_color | |
| scene.add_geometry(marker) | |
| # save relpose to .glb file | |
| def vis_pose2to1(pose2to1, images): | |
| poses = [np.identity(4), pose2to1] | |
| colors = [(255, 0, 0), (0, 0, 255)] | |
| scene = trimesh.Scene() | |
| # add each camera | |
| for i, pose_c2w in enumerate(poses): | |
| camera_edge_color = colors[i] | |
| add_scene_cam(scene, pose_c2w, camera_edge_color, images[i]) | |
| # coord transform for vis | |
| rot = np.eye(4) | |
| rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix() | |
| scene.apply_transform(np.linalg.inv(poses[0] @ OPENGL @ rot)) | |
| path = '_tmp_vis/scene.glb' | |
| scene.export(file_obj=path) | |
| print('Scene saved to', path) | |
| return path | |
| # draw matches in images | |
| def vis_ca_match(img1, img2, ca2to1_path, topkq=5, topkk=4): | |
| img1 = np.ascontiguousarray(img1) | |
| img2 = np.ascontiguousarray(img2) | |
| attn_map = np.loadtxt(ca2to1_path) | |
| h1, w1, _ = img1.shape | |
| h2, w2, _ = img2.shape | |
| assert w1 == w2 | |
| hp1, wp1 = h1//16, w1//16 | |
| hp2, wp2 = h2//16, w2//16 | |
| vis = np.concatenate((img1, img2), axis=0) | |
| alpha = 0.5 | |
| overlay = vis.copy() | |
| overlay[:,:] = (255, 255, 255) | |
| cv2.addWeighted(overlay, alpha, vis, 1 - alpha, 0, vis) | |
| cv2.rectangle(vis, (1, 1), (w1-2, h1-2), color=(255, 0, 0, 255), thickness=1, lineType=cv2.LINE_AA) | |
| cv2.rectangle(vis, (1, h1+1), (w2-2, h1+h2-2), color=(0, 0, 255, 255), thickness=1, lineType=cv2.LINE_AA) | |
| colors = [(245, 67, 62, 255), (93, 141, 253, 255), (94, 128, 64, 255), (245, 168, 61, 255), (0, 0, 0, 255)] | |
| def find_top_k_indices(arr, k): | |
| sorted_indices = np.argsort(arr) | |
| top_k_indices = sorted_indices[-k:] | |
| return top_k_indices | |
| # select topkq responses | |
| response_list = [] | |
| for id_v2 in range(attn_map.shape[0]): | |
| out_of_bound = False | |
| topk_list = find_top_k_indices(attn_map[id_v2], k=topkk) | |
| response = attn_map[id_v2][topk_list].mean() | |
| response_list.append(response) | |
| for id_v1 in topk_list: | |
| y_v1 = (id_v1//wp1)*16 | |
| x_v1 = (id_v1%wp1)*16 | |
| if x_v1 ==0 or x_v1 ==w1-16: | |
| out_of_bound = True | |
| if y_v1 ==0 or y_v1 ==h1-16: | |
| out_of_bound = True | |
| if out_of_bound: | |
| response_list[-1] = 0 | |
| continue | |
| top_match_ids = np.argsort(response_list)[-topkq:] | |
| # draw the responses as matches | |
| for i in range(len(top_match_ids)): | |
| id_v2 = top_match_ids[i] | |
| color = colors[i] if i < len(colors) else colors[-1] | |
| # query | |
| y_v2 = (id_v2//wp2)*16 | |
| x_v2 = (id_v2%wp2)*16 | |
| overlay = np.zeros_like(vis) | |
| cv2.rectangle(vis, (x_v2, y_v2+h1), (x_v2+15, y_v2+h1+15), color=color, thickness=1, lineType=cv2.LINE_AA) | |
| # keys | |
| topk_list = find_top_k_indices(attn_map[id_v2], k=topkk) | |
| for id_v1 in topk_list: | |
| y_v1 = (id_v1//wp1)*16 | |
| x_v1 = (id_v1%wp1)*16 | |
| cv2.rectangle(vis, (x_v1, y_v1), (x_v1+15, y_v1+15), color=color, thickness=1, lineType=cv2.LINE_AA) | |
| # lines | |
| for id_v1 in topk_list: | |
| y_v1 = (id_v1//wp1)*16 | |
| x_v1 = (id_v1%wp1)*16 | |
| cv2.line(vis, (x_v1+7, y_v1+7), (x_v2+7, y_v2+h1+7), color=color, thickness=1, lineType=cv2.LINE_AA) | |
| return vis | |
| # run the whole process | |
| def run_reloc3r_rpr(reloc3r_relpose, img_reso, device, | |
| imgs): | |
| if not len(imgs) == 2: | |
| print('There are >2 images uploaded, running with the first 2 images...') | |
| # load images | |
| print('Loading images...') | |
| images = load_images(imgs[0:2], size=int(img_reso)) | |
| images = check_images_shape_format(images, device) | |
| img1 = ((images[0]['img'].detach().cpu().numpy().squeeze().transpose(1,2,0) + 1) / 2 * 255).astype(np.uint8) | |
| img2 = ((images[1]['img'].detach().cpu().numpy().squeeze().transpose(1,2,0) + 1) / 2 * 255).astype(np.uint8) | |
| # estimate relpose | |
| print('Running relative pose estimation...') | |
| batch = [images[0], images[1]] | |
| pose2to1 = to_numpy(inference_relpose(batch, reloc3r_relpose, device)[0]) | |
| pose2to1[0:3,3] = pose2to1[0:3,3] / np.linalg.norm(pose2to1[0:3,3]) # normalize the scale to 1 meter | |
| pose_vis = vis_pose2to1(pose2to1, [img1, img2]) | |
| path = '_tmp_vis/pose2to1.txt' | |
| np.savetxt(path, pose2to1) | |
| print('Pose saved to', path) | |
| # patch matches from cross-attn | |
| print('Visualizing patch matches...') | |
| block_id = 5 | |
| head_id = 0 | |
| match_vis = vis_ca_match(img1, img2, '_tmp_vis/_ca_block{}_head{}.txt'.format(block_id, head_id)) | |
| path = '_tmp_vis/match.png' | |
| cv2.imwrite(path, match_vis[:,:,[2,1,0]]) | |
| print('Match visualization saved to', path) | |
| return pose_vis, [match_vis] | |
| # gradio interface | |
| def main_demo(reloc3r_relpose, device, img_reso, server_name, server_port): | |
| run = functools.partial(run_reloc3r_rpr, reloc3r_relpose, img_reso, device) | |
| with gradio.Blocks(css=""".gradio-container {margin: 0 !important; min-width: 100%};""", title="Reloc3r relative camera pose demo") as demo: | |
| gradio.HTML('<h2 style="text-align: center;">Reloc3r relative camera pose demo</h2>') | |
| # components | |
| with gradio.Row(): | |
| with gradio.Column(): | |
| inputfiles = gradio.File(file_count="multiple", file_types=["image"], | |
| scale=2, | |
| height=200, | |
| label="Upload a pair of images") | |
| run_btn = gradio.Button("Run") | |
| outmodel = gradio.Model3D(camera_position=(-90, 45, 2), | |
| height=400, | |
| label="Camera poses") | |
| outgallery = gradio.Gallery(preview=True, | |
| height=400, | |
| label="Cross-attention responses (top-5 queries with top-4 keys)") | |
| # events | |
| run_btn.click(fn=run, | |
| inputs=[inputfiles], | |
| outputs=[outmodel, outgallery]) | |
| # demo.launch(share=False, server_name=server_name, server_port=server_port) | |
| demo.launch() | |
| if __name__ == '__main__': | |
| print('Note: This demo runs slowly because it operates on CPU and saves intermediate data.') | |
| os.environ["GRADIO_TEMP_DIR"] = '_tmp_gradio' | |
| if not os.path.exists('_tmp_gradio'): | |
| os.mkdir('_tmp_gradio') | |
| if not os.path.exists('_tmp_vis'): | |
| os.mkdir('_tmp_vis') | |
| server_name = '127.0.0.1' | |
| server_port = 7867 | |
| img_reso = '512' | |
| device = 'cpu' | |
| device = torch.device(device) | |
| print('Loading Reloc3r-512 RPR model...') | |
| reloc3r_relpose = setup_reloc3r_relpose_model(model_args=img_reso, device=device) | |
| main_demo(reloc3r_relpose, device, img_reso, server_name, server_port) | |