derektan commited on
Commit
dd3c1c5
·
1 Parent(s): 56e7382

First commit. Using Git LFS for binaries

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
37
+ *.jpg filter=lfs diff=lfs merge=lfs -text
38
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ **/__pycache__/
app.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ EcoMonitor • multimodal heat-map demo (with custom preprocessing)
3
+ """
4
+
5
+ # ────────────────────────── imports ───────────────────────────────────
6
+ import cv2
7
+ import gradio as gr
8
+ import torch
9
+ import numpy as np
10
+ from PIL import Image
11
+ import matplotlib.pyplot as plt
12
+ import io
13
+
14
+ from torchvision import transforms
15
+ import open_clip
16
+ from clip_vision_per_patch_model import CLIPVisionPerPatchModel
17
+
18
+ # ────────────────────────── global config & models ────────────────────
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+
21
+ # 1️⃣ BioCLIP (ground-image & text encoder)
22
+ bio_model, _, _ = open_clip.create_model_and_transforms("hf-hub:imageomics/bioclip")
23
+ bio_model = bio_model.to(device).eval()
24
+ bio_tokenizer = open_clip.get_tokenizer("hf-hub:imageomics/bioclip")
25
+
26
+ # 2️⃣ Satellite patch encoder (CLIP-L-336 per-patch)
27
+ sat_model: CLIPVisionPerPatchModel = (
28
+ CLIPVisionPerPatchModel.from_pretrained("derektan95/search-tta")
29
+ .to(device)
30
+ .eval()
31
+ )
32
+
33
+ logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
34
+ logit_scale = logit_scale.exp()
35
+ blur_kernel = (5,5)
36
+
37
+ # ────────────────────────── transforms (exact spec) ───────────────────
38
+ img_transform = transforms.Compose(
39
+ [
40
+ transforms.Resize((256, 256)),
41
+ transforms.CenterCrop((224, 224)),
42
+ transforms.ToTensor(),
43
+ transforms.Normalize(
44
+ mean=[0.485, 0.456, 0.406],
45
+ std=[0.229, 0.224, 0.225],
46
+ ),
47
+ ]
48
+ )
49
+
50
+ imo_transform = transforms.Compose(
51
+ [
52
+ transforms.Resize((336, 336)),
53
+ transforms.ToTensor(),
54
+ transforms.Normalize(
55
+ mean=[0.485, 0.456, 0.406],
56
+ std=[0.229, 0.224, 0.225],
57
+ ),
58
+ ]
59
+ )
60
+
61
+ # ────────────────────────── helpers ───────────────────────────────────
62
+ # def _tensor_ground(img_pil: Image.Image) -> torch.Tensor:
63
+ # return img_transform(img_pil).unsqueeze(0).to(device)
64
+
65
+
66
+ # def _tensor_sat(img_pil: Image.Image) -> torch.Tensor:
67
+ # return imo_transform(img_pil).unsqueeze(0).to(device)
68
+
69
+
70
+ @torch.no_grad()
71
+ def _encode_ground(img_pil: Image.Image) -> torch.Tensor:
72
+ img = img_transform(img_pil).unsqueeze(0).to(device)
73
+ img_embeds, *_ = bio_model(img)
74
+ return img_embeds
75
+ # feats = bio_model.encode_image(_tensor_ground(img_pil))
76
+ # return torch.nn.functional.normalize(feats, dim=-1)
77
+
78
+
79
+ @torch.no_grad()
80
+ def _encode_text(text: str) -> torch.Tensor:
81
+ toks = bio_tokenizer(text).to(device)
82
+ _, txt_embeds, _ = bio_model(text=toks)
83
+ return txt_embeds
84
+ # return torch.nn.functional.normalize(feats, dim=-1)
85
+
86
+
87
+ @torch.no_grad()
88
+ def _encode_sat(img_pil: Image.Image) -> torch.Tensor:
89
+ imo = imo_transform(img_pil).unsqueeze(0).to(device)
90
+ imo_embeds = sat_model(imo)
91
+ return imo_embeds
92
+ # out = sat_model(_tensor_sat(img_pil))
93
+ # if hasattr(out, "last_hidden_state"):
94
+ # out = out.last_hidden_state
95
+ # return torch.nn.functional.normalize(out.squeeze(0), dim=-1) # (P, D)
96
+ # return out
97
+
98
+
99
+ def _similarity_heatmap(query: torch.Tensor, patches: torch.Tensor) -> np.ndarray:
100
+ sims = torch.matmul(query, patches.t()) * logit_scale
101
+ sims = sims.t().sigmoid()
102
+ # sims = torch.sigmoid(patches @ query.squeeze(0)) # (P,)
103
+ sims = sims[1:].squeeze() # drop CLS token
104
+ side = int(np.sqrt(len(sims)))
105
+ sims = sims.reshape(side, side)
106
+ return sims.cpu().detach().numpy()
107
+ # return sims[: side * side].view(side, side).cpu().numpy()
108
+
109
+
110
+ def _array_to_pil(arr: np.ndarray) -> Image.Image:
111
+ """
112
+ Render arr with viridis, automatically stretching its own min→max to 0→1
113
+ so that the most-similar patches appear yellow.
114
+ """
115
+
116
+ # Gausian Smoothing
117
+ if blur_kernel != (0,0):
118
+ arr = cv2.GaussianBlur(arr, blur_kernel, 0)
119
+
120
+ # --- contrast-stretch to local 0-1 range --------------------------
121
+ arr_min, arr_max = float(arr.min()), float(arr.max())
122
+ if arr_max - arr_min < 1e-6: # avoid /0 when the heat-map is flat
123
+ arr_scaled = np.zeros_like(arr)
124
+ else:
125
+ arr_scaled = (arr - arr_min) / (arr_max - arr_min)
126
+ # ------------------------------------------------------------------
127
+ fig, ax = plt.subplots(figsize=(2.6, 2.6), dpi=96)
128
+ ax.imshow(arr_scaled, cmap="viridis", vmin=0.0, vmax=1.0)
129
+ ax.axis("off")
130
+ buf = io.BytesIO()
131
+ plt.tight_layout(pad=0)
132
+ fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
133
+ plt.close(fig)
134
+ buf.seek(0)
135
+ return Image.open(buf)
136
+
137
+ # ────────────────────────── main inference ────────────────────────────
138
+ def process(
139
+ sat_img: Image.Image,
140
+ taxonomy: str,
141
+ ground_img: Image.Image | None,
142
+ ):
143
+ if sat_img is None:
144
+ return None, None
145
+
146
+ patches = _encode_sat(sat_img)
147
+
148
+ heat_ground, heat_text = None, None
149
+
150
+ if ground_img is not None:
151
+ q_img = _encode_ground(ground_img)
152
+ heat_ground = _array_to_pil(_similarity_heatmap(q_img, patches))
153
+
154
+ if taxonomy.strip():
155
+ q_txt = _encode_text(taxonomy.strip())
156
+ heat_text = _array_to_pil(_similarity_heatmap(q_txt, patches))
157
+
158
+ return heat_ground, heat_text
159
+
160
+
161
+ # ────────────────────────── Gradio UI ─────────────────────────────────
162
+ with gr.Blocks(title="EcoMonitor", theme=gr.themes.Base()) as demo:
163
+
164
+ with gr.Row():
165
+ gr.Markdown(
166
+ """
167
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
168
+ <div>
169
+ <h1>Search-TTA: A Multimodal Test-Time Adaptation Framework for Visual Search in the Wild</h1>
170
+ <span></span>
171
+ <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
172
+ <a href="https://search-tta.github.io">Project Website</a>
173
+ </h2>
174
+ </div>
175
+ </div>
176
+ """
177
+ # <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>WACV 2025</h2>
178
+
179
+ # <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
180
+ # <a href="https://derektan95.github.io">Derek M. S. Tan</a>,
181
+ # <a href="https://chinchinati.github.io/">Shailesh</a>,
182
+ # <a href="https://www.linkedin.com/in/boyang-liu-nus">Boyang Liu</a>,
183
+ # <a href="https://www.linkedin.com/in/loki-silvres">Alok Raj</a>,
184
+ # <a href="https://www.linkedin.com/in/ang-qi-xuan-714347142">Qi Xuan Ang</a>,
185
+ # <a href="https://weihengdai.top">Weiheng Dai</a>,
186
+ # <a href="https://www.linkedin.com/in/tanishqduhan">Tanishq Duhan</a>,
187
+ # <a href="https://www.linkedin.com/in/jimmychiun">Jimmy Chiun</a>,
188
+ # <a href="https://www.yuhongcao.online/">Yuhong Cao</a>,
189
+ # <a href="https://www.cs.toronto.edu/~florian/">Florian Shkurti</a>,
190
+ # <a href="https://www.marmotlab.org/bio.html">Guillaume Sartoretti</a>
191
+ # </h2>
192
+ # <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>National University of Singapore, University of Toronto, IIT-Dhanbad, Singapore Technologies Engineering</h2>
193
+ )
194
+
195
+ with gr.Row(variant="panel"):
196
+
197
+ # LEFT COLUMN (satellite, taxonomy, run)
198
+ with gr.Column():
199
+ sat_input = gr.Image(
200
+ label="Satellite Image",
201
+ sources=["upload"],
202
+ type="pil",
203
+ height=320,
204
+ )
205
+ taxonomy_input = gr.Textbox(
206
+ label="Full Taxonomy Name (optional)",
207
+ placeholder="e.g. Animalia Chordata Mammalia Carnivora Ursidae Ursus arctos",
208
+ )
209
+ run_btn = gr.Button("Run", variant="primary")
210
+
211
+ # RIGHT COLUMN (ground image + two heat-maps)
212
+ with gr.Column():
213
+ ground_input = gr.Image(
214
+ label="Ground-level Image (optional)",
215
+ sources=["upload"],
216
+ type="pil",
217
+ height=320,
218
+ )
219
+ heat_ground_out = gr.Image(
220
+ label="Heat-map (Ground query)",
221
+ height=160,
222
+ )
223
+ heat_text_out = gr.Image(
224
+ label="Heat-map (Text query)",
225
+ height=160,
226
+ )
227
+
228
+ # EXAMPLES
229
+ with gr.Row():
230
+ gr.Examples(
231
+ examples=[
232
+ [
233
+ "examples/NAIP_yosemite_v3_resized.png",
234
+ "Animalia Chordata Mammalia Carnivora Ursidae Ursus americanus",
235
+ "examples/american_black_bear_inat_248820933.jpeg",
236
+ ],
237
+ # [
238
+ # "examples/satellite_coast.png",
239
+ # "",
240
+ # "examples/ground_gull.jpg",
241
+ # ],
242
+ # [
243
+ # "examples/satellite_coast.png",
244
+ # "Animalia Chordata Aves Charadriiformes Laridae Larus argentatus",
245
+ # None,
246
+ # ],
247
+ ],
248
+ inputs=[sat_input, taxonomy_input, ground_input],
249
+ outputs=[heat_ground_out, heat_text_out],
250
+ fn=process,
251
+ cache_examples=False,
252
+ )
253
+
254
+ # CALLBACK
255
+ run_btn.click(
256
+ fn=process,
257
+ inputs=[sat_input, taxonomy_input, ground_input],
258
+ outputs=[heat_ground_out, heat_text_out],
259
+ )
260
+
261
+ # Footer to point out to model and data from app page.
262
+ gr.Markdown(
263
+ """
264
+ This model is fine-tuned using [Sentinel-2 Level 2A](https://docs.sentinel-hub.com/api/latest/data/sentinel-2-l2a/) satellite images and taxonomy images and locations from [iNaturalist](https://inaturalist.org/).
265
+ """
266
+ )
267
+
268
+ # LAUNCH
269
+ if __name__ == "__main__":
270
+ demo.queue(max_size=15)
271
+ demo.launch(share=True)
clip_vision_per_patch_model.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig
3
+
4
+ class CLIPVisionPerPatchModel(CLIPVisionModelWithProjection):
5
+ """
6
+ Like CLIPVisionModelWithProjection but returns
7
+ per-patch embeddings instead of pooled CLS tokens.
8
+ """
9
+ def __init__(self, config: CLIPVisionConfig):
10
+ super().__init__(config)
11
+ # everything else (self.vision_model, self.visual_projection)
12
+ # is set up for you by the parent class
13
+
14
+ def forward(self, pixel_values, **kwargs):
15
+ # 1) run the ViT backbone → last_hidden_state [B, n_patches, hidden_size]
16
+ outputs = self.vision_model(pixel_values, return_dict=True, **kwargs)
17
+ hidden_states = outputs.last_hidden_state
18
+
19
+ # 2) project every patch token → [B, n_patches, projection_dim]
20
+ patch_embeds = self.visual_projection(hidden_states)
21
+
22
+ # 3) Postprocessing embeds
23
+ patch_embeds = torch.nn.functional.normalize(patch_embeds, dim=-1)
24
+ patch_embeds = patch_embeds.squeeze() # (Patches, proj_dim)
25
+
26
+ return patch_embeds
examples/NAIP_yosemite_v3_resized.png ADDED

Git LFS Details

  • SHA256: 58ffc2032198596f427e5c1bb176bd694f427a810d0302b49260c81aea69fe87
  • Pointer size: 131 Bytes
  • Size of remote file: 458 kB
examples/american_black_bear_inat_248820933.jpeg ADDED

Git LFS Details

  • SHA256: d65b3c63c5c6cd98d8a324cd46a9b1216eacd1ce245a357f04409083cd28d944
  • Pointer size: 131 Bytes
  • Size of remote file: 266 kB
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python 3.10.14
2
+
3
+ numpy==1.26.3
4
+ torch==2.4.1
5
+ torchvision==0.19.1
6
+ pytorch-lightning==2.2.1
7
+ open_clip_torch==2.30.0
8
+ transformers==4.45.1
9
+ tokenizers==0.20.3
10
+ opencv-python==4.10.0.84
11
+ gradio==3.39.0