This model is a fork of facebook/levit-256, where:
nn.BatchNorm2dandnn.Conv2dare fusednn.BatchNorm1dandnn.Linearare fused
and the optimized model is converted to the onnx format.
The fusion of layers leverages torch.fx, using the transformations FuseBatchNorm2dInConv2d and FuseBatchNorm1dInLinear soon to be available to use out-of-the-box with 🤗 Optimum, check it out: https://huggingface.co/docs/optimum/main/en/fx/optimization#the-transformation-guide .
How to use
from optimum.onnxruntime.modeling_ort import ORTModelForImageClassification
from transformers import AutoFeatureExtractor
from PIL import Image
import requests
preprocessor = AutoFeatureExtractor.from_pretrained("fxmarty/levit-256-onnx")
ort_model = ORTModelForImageClassification.from_pretrained("fxmarty/levit-256-onnx")
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
inputs = preprocessor(images=image, return_tensors="pt")
outputs = model(**inputs)
predicted_class_idx = outputs.logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])
To be safe, check as well that the onnx model returns the same logits as the PyTorch model:
from optimum.onnxruntime.modeling_ort import ORTModelForImageClassification
from transformers import AutoModelForImageClassification
pt_model = AutoModelForImageClassification.from_pretrained("facebook/levit-256")
pt_model.eval()
ort_model = ORTModelForImageClassification.from_pretrained("fxmarty/levit-256-onnx")
inp = {"pixel_values": torch.rand(1, 3, 224, 224)}
with torch.no_grad():
res = pt_model(**inp)
res_ort = ort_model(**inp)
assert torch.allclose(res.logits, res_ort.logits, atol=1e-4)
Benchmarking
More than x2 throughput with batch normalization folding and onnxruntime 🔥
Below you can find latency percentiles and mean (in ms), and the models throughput (in iterations/s).
PyTorch runtime:
{'latency_50': 22.3024695,
'latency_90': 23.1230725,
'latency_95': 23.2653985,
'latency_99': 23.60095705,
'latency_999': 23.865580469999998,
'latency_mean': 22.442956878923766,
'latency_std': 0.46544295612971265,
'nb_forwards': 446,
'throughput': 44.6}
Optimum-onnxruntime runtime:
{'latency_50': 9.302445,
'latency_90': 9.782875,
'latency_95': 9.9071944,
'latency_99': 11.084606999999997,
'latency_999': 12.035858692000001,
'latency_mean': 9.357703552853133,
'latency_std': 0.4018553286992142,
'nb_forwards': 1069,
'throughput': 106.9}
Run on your own machine with:
from optimum.runs_base import TimeBenchmark
from pprint import pprint
time_benchmark_ort = TimeBenchmark(
model=ort_model,
batch_size=1,
input_length=224,
model_input_names={"pixel_values"},
warmup_runs=10,
duration=10
)
results_ort = time_benchmark_ort.execute()
with torch.no_grad():
time_benchmark_pt = TimeBenchmark(
model=pt_model,
batch_size=1,
input_length=224,
model_input_names={"pixel_values"},
warmup_runs=10,
duration=10
)
results_pt = time_benchmark_pt.execute()
print("PyTorch runtime:\n")
pprint(results_pt)
print("\nOptimum-onnxruntime runtime:\n")
pprint(results_ort)
- Downloads last month
- 2