File size: 5,098 Bytes
2883555 4560703 2883555 4560703 2883555 4560703 2883555 4560703 2883555 4560703 2883555 4560703 2883555 4560703 2883555 4560703 2883555 1958a45 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import React, { useRef, useState } from 'react';
import Webcam from 'react-webcam';
import * as ort from 'onnxruntime-web';
function ObjectDetection() {
const [results, setResults] = useState([]);
const [loading, setLoading] = useState(false);
const webcamRef = useRef(null);
const runInference = async () => {
if (!webcamRef.current) return;
setLoading(true);
try {
// Capture image from webcam
const imageSrc = webcamRef.current.getScreenshot();
// Load the ONNX model
const model = await ort.InferenceSession.create('./model.onnx');
// Preprocess the image
const inputTensor = await preprocessImage(imageSrc);
// Define model input
const feeds = { input: inputTensor };
// Run inference
const output = await model.run(feeds);
// Postprocess the output
const detections = postprocessOutput(output);
setResults(detections);
} catch (error) {
console.error('Error running inference:', error);
}
setLoading(false);
};
const preprocessImage = async (imageSrc) => {
const img = new Image();
img.src = imageSrc;
await new Promise((resolve) => (img.onload = resolve));
const canvas = document.createElement('canvas');
const context = canvas.getContext('2d');
// Resize to model input size
const modelInputWidth = 300; // Replace with your model's input width
const modelInputHeight = 300; // Replace with your model's input height
canvas.width = modelInputWidth;
canvas.height = modelInputHeight;
context.drawImage(img, 0, 0, modelInputWidth, modelInputHeight);
const imageData = context.getImageData(0, 0, modelInputWidth, modelInputHeight);
// Check the required data type
const isUint8 = true; // Set to true if your model expects uint8, false for float32
if (isUint8) {
// Create Uint8Array tensor
return new ort.Tensor('uint8', imageData.data, [1, modelInputHeight, modelInputWidth, 3]);
} else {
// Normalize to [0, 1] and create Float32Array tensor
const floatData = new Float32Array(imageData.data.length / 4 * 3);
for (let i = 0, j = 0; i < imageData.data.length; i += 4) {
floatData[j++] = imageData.data[i] / 255; // R
floatData[j++] = imageData.data[i + 1] / 255; // G
floatData[j++] = imageData.data[i + 2] / 255; // B
}
return new ort.Tensor('float32', floatData, [1, 3, modelInputHeight, modelInputWidth]);
}
};
const postprocessOutput = (output) => {
const boxes = output['boxes'].data; // Replace 'boxes' with your model's output name
const scores = output['scores'].data; // Replace 'scores' with your model's output name
const classes = output['classes'].data; // Replace 'classes' with your model's output name
const detections = [];
for (let i = 0; i < scores.length; i++) {
if (scores[i] > 0.5) { // Confidence threshold
detections.push({
box: boxes.slice(i * 4, i * 4 + 4),
score: scores[i],
class: classes[i],
});
}
}
return detections;
};
return React.createElement(
'div',
null,
React.createElement('h1', null, 'Object Detection with Webcam'),
React.createElement(Webcam, {
audio: false,
ref: webcamRef,
screenshotFormat: 'image/jpeg',
width: 300,
height: 300,
}),
React.createElement(
'button',
{ onClick: runInference, disabled: loading },
loading ? 'Detecting...' : 'Capture & Detect'
),
React.createElement(
'div',
null,
React.createElement('h2', null, 'Results:'),
React.createElement(
'ul',
null,
results.map((result, index) =>
React.createElement(
'li',
{ key: index },
`Class: ${result.class}, Score: ${result.score.toFixed(2)}, Box: ${result.box.join(', ')}`
)
)
)
)
);
}
export default ObjectDetection;
const preprocessImage = async (imageSrc) => {
const img = new Image();
img.src = imageSrc;
await new Promise((resolve) => (img.onload = resolve));
const canvas = document.createElement('canvas');
const context = canvas.getContext('2d');
// Resize to model input size
const modelInputWidth = 320; // Replace with your model's input width
const modelInputHeight = 320; // Replace with your model's input height
canvas.width = modelInputWidth;
canvas.height = modelInputHeight;
context.drawImage(img, 0, 0, modelInputWidth, modelInputHeight);
const imageData = context.getImageData(0, 0, modelInputWidth, modelInputHeight);
// Convert RGBA to RGB
const rgbData = new Uint8Array((imageData.data.length / 4) * 3); // 3 channels for RGB
for (let i = 0, j = 0; i < imageData.data.length; i += 4) {
rgbData[j++] = imageData.data[i]; // R
rgbData[j++] = imageData.data[i + 1]; // G
rgbData[j++] = imageData.data[i + 2]; // B
// Skip A (alpha) channel
}
// Create a tensor with shape [1, 320, 320, 3]
return new ort.Tensor('uint8', rgbData, [1, modelInputHeight, modelInputWidth, 3]);
}; |