SAM3-Tracker-WebGPU / index.html
Xenova's picture
Xenova HF Staff
Update index.html
bede8f0 verified
<!doctype html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>SAM3 WebGPU | Transformers.js</title>
<script src="https://cdn.tailwindcss.com"></script>
<style>
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap');
body {
font-family: 'Inter', sans-serif;
}
/* Style for the mask canvas overlay */
canvas {
position: absolute;
top: 0;
left: 0;
opacity: 0.6;
pointer-events: none;
}
/* Style for the emoji markers (star/cross) */
.icon {
position: absolute;
transform: translate(-50%, -50%);
font-size: 24px;
user-select: none;
pointer-events: none;
text-shadow: 0 0 4px white;
}
.aspect-w-3 {
position: relative;
width: 100%;
}
.aspect-w-3::before {
content: '';
display: block;
padding-bottom: calc(var(--aspect-h) / var(--aspect-w) * 100%);
}
.aspect-w-3> :first-child {
position: absolute;
top: 0;
left: 0;
width: 100%;
height: 100%;
}
.aspect-w-3 {
--aspect-w: 3;
}
.aspect-h-2 {
--aspect-h: 2;
}
.aspect-w-4 {
--aspect-w: 4;
}
.aspect-h-3 {
--aspect-h: 3;
}
</style>
</head>
<body class="bg-gray-100 text-gray-800 min-h-screen flex flex-col items-center justify-center p-4 sm:p-8">
<div class="w-full max-w-3xl bg-white rounded-xl shadow-2xl overflow-hidden">
<div class="p-6 sm:p-10">
<h1 class="text-3xl sm:text-4xl font-bold text-center text-gray-900">SAM3 WebGPU</h1>
<h3 class="text-lg sm:text-xl text-gray-500 text-center mb-6">
In-browser image segmentation w/
<a href="https://hf.co/docs/transformers.js" target="_blank" class="text-blue-600 hover:underline">🤗
Transformers.js</a>
</h3>
<div id="container"
class="relative w-full max-w-2xl mx-auto border border-gray-200 rounded-lg overflow-hidden cursor-pointer bg-gray-50 shadow-sm transition-all aspect-w-3 aspect-h-2">
<label id="upload-area" for="upload"
class="absolute inset-0 z-10 flex flex-col justify-center items-center p-10 transition-all cursor-pointer border-2 border-dashed border-gray-300 rounded-lg hover:bg-gray-50/50 hover:border-blue-500">
<div class="flex flex-col items-center justify-center p-6 transition-colors w-full max-w-sm">
<svg class="w-12 h-12 text-gray-400" fill="currentColor" viewBox="0 0 25 25"
xmlns="http://www.w3.org/2000/svg">
<path
d="M3.5 24.3a3 3 0 0 1-1.9-.8c-.5-.5-.8-1.2-.8-1.9V2.9c0-.7.3-1.3.8-1.9.6-.5 1.2-.7 2-.7h18.6c.7 0 1.3.2 1.9.7.5.6.7 1.2.7 2v18.6c0 .7-.2 1.4-.7 1.9a3 3 0 0 1-2 .8H3.6Zm0-2.7h18.7V2.9H3.5v18.7Zm2.7-2.7h13.3c.3 0 .5 0 .6-.3v-.7l-3.7-5a.6.6 0 0 0-.6-.2c-.2 0-.4 0-.5.3l-3.5 4.6-2.4-3.3a.6.6 0 0 0-.6-.3c-.2 0-.4.1-.5.3l-2.7 3.6c-.1.2-.2.4 0 .7.1.2.3.3.6.3Z">
</path>
</svg>
<span class="text-lg font-medium text-gray-500 mt-2">Click to upload image</span>
<span class="text-sm text-gray-400">or drag and drop</span>
</div>
<p class="text-gray-500 text-sm my-4">...or try an example:</p>
<div id="example-gallery" class="flex gap-4">
<img src="https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/truck.jpg"
class="example-image w-20 h-20 sm:w-24 sm:h-24 object-cover rounded-lg shadow-md cursor-pointer hover:opacity-80 transition-opacity"
alt="Example of a truck">
<img src="https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/corgi.jpg"
class="example-image w-20 h-20 sm:w-24 sm:h-24 object-cover rounded-lg shadow-md cursor-pointer hover:opacity-80 transition-opacity"
alt="Example of a corgi">
<img src="https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/groceries.jpg"
class="example-image w-20 h-20 sm:w-24 sm:h-24 object-cover rounded-lg shadow-md cursor-pointer hover:opacity-80 transition-opacity"
alt="Example of groceries">
</div>
</label>
<img id="image-display" class="absolute inset-0 w-full h-full object-contain block hidden z-0" />
<canvas id="mask-output"></canvas>
</div>
<label id="status" class="text-base text-center text-gray-600 min-h-[1.5rem] mt-6 mb-4 block w-full">Loading
model...</label>
<div id="controls" class="flex flex-col sm:flex-row justify-center gap-3">
<button id="reset-image" disabled
class="w-full sm:w-auto inline-flex items-center justify-center px-4 py-2 bg-gray-200 text-gray-800 font-medium rounded-lg shadow-sm hover:bg-gray-300 transition-colors disabled:bg-gray-100 disabled:text-gray-400 disabled:cursor-not-allowed">
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none"
stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"
class="w-4 h-4 mr-2">
<path d="M3 12a9 9 0 1 0 9-9 9.75 9.75 0 0 0-6.74 2.74L3 8" />
<path d="M3 3v5h5" />
</svg>
Reset image
</button>
<button id="clear-points" disabled
class="w-full sm:w-auto inline-flex items-center justify-center px-4 py-2 bg-gray-200 text-gray-800 font-medium rounded-lg shadow-sm hover:bg-gray-300 transition-colors disabled:bg-gray-100 disabled:text-gray-400 disabled:cursor-not-allowed">
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none"
stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"
class="w-4 h-4 mr-2">
<line x1="18" y1="6" x2="6" y2="18"></line>
<line x1="6" y1="6" x2="18" y2="18"></line>
</svg>
Clear points
</button>
<button id="cut-mask" disabled
class="w-full sm:w-auto inline-flex items-center justify-center px-4 py-2 bg-blue-600 text-white font-medium rounded-lg shadow-sm hover:bg-blue-700 transition-colors disabled:bg-gray-300 disabled:text-gray-500 disabled:cursor-not-allowed">
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none"
stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"
class="w-4 h-4 mr-2">
<path d="M21 15v4a2 2 0 0 1-2 2H5a2 2 0 0 1-2-2v-4" />
<polyline points="7 10 12 15 17 10" />
<line x1="12" y1="15" x2="12" y2="3" />
</svg>
Cut & Download
</button>
</div>
<p id="information" class="text-sm text-gray-500 mt-4 text-center">
Left click = positive (⭐), Right click = negative (❌).
</p>
</div>
</div>
<input id="upload" type="file" accept="image/*" disabled class="hidden" />
<script type="module">
import {
Sam3TrackerModel,
AutoProcessor,
RawImage,
Tensor,
} from "https://cdn.jsdelivr.net/npm/@huggingface/[email protected]";
const statusLabel = document.getElementById("status");
const fileUpload = document.getElementById("upload");
const imageContainer = document.getElementById("container");
const uploadArea = document.getElementById("upload-area");
const exampleImages = document.querySelectorAll(".example-image");
const resetButton = document.getElementById("reset-image");
const clearButton = document.getElementById("clear-points");
const cutButton = document.getElementById("cut-mask");
const imageDisplay = document.getElementById("image-display");
const maskCanvas = document.getElementById("mask-output");
const maskContext = maskCanvas.getContext("2d");
let isEncoding = false;
let isDecoding = false;
let decodePending = false;
let lastPoints = null;
let isMultiMaskMode = false;
let imageInput = null;
let imageProcessed = null;
let imageEmbeddings = null;
/**
* Encodes the image and generates embeddings.
*/
async function encode(url) {
if (isEncoding) return;
isEncoding = true;
statusLabel.textContent = "Extracting image embedding...";
try {
imageInput = await RawImage.fromURL(url);
imageDisplay.onload = updateCanvasGeometry;
imageDisplay.src = url;
imageDisplay.classList.remove('hidden');
uploadArea.classList.add("hidden");
cutButton.disabled = true;
imageProcessed = await processor(imageInput);
imageEmbeddings = await model.get_image_embeddings(imageProcessed);
console.log({ imageEmbeddings })
statusLabel.textContent = "Embedding extracted! Click on the image.";
resetButton.disabled = false;
clearButton.disabled = false;
} catch (error) {
console.error("Error during encoding:", error);
statusLabel.textContent = "Error loading image. Please try again.";
resetUI();
} finally {
isEncoding = false;
}
}
/**
* Decodes the mask based on the current points.
*/
async function decode() {
if (isDecoding || !imageEmbeddings || !lastPoints || lastPoints.length === 0) {
if (isDecoding) {
decodePending = true;
}
return;
}
isDecoding = true;
try {
const reshaped = imageProcessed.reshaped_input_sizes[0];
const points = lastPoints
.map((x) => [x.position[0] * reshaped[1], x.position[1] * reshaped[0]])
.flat(Infinity);
const labels = lastPoints.map((x) => BigInt(x.label)).flat(Infinity);
const num_points = lastPoints.length;
const input_points = new Tensor("float32", points, [1, 1, num_points, 2]);
const input_labels = new Tensor("int64", labels, [1, 1, num_points]);
const { pred_masks, iou_scores } = await model({
...imageEmbeddings,
input_points,
input_labels,
});
const masks = await processor.post_process_masks(
pred_masks,
imageProcessed.original_sizes,
imageProcessed.reshaped_input_sizes,
);
updateMaskOverlay(RawImage.fromTensor(masks[0][0]), iou_scores.data);
} catch (error) {
console.error("Error during decoding:", error);
statusLabel.textContent = "Error generating mask.";
} finally {
isDecoding = false;
}
if (decodePending) {
decodePending = false;
setTimeout(decode, 0);
}
}
/**
* Resizes and positions the canvas overlay to match the displayed image.
*/
function updateCanvasGeometry() {
if (!imageDisplay.src || imageDisplay.classList.contains('hidden')) return;
const { naturalWidth, naturalHeight } = imageDisplay;
const { width: containerWidth, height: containerHeight } = imageContainer.getBoundingClientRect();
const imageAspectRatio = naturalWidth / naturalHeight;
const containerAspectRatio = containerWidth / containerHeight;
let newWidth, newHeight, newTop, newLeft;
if (imageAspectRatio > containerAspectRatio) {
newWidth = containerWidth;
newHeight = newWidth / imageAspectRatio;
newTop = (containerHeight - newHeight) / 2;
newLeft = 0;
} else {
newHeight = containerHeight;
newWidth = newHeight * imageAspectRatio;
newLeft = (containerWidth - newWidth) / 2;
newTop = 0;
}
maskCanvas.style.width = `${newWidth}px`;
maskCanvas.style.height = `${newHeight}px`;
maskCanvas.style.top = `${newTop}px`;
maskCanvas.style.left = `${newLeft}px`;
}
/**
* Draws the generated mask onto the canvas.
*/
function updateMaskOverlay(mask, scores) {
if (maskCanvas.width !== mask.width || maskCanvas.height !== mask.height) {
maskCanvas.width = mask.width;
maskCanvas.height = mask.height;
}
const imageData = maskContext.createImageData(
maskCanvas.width,
maskCanvas.height,
);
const numMasks = scores.length;
let bestIndex = 0;
for (let i = 1; i < numMasks; ++i) {
if (scores[i] > scores[bestIndex]) {
bestIndex = i;
}
}
statusLabel.textContent = `Segment score: ${scores[bestIndex].toFixed(2)}`;
const pixelData = imageData.data;
for (let i = 0; i < pixelData.length; ++i) {
if (mask.data[numMasks * i + bestIndex] === 1) {
const offset = 4 * i;
pixelData[offset] = 0;
pixelData[offset + 1] = 114;
pixelData[offset + 2] = 189;
pixelData[offset + 3] = 255;
}
}
maskContext.putImageData(imageData, 0, 0);
}
function clearPointsAndMask() {
isMultiMaskMode = false;
lastPoints = null;
document.querySelectorAll(".icon").forEach((e) => e.remove());
cutButton.disabled = true;
maskContext.clearRect(0, 0, maskCanvas.width, maskCanvas.height);
statusLabel.textContent = "Points cleared. Click to add new points.";
}
function resetUI() {
imageInput = null;
imageProcessed = null;
imageEmbeddings = null;
isEncoding = false;
isDecoding = false;
decodePending = false;
clearPointsAndMask();
cutButton.disabled = true;
resetButton.disabled = true;
clearButton.disabled = true;
imageDisplay.src = '';
imageDisplay.classList.add('hidden');
uploadArea.classList.remove("hidden");
// Reset canvas geometry
maskCanvas.style.width = '0px';
maskCanvas.style.height = '0px';
exampleImages.forEach(img => img.style.pointerEvents = "auto");
statusLabel.textContent = "Ready";
}
function clamp(x, min = 0, max = 1) {
return Math.max(Math.min(x, max), min);
}
function getPoint(e) {
const imgBB = imageDisplay.getBoundingClientRect();
const canvasBB = maskCanvas.getBoundingClientRect();
// Calculate normalized coordinates (0 to 1) relative to the image
const mouseX = clamp((e.clientX - canvasBB.left) / canvasBB.width);
const mouseY = clamp((e.clientY - canvasBB.top) / canvasBB.height);
return {
position: [mouseX, mouseY],
label: e.button === 2 ? 0 : 1,
};
}
fileUpload.addEventListener("change", function (e) {
const file = e.target.files[0];
if (!file) return;
const reader = new FileReader();
reader.onload = (e2) => encode(e2.target.result);
reader.readAsDataURL(file);
});
exampleImages.forEach((img) => {
img.addEventListener("click", (e) => {
e.preventDefault();
e.stopPropagation();
exampleImages.forEach(i => i.style.pointerEvents = "none");
encode(img.src);
});
});
window.addEventListener("resize", updateCanvasGeometry);
resetButton.addEventListener("click", resetUI);
clearButton.addEventListener("click", clearPointsAndMask);
imageContainer.addEventListener("mousedown", (e) => {
if (!imageEmbeddings || uploadArea.classList.contains('hidden') === false) {
return;
}
if (e.button !== 0 && e.button !== 2) return;
if (!isMultiMaskMode) {
lastPoints = [];
isMultiMaskMode = true;
cutButton.disabled = false;
}
const point = getPoint(e);
lastPoints.push(point);
const icon = document.createElement('span');
icon.className = 'icon';
icon.textContent = point.label === 1 ? '⭐' : '❌';
// Calculate position relative to the container, considering canvas offset and size
const canvasRect = maskCanvas.getBoundingClientRect();
const containerRect = imageContainer.getBoundingClientRect();
const left = canvasRect.left - containerRect.left + point.position[0] * canvasRect.width;
const top = canvasRect.top - containerRect.top + point.position[1] * canvasRect.height;
icon.style.left = `${left}px`;
icon.style.top = `${top}px`;
imageContainer.appendChild(icon);
decode();
});
imageContainer.addEventListener("contextmenu", (e) => e.preventDefault());
imageContainer.addEventListener("mousemove", (e) => {
if (!imageEmbeddings || isMultiMaskMode || uploadArea.classList.contains('hidden') === false) {
return;
}
lastPoints = [getPoint(e)];
decode();
});
cutButton.addEventListener("click", async () => {
if (!imageInput || !maskCanvas) return;
const [w, h] = [maskCanvas.width, maskCanvas.height];
const maskImageData = maskContext.getImageData(0, 0, w, h);
const maskPixelData = maskImageData.data;
const cutCanvas = new OffscreenCanvas(w, h);
const cutContext = cutCanvas.getContext("2d");
const cutImageData = cutContext.createImageData(w, h);
const cutPixelData = cutImageData.data;
const imagePixelData = imageInput.data;
for (let i = 0; i < w * h; ++i) {
const maskOffset = 4 * i;
const imageOffset = 3 * i;
if (maskPixelData[maskOffset + 3] > 0) {
cutPixelData[maskOffset] = imagePixelData[imageOffset];
cutPixelData[maskOffset + 1] = imagePixelData[imageOffset + 1];
cutPixelData[maskOffset + 2] = imagePixelData[imageOffset + 2];
cutPixelData[maskOffset + 3] = 255;
}
}
cutContext.putImageData(cutImageData, 0, 0);
const link = document.createElement("a");
link.download = "mask-cutout.png";
link.href = URL.createObjectURL(await cutCanvas.convertToBlob());
link.click();
link.remove();
});
async function loadModel() {
try {
const model_id = "onnx-community/sam3-tracker-ONNX";
const model = await Sam3TrackerModel.from_pretrained(model_id, {
dtype: {
vision_encoder: "q4",
prompt_encoder_mask_decoder: "fp32",
},
device: "webgpu",
});
const processor = await AutoProcessor.from_pretrained(model_id);
statusLabel.textContent = "Ready";
fileUpload.disabled = false;
return { model, processor };
} catch (error) {
console.error("Error loading model:", error);
statusLabel.textContent = "Error loading model. Please refresh the page.";
}
}
const { model, processor } = await loadModel();
</script>
</body>
</html>