Update convert_tflite_2_onnx.py
Browse files- convert_tflite_2_onnx.py +54 -0
convert_tflite_2_onnx.py
CHANGED
|
@@ -40,3 +40,57 @@ output = session.run(None, {input_name: image_data})
|
|
| 40 |
print(output)
|
| 41 |
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
print(output)
|
| 41 |
|
| 42 |
|
| 43 |
+
import onnxruntime as ort
|
| 44 |
+
import numpy as np
|
| 45 |
+
from PIL import Image
|
| 46 |
+
|
| 47 |
+
# Load ONNX model
|
| 48 |
+
onnx_model_path = 'model.onnx'
|
| 49 |
+
session = ort.InferenceSession(onnx_model_path)
|
| 50 |
+
|
| 51 |
+
# Function to preprocess a single image (resize and normalize)
|
| 52 |
+
def preprocess_image(image_path, input_size=(320, 320)):
|
| 53 |
+
image = Image.open(image_path).resize(input_size) # Resize to match model input size
|
| 54 |
+
image_data = np.array(image).astype('float32') # Convert to float32
|
| 55 |
+
image_data = np.expand_dims(image_data, axis=0) # Add batch dimension (1, height, width, channels)
|
| 56 |
+
return image_data
|
| 57 |
+
|
| 58 |
+
# Prepare a batch of images
|
| 59 |
+
image_paths = ['image1.jpg', 'image2.jpg', 'image3.jpg'] # List of image file paths
|
| 60 |
+
batch_size = len(image_paths)
|
| 61 |
+
|
| 62 |
+
# Preprocess each image and stack them into a batch
|
| 63 |
+
batch_images = np.vstack([preprocess_image(image_path) for image_path in image_paths])
|
| 64 |
+
|
| 65 |
+
# Check input name from the ONNX model
|
| 66 |
+
input_name = session.get_inputs()[0].name
|
| 67 |
+
|
| 68 |
+
# Run batch inference
|
| 69 |
+
outputs = session.run(None, {input_name: batch_images})
|
| 70 |
+
|
| 71 |
+
# Postprocessing: Extract scores, bounding boxes, and labels for each image in the batch
|
| 72 |
+
scores_batch, bboxes_batch, labels_batch = outputs[0], outputs[1], outputs[2]
|
| 73 |
+
|
| 74 |
+
# Iterate over the batch of results and filter based on score threshold
|
| 75 |
+
score_threshold = 0.5
|
| 76 |
+
|
| 77 |
+
for i in range(batch_size):
|
| 78 |
+
scores = scores_batch[i] # Scores for i-th image
|
| 79 |
+
bboxes = bboxes_batch[i] # Bounding boxes for i-th image
|
| 80 |
+
labels = labels_batch[i] # Labels for i-th image
|
| 81 |
+
|
| 82 |
+
# Filter indices where scores are greater than the threshold
|
| 83 |
+
valid_indices = np.where(scores > score_threshold)
|
| 84 |
+
|
| 85 |
+
# Filter the outputs based on valid indices
|
| 86 |
+
filtered_scores = scores[valid_indices]
|
| 87 |
+
filtered_bboxes = bboxes[valid_indices]
|
| 88 |
+
filtered_labels = labels[valid_indices]
|
| 89 |
+
|
| 90 |
+
print(f"Image {i+1}:")
|
| 91 |
+
print("Filtered Scores:", filtered_scores)
|
| 92 |
+
print("Filtered Bounding Boxes:", filtered_bboxes)
|
| 93 |
+
print("Filtered Labels:", filtered_labels)
|
| 94 |
+
print('---')
|
| 95 |
+
|
| 96 |
+
|