import sys import argparse import glob import numpy as np import torch from tqdm import tqdm from pathlib import Path from PIL import Image from matplotlib import pyplot as plt import os import onnxruntime as ort import axengine as axe def load_image(imfile): img = np.array(Image.open(imfile).resize((512,384))).astype(np.uint8)[..., :3] img = torch.from_numpy(img).permute(2, 0, 1).float() return img[None] def visualize_disparity(disparity_map, title, name="test"): plt.figure(figsize=(10, 6)) plt.imshow(disparity_map, cmap='jet') plt.colorbar(label="Disparity") plt.title(title) plt.axis('off') # plt.show() plt.savefig(f"{title}-rt-{name}.png") def demo(args): left_images = sorted(glob.glob(args.left_imgs, recursive=True)) right_images = sorted(glob.glob(args.right_imgs, recursive=True)) if args.target_chip == "AX637": print("\033[91mWarning: AX637 uses quant_axmodel, which can not be run by onnxruntime, \ so we use AX650's onnx model for comparison\033[0m") ort_session = ort.InferenceSession(f"./models/AX650.onnx") ax_session = axe.InferenceSession(f"./models/{args.target_chip}_RTIGEV.axmodel") for (imfile1, imfile2) in tqdm(list(zip(left_images, right_images))): image1 = load_image(imfile1) image2 = load_image(imfile2) img_name = Path(imfile1).parent.name input_l_np = image1.cpu().numpy() input_r_np = image2.cpu().numpy() ax_inputs = {"left": input_l_np.transpose(0,2,3,1).astype(np.uint8), "right": input_r_np.transpose(0,2,3,1).astype(np.uint8)} input_l_np = (2 * (input_l_np / 255.0) - 1.0) input_r_np = (2 * (input_r_np / 255.0) - 1.0) onnx_inputs = {"left": input_l_np, "right": input_r_np} onnx_outputs = ort_session.run(None, onnx_inputs) disp_onnx = onnx_outputs[0].squeeze() ax_outputs = ax_session.run(None, ax_inputs) disp_ax = ax_outputs[0].squeeze() # print("disp_onnx",disp_onnx) # print("disp_ax",disp_ax) visualize_disparity(disp_onnx, title="ONNX_Disparity_Map", name=img_name) visualize_disparity(disp_ax, title="AXModel_Disparity_Map", name=img_name) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('-l', '--left_imgs', help="path to all first (left) frames", default="./demo-imgs/*/im0.png") parser.add_argument('-t', '--target_chip', help="target chip for inference", default="AX650", choices=["AX637", "AX650"]) parser.add_argument('-r', '--right_imgs', help="path to all second (right) frames", default="./demo-imgs/*/im1.png") args = parser.parse_args() demo(args)