Spaces:
Running
on
Zero
Running
on
Zero
derektan
commited on
Commit
·
dd3c1c5
1
Parent(s):
56e7382
First commit. Using Git LFS for binaries
Browse files- .gitattributes +3 -0
- .gitignore +1 -0
- app.py +271 -0
- clip_vision_per_patch_model.py +26 -0
- examples/NAIP_yosemite_v3_resized.png +3 -0
- examples/american_black_bear_inat_248820933.jpeg +3 -0
- requirements.txt +11 -0
.gitattributes
CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
38 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
**/__pycache__/
|
app.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
EcoMonitor • multimodal heat-map demo (with custom preprocessing)
|
3 |
+
"""
|
4 |
+
|
5 |
+
# ────────────────────────── imports ───────────────────────────────────
|
6 |
+
import cv2
|
7 |
+
import gradio as gr
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
import io
|
13 |
+
|
14 |
+
from torchvision import transforms
|
15 |
+
import open_clip
|
16 |
+
from clip_vision_per_patch_model import CLIPVisionPerPatchModel
|
17 |
+
|
18 |
+
# ────────────────────────── global config & models ────────────────────
|
19 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
20 |
+
|
21 |
+
# 1️⃣ BioCLIP (ground-image & text encoder)
|
22 |
+
bio_model, _, _ = open_clip.create_model_and_transforms("hf-hub:imageomics/bioclip")
|
23 |
+
bio_model = bio_model.to(device).eval()
|
24 |
+
bio_tokenizer = open_clip.get_tokenizer("hf-hub:imageomics/bioclip")
|
25 |
+
|
26 |
+
# 2️⃣ Satellite patch encoder (CLIP-L-336 per-patch)
|
27 |
+
sat_model: CLIPVisionPerPatchModel = (
|
28 |
+
CLIPVisionPerPatchModel.from_pretrained("derektan95/search-tta")
|
29 |
+
.to(device)
|
30 |
+
.eval()
|
31 |
+
)
|
32 |
+
|
33 |
+
logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
34 |
+
logit_scale = logit_scale.exp()
|
35 |
+
blur_kernel = (5,5)
|
36 |
+
|
37 |
+
# ────────────────────────── transforms (exact spec) ───────────────────
|
38 |
+
img_transform = transforms.Compose(
|
39 |
+
[
|
40 |
+
transforms.Resize((256, 256)),
|
41 |
+
transforms.CenterCrop((224, 224)),
|
42 |
+
transforms.ToTensor(),
|
43 |
+
transforms.Normalize(
|
44 |
+
mean=[0.485, 0.456, 0.406],
|
45 |
+
std=[0.229, 0.224, 0.225],
|
46 |
+
),
|
47 |
+
]
|
48 |
+
)
|
49 |
+
|
50 |
+
imo_transform = transforms.Compose(
|
51 |
+
[
|
52 |
+
transforms.Resize((336, 336)),
|
53 |
+
transforms.ToTensor(),
|
54 |
+
transforms.Normalize(
|
55 |
+
mean=[0.485, 0.456, 0.406],
|
56 |
+
std=[0.229, 0.224, 0.225],
|
57 |
+
),
|
58 |
+
]
|
59 |
+
)
|
60 |
+
|
61 |
+
# ────────────────────────── helpers ───────────────────────────────────
|
62 |
+
# def _tensor_ground(img_pil: Image.Image) -> torch.Tensor:
|
63 |
+
# return img_transform(img_pil).unsqueeze(0).to(device)
|
64 |
+
|
65 |
+
|
66 |
+
# def _tensor_sat(img_pil: Image.Image) -> torch.Tensor:
|
67 |
+
# return imo_transform(img_pil).unsqueeze(0).to(device)
|
68 |
+
|
69 |
+
|
70 |
+
@torch.no_grad()
|
71 |
+
def _encode_ground(img_pil: Image.Image) -> torch.Tensor:
|
72 |
+
img = img_transform(img_pil).unsqueeze(0).to(device)
|
73 |
+
img_embeds, *_ = bio_model(img)
|
74 |
+
return img_embeds
|
75 |
+
# feats = bio_model.encode_image(_tensor_ground(img_pil))
|
76 |
+
# return torch.nn.functional.normalize(feats, dim=-1)
|
77 |
+
|
78 |
+
|
79 |
+
@torch.no_grad()
|
80 |
+
def _encode_text(text: str) -> torch.Tensor:
|
81 |
+
toks = bio_tokenizer(text).to(device)
|
82 |
+
_, txt_embeds, _ = bio_model(text=toks)
|
83 |
+
return txt_embeds
|
84 |
+
# return torch.nn.functional.normalize(feats, dim=-1)
|
85 |
+
|
86 |
+
|
87 |
+
@torch.no_grad()
|
88 |
+
def _encode_sat(img_pil: Image.Image) -> torch.Tensor:
|
89 |
+
imo = imo_transform(img_pil).unsqueeze(0).to(device)
|
90 |
+
imo_embeds = sat_model(imo)
|
91 |
+
return imo_embeds
|
92 |
+
# out = sat_model(_tensor_sat(img_pil))
|
93 |
+
# if hasattr(out, "last_hidden_state"):
|
94 |
+
# out = out.last_hidden_state
|
95 |
+
# return torch.nn.functional.normalize(out.squeeze(0), dim=-1) # (P, D)
|
96 |
+
# return out
|
97 |
+
|
98 |
+
|
99 |
+
def _similarity_heatmap(query: torch.Tensor, patches: torch.Tensor) -> np.ndarray:
|
100 |
+
sims = torch.matmul(query, patches.t()) * logit_scale
|
101 |
+
sims = sims.t().sigmoid()
|
102 |
+
# sims = torch.sigmoid(patches @ query.squeeze(0)) # (P,)
|
103 |
+
sims = sims[1:].squeeze() # drop CLS token
|
104 |
+
side = int(np.sqrt(len(sims)))
|
105 |
+
sims = sims.reshape(side, side)
|
106 |
+
return sims.cpu().detach().numpy()
|
107 |
+
# return sims[: side * side].view(side, side).cpu().numpy()
|
108 |
+
|
109 |
+
|
110 |
+
def _array_to_pil(arr: np.ndarray) -> Image.Image:
|
111 |
+
"""
|
112 |
+
Render arr with viridis, automatically stretching its own min→max to 0→1
|
113 |
+
so that the most-similar patches appear yellow.
|
114 |
+
"""
|
115 |
+
|
116 |
+
# Gausian Smoothing
|
117 |
+
if blur_kernel != (0,0):
|
118 |
+
arr = cv2.GaussianBlur(arr, blur_kernel, 0)
|
119 |
+
|
120 |
+
# --- contrast-stretch to local 0-1 range --------------------------
|
121 |
+
arr_min, arr_max = float(arr.min()), float(arr.max())
|
122 |
+
if arr_max - arr_min < 1e-6: # avoid /0 when the heat-map is flat
|
123 |
+
arr_scaled = np.zeros_like(arr)
|
124 |
+
else:
|
125 |
+
arr_scaled = (arr - arr_min) / (arr_max - arr_min)
|
126 |
+
# ------------------------------------------------------------------
|
127 |
+
fig, ax = plt.subplots(figsize=(2.6, 2.6), dpi=96)
|
128 |
+
ax.imshow(arr_scaled, cmap="viridis", vmin=0.0, vmax=1.0)
|
129 |
+
ax.axis("off")
|
130 |
+
buf = io.BytesIO()
|
131 |
+
plt.tight_layout(pad=0)
|
132 |
+
fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
|
133 |
+
plt.close(fig)
|
134 |
+
buf.seek(0)
|
135 |
+
return Image.open(buf)
|
136 |
+
|
137 |
+
# ────────────────────────── main inference ────────────────────────────
|
138 |
+
def process(
|
139 |
+
sat_img: Image.Image,
|
140 |
+
taxonomy: str,
|
141 |
+
ground_img: Image.Image | None,
|
142 |
+
):
|
143 |
+
if sat_img is None:
|
144 |
+
return None, None
|
145 |
+
|
146 |
+
patches = _encode_sat(sat_img)
|
147 |
+
|
148 |
+
heat_ground, heat_text = None, None
|
149 |
+
|
150 |
+
if ground_img is not None:
|
151 |
+
q_img = _encode_ground(ground_img)
|
152 |
+
heat_ground = _array_to_pil(_similarity_heatmap(q_img, patches))
|
153 |
+
|
154 |
+
if taxonomy.strip():
|
155 |
+
q_txt = _encode_text(taxonomy.strip())
|
156 |
+
heat_text = _array_to_pil(_similarity_heatmap(q_txt, patches))
|
157 |
+
|
158 |
+
return heat_ground, heat_text
|
159 |
+
|
160 |
+
|
161 |
+
# ────────────────────────── Gradio UI ─────────────────────────────────
|
162 |
+
with gr.Blocks(title="EcoMonitor", theme=gr.themes.Base()) as demo:
|
163 |
+
|
164 |
+
with gr.Row():
|
165 |
+
gr.Markdown(
|
166 |
+
"""
|
167 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
168 |
+
<div>
|
169 |
+
<h1>Search-TTA: A Multimodal Test-Time Adaptation Framework for Visual Search in the Wild</h1>
|
170 |
+
<span></span>
|
171 |
+
<h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
|
172 |
+
<a href="https://search-tta.github.io">Project Website</a>
|
173 |
+
</h2>
|
174 |
+
</div>
|
175 |
+
</div>
|
176 |
+
"""
|
177 |
+
# <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>WACV 2025</h2>
|
178 |
+
|
179 |
+
# <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
|
180 |
+
# <a href="https://derektan95.github.io">Derek M. S. Tan</a>,
|
181 |
+
# <a href="https://chinchinati.github.io/">Shailesh</a>,
|
182 |
+
# <a href="https://www.linkedin.com/in/boyang-liu-nus">Boyang Liu</a>,
|
183 |
+
# <a href="https://www.linkedin.com/in/loki-silvres">Alok Raj</a>,
|
184 |
+
# <a href="https://www.linkedin.com/in/ang-qi-xuan-714347142">Qi Xuan Ang</a>,
|
185 |
+
# <a href="https://weihengdai.top">Weiheng Dai</a>,
|
186 |
+
# <a href="https://www.linkedin.com/in/tanishqduhan">Tanishq Duhan</a>,
|
187 |
+
# <a href="https://www.linkedin.com/in/jimmychiun">Jimmy Chiun</a>,
|
188 |
+
# <a href="https://www.yuhongcao.online/">Yuhong Cao</a>,
|
189 |
+
# <a href="https://www.cs.toronto.edu/~florian/">Florian Shkurti</a>,
|
190 |
+
# <a href="https://www.marmotlab.org/bio.html">Guillaume Sartoretti</a>
|
191 |
+
# </h2>
|
192 |
+
# <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>National University of Singapore, University of Toronto, IIT-Dhanbad, Singapore Technologies Engineering</h2>
|
193 |
+
)
|
194 |
+
|
195 |
+
with gr.Row(variant="panel"):
|
196 |
+
|
197 |
+
# LEFT COLUMN (satellite, taxonomy, run)
|
198 |
+
with gr.Column():
|
199 |
+
sat_input = gr.Image(
|
200 |
+
label="Satellite Image",
|
201 |
+
sources=["upload"],
|
202 |
+
type="pil",
|
203 |
+
height=320,
|
204 |
+
)
|
205 |
+
taxonomy_input = gr.Textbox(
|
206 |
+
label="Full Taxonomy Name (optional)",
|
207 |
+
placeholder="e.g. Animalia Chordata Mammalia Carnivora Ursidae Ursus arctos",
|
208 |
+
)
|
209 |
+
run_btn = gr.Button("Run", variant="primary")
|
210 |
+
|
211 |
+
# RIGHT COLUMN (ground image + two heat-maps)
|
212 |
+
with gr.Column():
|
213 |
+
ground_input = gr.Image(
|
214 |
+
label="Ground-level Image (optional)",
|
215 |
+
sources=["upload"],
|
216 |
+
type="pil",
|
217 |
+
height=320,
|
218 |
+
)
|
219 |
+
heat_ground_out = gr.Image(
|
220 |
+
label="Heat-map (Ground query)",
|
221 |
+
height=160,
|
222 |
+
)
|
223 |
+
heat_text_out = gr.Image(
|
224 |
+
label="Heat-map (Text query)",
|
225 |
+
height=160,
|
226 |
+
)
|
227 |
+
|
228 |
+
# EXAMPLES
|
229 |
+
with gr.Row():
|
230 |
+
gr.Examples(
|
231 |
+
examples=[
|
232 |
+
[
|
233 |
+
"examples/NAIP_yosemite_v3_resized.png",
|
234 |
+
"Animalia Chordata Mammalia Carnivora Ursidae Ursus americanus",
|
235 |
+
"examples/american_black_bear_inat_248820933.jpeg",
|
236 |
+
],
|
237 |
+
# [
|
238 |
+
# "examples/satellite_coast.png",
|
239 |
+
# "",
|
240 |
+
# "examples/ground_gull.jpg",
|
241 |
+
# ],
|
242 |
+
# [
|
243 |
+
# "examples/satellite_coast.png",
|
244 |
+
# "Animalia Chordata Aves Charadriiformes Laridae Larus argentatus",
|
245 |
+
# None,
|
246 |
+
# ],
|
247 |
+
],
|
248 |
+
inputs=[sat_input, taxonomy_input, ground_input],
|
249 |
+
outputs=[heat_ground_out, heat_text_out],
|
250 |
+
fn=process,
|
251 |
+
cache_examples=False,
|
252 |
+
)
|
253 |
+
|
254 |
+
# CALLBACK
|
255 |
+
run_btn.click(
|
256 |
+
fn=process,
|
257 |
+
inputs=[sat_input, taxonomy_input, ground_input],
|
258 |
+
outputs=[heat_ground_out, heat_text_out],
|
259 |
+
)
|
260 |
+
|
261 |
+
# Footer to point out to model and data from app page.
|
262 |
+
gr.Markdown(
|
263 |
+
"""
|
264 |
+
This model is fine-tuned using [Sentinel-2 Level 2A](https://docs.sentinel-hub.com/api/latest/data/sentinel-2-l2a/) satellite images and taxonomy images and locations from [iNaturalist](https://inaturalist.org/).
|
265 |
+
"""
|
266 |
+
)
|
267 |
+
|
268 |
+
# LAUNCH
|
269 |
+
if __name__ == "__main__":
|
270 |
+
demo.queue(max_size=15)
|
271 |
+
demo.launch(share=True)
|
clip_vision_per_patch_model.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig
|
3 |
+
|
4 |
+
class CLIPVisionPerPatchModel(CLIPVisionModelWithProjection):
|
5 |
+
"""
|
6 |
+
Like CLIPVisionModelWithProjection but returns
|
7 |
+
per-patch embeddings instead of pooled CLS tokens.
|
8 |
+
"""
|
9 |
+
def __init__(self, config: CLIPVisionConfig):
|
10 |
+
super().__init__(config)
|
11 |
+
# everything else (self.vision_model, self.visual_projection)
|
12 |
+
# is set up for you by the parent class
|
13 |
+
|
14 |
+
def forward(self, pixel_values, **kwargs):
|
15 |
+
# 1) run the ViT backbone → last_hidden_state [B, n_patches, hidden_size]
|
16 |
+
outputs = self.vision_model(pixel_values, return_dict=True, **kwargs)
|
17 |
+
hidden_states = outputs.last_hidden_state
|
18 |
+
|
19 |
+
# 2) project every patch token → [B, n_patches, projection_dim]
|
20 |
+
patch_embeds = self.visual_projection(hidden_states)
|
21 |
+
|
22 |
+
# 3) Postprocessing embeds
|
23 |
+
patch_embeds = torch.nn.functional.normalize(patch_embeds, dim=-1)
|
24 |
+
patch_embeds = patch_embeds.squeeze() # (Patches, proj_dim)
|
25 |
+
|
26 |
+
return patch_embeds
|
examples/NAIP_yosemite_v3_resized.png
ADDED
![]() |
Git LFS Details
|
examples/american_black_bear_inat_248820933.jpeg
ADDED
![]() |
Git LFS Details
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python 3.10.14
|
2 |
+
|
3 |
+
numpy==1.26.3
|
4 |
+
torch==2.4.1
|
5 |
+
torchvision==0.19.1
|
6 |
+
pytorch-lightning==2.2.1
|
7 |
+
open_clip_torch==2.30.0
|
8 |
+
transformers==4.45.1
|
9 |
+
tokenizers==0.20.3
|
10 |
+
opencv-python==4.10.0.84
|
11 |
+
gradio==3.39.0
|