|
|
<!doctype html> |
|
|
<html lang="en"> |
|
|
|
|
|
<head> |
|
|
<meta charset="UTF-8" /> |
|
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" /> |
|
|
<title>SAM3 WebGPU | Transformers.js</title> |
|
|
|
|
|
<script src="https://cdn.tailwindcss.com"></script> |
|
|
|
|
|
<style> |
|
|
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap'); |
|
|
|
|
|
body { |
|
|
font-family: 'Inter', sans-serif; |
|
|
} |
|
|
|
|
|
|
|
|
canvas { |
|
|
position: absolute; |
|
|
top: 0; |
|
|
left: 0; |
|
|
opacity: 0.6; |
|
|
pointer-events: none; |
|
|
} |
|
|
|
|
|
|
|
|
.icon { |
|
|
position: absolute; |
|
|
transform: translate(-50%, -50%); |
|
|
font-size: 24px; |
|
|
user-select: none; |
|
|
pointer-events: none; |
|
|
text-shadow: 0 0 4px white; |
|
|
} |
|
|
|
|
|
.aspect-w-3 { |
|
|
position: relative; |
|
|
width: 100%; |
|
|
} |
|
|
|
|
|
.aspect-w-3::before { |
|
|
content: ''; |
|
|
display: block; |
|
|
padding-bottom: calc(var(--aspect-h) / var(--aspect-w) * 100%); |
|
|
} |
|
|
|
|
|
.aspect-w-3> :first-child { |
|
|
position: absolute; |
|
|
top: 0; |
|
|
left: 0; |
|
|
width: 100%; |
|
|
height: 100%; |
|
|
} |
|
|
|
|
|
.aspect-w-3 { |
|
|
--aspect-w: 3; |
|
|
} |
|
|
|
|
|
.aspect-h-2 { |
|
|
--aspect-h: 2; |
|
|
} |
|
|
|
|
|
.aspect-w-4 { |
|
|
--aspect-w: 4; |
|
|
} |
|
|
|
|
|
.aspect-h-3 { |
|
|
--aspect-h: 3; |
|
|
} |
|
|
</style> |
|
|
</head> |
|
|
|
|
|
<body class="bg-gray-100 text-gray-800 min-h-screen flex flex-col items-center justify-center p-4 sm:p-8"> |
|
|
|
|
|
<div class="w-full max-w-3xl bg-white rounded-xl shadow-2xl overflow-hidden"> |
|
|
|
|
|
<div class="p-6 sm:p-10"> |
|
|
<h1 class="text-3xl sm:text-4xl font-bold text-center text-gray-900">SAM3 WebGPU</h1> |
|
|
<h3 class="text-lg sm:text-xl text-gray-500 text-center mb-6"> |
|
|
In-browser image segmentation w/ |
|
|
<a href="https://hf.co/docs/transformers.js" target="_blank" class="text-blue-600 hover:underline">🤗 |
|
|
Transformers.js</a> |
|
|
</h3> |
|
|
|
|
|
<div id="container" |
|
|
class="relative w-full max-w-2xl mx-auto border border-gray-200 rounded-lg overflow-hidden cursor-pointer bg-gray-50 shadow-sm transition-all aspect-w-3 aspect-h-2"> |
|
|
|
|
|
<label id="upload-area" for="upload" |
|
|
class="absolute inset-0 z-10 flex flex-col justify-center items-center p-10 transition-all cursor-pointer border-2 border-dashed border-gray-300 rounded-lg hover:bg-gray-50/50 hover:border-blue-500"> |
|
|
<div class="flex flex-col items-center justify-center p-6 transition-colors w-full max-w-sm"> |
|
|
<svg class="w-12 h-12 text-gray-400" fill="currentColor" viewBox="0 0 25 25" |
|
|
xmlns="http://www.w3.org/2000/svg"> |
|
|
<path |
|
|
d="M3.5 24.3a3 3 0 0 1-1.9-.8c-.5-.5-.8-1.2-.8-1.9V2.9c0-.7.3-1.3.8-1.9.6-.5 1.2-.7 2-.7h18.6c.7 0 1.3.2 1.9.7.5.6.7 1.2.7 2v18.6c0 .7-.2 1.4-.7 1.9a3 3 0 0 1-2 .8H3.6Zm0-2.7h18.7V2.9H3.5v18.7Zm2.7-2.7h13.3c.3 0 .5 0 .6-.3v-.7l-3.7-5a.6.6 0 0 0-.6-.2c-.2 0-.4 0-.5.3l-3.5 4.6-2.4-3.3a.6.6 0 0 0-.6-.3c-.2 0-.4.1-.5.3l-2.7 3.6c-.1.2-.2.4 0 .7.1.2.3.3.6.3Z"> |
|
|
</path> |
|
|
</svg> |
|
|
<span class="text-lg font-medium text-gray-500 mt-2">Click to upload image</span> |
|
|
<span class="text-sm text-gray-400">or drag and drop</span> |
|
|
</div> |
|
|
|
|
|
<p class="text-gray-500 text-sm my-4">...or try an example:</p> |
|
|
|
|
|
<div id="example-gallery" class="flex gap-4"> |
|
|
<img src="https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/truck.jpg" |
|
|
class="example-image w-20 h-20 sm:w-24 sm:h-24 object-cover rounded-lg shadow-md cursor-pointer hover:opacity-80 transition-opacity" |
|
|
alt="Example of a truck"> |
|
|
<img src="https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/corgi.jpg" |
|
|
class="example-image w-20 h-20 sm:w-24 sm:h-24 object-cover rounded-lg shadow-md cursor-pointer hover:opacity-80 transition-opacity" |
|
|
alt="Example of a corgi"> |
|
|
<img src="https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/groceries.jpg" |
|
|
class="example-image w-20 h-20 sm:w-24 sm:h-24 object-cover rounded-lg shadow-md cursor-pointer hover:opacity-80 transition-opacity" |
|
|
alt="Example of groceries"> |
|
|
</div> |
|
|
</label> |
|
|
|
|
|
<img id="image-display" class="absolute inset-0 w-full h-full object-contain block hidden z-0" /> |
|
|
|
|
|
<canvas id="mask-output"></canvas> |
|
|
</div> |
|
|
|
|
|
<label id="status" class="text-base text-center text-gray-600 min-h-[1.5rem] mt-6 mb-4 block w-full">Loading |
|
|
model...</label> |
|
|
|
|
|
<div id="controls" class="flex flex-col sm:flex-row justify-center gap-3"> |
|
|
<button id="reset-image" disabled |
|
|
class="w-full sm:w-auto inline-flex items-center justify-center px-4 py-2 bg-gray-200 text-gray-800 font-medium rounded-lg shadow-sm hover:bg-gray-300 transition-colors disabled:bg-gray-100 disabled:text-gray-400 disabled:cursor-not-allowed"> |
|
|
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" |
|
|
stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" |
|
|
class="w-4 h-4 mr-2"> |
|
|
<path d="M3 12a9 9 0 1 0 9-9 9.75 9.75 0 0 0-6.74 2.74L3 8" /> |
|
|
<path d="M3 3v5h5" /> |
|
|
</svg> |
|
|
Reset image |
|
|
</button> |
|
|
<button id="clear-points" disabled |
|
|
class="w-full sm:w-auto inline-flex items-center justify-center px-4 py-2 bg-gray-200 text-gray-800 font-medium rounded-lg shadow-sm hover:bg-gray-300 transition-colors disabled:bg-gray-100 disabled:text-gray-400 disabled:cursor-not-allowed"> |
|
|
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" |
|
|
stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" |
|
|
class="w-4 h-4 mr-2"> |
|
|
<line x1="18" y1="6" x2="6" y2="18"></line> |
|
|
<line x1="6" y1="6" x2="18" y2="18"></line> |
|
|
</svg> |
|
|
Clear points |
|
|
</button> |
|
|
<button id="cut-mask" disabled |
|
|
class="w-full sm:w-auto inline-flex items-center justify-center px-4 py-2 bg-blue-600 text-white font-medium rounded-lg shadow-sm hover:bg-blue-700 transition-colors disabled:bg-gray-300 disabled:text-gray-500 disabled:cursor-not-allowed"> |
|
|
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" |
|
|
stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" |
|
|
class="w-4 h-4 mr-2"> |
|
|
<path d="M21 15v4a2 2 0 0 1-2 2H5a2 2 0 0 1-2-2v-4" /> |
|
|
<polyline points="7 10 12 15 17 10" /> |
|
|
<line x1="12" y1="15" x2="12" y2="3" /> |
|
|
</svg> |
|
|
Cut & Download |
|
|
</button> |
|
|
</div> |
|
|
|
|
|
<p id="information" class="text-sm text-gray-500 mt-4 text-center"> |
|
|
Left click = positive (⭐), Right click = negative (❌). |
|
|
</p> |
|
|
</div> |
|
|
</div> |
|
|
|
|
|
<input id="upload" type="file" accept="image/*" disabled class="hidden" /> |
|
|
|
|
|
|
|
|
<script type="module"> |
|
|
import { |
|
|
Sam3TrackerModel, |
|
|
AutoProcessor, |
|
|
RawImage, |
|
|
Tensor, |
|
|
} from "https://cdn.jsdelivr.net/npm/@huggingface/[email protected]"; |
|
|
|
|
|
const statusLabel = document.getElementById("status"); |
|
|
const fileUpload = document.getElementById("upload"); |
|
|
const imageContainer = document.getElementById("container"); |
|
|
const uploadArea = document.getElementById("upload-area"); |
|
|
const exampleImages = document.querySelectorAll(".example-image"); |
|
|
const resetButton = document.getElementById("reset-image"); |
|
|
const clearButton = document.getElementById("clear-points"); |
|
|
const cutButton = document.getElementById("cut-mask"); |
|
|
const imageDisplay = document.getElementById("image-display"); |
|
|
const maskCanvas = document.getElementById("mask-output"); |
|
|
const maskContext = maskCanvas.getContext("2d"); |
|
|
|
|
|
let isEncoding = false; |
|
|
let isDecoding = false; |
|
|
let decodePending = false; |
|
|
let lastPoints = null; |
|
|
let isMultiMaskMode = false; |
|
|
let imageInput = null; |
|
|
let imageProcessed = null; |
|
|
let imageEmbeddings = null; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async function encode(url) { |
|
|
if (isEncoding) return; |
|
|
isEncoding = true; |
|
|
statusLabel.textContent = "Extracting image embedding..."; |
|
|
|
|
|
try { |
|
|
imageInput = await RawImage.fromURL(url); |
|
|
|
|
|
imageDisplay.onload = updateCanvasGeometry; |
|
|
imageDisplay.src = url; |
|
|
imageDisplay.classList.remove('hidden'); |
|
|
uploadArea.classList.add("hidden"); |
|
|
cutButton.disabled = true; |
|
|
|
|
|
imageProcessed = await processor(imageInput); |
|
|
imageEmbeddings = await model.get_image_embeddings(imageProcessed); |
|
|
console.log({ imageEmbeddings }) |
|
|
|
|
|
statusLabel.textContent = "Embedding extracted! Click on the image."; |
|
|
resetButton.disabled = false; |
|
|
clearButton.disabled = false; |
|
|
} catch (error) { |
|
|
console.error("Error during encoding:", error); |
|
|
statusLabel.textContent = "Error loading image. Please try again."; |
|
|
resetUI(); |
|
|
} finally { |
|
|
isEncoding = false; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async function decode() { |
|
|
if (isDecoding || !imageEmbeddings || !lastPoints || lastPoints.length === 0) { |
|
|
if (isDecoding) { |
|
|
decodePending = true; |
|
|
} |
|
|
return; |
|
|
} |
|
|
isDecoding = true; |
|
|
|
|
|
try { |
|
|
const reshaped = imageProcessed.reshaped_input_sizes[0]; |
|
|
const points = lastPoints |
|
|
.map((x) => [x.position[0] * reshaped[1], x.position[1] * reshaped[0]]) |
|
|
.flat(Infinity); |
|
|
const labels = lastPoints.map((x) => BigInt(x.label)).flat(Infinity); |
|
|
|
|
|
const num_points = lastPoints.length; |
|
|
const input_points = new Tensor("float32", points, [1, 1, num_points, 2]); |
|
|
const input_labels = new Tensor("int64", labels, [1, 1, num_points]); |
|
|
|
|
|
const { pred_masks, iou_scores } = await model({ |
|
|
...imageEmbeddings, |
|
|
input_points, |
|
|
input_labels, |
|
|
}); |
|
|
|
|
|
const masks = await processor.post_process_masks( |
|
|
pred_masks, |
|
|
imageProcessed.original_sizes, |
|
|
imageProcessed.reshaped_input_sizes, |
|
|
); |
|
|
|
|
|
updateMaskOverlay(RawImage.fromTensor(masks[0][0]), iou_scores.data); |
|
|
|
|
|
} catch (error) { |
|
|
console.error("Error during decoding:", error); |
|
|
statusLabel.textContent = "Error generating mask."; |
|
|
} finally { |
|
|
isDecoding = false; |
|
|
} |
|
|
|
|
|
if (decodePending) { |
|
|
decodePending = false; |
|
|
setTimeout(decode, 0); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
function updateCanvasGeometry() { |
|
|
if (!imageDisplay.src || imageDisplay.classList.contains('hidden')) return; |
|
|
|
|
|
const { naturalWidth, naturalHeight } = imageDisplay; |
|
|
const { width: containerWidth, height: containerHeight } = imageContainer.getBoundingClientRect(); |
|
|
|
|
|
const imageAspectRatio = naturalWidth / naturalHeight; |
|
|
const containerAspectRatio = containerWidth / containerHeight; |
|
|
|
|
|
let newWidth, newHeight, newTop, newLeft; |
|
|
|
|
|
if (imageAspectRatio > containerAspectRatio) { |
|
|
newWidth = containerWidth; |
|
|
newHeight = newWidth / imageAspectRatio; |
|
|
newTop = (containerHeight - newHeight) / 2; |
|
|
newLeft = 0; |
|
|
} else { |
|
|
newHeight = containerHeight; |
|
|
newWidth = newHeight * imageAspectRatio; |
|
|
newLeft = (containerWidth - newWidth) / 2; |
|
|
newTop = 0; |
|
|
} |
|
|
|
|
|
maskCanvas.style.width = `${newWidth}px`; |
|
|
maskCanvas.style.height = `${newHeight}px`; |
|
|
maskCanvas.style.top = `${newTop}px`; |
|
|
maskCanvas.style.left = `${newLeft}px`; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
function updateMaskOverlay(mask, scores) { |
|
|
if (maskCanvas.width !== mask.width || maskCanvas.height !== mask.height) { |
|
|
maskCanvas.width = mask.width; |
|
|
maskCanvas.height = mask.height; |
|
|
} |
|
|
|
|
|
const imageData = maskContext.createImageData( |
|
|
maskCanvas.width, |
|
|
maskCanvas.height, |
|
|
); |
|
|
|
|
|
const numMasks = scores.length; |
|
|
let bestIndex = 0; |
|
|
for (let i = 1; i < numMasks; ++i) { |
|
|
if (scores[i] > scores[bestIndex]) { |
|
|
bestIndex = i; |
|
|
} |
|
|
} |
|
|
statusLabel.textContent = `Segment score: ${scores[bestIndex].toFixed(2)}`; |
|
|
|
|
|
const pixelData = imageData.data; |
|
|
for (let i = 0; i < pixelData.length; ++i) { |
|
|
if (mask.data[numMasks * i + bestIndex] === 1) { |
|
|
const offset = 4 * i; |
|
|
pixelData[offset] = 0; |
|
|
pixelData[offset + 1] = 114; |
|
|
pixelData[offset + 2] = 189; |
|
|
pixelData[offset + 3] = 255; |
|
|
} |
|
|
} |
|
|
|
|
|
maskContext.putImageData(imageData, 0, 0); |
|
|
} |
|
|
|
|
|
function clearPointsAndMask() { |
|
|
isMultiMaskMode = false; |
|
|
lastPoints = null; |
|
|
|
|
|
document.querySelectorAll(".icon").forEach((e) => e.remove()); |
|
|
|
|
|
cutButton.disabled = true; |
|
|
|
|
|
maskContext.clearRect(0, 0, maskCanvas.width, maskCanvas.height); |
|
|
statusLabel.textContent = "Points cleared. Click to add new points."; |
|
|
} |
|
|
|
|
|
function resetUI() { |
|
|
imageInput = null; |
|
|
imageProcessed = null; |
|
|
imageEmbeddings = null; |
|
|
isEncoding = false; |
|
|
isDecoding = false; |
|
|
decodePending = false; |
|
|
|
|
|
clearPointsAndMask(); |
|
|
|
|
|
cutButton.disabled = true; |
|
|
resetButton.disabled = true; |
|
|
clearButton.disabled = true; |
|
|
imageDisplay.src = ''; |
|
|
imageDisplay.classList.add('hidden'); |
|
|
uploadArea.classList.remove("hidden"); |
|
|
|
|
|
|
|
|
maskCanvas.style.width = '0px'; |
|
|
maskCanvas.style.height = '0px'; |
|
|
|
|
|
exampleImages.forEach(img => img.style.pointerEvents = "auto"); |
|
|
|
|
|
statusLabel.textContent = "Ready"; |
|
|
} |
|
|
function clamp(x, min = 0, max = 1) { |
|
|
return Math.max(Math.min(x, max), min); |
|
|
} |
|
|
function getPoint(e) { |
|
|
const imgBB = imageDisplay.getBoundingClientRect(); |
|
|
const canvasBB = maskCanvas.getBoundingClientRect(); |
|
|
|
|
|
|
|
|
const mouseX = clamp((e.clientX - canvasBB.left) / canvasBB.width); |
|
|
const mouseY = clamp((e.clientY - canvasBB.top) / canvasBB.height); |
|
|
|
|
|
return { |
|
|
position: [mouseX, mouseY], |
|
|
label: e.button === 2 ? 0 : 1, |
|
|
}; |
|
|
} |
|
|
|
|
|
fileUpload.addEventListener("change", function (e) { |
|
|
const file = e.target.files[0]; |
|
|
if (!file) return; |
|
|
|
|
|
const reader = new FileReader(); |
|
|
reader.onload = (e2) => encode(e2.target.result); |
|
|
reader.readAsDataURL(file); |
|
|
}); |
|
|
|
|
|
exampleImages.forEach((img) => { |
|
|
img.addEventListener("click", (e) => { |
|
|
e.preventDefault(); |
|
|
e.stopPropagation(); |
|
|
exampleImages.forEach(i => i.style.pointerEvents = "none"); |
|
|
encode(img.src); |
|
|
}); |
|
|
}); |
|
|
|
|
|
window.addEventListener("resize", updateCanvasGeometry); |
|
|
|
|
|
resetButton.addEventListener("click", resetUI); |
|
|
|
|
|
clearButton.addEventListener("click", clearPointsAndMask); |
|
|
|
|
|
imageContainer.addEventListener("mousedown", (e) => { |
|
|
if (!imageEmbeddings || uploadArea.classList.contains('hidden') === false) { |
|
|
return; |
|
|
} |
|
|
|
|
|
if (e.button !== 0 && e.button !== 2) return; |
|
|
|
|
|
if (!isMultiMaskMode) { |
|
|
lastPoints = []; |
|
|
isMultiMaskMode = true; |
|
|
cutButton.disabled = false; |
|
|
} |
|
|
|
|
|
const point = getPoint(e); |
|
|
lastPoints.push(point); |
|
|
|
|
|
const icon = document.createElement('span'); |
|
|
icon.className = 'icon'; |
|
|
icon.textContent = point.label === 1 ? '⭐' : '❌'; |
|
|
|
|
|
|
|
|
const canvasRect = maskCanvas.getBoundingClientRect(); |
|
|
const containerRect = imageContainer.getBoundingClientRect(); |
|
|
const left = canvasRect.left - containerRect.left + point.position[0] * canvasRect.width; |
|
|
const top = canvasRect.top - containerRect.top + point.position[1] * canvasRect.height; |
|
|
|
|
|
icon.style.left = `${left}px`; |
|
|
icon.style.top = `${top}px`; |
|
|
imageContainer.appendChild(icon); |
|
|
|
|
|
decode(); |
|
|
}); |
|
|
|
|
|
imageContainer.addEventListener("contextmenu", (e) => e.preventDefault()); |
|
|
|
|
|
imageContainer.addEventListener("mousemove", (e) => { |
|
|
if (!imageEmbeddings || isMultiMaskMode || uploadArea.classList.contains('hidden') === false) { |
|
|
return; |
|
|
} |
|
|
lastPoints = [getPoint(e)]; |
|
|
decode(); |
|
|
}); |
|
|
|
|
|
cutButton.addEventListener("click", async () => { |
|
|
if (!imageInput || !maskCanvas) return; |
|
|
|
|
|
const [w, h] = [maskCanvas.width, maskCanvas.height]; |
|
|
|
|
|
const maskImageData = maskContext.getImageData(0, 0, w, h); |
|
|
const maskPixelData = maskImageData.data; |
|
|
|
|
|
const cutCanvas = new OffscreenCanvas(w, h); |
|
|
const cutContext = cutCanvas.getContext("2d"); |
|
|
|
|
|
const cutImageData = cutContext.createImageData(w, h); |
|
|
const cutPixelData = cutImageData.data; |
|
|
|
|
|
const imagePixelData = imageInput.data; |
|
|
|
|
|
for (let i = 0; i < w * h; ++i) { |
|
|
const maskOffset = 4 * i; |
|
|
const imageOffset = 3 * i; |
|
|
|
|
|
if (maskPixelData[maskOffset + 3] > 0) { |
|
|
cutPixelData[maskOffset] = imagePixelData[imageOffset]; |
|
|
cutPixelData[maskOffset + 1] = imagePixelData[imageOffset + 1]; |
|
|
cutPixelData[maskOffset + 2] = imagePixelData[imageOffset + 2]; |
|
|
cutPixelData[maskOffset + 3] = 255; |
|
|
} |
|
|
} |
|
|
cutContext.putImageData(cutImageData, 0, 0); |
|
|
|
|
|
const link = document.createElement("a"); |
|
|
link.download = "mask-cutout.png"; |
|
|
link.href = URL.createObjectURL(await cutCanvas.convertToBlob()); |
|
|
link.click(); |
|
|
link.remove(); |
|
|
}); |
|
|
|
|
|
async function loadModel() { |
|
|
try { |
|
|
const model_id = "onnx-community/sam3-tracker-ONNX"; |
|
|
|
|
|
const model = await Sam3TrackerModel.from_pretrained(model_id, { |
|
|
dtype: { |
|
|
vision_encoder: "q4", |
|
|
prompt_encoder_mask_decoder: "fp32", |
|
|
}, |
|
|
device: "webgpu", |
|
|
}); |
|
|
|
|
|
const processor = await AutoProcessor.from_pretrained(model_id); |
|
|
|
|
|
statusLabel.textContent = "Ready"; |
|
|
|
|
|
fileUpload.disabled = false; |
|
|
|
|
|
return { model, processor }; |
|
|
} catch (error) { |
|
|
console.error("Error loading model:", error); |
|
|
statusLabel.textContent = "Error loading model. Please refresh the page."; |
|
|
} |
|
|
} |
|
|
|
|
|
const { model, processor } = await loadModel(); |
|
|
|
|
|
</script> |
|
|
</body> |
|
|
|
|
|
</html> |