Update README.md
Browse files
README.md
CHANGED
@@ -1,10 +1,48 @@
|
|
1 |
-
---
|
2 |
-
tags:
|
3 |
-
-
|
4 |
-
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
tags:
|
3 |
+
- curb ramp detection
|
4 |
+
- accessibility
|
5 |
+
license: mit
|
6 |
+
datasets:
|
7 |
+
- projectsidewalk/rampnet-dataset
|
8 |
+
base_model:
|
9 |
+
- timm/convnextv2_base.fcmae_ft_in22k_in1k_384
|
10 |
+
pipeline_tag: object-detection
|
11 |
+
---
|
12 |
+
|
13 |
+
This is the curb ramp detection model introduced in *RampNet*.
|
14 |
+
|
15 |
+
**Example usage:**
|
16 |
+
```py
|
17 |
+
import torch
|
18 |
+
from transformers import AutoModel
|
19 |
+
from PIL import Image
|
20 |
+
import numpy as np
|
21 |
+
from torchvision import transforms
|
22 |
+
from skimage.feature import peak_local_max
|
23 |
+
|
24 |
+
IMAGE_PATH = "example.jpg"
|
25 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
26 |
+
|
27 |
+
model = AutoModel.from_pretrained("projectsidewalk/rampnet-model", trust_remote_code=True).to(DEVICE).eval()
|
28 |
+
|
29 |
+
preprocess = transforms.Compose([
|
30 |
+
transforms.Resize((2048, 4096), interpolation=transforms.InterpolationMode.BILINEAR),
|
31 |
+
transforms.ToTensor(),
|
32 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
33 |
+
])
|
34 |
+
|
35 |
+
img = Image.open(IMAGE_PATH).convert("RGB")
|
36 |
+
img_tensor = preprocess(img).unsqueeze(0).to(DEVICE)
|
37 |
+
|
38 |
+
with torch.no_grad():
|
39 |
+
heatmap = model(img_tensor).squeeze().cpu().numpy()
|
40 |
+
|
41 |
+
peaks = peak_local_max(np.clip(heatmap, 0, 1), min_distance=10, threshold_abs=0.5)
|
42 |
+
scale_w = img.width / heatmap.shape[1]
|
43 |
+
scale_h = img.height / heatmap.shape[0]
|
44 |
+
coordinates = [(int(c * scale_w), int(r * scale_h)) for r, c in peaks]
|
45 |
+
|
46 |
+
# Coordinates of detected curb ramps
|
47 |
+
print(coordinates)
|
48 |
+
```
|