search-tta-demo / app_multimodal_inference.py
derektan
Updated attributions
f952795
"""
Search-TTA demo
"""
# ────────────────────────── imports ───────────────────────────────────
import cv2
import gradio as gr
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import io
import torchaudio
import spaces # integration with ZeroGPU on hf
from torchvision import transforms
import open_clip
from clip_vision_per_patch_model import CLIPVisionPerPatchModel
from transformers import ClapAudioModelWithProjection
from transformers import ClapProcessor
# ────────────────────────── global config & models ────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# BioCLIP (ground-image & text encoder)
bio_model, _, _ = open_clip.create_model_and_transforms("hf-hub:imageomics/bioclip")
bio_model = bio_model.to(device).eval()
bio_tokenizer = open_clip.get_tokenizer("hf-hub:imageomics/bioclip")
# Satellite patch encoder CLIP-L-336 per-patch)
sat_model: CLIPVisionPerPatchModel = (
CLIPVisionPerPatchModel.from_pretrained("derektan95/search-tta-sat")
.to(device)
.eval()
)
# Sound CLAP model
sound_model: ClapAudioModelWithProjection = (
ClapAudioModelWithProjection.from_pretrained("derektan95/search-tta-sound")
.to(device)
.eval()
)
sound_processor: ClapProcessor = ClapProcessor.from_pretrained("derektan95/search-tta-sound")
SAMPLE_RATE = 48000
logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
logit_scale = logit_scale.exp()
blur_kernel = (5,5)
# ────────────────────────── transforms (exact spec) ───────────────────
img_transform = transforms.Compose(
[
transforms.Resize((256, 256)),
transforms.CenterCrop((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
]
)
imo_transform = transforms.Compose(
[
transforms.Resize((336, 336)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
]
)
def get_audio_clap(path_to_audio,format="mp3",padding="repeatpad",truncation="fusion"):
track, sr = torchaudio.load(path_to_audio, format=format) # torchaudio.load(path_to_audio)
track = track.mean(axis=0)
track = torchaudio.functional.resample(track, orig_freq=sr, new_freq=SAMPLE_RATE)
output = sound_processor(audios=track, sampling_rate=SAMPLE_RATE, max_length_s=10, return_tensors="pt",padding=padding,truncation=truncation)
return output
# ────────────────────────── helpers ───────────────────────────────────
@torch.no_grad()
def _encode_ground(img_pil: Image.Image) -> torch.Tensor:
img = img_transform(img_pil).unsqueeze(0).to(device)
img_embeds, *_ = bio_model(img)
return img_embeds
@torch.no_grad()
def _encode_text(text: str) -> torch.Tensor:
toks = bio_tokenizer(text).to(device)
_, txt_embeds, _ = bio_model(text=toks)
return txt_embeds
@torch.no_grad()
def _encode_sat(img_pil: Image.Image) -> torch.Tensor:
imo = imo_transform(img_pil).unsqueeze(0).to(device)
imo_embeds = sat_model(imo)
return imo_embeds
@torch.no_grad()
def _encode_sound(sound) -> torch.Tensor:
processed_sound = get_audio_clap(sound)
for k in processed_sound.keys():
processed_sound[k] = processed_sound[k].to(device)
unnormalized_audio_embeds = sound_model(**processed_sound).audio_embeds
sound_embeds = torch.nn.functional.normalize(unnormalized_audio_embeds, dim=-1)
return sound_embeds
def _similarity_heatmap(query: torch.Tensor, patches: torch.Tensor) -> np.ndarray:
sims = torch.matmul(query, patches.t()) * logit_scale
sims = sims.t().sigmoid()
sims = sims[1:].squeeze() # drop CLS token
side = int(np.sqrt(len(sims)))
sims = sims.reshape(side, side)
return sims.cpu().detach().numpy()
def _array_to_pil(arr: np.ndarray) -> Image.Image:
"""
Render arr with viridis, automatically stretching its own min→max to 0→1
so that the most-similar patches appear yellow.
"""
# Gausian Smoothing
if blur_kernel != (0,0):
arr = cv2.GaussianBlur(arr, blur_kernel, 0)
# --- contrast-stretch to local 0-1 range --------------------------
arr_min, arr_max = float(arr.min()), float(arr.max())
if arr_max - arr_min < 1e-6: # avoid /0 when the heat-map is flat
arr_scaled = np.zeros_like(arr)
else:
arr_scaled = (arr - arr_min) / (arr_max - arr_min)
# ------------------------------------------------------------------
fig, ax = plt.subplots(figsize=(2.6, 2.6), dpi=96)
ax.imshow(arr_scaled, cmap="viridis", vmin=0.0, vmax=1.0)
ax.axis("off")
buf = io.BytesIO()
plt.tight_layout(pad=0)
fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
plt.close(fig)
buf.seek(0)
return Image.open(buf)
# ────────────────────────── main inference ────────────────────────────
# integration with ZeroGPU on hf
@spaces.GPU(duration=5)
def process(
sat_img: Image.Image,
taxonomy: str,
ground_img: Image.Image | None,
sound: torch.Tensor | None,
):
if sat_img is None:
return None, None
patches = _encode_sat(sat_img)
heat_ground, heat_text, heat_sound = None, None, None
if ground_img is not None:
q_img = _encode_ground(ground_img)
heat_ground = _array_to_pil(_similarity_heatmap(q_img, patches))
if taxonomy.strip():
q_txt = _encode_text(taxonomy.strip())
heat_text = _array_to_pil(_similarity_heatmap(q_txt, patches))
if sound is not None:
q_sound = _encode_sound(sound)
heat_sound = _array_to_pil(_similarity_heatmap(q_sound, patches))
return heat_ground, heat_text, heat_sound
# ────────────────────────── Gradio UI ─────────────────────────────────
with gr.Blocks(title="Search-TTA", theme=gr.themes.Base()) as demo:
gr.Markdown(
"""
# Search-TTA: A Multimodal Test-Time Adaptation Framework for Visual Search in the Wild Demo
Click on any of the <b>examples below</b> and run the <b>multimodal inference demo</b>. Check out the <b>test-time adaptation feature</b> by switching to the previous tab above. <br>
If you encounter any errors, refresh the browser and rerun the demo, or try again the next day. We will improve this in the future. <br>
<a href="https://search-tta.github.io">Project Website</a>
"""
)
# with gr.Row():
# gr.Markdown(
# """
# <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
# <div>
# <h1>Search-TTA: A Multimodal Test-Time Adaptation Framework for Visual Search in the Wild</h1>
# <span></span>
# <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
# <a href="https://search-tta.github.io">Project Website</a>
# </h2>
# <span></span>
# <h2 style='font-weight: 450; font-size: 0.5rem; margin: 0rem'>[Work in Progress]</h2>
# </div>
# </div>
# """
# <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>WACV 2025</h2>
# <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
# <a href="https://derektan95.github.io">Derek M. S. Tan</a>,
# <a href="https://chinchinati.github.io/">Shailesh</a>,
# <a href="https://www.linkedin.com/in/boyang-liu-nus">Boyang Liu</a>,
# <a href="https://www.linkedin.com/in/loki-silvres">Alok Raj</a>,
# <a href="https://www.linkedin.com/in/ang-qi-xuan-714347142">Qi Xuan Ang</a>,
# <a href="https://weihengdai.top">Weiheng Dai</a>,
# <a href="https://www.linkedin.com/in/tanishqduhan">Tanishq Duhan</a>,
# <a href="https://www.linkedin.com/in/jimmychiun">Jimmy Chiun</a>,
# <a href="https://www.yuhongcao.online/">Yuhong Cao</a>,
# <a href="https://www.cs.toronto.edu/~florian/">Florian Shkurti</a>,
# <a href="https://www.marmotlab.org/bio.html">Guillaume Sartoretti</a>
# </h2>
# <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>National University of Singapore, University of Toronto, IIT-Dhanbad, Singapore Technologies Engineering</h2>
# )
with gr.Row(variant="panel"):
# LEFT COLUMN (satellite, taxonomy, run)
with gr.Column():
sat_input = gr.Image(
label="Satellite Image",
sources=["upload"],
type="pil",
height=320,
)
taxonomy_input = gr.Textbox(
label="Full Taxonomy Name (optional)",
placeholder="e.g. Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota",
)
# ─── NEW: sound input ───────────────────────────
sound_input = gr.Audio(
label="Sound Input (optional)",
sources=["upload"], # or "microphone" / "url" as you prefer
type="filepath", # or "numpy" if you want raw arrays
)
run_btn = gr.Button("Run", variant="primary")
# RIGHT COLUMN (ground image + two heat-maps)
with gr.Column():
ground_input = gr.Image(
label="Ground-level Image (optional)",
sources=["upload"],
type="pil",
height=320,
)
gr.Markdown("### Heat-map Results")
with gr.Row():
# Separate label and image to avoid overlap
with gr.Column(scale=1, min_width=100):
gr.Markdown("**Ground Image Query**", elem_id="label-ground")
heat_ground_out = gr.Image(
show_label=False,
height=160,
# width=160,
)
with gr.Column(scale=1, min_width=100):
gr.Markdown("**Text Query**", elem_id="label-text")
heat_text_out = gr.Image(
show_label=False,
height=160,
# width=160,
)
with gr.Column(scale=1, min_width=100):
gr.Markdown("**Sound Query**", elem_id="label-sound")
heat_sound_out = gr.Image(
show_label=False,
height=160,
# width=160,
)
# ─── NEW: sound output ─────────────────────────
# sound_output = gr.Audio(
# label="Playback",
# )
# EXAMPLES
with gr.Row():
gr.Markdown("### In-Domain Taxonomy")
with gr.Row():
gr.Examples(
examples=[
[
"examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/80645_39.76079_-74.10316.jpg",
"examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/cc1ebaf9-899d-49f2-81c8-d452249a8087.jpg",
"Animalia Chordata Aves Charadriiformes Laridae Larus marinus",
"examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/89758229.mp3"
],
[
"examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/28871_-12.80255_-69.29999.jpg",
"examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/1b8064f8-7deb-4b30-98cd-69da98ba6a3d.jpg",
"Animalia Chordata Mammalia Rodentia Caviidae Hydrochoerus hydrochaeris",
"examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/166631961.mp3"
],
[
"examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/277303_38.72364_-75.07749.jpg",
"examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/0b9cc264-a2ba-44bd-8e41-0d01a6edd1e8.jpg",
"Animalia Arthropoda Malacostraca Decapoda Ocypodidae Ocypode quadrata",
"examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/12372063.mp3"
],
[
"examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/388246_45.49036_7.14796.jpg",
"examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/327e1f07-692b-4140-8a3e-bd098bc064ff.jpg",
"Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota",
"examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/59677071.mp3"
],
[
"examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/410613_5.35573_100.28948.jpg",
"examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/461d8e6c-0e66-4acc-8ecd-bfd9c218bc14.jpg",
"Animalia Chordata Reptilia Squamata Varanidae Varanus salvator",
None
],
],
inputs=[sat_input, ground_input, taxonomy_input, sound_input],
outputs=[heat_ground_out, heat_text_out, heat_sound_out],
fn=process,
cache_examples=False,
)
# EXAMPLES
with gr.Row():
gr.Markdown("### Out-Domain Taxonomy")
with gr.Row():
gr.Examples(
examples=[
[
"examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/27423_35.64005_-121.17595.jpg",
"examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3aac526d-c921-452a-af6a-cb4f2f52e2c4.jpg",
"Animalia Chordata Mammalia Carnivora Phocidae Mirounga angustirostris",
"examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3123948.mp3"
],
[
"examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/1528408_13.00422_80.23033.jpg",
"examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/37faabd2-a613-4461-b27e-82fe5955ecaf.jpg",
"Animalia Chordata Mammalia Carnivora Canidae Canis aureus",
"examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/189318716.mp3"
],
[
"examples/Animalia_Chordata_Mammalia_Carnivora_Ursidae_Ursus_americanus/yosemite_v3_resized.png",
"examples/Animalia_Chordata_Mammalia_Carnivora_Ursidae_Ursus_americanus/248820933.jpeg",
"Animalia Chordata Mammalia Carnivora Ursidae Ursus americanus",
None
],
[
"examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/304160_34.0144_-119.54417.jpg",
"examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/0cbdfbf2-6cfe-4d61-9602-c949f24d0293.jpg",
"Animalia Chordata Mammalia Carnivora Canidae Urocyon littoralis",
None
],
],
inputs=[sat_input, ground_input, taxonomy_input, sound_input],
outputs=[heat_ground_out, heat_text_out, heat_sound_out],
fn=process,
cache_examples=False,
)
# CALLBACK
run_btn.click(
fn=process,
inputs=[sat_input, taxonomy_input, ground_input, sound_input],
outputs=[heat_ground_out, heat_text_out, heat_sound_out],
)
# Footer to point out to model and data from app page.
gr.Markdown(
"""
The satellite image CLIP encoder is fine-tuned using [Sentinel-2 Level 2A](https://docs.sentinel-hub.com/api/latest/data/sentinel-2-l2a/) satellite image and taxonomy images (with GPS locations) from [iNaturalist](https://inaturalist.org/). The sound CLIP encoder is fine-tuned with a subset of the same taxonomy images and their corresponding sounds from [iNaturalist](https://inaturalist.org/). Some of these iNaturalist data are also used in [Taxabind](https://arxiv.org/abs/2411.00683). Note that while some of the examples above result in poor probability distributions, they will be improved using our test-time adaptation framework during the search process.
"""
)
# LAUNCH
if __name__ == "__main__":
demo.queue(max_size=15)
demo.launch(share=True)