IGEV-plusplus / infer.py
Willaaaaaaa's picture
fix models and update images
b5e505f
import sys
import argparse
import glob
import numpy as np
import torch
from tqdm import tqdm
from pathlib import Path
from PIL import Image
from matplotlib import pyplot as plt
import os
import onnxruntime as ort
import axengine as axe
def load_image(imfile):
img = np.array(Image.open(imfile).resize((512,384))).astype(np.uint8)[..., :3]
img = torch.from_numpy(img).permute(2, 0, 1).float()
return img[None]
def visualize_disparity(disparity_map, title, name="test"):
plt.figure(figsize=(10, 6))
plt.imshow(disparity_map, cmap='jet')
plt.colorbar(label="Disparity")
plt.title(title)
plt.axis('off')
# plt.show()
plt.savefig(f"{title}-rt-{name}.png")
def demo(args):
left_images = sorted(glob.glob(args.left_imgs, recursive=True))
right_images = sorted(glob.glob(args.right_imgs, recursive=True))
if args.target_chip == "AX637":
print("\033[91mWarning: AX637 uses quant_axmodel, which can not be run by onnxruntime, \
so we use AX650's onnx model for comparison\033[0m")
ort_session = ort.InferenceSession(f"./models/AX650.onnx")
ax_session = axe.InferenceSession(f"./models/{args.target_chip}_RTIGEV.axmodel")
for (imfile1, imfile2) in tqdm(list(zip(left_images, right_images))):
image1 = load_image(imfile1)
image2 = load_image(imfile2)
img_name = Path(imfile1).parent.name
input_l_np = image1.cpu().numpy()
input_r_np = image2.cpu().numpy()
ax_inputs = {"left": input_l_np.transpose(0,2,3,1).astype(np.uint8), "right": input_r_np.transpose(0,2,3,1).astype(np.uint8)}
input_l_np = (2 * (input_l_np / 255.0) - 1.0)
input_r_np = (2 * (input_r_np / 255.0) - 1.0)
onnx_inputs = {"left": input_l_np, "right": input_r_np}
onnx_outputs = ort_session.run(None, onnx_inputs)
disp_onnx = onnx_outputs[0].squeeze()
ax_outputs = ax_session.run(None, ax_inputs)
disp_ax = ax_outputs[0].squeeze()
# print("disp_onnx",disp_onnx)
# print("disp_ax",disp_ax)
visualize_disparity(disp_onnx, title="ONNX_Disparity_Map", name=img_name)
visualize_disparity(disp_ax, title="AXModel_Disparity_Map", name=img_name)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-l', '--left_imgs', help="path to all first (left) frames",
default="./demo-imgs/*/im0.png")
parser.add_argument('-t', '--target_chip', help="target chip for inference",
default="AX650", choices=["AX637", "AX650"])
parser.add_argument('-r', '--right_imgs', help="path to all second (right) frames",
default="./demo-imgs/*/im1.png")
args = parser.parse_args()
demo(args)