| import sys |
| import argparse |
| import glob |
| import numpy as np |
| import torch |
| from tqdm import tqdm |
| from pathlib import Path |
| from PIL import Image |
| from matplotlib import pyplot as plt |
| import os |
| import onnxruntime as ort |
| import axengine as axe |
|
|
|
|
| def load_image(imfile): |
| img = np.array(Image.open(imfile).resize((512,384))).astype(np.uint8)[..., :3] |
| img = torch.from_numpy(img).permute(2, 0, 1).float() |
| return img[None] |
|
|
|
|
| def visualize_disparity(disparity_map, title, name="test"): |
| plt.figure(figsize=(10, 6)) |
| plt.imshow(disparity_map, cmap='jet') |
| plt.colorbar(label="Disparity") |
| plt.title(title) |
| plt.axis('off') |
| |
| plt.savefig(f"{title}-rt-{name}.png") |
|
|
|
|
| def demo(args): |
| left_images = sorted(glob.glob(args.left_imgs, recursive=True)) |
| right_images = sorted(glob.glob(args.right_imgs, recursive=True)) |
| if args.target_chip == "AX637": |
| print("\033[91mWarning: AX637 uses quant_axmodel, which can not be run by onnxruntime, \ |
| so we use AX650's onnx model for comparison\033[0m") |
| ort_session = ort.InferenceSession(f"./models/AX650.onnx") |
| ax_session = axe.InferenceSession(f"./models/{args.target_chip}_RTIGEV.axmodel") |
| for (imfile1, imfile2) in tqdm(list(zip(left_images, right_images))): |
| image1 = load_image(imfile1) |
| image2 = load_image(imfile2) |
| img_name = Path(imfile1).parent.name |
|
|
| input_l_np = image1.cpu().numpy() |
| input_r_np = image2.cpu().numpy() |
| ax_inputs = {"left": input_l_np.transpose(0,2,3,1).astype(np.uint8), "right": input_r_np.transpose(0,2,3,1).astype(np.uint8)} |
|
|
| input_l_np = (2 * (input_l_np / 255.0) - 1.0) |
| input_r_np = (2 * (input_r_np / 255.0) - 1.0) |
| onnx_inputs = {"left": input_l_np, "right": input_r_np} |
| |
| onnx_outputs = ort_session.run(None, onnx_inputs) |
| disp_onnx = onnx_outputs[0].squeeze() |
|
|
| ax_outputs = ax_session.run(None, ax_inputs) |
| disp_ax = ax_outputs[0].squeeze() |
|
|
| |
| |
| visualize_disparity(disp_onnx, title="ONNX_Disparity_Map", name=img_name) |
| visualize_disparity(disp_ax, title="AXModel_Disparity_Map", name=img_name) |
|
|
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser() |
| parser.add_argument('-l', '--left_imgs', help="path to all first (left) frames", |
| default="./demo-imgs/*/im0.png") |
| parser.add_argument('-t', '--target_chip', help="target chip for inference", |
| default="AX650", choices=["AX637", "AX650"]) |
| parser.add_argument('-r', '--right_imgs', help="path to all second (right) frames", |
| default="./demo-imgs/*/im1.png") |
| args = parser.parse_args() |
| demo(args) |
|
|