application3
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Dockerfile +72 -0
- LICENSE.txt +12 -0
- README.md +232 -8
- configs/gs/base.yaml +51 -0
- configs/train.yaml +38 -0
- extra/archive/rasterizer_impl.h +75 -0
- extra/archive/simple-knn.patch.txt +13 -0
- full_eval.py +102 -0
- gradio_demo.py +424 -0
- install.sh +35 -0
- main.py +268 -0
- metrics.py +115 -0
- requirements.txt +1 -0
- script.bash +0 -0
- source/EDGS.code-workspace +11 -0
- source/__init__.py +0 -0
- source/corr_init.py +907 -0
- source/data_utils.py +28 -0
- source/losses.py +100 -0
- source/networks.py +48 -0
- source/timer.py +24 -0
- source/trainer.py +265 -0
- source/utils_aux.py +92 -0
- source/utils_preprocess.py +334 -0
- source/visualization.py +1072 -0
- submodules/RoMa/.gitignore +11 -0
- submodules/RoMa/LICENSE +21 -0
- submodules/RoMa/README.md +123 -0
- submodules/RoMa/data/.gitignore +2 -0
- submodules/RoMa/demo/demo_3D_effect.py +47 -0
- submodules/RoMa/demo/demo_fundamental.py +34 -0
- submodules/RoMa/demo/demo_match.py +50 -0
- submodules/RoMa/demo/demo_match_opencv_sift.py +43 -0
- submodules/RoMa/demo/demo_match_tiny.py +77 -0
- submodules/RoMa/demo/gif/.gitignore +2 -0
- submodules/RoMa/experiments/eval_roma_outdoor.py +57 -0
- submodules/RoMa/experiments/eval_tiny_roma_v1_outdoor.py +84 -0
- submodules/RoMa/experiments/roma_indoor.py +320 -0
- submodules/RoMa/experiments/train_roma_outdoor.py +307 -0
- submodules/RoMa/experiments/train_tiny_roma_v1_outdoor.py +498 -0
- submodules/RoMa/requirements.txt +14 -0
- submodules/RoMa/romatch/__init__.py +8 -0
- submodules/RoMa/romatch/benchmarks/__init__.py +6 -0
- submodules/RoMa/romatch/benchmarks/hpatches_sequences_homog_benchmark.py +113 -0
- submodules/RoMa/romatch/benchmarks/megadepth_dense_benchmark.py +106 -0
- submodules/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark.py +118 -0
- submodules/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark_poselib.py +119 -0
- submodules/RoMa/romatch/benchmarks/scannet_benchmark.py +143 -0
- submodules/RoMa/romatch/checkpointing/__init__.py +1 -0
- submodules/RoMa/romatch/checkpointing/checkpoint.py +60 -0
Dockerfile
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvidia/cuda:12.1.1-devel-ubuntu22.04 AS builder
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
COPY . /app/
|
6 |
+
|
7 |
+
ENV CUDA_HOME=/usr/local/cuda-12.1
|
8 |
+
ENV LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH
|
9 |
+
ENV PATH=$CUDA_HOME/bin:$PATH
|
10 |
+
|
11 |
+
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX"
|
12 |
+
|
13 |
+
RUN apt-get update && \
|
14 |
+
DEBIAN_FRONTEND=noninteractive apt-get install -y \
|
15 |
+
build-essential wget curl nano ninja-build unzip libgl-dev ffmpeg && \
|
16 |
+
apt-get clean && \
|
17 |
+
rm -rf /var/lib/apt/lists/*
|
18 |
+
|
19 |
+
ENV CONDA_DIR=/opt/conda
|
20 |
+
|
21 |
+
RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh && \
|
22 |
+
/bin/bash ~/miniconda.sh -b -p /opt/conda && \
|
23 |
+
rm ~/miniconda.sh
|
24 |
+
|
25 |
+
ENV PATH=$CONDA_DIR/bin:$PATH
|
26 |
+
|
27 |
+
RUN conda update -n base conda -y && \
|
28 |
+
conda install -n base conda-libmamba-solver -y && \
|
29 |
+
conda config --set solver libmamba
|
30 |
+
|
31 |
+
RUN conda create -y -n edgs python=3.10 pip
|
32 |
+
|
33 |
+
SHELL ["conda", "run", "-n", "edgs", "/bin/bash", "-c"]
|
34 |
+
|
35 |
+
|
36 |
+
RUN conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia -y
|
37 |
+
|
38 |
+
RUN pip install -e submodules/gaussian-splatting/submodules/diff-gaussian-rasterization --no-build-isolation && \
|
39 |
+
pip install -e submodules/gaussian-splatting/submodules/simple-knn --no-build-isolation
|
40 |
+
|
41 |
+
RUN pip install pycolmap
|
42 |
+
|
43 |
+
RUN pip install wandb hydra-core tqdm torchmetrics lpips matplotlib rich plyfile imageio imageio-ffmpeg && \
|
44 |
+
conda install numpy=1.26.4 -y -c conda-forge --override-channels
|
45 |
+
|
46 |
+
RUN pip install -e submodules/RoMa
|
47 |
+
|
48 |
+
RUN pip install plotly scikit-learn moviepy==2.1.1 ffmpeg fastapi[standard]
|
49 |
+
|
50 |
+
# Imagen final
|
51 |
+
|
52 |
+
FROM nvidia/cuda:12.1.1-runtime-ubuntu22.04 AS final
|
53 |
+
|
54 |
+
WORKDIR /app
|
55 |
+
|
56 |
+
RUN apt-get update && \
|
57 |
+
DEBIAN_FRONTEND=noninteractive apt-get install -y \
|
58 |
+
libgl1-mesa-glx libsm6 libxext6 ffmpeg && \
|
59 |
+
apt-get clean && \
|
60 |
+
rm -rf /var/lib/apt/lists/*
|
61 |
+
|
62 |
+
COPY --from=builder /opt/conda /opt/conda
|
63 |
+
|
64 |
+
COPY --from=builder /app /app
|
65 |
+
|
66 |
+
ENV PATH="/opt/conda/bin:/opt/conda/envs/edgs/bin:$PATH"
|
67 |
+
|
68 |
+
ENV CUDA_HOME=/usr/local/cuda-12.1
|
69 |
+
ENV LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH
|
70 |
+
|
71 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
|
72 |
+
|
LICENSE.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright 2025, Dmytro Kotovenko, Olga Grebenkova, Björn Ommer
|
2 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted for non-commercial academic research and/or non-commercial personal use only provided that the following conditions are met:
|
3 |
+
|
4 |
+
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
5 |
+
|
6 |
+
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
7 |
+
|
8 |
+
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
9 |
+
|
10 |
+
Any use of this software beyond the above specified conditions requires a separate license. Please contact the copyright holders to discuss license terms.
|
11 |
+
|
12 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
README.md
CHANGED
@@ -1,10 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
-
title: Gs Final
|
3 |
-
emoji: 🐢
|
4 |
-
colorFrom: gray
|
5 |
-
colorTo: green
|
6 |
-
sdk: docker
|
7 |
-
pinned: false
|
8 |
-
---
|
9 |
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<h1 align="center">EDGS: Eliminating Densification for Efficient Convergence of 3DGS</h2>
|
2 |
+
|
3 |
+
<p align="center">
|
4 |
+
<a href="https://www.linkedin.com/in/dmitry-kotovenko-dl/">Dmytro Kotovenko</a><sup>*</sup> ·
|
5 |
+
<a href="https://www.linkedin.com/in/grebenkovao/">Olga Grebenkova</a><sup>*</sup> ·
|
6 |
+
<a href="https://ommer-lab.com/people/ommer/">Björn Ommer</a>
|
7 |
+
</p>
|
8 |
+
|
9 |
+
<p align="center">CompVis @ LMU Munich · Munich Center for Machine Learning (MCML) </p>
|
10 |
+
<p align="center">* equal contribution </p>
|
11 |
+
|
12 |
+
<p align="center">
|
13 |
+
<a href="https://compvis.github.io/EDGS/"><img src="https://img.shields.io/badge/Project-Page-blue" alt="Project Page"></a>
|
14 |
+
<a href="https://arxiv.org/pdf/2504.13204"><img src="https://img.shields.io/badge/arXiv-PDF-b31b1b" alt="Paper"></a>
|
15 |
+
<a href="https://colab.research.google.com/github/CompVis/EDGS/blob/main/notebooks/fit_model_to_scene_full.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
|
16 |
+
<a href="https://huggingface.co/spaces/CompVis/EDGS"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue" alt="Hugging Face"></a>
|
17 |
+
|
18 |
+
</p>
|
19 |
+
|
20 |
+
<p align="center">
|
21 |
+
<img src="./assets/Teaser2.png" width="99%">
|
22 |
+
</p>
|
23 |
+
|
24 |
+
<p>
|
25 |
+
<strong>3DGS</strong> initializes with a sparse set of Gaussians and progressively adds more in under-reconstructed regions. In contrast, <strong>EDGS</strong> starts with
|
26 |
+
a dense initialization from triangulated 2D correspondences across training image pairs,
|
27 |
+
requiring only minimal refinement. This leads to <strong>faster convergence</strong> and <strong>higher rendering quality</strong>. Our method reaches the original 3DGS <strong>LPIPS score in just 25% of the training time</strong> and uses only <strong>60% of the splats</strong>.
|
28 |
+
Renderings become <strong>nearly indistinguishable from ground truth after only 3,000 steps — without any densification</strong>.
|
29 |
+
</p>
|
30 |
+
|
31 |
+
<h3 align="center">3D scene reconstruction using our method in 11 seconds.</h3>
|
32 |
+
<p align="center">
|
33 |
+
<img src="assets/video_fruits_our_optimization.gif" width="480" alt="3D Reconstruction Demo">
|
34 |
+
</p>
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
## 📚 Table of Contents
|
39 |
+
- [🚀 Quickstart](#sec-quickstart)
|
40 |
+
- [🛠️ Installation](#sec-install)
|
41 |
+
- [📦 Data](#sec-data)
|
42 |
+
|
43 |
+
- [🏋️ Training](#sec-training)
|
44 |
+
- [🏗️ Reusing Our Model](#sec-reuse)
|
45 |
+
- [📄 Citation](#sec-citation)
|
46 |
+
|
47 |
+
<a id="sec-quickstart"></a>
|
48 |
+
## 🚀 Quickstart
|
49 |
+
The fastest way to try our model is through the [Hugging Face demo](https://huggingface.co/spaces/magistrkoljan/EDGS), which lets you upload images or a video and interactively rotate the resulting 3D scene. For broad accessibility, we currently support only **forward-facing scenes**.
|
50 |
+
#### Steps:
|
51 |
+
1. Upload a list of photos or a single video.
|
52 |
+
2. Click **📸 Preprocess Input** to estimate 3D positions using COLMAP.
|
53 |
+
3. Click **🚀 Start Reconstruction** to run the model.
|
54 |
+
|
55 |
+
You can also **explore the reconstructed scene in 3D** directly in the browser.
|
56 |
+
|
57 |
+
> ⚡ Runtime: EDGS typically takes just **10–20 seconds**, plus **5–10 seconds** for COLMAP processing. Additional time may be needed to save outputs (model, video, 3D preview).
|
58 |
+
|
59 |
+
You can also run the same app locally on your machine with command:
|
60 |
+
```CUDA_VISIBLE_DEVICES=0 python gradio_demo.py --port 7862 --no_share```
|
61 |
+
Without `--no_share` flag you will get the adress for gradio app that you can share with the others allowing others to process their data on your server.
|
62 |
+
|
63 |
+
Alternatively, check our [Colab notebook](https://colab.research.google.com/github/CompVis/EDGS/blob/main/notebooks/fit_model_to_scene_full.ipynb).
|
64 |
+
|
65 |
+
###
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
<a id="sec-install"></a>
|
70 |
+
## 🛠️ Installation
|
71 |
+
|
72 |
+
You can either run `install.sh` or manually install using the following:
|
73 |
+
|
74 |
+
```bash
|
75 |
+
git clone [email protected]:CompVis/EDGS.git --recursive
|
76 |
+
cd EDGS
|
77 |
+
git submodule update --init --recursive
|
78 |
+
|
79 |
+
conda create -y -n edgs python=3.10 pip
|
80 |
+
conda activate edgs
|
81 |
+
|
82 |
+
# Set up path to your CUDA. In our experience similar versions like 12.2 also work well
|
83 |
+
export CUDA_HOME=/usr/local/cuda-12.1
|
84 |
+
export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH
|
85 |
+
export PATH=$CUDA_HOME/bin:$PATH
|
86 |
+
|
87 |
+
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia -y
|
88 |
+
conda install nvidia/label/cuda-12.1.0::cuda-toolkit -y
|
89 |
+
|
90 |
+
pip install -e submodules/gaussian-splatting/submodules/diff-gaussian-rasterization
|
91 |
+
pip install -e submodules/gaussian-splatting/submodules/simple-knn
|
92 |
+
|
93 |
+
# For COLMAP and pycolmap
|
94 |
+
# Optionally install original colmap but probably pycolmap suffices
|
95 |
+
# conda install conda-forge/label/colmap_dev::colmap
|
96 |
+
pip install pycolmap
|
97 |
+
|
98 |
+
|
99 |
+
pip install wandb hydra-core tqdm torchmetrics lpips matplotlib rich plyfile imageio imageio-ffmpeg
|
100 |
+
conda install numpy=1.26.4 -y -c conda-forge --override-channels
|
101 |
+
|
102 |
+
pip install -e submodules/RoMa
|
103 |
+
conda install anaconda::jupyter --yes
|
104 |
+
|
105 |
+
# Stuff necessary for gradio and visualizations
|
106 |
+
pip install gradio
|
107 |
+
pip install plotly scikit-learn moviepy==2.1.1 ffmpeg
|
108 |
+
pip install open3d
|
109 |
+
```
|
110 |
+
|
111 |
+
<a id="sec-data"></a>
|
112 |
+
## 📦 Data
|
113 |
+
|
114 |
+
We evaluated on the following datasets:
|
115 |
+
|
116 |
+
- **MipNeRF360** — download [here](https://jonbarron.info/mipnerf360/). Unzip "Dataset Pt. 1" and "Dataset Pt. 2", then merge scenes.
|
117 |
+
- **Tanks & Temples + Deep Blending** — from the [original 3DGS repo](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/datasets/input/tandt_db.zip).
|
118 |
+
|
119 |
+
### Using Your Own Dataset
|
120 |
+
|
121 |
+
You can use the same data format as the [3DGS project](https://github.com/graphdeco-inria/gaussian-splatting?tab=readme-ov-file#processing-your-own-scenes). Please follow their guide to prepare your scene.
|
122 |
+
|
123 |
+
Expected folder structure:
|
124 |
+
```
|
125 |
+
scene_folder
|
126 |
+
|---images
|
127 |
+
| |---<image 0>
|
128 |
+
| |---<image 1>
|
129 |
+
| |---...
|
130 |
+
|---sparse
|
131 |
+
|---0
|
132 |
+
|---cameras.bin
|
133 |
+
|---images.bin
|
134 |
+
|---points3D.bin
|
135 |
+
```
|
136 |
+
|
137 |
+
Nerf synthetic format is also acceptable.
|
138 |
+
|
139 |
+
You can also use functions provided in our code to convert a collection of images or a sinlge video into a desired format. However, this may requre tweaking and processing time can be large for large collection of images with little overlap.
|
140 |
+
|
141 |
+
<a id="sec-training"></a>
|
142 |
+
## 🏋️ Training
|
143 |
+
|
144 |
+
|
145 |
+
To optimize on a single scene in COLMAP format use this code.
|
146 |
+
```bash
|
147 |
+
python train.py \
|
148 |
+
train.gs_epochs=30000 \
|
149 |
+
train.no_densify=True \
|
150 |
+
gs.dataset.source_path=<scene folder> \
|
151 |
+
gs.dataset.model_path=<output folder> \
|
152 |
+
init_wC.matches_per_ref=20000 \
|
153 |
+
init_wC.nns_per_ref=3 \
|
154 |
+
init_wC.num_refs=180
|
155 |
+
```
|
156 |
+
<details>
|
157 |
+
<summary><span style="font-weight: bold;">Command Line Arguments for train.py</span></summary>
|
158 |
+
|
159 |
+
* `train.gs_epochs`
|
160 |
+
Number of training iterations (steps) for Gaussian Splatting.
|
161 |
+
* `train.no_densify`
|
162 |
+
Disables densification. True by default.
|
163 |
+
* `gs.dataset.source_path`
|
164 |
+
Path to your input dataset directory. This should follow the same format as the original 3DGS dataset structure.
|
165 |
+
* `gs.dataset.model_path`
|
166 |
+
Output directory where the trained model, logs, and renderings will be saved.
|
167 |
+
* `init_wC.matches_per_ref`
|
168 |
+
Number of 2D feature correspondences to extract per reference view for initialization. More matches leads to more gaussians.
|
169 |
+
* `init_wC.nns_per_ref`
|
170 |
+
Number of nearest neighbor images used per reference during matching.
|
171 |
+
* `init_wC.num_refs`
|
172 |
+
Total number of reference views sampled.
|
173 |
+
* `wandb.mode`
|
174 |
+
Specifies how Weights & Biases (W&B) logging is handled.
|
175 |
+
|
176 |
+
- Default: `"disabled"`
|
177 |
+
- Options:
|
178 |
+
- `"online"` — log to the W&B server in real-time
|
179 |
+
- `"offline"` — save logs locally to sync later
|
180 |
+
- `"disabled"` — turn off W&B logging entirely
|
181 |
+
|
182 |
+
If you want to enable W&B logging, make sure to also configure:
|
183 |
+
|
184 |
+
- `wandb.project` — the name of your W&B project
|
185 |
+
- `wandb.entity` — your W&B username or team name
|
186 |
+
|
187 |
+
Example override:
|
188 |
+
```bash
|
189 |
+
wandb.mode=online wandb.project=EDGS wandb.entity=your_username train.gs_epochs=15_000 init_wC.matches_per_ref=15_000
|
190 |
+
```
|
191 |
+
</details>
|
192 |
+
<br>
|
193 |
+
|
194 |
+
To run full evaluation on all datasets:
|
195 |
+
|
196 |
+
```bash
|
197 |
+
python full_eval.py -m360 <mipnerf360 folder> -tat <tanks and temples folder> -db <deep blending folder>
|
198 |
+
```
|
199 |
+
<a id="sec-reuse"></a>
|
200 |
+
## 🏗️ Reusing Our Model
|
201 |
+
|
202 |
+
Our model is essentially a better **initialization module** for Gaussian Splatting. You can integrate it into your pipeline by calling:
|
203 |
+
|
204 |
+
```python
|
205 |
+
source.corr_init.init_gaussians_with_corr(...)
|
206 |
+
```
|
207 |
+
### Input arguments:
|
208 |
+
- A GaussianModel and Scene instance
|
209 |
+
- A configuration namespace `cfg.init_wC` to specify parameters like the number of matches, neighbors, and reference views
|
210 |
+
- A RoMA model (automatically instantiated if not provided)
|
211 |
+
|
212 |
+
|
213 |
+
|
214 |
+
<a id="sec-citation"></a>
|
215 |
+
## 📄 Citation
|
216 |
+
```bibtex
|
217 |
+
@misc{kotovenko2025edgseliminatingdensificationefficient,
|
218 |
+
title={EDGS: Eliminating Densification for Efficient Convergence of 3DGS},
|
219 |
+
author={Dmytro Kotovenko and Olga Grebenkova and Björn Ommer},
|
220 |
+
year={2025},
|
221 |
+
eprint={2504.13204},
|
222 |
+
archivePrefix={arXiv},
|
223 |
+
primaryClass={cs.GR},
|
224 |
+
url={https://arxiv.org/abs/2504.13204},
|
225 |
+
}
|
226 |
+
```
|
227 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
|
229 |
+
# TODO:
|
230 |
+
- [ ] Code for training and processing forward-facing scenes.
|
231 |
+
- [ ] More data examples
|
232 |
+
|
233 |
+
|
234 |
+
|
configs/gs/base.yaml
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: source.networks.Warper3DGS
|
2 |
+
|
3 |
+
verbose: True
|
4 |
+
viewpoint_stack: !!null
|
5 |
+
sh_degree: 3
|
6 |
+
|
7 |
+
opt:
|
8 |
+
iterations: 30000
|
9 |
+
position_lr_init: 0.00016
|
10 |
+
position_lr_final: 1.6e-06
|
11 |
+
position_lr_delay_mult: 0.01
|
12 |
+
position_lr_max_steps: 30000
|
13 |
+
feature_lr: 0.0025
|
14 |
+
opacity_lr: 0.025
|
15 |
+
scaling_lr: 0.005
|
16 |
+
rotation_lr: 0.001
|
17 |
+
percent_dense: 0.01
|
18 |
+
lambda_dssim: 0.2
|
19 |
+
densification_interval: 100
|
20 |
+
opacity_reset_interval: 30000
|
21 |
+
densify_from_iter: 500
|
22 |
+
densify_until_iter: 15000
|
23 |
+
densify_grad_threshold: 0.0002
|
24 |
+
random_background: false
|
25 |
+
save_iterations: [3000, 7000, 15000, 30000]
|
26 |
+
batch_size: 64
|
27 |
+
exposure_lr_init: 0.01
|
28 |
+
exposure_lr_final: 0.0001
|
29 |
+
exposure_lr_delay_steps: 0
|
30 |
+
exposure_lr_delay_mult: 0.0
|
31 |
+
|
32 |
+
TRAIN_CAM_IDX_TO_LOG: 50
|
33 |
+
TEST_CAM_IDX_TO_LOG: 10
|
34 |
+
|
35 |
+
pipe:
|
36 |
+
convert_SHs_python: False
|
37 |
+
compute_cov3D_python: False
|
38 |
+
debug: False
|
39 |
+
antialiasing: False
|
40 |
+
|
41 |
+
dataset:
|
42 |
+
densify_until_iter: 15000
|
43 |
+
source_path: '' #path to dataset
|
44 |
+
model_path: '' #path to logs
|
45 |
+
images: images
|
46 |
+
resolution: -1
|
47 |
+
white_background: false
|
48 |
+
data_device: cuda
|
49 |
+
eval: false
|
50 |
+
depths: ""
|
51 |
+
train_test_exp: False
|
configs/train.yaml
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- gs: base
|
3 |
+
- _self_
|
4 |
+
|
5 |
+
seed: 228
|
6 |
+
|
7 |
+
wandb:
|
8 |
+
mode: "online" # "disabled" for no logging
|
9 |
+
entity: "3dcorrespondence"
|
10 |
+
project: "Adv3DGS"
|
11 |
+
group: null
|
12 |
+
name: null
|
13 |
+
tag: "debug"
|
14 |
+
|
15 |
+
train:
|
16 |
+
gs_epochs: 0 # number of 3dgs iterations
|
17 |
+
reduce_opacity: True
|
18 |
+
no_densify: False # if True, the model will not be densified
|
19 |
+
max_lr: True
|
20 |
+
|
21 |
+
load:
|
22 |
+
gs: null #path to 3dgs checkpoint
|
23 |
+
gs_step: null #number of iterations, e.g. 7000
|
24 |
+
|
25 |
+
device: "cuda:0"
|
26 |
+
verbose: true
|
27 |
+
|
28 |
+
init_wC:
|
29 |
+
use: True # use EDGS
|
30 |
+
matches_per_ref: 15_000 # number of matches per reference
|
31 |
+
num_refs: 180 # number of reference images
|
32 |
+
nns_per_ref: 3 # number of nearest neighbors per reference
|
33 |
+
scaling_factor: 0.001
|
34 |
+
proj_err_tolerance: 0.01
|
35 |
+
roma_model: "outdoors" # you can change this to "indoors" or "outdoors"
|
36 |
+
add_SfM_init : False
|
37 |
+
|
38 |
+
|
extra/archive/rasterizer_impl.h
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (C) 2023, Inria
|
3 |
+
* GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
* All rights reserved.
|
5 |
+
*
|
6 |
+
* This software is free for non-commercial, research and evaluation use
|
7 |
+
* under the terms of the LICENSE.md file.
|
8 |
+
*
|
9 |
+
* For inquiries contact [email protected]
|
10 |
+
*/
|
11 |
+
|
12 |
+
#pragma once
|
13 |
+
|
14 |
+
#include <cstdint>
|
15 |
+
#include <iostream>
|
16 |
+
#include <vector>
|
17 |
+
#include "rasterizer.h"
|
18 |
+
#include <cuda_runtime_api.h>
|
19 |
+
|
20 |
+
namespace CudaRasterizer
|
21 |
+
{
|
22 |
+
template <typename T>
|
23 |
+
static void obtain(char*& chunk, T*& ptr, std::size_t count, std::size_t alignment)
|
24 |
+
{
|
25 |
+
std::size_t offset = (reinterpret_cast<std::uintptr_t>(chunk) + alignment - 1) & ~(alignment - 1);
|
26 |
+
ptr = reinterpret_cast<T*>(offset);
|
27 |
+
chunk = reinterpret_cast<char*>(ptr + count);
|
28 |
+
}
|
29 |
+
|
30 |
+
struct GeometryState
|
31 |
+
{
|
32 |
+
size_t scan_size;
|
33 |
+
float* depths;
|
34 |
+
char* scanning_space;
|
35 |
+
bool* clamped;
|
36 |
+
int* internal_radii;
|
37 |
+
float2* means2D;
|
38 |
+
float* cov3D;
|
39 |
+
float4* conic_opacity;
|
40 |
+
float* rgb;
|
41 |
+
uint32_t* point_offsets;
|
42 |
+
uint32_t* tiles_touched;
|
43 |
+
|
44 |
+
static GeometryState fromChunk(char*& chunk, size_t P);
|
45 |
+
};
|
46 |
+
|
47 |
+
struct ImageState
|
48 |
+
{
|
49 |
+
uint2* ranges;
|
50 |
+
uint32_t* n_contrib;
|
51 |
+
float* accum_alpha;
|
52 |
+
|
53 |
+
static ImageState fromChunk(char*& chunk, size_t N);
|
54 |
+
};
|
55 |
+
|
56 |
+
struct BinningState
|
57 |
+
{
|
58 |
+
size_t sorting_size;
|
59 |
+
uint64_t* point_list_keys_unsorted;
|
60 |
+
uint64_t* point_list_keys;
|
61 |
+
uint32_t* point_list_unsorted;
|
62 |
+
uint32_t* point_list;
|
63 |
+
char* list_sorting_space;
|
64 |
+
|
65 |
+
static BinningState fromChunk(char*& chunk, size_t P);
|
66 |
+
};
|
67 |
+
|
68 |
+
template<typename T>
|
69 |
+
size_t required(size_t P)
|
70 |
+
{
|
71 |
+
char* size = nullptr;
|
72 |
+
T::fromChunk(size, P);
|
73 |
+
return ((size_t)size) + 128;
|
74 |
+
}
|
75 |
+
};
|
extra/archive/simple-knn.patch.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diff --git a/simple_knn.cu b/simple_knn.cu
|
2 |
+
index e72e4c9..b2deb1b 100644
|
3 |
+
--- a/simple_knn.cu
|
4 |
+
+++ b/simple_knn.cu
|
5 |
+
@@ -11,6 +11,8 @@
|
6 |
+
|
7 |
+
#define BOX_SIZE 1024
|
8 |
+
|
9 |
+
+#include <float.h>
|
10 |
+
+
|
11 |
+
#include "cuda_runtime.h"
|
12 |
+
#include "device_launch_parameters.h"
|
13 |
+
#include "simple_knn.h"
|
full_eval.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
import os
|
13 |
+
from argparse import ArgumentParser
|
14 |
+
|
15 |
+
mipnerf360_outdoor_scenes = ["bicycle", "flowers", "garden", "stump", "treehill"]
|
16 |
+
mipnerf360_indoor_scenes = ["room", "counter", "kitchen", "bonsai"]
|
17 |
+
tanks_and_temples_scenes = ["truck", "train"]
|
18 |
+
deep_blending_scenes = ["drjohnson", "playroom"]
|
19 |
+
|
20 |
+
parser = ArgumentParser(description="Full evaluation script parameters")
|
21 |
+
parser.add_argument("--skip_training", action="store_true")
|
22 |
+
parser.add_argument("--skip_rendering", action="store_true")
|
23 |
+
parser.add_argument("--skip_metrics", action="store_true")
|
24 |
+
parser.add_argument("--output_path", default="./eval")
|
25 |
+
args, _ = parser.parse_known_args()
|
26 |
+
|
27 |
+
all_scenes = []
|
28 |
+
all_scenes.extend(mipnerf360_outdoor_scenes)
|
29 |
+
all_scenes.extend(mipnerf360_indoor_scenes)
|
30 |
+
all_scenes.extend(tanks_and_temples_scenes)
|
31 |
+
all_scenes.extend(deep_blending_scenes)
|
32 |
+
|
33 |
+
if not args.skip_training or not args.skip_rendering:
|
34 |
+
parser.add_argument('--mipnerf360', "-m360", required=True, type=str)
|
35 |
+
parser.add_argument("--tanksandtemples", "-tat", required=True, type=str)
|
36 |
+
parser.add_argument("--deepblending", "-db", required=True, type=str)
|
37 |
+
args = parser.parse_args()
|
38 |
+
|
39 |
+
if not args.skip_training:
|
40 |
+
name = "EDGS_"
|
41 |
+
common_args = " --quiet --eval --test_iterations -1 "
|
42 |
+
for scene in mipnerf360_outdoor_scenes:
|
43 |
+
source = args.mipnerf360 + "/" + scene
|
44 |
+
experiment = name + scene
|
45 |
+
os.system(f"python train.py verbose=True gs.dataset.source_path={source} gs.dataset.model_path={args.output_path}/mipnerf/{scene} wandb.name={experiment} init_wC.use=True train.gs_epochs=30000 init_wC.matches_per_ref=25_000 init_wC.nns_per_ref=3 gs.dataset.images=images_4 init_wC.num_refs=180 train.no_densify=True")
|
46 |
+
for scene in mipnerf360_indoor_scenes:
|
47 |
+
source = args.mipnerf360 + "/" + scene
|
48 |
+
experiment = name + scene
|
49 |
+
os.system(f"python train.py verbose=True gs.dataset.source_path={source} gs.dataset.model_path={args.output_path}/mipnerf/{scene} wandb.name={experiment} init_wC.use=True train.gs_epochs=30000 init_wC.matches_per_ref=25_000 init_wC.nns_per_ref=3 gs.dataset.images=images_2 init_wC.num_refs=180 train.no_densify=True")
|
50 |
+
for scene in tanks_and_temples_scenes:
|
51 |
+
source = args.tanksandtemples + "/" + scene
|
52 |
+
experiment = name + scene +"_tandt"
|
53 |
+
os.system(f"python train.py verbose=True gs.dataset.source_path={source} gs.dataset.model_path={args.output_path}/mipnerf/{scene} wandb.name={experiment} init_wC.use=True train.gs_epochs=30000 init_wC.matches_per_ref=15_000 init_wC.nns_per_ref=3 init_wC.num_refs=180 train.no_densify=True")
|
54 |
+
for scene in deep_blending_scenes:
|
55 |
+
source = args.deepblending + "/" + scene
|
56 |
+
experiment = name + scene + "_db"
|
57 |
+
os.system(f"python train.py verbose=True gs.dataset.source_path={source} gs.dataset.model_path={args.output_path}/mipnerf/{scene} wandb.name={experiment} init_wC.use=True train.gs_epochs=30000 init_wC.matches_per_ref=15_000 init_wC.nns_per_ref=3 init_wC.num_refs=180 train.no_densify=True")
|
58 |
+
|
59 |
+
|
60 |
+
if not args.skip_rendering:
|
61 |
+
all_sources = []
|
62 |
+
for scene in mipnerf360_outdoor_scenes:
|
63 |
+
all_sources.append(args.mipnerf360 + "/" + scene)
|
64 |
+
for scene in mipnerf360_indoor_scenes:
|
65 |
+
all_sources.append(args.mipnerf360 + "/" + scene)
|
66 |
+
for scene in tanks_and_temples_scenes:
|
67 |
+
all_sources.append(args.tanksandtemples + "/" + scene )
|
68 |
+
for scene in deep_blending_scenes:
|
69 |
+
all_sources.append(args.deepblending + "/" + scene)
|
70 |
+
|
71 |
+
all_outputs = []
|
72 |
+
for scene in mipnerf360_outdoor_scenes:
|
73 |
+
all_outputs.append(args.output_path + "/mipnerf/" + scene)
|
74 |
+
for scene in mipnerf360_indoor_scenes:
|
75 |
+
all_outputs.append(args.output_path + "/mipnerf/" + scene)
|
76 |
+
for scene in tanks_and_temples_scenes:
|
77 |
+
all_outputs.append(args.output_path + "/tandt/" + scene)
|
78 |
+
for scene in deep_blending_scenes:
|
79 |
+
all_outputs.append(args.output_path + "/db/" + scene)
|
80 |
+
|
81 |
+
|
82 |
+
common_args = " --quiet --eval --skip_train"
|
83 |
+
for scene, source, output in zip(all_scenes, all_sources, all_outputs):
|
84 |
+
os.system("python ./submodules/gaussian-splatting/render.py --iteration 7000 -s " + source + " -m " + output + common_args)
|
85 |
+
os.system("python ./submodules/gaussian-splatting/render.py --iteration 30000 -s " + source + " -m " + output + common_args)
|
86 |
+
|
87 |
+
if not args.skip_metrics:
|
88 |
+
all_outputs = []
|
89 |
+
for scene in mipnerf360_outdoor_scenes:
|
90 |
+
all_outputs.append(args.output_path + "/mipnerf/" + scene)
|
91 |
+
for scene in mipnerf360_indoor_scenes:
|
92 |
+
all_outputs.append(args.output_path + "/mipnerf/" + scene)
|
93 |
+
for scene in tanks_and_temples_scenes:
|
94 |
+
all_outputs.append(args.output_path + "/tandt/" + scene)
|
95 |
+
for scene in deep_blending_scenes:
|
96 |
+
all_outputs.append(args.output_path + "/db/" + scene)
|
97 |
+
|
98 |
+
scenes_string = ""
|
99 |
+
for scene, output in zip(all_scenes, all_outputs):
|
100 |
+
scenes_string += "\"" + output + "\" "
|
101 |
+
|
102 |
+
os.system("python metrics.py -m " + scenes_string)
|
gradio_demo.py
ADDED
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
import shutil
|
4 |
+
import tempfile
|
5 |
+
import argparse
|
6 |
+
import gradio as gr
|
7 |
+
import sys
|
8 |
+
import io
|
9 |
+
from PIL import Image
|
10 |
+
import numpy as np
|
11 |
+
from source.utils_aux import set_seed
|
12 |
+
from source.utils_preprocess import read_video_frames, preprocess_frames, select_optimal_frames, save_frames_to_scene_dir, run_colmap_on_scene
|
13 |
+
from source.trainer import EDGSTrainer
|
14 |
+
from hydra import initialize, compose
|
15 |
+
import hydra
|
16 |
+
import time
|
17 |
+
from source.visualization import generate_circular_camera_path, save_numpy_frames_as_mp4, generate_fully_smooth_cameras_with_tsp, put_text_on_image
|
18 |
+
import contextlib
|
19 |
+
import base64
|
20 |
+
|
21 |
+
|
22 |
+
# Init RoMA model:
|
23 |
+
sys.path.append('../submodules/RoMa')
|
24 |
+
from romatch import roma_outdoor, roma_indoor
|
25 |
+
|
26 |
+
roma_model = roma_indoor(device="cuda:0")
|
27 |
+
roma_model.upsample_preds = False
|
28 |
+
roma_model.symmetric = False
|
29 |
+
|
30 |
+
STATIC_FILE_SERVING_FOLDER = "./served_files"
|
31 |
+
MODEL_PATH = None
|
32 |
+
os.makedirs(STATIC_FILE_SERVING_FOLDER, exist_ok=True)
|
33 |
+
|
34 |
+
trainer = None
|
35 |
+
|
36 |
+
class Tee(io.TextIOBase):
|
37 |
+
def __init__(self, *streams):
|
38 |
+
self.streams = streams
|
39 |
+
|
40 |
+
def write(self, data):
|
41 |
+
for stream in self.streams:
|
42 |
+
stream.write(data)
|
43 |
+
return len(data)
|
44 |
+
|
45 |
+
def flush(self):
|
46 |
+
for stream in self.streams:
|
47 |
+
stream.flush()
|
48 |
+
|
49 |
+
def capture_logs(func, *args, **kwargs):
|
50 |
+
log_capture_string = io.StringIO()
|
51 |
+
tee = Tee(sys.__stdout__, log_capture_string)
|
52 |
+
with contextlib.redirect_stdout(tee):
|
53 |
+
result = func(*args, **kwargs)
|
54 |
+
return result, log_capture_string.getvalue()
|
55 |
+
|
56 |
+
# Training Pipeline
|
57 |
+
def run_training_pipeline(scene_dir,
|
58 |
+
num_ref_views=16,
|
59 |
+
num_corrs_per_view=20000,
|
60 |
+
num_steps=1_000,
|
61 |
+
mode_toggle="Ours (EDGS)"):
|
62 |
+
with initialize(config_path="./configs", version_base="1.1"):
|
63 |
+
cfg = compose(config_name="train")
|
64 |
+
|
65 |
+
scene_name = os.path.basename(scene_dir)
|
66 |
+
model_output_dir = f"./outputs/{scene_name}_trained"
|
67 |
+
|
68 |
+
cfg.wandb.mode = "disabled"
|
69 |
+
cfg.gs.dataset.model_path = model_output_dir
|
70 |
+
cfg.gs.dataset.source_path = scene_dir
|
71 |
+
cfg.gs.dataset.images = "images"
|
72 |
+
|
73 |
+
cfg.gs.opt.TEST_CAM_IDX_TO_LOG = 12
|
74 |
+
cfg.train.gs_epochs = 30000
|
75 |
+
|
76 |
+
if mode_toggle=="Ours (EDGS)":
|
77 |
+
cfg.gs.opt.opacity_reset_interval = 1_000_000
|
78 |
+
cfg.train.reduce_opacity = True
|
79 |
+
cfg.train.no_densify = True
|
80 |
+
cfg.train.max_lr = True
|
81 |
+
|
82 |
+
cfg.init_wC.use = True
|
83 |
+
cfg.init_wC.matches_per_ref = num_corrs_per_view
|
84 |
+
cfg.init_wC.nns_per_ref = 1
|
85 |
+
cfg.init_wC.num_refs = num_ref_views
|
86 |
+
cfg.init_wC.add_SfM_init = False
|
87 |
+
cfg.init_wC.scaling_factor = 0.00077 * 2.
|
88 |
+
|
89 |
+
set_seed(cfg.seed)
|
90 |
+
os.makedirs(cfg.gs.dataset.model_path, exist_ok=True)
|
91 |
+
|
92 |
+
global trainer
|
93 |
+
global MODEL_PATH
|
94 |
+
generator3dgs = hydra.utils.instantiate(cfg.gs, do_train_test_split=False)
|
95 |
+
trainer = EDGSTrainer(GS=generator3dgs, training_config=cfg.gs.opt, device=cfg.device, log_wandb=cfg.wandb.mode != 'disabled')
|
96 |
+
|
97 |
+
# Disable evaluation and saving
|
98 |
+
trainer.saving_iterations = []
|
99 |
+
trainer.evaluate_iterations = []
|
100 |
+
|
101 |
+
# Initialize
|
102 |
+
trainer.timer.start()
|
103 |
+
start_time = time.time()
|
104 |
+
trainer.init_with_corr(cfg.init_wC, roma_model=roma_model)
|
105 |
+
time_for_init = time.time()-start_time
|
106 |
+
|
107 |
+
viewpoint_cams = trainer.GS.scene.getTrainCameras()
|
108 |
+
path_cameras = generate_fully_smooth_cameras_with_tsp(existing_cameras=viewpoint_cams,
|
109 |
+
n_selected=6, # 8
|
110 |
+
n_points_per_segment=30, # 30
|
111 |
+
closed=False)
|
112 |
+
path_cameras = path_cameras + path_cameras[::-1]
|
113 |
+
|
114 |
+
path_renderings = []
|
115 |
+
idx = 0
|
116 |
+
# Visualize after init
|
117 |
+
for _ in range(120):
|
118 |
+
with torch.no_grad():
|
119 |
+
viewpoint_cam = path_cameras[idx]
|
120 |
+
idx = (idx + 1) % len(path_cameras)
|
121 |
+
render_pkg = trainer.GS(viewpoint_cam)
|
122 |
+
image = render_pkg["render"]
|
123 |
+
image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1)
|
124 |
+
image_np = (image_np * 255).astype(np.uint8)
|
125 |
+
path_renderings.append(put_text_on_image(img=image_np,
|
126 |
+
text=f"Init stage.\nTime:{time_for_init:.3f}s. "))
|
127 |
+
path_renderings = path_renderings + [put_text_on_image(img=image_np, text=f"Start fitting.\nTime:{time_for_init:.3f}s. ")]*30
|
128 |
+
|
129 |
+
# Train and save visualizations during training.
|
130 |
+
start_time = time.time()
|
131 |
+
for _ in range(int(num_steps//10)):
|
132 |
+
with torch.no_grad():
|
133 |
+
viewpoint_cam = path_cameras[idx]
|
134 |
+
idx = (idx + 1) % len(path_cameras)
|
135 |
+
render_pkg = trainer.GS(viewpoint_cam)
|
136 |
+
image = render_pkg["render"]
|
137 |
+
image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1)
|
138 |
+
image_np = (image_np * 255).astype(np.uint8)
|
139 |
+
path_renderings.append(put_text_on_image(
|
140 |
+
img=image_np,
|
141 |
+
text=f"Fitting stage.\nTime:{time_for_init + time.time()-start_time:.3f}s. "))
|
142 |
+
|
143 |
+
cfg.train.gs_epochs = 10
|
144 |
+
trainer.train(cfg.train)
|
145 |
+
print(f"Time elapsed: {(time_for_init + time.time()-start_time):.2f}s.")
|
146 |
+
# if (cfg.init_wC.use == False) and (time_for_init + time.time()-start_time) > 60:
|
147 |
+
# break
|
148 |
+
final_time = time.time()
|
149 |
+
|
150 |
+
# Add static frame. To highlight we're done
|
151 |
+
path_renderings += [put_text_on_image(
|
152 |
+
img=image_np, text=f"Done.\nTime:{time_for_init + final_time -start_time:.3f}s. ")]*30
|
153 |
+
# Final rendering at the end.
|
154 |
+
for _ in range(len(path_cameras)):
|
155 |
+
with torch.no_grad():
|
156 |
+
viewpoint_cam = path_cameras[idx]
|
157 |
+
idx = (idx + 1) % len(path_cameras)
|
158 |
+
render_pkg = trainer.GS(viewpoint_cam)
|
159 |
+
image = render_pkg["render"]
|
160 |
+
image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1)
|
161 |
+
image_np = (image_np * 255).astype(np.uint8)
|
162 |
+
path_renderings.append(put_text_on_image(img=image_np,
|
163 |
+
text=f"Final result.\nTime:{time_for_init + final_time -start_time:.3f}s. "))
|
164 |
+
|
165 |
+
trainer.save_model()
|
166 |
+
final_video_path = os.path.join(STATIC_FILE_SERVING_FOLDER, f"{scene_name}_final.mp4")
|
167 |
+
save_numpy_frames_as_mp4(frames=path_renderings, output_path=final_video_path, fps=30, center_crop=0.85)
|
168 |
+
MODEL_PATH = cfg.gs.dataset.model_path
|
169 |
+
ply_path = os.path.join(cfg.gs.dataset.model_path, f"point_cloud/iteration_{trainer.gs_step}/point_cloud.ply")
|
170 |
+
shutil.copy(ply_path, os.path.join(STATIC_FILE_SERVING_FOLDER, "point_cloud_final.ply"))
|
171 |
+
|
172 |
+
return final_video_path, ply_path
|
173 |
+
|
174 |
+
# Gradio Interface
|
175 |
+
def gradio_interface(input_path, num_ref_views, num_corrs, num_steps):
|
176 |
+
images, scene_dir = run_full_pipeline(input_path, num_ref_views, num_corrs, max_size=1024)
|
177 |
+
shutil.copytree(scene_dir, STATIC_FILE_SERVING_FOLDER+'/scene_colmaped', dirs_exist_ok=True)
|
178 |
+
(final_video_path, ply_path), log_output = capture_logs(run_training_pipeline,
|
179 |
+
scene_dir,
|
180 |
+
num_ref_views,
|
181 |
+
num_corrs,
|
182 |
+
num_steps)
|
183 |
+
images_rgb = [img[:, :, ::-1] for img in images]
|
184 |
+
return images_rgb, final_video_path, scene_dir, ply_path, log_output
|
185 |
+
|
186 |
+
# Dummy Render Functions
|
187 |
+
def render_all_views(scene_dir):
|
188 |
+
viewpoint_cams = trainer.GS.scene.getTrainCameras()
|
189 |
+
path_cameras = generate_fully_smooth_cameras_with_tsp(existing_cameras=viewpoint_cams,
|
190 |
+
n_selected=8,
|
191 |
+
n_points_per_segment=60,
|
192 |
+
closed=False)
|
193 |
+
path_cameras = path_cameras + path_cameras[::-1]
|
194 |
+
|
195 |
+
path_renderings = []
|
196 |
+
with torch.no_grad():
|
197 |
+
for viewpoint_cam in path_cameras:
|
198 |
+
render_pkg = trainer.GS(viewpoint_cam)
|
199 |
+
image = render_pkg["render"]
|
200 |
+
image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1)
|
201 |
+
image_np = (image_np * 255).astype(np.uint8)
|
202 |
+
path_renderings.append(image_np)
|
203 |
+
save_numpy_frames_as_mp4(frames=path_renderings,
|
204 |
+
output_path=os.path.join(STATIC_FILE_SERVING_FOLDER, "render_all_views.mp4"),
|
205 |
+
fps=30,
|
206 |
+
center_crop=0.85)
|
207 |
+
|
208 |
+
return os.path.join(STATIC_FILE_SERVING_FOLDER, "render_all_views.mp4")
|
209 |
+
|
210 |
+
def render_circular_path(scene_dir):
|
211 |
+
viewpoint_cams = trainer.GS.scene.getTrainCameras()
|
212 |
+
path_cameras = generate_circular_camera_path(existing_cameras=viewpoint_cams,
|
213 |
+
N=240,
|
214 |
+
radius_scale=0.65,
|
215 |
+
d=0)
|
216 |
+
|
217 |
+
path_renderings = []
|
218 |
+
with torch.no_grad():
|
219 |
+
for viewpoint_cam in path_cameras:
|
220 |
+
render_pkg = trainer.GS(viewpoint_cam)
|
221 |
+
image = render_pkg["render"]
|
222 |
+
image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1)
|
223 |
+
image_np = (image_np * 255).astype(np.uint8)
|
224 |
+
path_renderings.append(image_np)
|
225 |
+
save_numpy_frames_as_mp4(frames=path_renderings,
|
226 |
+
output_path=os.path.join(STATIC_FILE_SERVING_FOLDER, "render_circular_path.mp4"),
|
227 |
+
fps=30,
|
228 |
+
center_crop=0.85)
|
229 |
+
|
230 |
+
return os.path.join(STATIC_FILE_SERVING_FOLDER, "render_circular_path.mp4")
|
231 |
+
|
232 |
+
# Download Functions
|
233 |
+
def download_cameras():
|
234 |
+
path = os.path.join(MODEL_PATH, "cameras.json")
|
235 |
+
return f"[📥 Download Cameras.json](file={path})"
|
236 |
+
|
237 |
+
def download_model():
|
238 |
+
path = os.path.join(STATIC_FILE_SERVING_FOLDER, "point_cloud_final.ply")
|
239 |
+
return f"[📥 Download Pretrained Model (.ply)](file={path})"
|
240 |
+
|
241 |
+
# Full pipeline helpers
|
242 |
+
def run_full_pipeline(input_path, num_ref_views, num_corrs, max_size=1024):
|
243 |
+
tmpdirname = tempfile.mkdtemp()
|
244 |
+
scene_dir = os.path.join(tmpdirname, "scene")
|
245 |
+
os.makedirs(scene_dir, exist_ok=True)
|
246 |
+
|
247 |
+
selected_frames = process_input(input_path, num_ref_views, scene_dir, max_size)
|
248 |
+
run_colmap_on_scene(scene_dir)
|
249 |
+
|
250 |
+
return selected_frames, scene_dir
|
251 |
+
|
252 |
+
# Preprocess Input
|
253 |
+
def process_input(input_path, num_ref_views, output_dir, max_size=1024):
|
254 |
+
if isinstance(input_path, (str, os.PathLike)):
|
255 |
+
if os.path.isdir(input_path):
|
256 |
+
frames = []
|
257 |
+
for img_file in sorted(os.listdir(input_path)):
|
258 |
+
if img_file.lower().endswith(('jpg', 'jpeg', 'png')):
|
259 |
+
img = Image.open(os.path.join(output_dir, img_file)).convert('RGB')
|
260 |
+
img.thumbnail((1024, 1024))
|
261 |
+
frames.append(np.array(img))
|
262 |
+
else:
|
263 |
+
frames = read_video_frames(video_input=input_path, max_size=max_size)
|
264 |
+
else:
|
265 |
+
frames = read_video_frames(video_input=input_path, max_size=max_size)
|
266 |
+
|
267 |
+
frames_scores = preprocess_frames(frames)
|
268 |
+
selected_frames_indices = select_optimal_frames(scores=frames_scores,
|
269 |
+
k=min(num_ref_views, len(frames)))
|
270 |
+
selected_frames = [frames[frame_idx] for frame_idx in selected_frames_indices]
|
271 |
+
|
272 |
+
save_frames_to_scene_dir(frames=selected_frames, scene_dir=output_dir)
|
273 |
+
return selected_frames
|
274 |
+
|
275 |
+
def preprocess_input(input_path, num_ref_views, max_size=1024):
|
276 |
+
tmpdirname = tempfile.mkdtemp()
|
277 |
+
scene_dir = os.path.join(tmpdirname, "scene")
|
278 |
+
os.makedirs(scene_dir, exist_ok=True)
|
279 |
+
selected_frames = process_input(input_path, num_ref_views, scene_dir, max_size)
|
280 |
+
run_colmap_on_scene(scene_dir)
|
281 |
+
return selected_frames, scene_dir
|
282 |
+
|
283 |
+
def start_training(scene_dir, num_ref_views, num_corrs, num_steps):
|
284 |
+
return capture_logs(run_training_pipeline, scene_dir, num_ref_views, num_corrs, num_steps)
|
285 |
+
|
286 |
+
|
287 |
+
# Gradio App
|
288 |
+
with gr.Blocks() as demo:
|
289 |
+
with gr.Row():
|
290 |
+
with gr.Column(scale=6):
|
291 |
+
gr.Markdown("""
|
292 |
+
## <span style='font-size: 20px;'>📄 EDGS: Eliminating Densification for Efficient Convergence of 3DGS</span>
|
293 |
+
🔗 <a href='https://compvis.github.io/EDGS' target='_blank'>Project Page</a>
|
294 |
+
""", elem_id="header")
|
295 |
+
|
296 |
+
gr.Markdown("""
|
297 |
+
### <span style='font-size: 22px;'>🛠️ How to Use This Demo</span>
|
298 |
+
|
299 |
+
1. Upload a **front-facing video** or **a folder of images** of a **static** scene.
|
300 |
+
2. Use the sliders to configure the number of reference views, correspondences, and optimization steps.
|
301 |
+
3. First press on preprocess Input to extract frames from video(for videos) and COLMAP frames.
|
302 |
+
4. Then click **🚀 Start Reconstruction** to actually launch the reconstruction pipeline.
|
303 |
+
5. Watch the training visualization and explore the 3D model.
|
304 |
+
‼️ **If you see nothing in the 3D model viewer**, try rotating or zooming — sometimes the initial camera orientation is off.
|
305 |
+
|
306 |
+
|
307 |
+
✅ Best for scenes with small camera motion.
|
308 |
+
❗ For full 360° or large-scale scenes, we recommend the Colab version (see project page).
|
309 |
+
""", elem_id="quickstart")
|
310 |
+
|
311 |
+
|
312 |
+
scene_dir_state = gr.State()
|
313 |
+
ply_model_state = gr.State()
|
314 |
+
|
315 |
+
with gr.Row():
|
316 |
+
with gr.Column(scale=2):
|
317 |
+
input_file = gr.File(label="Upload Video or Images",
|
318 |
+
file_types=[".mp4", ".avi", ".mov", ".png", ".jpg", ".jpeg"],
|
319 |
+
file_count="multiple")
|
320 |
+
gr.Examples(
|
321 |
+
examples = [
|
322 |
+
[["assets/examples/video_bakery.mp4"]],
|
323 |
+
[["assets/examples/video_flowers.mp4"]],
|
324 |
+
[["assets/examples/video_fruits.mp4"]],
|
325 |
+
[["assets/examples/video_plant.mp4"]],
|
326 |
+
[["assets/examples/video_salad.mp4"]],
|
327 |
+
[["assets/examples/video_tram.mp4"]],
|
328 |
+
[["assets/examples/video_tulips.mp4"]]
|
329 |
+
],
|
330 |
+
inputs=[input_file],
|
331 |
+
label="🎞️ ALternatively, try an Example Video",
|
332 |
+
examples_per_page=4
|
333 |
+
)
|
334 |
+
ref_slider = gr.Slider(4, 32, value=16, step=1, label="Number of Reference Views")
|
335 |
+
corr_slider = gr.Slider(5000, 30000, value=20000, step=1000, label="Correspondences per Reference View")
|
336 |
+
fit_steps_slider = gr.Slider(100, 5000, value=400, step=100, label="Number of optimization steps")
|
337 |
+
preprocess_button = gr.Button("📸 Preprocess Input")
|
338 |
+
start_button = gr.Button("🚀 Start Reconstruction", interactive=False)
|
339 |
+
gallery = gr.Gallery(label="Selected Reference Views", columns=4, height=300)
|
340 |
+
|
341 |
+
with gr.Column(scale=3):
|
342 |
+
gr.Markdown("### 🏋️ Training Visualization")
|
343 |
+
video_output = gr.Video(label="Training Video", autoplay=True)
|
344 |
+
render_all_views_button = gr.Button("🎥 Render All-Views Path")
|
345 |
+
render_circular_path_button = gr.Button("🎥 Render Circular Path")
|
346 |
+
rendered_video_output = gr.Video(label="Rendered Video", autoplay=True)
|
347 |
+
with gr.Column(scale=5):
|
348 |
+
gr.Markdown("### 🌐 Final 3D Model")
|
349 |
+
model3d_viewer = gr.Model3D(label="3D Model Viewer")
|
350 |
+
|
351 |
+
gr.Markdown("### 📦 Output Files")
|
352 |
+
with gr.Row(height=50):
|
353 |
+
with gr.Column():
|
354 |
+
#gr.Markdown(value=f"[📥 Download .ply](file/point_cloud_final.ply)")
|
355 |
+
download_cameras_button = gr.Button("📥 Download Cameras.json")
|
356 |
+
download_cameras_file = gr.File(label="📄 Cameras.json")
|
357 |
+
with gr.Column():
|
358 |
+
download_model_button = gr.Button("📥 Download Pretrained Model (.ply)")
|
359 |
+
download_model_file = gr.File(label="📄 Pretrained Model (.ply)")
|
360 |
+
|
361 |
+
log_output_box = gr.Textbox(label="🖥️ Log", lines=10, interactive=False)
|
362 |
+
|
363 |
+
def on_preprocess_click(input_file, num_ref_views):
|
364 |
+
images, scene_dir = preprocess_input(input_file, num_ref_views)
|
365 |
+
return gr.update(value=[x[...,::-1] for x in images]), scene_dir, gr.update(interactive=True)
|
366 |
+
|
367 |
+
def on_start_click(scene_dir, num_ref_views, num_corrs, num_steps):
|
368 |
+
(video_path, ply_path), logs = start_training(scene_dir, num_ref_views, num_corrs, num_steps)
|
369 |
+
return video_path, ply_path, logs
|
370 |
+
|
371 |
+
preprocess_button.click(
|
372 |
+
fn=on_preprocess_click,
|
373 |
+
inputs=[input_file, ref_slider],
|
374 |
+
outputs=[gallery, scene_dir_state, start_button]
|
375 |
+
)
|
376 |
+
|
377 |
+
start_button.click(
|
378 |
+
fn=on_start_click,
|
379 |
+
inputs=[scene_dir_state, ref_slider, corr_slider, fit_steps_slider],
|
380 |
+
outputs=[video_output, model3d_viewer, log_output_box]
|
381 |
+
)
|
382 |
+
|
383 |
+
render_all_views_button.click(fn=render_all_views, inputs=[scene_dir_state], outputs=[rendered_video_output])
|
384 |
+
render_circular_path_button.click(fn=render_circular_path, inputs=[scene_dir_state], outputs=[rendered_video_output])
|
385 |
+
|
386 |
+
download_cameras_button.click(fn=lambda: os.path.join(MODEL_PATH, "cameras.json"), inputs=[], outputs=[download_cameras_file])
|
387 |
+
download_model_button.click(fn=lambda: os.path.join(STATIC_FILE_SERVING_FOLDER, "point_cloud_final.ply"), inputs=[], outputs=[download_model_file])
|
388 |
+
|
389 |
+
|
390 |
+
gr.Markdown("""
|
391 |
+
---
|
392 |
+
### <span style='font-size: 20px;'>📖 Detailed Overview</span>
|
393 |
+
|
394 |
+
If you uploaded a video, it will be automatically cut into a smaller number of frames (default: 16).
|
395 |
+
|
396 |
+
The model pipeline:
|
397 |
+
1. 🧠 Runs PyCOLMAP to estimate camera intrinsics & poses (~3–7 seconds for <16 images).
|
398 |
+
2. 🔁 Computes 2D-2D correspondences between views. More correspondences generally improve quality.
|
399 |
+
3. 🔧 Optimizes a 3D Gaussian Splatting model for several steps.
|
400 |
+
|
401 |
+
### 🎥 Training Visualization
|
402 |
+
You will see a visualization of the entire training process in the "Training Video" pane.
|
403 |
+
|
404 |
+
### 🌀 Rendering & 3D Model
|
405 |
+
- Render the scene from a circular path of novel views.
|
406 |
+
- Or from camera views close to the original input.
|
407 |
+
|
408 |
+
The 3D model is shown in the right viewer. You can explore it interactively:
|
409 |
+
- On PC: WASD keys, arrow keys, and mouse clicks
|
410 |
+
- On mobile: pan and pinch to zoom
|
411 |
+
|
412 |
+
🕒 Note: the 3D viewer takes a few extra seconds (~5s) to display after training ends.
|
413 |
+
|
414 |
+
---
|
415 |
+
Preloaded models coming soon. (TODO)
|
416 |
+
""", elem_id="details")
|
417 |
+
|
418 |
+
if __name__ == "__main__":
|
419 |
+
parser = argparse.ArgumentParser(description="Launch Gradio demo for EDGS preprocessing and 3D viewing.")
|
420 |
+
parser.add_argument("--port", type=int, default=7860, help="Port to launch the Gradio app on.")
|
421 |
+
parser.add_argument("--no_share", action='store_true', help="Disable Gradio sharing and assume local access (default: share=True)")
|
422 |
+
args = parser.parse_args()
|
423 |
+
|
424 |
+
demo.launch(server_name="0.0.0.0", server_port=args.port, share=not args.no_share)
|
install.sh
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
git clone [email protected]:CompVis/EDGS.git --recursive
|
3 |
+
cd EDGS
|
4 |
+
git submodule update --init --recursive
|
5 |
+
|
6 |
+
conda create -y -n edgs python=3.10 pip
|
7 |
+
conda activate edgs
|
8 |
+
|
9 |
+
# Optionally set path to CUDA
|
10 |
+
export CUDA_HOME=/usr/local/cuda-12.1
|
11 |
+
export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH
|
12 |
+
export PATH=$CUDA_HOME/bin:$PATH
|
13 |
+
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia -y
|
14 |
+
conda install nvidia/label/cuda-12.1.0::cuda-toolkit -y
|
15 |
+
|
16 |
+
pip install -e submodules/gaussian-splatting/submodules/diff-gaussian-rasterization
|
17 |
+
pip install -e submodules/gaussian-splatting/submodules/simple-knn
|
18 |
+
|
19 |
+
# For COLMAP and pycolmap
|
20 |
+
# Optionally install original colmap but probably pycolmap suffices
|
21 |
+
# conda install conda-forge/label/colmap_dev::colmap
|
22 |
+
pip install pycolmap
|
23 |
+
|
24 |
+
|
25 |
+
pip install wandb hydra-core tqdm torchmetrics lpips matplotlib rich plyfile imageio imageio-ffmpeg
|
26 |
+
conda install numpy=1.26.4 -y -c conda-forge --override-channels
|
27 |
+
|
28 |
+
pip install -e submodules/RoMa
|
29 |
+
conda install anaconda::jupyter --yes
|
30 |
+
|
31 |
+
# Stuff necessary for gradio and visualizations
|
32 |
+
pip install gradio
|
33 |
+
pip install plotly scikit-learn moviepy==2.1.1 ffmpeg
|
34 |
+
pip install open3d
|
35 |
+
|
main.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
import shutil
|
4 |
+
import tempfile
|
5 |
+
import uuid
|
6 |
+
import asyncio
|
7 |
+
import io
|
8 |
+
import time
|
9 |
+
import contextlib
|
10 |
+
import base64
|
11 |
+
from PIL import Image
|
12 |
+
import numpy as np
|
13 |
+
from fastapi import FastAPI, UploadFile, File, HTTPException, Body
|
14 |
+
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse
|
15 |
+
from pydantic import BaseModel, Field
|
16 |
+
|
17 |
+
try:
|
18 |
+
from source.utils_aux import set_seed
|
19 |
+
from source.utils_preprocess import read_video_frames, preprocess_frames, select_optimal_frames, save_frames_to_scene_dir, run_colmap_on_scene
|
20 |
+
from source.trainer import EDGSTrainer
|
21 |
+
from hydra import initialize, compose
|
22 |
+
import hydra
|
23 |
+
from source.visualization import generate_fully_smooth_cameras_with_tsp, put_text_on_image
|
24 |
+
import sys
|
25 |
+
sys.path.append('../submodules/RoMa') # Ajusta esta ruta si es necesario
|
26 |
+
from romatch import roma_indoor
|
27 |
+
except ImportError as e:
|
28 |
+
print(f"Error: No se pudieron importar los módulos del proyecto EDGS. Asegúrate de que las rutas y la instalación son correctas. {e}")
|
29 |
+
sys.exit(1)
|
30 |
+
|
31 |
+
# --- Configuración Inicial ---
|
32 |
+
# 1. Inicialización de la App FastAPI
|
33 |
+
app = FastAPI(
|
34 |
+
title="EDGS Training API",
|
35 |
+
description="Una API para preprocesar videos y entrenar modelos 3DGS con EDGS.",
|
36 |
+
version="1.0.0"
|
37 |
+
)
|
38 |
+
|
39 |
+
# 2. Variables Globales y Almacenamiento de Estado
|
40 |
+
# El modelo se cargará en el evento 'startup'
|
41 |
+
roma_model = None
|
42 |
+
|
43 |
+
# Base de datos en memoria para gestionar el estado de las tareas entre endpoints
|
44 |
+
tasks_db = {}
|
45 |
+
|
46 |
+
# 3. Modelos Pydantic para la validación de datos
|
47 |
+
class TrainParams(BaseModel):
|
48 |
+
num_corrs_per_view: int = Field(20000, gt=0, description="Correspondencias por vista de referencia.")
|
49 |
+
num_steps: int = Field(1000, gt=0, description="Número de pasos de optimización.")
|
50 |
+
|
51 |
+
class PreprocessResponse(BaseModel):
|
52 |
+
task_id: str
|
53 |
+
message: str
|
54 |
+
selected_frames_count: int
|
55 |
+
# Opcional: podrías devolver las imágenes en base64 si el cliente las necesita visualizar
|
56 |
+
# frames: list[str]
|
57 |
+
|
58 |
+
# --- Lógica de Negocio (Adaptada del script de Gradio) ---
|
59 |
+
|
60 |
+
# Esta función se ejecutará en un hilo separado para no bloquear el servidor
|
61 |
+
def run_preprocessing_sync(input_path: str, num_ref_views: int):
|
62 |
+
"""
|
63 |
+
Ejecuta el preprocesamiento: selección de frames y ejecución de COLMAP.
|
64 |
+
"""
|
65 |
+
tmpdirname = tempfile.mkdtemp()
|
66 |
+
scene_dir = os.path.join(tmpdirname, "scene")
|
67 |
+
os.makedirs(scene_dir, exist_ok=True)
|
68 |
+
|
69 |
+
# 1. Lee y selecciona los mejores frames
|
70 |
+
frames = read_video_frames(video_input=input_path, max_size=1024)
|
71 |
+
frames_scores = preprocess_frames(frames)
|
72 |
+
selected_frames_indices = select_optimal_frames(scores=frames_scores, k=min(num_ref_views, len(frames)))
|
73 |
+
selected_frames = [frames[frame_idx] for frame_idx in selected_frames_indices]
|
74 |
+
|
75 |
+
# 2. Guarda los frames y ejecuta COLMAP
|
76 |
+
save_frames_to_scene_dir(frames=selected_frames, scene_dir=scene_dir)
|
77 |
+
run_colmap_on_scene(scene_dir)
|
78 |
+
|
79 |
+
return scene_dir, selected_frames
|
80 |
+
|
81 |
+
async def training_log_generator(scene_dir: str, num_ref_views: int, params: TrainParams, task_id: str):
|
82 |
+
"""
|
83 |
+
Un generador asíncrono que ejecuta el entrenamiento. Los logs detallados se muestran
|
84 |
+
en la terminal del servidor, mientras que el cliente recibe un stream de progreso simple.
|
85 |
+
"""
|
86 |
+
def training_pipeline():
|
87 |
+
try:
|
88 |
+
# La inicialización y configuración de Hydra se mantienen igual
|
89 |
+
with initialize(config_path="./configs", version_base="1.1"):
|
90 |
+
cfg = compose(config_name="train")
|
91 |
+
|
92 |
+
# --- CONFIGURACIÓN COMPLETA ---
|
93 |
+
scene_name = os.path.basename(scene_dir)
|
94 |
+
model_output_dir = f"./outputs/{scene_name}_trained"
|
95 |
+
cfg.wandb.mode = "disabled"
|
96 |
+
cfg.gs.dataset.model_path = model_output_dir
|
97 |
+
cfg.gs.dataset.source_path = scene_dir
|
98 |
+
cfg.gs.dataset.images = "images"
|
99 |
+
cfg.train.gs_epochs = 30000
|
100 |
+
cfg.gs.opt.opacity_reset_interval = 1_000_000
|
101 |
+
cfg.train.reduce_opacity = True
|
102 |
+
cfg.train.no_densify = True
|
103 |
+
cfg.train.max_lr = True
|
104 |
+
cfg.init_wC.use = True
|
105 |
+
cfg.init_wC.matches_per_ref = params.num_corrs_per_view
|
106 |
+
cfg.init_wC.nns_per_ref = 1
|
107 |
+
cfg.init_wC.num_refs = num_ref_views
|
108 |
+
cfg.init_wC.add_SfM_init = False
|
109 |
+
cfg.init_wC.scaling_factor = 0.00077 * 2.
|
110 |
+
|
111 |
+
set_seed(cfg.seed)
|
112 |
+
os.makedirs(cfg.gs.dataset.model_path, exist_ok=True)
|
113 |
+
|
114 |
+
device = cfg.device
|
115 |
+
generator3dgs = hydra.utils.instantiate(cfg.gs, do_train_test_split=False)
|
116 |
+
trainer = EDGSTrainer(GS=generator3dgs, training_config=cfg.gs.opt, device=device, log_wandb=False)
|
117 |
+
trainer.saving_iterations = []
|
118 |
+
trainer.evaluate_iterations = []
|
119 |
+
trainer.timer.start()
|
120 |
+
|
121 |
+
# Mensaje de progreso para el cliente antes de la inicialización
|
122 |
+
yield "data: Inicializando modelo...\n\n"
|
123 |
+
trainer.init_with_corr(cfg.init_wC, roma_model=roma_model)
|
124 |
+
|
125 |
+
# El bucle de entrenamiento principal
|
126 |
+
for step in range(int(params.num_steps // 10)):
|
127 |
+
cfg.train.gs_epochs = 10
|
128 |
+
# trainer.train() ahora imprimirá sus logs detallados directamente en la terminal
|
129 |
+
trainer.train(cfg.train)
|
130 |
+
|
131 |
+
# --- CAMBIO CLAVE ---
|
132 |
+
# Envía un mensaje de progreso simple al cliente en lugar de los logs capturados.
|
133 |
+
yield f"data: Progreso: {step*10+10}/{params.num_steps} pasos completados.\n\n"
|
134 |
+
|
135 |
+
trainer.save_model()
|
136 |
+
ply_path = os.path.join(cfg.gs.dataset.model_path, f"point_cloud/iteration_{trainer.gs_step}/point_cloud.ply")
|
137 |
+
|
138 |
+
tasks_db[task_id]['result_ply_path'] = ply_path
|
139 |
+
|
140 |
+
final_message = "Entrenamiento completado. El modelo está listo para descargar."
|
141 |
+
yield f"data: {final_message}\n\n"
|
142 |
+
|
143 |
+
except Exception as e:
|
144 |
+
yield f"data: ERROR: {repr(e)}\n\n"
|
145 |
+
|
146 |
+
# El bucle que llama a la pipeline se mantiene igual
|
147 |
+
training_gen = training_pipeline()
|
148 |
+
for log_message in training_gen:
|
149 |
+
yield log_message
|
150 |
+
await asyncio.sleep(0.1)
|
151 |
+
|
152 |
+
# --- Eventos de Ciclo de Vida de la App ---
|
153 |
+
|
154 |
+
@app.on_event("startup")
|
155 |
+
async def startup_event():
|
156 |
+
"""
|
157 |
+
Carga el modelo RoMa cuando el servidor se inicia.
|
158 |
+
"""
|
159 |
+
global roma_model
|
160 |
+
print("🚀 Iniciando servidor FastAPI...")
|
161 |
+
if torch.cuda.is_available():
|
162 |
+
device = "cuda:0"
|
163 |
+
print("✅ GPU detectada. Usando CUDA.")
|
164 |
+
else:
|
165 |
+
device = "cpu"
|
166 |
+
print("⚠️ No se detectó GPU. Usando CPU (puede ser muy lento).")
|
167 |
+
|
168 |
+
roma_model = roma_indoor(device=device)
|
169 |
+
roma_model.upsample_preds = False
|
170 |
+
roma_model.symmetric = False
|
171 |
+
print("🤖 Modelo RoMa cargado y listo.")
|
172 |
+
|
173 |
+
# --- Endpoints de la API ---
|
174 |
+
|
175 |
+
@app.post("/preprocess", response_model=PreprocessResponse)
|
176 |
+
async def preprocess_video(
|
177 |
+
num_ref_views: int = Body(16, embed=True, description="Número de vistas de referencia a extraer del video."),
|
178 |
+
video: UploadFile = File(..., description="Archivo de video a procesar (.mp4, .mov).")
|
179 |
+
):
|
180 |
+
"""
|
181 |
+
Recibe un video, lo preprocesa (extrae frames + COLMAP) y prepara para el entrenamiento.
|
182 |
+
"""
|
183 |
+
if not video.filename.lower().endswith(('.mp4', '.avi', '.mov')):
|
184 |
+
raise HTTPException(status_code=400, detail="Formato de archivo no soportado. Usa .mp4, .avi, o .mov.")
|
185 |
+
|
186 |
+
# Guarda el video temporalmente para que la librería pueda procesarlo
|
187 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=video.filename) as tmp_video:
|
188 |
+
shutil.copyfileobj(video.file, tmp_video)
|
189 |
+
tmp_video_path = tmp_video.name
|
190 |
+
|
191 |
+
try:
|
192 |
+
loop = asyncio.get_running_loop()
|
193 |
+
# Ejecuta la función síncrona y bloqueante en un executor para no bloquear el servidor
|
194 |
+
scene_dir, selected_frames = await loop.run_in_executor(
|
195 |
+
None, run_preprocessing_sync, tmp_video_path, num_ref_views
|
196 |
+
)
|
197 |
+
|
198 |
+
# Genera un ID único para esta tarea y guarda la ruta
|
199 |
+
task_id = str(uuid.uuid4())
|
200 |
+
tasks_db[task_id] = {
|
201 |
+
"scene_dir": scene_dir,
|
202 |
+
"num_ref_views": len(selected_frames),
|
203 |
+
"result_ply_path": None
|
204 |
+
}
|
205 |
+
|
206 |
+
return JSONResponse(
|
207 |
+
status_code=200,
|
208 |
+
content={
|
209 |
+
"task_id": task_id,
|
210 |
+
"message": f"Preprocesamiento completado. Se generó el directorio de la escena. Listo para entrenar.",
|
211 |
+
"selected_frames_count": len(selected_frames)
|
212 |
+
}
|
213 |
+
)
|
214 |
+
except Exception as e:
|
215 |
+
raise HTTPException(status_code=500, detail=f"Error durante el preprocesamiento: {e}")
|
216 |
+
finally:
|
217 |
+
os.unlink(tmp_video_path) # Limpia el archivo de video temporal
|
218 |
+
|
219 |
+
|
220 |
+
@app.post("/train/{task_id}")
|
221 |
+
async def train_model(task_id: str, params: TrainParams):
|
222 |
+
"""
|
223 |
+
Inicia el entrenamiento para una tarea preprocesada.
|
224 |
+
Devuelve un stream de logs en tiempo real.
|
225 |
+
"""
|
226 |
+
if task_id not in tasks_db:
|
227 |
+
raise HTTPException(status_code=404, detail="Task ID no encontrado. Por favor, ejecuta el preprocesamiento primero.")
|
228 |
+
|
229 |
+
task_info = tasks_db[task_id]
|
230 |
+
scene_dir = task_info["scene_dir"]
|
231 |
+
num_ref_views = task_info["num_ref_views"]
|
232 |
+
|
233 |
+
return StreamingResponse(
|
234 |
+
training_log_generator(scene_dir, num_ref_views, params, task_id),
|
235 |
+
media_type="text/event-stream"
|
236 |
+
)
|
237 |
+
|
238 |
+
@app.get("/download/{task_id}")
|
239 |
+
async def download_ply_file(task_id: str):
|
240 |
+
"""
|
241 |
+
Permite descargar el archivo .ply resultante de un entrenamiento completado.
|
242 |
+
"""
|
243 |
+
if task_id not in tasks_db:
|
244 |
+
raise HTTPException(status_code=404, detail="Task ID no encontrado.")
|
245 |
+
|
246 |
+
task_info = tasks_db[task_id]
|
247 |
+
ply_path = task_info.get("result_ply_path")
|
248 |
+
|
249 |
+
if not ply_path:
|
250 |
+
raise HTTPException(status_code=404, detail="El entrenamiento no ha finalizado o el archivo aún no está disponible.")
|
251 |
+
|
252 |
+
if not os.path.exists(ply_path):
|
253 |
+
raise HTTPException(status_code=500, detail="Error: El archivo del modelo no se encuentra en el servidor.")
|
254 |
+
|
255 |
+
# Generamos un nombre de archivo amigable para el usuario
|
256 |
+
file_name = f"model_{task_id[:8]}.ply"
|
257 |
+
|
258 |
+
return FileResponse(
|
259 |
+
path=ply_path,
|
260 |
+
media_type='application/octet-stream',
|
261 |
+
filename=file_name
|
262 |
+
)
|
263 |
+
|
264 |
+
if __name__ == "__main__":
|
265 |
+
import uvicorn
|
266 |
+
# Para ejecutar: uvicorn main:app --reload
|
267 |
+
# El flag --reload es para desarrollo. Quítalo en producción.
|
268 |
+
uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=False)
|
metrics.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
from pathlib import Path
|
13 |
+
import os
|
14 |
+
import sys
|
15 |
+
from PIL import Image
|
16 |
+
import torch
|
17 |
+
import torchvision.transforms.functional as tf
|
18 |
+
sys.path.append('./submodules/gaussian-splatting/')
|
19 |
+
from utils.loss_utils import ssim
|
20 |
+
from lpipsPyTorch import lpips as lpips_3dgs
|
21 |
+
import json
|
22 |
+
from tqdm import tqdm
|
23 |
+
from utils.image_utils import psnr
|
24 |
+
from argparse import ArgumentParser
|
25 |
+
|
26 |
+
import lpips
|
27 |
+
|
28 |
+
def readImages(renders_dir, gt_dir):
|
29 |
+
renders = []
|
30 |
+
gts = []
|
31 |
+
image_names = []
|
32 |
+
for fname in os.listdir(renders_dir):
|
33 |
+
render = Image.open(renders_dir / fname)
|
34 |
+
gt = Image.open(gt_dir / fname)
|
35 |
+
renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda())
|
36 |
+
gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda())
|
37 |
+
image_names.append(fname)
|
38 |
+
return renders, gts, image_names
|
39 |
+
|
40 |
+
def evaluate(model_paths):
|
41 |
+
|
42 |
+
full_dict = {}
|
43 |
+
per_view_dict = {}
|
44 |
+
full_dict_polytopeonly = {}
|
45 |
+
per_view_dict_polytopeonly = {}
|
46 |
+
print("")
|
47 |
+
|
48 |
+
for scene_dir in model_paths:
|
49 |
+
#try:
|
50 |
+
print("Scene:", scene_dir)
|
51 |
+
full_dict[scene_dir] = {}
|
52 |
+
per_view_dict[scene_dir] = {}
|
53 |
+
full_dict_polytopeonly[scene_dir] = {}
|
54 |
+
per_view_dict_polytopeonly[scene_dir] = {}
|
55 |
+
|
56 |
+
test_dir = Path(scene_dir) / "test"
|
57 |
+
|
58 |
+
for method in os.listdir(test_dir):
|
59 |
+
print("Method:", method)
|
60 |
+
|
61 |
+
full_dict[scene_dir][method] = {}
|
62 |
+
per_view_dict[scene_dir][method] = {}
|
63 |
+
full_dict_polytopeonly[scene_dir][method] = {}
|
64 |
+
per_view_dict_polytopeonly[scene_dir][method] = {}
|
65 |
+
|
66 |
+
method_dir = test_dir / method
|
67 |
+
gt_dir = method_dir/ "gt"
|
68 |
+
renders_dir = method_dir / "renders"
|
69 |
+
renders, gts, image_names = readImages(renders_dir, gt_dir)
|
70 |
+
|
71 |
+
ssims = []
|
72 |
+
psnrs = []
|
73 |
+
lpipss = []
|
74 |
+
lpipss_3dgs = []
|
75 |
+
with torch.no_grad():
|
76 |
+
for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"):
|
77 |
+
ssims.append(ssim(renders[idx], gts[idx]))
|
78 |
+
psnrs.append(psnr(renders[idx], gts[idx]))
|
79 |
+
lpipss.append(lpips_fn(renders[idx], gts[idx]))
|
80 |
+
lpipss_3dgs.append(lpips_3dgs(renders[idx], gts[idx], net_type='vgg'))
|
81 |
+
torch.cuda.empty_cache()
|
82 |
+
|
83 |
+
print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5"))
|
84 |
+
print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5"))
|
85 |
+
print(" LPIPS: {:>12.7f}".format(torch.tensor(lpipss).mean(), ".5"))
|
86 |
+
print(" LPIPS_3dgs: {:>12.7f}".format(torch.tensor(lpipss_3dgs).mean(), ".5"))
|
87 |
+
print("")
|
88 |
+
|
89 |
+
full_dict[scene_dir][method].update({"SSIM": torch.tensor(ssims).mean().item(),
|
90 |
+
"PSNR": torch.tensor(psnrs).mean().item(),
|
91 |
+
"LPIPS": torch.tensor(lpipss).mean().item(),
|
92 |
+
"LPIPS_3dgs": torch.tensor(lpipss_3dgs).mean().item(),
|
93 |
+
})
|
94 |
+
per_view_dict[scene_dir][method].update({"SSIM": {name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)},
|
95 |
+
"PSNR": {name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)},
|
96 |
+
"LPIPS": {name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)},
|
97 |
+
"LPIPS_3dgs": {name: lp for lp, name in zip(torch.tensor(lpipss_3dgs).tolist(), image_names)},
|
98 |
+
})
|
99 |
+
|
100 |
+
with open(scene_dir + "/results.json", 'w') as fp:
|
101 |
+
json.dump(full_dict[scene_dir], fp, indent=True)
|
102 |
+
with open(scene_dir + "/per_view.json", 'w') as fp:
|
103 |
+
json.dump(per_view_dict[scene_dir], fp, indent=True)
|
104 |
+
#except:
|
105 |
+
# print("Unable to compute metrics for model", scene_dir)
|
106 |
+
|
107 |
+
if __name__ == "__main__":
|
108 |
+
device = torch.device("cuda:0")
|
109 |
+
torch.cuda.set_device(device)
|
110 |
+
lpips_fn = lpips.LPIPS(net='vgg').to(device)
|
111 |
+
# Set up command line argument parser
|
112 |
+
parser = ArgumentParser(description="Training script parameters")
|
113 |
+
parser.add_argument('--model_paths', '-m', required=True, nargs="+", type=str, default=[])
|
114 |
+
args = parser.parse_args()
|
115 |
+
evaluate(args.model_paths)
|
requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
fastapi
|
script.bash
ADDED
File without changes
|
source/EDGS.code-workspace
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"folders": [
|
3 |
+
{
|
4 |
+
"path": ".."
|
5 |
+
},
|
6 |
+
{
|
7 |
+
"path": "../../../../.."
|
8 |
+
}
|
9 |
+
],
|
10 |
+
"settings": {}
|
11 |
+
}
|
source/__init__.py
ADDED
File without changes
|
source/corr_init.py
ADDED
@@ -0,0 +1,907 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append('../')
|
3 |
+
sys.path.append("../submodules")
|
4 |
+
sys.path.append('../submodules/RoMa')
|
5 |
+
|
6 |
+
from matplotlib import pyplot as plt
|
7 |
+
from PIL import Image
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
#from tqdm import tqdm_notebook as tqdm
|
12 |
+
from tqdm import tqdm
|
13 |
+
from scipy.cluster.vq import kmeans, vq
|
14 |
+
from scipy.spatial.distance import cdist
|
15 |
+
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from romatch import roma_outdoor, roma_indoor
|
18 |
+
from utils.sh_utils import RGB2SH
|
19 |
+
from romatch.utils import get_tuple_transform_ops
|
20 |
+
|
21 |
+
import time
|
22 |
+
from collections import defaultdict
|
23 |
+
from tqdm import tqdm
|
24 |
+
|
25 |
+
|
26 |
+
def pairwise_distances(matrix):
|
27 |
+
"""
|
28 |
+
Computes the pairwise Euclidean distances between all vectors in the input matrix.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
matrix (torch.Tensor): Input matrix of shape [N, D], where N is the number of vectors and D is the dimensionality.
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
torch.Tensor: Pairwise distance matrix of shape [N, N].
|
35 |
+
"""
|
36 |
+
# Compute squared pairwise distances
|
37 |
+
squared_diff = torch.cdist(matrix, matrix, p=2)
|
38 |
+
return squared_diff
|
39 |
+
|
40 |
+
|
41 |
+
def k_closest_vectors(matrix, k):
|
42 |
+
"""
|
43 |
+
Finds the k-closest vectors for each vector in the input matrix based on Euclidean distance.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
matrix (torch.Tensor): Input matrix of shape [N, D], where N is the number of vectors and D is the dimensionality.
|
47 |
+
k (int): Number of closest vectors to return for each vector.
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
torch.Tensor: Indices of the k-closest vectors for each vector, excluding the vector itself.
|
51 |
+
"""
|
52 |
+
# Compute pairwise distances
|
53 |
+
distances = pairwise_distances(matrix)
|
54 |
+
|
55 |
+
# For each vector, sort distances and get the indices of the k-closest vectors (excluding itself)
|
56 |
+
# Set diagonal distances to infinity to exclude the vector itself from the nearest neighbors
|
57 |
+
distances.fill_diagonal_(float('inf'))
|
58 |
+
|
59 |
+
# Get the indices of the k smallest distances (k-closest vectors)
|
60 |
+
_, indices = torch.topk(distances, k, largest=False, dim=1)
|
61 |
+
|
62 |
+
return indices
|
63 |
+
|
64 |
+
|
65 |
+
def select_cameras_kmeans(cameras, K):
|
66 |
+
"""
|
67 |
+
Selects K cameras from a set using K-means clustering.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
cameras: NumPy array of shape (N, 16), representing N cameras with their 4x4 homogeneous matrices flattened.
|
71 |
+
K: Number of clusters (cameras to select).
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
selected_indices: List of indices of the cameras closest to the cluster centers.
|
75 |
+
"""
|
76 |
+
# Ensure input is a NumPy array
|
77 |
+
if not isinstance(cameras, np.ndarray):
|
78 |
+
cameras = np.asarray(cameras)
|
79 |
+
|
80 |
+
if cameras.shape[1] != 16:
|
81 |
+
raise ValueError("Each camera must have 16 values corresponding to a flattened 4x4 matrix.")
|
82 |
+
|
83 |
+
# Perform K-means clustering
|
84 |
+
cluster_centers, _ = kmeans(cameras, K)
|
85 |
+
|
86 |
+
# Assign each camera to a cluster and find distances to cluster centers
|
87 |
+
cluster_assignments, _ = vq(cameras, cluster_centers)
|
88 |
+
|
89 |
+
# Find the camera nearest to each cluster center
|
90 |
+
selected_indices = []
|
91 |
+
for k in range(K):
|
92 |
+
cluster_members = cameras[cluster_assignments == k]
|
93 |
+
distances = cdist([cluster_centers[k]], cluster_members)[0]
|
94 |
+
nearest_camera_idx = np.where(cluster_assignments == k)[0][np.argmin(distances)]
|
95 |
+
selected_indices.append(nearest_camera_idx)
|
96 |
+
|
97 |
+
return selected_indices
|
98 |
+
|
99 |
+
|
100 |
+
def compute_warp_and_confidence(viewpoint_cam1, viewpoint_cam2, roma_model, device="cuda", verbose=False, output_dict={}):
|
101 |
+
"""
|
102 |
+
Computes the warp and confidence between two viewpoint cameras using the roma_model.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
viewpoint_cam1: Source viewpoint camera.
|
106 |
+
viewpoint_cam2: Target viewpoint camera.
|
107 |
+
roma_model: Pre-trained Roma model for correspondence matching.
|
108 |
+
device: Device to run the computation on.
|
109 |
+
verbose: If True, displays the images.
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
certainty: Confidence tensor.
|
113 |
+
warp: Warp tensor.
|
114 |
+
imB: Processed image B as numpy array.
|
115 |
+
"""
|
116 |
+
# Prepare images
|
117 |
+
imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0)
|
118 |
+
imB = viewpoint_cam2.original_image.detach().cpu().numpy().transpose(1, 2, 0)
|
119 |
+
imA = Image.fromarray(np.clip(imA * 255, 0, 255).astype(np.uint8))
|
120 |
+
imB = Image.fromarray(np.clip(imB * 255, 0, 255).astype(np.uint8))
|
121 |
+
|
122 |
+
if verbose:
|
123 |
+
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(16, 8))
|
124 |
+
cax1 = ax[0].imshow(imA)
|
125 |
+
ax[0].set_title("Image 1")
|
126 |
+
cax2 = ax[1].imshow(imB)
|
127 |
+
ax[1].set_title("Image 2")
|
128 |
+
fig.colorbar(cax1, ax=ax[0])
|
129 |
+
fig.colorbar(cax2, ax=ax[1])
|
130 |
+
|
131 |
+
for axis in ax:
|
132 |
+
axis.axis('off')
|
133 |
+
# Save the figure into the dictionary
|
134 |
+
output_dict[f'image_pair'] = fig
|
135 |
+
|
136 |
+
# Transform images
|
137 |
+
ws, hs = roma_model.w_resized, roma_model.h_resized
|
138 |
+
test_transform = get_tuple_transform_ops(resize=(hs, ws), normalize=True)
|
139 |
+
im_A, im_B = test_transform((imA, imB))
|
140 |
+
batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)}
|
141 |
+
|
142 |
+
# Forward pass through Roma model
|
143 |
+
corresps = roma_model.forward(batch) if not roma_model.symmetric else roma_model.forward_symmetric(batch)
|
144 |
+
finest_scale = 1
|
145 |
+
hs, ws = roma_model.upsample_res if roma_model.upsample_preds else (hs, ws)
|
146 |
+
|
147 |
+
# Process certainty and warp
|
148 |
+
certainty = corresps[finest_scale]["certainty"]
|
149 |
+
im_A_to_im_B = corresps[finest_scale]["flow"]
|
150 |
+
if roma_model.attenuate_cert:
|
151 |
+
low_res_certainty = F.interpolate(
|
152 |
+
corresps[16]["certainty"], size=(hs, ws), align_corners=False, mode="bilinear"
|
153 |
+
)
|
154 |
+
certainty -= 0.5 * low_res_certainty * (low_res_certainty < 0)
|
155 |
+
|
156 |
+
# Upsample predictions if needed
|
157 |
+
if roma_model.upsample_preds:
|
158 |
+
im_A_to_im_B = F.interpolate(
|
159 |
+
im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear"
|
160 |
+
)
|
161 |
+
certainty = F.interpolate(
|
162 |
+
certainty, size=(hs, ws), align_corners=False, mode="bilinear"
|
163 |
+
)
|
164 |
+
|
165 |
+
# Convert predictions to final format
|
166 |
+
im_A_to_im_B = im_A_to_im_B.permute(0, 2, 3, 1)
|
167 |
+
im_A_coords = torch.stack(torch.meshgrid(
|
168 |
+
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
|
169 |
+
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
|
170 |
+
indexing='ij'
|
171 |
+
), dim=0).permute(1, 2, 0).unsqueeze(0).expand(im_A_to_im_B.size(0), -1, -1, -1)
|
172 |
+
|
173 |
+
warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1)
|
174 |
+
certainty = certainty.sigmoid()
|
175 |
+
|
176 |
+
return certainty[0, 0], warp[0], np.array(imB)
|
177 |
+
|
178 |
+
|
179 |
+
def resize_batch(tensors_3d, tensors_4d, target_shape):
|
180 |
+
"""
|
181 |
+
Resizes a batch of tensors with shapes [B, H, W] and [B, H, W, 4] to the target spatial dimensions.
|
182 |
+
|
183 |
+
Args:
|
184 |
+
tensors_3d: Tensor of shape [B, H, W].
|
185 |
+
tensors_4d: Tensor of shape [B, H, W, 4].
|
186 |
+
target_shape: Tuple (target_H, target_W) specifying the target spatial dimensions.
|
187 |
+
|
188 |
+
Returns:
|
189 |
+
resized_tensors_3d: Tensor of shape [B, target_H, target_W].
|
190 |
+
resized_tensors_4d: Tensor of shape [B, target_H, target_W, 4].
|
191 |
+
"""
|
192 |
+
target_H, target_W = target_shape
|
193 |
+
|
194 |
+
# Resize [B, H, W] tensor
|
195 |
+
resized_tensors_3d = F.interpolate(
|
196 |
+
tensors_3d.unsqueeze(1), size=(target_H, target_W), mode="bilinear", align_corners=False
|
197 |
+
).squeeze(1)
|
198 |
+
|
199 |
+
# Resize [B, H, W, 4] tensor
|
200 |
+
B, _, _, C = tensors_4d.shape
|
201 |
+
resized_tensors_4d = F.interpolate(
|
202 |
+
tensors_4d.permute(0, 3, 1, 2), size=(target_H, target_W), mode="bilinear", align_corners=False
|
203 |
+
).permute(0, 2, 3, 1)
|
204 |
+
|
205 |
+
return resized_tensors_3d, resized_tensors_4d
|
206 |
+
|
207 |
+
|
208 |
+
def aggregate_confidences_and_warps(viewpoint_stack, closest_indices, roma_model, source_idx, verbose=False, output_dict={}):
|
209 |
+
"""
|
210 |
+
Aggregates confidences and warps by iterating over the nearest neighbors of the source viewpoint.
|
211 |
+
|
212 |
+
Args:
|
213 |
+
viewpoint_stack: Stack of viewpoint cameras.
|
214 |
+
closest_indices: Indices of the nearest neighbors for each viewpoint.
|
215 |
+
roma_model: Pre-trained Roma model.
|
216 |
+
source_idx: Index of the source viewpoint.
|
217 |
+
verbose: If True, displays intermediate results.
|
218 |
+
|
219 |
+
Returns:
|
220 |
+
certainties_max: Aggregated maximum confidences.
|
221 |
+
warps_max: Aggregated warps corresponding to maximum confidences.
|
222 |
+
certainties_max_idcs: Pixel-wise index of the image from which we taken the best matching.
|
223 |
+
imB_compound: List of the neighboring images.
|
224 |
+
"""
|
225 |
+
certainties_all, warps_all, imB_compound = [], [], []
|
226 |
+
|
227 |
+
for nn in tqdm(closest_indices[source_idx]):
|
228 |
+
|
229 |
+
viewpoint_cam1 = viewpoint_stack[source_idx]
|
230 |
+
viewpoint_cam2 = viewpoint_stack[nn]
|
231 |
+
|
232 |
+
certainty, warp, imB = compute_warp_and_confidence(viewpoint_cam1, viewpoint_cam2, roma_model, verbose=verbose, output_dict=output_dict)
|
233 |
+
certainties_all.append(certainty)
|
234 |
+
warps_all.append(warp)
|
235 |
+
imB_compound.append(imB)
|
236 |
+
|
237 |
+
certainties_all = torch.stack(certainties_all, dim=0)
|
238 |
+
target_shape = imB_compound[0].shape[:2]
|
239 |
+
if verbose:
|
240 |
+
print("certainties_all.shape:", certainties_all.shape)
|
241 |
+
print("torch.stack(warps_all, dim=0).shape:", torch.stack(warps_all, dim=0).shape)
|
242 |
+
print("target_shape:", target_shape)
|
243 |
+
|
244 |
+
certainties_all_resized, warps_all_resized = resize_batch(certainties_all,
|
245 |
+
torch.stack(warps_all, dim=0),
|
246 |
+
target_shape
|
247 |
+
)
|
248 |
+
|
249 |
+
if verbose:
|
250 |
+
print("warps_all_resized.shape:", warps_all_resized.shape)
|
251 |
+
for n, cert in enumerate(certainties_all):
|
252 |
+
fig, ax = plt.subplots()
|
253 |
+
cax = ax.imshow(cert.cpu().numpy(), cmap='viridis')
|
254 |
+
fig.colorbar(cax, ax=ax)
|
255 |
+
ax.set_title("Pixel-wise Confidence")
|
256 |
+
output_dict[f'certainty_{n}'] = fig
|
257 |
+
|
258 |
+
for n, warp in enumerate(warps_all):
|
259 |
+
fig, ax = plt.subplots()
|
260 |
+
cax = ax.imshow(warp.cpu().numpy()[:, :, :3], cmap='viridis')
|
261 |
+
fig.colorbar(cax, ax=ax)
|
262 |
+
ax.set_title("Pixel-wise warp")
|
263 |
+
output_dict[f'warp_resized_{n}'] = fig
|
264 |
+
|
265 |
+
for n, cert in enumerate(certainties_all_resized):
|
266 |
+
fig, ax = plt.subplots()
|
267 |
+
cax = ax.imshow(cert.cpu().numpy(), cmap='viridis')
|
268 |
+
fig.colorbar(cax, ax=ax)
|
269 |
+
ax.set_title("Pixel-wise Confidence resized")
|
270 |
+
output_dict[f'certainty_resized_{n}'] = fig
|
271 |
+
|
272 |
+
for n, warp in enumerate(warps_all_resized):
|
273 |
+
fig, ax = plt.subplots()
|
274 |
+
cax = ax.imshow(warp.cpu().numpy()[:, :, :3], cmap='viridis')
|
275 |
+
fig.colorbar(cax, ax=ax)
|
276 |
+
ax.set_title("Pixel-wise warp resized")
|
277 |
+
output_dict[f'warp_resized_{n}'] = fig
|
278 |
+
|
279 |
+
certainties_max, certainties_max_idcs = torch.max(certainties_all_resized, dim=0)
|
280 |
+
H, W = certainties_max.shape
|
281 |
+
|
282 |
+
warps_max = warps_all_resized[certainties_max_idcs, torch.arange(H).unsqueeze(1), torch.arange(W)]
|
283 |
+
|
284 |
+
imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0)
|
285 |
+
imA = np.clip(imA * 255, 0, 255).astype(np.uint8)
|
286 |
+
|
287 |
+
return certainties_max, warps_max, certainties_max_idcs, imA, imB_compound, certainties_all_resized, warps_all_resized
|
288 |
+
|
289 |
+
|
290 |
+
|
291 |
+
def extract_keypoints_and_colors(imA, imB_compound, certainties_max, certainties_max_idcs, matches, roma_model,
|
292 |
+
verbose=False, output_dict={}):
|
293 |
+
"""
|
294 |
+
Extracts keypoints and corresponding colors from the source image (imA) and multiple target images (imB_compound).
|
295 |
+
|
296 |
+
Args:
|
297 |
+
imA: Source image as a NumPy array (H_A, W_A, C).
|
298 |
+
imB_compound: List of target images as NumPy arrays [(H_B, W_B, C), ...].
|
299 |
+
certainties_max: Tensor of pixel-wise maximum confidences.
|
300 |
+
certainties_max_idcs: Tensor of pixel-wise indices for the best matches.
|
301 |
+
matches: Matches in normalized coordinates.
|
302 |
+
roma_model: Roma model instance for keypoint operations.
|
303 |
+
verbose: if to show intermediate outputs and visualize results
|
304 |
+
|
305 |
+
Returns:
|
306 |
+
kptsA_np: Keypoints in imA in normalized coordinates.
|
307 |
+
kptsB_np: Keypoints in imB in normalized coordinates.
|
308 |
+
kptsA_color: Colors of keypoints in imA.
|
309 |
+
kptsB_color: Colors of keypoints in imB based on certainties_max_idcs.
|
310 |
+
"""
|
311 |
+
H_A, W_A, _ = imA.shape
|
312 |
+
H, W = certainties_max.shape
|
313 |
+
|
314 |
+
# Convert matches to pixel coordinates
|
315 |
+
kptsA, kptsB = roma_model.to_pixel_coordinates(
|
316 |
+
matches, W_A, H_A, H, W # W, H
|
317 |
+
)
|
318 |
+
|
319 |
+
kptsA_np = kptsA.detach().cpu().numpy()
|
320 |
+
kptsB_np = kptsB.detach().cpu().numpy()
|
321 |
+
kptsA_np = kptsA_np[:, [1, 0]]
|
322 |
+
|
323 |
+
if verbose:
|
324 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
325 |
+
cax = ax.imshow(imA)
|
326 |
+
ax.set_title("Reference image, imA")
|
327 |
+
output_dict[f'reference_image'] = fig
|
328 |
+
|
329 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
330 |
+
cax = ax.imshow(imB_compound[0])
|
331 |
+
ax.set_title("Image to compare to image, imB_compound")
|
332 |
+
output_dict[f'imB_compound'] = fig
|
333 |
+
|
334 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
335 |
+
cax = ax.imshow(np.flipud(imA))
|
336 |
+
cax = ax.scatter(kptsA_np[:, 0], H_A - kptsA_np[:, 1], s=.03)
|
337 |
+
ax.set_title("Keypoints in imA")
|
338 |
+
ax.set_xlim(0, W_A)
|
339 |
+
ax.set_ylim(0, H_A)
|
340 |
+
output_dict[f'kptsA'] = fig
|
341 |
+
|
342 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
343 |
+
cax = ax.imshow(np.flipud(imB_compound[0]))
|
344 |
+
cax = ax.scatter(kptsB_np[:, 0], H_A - kptsB_np[:, 1], s=.03)
|
345 |
+
ax.set_title("Keypoints in imB")
|
346 |
+
ax.set_xlim(0, W_A)
|
347 |
+
ax.set_ylim(0, H_A)
|
348 |
+
output_dict[f'kptsB'] = fig
|
349 |
+
|
350 |
+
# Keypoints are in format (row, column) so the first value is alwain in range [0;height] and second is in range[0;width]
|
351 |
+
|
352 |
+
kptsA_np = kptsA.detach().cpu().numpy()
|
353 |
+
kptsB_np = kptsB.detach().cpu().numpy()
|
354 |
+
|
355 |
+
# Extract colors for keypoints in imA (vectorized)
|
356 |
+
# New experimental version
|
357 |
+
kptsA_x = np.round(kptsA_np[:, 0] / 1.).astype(int)
|
358 |
+
kptsA_y = np.round(kptsA_np[:, 1] / 1.).astype(int)
|
359 |
+
kptsA_color = imA[np.clip(kptsA_x, 0, H - 1), np.clip(kptsA_y, 0, W - 1)]
|
360 |
+
|
361 |
+
# Create a composite image from imB_compound
|
362 |
+
imB_compound_np = np.stack(imB_compound, axis=0)
|
363 |
+
H_B, W_B, _ = imB_compound[0].shape
|
364 |
+
|
365 |
+
# Extract colors for keypoints in imB using certainties_max_idcs
|
366 |
+
imB_np = imB_compound_np[
|
367 |
+
certainties_max_idcs.detach().cpu().numpy(),
|
368 |
+
np.arange(H).reshape(-1, 1),
|
369 |
+
np.arange(W)
|
370 |
+
]
|
371 |
+
|
372 |
+
if verbose:
|
373 |
+
print("imB_np.shape:", imB_np.shape)
|
374 |
+
print("imB_np:", imB_np)
|
375 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
376 |
+
cax = ax.imshow(np.flipud(imB_np))
|
377 |
+
cax = ax.scatter(kptsB_np[:, 0], H_A - kptsB_np[:, 1], s=.03)
|
378 |
+
ax.set_title("np.flipud(imB_np[0]")
|
379 |
+
ax.set_xlim(0, W_A)
|
380 |
+
ax.set_ylim(0, H_A)
|
381 |
+
output_dict[f'np.flipud(imB_np[0]'] = fig
|
382 |
+
|
383 |
+
|
384 |
+
kptsB_x = np.round(kptsB_np[:, 0]).astype(int)
|
385 |
+
kptsB_y = np.round(kptsB_np[:, 1]).astype(int)
|
386 |
+
|
387 |
+
certainties_max_idcs_np = certainties_max_idcs.detach().cpu().numpy()
|
388 |
+
kptsB_proj_matrices_idx = certainties_max_idcs_np[np.clip(kptsA_x, 0, H - 1), np.clip(kptsA_y, 0, W - 1)]
|
389 |
+
kptsB_color = imB_compound_np[kptsB_proj_matrices_idx, np.clip(kptsB_y, 0, H - 1), np.clip(kptsB_x, 0, W - 1)]
|
390 |
+
|
391 |
+
# Normalize keypoints in both images
|
392 |
+
kptsA_np[:, 0] = kptsA_np[:, 0] / H * 2.0 - 1.0
|
393 |
+
kptsA_np[:, 1] = kptsA_np[:, 1] / W * 2.0 - 1.0
|
394 |
+
kptsB_np[:, 0] = kptsB_np[:, 0] / W_B * 2.0 - 1.0
|
395 |
+
kptsB_np[:, 1] = kptsB_np[:, 1] / H_B * 2.0 - 1.0
|
396 |
+
|
397 |
+
return kptsA_np[:, [1, 0]], kptsB_np, kptsB_proj_matrices_idx, kptsA_color, kptsB_color
|
398 |
+
|
399 |
+
def prepare_tensor(input_array, device):
|
400 |
+
"""
|
401 |
+
Converts an input array to a torch tensor, clones it, and detaches it for safe computation.
|
402 |
+
Args:
|
403 |
+
input_array (array-like): The input array to convert.
|
404 |
+
device (str or torch.device): The device to move the tensor to.
|
405 |
+
Returns:
|
406 |
+
torch.Tensor: A detached tensor clone of the input array on the specified device.
|
407 |
+
"""
|
408 |
+
if not isinstance(input_array, torch.Tensor):
|
409 |
+
return torch.tensor(input_array, dtype=torch.float32).to(device).clone().detach()
|
410 |
+
return input_array.clone().detach().to(device).to(torch.float32)
|
411 |
+
|
412 |
+
def triangulate_points(P1, P2, k1_x, k1_y, k2_x, k2_y, device="cuda"):
|
413 |
+
"""
|
414 |
+
Solves for a batch of 3D points given batches of projection matrices and corresponding image points.
|
415 |
+
|
416 |
+
Parameters:
|
417 |
+
- P1, P2: Tensors of projection matrices of size (batch_size, 4, 4) or (4, 4)
|
418 |
+
- k1_x, k1_y: Tensors of shape (batch_size,)
|
419 |
+
- k2_x, k2_y: Tensors of shape (batch_size,)
|
420 |
+
|
421 |
+
Returns:
|
422 |
+
- X: A tensor containing the 3D homogeneous coordinates, shape (batch_size, 4)
|
423 |
+
"""
|
424 |
+
EPS = 1e-4
|
425 |
+
# Ensure inputs are tensors
|
426 |
+
|
427 |
+
P1 = prepare_tensor(P1, device)
|
428 |
+
P2 = prepare_tensor(P2, device)
|
429 |
+
k1_x = prepare_tensor(k1_x, device)
|
430 |
+
k1_y = prepare_tensor(k1_y, device)
|
431 |
+
k2_x = prepare_tensor(k2_x, device)
|
432 |
+
k2_y = prepare_tensor(k2_y, device)
|
433 |
+
batch_size = k1_x.shape[0]
|
434 |
+
|
435 |
+
# Expand P1 and P2 if they are not batched
|
436 |
+
if P1.ndim == 2:
|
437 |
+
P1 = P1.unsqueeze(0).expand(batch_size, -1, -1)
|
438 |
+
if P2.ndim == 2:
|
439 |
+
P2 = P2.unsqueeze(0).expand(batch_size, -1, -1)
|
440 |
+
|
441 |
+
# Extract columns from P1 and P2
|
442 |
+
P1_0 = P1[:, :, 0] # Shape: (batch_size, 4)
|
443 |
+
P1_1 = P1[:, :, 1]
|
444 |
+
P1_2 = P1[:, :, 2]
|
445 |
+
|
446 |
+
P2_0 = P2[:, :, 0]
|
447 |
+
P2_1 = P2[:, :, 1]
|
448 |
+
P2_2 = P2[:, :, 2]
|
449 |
+
|
450 |
+
# Reshape kx and ky to (batch_size, 1)
|
451 |
+
k1_x = k1_x.view(-1, 1)
|
452 |
+
k1_y = k1_y.view(-1, 1)
|
453 |
+
k2_x = k2_x.view(-1, 1)
|
454 |
+
k2_y = k2_y.view(-1, 1)
|
455 |
+
|
456 |
+
# Construct the equations for each batch
|
457 |
+
# For camera 1
|
458 |
+
A1 = P1_0 - k1_x * P1_2 # Shape: (batch_size, 4)
|
459 |
+
A2 = P1_1 - k1_y * P1_2
|
460 |
+
# For camera 2
|
461 |
+
A3 = P2_0 - k2_x * P2_2
|
462 |
+
A4 = P2_1 - k2_y * P2_2
|
463 |
+
|
464 |
+
# Stack the equations
|
465 |
+
A = torch.stack([A1, A2, A3, A4], dim=1) # Shape: (batch_size, 4, 4)
|
466 |
+
|
467 |
+
# Right-hand side (constants)
|
468 |
+
b = -A[:, :, 3] # Shape: (batch_size, 4)
|
469 |
+
A_reduced = A[:, :, :3] # Coefficients of x, y, z
|
470 |
+
|
471 |
+
# Solve using torch.linalg.lstsq (supports batching)
|
472 |
+
X_xyz = torch.linalg.lstsq(A_reduced, b.unsqueeze(2)).solution.squeeze(2) # Shape: (batch_size, 3)
|
473 |
+
|
474 |
+
# Append 1 to get homogeneous coordinates
|
475 |
+
ones = torch.ones((batch_size, 1), dtype=torch.float32, device=X_xyz.device)
|
476 |
+
X = torch.cat([X_xyz, ones], dim=1) # Shape: (batch_size, 4)
|
477 |
+
|
478 |
+
# Now compute the errors of projections.
|
479 |
+
seeked_splats_proj1 = (X.unsqueeze(1) @ P1).squeeze(1)
|
480 |
+
seeked_splats_proj1 = seeked_splats_proj1 / (EPS + seeked_splats_proj1[:, [3]])
|
481 |
+
seeked_splats_proj2 = (X.unsqueeze(1) @ P2).squeeze(1)
|
482 |
+
seeked_splats_proj2 = seeked_splats_proj2 / (EPS + seeked_splats_proj2[:, [3]])
|
483 |
+
proj1_target = torch.concat([k1_x, k1_y], dim=1)
|
484 |
+
proj2_target = torch.concat([k2_x, k2_y], dim=1)
|
485 |
+
errors_proj1 = torch.abs(seeked_splats_proj1[:, :2] - proj1_target).sum(1).detach().cpu().numpy()
|
486 |
+
errors_proj2 = torch.abs(seeked_splats_proj2[:, :2] - proj2_target).sum(1).detach().cpu().numpy()
|
487 |
+
|
488 |
+
return X, errors_proj1, errors_proj2
|
489 |
+
|
490 |
+
|
491 |
+
|
492 |
+
def select_best_keypoints(
|
493 |
+
NNs_triangulated_points, NNs_errors_proj1, NNs_errors_proj2, device="cuda"):
|
494 |
+
"""
|
495 |
+
From all the points fitted to keypoints and corresponding colors from the source image (imA) and multiple target images (imB_compound).
|
496 |
+
|
497 |
+
Args:
|
498 |
+
NNs_triangulated_points: torch tensor with keypoints coordinates (num_nns, num_points, dim). dim can be arbitrary,
|
499 |
+
usually 3 or 4(for homogeneous representation).
|
500 |
+
NNs_errors_proj1: numpy array with projection error of the estimated keypoint on the reference frame (num_nns, num_points).
|
501 |
+
NNs_errors_proj2: numpy array with projection error of the estimated keypoint on the neighbor frame (num_nns, num_points).
|
502 |
+
Returns:
|
503 |
+
selected_keypoints: keypoints with the best score.
|
504 |
+
"""
|
505 |
+
|
506 |
+
NNs_errors_proj = np.maximum(NNs_errors_proj1, NNs_errors_proj2)
|
507 |
+
|
508 |
+
# Convert indices to PyTorch tensor
|
509 |
+
indices = torch.from_numpy(np.argmin(NNs_errors_proj, axis=0)).long().to(device)
|
510 |
+
|
511 |
+
# Create index tensor for the second dimension
|
512 |
+
n_indices = torch.arange(NNs_triangulated_points.shape[1]).long().to(device)
|
513 |
+
|
514 |
+
# Use advanced indexing to select elements
|
515 |
+
NNs_triangulated_points_selected = NNs_triangulated_points[indices, n_indices, :] # Shape: [N, k]
|
516 |
+
|
517 |
+
return NNs_triangulated_points_selected, np.min(NNs_errors_proj, axis=0)
|
518 |
+
|
519 |
+
|
520 |
+
|
521 |
+
def init_gaussians_with_corr(gaussians, scene, cfg, device, verbose = False, roma_model=None):
|
522 |
+
"""
|
523 |
+
For a given input gaussians and a scene we instantiate a RoMa model(change to indoors if necessary) and process scene
|
524 |
+
training frames to extract correspondences. Those are used to initialize gaussians
|
525 |
+
Args:
|
526 |
+
gaussians: object gaussians of the class GaussianModel that we need to enrich with gaussians.
|
527 |
+
scene: object of the Scene class.
|
528 |
+
cfg: configuration. Use init_wC
|
529 |
+
Returns:
|
530 |
+
gaussians: inplace transforms object gaussians of the class GaussianModel.
|
531 |
+
|
532 |
+
"""
|
533 |
+
if roma_model is None:
|
534 |
+
if cfg.roma_model == "indoors":
|
535 |
+
roma_model = roma_indoor(device=device)
|
536 |
+
else:
|
537 |
+
roma_model = roma_outdoor(device=device)
|
538 |
+
roma_model.upsample_preds = False
|
539 |
+
roma_model.symmetric = False
|
540 |
+
M = cfg.matches_per_ref
|
541 |
+
upper_thresh = roma_model.sample_thresh
|
542 |
+
scaling_factor = cfg.scaling_factor
|
543 |
+
expansion_factor = 1
|
544 |
+
keypoint_fit_error_tolerance = cfg.proj_err_tolerance
|
545 |
+
visualizations = {}
|
546 |
+
viewpoint_stack = scene.getTrainCameras().copy()
|
547 |
+
NUM_REFERENCE_FRAMES = min(cfg.num_refs, len(viewpoint_stack))
|
548 |
+
NUM_NNS_PER_REFERENCE = min(cfg.nns_per_ref , len(viewpoint_stack))
|
549 |
+
# Select cameras using K-means
|
550 |
+
viewpoint_cam_all = torch.stack([x.world_view_transform.flatten() for x in viewpoint_stack], axis=0)
|
551 |
+
|
552 |
+
selected_indices = select_cameras_kmeans(cameras=viewpoint_cam_all.detach().cpu().numpy(), K=NUM_REFERENCE_FRAMES)
|
553 |
+
selected_indices = sorted(selected_indices)
|
554 |
+
|
555 |
+
|
556 |
+
# Find the k-closest vectors for each vector
|
557 |
+
viewpoint_cam_all = torch.stack([x.world_view_transform.flatten() for x in viewpoint_stack], axis=0)
|
558 |
+
closest_indices = k_closest_vectors(viewpoint_cam_all, NUM_NNS_PER_REFERENCE)
|
559 |
+
if verbose: print("Indices of k-closest vectors for each vector:\n", closest_indices)
|
560 |
+
|
561 |
+
closest_indices_selected = closest_indices[:, :].detach().cpu().numpy()
|
562 |
+
|
563 |
+
all_new_xyz = []
|
564 |
+
all_new_features_dc = []
|
565 |
+
all_new_features_rest = []
|
566 |
+
all_new_opacities = []
|
567 |
+
all_new_scaling = []
|
568 |
+
all_new_rotation = []
|
569 |
+
|
570 |
+
# Run roma_model.match once to kinda initialize the model
|
571 |
+
with torch.no_grad():
|
572 |
+
viewpoint_cam1 = viewpoint_stack[0]
|
573 |
+
viewpoint_cam2 = viewpoint_stack[1]
|
574 |
+
imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0)
|
575 |
+
imB = viewpoint_cam2.original_image.detach().cpu().numpy().transpose(1, 2, 0)
|
576 |
+
imA = Image.fromarray(np.clip(imA * 255, 0, 255).astype(np.uint8))
|
577 |
+
imB = Image.fromarray(np.clip(imB * 255, 0, 255).astype(np.uint8))
|
578 |
+
warp, certainty_warp = roma_model.match(imA, imB, device=device)
|
579 |
+
print("Once run full roma_model.match warp.shape:", warp.shape)
|
580 |
+
print("Once run full roma_model.match certainty_warp.shape:", certainty_warp.shape)
|
581 |
+
del warp, certainty_warp
|
582 |
+
torch.cuda.empty_cache()
|
583 |
+
|
584 |
+
for source_idx in tqdm(sorted(selected_indices)):
|
585 |
+
# 1. Compute keypoints and warping for all the neigboring views
|
586 |
+
with torch.no_grad():
|
587 |
+
# Call the aggregation function to get imA and imB_compound
|
588 |
+
certainties_max, warps_max, certainties_max_idcs, imA, imB_compound, certainties_all, warps_all = aggregate_confidences_and_warps(
|
589 |
+
viewpoint_stack=viewpoint_stack,
|
590 |
+
closest_indices=closest_indices_selected,
|
591 |
+
roma_model=roma_model,
|
592 |
+
source_idx=source_idx,
|
593 |
+
verbose=verbose, output_dict=visualizations
|
594 |
+
)
|
595 |
+
|
596 |
+
|
597 |
+
# Triangulate keypoints
|
598 |
+
with torch.no_grad():
|
599 |
+
matches = warps_max
|
600 |
+
certainty = certainties_max
|
601 |
+
certainty = certainty.clone()
|
602 |
+
certainty[certainty > upper_thresh] = 1
|
603 |
+
matches, certainty = (
|
604 |
+
matches.reshape(-1, 4),
|
605 |
+
certainty.reshape(-1),
|
606 |
+
)
|
607 |
+
|
608 |
+
# Select based on certainty elements with high confidence. These are basically all of
|
609 |
+
# kptsA_np.
|
610 |
+
good_samples = torch.multinomial(certainty,
|
611 |
+
num_samples=min(expansion_factor * M, len(certainty)),
|
612 |
+
replacement=False)
|
613 |
+
|
614 |
+
certainties_max, warps_max, certainties_max_idcs, imA, imB_compound, certainties_all, warps_all
|
615 |
+
reference_image_dict = {
|
616 |
+
"ref_image": imA,
|
617 |
+
"NNs_images": imB_compound,
|
618 |
+
"certainties_all": certainties_all,
|
619 |
+
"warps_all": warps_all,
|
620 |
+
"triangulated_points": [],
|
621 |
+
"triangulated_points_errors_proj1": [],
|
622 |
+
"triangulated_points_errors_proj2": []
|
623 |
+
|
624 |
+
}
|
625 |
+
with torch.no_grad():
|
626 |
+
for NN_idx in tqdm(range(len(warps_all))):
|
627 |
+
matches_NN = warps_all[NN_idx].reshape(-1, 4)[good_samples]
|
628 |
+
|
629 |
+
# Extract keypoints and colors
|
630 |
+
kptsA_np, kptsB_np, kptsB_proj_matrices_idcs, kptsA_color, kptsB_color = extract_keypoints_and_colors(
|
631 |
+
imA, imB_compound, certainties_max, certainties_max_idcs, matches_NN, roma_model
|
632 |
+
)
|
633 |
+
|
634 |
+
proj_matrices_A = viewpoint_stack[source_idx].full_proj_transform
|
635 |
+
proj_matrices_B = viewpoint_stack[closest_indices_selected[source_idx, NN_idx]].full_proj_transform
|
636 |
+
triangulated_points, triangulated_points_errors_proj1, triangulated_points_errors_proj2 = triangulate_points(
|
637 |
+
P1=torch.stack([proj_matrices_A] * M, axis=0),
|
638 |
+
P2=torch.stack([proj_matrices_B] * M, axis=0),
|
639 |
+
k1_x=kptsA_np[:M, 0], k1_y=kptsA_np[:M, 1],
|
640 |
+
k2_x=kptsB_np[:M, 0], k2_y=kptsB_np[:M, 1])
|
641 |
+
|
642 |
+
reference_image_dict["triangulated_points"].append(triangulated_points)
|
643 |
+
reference_image_dict["triangulated_points_errors_proj1"].append(triangulated_points_errors_proj1)
|
644 |
+
reference_image_dict["triangulated_points_errors_proj2"].append(triangulated_points_errors_proj2)
|
645 |
+
|
646 |
+
with torch.no_grad():
|
647 |
+
NNs_triangulated_points_selected, NNs_triangulated_points_selected_proj_errors = select_best_keypoints(
|
648 |
+
NNs_triangulated_points=torch.stack(reference_image_dict["triangulated_points"], dim=0),
|
649 |
+
NNs_errors_proj1=np.stack(reference_image_dict["triangulated_points_errors_proj1"], axis=0),
|
650 |
+
NNs_errors_proj2=np.stack(reference_image_dict["triangulated_points_errors_proj2"], axis=0))
|
651 |
+
|
652 |
+
# 4. Save as gaussians
|
653 |
+
viewpoint_cam1 = viewpoint_stack[source_idx]
|
654 |
+
N = len(NNs_triangulated_points_selected)
|
655 |
+
with torch.no_grad():
|
656 |
+
new_xyz = NNs_triangulated_points_selected[:, :-1]
|
657 |
+
all_new_xyz.append(new_xyz) # seeked_splats
|
658 |
+
all_new_features_dc.append(RGB2SH(torch.tensor(kptsA_color.astype(np.float32) / 255.)).unsqueeze(1))
|
659 |
+
all_new_features_rest.append(torch.stack([gaussians._features_rest[-1].clone().detach() * 0.] * N, dim=0))
|
660 |
+
# new version that sets points with large error invisible
|
661 |
+
# TODO: remove those points instead. However it doesn't affect the performance.
|
662 |
+
mask_bad_points = torch.tensor(
|
663 |
+
NNs_triangulated_points_selected_proj_errors > keypoint_fit_error_tolerance,
|
664 |
+
dtype=torch.float32).unsqueeze(1).to(device)
|
665 |
+
all_new_opacities.append(torch.stack([gaussians._opacity[-1].clone().detach()] * N, dim=0) * 0. - mask_bad_points * (1e1))
|
666 |
+
|
667 |
+
dist_points_to_cam1 = torch.linalg.norm(viewpoint_cam1.camera_center.clone().detach() - new_xyz,
|
668 |
+
dim=1, ord=2)
|
669 |
+
#all_new_scaling.append(torch.log(((dist_points_to_cam1) / 1. * scaling_factor).unsqueeze(1).repeat(1, 3)))
|
670 |
+
all_new_scaling.append(gaussians.scaling_inverse_activation((dist_points_to_cam1 * scaling_factor).unsqueeze(1).repeat(1, 3)))
|
671 |
+
all_new_rotation.append(torch.stack([gaussians._rotation[-1].clone().detach()] * N, dim=0))
|
672 |
+
|
673 |
+
all_new_xyz = torch.cat(all_new_xyz, dim=0)
|
674 |
+
all_new_features_dc = torch.cat(all_new_features_dc, dim=0)
|
675 |
+
new_tmp_radii = torch.zeros(all_new_xyz.shape[0])
|
676 |
+
prune_mask = torch.ones(all_new_xyz.shape[0], dtype=torch.bool)
|
677 |
+
|
678 |
+
gaussians.densification_postfix(all_new_xyz[prune_mask].to(device),
|
679 |
+
all_new_features_dc[prune_mask].to(device),
|
680 |
+
torch.cat(all_new_features_rest, dim=0)[prune_mask].to(device),
|
681 |
+
torch.cat(all_new_opacities, dim=0)[prune_mask].to(device),
|
682 |
+
torch.cat(all_new_scaling, dim=0)[prune_mask].to(device),
|
683 |
+
torch.cat(all_new_rotation, dim=0)[prune_mask].to(device),
|
684 |
+
new_tmp_radii[prune_mask].to(device))
|
685 |
+
|
686 |
+
return viewpoint_stack, closest_indices_selected, visualizations
|
687 |
+
|
688 |
+
|
689 |
+
|
690 |
+
def extract_keypoints_and_colors_single(imA, imB, matches, roma_model, verbose=False, output_dict={}):
|
691 |
+
"""
|
692 |
+
Extracts keypoints and corresponding colors from a source image (imA) and a single target image (imB).
|
693 |
+
|
694 |
+
Args:
|
695 |
+
imA: Source image as a NumPy array (H_A, W_A, C).
|
696 |
+
imB: Target image as a NumPy array (H_B, W_B, C).
|
697 |
+
matches: Matches in normalized coordinates (torch.Tensor).
|
698 |
+
roma_model: Roma model instance for keypoint operations.
|
699 |
+
verbose: If True, outputs intermediate visualizations.
|
700 |
+
Returns:
|
701 |
+
kptsA_np: Keypoints in imA (normalized).
|
702 |
+
kptsB_np: Keypoints in imB (normalized).
|
703 |
+
kptsA_color: Colors of keypoints in imA.
|
704 |
+
kptsB_color: Colors of keypoints in imB.
|
705 |
+
"""
|
706 |
+
H_A, W_A, _ = imA.shape
|
707 |
+
H_B, W_B, _ = imB.shape
|
708 |
+
|
709 |
+
# Convert matches to pixel coordinates
|
710 |
+
# Matches format: (B, 4) = (x1_norm, y1_norm, x2_norm, y2_norm)
|
711 |
+
kptsA = matches[:, :2] # [N, 2]
|
712 |
+
kptsB = matches[:, 2:] # [N, 2]
|
713 |
+
|
714 |
+
# Scale normalized coordinates [-1,1] to pixel coordinates
|
715 |
+
kptsA_pix = torch.zeros_like(kptsA)
|
716 |
+
kptsB_pix = torch.zeros_like(kptsB)
|
717 |
+
|
718 |
+
# Important! [Normalized to pixel space]
|
719 |
+
kptsA_pix[:, 0] = (kptsA[:, 0] + 1) * (W_A - 1) / 2
|
720 |
+
kptsA_pix[:, 1] = (kptsA[:, 1] + 1) * (H_A - 1) / 2
|
721 |
+
|
722 |
+
kptsB_pix[:, 0] = (kptsB[:, 0] + 1) * (W_B - 1) / 2
|
723 |
+
kptsB_pix[:, 1] = (kptsB[:, 1] + 1) * (H_B - 1) / 2
|
724 |
+
|
725 |
+
kptsA_np = kptsA_pix.detach().cpu().numpy()
|
726 |
+
kptsB_np = kptsB_pix.detach().cpu().numpy()
|
727 |
+
|
728 |
+
# Extract colors
|
729 |
+
kptsA_x = np.round(kptsA_np[:, 0]).astype(int)
|
730 |
+
kptsA_y = np.round(kptsA_np[:, 1]).astype(int)
|
731 |
+
kptsB_x = np.round(kptsB_np[:, 0]).astype(int)
|
732 |
+
kptsB_y = np.round(kptsB_np[:, 1]).astype(int)
|
733 |
+
|
734 |
+
kptsA_color = imA[np.clip(kptsA_y, 0, H_A-1), np.clip(kptsA_x, 0, W_A-1)]
|
735 |
+
kptsB_color = imB[np.clip(kptsB_y, 0, H_B-1), np.clip(kptsB_x, 0, W_B-1)]
|
736 |
+
|
737 |
+
# Normalize keypoints into [-1, 1] for downstream triangulation
|
738 |
+
kptsA_np_norm = np.zeros_like(kptsA_np)
|
739 |
+
kptsB_np_norm = np.zeros_like(kptsB_np)
|
740 |
+
|
741 |
+
kptsA_np_norm[:, 0] = kptsA_np[:, 0] / (W_A - 1) * 2.0 - 1.0
|
742 |
+
kptsA_np_norm[:, 1] = kptsA_np[:, 1] / (H_A - 1) * 2.0 - 1.0
|
743 |
+
|
744 |
+
kptsB_np_norm[:, 0] = kptsB_np[:, 0] / (W_B - 1) * 2.0 - 1.0
|
745 |
+
kptsB_np_norm[:, 1] = kptsB_np[:, 1] / (H_B - 1) * 2.0 - 1.0
|
746 |
+
|
747 |
+
return kptsA_np_norm, kptsB_np_norm, kptsA_color, kptsB_color
|
748 |
+
|
749 |
+
|
750 |
+
|
751 |
+
def init_gaussians_with_corr_fast(gaussians, scene, cfg, device, verbose=False, roma_model=None):
|
752 |
+
timings = defaultdict(list)
|
753 |
+
|
754 |
+
if roma_model is None:
|
755 |
+
if cfg.roma_model == "indoors":
|
756 |
+
roma_model = roma_indoor(device=device)
|
757 |
+
else:
|
758 |
+
roma_model = roma_outdoor(device=device)
|
759 |
+
roma_model.upsample_preds = False
|
760 |
+
roma_model.symmetric = False
|
761 |
+
|
762 |
+
M = cfg.matches_per_ref
|
763 |
+
upper_thresh = roma_model.sample_thresh
|
764 |
+
scaling_factor = cfg.scaling_factor
|
765 |
+
expansion_factor = 1
|
766 |
+
keypoint_fit_error_tolerance = cfg.proj_err_tolerance
|
767 |
+
visualizations = {}
|
768 |
+
viewpoint_stack = scene.getTrainCameras().copy()
|
769 |
+
NUM_REFERENCE_FRAMES = min(cfg.num_refs, len(viewpoint_stack))
|
770 |
+
NUM_NNS_PER_REFERENCE = 1 # Only ONE neighbor now!
|
771 |
+
|
772 |
+
viewpoint_cam_all = torch.stack([x.world_view_transform.flatten() for x in viewpoint_stack], axis=0)
|
773 |
+
|
774 |
+
selected_indices = select_cameras_kmeans(cameras=viewpoint_cam_all.detach().cpu().numpy(), K=NUM_REFERENCE_FRAMES)
|
775 |
+
selected_indices = sorted(selected_indices)
|
776 |
+
|
777 |
+
viewpoint_cam_all = torch.stack([x.world_view_transform.flatten() for x in viewpoint_stack], axis=0)
|
778 |
+
closest_indices = k_closest_vectors(viewpoint_cam_all, NUM_NNS_PER_REFERENCE)
|
779 |
+
closest_indices_selected = closest_indices[:, :].detach().cpu().numpy()
|
780 |
+
|
781 |
+
all_new_xyz = []
|
782 |
+
all_new_features_dc = []
|
783 |
+
all_new_features_rest = []
|
784 |
+
all_new_opacities = []
|
785 |
+
all_new_scaling = []
|
786 |
+
all_new_rotation = []
|
787 |
+
|
788 |
+
# Dummy first pass to initialize model
|
789 |
+
with torch.no_grad():
|
790 |
+
viewpoint_cam1 = viewpoint_stack[0]
|
791 |
+
viewpoint_cam2 = viewpoint_stack[1]
|
792 |
+
imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0)
|
793 |
+
imB = viewpoint_cam2.original_image.detach().cpu().numpy().transpose(1, 2, 0)
|
794 |
+
imA = Image.fromarray(np.clip(imA * 255, 0, 255).astype(np.uint8))
|
795 |
+
imB = Image.fromarray(np.clip(imB * 255, 0, 255).astype(np.uint8))
|
796 |
+
warp, certainty_warp = roma_model.match(imA, imB, device=device)
|
797 |
+
del warp, certainty_warp
|
798 |
+
torch.cuda.empty_cache()
|
799 |
+
|
800 |
+
# Main Loop over source_idx
|
801 |
+
for source_idx in tqdm(sorted(selected_indices), desc="Profiling source frames"):
|
802 |
+
|
803 |
+
# =================== Step 1: Compute Warp and Certainty ===================
|
804 |
+
start = time.time()
|
805 |
+
viewpoint_cam1 = viewpoint_stack[source_idx]
|
806 |
+
NNs=closest_indices_selected.shape[1]
|
807 |
+
viewpoint_cam2 = viewpoint_stack[closest_indices_selected[source_idx, np.random.randint(NNs)]]
|
808 |
+
imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0)
|
809 |
+
imB = viewpoint_cam2.original_image.detach().cpu().numpy().transpose(1, 2, 0)
|
810 |
+
imA = Image.fromarray(np.clip(imA * 255, 0, 255).astype(np.uint8))
|
811 |
+
imB = Image.fromarray(np.clip(imB * 255, 0, 255).astype(np.uint8))
|
812 |
+
warp, certainty_warp = roma_model.match(imA, imB, device=device)
|
813 |
+
|
814 |
+
certainties_max = certainty_warp # New manual sampling
|
815 |
+
timings['aggregation_warp_certainty'].append(time.time() - start)
|
816 |
+
|
817 |
+
# =================== Step 2: Good Samples Selection ===================
|
818 |
+
start = time.time()
|
819 |
+
certainty = certainties_max.reshape(-1).clone()
|
820 |
+
certainty[certainty > upper_thresh] = 1
|
821 |
+
good_samples = torch.multinomial(certainty, num_samples=min(expansion_factor * M, len(certainty)), replacement=False)
|
822 |
+
timings['good_samples_selection'].append(time.time() - start)
|
823 |
+
|
824 |
+
# =================== Step 3: Triangulate Keypoints ===================
|
825 |
+
reference_image_dict = {
|
826 |
+
"triangulated_points": [],
|
827 |
+
"triangulated_points_errors_proj1": [],
|
828 |
+
"triangulated_points_errors_proj2": []
|
829 |
+
}
|
830 |
+
|
831 |
+
start = time.time()
|
832 |
+
matches_NN = warp.reshape(-1, 4)[good_samples]
|
833 |
+
|
834 |
+
# Convert matches to pixel coordinates
|
835 |
+
kptsA_np, kptsB_np, kptsA_color, kptsB_color = extract_keypoints_and_colors_single(
|
836 |
+
np.array(imA).astype(np.uint8),
|
837 |
+
np.array(imB).astype(np.uint8),
|
838 |
+
matches_NN,
|
839 |
+
roma_model
|
840 |
+
)
|
841 |
+
|
842 |
+
proj_matrices_A = viewpoint_stack[source_idx].full_proj_transform
|
843 |
+
proj_matrices_B = viewpoint_stack[closest_indices_selected[source_idx, 0]].full_proj_transform
|
844 |
+
|
845 |
+
triangulated_points, triangulated_points_errors_proj1, triangulated_points_errors_proj2 = triangulate_points(
|
846 |
+
P1=torch.stack([proj_matrices_A] * M, axis=0),
|
847 |
+
P2=torch.stack([proj_matrices_B] * M, axis=0),
|
848 |
+
k1_x=kptsA_np[:M, 0], k1_y=kptsA_np[:M, 1],
|
849 |
+
k2_x=kptsB_np[:M, 0], k2_y=kptsB_np[:M, 1])
|
850 |
+
|
851 |
+
reference_image_dict["triangulated_points"].append(triangulated_points)
|
852 |
+
reference_image_dict["triangulated_points_errors_proj1"].append(triangulated_points_errors_proj1)
|
853 |
+
reference_image_dict["triangulated_points_errors_proj2"].append(triangulated_points_errors_proj2)
|
854 |
+
timings['triangulation_per_NN'].append(time.time() - start)
|
855 |
+
|
856 |
+
# =================== Step 4: Select Best Triangulated Points ===================
|
857 |
+
start = time.time()
|
858 |
+
NNs_triangulated_points_selected, NNs_triangulated_points_selected_proj_errors = select_best_keypoints(
|
859 |
+
NNs_triangulated_points=torch.stack(reference_image_dict["triangulated_points"], dim=0),
|
860 |
+
NNs_errors_proj1=np.stack(reference_image_dict["triangulated_points_errors_proj1"], axis=0),
|
861 |
+
NNs_errors_proj2=np.stack(reference_image_dict["triangulated_points_errors_proj2"], axis=0))
|
862 |
+
timings['select_best_keypoints'].append(time.time() - start)
|
863 |
+
|
864 |
+
# =================== Step 5: Create New Gaussians ===================
|
865 |
+
start = time.time()
|
866 |
+
viewpoint_cam1 = viewpoint_stack[source_idx]
|
867 |
+
N = len(NNs_triangulated_points_selected)
|
868 |
+
new_xyz = NNs_triangulated_points_selected[:, :-1]
|
869 |
+
all_new_xyz.append(new_xyz)
|
870 |
+
all_new_features_dc.append(RGB2SH(torch.tensor(kptsA_color.astype(np.float32) / 255.)).unsqueeze(1))
|
871 |
+
all_new_features_rest.append(torch.stack([gaussians._features_rest[-1].clone().detach() * 0.] * N, dim=0))
|
872 |
+
|
873 |
+
mask_bad_points = torch.tensor(
|
874 |
+
NNs_triangulated_points_selected_proj_errors > keypoint_fit_error_tolerance,
|
875 |
+
dtype=torch.float32).unsqueeze(1).to(device)
|
876 |
+
|
877 |
+
all_new_opacities.append(torch.stack([gaussians._opacity[-1].clone().detach()] * N, dim=0) * 0. - mask_bad_points * (1e1))
|
878 |
+
|
879 |
+
dist_points_to_cam1 = torch.linalg.norm(viewpoint_cam1.camera_center.clone().detach() - new_xyz, dim=1, ord=2)
|
880 |
+
all_new_scaling.append(gaussians.scaling_inverse_activation((dist_points_to_cam1 * scaling_factor).unsqueeze(1).repeat(1, 3)))
|
881 |
+
all_new_rotation.append(torch.stack([gaussians._rotation[-1].clone().detach()] * N, dim=0))
|
882 |
+
timings['save_gaussians'].append(time.time() - start)
|
883 |
+
|
884 |
+
# =================== Final Densification Postfix ===================
|
885 |
+
start = time.time()
|
886 |
+
all_new_xyz = torch.cat(all_new_xyz, dim=0)
|
887 |
+
all_new_features_dc = torch.cat(all_new_features_dc, dim=0)
|
888 |
+
new_tmp_radii = torch.zeros(all_new_xyz.shape[0])
|
889 |
+
prune_mask = torch.ones(all_new_xyz.shape[0], dtype=torch.bool)
|
890 |
+
|
891 |
+
gaussians.densification_postfix(
|
892 |
+
all_new_xyz[prune_mask].to(device),
|
893 |
+
all_new_features_dc[prune_mask].to(device),
|
894 |
+
torch.cat(all_new_features_rest, dim=0)[prune_mask].to(device),
|
895 |
+
torch.cat(all_new_opacities, dim=0)[prune_mask].to(device),
|
896 |
+
torch.cat(all_new_scaling, dim=0)[prune_mask].to(device),
|
897 |
+
torch.cat(all_new_rotation, dim=0)[prune_mask].to(device),
|
898 |
+
new_tmp_radii[prune_mask].to(device)
|
899 |
+
)
|
900 |
+
timings['final_densification_postfix'].append(time.time() - start)
|
901 |
+
|
902 |
+
# =================== Print Profiling Results ===================
|
903 |
+
print("\n=== Profiling Summary (average per frame) ===")
|
904 |
+
for key, times in timings.items():
|
905 |
+
print(f"{key:35s}: {sum(times) / len(times):.4f} sec (total {sum(times):.2f} sec)")
|
906 |
+
|
907 |
+
return viewpoint_stack, closest_indices_selected, visualizations
|
source/data_utils.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def scene_cameras_train_test_split(scene, verbose=False):
|
2 |
+
"""
|
3 |
+
Iterate over resolutions in the scene. For each resolution check if this resolution has test_cameras
|
4 |
+
if it doesn't then extract every 8th camera from the train and put it to the test set. This follows the
|
5 |
+
evaluation protocol suggested by Kerbl et al. in the seminal work on 3DGS. All changes to the input
|
6 |
+
object scene are inplace changes.
|
7 |
+
:param scene: Scene Class object from the gaussian-splatting.scene module
|
8 |
+
:param verbose: Print initial and final stage of the function
|
9 |
+
:return: None
|
10 |
+
|
11 |
+
"""
|
12 |
+
if verbose: print("Preparing train and test sets split...")
|
13 |
+
for resolution in scene.train_cameras.keys():
|
14 |
+
if len(scene.test_cameras[resolution]) == 0:
|
15 |
+
if verbose:
|
16 |
+
print(f"Found no test_cameras for resolution {resolution}. Move every 8th camera out ouf total "+\
|
17 |
+
f"{len(scene.train_cameras[resolution])} train cameras to the test set now")
|
18 |
+
N = len(scene.train_cameras[resolution])
|
19 |
+
scene.test_cameras[resolution] = [scene.train_cameras[resolution][idx] for idx in range(0, N)
|
20 |
+
if idx % 8 == 0]
|
21 |
+
scene.train_cameras[resolution] = [scene.train_cameras[resolution][idx] for idx in range(0, N)
|
22 |
+
if idx % 8 != 0]
|
23 |
+
if verbose:
|
24 |
+
print(f"Done. Now train and test sets contain each {len(scene.train_cameras[resolution])} and " + \
|
25 |
+
f"{len(scene.test_cameras[resolution])} cameras respectively.")
|
26 |
+
|
27 |
+
|
28 |
+
return
|
source/losses.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Code is copied from the gaussian-splatting/utils/loss_utils.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.autograd import Variable
|
6 |
+
from math import exp
|
7 |
+
|
8 |
+
def l1_loss(network_output, gt, mean=True):
|
9 |
+
return torch.abs((network_output - gt)).mean() if mean else torch.abs((network_output - gt))
|
10 |
+
|
11 |
+
def l2_loss(network_output, gt):
|
12 |
+
return ((network_output - gt) ** 2).mean()
|
13 |
+
|
14 |
+
def gaussian(window_size, sigma):
|
15 |
+
gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
|
16 |
+
return gauss / gauss.sum()
|
17 |
+
|
18 |
+
def create_window(window_size, channel):
|
19 |
+
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
20 |
+
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
21 |
+
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
|
22 |
+
return window
|
23 |
+
|
24 |
+
def ssim(img1, img2, window_size=11, size_average=True, mask = None):
|
25 |
+
channel = img1.size(-3)
|
26 |
+
window = create_window(window_size, channel)
|
27 |
+
|
28 |
+
if img1.is_cuda:
|
29 |
+
window = window.cuda(img1.get_device())
|
30 |
+
window = window.type_as(img1)
|
31 |
+
|
32 |
+
return _ssim(img1, img2, window, window_size, channel, size_average, mask)
|
33 |
+
|
34 |
+
def _ssim(img1, img2, window, window_size, channel, size_average=True, mask = None):
|
35 |
+
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
|
36 |
+
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
|
37 |
+
|
38 |
+
mu1_sq = mu1.pow(2)
|
39 |
+
mu2_sq = mu2.pow(2)
|
40 |
+
mu1_mu2 = mu1 * mu2
|
41 |
+
|
42 |
+
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
|
43 |
+
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
|
44 |
+
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
|
45 |
+
|
46 |
+
C1 = 0.01 ** 2
|
47 |
+
C2 = 0.03 ** 2
|
48 |
+
|
49 |
+
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
50 |
+
|
51 |
+
if mask is not None:
|
52 |
+
ssim_map = ssim_map * mask
|
53 |
+
|
54 |
+
if size_average:
|
55 |
+
return ssim_map.mean()
|
56 |
+
else:
|
57 |
+
return ssim_map.mean(1).mean(1).mean(1)
|
58 |
+
|
59 |
+
|
60 |
+
def mse(img1, img2):
|
61 |
+
return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
|
62 |
+
|
63 |
+
def psnr(img1, img2):
|
64 |
+
"""
|
65 |
+
Computes the Peak Signal-to-Noise Ratio (PSNR) between two single images. NOT BATCHED!
|
66 |
+
Args:
|
67 |
+
img1 (torch.Tensor): The first image tensor, with pixel values scaled between 0 and 1.
|
68 |
+
Shape should be (channels, height, width).
|
69 |
+
img2 (torch.Tensor): The second image tensor with the same shape as img1, used for comparison.
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
torch.Tensor: A scalar tensor containing the PSNR value in decibels (dB).
|
73 |
+
"""
|
74 |
+
mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
|
75 |
+
return 20 * torch.log10(1.0 / torch.sqrt(mse))
|
76 |
+
|
77 |
+
|
78 |
+
def tv_loss(image):
|
79 |
+
"""
|
80 |
+
Computes the total variation (TV) loss for an image of shape [3, H, W].
|
81 |
+
|
82 |
+
Args:
|
83 |
+
image (torch.Tensor): Input image of shape [3, H, W]
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
torch.Tensor: Scalar value representing the total variation loss.
|
87 |
+
"""
|
88 |
+
# Ensure the image has the correct dimensions
|
89 |
+
assert image.ndim == 3 and image.shape[0] == 3, "Input must be of shape [3, H, W]"
|
90 |
+
|
91 |
+
# Compute the difference between adjacent pixels in the x-direction (width)
|
92 |
+
diff_x = torch.abs(image[:, :, 1:] - image[:, :, :-1])
|
93 |
+
|
94 |
+
# Compute the difference between adjacent pixels in the y-direction (height)
|
95 |
+
diff_y = torch.abs(image[:, 1:, :] - image[:, :-1, :])
|
96 |
+
|
97 |
+
# Sum the total variation in both directions
|
98 |
+
tv_loss_value = torch.mean(diff_x) + torch.mean(diff_y)
|
99 |
+
|
100 |
+
return tv_loss_value
|
source/networks.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
import sys
|
4 |
+
sys.path.append('./submodules/gaussian-splatting/')
|
5 |
+
|
6 |
+
from random import randint
|
7 |
+
from scene import Scene, GaussianModel
|
8 |
+
from gaussian_renderer import render
|
9 |
+
from source.data_utils import scene_cameras_train_test_split
|
10 |
+
|
11 |
+
class Warper3DGS(torch.nn.Module):
|
12 |
+
def __init__(self, sh_degree, opt, pipe, dataset, viewpoint_stack, verbose,
|
13 |
+
do_train_test_split=True):
|
14 |
+
super(Warper3DGS, self).__init__()
|
15 |
+
"""
|
16 |
+
Init Warper using all the objects necessary for rendering gaussian splats.
|
17 |
+
Here we merely link class objects to the objects instantiated outsided the class.
|
18 |
+
"""
|
19 |
+
self.gaussians = GaussianModel(sh_degree)
|
20 |
+
self.gaussians.tmp_radii = torch.zeros((self.gaussians.get_xyz.shape[0]), device="cuda")
|
21 |
+
self.render = render
|
22 |
+
self.gs_config_opt = opt
|
23 |
+
bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
|
24 |
+
self.bg = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
|
25 |
+
self.pipe = pipe
|
26 |
+
self.scene = Scene(dataset, self.gaussians, shuffle=False)
|
27 |
+
if do_train_test_split:
|
28 |
+
scene_cameras_train_test_split(self.scene, verbose=verbose)
|
29 |
+
|
30 |
+
self.gaussians.training_setup(opt)
|
31 |
+
self.viewpoint_stack = viewpoint_stack
|
32 |
+
if not self.viewpoint_stack:
|
33 |
+
self.viewpoint_stack = self.scene.getTrainCameras().copy()
|
34 |
+
|
35 |
+
def forward(self, viewpoint_cam=None):
|
36 |
+
"""
|
37 |
+
For a provided camera viewpoint_cam we render gaussians from this viewpoint.
|
38 |
+
If no camera provided then we use the self.viewpoint_stack (list of cameras).
|
39 |
+
If the latter is empty we reinitialize it using the self.scene object.
|
40 |
+
"""
|
41 |
+
if not viewpoint_cam:
|
42 |
+
if not self.viewpoint_stack:
|
43 |
+
self.viewpoint_stack = self.scene.getTrainCameras().copy()
|
44 |
+
viewpoint_cam = self.viewpoint_stack[randint(0, len(self.viewpoint_stack) - 1)]
|
45 |
+
|
46 |
+
render_pkg = self.render(viewpoint_cam, self.gaussians, self.pipe, self.bg)
|
47 |
+
return render_pkg
|
48 |
+
|
source/timer.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
class Timer:
|
3 |
+
def __init__(self):
|
4 |
+
self.start_time = None
|
5 |
+
self.elapsed = 0
|
6 |
+
self.paused = False
|
7 |
+
|
8 |
+
def start(self):
|
9 |
+
if self.start_time is None:
|
10 |
+
self.start_time = time.time()
|
11 |
+
elif self.paused:
|
12 |
+
self.start_time = time.time() - self.elapsed
|
13 |
+
self.paused = False
|
14 |
+
|
15 |
+
def pause(self):
|
16 |
+
if not self.paused:
|
17 |
+
self.elapsed = time.time() - self.start_time
|
18 |
+
self.paused = True
|
19 |
+
|
20 |
+
def get_elapsed_time(self):
|
21 |
+
if self.paused:
|
22 |
+
return self.elapsed
|
23 |
+
else:
|
24 |
+
return time.time() - self.start_time
|
source/trainer.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from random import randint
|
3 |
+
from tqdm.rich import trange
|
4 |
+
from tqdm import tqdm as tqdm
|
5 |
+
from source.networks import Warper3DGS
|
6 |
+
import wandb
|
7 |
+
import sys
|
8 |
+
|
9 |
+
sys.path.append('./submodules/gaussian-splatting/')
|
10 |
+
import lpips
|
11 |
+
from source.losses import ssim, l1_loss, psnr
|
12 |
+
from rich.console import Console
|
13 |
+
from rich.theme import Theme
|
14 |
+
|
15 |
+
custom_theme = Theme({
|
16 |
+
"info": "dim cyan",
|
17 |
+
"warning": "magenta",
|
18 |
+
"danger": "bold red"
|
19 |
+
})
|
20 |
+
|
21 |
+
from source.corr_init import init_gaussians_with_corr, init_gaussians_with_corr_fast
|
22 |
+
from source.utils_aux import log_samples
|
23 |
+
|
24 |
+
from source.timer import Timer
|
25 |
+
|
26 |
+
class EDGSTrainer:
|
27 |
+
def __init__(self,
|
28 |
+
GS: Warper3DGS,
|
29 |
+
training_config,
|
30 |
+
dataset_white_background=False,
|
31 |
+
device=torch.device('cuda'),
|
32 |
+
log_wandb=True,
|
33 |
+
):
|
34 |
+
self.GS = GS
|
35 |
+
self.scene = GS.scene
|
36 |
+
self.viewpoint_stack = GS.viewpoint_stack
|
37 |
+
self.gaussians = GS.gaussians
|
38 |
+
|
39 |
+
self.training_config = training_config
|
40 |
+
self.GS_optimizer = GS.gaussians.optimizer
|
41 |
+
self.dataset_white_background = dataset_white_background
|
42 |
+
|
43 |
+
self.training_step = 1
|
44 |
+
self.gs_step = 0
|
45 |
+
self.CONSOLE = Console(width=120, theme=custom_theme)
|
46 |
+
self.saving_iterations = training_config.save_iterations
|
47 |
+
self.evaluate_iterations = None
|
48 |
+
self.batch_size = training_config.batch_size
|
49 |
+
self.ema_loss_for_log = 0.0
|
50 |
+
|
51 |
+
# Logs in the format {step:{"loss1":loss1_value, "loss2":loss2_value}}
|
52 |
+
self.logs_losses = {}
|
53 |
+
self.lpips = lpips.LPIPS(net='vgg').to(device)
|
54 |
+
self.device = device
|
55 |
+
self.timer = Timer()
|
56 |
+
self.log_wandb = log_wandb
|
57 |
+
|
58 |
+
def load_checkpoints(self, load_cfg):
|
59 |
+
# Load 3DGS checkpoint
|
60 |
+
if load_cfg.gs:
|
61 |
+
self.gs.gaussians.restore(
|
62 |
+
torch.load(f"{load_cfg.gs}/chkpnt{load_cfg.gs_step}.pth")[0],
|
63 |
+
self.training_config)
|
64 |
+
self.GS_optimizer = self.GS.gaussians.optimizer
|
65 |
+
self.CONSOLE.print(f"3DGS loaded from checkpoint for iteration {load_cfg.gs_step}",
|
66 |
+
style="info")
|
67 |
+
self.training_step += load_cfg.gs_step
|
68 |
+
self.gs_step += load_cfg.gs_step
|
69 |
+
|
70 |
+
def train(self, train_cfg):
|
71 |
+
# 3DGS training
|
72 |
+
self.CONSOLE.print("Train 3DGS for {} iterations".format(train_cfg.gs_epochs), style="info")
|
73 |
+
with trange(self.training_step, self.training_step + train_cfg.gs_epochs, desc="[green]Train gaussians") as progress_bar:
|
74 |
+
for self.training_step in progress_bar:
|
75 |
+
radii = self.train_step_gs(max_lr=train_cfg.max_lr, no_densify=train_cfg.no_densify)
|
76 |
+
with torch.no_grad():
|
77 |
+
if train_cfg.no_densify:
|
78 |
+
self.prune(radii)
|
79 |
+
else:
|
80 |
+
self.densify_and_prune(radii)
|
81 |
+
if train_cfg.reduce_opacity:
|
82 |
+
# Slightly reduce opacity every few steps:
|
83 |
+
if self.gs_step < self.training_config.densify_until_iter and self.gs_step % 10 == 0:
|
84 |
+
opacities_new = torch.log(torch.exp(self.GS.gaussians._opacity.data) * 0.99)
|
85 |
+
self.GS.gaussians._opacity.data = opacities_new
|
86 |
+
self.timer.pause()
|
87 |
+
# Progress bar
|
88 |
+
if self.training_step % 10 == 0:
|
89 |
+
progress_bar.set_postfix({"[red]Loss": f"{self.ema_loss_for_log:.{7}f}"}, refresh=True)
|
90 |
+
# Log and save
|
91 |
+
if self.training_step in self.saving_iterations:
|
92 |
+
self.save_model()
|
93 |
+
if self.evaluate_iterations is not None:
|
94 |
+
if self.training_step in self.evaluate_iterations:
|
95 |
+
self.evaluate()
|
96 |
+
else:
|
97 |
+
if (self.training_step <= 3000 and self.training_step % 500 == 0) or \
|
98 |
+
(self.training_step > 3000 and self.training_step % 1000 == 228) :
|
99 |
+
self.evaluate()
|
100 |
+
|
101 |
+
self.timer.start()
|
102 |
+
|
103 |
+
|
104 |
+
def evaluate(self):
|
105 |
+
torch.cuda.empty_cache()
|
106 |
+
log_gen_images, log_real_images = [], []
|
107 |
+
validation_configs = ({'name': 'test', 'cameras': self.scene.getTestCameras(), 'cam_idx': self.training_config.TEST_CAM_IDX_TO_LOG},
|
108 |
+
{'name': 'train',
|
109 |
+
'cameras': [self.scene.getTrainCameras()[idx % len(self.scene.getTrainCameras())] for idx in
|
110 |
+
range(0, 150, 5)], 'cam_idx': 10})
|
111 |
+
if self.log_wandb:
|
112 |
+
wandb.log({f"Number of Gaussians": len(self.GS.gaussians._xyz)}, step=self.training_step)
|
113 |
+
for config in validation_configs:
|
114 |
+
if config['cameras'] and len(config['cameras']) > 0:
|
115 |
+
l1_test = 0.0
|
116 |
+
psnr_test = 0.0
|
117 |
+
ssim_test = 0.0
|
118 |
+
lpips_splat_test = 0.0
|
119 |
+
for idx, viewpoint in enumerate(config['cameras']):
|
120 |
+
image = torch.clamp(self.GS(viewpoint)["render"], 0.0, 1.0)
|
121 |
+
gt_image = torch.clamp(viewpoint.original_image.to(self.device), 0.0, 1.0)
|
122 |
+
l1_test += l1_loss(image, gt_image).double()
|
123 |
+
psnr_test += psnr(image.unsqueeze(0), gt_image.unsqueeze(0)).double()
|
124 |
+
ssim_test += ssim(image, gt_image).double()
|
125 |
+
lpips_splat_test += self.lpips(image, gt_image).detach().double()
|
126 |
+
if idx in [config['cam_idx']]:
|
127 |
+
log_gen_images.append(image)
|
128 |
+
log_real_images.append(gt_image)
|
129 |
+
psnr_test /= len(config['cameras'])
|
130 |
+
l1_test /= len(config['cameras'])
|
131 |
+
ssim_test /= len(config['cameras'])
|
132 |
+
lpips_splat_test /= len(config['cameras'])
|
133 |
+
if self.log_wandb:
|
134 |
+
wandb.log({f"{config['name']}/L1": l1_test.item(), f"{config['name']}/PSNR": psnr_test.item(), \
|
135 |
+
f"{config['name']}/SSIM": ssim_test.item(), f"{config['name']}/LPIPS_splat": lpips_splat_test.item()}, step = self.training_step)
|
136 |
+
self.CONSOLE.print("\n[ITER {}], #{} gaussians, Evaluating {}: L1={:.6f}, PSNR={:.6f}, SSIM={:.6f}, LPIPS_splat={:.6f} ".format(
|
137 |
+
self.training_step, len(self.GS.gaussians._xyz), config['name'], l1_test.item(), psnr_test.item(), ssim_test.item(), lpips_splat_test.item()), style="info")
|
138 |
+
if self.log_wandb:
|
139 |
+
with torch.no_grad():
|
140 |
+
log_samples(torch.stack((log_real_images[0],log_gen_images[0])) , [], self.training_step, caption="Real and Generated Samples")
|
141 |
+
wandb.log({"time": self.timer.get_elapsed_time()}, step=self.training_step)
|
142 |
+
torch.cuda.empty_cache()
|
143 |
+
|
144 |
+
def train_step_gs(self, max_lr = False, no_densify = False):
|
145 |
+
self.gs_step += 1
|
146 |
+
if max_lr:
|
147 |
+
self.GS.gaussians.update_learning_rate(max(self.gs_step, 8_000))
|
148 |
+
else:
|
149 |
+
self.GS.gaussians.update_learning_rate(self.gs_step)
|
150 |
+
# Every 1000 its we increase the levels of SH up to a maximum degree
|
151 |
+
if self.gs_step % 1000 == 0:
|
152 |
+
self.GS.gaussians.oneupSHdegree()
|
153 |
+
|
154 |
+
# Pick a random Camera
|
155 |
+
if not self.viewpoint_stack:
|
156 |
+
self.viewpoint_stack = self.scene.getTrainCameras().copy()
|
157 |
+
viewpoint_cam = self.viewpoint_stack.pop(randint(0, len(self.viewpoint_stack) - 1))
|
158 |
+
|
159 |
+
render_pkg = self.GS(viewpoint_cam=viewpoint_cam)
|
160 |
+
image = render_pkg["render"]
|
161 |
+
# Loss
|
162 |
+
gt_image = viewpoint_cam.original_image.to(self.device)
|
163 |
+
L1_loss = l1_loss(image, gt_image)
|
164 |
+
|
165 |
+
ssim_loss = (1.0 - ssim(image, gt_image))
|
166 |
+
loss = (1.0 - self.training_config.lambda_dssim) * L1_loss + \
|
167 |
+
self.training_config.lambda_dssim * ssim_loss
|
168 |
+
self.timer.pause()
|
169 |
+
self.logs_losses[self.training_step] = {"loss": loss.item(),
|
170 |
+
"L1_loss": L1_loss.item(),
|
171 |
+
"ssim_loss": ssim_loss.item()}
|
172 |
+
|
173 |
+
if self.log_wandb:
|
174 |
+
for k, v in self.logs_losses[self.training_step].items():
|
175 |
+
wandb.log({f"train/{k}": v}, step=self.training_step)
|
176 |
+
self.ema_loss_for_log = 0.4 * self.logs_losses[self.training_step]["loss"] + 0.6 * self.ema_loss_for_log
|
177 |
+
self.timer.start()
|
178 |
+
self.GS_optimizer.zero_grad(set_to_none=True)
|
179 |
+
loss.backward()
|
180 |
+
with torch.no_grad():
|
181 |
+
if self.gs_step < self.training_config.densify_until_iter and not no_densify:
|
182 |
+
self.GS.gaussians.max_radii2D[render_pkg["visibility_filter"]] = torch.max(
|
183 |
+
self.GS.gaussians.max_radii2D[render_pkg["visibility_filter"]],
|
184 |
+
render_pkg["radii"][render_pkg["visibility_filter"]])
|
185 |
+
self.GS.gaussians.add_densification_stats(render_pkg["viewspace_points"],
|
186 |
+
render_pkg["visibility_filter"])
|
187 |
+
|
188 |
+
# Optimizer step
|
189 |
+
self.GS_optimizer.step()
|
190 |
+
self.GS_optimizer.zero_grad(set_to_none=True)
|
191 |
+
return render_pkg["radii"]
|
192 |
+
|
193 |
+
def densify_and_prune(self, radii = None):
|
194 |
+
# Densification or pruning
|
195 |
+
if self.gs_step < self.training_config.densify_until_iter:
|
196 |
+
if (self.gs_step > self.training_config.densify_from_iter) and \
|
197 |
+
(self.gs_step % self.training_config.densification_interval == 0):
|
198 |
+
size_threshold = 20 if self.gs_step > self.training_config.opacity_reset_interval else None
|
199 |
+
self.GS.gaussians.densify_and_prune(self.training_config.densify_grad_threshold,
|
200 |
+
0.005,
|
201 |
+
self.GS.scene.cameras_extent,
|
202 |
+
size_threshold, radii)
|
203 |
+
if self.gs_step % self.training_config.opacity_reset_interval == 0 or (
|
204 |
+
self.dataset_white_background and self.gs_step == self.training_config.densify_from_iter):
|
205 |
+
self.GS.gaussians.reset_opacity()
|
206 |
+
|
207 |
+
|
208 |
+
|
209 |
+
def save_model(self):
|
210 |
+
print("\n[ITER {}] Saving Gaussians".format(self.gs_step))
|
211 |
+
self.scene.save(self.gs_step)
|
212 |
+
print("\n[ITER {}] Saving Checkpoint".format(self.gs_step))
|
213 |
+
torch.save((self.GS.gaussians.capture(), self.gs_step),
|
214 |
+
self.scene.model_path + "/chkpnt" + str(self.gs_step) + ".pth")
|
215 |
+
|
216 |
+
|
217 |
+
def init_with_corr(self, cfg, verbose=False, roma_model=None):
|
218 |
+
"""
|
219 |
+
Initializes image with matchings. Also removes SfM init points.
|
220 |
+
Args:
|
221 |
+
cfg: configuration part named init_wC. Check train.yaml
|
222 |
+
verbose: whether you want to print intermediate results. Useful for debug.
|
223 |
+
roma_model: optionally you can pass here preinit RoMA model to avoid reinit
|
224 |
+
it every time.
|
225 |
+
"""
|
226 |
+
if not cfg.use:
|
227 |
+
return None
|
228 |
+
N_splats_at_init = len(self.GS.gaussians._xyz)
|
229 |
+
print("N_splats_at_init:", N_splats_at_init)
|
230 |
+
if cfg.nns_per_ref == 1:
|
231 |
+
init_fn = init_gaussians_with_corr_fast
|
232 |
+
else:
|
233 |
+
init_fn = init_gaussians_with_corr
|
234 |
+
camera_set, selected_indices, visualization_dict = init_fn(
|
235 |
+
self.GS.gaussians,
|
236 |
+
self.scene,
|
237 |
+
cfg,
|
238 |
+
self.device,
|
239 |
+
verbose=verbose,
|
240 |
+
roma_model=roma_model)
|
241 |
+
|
242 |
+
# Remove SfM points and leave only matchings inits
|
243 |
+
if not cfg.add_SfM_init:
|
244 |
+
with torch.no_grad():
|
245 |
+
N_splats_after_init = len(self.GS.gaussians._xyz)
|
246 |
+
print("N_splats_after_init:", N_splats_after_init)
|
247 |
+
self.gaussians.tmp_radii = torch.zeros(self.gaussians._xyz.shape[0]).to(self.device)
|
248 |
+
mask = torch.concat([torch.ones(N_splats_at_init, dtype=torch.bool),
|
249 |
+
torch.zeros(N_splats_after_init-N_splats_at_init, dtype=torch.bool)],
|
250 |
+
axis=0)
|
251 |
+
self.GS.gaussians.prune_points(mask)
|
252 |
+
with torch.no_grad():
|
253 |
+
gaussians = self.gaussians
|
254 |
+
gaussians._scaling = gaussians.scaling_inverse_activation(gaussians.scaling_activation(gaussians._scaling)*0.5)
|
255 |
+
return visualization_dict
|
256 |
+
|
257 |
+
|
258 |
+
def prune(self, radii, min_opacity=0.005):
|
259 |
+
self.GS.gaussians.tmp_radii = radii
|
260 |
+
if self.gs_step < self.training_config.densify_until_iter:
|
261 |
+
prune_mask = (self.GS.gaussians.get_opacity < min_opacity).squeeze()
|
262 |
+
self.GS.gaussians.prune_points(prune_mask)
|
263 |
+
torch.cuda.empty_cache()
|
264 |
+
self.GS.gaussians.tmp_radii = None
|
265 |
+
|
source/utils_aux.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Perlin noise code taken from https://gist.github.com/adefossez/0646dbe9ed4005480a2407c62aac8869
|
2 |
+
from types import SimpleNamespace
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torchvision
|
7 |
+
import wandb
|
8 |
+
import random
|
9 |
+
import torchvision.transforms as T
|
10 |
+
import torchvision.transforms.functional as F
|
11 |
+
import torch
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
def parse_dict_to_namespace(dict_nested):
|
15 |
+
"""Turns nested dictionary into nested namespaces"""
|
16 |
+
if type(dict_nested) != dict and type(dict_nested) != list: return dict_nested
|
17 |
+
x = SimpleNamespace()
|
18 |
+
for key, val in dict_nested.items():
|
19 |
+
if type(val) == dict:
|
20 |
+
setattr(x, key, parse_dict_to_namespace(val))
|
21 |
+
elif type(val) == list:
|
22 |
+
setattr(x, key, [parse_dict_to_namespace(v) for v in val])
|
23 |
+
else:
|
24 |
+
setattr(x, key, val)
|
25 |
+
return x
|
26 |
+
|
27 |
+
def set_seed(seed=42, cuda=True):
|
28 |
+
random.seed(seed)
|
29 |
+
np.random.seed(seed)
|
30 |
+
torch.manual_seed(seed)
|
31 |
+
if cuda:
|
32 |
+
torch.cuda.manual_seed_all(seed)
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
def log_samples(samples, scores, iteration, caption="Real Samples"):
|
37 |
+
# Create a grid of images
|
38 |
+
grid = torchvision.utils.make_grid(samples)
|
39 |
+
|
40 |
+
# Log the images and scores to wandb
|
41 |
+
wandb.log({
|
42 |
+
f"{caption}_images": [wandb.Image(grid, caption=f"{caption}: {scores}")],
|
43 |
+
}, step = iteration)
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
def pairwise_distances(matrix):
|
48 |
+
"""
|
49 |
+
Computes the pairwise Euclidean distances between all vectors in the input matrix.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
matrix (torch.Tensor): Input matrix of shape [N, D], where N is the number of vectors and D is the dimensionality.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
torch.Tensor: Pairwise distance matrix of shape [N, N].
|
56 |
+
"""
|
57 |
+
# Compute squared pairwise distances
|
58 |
+
squared_diff = torch.cdist(matrix, matrix, p=2)
|
59 |
+
return squared_diff
|
60 |
+
|
61 |
+
def k_closest_vectors(matrix, k):
|
62 |
+
"""
|
63 |
+
Finds the k-closest vectors for each vector in the input matrix based on Euclidean distance.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
matrix (torch.Tensor): Input matrix of shape [N, D], where N is the number of vectors and D is the dimensionality.
|
67 |
+
k (int): Number of closest vectors to return for each vector.
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
torch.Tensor: Indices of the k-closest vectors for each vector, excluding the vector itself.
|
71 |
+
"""
|
72 |
+
# Compute pairwise distances
|
73 |
+
distances = pairwise_distances(matrix)
|
74 |
+
|
75 |
+
# For each vector, sort distances and get the indices of the k-closest vectors (excluding itself)
|
76 |
+
# Set diagonal distances to infinity to exclude the vector itself from the nearest neighbors
|
77 |
+
distances.fill_diagonal_(float('inf'))
|
78 |
+
|
79 |
+
# Get the indices of the k smallest distances (k-closest vectors)
|
80 |
+
_, indices = torch.topk(distances, k, largest=False, dim=1)
|
81 |
+
|
82 |
+
return indices
|
83 |
+
|
84 |
+
def process_image(image_tensor):
|
85 |
+
image_np = image_tensor.detach().cpu().numpy().transpose(1, 2, 0)
|
86 |
+
return Image.fromarray(np.clip(image_np * 255, 0, 255).astype(np.uint8))
|
87 |
+
|
88 |
+
|
89 |
+
def normalize_keypoints(kpts_np, width, height):
|
90 |
+
kpts_np[:, 0] = kpts_np[:, 0] / width * 2. - 1.
|
91 |
+
kpts_np[:, 1] = kpts_np[:, 1] / height * 2. - 1.
|
92 |
+
return kpts_np
|
source/utils_preprocess.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file contains function for video or image collection preprocessing.
|
2 |
+
# For video we do the preprocessing and select k sharpest frames.
|
3 |
+
# Afterwards scene is constructed
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
from tqdm import tqdm
|
7 |
+
import pycolmap
|
8 |
+
import os
|
9 |
+
import time
|
10 |
+
import tempfile
|
11 |
+
from moviepy import VideoFileClip
|
12 |
+
from matplotlib import pyplot as plt
|
13 |
+
from PIL import Image
|
14 |
+
import cv2
|
15 |
+
from tqdm import tqdm
|
16 |
+
|
17 |
+
WORKDIR = "../outputs/"
|
18 |
+
|
19 |
+
|
20 |
+
def get_rotation_moviepy(video_path):
|
21 |
+
clip = VideoFileClip(video_path)
|
22 |
+
rotation = 0
|
23 |
+
|
24 |
+
try:
|
25 |
+
displaymatrix = clip.reader.infos['inputs'][0]['streams'][2]['metadata'].get('displaymatrix', '')
|
26 |
+
if 'rotation of' in displaymatrix:
|
27 |
+
angle = float(displaymatrix.strip().split('rotation of')[-1].split('degrees')[0])
|
28 |
+
rotation = int(angle) % 360
|
29 |
+
|
30 |
+
except Exception as e:
|
31 |
+
print(f"No displaymatrix rotation found: {e}")
|
32 |
+
|
33 |
+
clip.reader.close()
|
34 |
+
#if clip.audio:
|
35 |
+
# clip.audio.reader.close_proc()
|
36 |
+
|
37 |
+
return rotation
|
38 |
+
|
39 |
+
def resize_max_side(frame, max_size):
|
40 |
+
h, w = frame.shape[:2]
|
41 |
+
scale = max_size / max(h, w)
|
42 |
+
if scale < 1:
|
43 |
+
frame = cv2.resize(frame, (int(w * scale), int(h * scale)))
|
44 |
+
return frame
|
45 |
+
|
46 |
+
def read_video_frames(video_input, k=1, max_size=1024):
|
47 |
+
"""
|
48 |
+
Extracts every k-th frame from a video or list of images, resizes to max size, and returns frames as list.
|
49 |
+
|
50 |
+
Parameters:
|
51 |
+
video_input (str, file-like, or list): Path to video file, file-like object, or list of image files.
|
52 |
+
k (int): Interval for frame extraction (every k-th frame).
|
53 |
+
max_size (int): Maximum size for width or height after resizing.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
frames (list): List of resized frames (numpy arrays).
|
57 |
+
"""
|
58 |
+
# Handle list of image files (not single video in a list)
|
59 |
+
if isinstance(video_input, list):
|
60 |
+
# If it's a single video in a list, treat it as video
|
61 |
+
if len(video_input) == 1 and video_input[0].name.endswith(('.mp4', '.avi', '.mov')):
|
62 |
+
video_input = video_input[0] # unwrap single video file
|
63 |
+
else:
|
64 |
+
# Treat as list of images
|
65 |
+
frames = []
|
66 |
+
for img_file in video_input:
|
67 |
+
img = Image.open(img_file.name).convert("RGB")
|
68 |
+
img.thumbnail((max_size, max_size))
|
69 |
+
frames.append(np.array(img)[...,::-1])
|
70 |
+
return frames
|
71 |
+
|
72 |
+
# Handle file-like or path
|
73 |
+
if hasattr(video_input, 'name'):
|
74 |
+
video_path = video_input.name
|
75 |
+
elif isinstance(video_input, (str, os.PathLike)):
|
76 |
+
video_path = str(video_input)
|
77 |
+
else:
|
78 |
+
raise ValueError("Unsupported video input type. Must be a filepath, file-like object, or list of images.")
|
79 |
+
|
80 |
+
|
81 |
+
cap = cv2.VideoCapture(video_path)
|
82 |
+
if not cap.isOpened():
|
83 |
+
raise ValueError(f"Error: Could not open video {video_path}.")
|
84 |
+
|
85 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
86 |
+
frame_count = 0
|
87 |
+
frames = []
|
88 |
+
|
89 |
+
with tqdm(total=total_frames // k, desc="Processing Video", unit="frame") as pbar:
|
90 |
+
while True:
|
91 |
+
ret, frame = cap.read()
|
92 |
+
if not ret:
|
93 |
+
break
|
94 |
+
if frame_count % k == 0:
|
95 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
96 |
+
h, w = frame.shape[:2]
|
97 |
+
scale = max(h, w) / max_size
|
98 |
+
if scale > 1:
|
99 |
+
frame = cv2.resize(frame, (int(w / scale), int(h / scale)))
|
100 |
+
frames.append(frame[...,[2,1,0]])
|
101 |
+
pbar.update(1)
|
102 |
+
frame_count += 1
|
103 |
+
|
104 |
+
cap.release()
|
105 |
+
return frames
|
106 |
+
|
107 |
+
def resize_max_side(frame, max_size):
|
108 |
+
"""
|
109 |
+
Resizes the frame so that its largest side equals max_size, maintaining aspect ratio.
|
110 |
+
"""
|
111 |
+
height, width = frame.shape[:2]
|
112 |
+
max_dim = max(height, width)
|
113 |
+
|
114 |
+
if max_dim <= max_size:
|
115 |
+
return frame # No need to resize
|
116 |
+
|
117 |
+
scale = max_size / max_dim
|
118 |
+
new_width = int(width * scale)
|
119 |
+
new_height = int(height * scale)
|
120 |
+
|
121 |
+
resized_frame = cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_AREA)
|
122 |
+
return resized_frame
|
123 |
+
|
124 |
+
|
125 |
+
|
126 |
+
def variance_of_laplacian(image):
|
127 |
+
# compute the Laplacian of the image and then return the focus
|
128 |
+
# measure, which is simply the variance of the Laplacian
|
129 |
+
return cv2.Laplacian(image, cv2.CV_64F).var()
|
130 |
+
|
131 |
+
def process_all_frames(IMG_FOLDER = '/scratch/datasets/hq_data/night2_all_frames',
|
132 |
+
to_visualize=False,
|
133 |
+
save_images=True):
|
134 |
+
dict_scores = {}
|
135 |
+
for idx, img_name in tqdm(enumerate(sorted([x for x in os.listdir(IMG_FOLDER) if '.png' in x]))):
|
136 |
+
|
137 |
+
img = cv2.imread(os.path.join(IMG_FOLDER, img_name))#[250:, 100:]
|
138 |
+
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
139 |
+
fm = variance_of_laplacian(gray) + \
|
140 |
+
variance_of_laplacian(cv2.resize(gray, (0,0), fx=0.75, fy=0.75)) + \
|
141 |
+
variance_of_laplacian(cv2.resize(gray, (0,0), fx=0.5, fy=0.5)) + \
|
142 |
+
variance_of_laplacian(cv2.resize(gray, (0,0), fx=0.25, fy=0.25))
|
143 |
+
if to_visualize:
|
144 |
+
plt.figure()
|
145 |
+
plt.title(f"Laplacian score: {fm:.2f}")
|
146 |
+
plt.imshow(img[..., [2,1,0]])
|
147 |
+
plt.show()
|
148 |
+
dict_scores[idx] = {"idx" : idx,
|
149 |
+
"img_name" : img_name,
|
150 |
+
"score" : fm}
|
151 |
+
if save_images:
|
152 |
+
dict_scores[idx]["img"] = img
|
153 |
+
|
154 |
+
return dict_scores
|
155 |
+
|
156 |
+
def select_optimal_frames(scores, k):
|
157 |
+
"""
|
158 |
+
Selects a minimal subset of frames while ensuring no gaps exceed k.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
scores (list of float): List of scores where index represents frame number.
|
162 |
+
k (int): Maximum allowed gap between selected frames.
|
163 |
+
|
164 |
+
Returns:
|
165 |
+
list of int: Indices of selected frames.
|
166 |
+
"""
|
167 |
+
n = len(scores)
|
168 |
+
selected = [0, n-1]
|
169 |
+
i = 0 # Start at the first frame
|
170 |
+
|
171 |
+
while i < n:
|
172 |
+
# Find the best frame to select within the next k frames
|
173 |
+
best_idx = max(range(i, min(i + k + 1, n)), key=lambda x: scores[x], default=None)
|
174 |
+
|
175 |
+
if best_idx is None:
|
176 |
+
break # No more frames left
|
177 |
+
|
178 |
+
selected.append(best_idx)
|
179 |
+
i = best_idx + k + 1 # Move forward, ensuring gaps stay within k
|
180 |
+
|
181 |
+
return sorted(selected)
|
182 |
+
|
183 |
+
|
184 |
+
def variance_of_laplacian(image):
|
185 |
+
"""
|
186 |
+
Compute the variance of Laplacian as a focus measure.
|
187 |
+
"""
|
188 |
+
return cv2.Laplacian(image, cv2.CV_64F).var()
|
189 |
+
|
190 |
+
def preprocess_frames(frames, verbose=False):
|
191 |
+
"""
|
192 |
+
Compute sharpness scores for a list of frames using multi-scale Laplacian variance.
|
193 |
+
|
194 |
+
Args:
|
195 |
+
frames (list of np.ndarray): List of frames (BGR images).
|
196 |
+
verbose (bool): If True, print scores.
|
197 |
+
|
198 |
+
Returns:
|
199 |
+
list of float: Sharpness scores for each frame.
|
200 |
+
"""
|
201 |
+
scores = []
|
202 |
+
|
203 |
+
for idx, frame in enumerate(tqdm(frames, desc="Scoring frames")):
|
204 |
+
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
205 |
+
|
206 |
+
fm = (
|
207 |
+
variance_of_laplacian(gray) +
|
208 |
+
variance_of_laplacian(cv2.resize(gray, (0, 0), fx=0.75, fy=0.75)) +
|
209 |
+
variance_of_laplacian(cv2.resize(gray, (0, 0), fx=0.5, fy=0.5)) +
|
210 |
+
variance_of_laplacian(cv2.resize(gray, (0, 0), fx=0.25, fy=0.25))
|
211 |
+
)
|
212 |
+
|
213 |
+
if verbose:
|
214 |
+
print(f"Frame {idx}: Sharpness Score = {fm:.2f}")
|
215 |
+
|
216 |
+
scores.append(fm)
|
217 |
+
|
218 |
+
return scores
|
219 |
+
|
220 |
+
def select_optimal_frames(scores, k):
|
221 |
+
"""
|
222 |
+
Selects k frames by splitting into k segments and picking the sharpest frame from each.
|
223 |
+
|
224 |
+
Args:
|
225 |
+
scores (list of float): List of sharpness scores.
|
226 |
+
k (int): Number of frames to select.
|
227 |
+
|
228 |
+
Returns:
|
229 |
+
list of int: Indices of selected frames.
|
230 |
+
"""
|
231 |
+
n = len(scores)
|
232 |
+
selected_indices = []
|
233 |
+
segment_size = n // k
|
234 |
+
|
235 |
+
for i in range(k):
|
236 |
+
start = i * segment_size
|
237 |
+
end = (i + 1) * segment_size if i < k - 1 else n # Last chunk may be larger
|
238 |
+
segment_scores = scores[start:end]
|
239 |
+
|
240 |
+
if len(segment_scores) == 0:
|
241 |
+
continue # Safety check if some segment is empty
|
242 |
+
|
243 |
+
best_in_segment = start + np.argmax(segment_scores)
|
244 |
+
selected_indices.append(best_in_segment)
|
245 |
+
|
246 |
+
return sorted(selected_indices)
|
247 |
+
|
248 |
+
def save_frames_to_scene_dir(frames, scene_dir):
|
249 |
+
"""
|
250 |
+
Saves a list of frames into the target scene directory under 'images/' subfolder.
|
251 |
+
|
252 |
+
Args:
|
253 |
+
frames (list of np.ndarray): List of frames (BGR images) to save.
|
254 |
+
scene_dir (str): Target path where 'images/' subfolder will be created.
|
255 |
+
"""
|
256 |
+
images_dir = os.path.join(scene_dir, "images")
|
257 |
+
os.makedirs(images_dir, exist_ok=True)
|
258 |
+
|
259 |
+
for idx, frame in enumerate(frames):
|
260 |
+
filename = os.path.join(images_dir, f"{idx:08d}.png") # 00000000.png, 00000001.png, etc.
|
261 |
+
cv2.imwrite(filename, frame)
|
262 |
+
|
263 |
+
print(f"Saved {len(frames)} frames to {images_dir}")
|
264 |
+
|
265 |
+
|
266 |
+
def run_colmap_on_scene(scene_dir):
|
267 |
+
"""
|
268 |
+
Runs feature extraction, matching, and mapping on all images inside scene_dir/images using pycolmap.
|
269 |
+
|
270 |
+
Args:
|
271 |
+
scene_dir (str): Path to scene directory containing 'images' folder.
|
272 |
+
|
273 |
+
TODO: if the function hasn't managed to match all the frames either increase image size,
|
274 |
+
increase number of features or just remove those frames from the folder scene_dir/images
|
275 |
+
"""
|
276 |
+
start_time = time.time()
|
277 |
+
print(f"Running COLMAP pipeline on all images inside {scene_dir}")
|
278 |
+
|
279 |
+
# Setup paths
|
280 |
+
database_path = os.path.join(scene_dir, "database.db")
|
281 |
+
sparse_path = os.path.join(scene_dir, "sparse")
|
282 |
+
image_dir = os.path.join(scene_dir, "images")
|
283 |
+
|
284 |
+
# Make sure output directories exist
|
285 |
+
os.makedirs(sparse_path, exist_ok=True)
|
286 |
+
|
287 |
+
# Step 1: Feature Extraction
|
288 |
+
pycolmap.extract_features(
|
289 |
+
database_path,
|
290 |
+
image_dir,
|
291 |
+
sift_options={
|
292 |
+
"max_num_features": 512 * 2,
|
293 |
+
"max_image_size": 512 * 1,
|
294 |
+
}
|
295 |
+
)
|
296 |
+
print(f"Finished feature extraction in {(time.time() - start_time):.2f}s.")
|
297 |
+
|
298 |
+
# Step 2: Feature Matching
|
299 |
+
pycolmap.match_exhaustive(database_path)
|
300 |
+
print(f"Finished feature matching in {(time.time() - start_time):.2f}s.")
|
301 |
+
|
302 |
+
# Step 3: Mapping
|
303 |
+
pipeline_options = pycolmap.IncrementalPipelineOptions()
|
304 |
+
pipeline_options.min_num_matches = 15
|
305 |
+
pipeline_options.multiple_models = True
|
306 |
+
pipeline_options.max_num_models = 50
|
307 |
+
pipeline_options.max_model_overlap = 20
|
308 |
+
pipeline_options.min_model_size = 10
|
309 |
+
pipeline_options.extract_colors = True
|
310 |
+
pipeline_options.num_threads = 8
|
311 |
+
pipeline_options.mapper.init_min_num_inliers = 30
|
312 |
+
pipeline_options.mapper.init_max_error = 8.0
|
313 |
+
pipeline_options.mapper.init_min_tri_angle = 5.0
|
314 |
+
|
315 |
+
reconstruction = pycolmap.incremental_mapping(
|
316 |
+
database_path=database_path,
|
317 |
+
image_path=image_dir,
|
318 |
+
output_path=sparse_path,
|
319 |
+
options=pipeline_options,
|
320 |
+
)
|
321 |
+
print(f"Finished incremental mapping in {(time.time() - start_time):.2f}s.")
|
322 |
+
|
323 |
+
# Step 4: Post-process Cameras to SIMPLE_PINHOLE
|
324 |
+
recon_path = os.path.join(sparse_path, "0")
|
325 |
+
reconstruction = pycolmap.Reconstruction(recon_path)
|
326 |
+
|
327 |
+
for cam in reconstruction.cameras.values():
|
328 |
+
cam.model = 'SIMPLE_PINHOLE'
|
329 |
+
cam.params = cam.params[:3] # Keep only [f, cx, cy]
|
330 |
+
|
331 |
+
reconstruction.write(recon_path)
|
332 |
+
|
333 |
+
print(f"Total pipeline time: {(time.time() - start_time):.2f}s.")
|
334 |
+
|
source/visualization.py
ADDED
@@ -0,0 +1,1072 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from matplotlib import pyplot as plt
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from typing import List
|
7 |
+
import sys
|
8 |
+
sys.path.append('./submodules/gaussian-splatting/')
|
9 |
+
from scene.cameras import Camera
|
10 |
+
from PIL import Image
|
11 |
+
import imageio
|
12 |
+
from scipy.interpolate import splprep, splev
|
13 |
+
|
14 |
+
import cv2
|
15 |
+
import numpy as np
|
16 |
+
import plotly.graph_objects as go
|
17 |
+
import numpy as np
|
18 |
+
from scipy.spatial.transform import Rotation as R, Slerp
|
19 |
+
from scipy.spatial import distance_matrix
|
20 |
+
from sklearn.decomposition import PCA
|
21 |
+
from scipy.interpolate import splprep, splev
|
22 |
+
from typing import List
|
23 |
+
from sklearn.mixture import GaussianMixture
|
24 |
+
|
25 |
+
def render_gaussians_rgb(generator3DGS, viewpoint_cam, visualize=False):
|
26 |
+
"""
|
27 |
+
Simply render gaussians from the generator3DGS from the viewpoint_cam.
|
28 |
+
Args:
|
29 |
+
generator3DGS : instance of the Generator3DGS class from the networks.py file
|
30 |
+
viewpoint_cam : camera instance
|
31 |
+
visualize : boolean flag. If True, will call pyplot function and render image inplace
|
32 |
+
Returns:
|
33 |
+
uint8 numpy array with shape (H, W, 3) representing the image
|
34 |
+
"""
|
35 |
+
with torch.no_grad():
|
36 |
+
render_pkg = generator3DGS(viewpoint_cam)
|
37 |
+
image = render_pkg["render"]
|
38 |
+
image_np = image.clone().detach().cpu().numpy().transpose(1, 2, 0)
|
39 |
+
|
40 |
+
# Clip values to be in the range [0, 1]
|
41 |
+
image_np = np.clip(image_np * 255, 0, 255).astype(np.uint8)
|
42 |
+
if visualize:
|
43 |
+
plt.figure(figsize=(12, 8))
|
44 |
+
plt.imshow(image_np)
|
45 |
+
plt.show()
|
46 |
+
|
47 |
+
return image_np
|
48 |
+
|
49 |
+
def render_gaussians_D_scores(generator3DGS, viewpoint_cam, mask=None, mask_channel=0, visualize=False):
|
50 |
+
"""
|
51 |
+
Simply render D_scores of gaussians from the generator3DGS from the viewpoint_cam.
|
52 |
+
Args:
|
53 |
+
generator3DGS : instance of the Generator3DGS class from the networks.py file
|
54 |
+
viewpoint_cam : camera instance
|
55 |
+
visualize : boolean flag. If True, will call pyplot function and render image inplace
|
56 |
+
mask : optional mask to highlight specific gaussians. Must be of shape (N) where N is the numnber
|
57 |
+
of gaussians in generator3DGS.gaussians. Must be a torch tensor of floats, please scale according
|
58 |
+
to how much color you want to have. Recommended mask value is 10.
|
59 |
+
mask_channel: to which color channel should we add mask
|
60 |
+
Returns:
|
61 |
+
uint8 numpy array with shape (H, W, 3) representing the generator3DGS.gaussians.D_scores rendered as colors
|
62 |
+
"""
|
63 |
+
with torch.no_grad():
|
64 |
+
# Visualize D_scores
|
65 |
+
generator3DGS.gaussians._features_dc = generator3DGS.gaussians._features_dc * 1e-4 + \
|
66 |
+
torch.stack([generator3DGS.gaussians.D_scores] * 3, axis=-1)
|
67 |
+
generator3DGS.gaussians._features_rest = generator3DGS.gaussians._features_rest * 1e-4
|
68 |
+
if mask is not None:
|
69 |
+
generator3DGS.gaussians._features_dc[..., mask_channel] += mask.unsqueeze(-1)
|
70 |
+
render_pkg = generator3DGS(viewpoint_cam)
|
71 |
+
image = render_pkg["render"]
|
72 |
+
image_np = image.clone().detach().cpu().numpy().transpose(1, 2, 0)
|
73 |
+
|
74 |
+
# Clip values to be in the range [0, 1]
|
75 |
+
image_np = np.clip(image_np * 255, 0, 255).astype(np.uint8)
|
76 |
+
if visualize:
|
77 |
+
plt.figure(figsize=(12, 8))
|
78 |
+
plt.imshow(image_np)
|
79 |
+
plt.show()
|
80 |
+
|
81 |
+
if mask is not None:
|
82 |
+
generator3DGS.gaussians._features_dc[..., mask_channel] -= mask.unsqueeze(-1)
|
83 |
+
|
84 |
+
generator3DGS.gaussians._features_dc = (generator3DGS.gaussians._features_dc - \
|
85 |
+
torch.stack([generator3DGS.gaussians.D_scores] * 3, axis=-1)) * 1e4
|
86 |
+
generator3DGS.gaussians._features_rest = generator3DGS.gaussians._features_rest * 1e4
|
87 |
+
|
88 |
+
return image_np
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
def normalize(v):
|
93 |
+
"""
|
94 |
+
Normalize a vector to unit length.
|
95 |
+
|
96 |
+
Parameters:
|
97 |
+
v (np.ndarray): Input vector.
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
np.ndarray: Unit vector in the same direction as `v`.
|
101 |
+
"""
|
102 |
+
return v / np.linalg.norm(v)
|
103 |
+
|
104 |
+
def look_at_rotation(camera_position: np.ndarray, target: np.ndarray, world_up=np.array([0, 1, 0])):
|
105 |
+
"""
|
106 |
+
Compute a rotation matrix for a camera looking at a target point.
|
107 |
+
|
108 |
+
Parameters:
|
109 |
+
camera_position (np.ndarray): The 3D position of the camera.
|
110 |
+
target (np.ndarray): The point the camera should look at.
|
111 |
+
world_up (np.ndarray): A vector that defines the global 'up' direction.
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
np.ndarray: A 3x3 rotation matrix (camera-to-world) with columns [right, up, forward].
|
115 |
+
"""
|
116 |
+
z_axis = normalize(target - camera_position) # Forward direction
|
117 |
+
x_axis = normalize(np.cross(world_up, z_axis)) # Right direction
|
118 |
+
y_axis = np.cross(z_axis, x_axis) # Recomputed up
|
119 |
+
return np.stack([x_axis, y_axis, z_axis], axis=1)
|
120 |
+
|
121 |
+
|
122 |
+
def generate_circular_camera_path(existing_cameras: List[Camera], N: int = 12, radius_scale: float = 1.0, d: float = 2.0) -> List[Camera]:
|
123 |
+
"""
|
124 |
+
Generate a circular path of cameras around an existing camera group,
|
125 |
+
with each new camera oriented to look at the average viewing direction.
|
126 |
+
|
127 |
+
Parameters:
|
128 |
+
existing_cameras (List[Camera]): List of existing camera objects to estimate average orientation and layout.
|
129 |
+
N (int): Number of new cameras to generate along the circular path.
|
130 |
+
radius_scale (float): Scale factor to adjust the radius of the circle.
|
131 |
+
d (float): Distance ahead of each camera used to estimate its look-at point.
|
132 |
+
|
133 |
+
Returns:
|
134 |
+
List[Camera]: A list of newly generated Camera objects forming a circular path and oriented toward a shared view center.
|
135 |
+
"""
|
136 |
+
# Step 1: Compute average camera position
|
137 |
+
center = np.mean([cam.T for cam in existing_cameras], axis=0)
|
138 |
+
|
139 |
+
# Estimate where each camera is looking
|
140 |
+
# d denotes how far ahead each camera sees — you can scale this
|
141 |
+
look_targets = [cam.T + cam.R[:, 2] * d for cam in existing_cameras]
|
142 |
+
center_of_view = np.mean(look_targets, axis=0)
|
143 |
+
|
144 |
+
# Step 2: Define circular plane basis using fixed up vector
|
145 |
+
avg_forward = normalize(np.mean([cam.R[:, 2] for cam in existing_cameras], axis=0))
|
146 |
+
up_guess = np.array([0, 1, 0])
|
147 |
+
right = normalize(np.cross(avg_forward, up_guess))
|
148 |
+
up = normalize(np.cross(right, avg_forward))
|
149 |
+
|
150 |
+
# Step 3: Estimate radius
|
151 |
+
avg_radius = np.mean([np.linalg.norm(cam.T - center) for cam in existing_cameras]) * radius_scale
|
152 |
+
|
153 |
+
# Step 4: Create cameras on a circular path
|
154 |
+
angles = np.linspace(0, 2 * np.pi, N, endpoint=False)
|
155 |
+
reference_cam = existing_cameras[0]
|
156 |
+
new_cameras = []
|
157 |
+
|
158 |
+
|
159 |
+
for i, a in enumerate(angles):
|
160 |
+
position = center + avg_radius * (np.cos(a) * right + np.sin(a) * up)
|
161 |
+
|
162 |
+
if d < 1e-5 or radius_scale < 1e-5:
|
163 |
+
# Use same orientation as the first camera
|
164 |
+
R = reference_cam.R.copy()
|
165 |
+
else:
|
166 |
+
# Change orientation
|
167 |
+
R = look_at_rotation(position, center_of_view)
|
168 |
+
new_cameras.append(Camera(
|
169 |
+
R=R,
|
170 |
+
T=position, # New position
|
171 |
+
FoVx=reference_cam.FoVx,
|
172 |
+
FoVy=reference_cam.FoVy,
|
173 |
+
resolution=(reference_cam.image_width, reference_cam.image_height),
|
174 |
+
colmap_id=-1,
|
175 |
+
depth_params=None,
|
176 |
+
image=Image.fromarray(np.zeros((reference_cam.image_height, reference_cam.image_width, 3), dtype=np.uint8)),
|
177 |
+
invdepthmap=None,
|
178 |
+
image_name=f"circular_a={a:.3f}",
|
179 |
+
uid=i
|
180 |
+
))
|
181 |
+
|
182 |
+
return new_cameras
|
183 |
+
|
184 |
+
|
185 |
+
def save_numpy_frames_as_gif(frames, output_path="animation.gif", duration=100):
|
186 |
+
"""
|
187 |
+
Save a list of RGB NumPy frames as a looping GIF animation.
|
188 |
+
|
189 |
+
Parameters:
|
190 |
+
frames (List[np.ndarray]): List of RGB images as uint8 NumPy arrays (shape HxWx3).
|
191 |
+
output_path (str): Path to save the output GIF.
|
192 |
+
duration (int): Duration per frame in milliseconds.
|
193 |
+
|
194 |
+
Returns:
|
195 |
+
None
|
196 |
+
"""
|
197 |
+
pil_frames = [Image.fromarray(f) for f in frames]
|
198 |
+
pil_frames[0].save(
|
199 |
+
output_path,
|
200 |
+
save_all=True,
|
201 |
+
append_images=pil_frames[1:],
|
202 |
+
duration=duration, # duration per frame in ms
|
203 |
+
loop=0
|
204 |
+
)
|
205 |
+
print(f"GIF saved to: {output_path}")
|
206 |
+
|
207 |
+
def center_crop_frame(frame: np.ndarray, crop_fraction: float) -> np.ndarray:
|
208 |
+
"""
|
209 |
+
Crop the central region of the frame by the given fraction.
|
210 |
+
|
211 |
+
Parameters:
|
212 |
+
frame (np.ndarray): Input RGB image (H, W, 3).
|
213 |
+
crop_fraction (float): Fraction of the original size to retain (e.g., 0.8 keeps 80%).
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
np.ndarray: Cropped RGB image.
|
217 |
+
"""
|
218 |
+
if crop_fraction >= 1.0:
|
219 |
+
return frame
|
220 |
+
|
221 |
+
h, w, _ = frame.shape
|
222 |
+
new_h, new_w = int(h * crop_fraction), int(w * crop_fraction)
|
223 |
+
start_y = (h - new_h) // 2
|
224 |
+
start_x = (w - new_w) // 2
|
225 |
+
return frame[start_y:start_y + new_h, start_x:start_x + new_w, :]
|
226 |
+
|
227 |
+
|
228 |
+
|
229 |
+
def generate_smooth_closed_camera_path(existing_cameras: List[Camera], N: int = 120, d: float = 2.0, s=.25) -> List[Camera]:
|
230 |
+
"""
|
231 |
+
Generate a smooth, closed path interpolating the positions of existing cameras.
|
232 |
+
|
233 |
+
Parameters:
|
234 |
+
existing_cameras (List[Camera]): List of existing cameras.
|
235 |
+
N (int): Number of points (cameras) to sample along the smooth path.
|
236 |
+
d (float): Distance ahead for estimating the center of view.
|
237 |
+
|
238 |
+
Returns:
|
239 |
+
List[Camera]: A list of smoothly moving Camera objects along a closed loop.
|
240 |
+
"""
|
241 |
+
# Step 1: Extract camera positions
|
242 |
+
positions = np.array([cam.T for cam in existing_cameras])
|
243 |
+
|
244 |
+
# Step 2: Estimate center of view
|
245 |
+
look_targets = [cam.T + cam.R[:, 2] * d for cam in existing_cameras]
|
246 |
+
center_of_view = np.mean(look_targets, axis=0)
|
247 |
+
|
248 |
+
# Step 3: Fit a smooth closed spline through the positions
|
249 |
+
positions = np.vstack([positions, positions[0]]) # close the loop
|
250 |
+
tck, u = splprep(positions.T, s=s, per=True) # periodic=True for closed loop
|
251 |
+
|
252 |
+
# Step 4: Sample points along the spline
|
253 |
+
u_fine = np.linspace(0, 1, N)
|
254 |
+
smooth_path = np.stack(splev(u_fine, tck), axis=-1)
|
255 |
+
|
256 |
+
# Step 5: Generate cameras along the smooth path
|
257 |
+
reference_cam = existing_cameras[0]
|
258 |
+
new_cameras = []
|
259 |
+
|
260 |
+
for i, pos in enumerate(smooth_path):
|
261 |
+
R = look_at_rotation(pos, center_of_view)
|
262 |
+
new_cameras.append(Camera(
|
263 |
+
R=R,
|
264 |
+
T=pos,
|
265 |
+
FoVx=reference_cam.FoVx,
|
266 |
+
FoVy=reference_cam.FoVy,
|
267 |
+
resolution=(reference_cam.image_width, reference_cam.image_height),
|
268 |
+
colmap_id=-1,
|
269 |
+
depth_params=None,
|
270 |
+
image=Image.fromarray(np.zeros((reference_cam.image_height, reference_cam.image_width, 3), dtype=np.uint8)),
|
271 |
+
invdepthmap=None,
|
272 |
+
image_name=f"smooth_path_i={i}",
|
273 |
+
uid=i
|
274 |
+
))
|
275 |
+
|
276 |
+
return new_cameras
|
277 |
+
|
278 |
+
|
279 |
+
def save_numpy_frames_as_mp4(frames, output_path="animation.mp4", fps=10, center_crop: float = 1.0):
|
280 |
+
"""
|
281 |
+
Save a list of RGB NumPy frames as an MP4 video with optional center cropping.
|
282 |
+
|
283 |
+
Parameters:
|
284 |
+
frames (List[np.ndarray]): List of RGB images as uint8 NumPy arrays (shape HxWx3).
|
285 |
+
output_path (str): Path to save the output MP4.
|
286 |
+
fps (int): Frames per second for playback speed.
|
287 |
+
center_crop (float): Fraction (0 < center_crop <= 1.0) of central region to retain.
|
288 |
+
Use 1.0 for no cropping; 0.8 to crop to 80% center region.
|
289 |
+
|
290 |
+
Returns:
|
291 |
+
None
|
292 |
+
"""
|
293 |
+
with imageio.get_writer(output_path, fps=fps, codec='libx264', quality=8) as writer:
|
294 |
+
for frame in frames:
|
295 |
+
cropped = center_crop_frame(frame, center_crop)
|
296 |
+
writer.append_data(cropped)
|
297 |
+
print(f"MP4 saved to: {output_path}")
|
298 |
+
|
299 |
+
|
300 |
+
|
301 |
+
def put_text_on_image(img: np.ndarray, text: str) -> np.ndarray:
|
302 |
+
"""
|
303 |
+
Draws multiline white text on a copy of the input image, positioned near the bottom
|
304 |
+
and around 80% of the image width. Handles '\n' characters to split text into multiple lines.
|
305 |
+
|
306 |
+
Args:
|
307 |
+
img (np.ndarray): Input image as a (H, W, 3) uint8 numpy array.
|
308 |
+
text (str): Text string to draw on the image. Newlines '\n' are treated as line breaks.
|
309 |
+
|
310 |
+
Returns:
|
311 |
+
np.ndarray: The output image with the text drawn on it.
|
312 |
+
|
313 |
+
Notes:
|
314 |
+
- The function automatically adjusts line spacing and prevents text from going outside the image.
|
315 |
+
- Text is drawn in white with small font size (0.5) for minimal visual impact.
|
316 |
+
"""
|
317 |
+
img = img.copy()
|
318 |
+
height, width, _ = img.shape
|
319 |
+
|
320 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
321 |
+
font_scale = 1.
|
322 |
+
color = (255, 255, 255)
|
323 |
+
thickness = 2
|
324 |
+
line_spacing = 5 # extra pixels between lines
|
325 |
+
|
326 |
+
lines = text.split('\n')
|
327 |
+
|
328 |
+
# Precompute the maximum text width to adjust starting x
|
329 |
+
max_text_width = max(cv2.getTextSize(line, font, font_scale, thickness)[0][0] for line in lines)
|
330 |
+
|
331 |
+
x = int(0.8 * width)
|
332 |
+
x = min(x, width - max_text_width - 30) # margin on right
|
333 |
+
#x = int(0.03 * width)
|
334 |
+
|
335 |
+
# Start near the bottom, but move up depending on number of lines
|
336 |
+
total_text_height = len(lines) * (cv2.getTextSize('A', font, font_scale, thickness)[0][1] + line_spacing)
|
337 |
+
y_start = int(height*0.9) - total_text_height # 30 pixels from bottom
|
338 |
+
|
339 |
+
for i, line in enumerate(lines):
|
340 |
+
y = y_start + i * (cv2.getTextSize(line, font, font_scale, thickness)[0][1] + line_spacing)
|
341 |
+
cv2.putText(img, line, (x, y), font, font_scale, color, thickness, cv2.LINE_AA)
|
342 |
+
|
343 |
+
return img
|
344 |
+
|
345 |
+
|
346 |
+
|
347 |
+
|
348 |
+
def catmull_rom_spline(P0, P1, P2, P3, n_points=20):
|
349 |
+
"""
|
350 |
+
Compute Catmull-Rom spline segment between P1 and P2.
|
351 |
+
"""
|
352 |
+
t = np.linspace(0, 1, n_points)[:, None]
|
353 |
+
|
354 |
+
M = 0.5 * np.array([
|
355 |
+
[-1, 3, -3, 1],
|
356 |
+
[ 2, -5, 4, -1],
|
357 |
+
[-1, 0, 1, 0],
|
358 |
+
[ 0, 2, 0, 0]
|
359 |
+
])
|
360 |
+
|
361 |
+
G = np.stack([P0, P1, P2, P3], axis=0)
|
362 |
+
T = np.concatenate([t**3, t**2, t, np.ones_like(t)], axis=1)
|
363 |
+
|
364 |
+
return T @ M @ G
|
365 |
+
|
366 |
+
def sort_cameras_pca(existing_cameras: List[Camera]):
|
367 |
+
"""
|
368 |
+
Sort cameras along the main PCA axis.
|
369 |
+
"""
|
370 |
+
positions = np.array([cam.T for cam in existing_cameras])
|
371 |
+
pca = PCA(n_components=1)
|
372 |
+
scores = pca.fit_transform(positions)
|
373 |
+
sorted_indices = np.argsort(scores[:, 0])
|
374 |
+
return sorted_indices
|
375 |
+
|
376 |
+
def generate_fully_smooth_cameras(existing_cameras: List[Camera],
|
377 |
+
n_selected: int = 30,
|
378 |
+
n_points_per_segment: int = 20,
|
379 |
+
d: float = 2.0,
|
380 |
+
closed: bool = False) -> List[Camera]:
|
381 |
+
"""
|
382 |
+
Generate a fully smooth camera path using PCA ordering, global Catmull-Rom spline for positions, and global SLERP for orientations.
|
383 |
+
|
384 |
+
Args:
|
385 |
+
existing_cameras (List[Camera]): List of input cameras.
|
386 |
+
n_selected (int): Number of cameras to select after sorting.
|
387 |
+
n_points_per_segment (int): Number of interpolated points per spline segment.
|
388 |
+
d (float): Distance ahead for estimating center of view.
|
389 |
+
closed (bool): Whether to close the path.
|
390 |
+
|
391 |
+
Returns:
|
392 |
+
List[Camera]: List of smoothly moving Camera objects.
|
393 |
+
"""
|
394 |
+
# 1. Sort cameras along PCA axis
|
395 |
+
sorted_indices = sort_cameras_pca(existing_cameras)
|
396 |
+
sorted_cameras = [existing_cameras[i] for i in sorted_indices]
|
397 |
+
positions = np.array([cam.T for cam in sorted_cameras])
|
398 |
+
|
399 |
+
# 2. Subsample uniformly
|
400 |
+
idx = np.linspace(0, len(positions) - 1, n_selected).astype(int)
|
401 |
+
sampled_positions = positions[idx]
|
402 |
+
sampled_cameras = [sorted_cameras[i] for i in idx]
|
403 |
+
|
404 |
+
# 3. Prepare for Catmull-Rom
|
405 |
+
if closed:
|
406 |
+
sampled_positions = np.vstack([sampled_positions[-1], sampled_positions, sampled_positions[0], sampled_positions[1]])
|
407 |
+
else:
|
408 |
+
sampled_positions = np.vstack([sampled_positions[0], sampled_positions, sampled_positions[-1], sampled_positions[-1]])
|
409 |
+
|
410 |
+
# 4. Generate smooth path positions
|
411 |
+
path_positions = []
|
412 |
+
for i in range(1, len(sampled_positions) - 2):
|
413 |
+
segment = catmull_rom_spline(sampled_positions[i-1], sampled_positions[i], sampled_positions[i+1], sampled_positions[i+2], n_points_per_segment)
|
414 |
+
path_positions.append(segment)
|
415 |
+
path_positions = np.concatenate(path_positions, axis=0)
|
416 |
+
|
417 |
+
# 5. Global SLERP for rotations
|
418 |
+
rotations = R.from_matrix([cam.R for cam in sampled_cameras])
|
419 |
+
key_times = np.linspace(0, 1, len(rotations))
|
420 |
+
slerp = Slerp(key_times, rotations)
|
421 |
+
|
422 |
+
query_times = np.linspace(0, 1, len(path_positions))
|
423 |
+
interpolated_rotations = slerp(query_times)
|
424 |
+
|
425 |
+
# 6. Generate Camera objects
|
426 |
+
reference_cam = existing_cameras[0]
|
427 |
+
smooth_cameras = []
|
428 |
+
|
429 |
+
for i, pos in enumerate(path_positions):
|
430 |
+
R_interp = interpolated_rotations[i].as_matrix()
|
431 |
+
|
432 |
+
smooth_cameras.append(Camera(
|
433 |
+
R=R_interp,
|
434 |
+
T=pos,
|
435 |
+
FoVx=reference_cam.FoVx,
|
436 |
+
FoVy=reference_cam.FoVy,
|
437 |
+
resolution=(reference_cam.image_width, reference_cam.image_height),
|
438 |
+
colmap_id=-1,
|
439 |
+
depth_params=None,
|
440 |
+
image=Image.fromarray(np.zeros((reference_cam.image_height, reference_cam.image_width, 3), dtype=np.uint8)),
|
441 |
+
invdepthmap=None,
|
442 |
+
image_name=f"fully_smooth_path_i={i}",
|
443 |
+
uid=i
|
444 |
+
))
|
445 |
+
|
446 |
+
return smooth_cameras
|
447 |
+
|
448 |
+
|
449 |
+
def plot_cameras_and_smooth_path_with_orientation(existing_cameras: List[Camera], smooth_cameras: List[Camera], scale: float = 0.1):
|
450 |
+
"""
|
451 |
+
Plot input cameras and smooth path cameras with their orientations in 3D.
|
452 |
+
|
453 |
+
Args:
|
454 |
+
existing_cameras (List[Camera]): List of original input cameras.
|
455 |
+
smooth_cameras (List[Camera]): List of smooth path cameras.
|
456 |
+
scale (float): Length of orientation arrows.
|
457 |
+
|
458 |
+
Returns:
|
459 |
+
None
|
460 |
+
"""
|
461 |
+
# Input cameras
|
462 |
+
input_positions = np.array([cam.T for cam in existing_cameras])
|
463 |
+
|
464 |
+
# Smooth cameras
|
465 |
+
smooth_positions = np.array([cam.T for cam in smooth_cameras])
|
466 |
+
|
467 |
+
fig = go.Figure()
|
468 |
+
|
469 |
+
# Plot input camera positions
|
470 |
+
fig.add_trace(go.Scatter3d(
|
471 |
+
x=input_positions[:, 0], y=input_positions[:, 1], z=input_positions[:, 2],
|
472 |
+
mode='markers',
|
473 |
+
marker=dict(size=4, color='blue'),
|
474 |
+
name='Input Cameras'
|
475 |
+
))
|
476 |
+
|
477 |
+
# Plot smooth path positions
|
478 |
+
fig.add_trace(go.Scatter3d(
|
479 |
+
x=smooth_positions[:, 0], y=smooth_positions[:, 1], z=smooth_positions[:, 2],
|
480 |
+
mode='lines+markers',
|
481 |
+
line=dict(color='red', width=3),
|
482 |
+
marker=dict(size=2, color='red'),
|
483 |
+
name='Smooth Path Cameras'
|
484 |
+
))
|
485 |
+
|
486 |
+
# Plot input camera orientations
|
487 |
+
for cam in existing_cameras:
|
488 |
+
origin = cam.T
|
489 |
+
forward = cam.R[:, 2] # Forward direction
|
490 |
+
|
491 |
+
fig.add_trace(go.Cone(
|
492 |
+
x=[origin[0]], y=[origin[1]], z=[origin[2]],
|
493 |
+
u=[forward[0]], v=[forward[1]], w=[forward[2]],
|
494 |
+
colorscale=[[0, 'blue'], [1, 'blue']],
|
495 |
+
sizemode="absolute",
|
496 |
+
sizeref=scale,
|
497 |
+
anchor="tail",
|
498 |
+
showscale=False,
|
499 |
+
name='Input Camera Direction'
|
500 |
+
))
|
501 |
+
|
502 |
+
# Plot smooth camera orientations
|
503 |
+
for cam in smooth_cameras:
|
504 |
+
origin = cam.T
|
505 |
+
forward = cam.R[:, 2] # Forward direction
|
506 |
+
|
507 |
+
fig.add_trace(go.Cone(
|
508 |
+
x=[origin[0]], y=[origin[1]], z=[origin[2]],
|
509 |
+
u=[forward[0]], v=[forward[1]], w=[forward[2]],
|
510 |
+
colorscale=[[0, 'red'], [1, 'red']],
|
511 |
+
sizemode="absolute",
|
512 |
+
sizeref=scale,
|
513 |
+
anchor="tail",
|
514 |
+
showscale=False,
|
515 |
+
name='Smooth Camera Direction'
|
516 |
+
))
|
517 |
+
|
518 |
+
fig.update_layout(
|
519 |
+
scene=dict(
|
520 |
+
xaxis_title='X',
|
521 |
+
yaxis_title='Y',
|
522 |
+
zaxis_title='Z',
|
523 |
+
aspectmode='data'
|
524 |
+
),
|
525 |
+
title="Input Cameras and Smooth Path with Orientations",
|
526 |
+
margin=dict(l=0, r=0, b=0, t=30)
|
527 |
+
)
|
528 |
+
|
529 |
+
fig.show()
|
530 |
+
|
531 |
+
|
532 |
+
def solve_tsp_nearest_neighbor(points: np.ndarray):
|
533 |
+
"""
|
534 |
+
Solve TSP approximately using nearest neighbor heuristic.
|
535 |
+
|
536 |
+
Args:
|
537 |
+
points (np.ndarray): (N, 3) array of points.
|
538 |
+
|
539 |
+
Returns:
|
540 |
+
List[int]: Optimal visiting order of points.
|
541 |
+
"""
|
542 |
+
N = points.shape[0]
|
543 |
+
dist = distance_matrix(points, points)
|
544 |
+
visited = [0]
|
545 |
+
unvisited = set(range(1, N))
|
546 |
+
|
547 |
+
while unvisited:
|
548 |
+
last = visited[-1]
|
549 |
+
next_city = min(unvisited, key=lambda city: dist[last, city])
|
550 |
+
visited.append(next_city)
|
551 |
+
unvisited.remove(next_city)
|
552 |
+
|
553 |
+
return visited
|
554 |
+
|
555 |
+
def solve_tsp_2opt(points: np.ndarray, n_iter: int = 1000) -> np.ndarray:
|
556 |
+
"""
|
557 |
+
Solve TSP approximately using Nearest Neighbor + 2-Opt.
|
558 |
+
|
559 |
+
Args:
|
560 |
+
points (np.ndarray): Array of shape (N, D) with points.
|
561 |
+
n_iter (int): Number of 2-opt iterations.
|
562 |
+
|
563 |
+
Returns:
|
564 |
+
np.ndarray: Ordered list of indices.
|
565 |
+
"""
|
566 |
+
n_points = points.shape[0]
|
567 |
+
|
568 |
+
# === 1. Start with Nearest Neighbor
|
569 |
+
unvisited = list(range(n_points))
|
570 |
+
current = unvisited.pop(0)
|
571 |
+
path = [current]
|
572 |
+
|
573 |
+
while unvisited:
|
574 |
+
dists = np.linalg.norm(points[unvisited] - points[current], axis=1)
|
575 |
+
next_idx = unvisited[np.argmin(dists)]
|
576 |
+
unvisited.remove(next_idx)
|
577 |
+
path.append(next_idx)
|
578 |
+
current = next_idx
|
579 |
+
|
580 |
+
# === 2. Apply 2-Opt improvements
|
581 |
+
def path_length(path):
|
582 |
+
return np.sum(np.linalg.norm(points[path[i]] - points[path[i+1]], axis=0) for i in range(len(path)-1))
|
583 |
+
|
584 |
+
best_length = path_length(path)
|
585 |
+
improved = True
|
586 |
+
|
587 |
+
for _ in range(n_iter):
|
588 |
+
if not improved:
|
589 |
+
break
|
590 |
+
improved = False
|
591 |
+
for i in range(1, n_points - 2):
|
592 |
+
for j in range(i + 1, n_points):
|
593 |
+
if j - i == 1: continue
|
594 |
+
new_path = path[:i] + path[i:j][::-1] + path[j:]
|
595 |
+
new_length = path_length(new_path)
|
596 |
+
if new_length < best_length:
|
597 |
+
path = new_path
|
598 |
+
best_length = new_length
|
599 |
+
improved = True
|
600 |
+
break
|
601 |
+
if improved:
|
602 |
+
break
|
603 |
+
|
604 |
+
return np.array(path)
|
605 |
+
|
606 |
+
def generate_fully_smooth_cameras_with_tsp(existing_cameras: List[Camera],
|
607 |
+
n_selected: int = 30,
|
608 |
+
n_points_per_segment: int = 20,
|
609 |
+
d: float = 2.0,
|
610 |
+
closed: bool = False) -> List[Camera]:
|
611 |
+
"""
|
612 |
+
Generate a fully smooth camera path using TSP ordering, global Catmull-Rom spline for positions, and global SLERP for orientations.
|
613 |
+
|
614 |
+
Args:
|
615 |
+
existing_cameras (List[Camera]): List of input cameras.
|
616 |
+
n_selected (int): Number of cameras to select after ordering.
|
617 |
+
n_points_per_segment (int): Number of interpolated points per spline segment.
|
618 |
+
d (float): Distance ahead for estimating center of view.
|
619 |
+
closed (bool): Whether to close the path.
|
620 |
+
|
621 |
+
Returns:
|
622 |
+
List[Camera]: List of smoothly moving Camera objects.
|
623 |
+
"""
|
624 |
+
positions = np.array([cam.T for cam in existing_cameras])
|
625 |
+
|
626 |
+
# 1. Solve approximate TSP
|
627 |
+
order = solve_tsp_nearest_neighbor(positions)
|
628 |
+
ordered_cameras = [existing_cameras[i] for i in order]
|
629 |
+
ordered_positions = positions[order]
|
630 |
+
|
631 |
+
# 2. Subsample uniformly
|
632 |
+
idx = np.linspace(0, len(ordered_positions) - 1, n_selected).astype(int)
|
633 |
+
sampled_positions = ordered_positions[idx]
|
634 |
+
sampled_cameras = [ordered_cameras[i] for i in idx]
|
635 |
+
|
636 |
+
# 3. Prepare for Catmull-Rom
|
637 |
+
if closed:
|
638 |
+
sampled_positions = np.vstack([sampled_positions[-1], sampled_positions, sampled_positions[0], sampled_positions[1]])
|
639 |
+
else:
|
640 |
+
sampled_positions = np.vstack([sampled_positions[0], sampled_positions, sampled_positions[-1], sampled_positions[-1]])
|
641 |
+
|
642 |
+
# 4. Generate smooth path positions
|
643 |
+
path_positions = []
|
644 |
+
for i in range(1, len(sampled_positions) - 2):
|
645 |
+
segment = catmull_rom_spline(sampled_positions[i-1], sampled_positions[i], sampled_positions[i+1], sampled_positions[i+2], n_points_per_segment)
|
646 |
+
path_positions.append(segment)
|
647 |
+
path_positions = np.concatenate(path_positions, axis=0)
|
648 |
+
|
649 |
+
# 5. Global SLERP for rotations
|
650 |
+
rotations = R.from_matrix([cam.R for cam in sampled_cameras])
|
651 |
+
key_times = np.linspace(0, 1, len(rotations))
|
652 |
+
slerp = Slerp(key_times, rotations)
|
653 |
+
|
654 |
+
query_times = np.linspace(0, 1, len(path_positions))
|
655 |
+
interpolated_rotations = slerp(query_times)
|
656 |
+
|
657 |
+
# 6. Generate Camera objects
|
658 |
+
reference_cam = existing_cameras[0]
|
659 |
+
smooth_cameras = []
|
660 |
+
|
661 |
+
for i, pos in enumerate(path_positions):
|
662 |
+
R_interp = interpolated_rotations[i].as_matrix()
|
663 |
+
|
664 |
+
smooth_cameras.append(Camera(
|
665 |
+
R=R_interp,
|
666 |
+
T=pos,
|
667 |
+
FoVx=reference_cam.FoVx,
|
668 |
+
FoVy=reference_cam.FoVy,
|
669 |
+
resolution=(reference_cam.image_width, reference_cam.image_height),
|
670 |
+
colmap_id=-1,
|
671 |
+
depth_params=None,
|
672 |
+
image=Image.fromarray(np.zeros((reference_cam.image_height, reference_cam.image_width, 3), dtype=np.uint8)),
|
673 |
+
invdepthmap=None,
|
674 |
+
image_name=f"fully_smooth_path_i={i}",
|
675 |
+
uid=i
|
676 |
+
))
|
677 |
+
|
678 |
+
return smooth_cameras
|
679 |
+
|
680 |
+
from typing import List
|
681 |
+
import numpy as np
|
682 |
+
from sklearn.mixture import GaussianMixture
|
683 |
+
from scipy.spatial.transform import Rotation as R, Slerp
|
684 |
+
from PIL import Image
|
685 |
+
|
686 |
+
def generate_clustered_smooth_cameras_with_tsp(existing_cameras: List[Camera],
|
687 |
+
n_selected: int = 30,
|
688 |
+
n_points_per_segment: int = 20,
|
689 |
+
d: float = 2.0,
|
690 |
+
n_clusters: int = 5,
|
691 |
+
closed: bool = False) -> List[Camera]:
|
692 |
+
"""
|
693 |
+
Generate a fully smooth camera path using clustering + TSP between nearest cluster centers + TSP inside clusters.
|
694 |
+
Positions are normalized before clustering and denormalized before generating final cameras.
|
695 |
+
|
696 |
+
Args:
|
697 |
+
existing_cameras (List[Camera]): List of input cameras.
|
698 |
+
n_selected (int): Number of cameras to select after ordering.
|
699 |
+
n_points_per_segment (int): Number of interpolated points per spline segment.
|
700 |
+
d (float): Distance ahead for estimating center of view.
|
701 |
+
n_clusters (int): Number of GMM clusters.
|
702 |
+
closed (bool): Whether to close the path.
|
703 |
+
|
704 |
+
Returns:
|
705 |
+
List[Camera]: Smooth path of Camera objects.
|
706 |
+
"""
|
707 |
+
# Extract positions and rotations
|
708 |
+
positions = np.array([cam.T for cam in existing_cameras])
|
709 |
+
rotations = np.array([R.from_matrix(cam.R).as_quat() for cam in existing_cameras])
|
710 |
+
|
711 |
+
# === Normalize positions
|
712 |
+
mean_pos = np.mean(positions, axis=0)
|
713 |
+
scale_pos = np.std(positions, axis=0)
|
714 |
+
scale_pos[scale_pos == 0] = 1.0 # avoid division by zero
|
715 |
+
|
716 |
+
positions_normalized = (positions - mean_pos) / scale_pos
|
717 |
+
|
718 |
+
# === Features for clustering (only positions, not rotations)
|
719 |
+
features = positions_normalized
|
720 |
+
|
721 |
+
# === 1. GMM clustering
|
722 |
+
gmm = GaussianMixture(n_components=n_clusters, covariance_type='full', random_state=42)
|
723 |
+
cluster_labels = gmm.fit_predict(features)
|
724 |
+
|
725 |
+
clusters = {}
|
726 |
+
cluster_centers = []
|
727 |
+
|
728 |
+
for cluster_id in range(n_clusters):
|
729 |
+
cluster_indices = np.where(cluster_labels == cluster_id)[0]
|
730 |
+
if len(cluster_indices) == 0:
|
731 |
+
continue
|
732 |
+
clusters[cluster_id] = cluster_indices
|
733 |
+
cluster_center = np.mean(features[cluster_indices], axis=0)
|
734 |
+
cluster_centers.append(cluster_center)
|
735 |
+
|
736 |
+
cluster_centers = np.stack(cluster_centers)
|
737 |
+
|
738 |
+
# === 2. Remap cluster centers to nearest existing cameras
|
739 |
+
if False:
|
740 |
+
mapped_centers = []
|
741 |
+
for center in cluster_centers:
|
742 |
+
dists = np.linalg.norm(features - center, axis=1)
|
743 |
+
nearest_idx = np.argmin(dists)
|
744 |
+
mapped_centers.append(features[nearest_idx])
|
745 |
+
mapped_centers = np.stack(mapped_centers)
|
746 |
+
cluster_centers = mapped_centers
|
747 |
+
# === 3. Solve TSP between mapped cluster centers
|
748 |
+
cluster_order = solve_tsp_2opt(cluster_centers)
|
749 |
+
|
750 |
+
# === 4. For each cluster, solve TSP inside cluster
|
751 |
+
final_indices = []
|
752 |
+
for cluster_id in cluster_order:
|
753 |
+
cluster_indices = clusters[cluster_id]
|
754 |
+
cluster_positions = features[cluster_indices]
|
755 |
+
|
756 |
+
if len(cluster_positions) == 1:
|
757 |
+
final_indices.append(cluster_indices[0])
|
758 |
+
continue
|
759 |
+
|
760 |
+
local_order = solve_tsp_nearest_neighbor(cluster_positions)
|
761 |
+
ordered_cluster_indices = cluster_indices[local_order]
|
762 |
+
final_indices.extend(ordered_cluster_indices)
|
763 |
+
|
764 |
+
ordered_cameras = [existing_cameras[i] for i in final_indices]
|
765 |
+
ordered_positions = positions_normalized[final_indices]
|
766 |
+
|
767 |
+
# === 5. Subsample uniformly
|
768 |
+
idx = np.linspace(0, len(ordered_positions) - 1, n_selected).astype(int)
|
769 |
+
sampled_positions = ordered_positions[idx]
|
770 |
+
sampled_cameras = [ordered_cameras[i] for i in idx]
|
771 |
+
|
772 |
+
# === 6. Prepare for Catmull-Rom spline
|
773 |
+
if closed:
|
774 |
+
sampled_positions = np.vstack([sampled_positions[-1], sampled_positions, sampled_positions[0], sampled_positions[1]])
|
775 |
+
else:
|
776 |
+
sampled_positions = np.vstack([sampled_positions[0], sampled_positions, sampled_positions[-1], sampled_positions[-1]])
|
777 |
+
|
778 |
+
# === 7. Smooth path positions
|
779 |
+
path_positions = []
|
780 |
+
for i in range(1, len(sampled_positions) - 2):
|
781 |
+
segment = catmull_rom_spline(sampled_positions[i-1], sampled_positions[i], sampled_positions[i+1], sampled_positions[i+2], n_points_per_segment)
|
782 |
+
path_positions.append(segment)
|
783 |
+
path_positions = np.concatenate(path_positions, axis=0)
|
784 |
+
|
785 |
+
# === 8. Denormalize
|
786 |
+
path_positions = path_positions * scale_pos + mean_pos
|
787 |
+
|
788 |
+
# === 9. SLERP for rotations
|
789 |
+
rotations = R.from_matrix([cam.R for cam in sampled_cameras])
|
790 |
+
key_times = np.linspace(0, 1, len(rotations))
|
791 |
+
slerp = Slerp(key_times, rotations)
|
792 |
+
|
793 |
+
query_times = np.linspace(0, 1, len(path_positions))
|
794 |
+
interpolated_rotations = slerp(query_times)
|
795 |
+
|
796 |
+
# === 10. Generate Camera objects
|
797 |
+
reference_cam = existing_cameras[0]
|
798 |
+
smooth_cameras = []
|
799 |
+
|
800 |
+
for i, pos in enumerate(path_positions):
|
801 |
+
R_interp = interpolated_rotations[i].as_matrix()
|
802 |
+
|
803 |
+
smooth_cameras.append(Camera(
|
804 |
+
R=R_interp,
|
805 |
+
T=pos,
|
806 |
+
FoVx=reference_cam.FoVx,
|
807 |
+
FoVy=reference_cam.FoVy,
|
808 |
+
resolution=(reference_cam.image_width, reference_cam.image_height),
|
809 |
+
colmap_id=-1,
|
810 |
+
depth_params=None,
|
811 |
+
image=Image.fromarray(np.zeros((reference_cam.image_height, reference_cam.image_width, 3), dtype=np.uint8)),
|
812 |
+
invdepthmap=None,
|
813 |
+
image_name=f"clustered_smooth_path_i={i}",
|
814 |
+
uid=i
|
815 |
+
))
|
816 |
+
|
817 |
+
return smooth_cameras
|
818 |
+
|
819 |
+
|
820 |
+
# def generate_clustered_path(existing_cameras: List[Camera],
|
821 |
+
# n_points_per_segment: int = 20,
|
822 |
+
# d: float = 2.0,
|
823 |
+
# n_clusters: int = 5,
|
824 |
+
# closed: bool = False) -> List[Camera]:
|
825 |
+
# """
|
826 |
+
# Generate a smooth camera path using GMM clustering and TSP on cluster centers.
|
827 |
+
|
828 |
+
# Args:
|
829 |
+
# existing_cameras (List[Camera]): List of input cameras.
|
830 |
+
# n_points_per_segment (int): Number of interpolated points per spline segment.
|
831 |
+
# d (float): Distance ahead for estimating center of view.
|
832 |
+
# n_clusters (int): Number of GMM clusters (zones).
|
833 |
+
# closed (bool): Whether to close the path.
|
834 |
+
|
835 |
+
# Returns:
|
836 |
+
# List[Camera]: Smooth path of Camera objects.
|
837 |
+
# """
|
838 |
+
# # Extract positions and rotations
|
839 |
+
# positions = np.array([cam.T for cam in existing_cameras])
|
840 |
+
|
841 |
+
# # === Normalize positions
|
842 |
+
# mean_pos = np.mean(positions, axis=0)
|
843 |
+
# scale_pos = np.std(positions, axis=0)
|
844 |
+
# scale_pos[scale_pos == 0] = 1.0
|
845 |
+
|
846 |
+
# positions_normalized = (positions - mean_pos) / scale_pos
|
847 |
+
|
848 |
+
# # === 1. GMM clustering (only positions)
|
849 |
+
# gmm = GaussianMixture(n_components=n_clusters, covariance_type='full', random_state=42)
|
850 |
+
# cluster_labels = gmm.fit_predict(positions_normalized)
|
851 |
+
|
852 |
+
# cluster_centers = []
|
853 |
+
# for cluster_id in range(n_clusters):
|
854 |
+
# cluster_indices = np.where(cluster_labels == cluster_id)[0]
|
855 |
+
# if len(cluster_indices) == 0:
|
856 |
+
# continue
|
857 |
+
# cluster_center = np.mean(positions_normalized[cluster_indices], axis=0)
|
858 |
+
# cluster_centers.append(cluster_center)
|
859 |
+
|
860 |
+
# cluster_centers = np.stack(cluster_centers)
|
861 |
+
|
862 |
+
# # === 2. Solve TSP between cluster centers
|
863 |
+
# cluster_order = solve_tsp_2opt(cluster_centers)
|
864 |
+
|
865 |
+
# # === 3. Reorder cluster centers
|
866 |
+
# ordered_centers = cluster_centers[cluster_order]
|
867 |
+
|
868 |
+
# # === 4. Prepare Catmull-Rom spline
|
869 |
+
# if closed:
|
870 |
+
# ordered_centers = np.vstack([ordered_centers[-1], ordered_centers, ordered_centers[0], ordered_centers[1]])
|
871 |
+
# else:
|
872 |
+
# ordered_centers = np.vstack([ordered_centers[0], ordered_centers, ordered_centers[-1], ordered_centers[-1]])
|
873 |
+
|
874 |
+
# # === 5. Generate smooth path positions
|
875 |
+
# path_positions = []
|
876 |
+
# for i in range(1, len(ordered_centers) - 2):
|
877 |
+
# segment = catmull_rom_spline(ordered_centers[i-1], ordered_centers[i], ordered_centers[i+1], ordered_centers[i+2], n_points_per_segment)
|
878 |
+
# path_positions.append(segment)
|
879 |
+
# path_positions = np.concatenate(path_positions, axis=0)
|
880 |
+
|
881 |
+
# # === 6. Denormalize back
|
882 |
+
# path_positions = path_positions * scale_pos + mean_pos
|
883 |
+
|
884 |
+
# # === 7. Generate dummy rotations (constant forward facing)
|
885 |
+
# reference_cam = existing_cameras[0]
|
886 |
+
# default_rotation = R.from_matrix(reference_cam.R)
|
887 |
+
|
888 |
+
# # For simplicity, fixed rotation for all
|
889 |
+
# smooth_cameras = []
|
890 |
+
|
891 |
+
# for i, pos in enumerate(path_positions):
|
892 |
+
# R_interp = default_rotation.as_matrix()
|
893 |
+
|
894 |
+
# smooth_cameras.append(Camera(
|
895 |
+
# R=R_interp,
|
896 |
+
# T=pos,
|
897 |
+
# FoVx=reference_cam.FoVx,
|
898 |
+
# FoVy=reference_cam.FoVy,
|
899 |
+
# resolution=(reference_cam.image_width, reference_cam.image_height),
|
900 |
+
# colmap_id=-1,
|
901 |
+
# depth_params=None,
|
902 |
+
# image=Image.fromarray(np.zeros((reference_cam.image_height, reference_cam.image_width, 3), dtype=np.uint8)),
|
903 |
+
# invdepthmap=None,
|
904 |
+
# image_name=f"cluster_path_i={i}",
|
905 |
+
# uid=i
|
906 |
+
# ))
|
907 |
+
|
908 |
+
# return smooth_cameras
|
909 |
+
|
910 |
+
from typing import List
|
911 |
+
import numpy as np
|
912 |
+
from sklearn.cluster import KMeans
|
913 |
+
from scipy.spatial.transform import Rotation as R, Slerp
|
914 |
+
from PIL import Image
|
915 |
+
|
916 |
+
def generate_clustered_path(existing_cameras: List[Camera],
|
917 |
+
n_points_per_segment: int = 20,
|
918 |
+
d: float = 2.0,
|
919 |
+
n_clusters: int = 5,
|
920 |
+
closed: bool = False) -> List[Camera]:
|
921 |
+
"""
|
922 |
+
Generate a smooth camera path using K-Means clustering and TSP on cluster centers.
|
923 |
+
|
924 |
+
Args:
|
925 |
+
existing_cameras (List[Camera]): List of input cameras.
|
926 |
+
n_points_per_segment (int): Number of interpolated points per spline segment.
|
927 |
+
d (float): Distance ahead for estimating center of view.
|
928 |
+
n_clusters (int): Number of KMeans clusters (zones).
|
929 |
+
closed (bool): Whether to close the path.
|
930 |
+
|
931 |
+
Returns:
|
932 |
+
List[Camera]: Smooth path of Camera objects.
|
933 |
+
"""
|
934 |
+
# Extract positions
|
935 |
+
positions = np.array([cam.T for cam in existing_cameras])
|
936 |
+
|
937 |
+
# === Normalize positions
|
938 |
+
mean_pos = np.mean(positions, axis=0)
|
939 |
+
scale_pos = np.std(positions, axis=0)
|
940 |
+
scale_pos[scale_pos == 0] = 1.0
|
941 |
+
|
942 |
+
positions_normalized = (positions - mean_pos) / scale_pos
|
943 |
+
|
944 |
+
# === 1. K-Means clustering (only positions)
|
945 |
+
kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init='auto')
|
946 |
+
cluster_labels = kmeans.fit_predict(positions_normalized)
|
947 |
+
|
948 |
+
cluster_centers = []
|
949 |
+
for cluster_id in range(n_clusters):
|
950 |
+
cluster_indices = np.where(cluster_labels == cluster_id)[0]
|
951 |
+
if len(cluster_indices) == 0:
|
952 |
+
continue
|
953 |
+
cluster_center = np.mean(positions_normalized[cluster_indices], axis=0)
|
954 |
+
cluster_centers.append(cluster_center)
|
955 |
+
|
956 |
+
cluster_centers = np.stack(cluster_centers)
|
957 |
+
|
958 |
+
# === 2. Solve TSP between cluster centers
|
959 |
+
cluster_order = solve_tsp_2opt(cluster_centers)
|
960 |
+
|
961 |
+
# === 3. Reorder cluster centers
|
962 |
+
ordered_centers = cluster_centers[cluster_order]
|
963 |
+
|
964 |
+
# === 4. Prepare Catmull-Rom spline
|
965 |
+
if closed:
|
966 |
+
ordered_centers = np.vstack([ordered_centers[-1], ordered_centers, ordered_centers[0], ordered_centers[1]])
|
967 |
+
else:
|
968 |
+
ordered_centers = np.vstack([ordered_centers[0], ordered_centers, ordered_centers[-1], ordered_centers[-1]])
|
969 |
+
|
970 |
+
# === 5. Generate smooth path positions
|
971 |
+
path_positions = []
|
972 |
+
for i in range(1, len(ordered_centers) - 2):
|
973 |
+
segment = catmull_rom_spline(ordered_centers[i-1], ordered_centers[i], ordered_centers[i+1], ordered_centers[i+2], n_points_per_segment)
|
974 |
+
path_positions.append(segment)
|
975 |
+
path_positions = np.concatenate(path_positions, axis=0)
|
976 |
+
|
977 |
+
# === 6. Denormalize back
|
978 |
+
path_positions = path_positions * scale_pos + mean_pos
|
979 |
+
|
980 |
+
# === 7. Generate dummy rotations (constant forward facing)
|
981 |
+
reference_cam = existing_cameras[0]
|
982 |
+
default_rotation = R.from_matrix(reference_cam.R)
|
983 |
+
|
984 |
+
# For simplicity, fixed rotation for all
|
985 |
+
smooth_cameras = []
|
986 |
+
|
987 |
+
for i, pos in enumerate(path_positions):
|
988 |
+
R_interp = default_rotation.as_matrix()
|
989 |
+
|
990 |
+
smooth_cameras.append(Camera(
|
991 |
+
R=R_interp,
|
992 |
+
T=pos,
|
993 |
+
FoVx=reference_cam.FoVx,
|
994 |
+
FoVy=reference_cam.FoVy,
|
995 |
+
resolution=(reference_cam.image_width, reference_cam.image_height),
|
996 |
+
colmap_id=-1,
|
997 |
+
depth_params=None,
|
998 |
+
image=Image.fromarray(np.zeros((reference_cam.image_height, reference_cam.image_width, 3), dtype=np.uint8)),
|
999 |
+
invdepthmap=None,
|
1000 |
+
image_name=f"cluster_path_i={i}",
|
1001 |
+
uid=i
|
1002 |
+
))
|
1003 |
+
|
1004 |
+
return smooth_cameras
|
1005 |
+
|
1006 |
+
|
1007 |
+
|
1008 |
+
|
1009 |
+
def visualize_image_with_points(image, points):
|
1010 |
+
"""
|
1011 |
+
Visualize an image with points overlaid on top. This is useful for correspondences visualizations
|
1012 |
+
|
1013 |
+
Parameters:
|
1014 |
+
- image: PIL Image object
|
1015 |
+
- points: Numpy array of shape [N, 2] containing (x, y) coordinates of points
|
1016 |
+
|
1017 |
+
Returns:
|
1018 |
+
- None (displays the visualization)
|
1019 |
+
"""
|
1020 |
+
|
1021 |
+
# Convert PIL image to numpy array
|
1022 |
+
img_array = np.array(image)
|
1023 |
+
|
1024 |
+
# Create a figure and axis
|
1025 |
+
fig, ax = plt.subplots(figsize=(7,7))
|
1026 |
+
|
1027 |
+
# Display the image
|
1028 |
+
ax.imshow(img_array)
|
1029 |
+
|
1030 |
+
# Scatter plot the points on top of the image
|
1031 |
+
ax.scatter(points[:, 0], points[:, 1], color='red', marker='o', s=1)
|
1032 |
+
|
1033 |
+
# Show the plot
|
1034 |
+
plt.show()
|
1035 |
+
|
1036 |
+
|
1037 |
+
def visualize_correspondences(image1, points1, image2, points2):
|
1038 |
+
"""
|
1039 |
+
Visualize two images concatenated horizontally with key points and correspondences.
|
1040 |
+
|
1041 |
+
Parameters:
|
1042 |
+
- image1: PIL Image object (left image)
|
1043 |
+
- points1: Numpy array of shape [N, 2] containing (x, y) coordinates of key points for image1
|
1044 |
+
- image2: PIL Image object (right image)
|
1045 |
+
- points2: Numpy array of shape [N, 2] containing (x, y) coordinates of key points for image2
|
1046 |
+
|
1047 |
+
Returns:
|
1048 |
+
- None (displays the visualization)
|
1049 |
+
"""
|
1050 |
+
|
1051 |
+
# Concatenate images horizontally
|
1052 |
+
concatenated_image = np.concatenate((np.array(image1), np.array(image2)), axis=1)
|
1053 |
+
|
1054 |
+
# Create a figure and axis
|
1055 |
+
fig, ax = plt.subplots(figsize=(10,10))
|
1056 |
+
|
1057 |
+
# Display the concatenated image
|
1058 |
+
ax.imshow(concatenated_image)
|
1059 |
+
|
1060 |
+
# Plot key points on the left image
|
1061 |
+
ax.scatter(points1[:, 0], points1[:, 1], color='red', marker='o', s=10)
|
1062 |
+
|
1063 |
+
# Plot key points on the right image
|
1064 |
+
ax.scatter(points2[:, 0] + image1.width, points2[:, 1], color='blue', marker='o', s=10)
|
1065 |
+
|
1066 |
+
# Draw lines connecting corresponding key points
|
1067 |
+
for i in range(len(points1)):
|
1068 |
+
ax.plot([points1[i, 0], points2[i, 0] + image1.width], [points1[i, 1], points2[i, 1]])#, color='green')
|
1069 |
+
|
1070 |
+
# Show the plot
|
1071 |
+
plt.show()
|
1072 |
+
|
submodules/RoMa/.gitignore
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.egg-info*
|
2 |
+
*.vscode*
|
3 |
+
*__pycache__*
|
4 |
+
vis*
|
5 |
+
workspace*
|
6 |
+
.venv
|
7 |
+
.DS_Store
|
8 |
+
jobs/*
|
9 |
+
*ignore_me*
|
10 |
+
*.pth
|
11 |
+
wandb*
|
submodules/RoMa/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 Johan Edstedt
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
submodules/RoMa/README.md
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
<p align="center">
|
3 |
+
<h1 align="center"> <ins>RoMa</ins> 🏛️:<br> Robust Dense Feature Matching <br> ⭐CVPR 2024⭐</h1>
|
4 |
+
<p align="center">
|
5 |
+
<a href="https://scholar.google.com/citations?user=Ul-vMR0AAAAJ">Johan Edstedt</a>
|
6 |
+
·
|
7 |
+
<a href="https://scholar.google.com/citations?user=HS2WuHkAAAAJ">Qiyu Sun</a>
|
8 |
+
·
|
9 |
+
<a href="https://scholar.google.com/citations?user=FUE3Wd0AAAAJ">Georg Bökman</a>
|
10 |
+
·
|
11 |
+
<a href="https://scholar.google.com/citations?user=6WRQpCQAAAAJ">Mårten Wadenbäck</a>
|
12 |
+
·
|
13 |
+
<a href="https://scholar.google.com/citations?user=lkWfR08AAAAJ">Michael Felsberg</a>
|
14 |
+
</p>
|
15 |
+
<h2 align="center"><p>
|
16 |
+
<a href="https://arxiv.org/abs/2305.15404" align="center">Paper</a> |
|
17 |
+
<a href="https://parskatt.github.io/RoMa" align="center">Project Page</a>
|
18 |
+
</p></h2>
|
19 |
+
<div align="center"></div>
|
20 |
+
</p>
|
21 |
+
<br/>
|
22 |
+
<p align="center">
|
23 |
+
<img src="https://github.com/Parskatt/RoMa/assets/22053118/15d8fea7-aa6d-479f-8a93-350d950d006b" alt="example" width=80%>
|
24 |
+
<br>
|
25 |
+
<em>RoMa is the robust dense feature matcher capable of estimating pixel-dense warps and reliable certainties for almost any image pair.</em>
|
26 |
+
</p>
|
27 |
+
|
28 |
+
## Setup/Install
|
29 |
+
In your python environment (tested on Linux python 3.10), run:
|
30 |
+
```bash
|
31 |
+
pip install -e .
|
32 |
+
```
|
33 |
+
## Demo / How to Use
|
34 |
+
We provide two demos in the [demos folder](demo).
|
35 |
+
Here's the gist of it:
|
36 |
+
```python
|
37 |
+
from romatch import roma_outdoor
|
38 |
+
roma_model = roma_outdoor(device=device)
|
39 |
+
# Match
|
40 |
+
warp, certainty = roma_model.match(imA_path, imB_path, device=device)
|
41 |
+
# Sample matches for estimation
|
42 |
+
matches, certainty = roma_model.sample(warp, certainty)
|
43 |
+
# Convert to pixel coordinates (RoMa produces matches in [-1,1]x[-1,1])
|
44 |
+
kptsA, kptsB = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
|
45 |
+
# Find a fundamental matrix (or anything else of interest)
|
46 |
+
F, mask = cv2.findFundamentalMat(
|
47 |
+
kptsA.cpu().numpy(), kptsB.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000
|
48 |
+
)
|
49 |
+
```
|
50 |
+
|
51 |
+
**New**: You can also match arbitrary keypoints with RoMa. See [match_keypoints](romatch/models/matcher.py) in RegressionMatcher.
|
52 |
+
|
53 |
+
## Settings
|
54 |
+
|
55 |
+
### Resolution
|
56 |
+
By default RoMa uses an initial resolution of (560,560) which is then upsampled to (864,864).
|
57 |
+
You can change this at construction (see roma_outdoor kwargs).
|
58 |
+
You can also change this later, by changing the roma_model.w_resized, roma_model.h_resized, and roma_model.upsample_res.
|
59 |
+
|
60 |
+
### Sampling
|
61 |
+
roma_model.sample_thresh controls the thresholding used when sampling matches for estimation. In certain cases a lower or higher threshold may improve results.
|
62 |
+
|
63 |
+
|
64 |
+
## Reproducing Results
|
65 |
+
The experiments in the paper are provided in the [experiments folder](experiments).
|
66 |
+
|
67 |
+
### Training
|
68 |
+
1. First follow the instructions provided here: https://github.com/Parskatt/DKM for downloading and preprocessing datasets.
|
69 |
+
2. Run the relevant experiment, e.g.,
|
70 |
+
```bash
|
71 |
+
torchrun --nproc_per_node=4 --nnodes=1 --rdzv_backend=c10d experiments/roma_outdoor.py
|
72 |
+
```
|
73 |
+
### Testing
|
74 |
+
```bash
|
75 |
+
python experiments/roma_outdoor.py --only_test --benchmark mega-1500
|
76 |
+
```
|
77 |
+
## License
|
78 |
+
All our code except DINOv2 is MIT license.
|
79 |
+
DINOv2 has an Apache 2 license [DINOv2](https://github.com/facebookresearch/dinov2/blob/main/LICENSE).
|
80 |
+
|
81 |
+
## Acknowledgement
|
82 |
+
Our codebase builds on the code in [DKM](https://github.com/Parskatt/DKM).
|
83 |
+
|
84 |
+
## Tiny RoMa
|
85 |
+
If you find that RoMa is too heavy, you might want to try Tiny RoMa which is built on top of XFeat.
|
86 |
+
```python
|
87 |
+
from romatch import tiny_roma_v1_outdoor
|
88 |
+
tiny_roma_model = tiny_roma_v1_outdoor(device=device)
|
89 |
+
```
|
90 |
+
Mega1500:
|
91 |
+
| | AUC@5 | AUC@10 | AUC@20 |
|
92 |
+
|----------|----------|----------|----------|
|
93 |
+
| XFeat | 46.4 | 58.9 | 69.2 |
|
94 |
+
| XFeat* | 51.9 | 67.2 | 78.9 |
|
95 |
+
| Tiny RoMa v1 | 56.4 | 69.5 | 79.5 |
|
96 |
+
| RoMa | - | - | - |
|
97 |
+
|
98 |
+
Mega-8-Scenes (See DKM):
|
99 |
+
| | AUC@5 | AUC@10 | AUC@20 |
|
100 |
+
|----------|----------|----------|----------|
|
101 |
+
| XFeat | - | - | - |
|
102 |
+
| XFeat* | 50.1 | 64.4 | 75.2 |
|
103 |
+
| Tiny RoMa v1 | 57.7 | 70.5 | 79.6 |
|
104 |
+
| RoMa | - | - | - |
|
105 |
+
|
106 |
+
IMC22 :'):
|
107 |
+
| | mAA@10 |
|
108 |
+
|----------|----------|
|
109 |
+
| XFeat | 42.1 |
|
110 |
+
| XFeat* | - |
|
111 |
+
| Tiny RoMa v1 | 42.2 |
|
112 |
+
| RoMa | - |
|
113 |
+
|
114 |
+
## BibTeX
|
115 |
+
If you find our models useful, please consider citing our paper!
|
116 |
+
```
|
117 |
+
@article{edstedt2024roma,
|
118 |
+
title={{RoMa: Robust Dense Feature Matching}},
|
119 |
+
author={Edstedt, Johan and Sun, Qiyu and Bökman, Georg and Wadenbäck, Mårten and Felsberg, Michael},
|
120 |
+
journal={IEEE Conference on Computer Vision and Pattern Recognition},
|
121 |
+
year={2024}
|
122 |
+
}
|
123 |
+
```
|
submodules/RoMa/data/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*
|
2 |
+
!.gitignore
|
submodules/RoMa/demo/demo_3D_effect.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
from romatch.utils.utils import tensor_to_pil
|
6 |
+
|
7 |
+
from romatch import roma_outdoor
|
8 |
+
|
9 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
10 |
+
if torch.backends.mps.is_available():
|
11 |
+
device = torch.device('mps')
|
12 |
+
|
13 |
+
if __name__ == "__main__":
|
14 |
+
from argparse import ArgumentParser
|
15 |
+
parser = ArgumentParser()
|
16 |
+
parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
|
17 |
+
parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
|
18 |
+
parser.add_argument("--save_path", default="demo/gif/roma_warp_toronto", type=str)
|
19 |
+
|
20 |
+
args, _ = parser.parse_known_args()
|
21 |
+
im1_path = args.im_A_path
|
22 |
+
im2_path = args.im_B_path
|
23 |
+
save_path = args.save_path
|
24 |
+
|
25 |
+
# Create model
|
26 |
+
roma_model = roma_outdoor(device=device, coarse_res=560, upsample_res=(864, 1152))
|
27 |
+
roma_model.symmetric = False
|
28 |
+
|
29 |
+
H, W = roma_model.get_output_resolution()
|
30 |
+
|
31 |
+
im1 = Image.open(im1_path).resize((W, H))
|
32 |
+
im2 = Image.open(im2_path).resize((W, H))
|
33 |
+
|
34 |
+
# Match
|
35 |
+
warp, certainty = roma_model.match(im1_path, im2_path, device=device)
|
36 |
+
# Sampling not needed, but can be done with model.sample(warp, certainty)
|
37 |
+
x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
|
38 |
+
x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
|
39 |
+
|
40 |
+
coords_A, coords_B = warp[...,:2], warp[...,2:]
|
41 |
+
for i, x in enumerate(np.linspace(0,2*np.pi,200)):
|
42 |
+
t = (1 + np.cos(x))/2
|
43 |
+
interp_warp = (1-t)*coords_A + t*coords_B
|
44 |
+
im2_transfer_rgb = F.grid_sample(
|
45 |
+
x2[None], interp_warp[None], mode="bilinear", align_corners=False
|
46 |
+
)[0]
|
47 |
+
tensor_to_pil(im2_transfer_rgb, unnormalize=False).save(f"{save_path}_{i:03d}.jpg")
|
submodules/RoMa/demo/demo_fundamental.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import torch
|
3 |
+
import cv2
|
4 |
+
from romatch import roma_outdoor
|
5 |
+
|
6 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
7 |
+
if torch.backends.mps.is_available():
|
8 |
+
device = torch.device('mps')
|
9 |
+
|
10 |
+
if __name__ == "__main__":
|
11 |
+
from argparse import ArgumentParser
|
12 |
+
parser = ArgumentParser()
|
13 |
+
parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
|
14 |
+
parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
|
15 |
+
|
16 |
+
args, _ = parser.parse_known_args()
|
17 |
+
im1_path = args.im_A_path
|
18 |
+
im2_path = args.im_B_path
|
19 |
+
|
20 |
+
# Create model
|
21 |
+
roma_model = roma_outdoor(device=device)
|
22 |
+
|
23 |
+
|
24 |
+
W_A, H_A = Image.open(im1_path).size
|
25 |
+
W_B, H_B = Image.open(im2_path).size
|
26 |
+
|
27 |
+
# Match
|
28 |
+
warp, certainty = roma_model.match(im1_path, im2_path, device=device)
|
29 |
+
# Sample matches for estimation
|
30 |
+
matches, certainty = roma_model.sample(warp, certainty)
|
31 |
+
kpts1, kpts2 = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
|
32 |
+
F, mask = cv2.findFundamentalMat(
|
33 |
+
kpts1.cpu().numpy(), kpts2.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000
|
34 |
+
)
|
submodules/RoMa/demo/demo_match.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import numpy as np
|
7 |
+
from romatch.utils.utils import tensor_to_pil
|
8 |
+
|
9 |
+
from romatch import roma_outdoor
|
10 |
+
|
11 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
12 |
+
if torch.backends.mps.is_available():
|
13 |
+
device = torch.device('mps')
|
14 |
+
|
15 |
+
if __name__ == "__main__":
|
16 |
+
from argparse import ArgumentParser
|
17 |
+
parser = ArgumentParser()
|
18 |
+
parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
|
19 |
+
parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
|
20 |
+
parser.add_argument("--save_path", default="demo/roma_warp_toronto.jpg", type=str)
|
21 |
+
|
22 |
+
args, _ = parser.parse_known_args()
|
23 |
+
im1_path = args.im_A_path
|
24 |
+
im2_path = args.im_B_path
|
25 |
+
save_path = args.save_path
|
26 |
+
|
27 |
+
# Create model
|
28 |
+
roma_model = roma_outdoor(device=device, coarse_res=560, upsample_res=(864, 1152))
|
29 |
+
|
30 |
+
H, W = roma_model.get_output_resolution()
|
31 |
+
|
32 |
+
im1 = Image.open(im1_path).resize((W, H))
|
33 |
+
im2 = Image.open(im2_path).resize((W, H))
|
34 |
+
|
35 |
+
# Match
|
36 |
+
warp, certainty = roma_model.match(im1_path, im2_path, device=device)
|
37 |
+
# Sampling not needed, but can be done with model.sample(warp, certainty)
|
38 |
+
x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
|
39 |
+
x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
|
40 |
+
|
41 |
+
im2_transfer_rgb = F.grid_sample(
|
42 |
+
x2[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
|
43 |
+
)[0]
|
44 |
+
im1_transfer_rgb = F.grid_sample(
|
45 |
+
x1[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
|
46 |
+
)[0]
|
47 |
+
warp_im = torch.cat((im2_transfer_rgb,im1_transfer_rgb),dim=2)
|
48 |
+
white_im = torch.ones((H,2*W),device=device)
|
49 |
+
vis_im = certainty * warp_im + (1 - certainty) * white_im
|
50 |
+
tensor_to_pil(vis_im, unnormalize=False).save(save_path)
|
submodules/RoMa/demo/demo_match_opencv_sift.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import cv2 as cv
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
if __name__ == "__main__":
|
11 |
+
from argparse import ArgumentParser
|
12 |
+
parser = ArgumentParser()
|
13 |
+
parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
|
14 |
+
parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
|
15 |
+
parser.add_argument("--save_path", default="demo/roma_warp_toronto.jpg", type=str)
|
16 |
+
|
17 |
+
args, _ = parser.parse_known_args()
|
18 |
+
im1_path = args.im_A_path
|
19 |
+
im2_path = args.im_B_path
|
20 |
+
save_path = args.save_path
|
21 |
+
|
22 |
+
img1 = cv.imread(im1_path,cv.IMREAD_GRAYSCALE) # queryImage
|
23 |
+
img2 = cv.imread(im2_path,cv.IMREAD_GRAYSCALE) # trainImage
|
24 |
+
# Initiate SIFT detector
|
25 |
+
sift = cv.SIFT_create()
|
26 |
+
# find the keypoints and descriptors with SIFT
|
27 |
+
kp1, des1 = sift.detectAndCompute(img1,None)
|
28 |
+
kp2, des2 = sift.detectAndCompute(img2,None)
|
29 |
+
# BFMatcher with default params
|
30 |
+
bf = cv.BFMatcher()
|
31 |
+
matches = bf.knnMatch(des1,des2,k=2)
|
32 |
+
# Apply ratio test
|
33 |
+
good = []
|
34 |
+
for m,n in matches:
|
35 |
+
if m.distance < 0.75*n.distance:
|
36 |
+
good.append([m])
|
37 |
+
# cv.drawMatchesKnn expects list of lists as matches.
|
38 |
+
draw_params = dict(matchColor = (255,0,0), # draw matches in red color
|
39 |
+
singlePointColor = None,
|
40 |
+
flags = 2)
|
41 |
+
|
42 |
+
img3 = cv.drawMatchesKnn(img1,kp1,img2,kp2,good,None,**draw_params)
|
43 |
+
Image.fromarray(img3).save("demo/sift_matches.png")
|
submodules/RoMa/demo/demo_match_tiny.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import numpy as np
|
7 |
+
from romatch.utils.utils import tensor_to_pil
|
8 |
+
|
9 |
+
from romatch import tiny_roma_v1_outdoor
|
10 |
+
|
11 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
12 |
+
if torch.backends.mps.is_available():
|
13 |
+
device = torch.device('mps')
|
14 |
+
|
15 |
+
if __name__ == "__main__":
|
16 |
+
from argparse import ArgumentParser
|
17 |
+
parser = ArgumentParser()
|
18 |
+
parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
|
19 |
+
parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
|
20 |
+
parser.add_argument("--save_A_path", default="demo/tiny_roma_warp_A.jpg", type=str)
|
21 |
+
parser.add_argument("--save_B_path", default="demo/tiny_roma_warp_B.jpg", type=str)
|
22 |
+
|
23 |
+
args, _ = parser.parse_known_args()
|
24 |
+
im1_path = args.im_A_path
|
25 |
+
im2_path = args.im_B_path
|
26 |
+
|
27 |
+
# Create model
|
28 |
+
roma_model = tiny_roma_v1_outdoor(device=device)
|
29 |
+
|
30 |
+
# Match
|
31 |
+
warp, certainty1 = roma_model.match(im1_path, im2_path)
|
32 |
+
|
33 |
+
h1, w1 = warp.shape[:2]
|
34 |
+
|
35 |
+
# maybe im1.size != im2.size
|
36 |
+
im1 = Image.open(im1_path).resize((w1, h1))
|
37 |
+
im2 = Image.open(im2_path)
|
38 |
+
x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
|
39 |
+
x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
|
40 |
+
|
41 |
+
h2, w2 = x2.shape[1:]
|
42 |
+
g1_p2x = w2 / 2 * (warp[..., 2] + 1)
|
43 |
+
g1_p2y = h2 / 2 * (warp[..., 3] + 1)
|
44 |
+
g2_p1x = torch.zeros((h2, w2), dtype=torch.float32).to(device) - 2
|
45 |
+
g2_p1y = torch.zeros((h2, w2), dtype=torch.float32).to(device) - 2
|
46 |
+
|
47 |
+
x, y = torch.meshgrid(
|
48 |
+
torch.arange(w1, device=device),
|
49 |
+
torch.arange(h1, device=device),
|
50 |
+
indexing="xy",
|
51 |
+
)
|
52 |
+
g2x = torch.round(g1_p2x[y, x]).long()
|
53 |
+
g2y = torch.round(g1_p2y[y, x]).long()
|
54 |
+
idx_x = torch.bitwise_and(0 <= g2x, g2x < w2)
|
55 |
+
idx_y = torch.bitwise_and(0 <= g2y, g2y < h2)
|
56 |
+
idx = torch.bitwise_and(idx_x, idx_y)
|
57 |
+
g2_p1x[g2y[idx], g2x[idx]] = x[idx].float() * 2 / w1 - 1
|
58 |
+
g2_p1y[g2y[idx], g2x[idx]] = y[idx].float() * 2 / h1 - 1
|
59 |
+
|
60 |
+
certainty2 = F.grid_sample(
|
61 |
+
certainty1[None][None],
|
62 |
+
torch.stack([g2_p1x, g2_p1y], dim=2)[None],
|
63 |
+
mode="bilinear",
|
64 |
+
align_corners=False,
|
65 |
+
)[0]
|
66 |
+
|
67 |
+
white_im1 = torch.ones((h1, w1), device = device)
|
68 |
+
white_im2 = torch.ones((h2, w2), device = device)
|
69 |
+
|
70 |
+
certainty1 = F.avg_pool2d(certainty1[None], kernel_size=5, stride=1, padding=2)[0]
|
71 |
+
certainty2 = F.avg_pool2d(certainty2[None], kernel_size=5, stride=1, padding=2)[0]
|
72 |
+
|
73 |
+
vis_im1 = certainty1 * x1 + (1 - certainty1) * white_im1
|
74 |
+
vis_im2 = certainty2 * x2 + (1 - certainty2) * white_im2
|
75 |
+
|
76 |
+
tensor_to_pil(vis_im1, unnormalize=False).save(args.save_A_path)
|
77 |
+
tensor_to_pil(vis_im2, unnormalize=False).save(args.save_B_path)
|
submodules/RoMa/demo/gif/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*
|
2 |
+
!.gitignore
|
submodules/RoMa/experiments/eval_roma_outdoor.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
from romatch.benchmarks import MegadepthDenseBenchmark
|
4 |
+
from romatch.benchmarks import MegaDepthPoseEstimationBenchmark, HpatchesHomogBenchmark
|
5 |
+
from romatch.benchmarks import Mega1500PoseLibBenchmark
|
6 |
+
|
7 |
+
def test_mega_8_scenes(model, name):
|
8 |
+
mega_8_scenes_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth",
|
9 |
+
scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
|
10 |
+
'mega_8_scenes_0025_0.1_0.3.npz',
|
11 |
+
'mega_8_scenes_0021_0.1_0.3.npz',
|
12 |
+
'mega_8_scenes_0008_0.1_0.3.npz',
|
13 |
+
'mega_8_scenes_0032_0.1_0.3.npz',
|
14 |
+
'mega_8_scenes_1589_0.1_0.3.npz',
|
15 |
+
'mega_8_scenes_0063_0.1_0.3.npz',
|
16 |
+
'mega_8_scenes_0024_0.1_0.3.npz',
|
17 |
+
'mega_8_scenes_0019_0.3_0.5.npz',
|
18 |
+
'mega_8_scenes_0025_0.3_0.5.npz',
|
19 |
+
'mega_8_scenes_0021_0.3_0.5.npz',
|
20 |
+
'mega_8_scenes_0008_0.3_0.5.npz',
|
21 |
+
'mega_8_scenes_0032_0.3_0.5.npz',
|
22 |
+
'mega_8_scenes_1589_0.3_0.5.npz',
|
23 |
+
'mega_8_scenes_0063_0.3_0.5.npz',
|
24 |
+
'mega_8_scenes_0024_0.3_0.5.npz'])
|
25 |
+
mega_8_scenes_results = mega_8_scenes_benchmark.benchmark(model, model_name=name)
|
26 |
+
print(mega_8_scenes_results)
|
27 |
+
json.dump(mega_8_scenes_results, open(f"results/mega_8_scenes_{name}.json", "w"))
|
28 |
+
|
29 |
+
def test_mega1500(model, name):
|
30 |
+
mega1500_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth")
|
31 |
+
mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
|
32 |
+
json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w"))
|
33 |
+
|
34 |
+
def test_mega1500_poselib(model, name):
|
35 |
+
mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth")
|
36 |
+
mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
|
37 |
+
json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w"))
|
38 |
+
|
39 |
+
def test_mega_dense(model, name):
|
40 |
+
megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000)
|
41 |
+
megadense_results = megadense_benchmark.benchmark(model)
|
42 |
+
json.dump(megadense_results, open(f"results/mega_dense_{name}.json", "w"))
|
43 |
+
|
44 |
+
def test_hpatches(model, name):
|
45 |
+
hpatches_benchmark = HpatchesHomogBenchmark("data/hpatches")
|
46 |
+
hpatches_results = hpatches_benchmark.benchmark(model)
|
47 |
+
json.dump(hpatches_results, open(f"results/hpatches_{name}.json", "w"))
|
48 |
+
|
49 |
+
|
50 |
+
if __name__ == "__main__":
|
51 |
+
from romatch import roma_outdoor
|
52 |
+
device = "cuda"
|
53 |
+
model = roma_outdoor(device = device, coarse_res = 672, upsample_res = 1344)
|
54 |
+
experiment_name = "roma_latest"
|
55 |
+
test_mega1500(model, experiment_name)
|
56 |
+
#test_mega1500_poselib(model, experiment_name)
|
57 |
+
|
submodules/RoMa/experiments/eval_tiny_roma_v1_outdoor.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
import json
|
5 |
+
from romatch.benchmarks import ScanNetBenchmark
|
6 |
+
from romatch.benchmarks import Mega1500PoseLibBenchmark, ScanNetPoselibBenchmark
|
7 |
+
from romatch.benchmarks import MegaDepthPoseEstimationBenchmark
|
8 |
+
|
9 |
+
def test_mega_8_scenes(model, name):
|
10 |
+
mega_8_scenes_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth",
|
11 |
+
scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
|
12 |
+
'mega_8_scenes_0025_0.1_0.3.npz',
|
13 |
+
'mega_8_scenes_0021_0.1_0.3.npz',
|
14 |
+
'mega_8_scenes_0008_0.1_0.3.npz',
|
15 |
+
'mega_8_scenes_0032_0.1_0.3.npz',
|
16 |
+
'mega_8_scenes_1589_0.1_0.3.npz',
|
17 |
+
'mega_8_scenes_0063_0.1_0.3.npz',
|
18 |
+
'mega_8_scenes_0024_0.1_0.3.npz',
|
19 |
+
'mega_8_scenes_0019_0.3_0.5.npz',
|
20 |
+
'mega_8_scenes_0025_0.3_0.5.npz',
|
21 |
+
'mega_8_scenes_0021_0.3_0.5.npz',
|
22 |
+
'mega_8_scenes_0008_0.3_0.5.npz',
|
23 |
+
'mega_8_scenes_0032_0.3_0.5.npz',
|
24 |
+
'mega_8_scenes_1589_0.3_0.5.npz',
|
25 |
+
'mega_8_scenes_0063_0.3_0.5.npz',
|
26 |
+
'mega_8_scenes_0024_0.3_0.5.npz'])
|
27 |
+
mega_8_scenes_results = mega_8_scenes_benchmark.benchmark(model, model_name=name)
|
28 |
+
print(mega_8_scenes_results)
|
29 |
+
json.dump(mega_8_scenes_results, open(f"results/mega_8_scenes_{name}.json", "w"))
|
30 |
+
|
31 |
+
def test_mega1500(model, name):
|
32 |
+
mega1500_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth")
|
33 |
+
mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
|
34 |
+
json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w"))
|
35 |
+
|
36 |
+
def test_mega1500_poselib(model, name):
|
37 |
+
#model.exact_softmax = True
|
38 |
+
mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 1)
|
39 |
+
mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
|
40 |
+
json.dump(mega1500_results, open(f"results/mega1500_poselib_{name}.json", "w"))
|
41 |
+
|
42 |
+
def test_mega_8_scenes_poselib(model, name):
|
43 |
+
mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 1,
|
44 |
+
scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
|
45 |
+
'mega_8_scenes_0025_0.1_0.3.npz',
|
46 |
+
'mega_8_scenes_0021_0.1_0.3.npz',
|
47 |
+
'mega_8_scenes_0008_0.1_0.3.npz',
|
48 |
+
'mega_8_scenes_0032_0.1_0.3.npz',
|
49 |
+
'mega_8_scenes_1589_0.1_0.3.npz',
|
50 |
+
'mega_8_scenes_0063_0.1_0.3.npz',
|
51 |
+
'mega_8_scenes_0024_0.1_0.3.npz',
|
52 |
+
'mega_8_scenes_0019_0.3_0.5.npz',
|
53 |
+
'mega_8_scenes_0025_0.3_0.5.npz',
|
54 |
+
'mega_8_scenes_0021_0.3_0.5.npz',
|
55 |
+
'mega_8_scenes_0008_0.3_0.5.npz',
|
56 |
+
'mega_8_scenes_0032_0.3_0.5.npz',
|
57 |
+
'mega_8_scenes_1589_0.3_0.5.npz',
|
58 |
+
'mega_8_scenes_0063_0.3_0.5.npz',
|
59 |
+
'mega_8_scenes_0024_0.3_0.5.npz'])
|
60 |
+
mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
|
61 |
+
json.dump(mega1500_results, open(f"results/mega_8_scenes_poselib_{name}.json", "w"))
|
62 |
+
|
63 |
+
def test_scannet_poselib(model, name):
|
64 |
+
scannet_benchmark = ScanNetPoselibBenchmark("data/scannet")
|
65 |
+
scannet_results = scannet_benchmark.benchmark(model)
|
66 |
+
json.dump(scannet_results, open(f"results/scannet_{name}.json", "w"))
|
67 |
+
|
68 |
+
def test_scannet(model, name):
|
69 |
+
scannet_benchmark = ScanNetBenchmark("data/scannet")
|
70 |
+
scannet_results = scannet_benchmark.benchmark(model)
|
71 |
+
json.dump(scannet_results, open(f"results/scannet_{name}.json", "w"))
|
72 |
+
|
73 |
+
if __name__ == "__main__":
|
74 |
+
os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations
|
75 |
+
os.environ["OMP_NUM_THREADS"] = "16"
|
76 |
+
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
|
77 |
+
from romatch import tiny_roma_v1_outdoor
|
78 |
+
|
79 |
+
experiment_name = Path(__file__).stem
|
80 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
81 |
+
model = tiny_roma_v1_outdoor(device)
|
82 |
+
#test_mega1500_poselib(model, experiment_name)
|
83 |
+
test_mega_8_scenes_poselib(model, experiment_name)
|
84 |
+
|
submodules/RoMa/experiments/roma_indoor.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from argparse import ArgumentParser
|
4 |
+
|
5 |
+
from torch import nn
|
6 |
+
from torch.utils.data import ConcatDataset
|
7 |
+
import torch.distributed as dist
|
8 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
9 |
+
|
10 |
+
import json
|
11 |
+
import wandb
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
from romatch.benchmarks import MegadepthDenseBenchmark
|
15 |
+
from romatch.datasets.megadepth import MegadepthBuilder
|
16 |
+
from romatch.datasets.scannet import ScanNetBuilder
|
17 |
+
from romatch.losses.robust_loss import RobustLosses
|
18 |
+
from romatch.benchmarks import MegadepthDenseBenchmark, ScanNetBenchmark
|
19 |
+
from romatch.train.train import train_k_steps
|
20 |
+
from romatch.models.matcher import *
|
21 |
+
from romatch.models.transformer import Block, TransformerDecoder, MemEffAttention
|
22 |
+
from romatch.models.encoders import *
|
23 |
+
from romatch.checkpointing import CheckPoint
|
24 |
+
|
25 |
+
resolutions = {"low":(448, 448), "medium":(14*8*5, 14*8*5), "high":(14*8*6, 14*8*6)}
|
26 |
+
|
27 |
+
def get_model(pretrained_backbone=True, resolution = "medium", **kwargs):
|
28 |
+
gp_dim = 512
|
29 |
+
feat_dim = 512
|
30 |
+
decoder_dim = gp_dim + feat_dim
|
31 |
+
cls_to_coord_res = 64
|
32 |
+
coordinate_decoder = TransformerDecoder(
|
33 |
+
nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]),
|
34 |
+
decoder_dim,
|
35 |
+
cls_to_coord_res**2 + 1,
|
36 |
+
is_classifier=True,
|
37 |
+
amp = True,
|
38 |
+
pos_enc = False,)
|
39 |
+
dw = True
|
40 |
+
hidden_blocks = 8
|
41 |
+
kernel_size = 5
|
42 |
+
displacement_emb = "linear"
|
43 |
+
disable_local_corr_grad = True
|
44 |
+
|
45 |
+
conv_refiner = nn.ModuleDict(
|
46 |
+
{
|
47 |
+
"16": ConvRefiner(
|
48 |
+
2 * 512+128+(2*7+1)**2,
|
49 |
+
2 * 512+128+(2*7+1)**2,
|
50 |
+
2 + 1,
|
51 |
+
kernel_size=kernel_size,
|
52 |
+
dw=dw,
|
53 |
+
hidden_blocks=hidden_blocks,
|
54 |
+
displacement_emb=displacement_emb,
|
55 |
+
displacement_emb_dim=128,
|
56 |
+
local_corr_radius = 7,
|
57 |
+
corr_in_other = True,
|
58 |
+
amp = True,
|
59 |
+
disable_local_corr_grad = disable_local_corr_grad,
|
60 |
+
bn_momentum = 0.01,
|
61 |
+
),
|
62 |
+
"8": ConvRefiner(
|
63 |
+
2 * 512+64+(2*3+1)**2,
|
64 |
+
2 * 512+64+(2*3+1)**2,
|
65 |
+
2 + 1,
|
66 |
+
kernel_size=kernel_size,
|
67 |
+
dw=dw,
|
68 |
+
hidden_blocks=hidden_blocks,
|
69 |
+
displacement_emb=displacement_emb,
|
70 |
+
displacement_emb_dim=64,
|
71 |
+
local_corr_radius = 3,
|
72 |
+
corr_in_other = True,
|
73 |
+
amp = True,
|
74 |
+
disable_local_corr_grad = disable_local_corr_grad,
|
75 |
+
bn_momentum = 0.01,
|
76 |
+
),
|
77 |
+
"4": ConvRefiner(
|
78 |
+
2 * 256+32+(2*2+1)**2,
|
79 |
+
2 * 256+32+(2*2+1)**2,
|
80 |
+
2 + 1,
|
81 |
+
kernel_size=kernel_size,
|
82 |
+
dw=dw,
|
83 |
+
hidden_blocks=hidden_blocks,
|
84 |
+
displacement_emb=displacement_emb,
|
85 |
+
displacement_emb_dim=32,
|
86 |
+
local_corr_radius = 2,
|
87 |
+
corr_in_other = True,
|
88 |
+
amp = True,
|
89 |
+
disable_local_corr_grad = disable_local_corr_grad,
|
90 |
+
bn_momentum = 0.01,
|
91 |
+
),
|
92 |
+
"2": ConvRefiner(
|
93 |
+
2 * 64+16,
|
94 |
+
128+16,
|
95 |
+
2 + 1,
|
96 |
+
kernel_size=kernel_size,
|
97 |
+
dw=dw,
|
98 |
+
hidden_blocks=hidden_blocks,
|
99 |
+
displacement_emb=displacement_emb,
|
100 |
+
displacement_emb_dim=16,
|
101 |
+
amp = True,
|
102 |
+
disable_local_corr_grad = disable_local_corr_grad,
|
103 |
+
bn_momentum = 0.01,
|
104 |
+
),
|
105 |
+
"1": ConvRefiner(
|
106 |
+
2 * 9 + 6,
|
107 |
+
24,
|
108 |
+
2 + 1,
|
109 |
+
kernel_size=kernel_size,
|
110 |
+
dw=dw,
|
111 |
+
hidden_blocks = hidden_blocks,
|
112 |
+
displacement_emb = displacement_emb,
|
113 |
+
displacement_emb_dim = 6,
|
114 |
+
amp = True,
|
115 |
+
disable_local_corr_grad = disable_local_corr_grad,
|
116 |
+
bn_momentum = 0.01,
|
117 |
+
),
|
118 |
+
}
|
119 |
+
)
|
120 |
+
kernel_temperature = 0.2
|
121 |
+
learn_temperature = False
|
122 |
+
no_cov = True
|
123 |
+
kernel = CosKernel
|
124 |
+
only_attention = False
|
125 |
+
basis = "fourier"
|
126 |
+
gp16 = GP(
|
127 |
+
kernel,
|
128 |
+
T=kernel_temperature,
|
129 |
+
learn_temperature=learn_temperature,
|
130 |
+
only_attention=only_attention,
|
131 |
+
gp_dim=gp_dim,
|
132 |
+
basis=basis,
|
133 |
+
no_cov=no_cov,
|
134 |
+
)
|
135 |
+
gps = nn.ModuleDict({"16": gp16})
|
136 |
+
proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512))
|
137 |
+
proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512))
|
138 |
+
proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256))
|
139 |
+
proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64))
|
140 |
+
proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9))
|
141 |
+
proj = nn.ModuleDict({
|
142 |
+
"16": proj16,
|
143 |
+
"8": proj8,
|
144 |
+
"4": proj4,
|
145 |
+
"2": proj2,
|
146 |
+
"1": proj1,
|
147 |
+
})
|
148 |
+
displacement_dropout_p = 0.0
|
149 |
+
gm_warp_dropout_p = 0.0
|
150 |
+
decoder = Decoder(coordinate_decoder,
|
151 |
+
gps,
|
152 |
+
proj,
|
153 |
+
conv_refiner,
|
154 |
+
detach=True,
|
155 |
+
scales=["16", "8", "4", "2", "1"],
|
156 |
+
displacement_dropout_p = displacement_dropout_p,
|
157 |
+
gm_warp_dropout_p = gm_warp_dropout_p)
|
158 |
+
h,w = resolutions[resolution]
|
159 |
+
encoder = CNNandDinov2(
|
160 |
+
cnn_kwargs = dict(
|
161 |
+
pretrained=pretrained_backbone,
|
162 |
+
amp = True),
|
163 |
+
amp = True,
|
164 |
+
use_vgg = True,
|
165 |
+
)
|
166 |
+
matcher = RegressionMatcher(encoder, decoder, h=h, w=w, alpha=1, beta=0,**kwargs)
|
167 |
+
return matcher
|
168 |
+
|
169 |
+
def train(args):
|
170 |
+
dist.init_process_group('nccl')
|
171 |
+
#torch._dynamo.config.verbose=True
|
172 |
+
gpus = int(os.environ['WORLD_SIZE'])
|
173 |
+
# create model and move it to GPU with id rank
|
174 |
+
rank = dist.get_rank()
|
175 |
+
print(f"Start running DDP on rank {rank}")
|
176 |
+
device_id = rank % torch.cuda.device_count()
|
177 |
+
romatch.LOCAL_RANK = device_id
|
178 |
+
torch.cuda.set_device(device_id)
|
179 |
+
|
180 |
+
resolution = args.train_resolution
|
181 |
+
wandb_log = not args.dont_log_wandb
|
182 |
+
experiment_name = os.path.splitext(os.path.basename(__file__))[0]
|
183 |
+
wandb_mode = "online" if wandb_log and rank == 0 and False else "disabled"
|
184 |
+
wandb.init(project="romatch", entity=args.wandb_entity, name=experiment_name, reinit=False, mode = wandb_mode)
|
185 |
+
checkpoint_dir = "workspace/checkpoints/"
|
186 |
+
h,w = resolutions[resolution]
|
187 |
+
model = get_model(pretrained_backbone=True, resolution=resolution, attenuate_cert = False).to(device_id)
|
188 |
+
# Num steps
|
189 |
+
global_step = 0
|
190 |
+
batch_size = args.gpu_batch_size
|
191 |
+
step_size = gpus*batch_size
|
192 |
+
romatch.STEP_SIZE = step_size
|
193 |
+
|
194 |
+
N = (32 * 250000) # 250k steps of batch size 32
|
195 |
+
# checkpoint every
|
196 |
+
k = 25000 // romatch.STEP_SIZE
|
197 |
+
|
198 |
+
# Data
|
199 |
+
mega = MegadepthBuilder(data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True)
|
200 |
+
use_horizontal_flip_aug = True
|
201 |
+
rot_prob = 0
|
202 |
+
depth_interpolation_mode = "bilinear"
|
203 |
+
megadepth_train1 = mega.build_scenes(
|
204 |
+
split="train_loftr", min_overlap=0.01, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
|
205 |
+
ht=h,wt=w,
|
206 |
+
)
|
207 |
+
megadepth_train2 = mega.build_scenes(
|
208 |
+
split="train_loftr", min_overlap=0.35, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
|
209 |
+
ht=h,wt=w,
|
210 |
+
)
|
211 |
+
megadepth_train = ConcatDataset(megadepth_train1 + megadepth_train2)
|
212 |
+
mega_ws = mega.weight_scenes(megadepth_train, alpha=0.75)
|
213 |
+
|
214 |
+
scannet = ScanNetBuilder(data_root="data/scannet")
|
215 |
+
scannet_train = scannet.build_scenes(split="train", ht=h, wt=w, use_horizontal_flip_aug = use_horizontal_flip_aug)
|
216 |
+
scannet_train = ConcatDataset(scannet_train)
|
217 |
+
scannet_ws = scannet.weight_scenes(scannet_train, alpha=0.75)
|
218 |
+
|
219 |
+
# Loss and optimizer
|
220 |
+
depth_loss_scannet = RobustLosses(
|
221 |
+
ce_weight=0.0,
|
222 |
+
local_dist={1:4, 2:4, 4:8, 8:8},
|
223 |
+
local_largest_scale=8,
|
224 |
+
depth_interpolation_mode=depth_interpolation_mode,
|
225 |
+
alpha = 0.5,
|
226 |
+
c = 1e-4,)
|
227 |
+
# Loss and optimizer
|
228 |
+
depth_loss_mega = RobustLosses(
|
229 |
+
ce_weight=0.01,
|
230 |
+
local_dist={1:4, 2:4, 4:8, 8:8},
|
231 |
+
local_largest_scale=8,
|
232 |
+
depth_interpolation_mode=depth_interpolation_mode,
|
233 |
+
alpha = 0.5,
|
234 |
+
c = 1e-4,)
|
235 |
+
parameters = [
|
236 |
+
{"params": model.encoder.parameters(), "lr": romatch.STEP_SIZE * 5e-6 / 8},
|
237 |
+
{"params": model.decoder.parameters(), "lr": romatch.STEP_SIZE * 1e-4 / 8},
|
238 |
+
]
|
239 |
+
optimizer = torch.optim.AdamW(parameters, weight_decay=0.01)
|
240 |
+
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
241 |
+
optimizer, milestones=[(9*N/romatch.STEP_SIZE)//10])
|
242 |
+
megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000, h=h,w=w)
|
243 |
+
checkpointer = CheckPoint(checkpoint_dir, experiment_name)
|
244 |
+
model, optimizer, lr_scheduler, global_step = checkpointer.load(model, optimizer, lr_scheduler, global_step)
|
245 |
+
romatch.GLOBAL_STEP = global_step
|
246 |
+
ddp_model = DDP(model, device_ids=[device_id], find_unused_parameters = False, gradient_as_bucket_view=True)
|
247 |
+
grad_scaler = torch.cuda.amp.GradScaler(growth_interval=1_000_000)
|
248 |
+
grad_clip_norm = 0.01
|
249 |
+
for n in range(romatch.GLOBAL_STEP, N, k * romatch.STEP_SIZE):
|
250 |
+
mega_sampler = torch.utils.data.WeightedRandomSampler(
|
251 |
+
mega_ws, num_samples = batch_size * k, replacement=False
|
252 |
+
)
|
253 |
+
mega_dataloader = iter(
|
254 |
+
torch.utils.data.DataLoader(
|
255 |
+
megadepth_train,
|
256 |
+
batch_size = batch_size,
|
257 |
+
sampler = mega_sampler,
|
258 |
+
num_workers = 8,
|
259 |
+
)
|
260 |
+
)
|
261 |
+
scannet_ws_sampler = torch.utils.data.WeightedRandomSampler(
|
262 |
+
scannet_ws, num_samples=batch_size * k, replacement=False
|
263 |
+
)
|
264 |
+
scannet_dataloader = iter(
|
265 |
+
torch.utils.data.DataLoader(
|
266 |
+
scannet_train,
|
267 |
+
batch_size=batch_size,
|
268 |
+
sampler=scannet_ws_sampler,
|
269 |
+
num_workers=gpus * 8,
|
270 |
+
)
|
271 |
+
)
|
272 |
+
for n_k in tqdm(range(n, n + 2 * k, 2),disable = romatch.RANK > 0):
|
273 |
+
train_k_steps(
|
274 |
+
n_k, 1, mega_dataloader, ddp_model, depth_loss_mega, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm, progress_bar=False
|
275 |
+
)
|
276 |
+
train_k_steps(
|
277 |
+
n_k + 1, 1, scannet_dataloader, ddp_model, depth_loss_scannet, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm, progress_bar=False
|
278 |
+
)
|
279 |
+
checkpointer.save(model, optimizer, lr_scheduler, romatch.GLOBAL_STEP)
|
280 |
+
wandb.log(megadense_benchmark.benchmark(model), step = romatch.GLOBAL_STEP)
|
281 |
+
|
282 |
+
def test_scannet(model, name, resolution, sample_mode):
|
283 |
+
scannet_benchmark = ScanNetBenchmark("data/scannet")
|
284 |
+
scannet_results = scannet_benchmark.benchmark(model)
|
285 |
+
json.dump(scannet_results, open(f"results/scannet_{name}.json", "w"))
|
286 |
+
|
287 |
+
if __name__ == "__main__":
|
288 |
+
import warnings
|
289 |
+
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
290 |
+
warnings.filterwarnings('ignore')#, category=UserWarning)#, message='WARNING batched routines are designed for small sizes.')
|
291 |
+
os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations
|
292 |
+
os.environ["OMP_NUM_THREADS"] = "16"
|
293 |
+
|
294 |
+
import romatch
|
295 |
+
parser = ArgumentParser()
|
296 |
+
parser.add_argument("--test", action='store_true')
|
297 |
+
parser.add_argument("--debug_mode", action='store_true')
|
298 |
+
parser.add_argument("--dont_log_wandb", action='store_true')
|
299 |
+
parser.add_argument("--train_resolution", default='medium')
|
300 |
+
parser.add_argument("--gpu_batch_size", default=4, type=int)
|
301 |
+
parser.add_argument("--wandb_entity", required = False)
|
302 |
+
|
303 |
+
args, _ = parser.parse_known_args()
|
304 |
+
romatch.DEBUG_MODE = args.debug_mode
|
305 |
+
if not args.test:
|
306 |
+
train(args)
|
307 |
+
experiment_name = os.path.splitext(os.path.basename(__file__))[0]
|
308 |
+
checkpoint_dir = "workspace/"
|
309 |
+
checkpoint_name = checkpoint_dir + experiment_name + ".pth"
|
310 |
+
test_resolution = "medium"
|
311 |
+
sample_mode = "threshold_balanced"
|
312 |
+
symmetric = True
|
313 |
+
upsample_preds = False
|
314 |
+
attenuate_cert = True
|
315 |
+
|
316 |
+
model = get_model(pretrained_backbone=False, resolution = test_resolution, sample_mode = sample_mode, upsample_preds = upsample_preds, symmetric=symmetric, name=experiment_name, attenuate_cert = attenuate_cert)
|
317 |
+
model = model.cuda()
|
318 |
+
states = torch.load(checkpoint_name)
|
319 |
+
model.load_state_dict(states["model"])
|
320 |
+
test_scannet(model, experiment_name, resolution = test_resolution, sample_mode = sample_mode)
|
submodules/RoMa/experiments/train_roma_outdoor.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from argparse import ArgumentParser
|
4 |
+
|
5 |
+
from torch import nn
|
6 |
+
from torch.utils.data import ConcatDataset
|
7 |
+
import torch.distributed as dist
|
8 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
9 |
+
import json
|
10 |
+
import wandb
|
11 |
+
|
12 |
+
from romatch.benchmarks import MegadepthDenseBenchmark
|
13 |
+
from romatch.datasets.megadepth import MegadepthBuilder
|
14 |
+
from romatch.losses.robust_loss import RobustLosses
|
15 |
+
from romatch.benchmarks import MegaDepthPoseEstimationBenchmark, MegadepthDenseBenchmark, HpatchesHomogBenchmark
|
16 |
+
|
17 |
+
from romatch.train.train import train_k_steps
|
18 |
+
from romatch.models.matcher import *
|
19 |
+
from romatch.models.transformer import Block, TransformerDecoder, MemEffAttention
|
20 |
+
from romatch.models.encoders import *
|
21 |
+
from romatch.checkpointing import CheckPoint
|
22 |
+
|
23 |
+
resolutions = {"low":(448, 448), "medium":(14*8*5, 14*8*5), "high":(14*8*6, 14*8*6)}
|
24 |
+
|
25 |
+
def get_model(pretrained_backbone=True, resolution = "medium", **kwargs):
|
26 |
+
import warnings
|
27 |
+
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
28 |
+
gp_dim = 512
|
29 |
+
feat_dim = 512
|
30 |
+
decoder_dim = gp_dim + feat_dim
|
31 |
+
cls_to_coord_res = 64
|
32 |
+
coordinate_decoder = TransformerDecoder(
|
33 |
+
nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]),
|
34 |
+
decoder_dim,
|
35 |
+
cls_to_coord_res**2 + 1,
|
36 |
+
is_classifier=True,
|
37 |
+
amp = True,
|
38 |
+
pos_enc = False,)
|
39 |
+
dw = True
|
40 |
+
hidden_blocks = 8
|
41 |
+
kernel_size = 5
|
42 |
+
displacement_emb = "linear"
|
43 |
+
disable_local_corr_grad = True
|
44 |
+
|
45 |
+
conv_refiner = nn.ModuleDict(
|
46 |
+
{
|
47 |
+
"16": ConvRefiner(
|
48 |
+
2 * 512+128+(2*7+1)**2,
|
49 |
+
2 * 512+128+(2*7+1)**2,
|
50 |
+
2 + 1,
|
51 |
+
kernel_size=kernel_size,
|
52 |
+
dw=dw,
|
53 |
+
hidden_blocks=hidden_blocks,
|
54 |
+
displacement_emb=displacement_emb,
|
55 |
+
displacement_emb_dim=128,
|
56 |
+
local_corr_radius = 7,
|
57 |
+
corr_in_other = True,
|
58 |
+
amp = True,
|
59 |
+
disable_local_corr_grad = disable_local_corr_grad,
|
60 |
+
bn_momentum = 0.01,
|
61 |
+
),
|
62 |
+
"8": ConvRefiner(
|
63 |
+
2 * 512+64+(2*3+1)**2,
|
64 |
+
2 * 512+64+(2*3+1)**2,
|
65 |
+
2 + 1,
|
66 |
+
kernel_size=kernel_size,
|
67 |
+
dw=dw,
|
68 |
+
hidden_blocks=hidden_blocks,
|
69 |
+
displacement_emb=displacement_emb,
|
70 |
+
displacement_emb_dim=64,
|
71 |
+
local_corr_radius = 3,
|
72 |
+
corr_in_other = True,
|
73 |
+
amp = True,
|
74 |
+
disable_local_corr_grad = disable_local_corr_grad,
|
75 |
+
bn_momentum = 0.01,
|
76 |
+
),
|
77 |
+
"4": ConvRefiner(
|
78 |
+
2 * 256+32+(2*2+1)**2,
|
79 |
+
2 * 256+32+(2*2+1)**2,
|
80 |
+
2 + 1,
|
81 |
+
kernel_size=kernel_size,
|
82 |
+
dw=dw,
|
83 |
+
hidden_blocks=hidden_blocks,
|
84 |
+
displacement_emb=displacement_emb,
|
85 |
+
displacement_emb_dim=32,
|
86 |
+
local_corr_radius = 2,
|
87 |
+
corr_in_other = True,
|
88 |
+
amp = True,
|
89 |
+
disable_local_corr_grad = disable_local_corr_grad,
|
90 |
+
bn_momentum = 0.01,
|
91 |
+
),
|
92 |
+
"2": ConvRefiner(
|
93 |
+
2 * 64+16,
|
94 |
+
128+16,
|
95 |
+
2 + 1,
|
96 |
+
kernel_size=kernel_size,
|
97 |
+
dw=dw,
|
98 |
+
hidden_blocks=hidden_blocks,
|
99 |
+
displacement_emb=displacement_emb,
|
100 |
+
displacement_emb_dim=16,
|
101 |
+
amp = True,
|
102 |
+
disable_local_corr_grad = disable_local_corr_grad,
|
103 |
+
bn_momentum = 0.01,
|
104 |
+
),
|
105 |
+
"1": ConvRefiner(
|
106 |
+
2 * 9 + 6,
|
107 |
+
24,
|
108 |
+
2 + 1,
|
109 |
+
kernel_size=kernel_size,
|
110 |
+
dw=dw,
|
111 |
+
hidden_blocks = hidden_blocks,
|
112 |
+
displacement_emb = displacement_emb,
|
113 |
+
displacement_emb_dim = 6,
|
114 |
+
amp = True,
|
115 |
+
disable_local_corr_grad = disable_local_corr_grad,
|
116 |
+
bn_momentum = 0.01,
|
117 |
+
),
|
118 |
+
}
|
119 |
+
)
|
120 |
+
kernel_temperature = 0.2
|
121 |
+
learn_temperature = False
|
122 |
+
no_cov = True
|
123 |
+
kernel = CosKernel
|
124 |
+
only_attention = False
|
125 |
+
basis = "fourier"
|
126 |
+
gp16 = GP(
|
127 |
+
kernel,
|
128 |
+
T=kernel_temperature,
|
129 |
+
learn_temperature=learn_temperature,
|
130 |
+
only_attention=only_attention,
|
131 |
+
gp_dim=gp_dim,
|
132 |
+
basis=basis,
|
133 |
+
no_cov=no_cov,
|
134 |
+
)
|
135 |
+
gps = nn.ModuleDict({"16": gp16})
|
136 |
+
proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512))
|
137 |
+
proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512))
|
138 |
+
proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256))
|
139 |
+
proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64))
|
140 |
+
proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9))
|
141 |
+
proj = nn.ModuleDict({
|
142 |
+
"16": proj16,
|
143 |
+
"8": proj8,
|
144 |
+
"4": proj4,
|
145 |
+
"2": proj2,
|
146 |
+
"1": proj1,
|
147 |
+
})
|
148 |
+
displacement_dropout_p = 0.0
|
149 |
+
gm_warp_dropout_p = 0.0
|
150 |
+
decoder = Decoder(coordinate_decoder,
|
151 |
+
gps,
|
152 |
+
proj,
|
153 |
+
conv_refiner,
|
154 |
+
detach=True,
|
155 |
+
scales=["16", "8", "4", "2", "1"],
|
156 |
+
displacement_dropout_p = displacement_dropout_p,
|
157 |
+
gm_warp_dropout_p = gm_warp_dropout_p)
|
158 |
+
h,w = resolutions[resolution]
|
159 |
+
encoder = CNNandDinov2(
|
160 |
+
cnn_kwargs = dict(
|
161 |
+
pretrained=pretrained_backbone,
|
162 |
+
amp = True),
|
163 |
+
amp = True,
|
164 |
+
use_vgg = True,
|
165 |
+
)
|
166 |
+
matcher = RegressionMatcher(encoder, decoder, h=h, w=w,**kwargs)
|
167 |
+
return matcher
|
168 |
+
|
169 |
+
def train(args):
|
170 |
+
dist.init_process_group('nccl')
|
171 |
+
#torch._dynamo.config.verbose=True
|
172 |
+
gpus = int(os.environ['WORLD_SIZE'])
|
173 |
+
# create model and move it to GPU with id rank
|
174 |
+
rank = dist.get_rank()
|
175 |
+
print(f"Start running DDP on rank {rank}")
|
176 |
+
device_id = rank % torch.cuda.device_count()
|
177 |
+
romatch.LOCAL_RANK = device_id
|
178 |
+
torch.cuda.set_device(device_id)
|
179 |
+
|
180 |
+
resolution = args.train_resolution
|
181 |
+
wandb_log = not args.dont_log_wandb
|
182 |
+
experiment_name = os.path.splitext(os.path.basename(__file__))[0]
|
183 |
+
wandb_mode = "online" if wandb_log and rank == 0 else "disabled"
|
184 |
+
wandb.init(project="romatch", entity=args.wandb_entity, name=experiment_name, reinit=False, mode = wandb_mode)
|
185 |
+
checkpoint_dir = "workspace/checkpoints/"
|
186 |
+
h,w = resolutions[resolution]
|
187 |
+
model = get_model(pretrained_backbone=True, resolution=resolution, attenuate_cert = False).to(device_id)
|
188 |
+
# Num steps
|
189 |
+
global_step = 0
|
190 |
+
batch_size = args.gpu_batch_size
|
191 |
+
step_size = gpus*batch_size
|
192 |
+
romatch.STEP_SIZE = step_size
|
193 |
+
|
194 |
+
N = (32 * 250000) # 250k steps of batch size 32
|
195 |
+
# checkpoint every
|
196 |
+
k = 25000 // romatch.STEP_SIZE
|
197 |
+
|
198 |
+
# Data
|
199 |
+
mega = MegadepthBuilder(data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True)
|
200 |
+
use_horizontal_flip_aug = True
|
201 |
+
rot_prob = 0
|
202 |
+
depth_interpolation_mode = "bilinear"
|
203 |
+
megadepth_train1 = mega.build_scenes(
|
204 |
+
split="train_loftr", min_overlap=0.01, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
|
205 |
+
ht=h,wt=w,
|
206 |
+
)
|
207 |
+
megadepth_train2 = mega.build_scenes(
|
208 |
+
split="train_loftr", min_overlap=0.35, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
|
209 |
+
ht=h,wt=w,
|
210 |
+
)
|
211 |
+
megadepth_train = ConcatDataset(megadepth_train1 + megadepth_train2)
|
212 |
+
mega_ws = mega.weight_scenes(megadepth_train, alpha=0.75)
|
213 |
+
# Loss and optimizer
|
214 |
+
depth_loss = RobustLosses(
|
215 |
+
ce_weight=0.01,
|
216 |
+
local_dist={1:4, 2:4, 4:8, 8:8},
|
217 |
+
local_largest_scale=8,
|
218 |
+
depth_interpolation_mode=depth_interpolation_mode,
|
219 |
+
alpha = 0.5,
|
220 |
+
c = 1e-4,)
|
221 |
+
parameters = [
|
222 |
+
{"params": model.encoder.parameters(), "lr": romatch.STEP_SIZE * 5e-6 / 8},
|
223 |
+
{"params": model.decoder.parameters(), "lr": romatch.STEP_SIZE * 1e-4 / 8},
|
224 |
+
]
|
225 |
+
optimizer = torch.optim.AdamW(parameters, weight_decay=0.01)
|
226 |
+
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
227 |
+
optimizer, milestones=[(9*N/romatch.STEP_SIZE)//10])
|
228 |
+
megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000, h=h,w=w)
|
229 |
+
checkpointer = CheckPoint(checkpoint_dir, experiment_name)
|
230 |
+
model, optimizer, lr_scheduler, global_step = checkpointer.load(model, optimizer, lr_scheduler, global_step)
|
231 |
+
romatch.GLOBAL_STEP = global_step
|
232 |
+
ddp_model = DDP(model, device_ids=[device_id], find_unused_parameters = False, gradient_as_bucket_view=True)
|
233 |
+
grad_scaler = torch.cuda.amp.GradScaler(growth_interval=1_000_000)
|
234 |
+
grad_clip_norm = 0.01
|
235 |
+
for n in range(romatch.GLOBAL_STEP, N, k * romatch.STEP_SIZE):
|
236 |
+
mega_sampler = torch.utils.data.WeightedRandomSampler(
|
237 |
+
mega_ws, num_samples = batch_size * k, replacement=False
|
238 |
+
)
|
239 |
+
mega_dataloader = iter(
|
240 |
+
torch.utils.data.DataLoader(
|
241 |
+
megadepth_train,
|
242 |
+
batch_size = batch_size,
|
243 |
+
sampler = mega_sampler,
|
244 |
+
num_workers = 8,
|
245 |
+
)
|
246 |
+
)
|
247 |
+
train_k_steps(
|
248 |
+
n, k, mega_dataloader, ddp_model, depth_loss, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm,
|
249 |
+
)
|
250 |
+
checkpointer.save(model, optimizer, lr_scheduler, romatch.GLOBAL_STEP)
|
251 |
+
wandb.log(megadense_benchmark.benchmark(model), step = romatch.GLOBAL_STEP)
|
252 |
+
|
253 |
+
def test_mega_8_scenes(model, name):
|
254 |
+
mega_8_scenes_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth",
|
255 |
+
scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
|
256 |
+
'mega_8_scenes_0025_0.1_0.3.npz',
|
257 |
+
'mega_8_scenes_0021_0.1_0.3.npz',
|
258 |
+
'mega_8_scenes_0008_0.1_0.3.npz',
|
259 |
+
'mega_8_scenes_0032_0.1_0.3.npz',
|
260 |
+
'mega_8_scenes_1589_0.1_0.3.npz',
|
261 |
+
'mega_8_scenes_0063_0.1_0.3.npz',
|
262 |
+
'mega_8_scenes_0024_0.1_0.3.npz',
|
263 |
+
'mega_8_scenes_0019_0.3_0.5.npz',
|
264 |
+
'mega_8_scenes_0025_0.3_0.5.npz',
|
265 |
+
'mega_8_scenes_0021_0.3_0.5.npz',
|
266 |
+
'mega_8_scenes_0008_0.3_0.5.npz',
|
267 |
+
'mega_8_scenes_0032_0.3_0.5.npz',
|
268 |
+
'mega_8_scenes_1589_0.3_0.5.npz',
|
269 |
+
'mega_8_scenes_0063_0.3_0.5.npz',
|
270 |
+
'mega_8_scenes_0024_0.3_0.5.npz'])
|
271 |
+
mega_8_scenes_results = mega_8_scenes_benchmark.benchmark(model, model_name=name)
|
272 |
+
print(mega_8_scenes_results)
|
273 |
+
json.dump(mega_8_scenes_results, open(f"results/mega_8_scenes_{name}.json", "w"))
|
274 |
+
|
275 |
+
def test_mega1500(model, name):
|
276 |
+
mega1500_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth")
|
277 |
+
mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
|
278 |
+
json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w"))
|
279 |
+
|
280 |
+
def test_mega_dense(model, name):
|
281 |
+
megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000)
|
282 |
+
megadense_results = megadense_benchmark.benchmark(model)
|
283 |
+
json.dump(megadense_results, open(f"results/mega_dense_{name}.json", "w"))
|
284 |
+
|
285 |
+
def test_hpatches(model, name):
|
286 |
+
hpatches_benchmark = HpatchesHomogBenchmark("data/hpatches")
|
287 |
+
hpatches_results = hpatches_benchmark.benchmark(model)
|
288 |
+
json.dump(hpatches_results, open(f"results/hpatches_{name}.json", "w"))
|
289 |
+
|
290 |
+
|
291 |
+
if __name__ == "__main__":
|
292 |
+
os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations
|
293 |
+
os.environ["OMP_NUM_THREADS"] = "16"
|
294 |
+
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
|
295 |
+
import romatch
|
296 |
+
parser = ArgumentParser()
|
297 |
+
parser.add_argument("--only_test", action='store_true')
|
298 |
+
parser.add_argument("--debug_mode", action='store_true')
|
299 |
+
parser.add_argument("--dont_log_wandb", action='store_true')
|
300 |
+
parser.add_argument("--train_resolution", default='medium')
|
301 |
+
parser.add_argument("--gpu_batch_size", default=8, type=int)
|
302 |
+
parser.add_argument("--wandb_entity", required = False)
|
303 |
+
|
304 |
+
args, _ = parser.parse_known_args()
|
305 |
+
romatch.DEBUG_MODE = args.debug_mode
|
306 |
+
if not args.only_test:
|
307 |
+
train(args)
|
submodules/RoMa/experiments/train_tiny_roma_v1_outdoor.py
ADDED
@@ -0,0 +1,498 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
from argparse import ArgumentParser
|
7 |
+
from pathlib import Path
|
8 |
+
import math
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
from torch import nn
|
12 |
+
from torch.utils.data import ConcatDataset
|
13 |
+
import torch.distributed as dist
|
14 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
15 |
+
import json
|
16 |
+
import wandb
|
17 |
+
from PIL import Image
|
18 |
+
from torchvision.transforms import ToTensor
|
19 |
+
|
20 |
+
from romatch.benchmarks import MegadepthDenseBenchmark, ScanNetBenchmark
|
21 |
+
from romatch.benchmarks import Mega1500PoseLibBenchmark, ScanNetPoselibBenchmark
|
22 |
+
from romatch.datasets.megadepth import MegadepthBuilder
|
23 |
+
from romatch.losses.robust_loss_tiny_roma import RobustLosses
|
24 |
+
from romatch.benchmarks import MegaDepthPoseEstimationBenchmark, MegadepthDenseBenchmark, HpatchesHomogBenchmark
|
25 |
+
from romatch.train.train import train_k_steps
|
26 |
+
from romatch.checkpointing import CheckPoint
|
27 |
+
|
28 |
+
resolutions = {"low":(448, 448), "medium":(14*8*5, 14*8*5), "high":(14*8*6, 14*8*6), "xfeat": (600,800), "big": (768, 1024)}
|
29 |
+
|
30 |
+
def kde(x, std = 0.1):
|
31 |
+
# use a gaussian kernel to estimate density
|
32 |
+
x = x.half() # Do it in half precision TODO: remove hardcoding
|
33 |
+
scores = (-torch.cdist(x,x)**2/(2*std**2)).exp()
|
34 |
+
density = scores.sum(dim=-1)
|
35 |
+
return density
|
36 |
+
|
37 |
+
class BasicLayer(nn.Module):
|
38 |
+
"""
|
39 |
+
Basic Convolutional Layer: Conv2d -> BatchNorm -> ReLU
|
40 |
+
"""
|
41 |
+
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, bias=False, relu = True):
|
42 |
+
super().__init__()
|
43 |
+
self.layer = nn.Sequential(
|
44 |
+
nn.Conv2d( in_channels, out_channels, kernel_size, padding = padding, stride=stride, dilation=dilation, bias = bias),
|
45 |
+
nn.BatchNorm2d(out_channels, affine=False),
|
46 |
+
nn.ReLU(inplace = True) if relu else nn.Identity()
|
47 |
+
)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
return self.layer(x)
|
51 |
+
|
52 |
+
class XFeatModel(nn.Module):
|
53 |
+
"""
|
54 |
+
Implementation of architecture described in
|
55 |
+
"XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024."
|
56 |
+
"""
|
57 |
+
|
58 |
+
def __init__(self, xfeat = None,
|
59 |
+
freeze_xfeat = True,
|
60 |
+
sample_mode = "threshold_balanced",
|
61 |
+
symmetric = False,
|
62 |
+
exact_softmax = False):
|
63 |
+
super().__init__()
|
64 |
+
if xfeat is None:
|
65 |
+
xfeat = torch.hub.load('verlab/accelerated_features', 'XFeat', pretrained = True, top_k = 4096).net
|
66 |
+
del xfeat.heatmap_head, xfeat.keypoint_head, xfeat.fine_matcher
|
67 |
+
if freeze_xfeat:
|
68 |
+
xfeat.train(False)
|
69 |
+
self.xfeat = [xfeat]# hide params from ddp
|
70 |
+
else:
|
71 |
+
self.xfeat = nn.ModuleList([xfeat])
|
72 |
+
self.freeze_xfeat = freeze_xfeat
|
73 |
+
match_dim = 256
|
74 |
+
self.coarse_matcher = nn.Sequential(
|
75 |
+
BasicLayer(64+64+2, match_dim,),
|
76 |
+
BasicLayer(match_dim, match_dim,),
|
77 |
+
BasicLayer(match_dim, match_dim,),
|
78 |
+
BasicLayer(match_dim, match_dim,),
|
79 |
+
nn.Conv2d(match_dim, 3, kernel_size=1, bias=True, padding=0))
|
80 |
+
fine_match_dim = 64
|
81 |
+
self.fine_matcher = nn.Sequential(
|
82 |
+
BasicLayer(24+24+2, fine_match_dim,),
|
83 |
+
BasicLayer(fine_match_dim, fine_match_dim,),
|
84 |
+
BasicLayer(fine_match_dim, fine_match_dim,),
|
85 |
+
BasicLayer(fine_match_dim, fine_match_dim,),
|
86 |
+
nn.Conv2d(fine_match_dim, 3, kernel_size=1, bias=True, padding=0),)
|
87 |
+
self.sample_mode = sample_mode
|
88 |
+
self.sample_thresh = 0.2
|
89 |
+
self.symmetric = symmetric
|
90 |
+
self.exact_softmax = exact_softmax
|
91 |
+
|
92 |
+
@property
|
93 |
+
def device(self):
|
94 |
+
return self.fine_matcher[-1].weight.device
|
95 |
+
|
96 |
+
def preprocess_tensor(self, x):
|
97 |
+
""" Guarantee that image is divisible by 32 to avoid aliasing artifacts. """
|
98 |
+
H, W = x.shape[-2:]
|
99 |
+
_H, _W = (H//32) * 32, (W//32) * 32
|
100 |
+
rh, rw = H/_H, W/_W
|
101 |
+
|
102 |
+
x = F.interpolate(x, (_H, _W), mode='bilinear', align_corners=False)
|
103 |
+
return x, rh, rw
|
104 |
+
|
105 |
+
def forward_single(self, x):
|
106 |
+
with torch.inference_mode(self.freeze_xfeat or not self.training):
|
107 |
+
xfeat = self.xfeat[0]
|
108 |
+
with torch.no_grad():
|
109 |
+
x = x.mean(dim=1, keepdim = True)
|
110 |
+
x = xfeat.norm(x)
|
111 |
+
|
112 |
+
#main backbone
|
113 |
+
x1 = xfeat.block1(x)
|
114 |
+
x2 = xfeat.block2(x1 + xfeat.skip1(x))
|
115 |
+
x3 = xfeat.block3(x2)
|
116 |
+
x4 = xfeat.block4(x3)
|
117 |
+
x5 = xfeat.block5(x4)
|
118 |
+
x4 = F.interpolate(x4, (x3.shape[-2], x3.shape[-1]), mode='bilinear')
|
119 |
+
x5 = F.interpolate(x5, (x3.shape[-2], x3.shape[-1]), mode='bilinear')
|
120 |
+
feats = xfeat.block_fusion( x3 + x4 + x5 )
|
121 |
+
if self.freeze_xfeat:
|
122 |
+
return x2.clone(), feats.clone()
|
123 |
+
return x2, feats
|
124 |
+
|
125 |
+
def to_pixel_coordinates(self, coords, H_A, W_A, H_B = None, W_B = None):
|
126 |
+
if coords.shape[-1] == 2:
|
127 |
+
return self._to_pixel_coordinates(coords, H_A, W_A)
|
128 |
+
|
129 |
+
if isinstance(coords, (list, tuple)):
|
130 |
+
kpts_A, kpts_B = coords[0], coords[1]
|
131 |
+
else:
|
132 |
+
kpts_A, kpts_B = coords[...,:2], coords[...,2:]
|
133 |
+
return self._to_pixel_coordinates(kpts_A, H_A, W_A), self._to_pixel_coordinates(kpts_B, H_B, W_B)
|
134 |
+
|
135 |
+
def _to_pixel_coordinates(self, coords, H, W):
|
136 |
+
kpts = torch.stack((W/2 * (coords[...,0]+1), H/2 * (coords[...,1]+1)),axis=-1)
|
137 |
+
return kpts
|
138 |
+
|
139 |
+
def pos_embed(self, corr_volume: torch.Tensor):
|
140 |
+
B, H1, W1, H0, W0 = corr_volume.shape
|
141 |
+
grid = torch.stack(
|
142 |
+
torch.meshgrid(
|
143 |
+
torch.linspace(-1+1/W1,1-1/W1, W1),
|
144 |
+
torch.linspace(-1+1/H1,1-1/H1, H1),
|
145 |
+
indexing = "xy"),
|
146 |
+
dim = -1).float().to(corr_volume).reshape(H1*W1, 2)
|
147 |
+
down = 4
|
148 |
+
if not self.training and not self.exact_softmax:
|
149 |
+
grid_lr = torch.stack(
|
150 |
+
torch.meshgrid(
|
151 |
+
torch.linspace(-1+down/W1,1-down/W1, W1//down),
|
152 |
+
torch.linspace(-1+down/H1,1-down/H1, H1//down),
|
153 |
+
indexing = "xy"),
|
154 |
+
dim = -1).float().to(corr_volume).reshape(H1*W1 //down**2, 2)
|
155 |
+
cv = corr_volume
|
156 |
+
best_match = cv.reshape(B,H1*W1,H0,W0).amax(dim=1) # B, HW, H, W
|
157 |
+
P_lowres = torch.cat((cv[:,::down,::down].reshape(B,H1*W1 // down**2,H0,W0), best_match[:,None]),dim=1).softmax(dim=1)
|
158 |
+
pos_embeddings = torch.einsum('bchw,cd->bdhw', P_lowres[:,:-1], grid_lr)
|
159 |
+
pos_embeddings += P_lowres[:,-1] * grid[best_match].permute(0,3,1,2)
|
160 |
+
else:
|
161 |
+
P = corr_volume.reshape(B,H1*W1,H0,W0).softmax(dim=1) # B, HW, H, W
|
162 |
+
pos_embeddings = torch.einsum('bchw,cd->bdhw', P, grid)
|
163 |
+
return pos_embeddings
|
164 |
+
|
165 |
+
def visualize_warp(self, warp, certainty, im_A = None, im_B = None,
|
166 |
+
im_A_path = None, im_B_path = None, symmetric = True, save_path = None, unnormalize = False):
|
167 |
+
device = warp.device
|
168 |
+
H,W2,_ = warp.shape
|
169 |
+
W = W2//2 if symmetric else W2
|
170 |
+
if im_A is None:
|
171 |
+
from PIL import Image
|
172 |
+
im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")
|
173 |
+
if not isinstance(im_A, torch.Tensor):
|
174 |
+
im_A = im_A.resize((W,H))
|
175 |
+
im_B = im_B.resize((W,H))
|
176 |
+
x_B = (torch.tensor(np.array(im_B)) / 255).to(device).permute(2, 0, 1)
|
177 |
+
if symmetric:
|
178 |
+
x_A = (torch.tensor(np.array(im_A)) / 255).to(device).permute(2, 0, 1)
|
179 |
+
else:
|
180 |
+
if symmetric:
|
181 |
+
x_A = im_A
|
182 |
+
x_B = im_B
|
183 |
+
im_A_transfer_rgb = F.grid_sample(
|
184 |
+
x_B[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
|
185 |
+
)[0]
|
186 |
+
if symmetric:
|
187 |
+
im_B_transfer_rgb = F.grid_sample(
|
188 |
+
x_A[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
|
189 |
+
)[0]
|
190 |
+
warp_im = torch.cat((im_A_transfer_rgb,im_B_transfer_rgb),dim=2)
|
191 |
+
white_im = torch.ones((H,2*W),device=device)
|
192 |
+
else:
|
193 |
+
warp_im = im_A_transfer_rgb
|
194 |
+
white_im = torch.ones((H, W), device = device)
|
195 |
+
vis_im = certainty * warp_im + (1 - certainty) * white_im
|
196 |
+
if save_path is not None:
|
197 |
+
from romatch.utils import tensor_to_pil
|
198 |
+
tensor_to_pil(vis_im, unnormalize=unnormalize).save(save_path)
|
199 |
+
return vis_im
|
200 |
+
|
201 |
+
def corr_volume(self, feat0, feat1):
|
202 |
+
"""
|
203 |
+
input:
|
204 |
+
feat0 -> torch.Tensor(B, C, H, W)
|
205 |
+
feat1 -> torch.Tensor(B, C, H, W)
|
206 |
+
return:
|
207 |
+
corr_volume -> torch.Tensor(B, H, W, H, W)
|
208 |
+
"""
|
209 |
+
B, C, H0, W0 = feat0.shape
|
210 |
+
B, C, H1, W1 = feat1.shape
|
211 |
+
feat0 = feat0.view(B, C, H0*W0)
|
212 |
+
feat1 = feat1.view(B, C, H1*W1)
|
213 |
+
corr_volume = torch.einsum('bci,bcj->bji', feat0, feat1).reshape(B, H1, W1, H0 , W0)/math.sqrt(C) #16*16*16
|
214 |
+
return corr_volume
|
215 |
+
|
216 |
+
@torch.inference_mode()
|
217 |
+
def match_from_path(self, im0_path, im1_path):
|
218 |
+
device = self.device
|
219 |
+
im0 = ToTensor()(Image.open(im0_path))[None].to(device)
|
220 |
+
im1 = ToTensor()(Image.open(im1_path))[None].to(device)
|
221 |
+
return self.match(im0, im1, batched = False)
|
222 |
+
|
223 |
+
@torch.inference_mode()
|
224 |
+
def match(self, im0, im1, *args, batched = True):
|
225 |
+
# stupid
|
226 |
+
if isinstance(im0, (str, Path)):
|
227 |
+
return self.match_from_path(im0, im1)
|
228 |
+
elif isinstance(im0, Image.Image):
|
229 |
+
batched = False
|
230 |
+
device = self.device
|
231 |
+
im0 = ToTensor()(im0)[None].to(device)
|
232 |
+
im1 = ToTensor()(im1)[None].to(device)
|
233 |
+
|
234 |
+
B,C,H0,W0 = im0.shape
|
235 |
+
B,C,H1,W1 = im1.shape
|
236 |
+
self.train(False)
|
237 |
+
corresps = self.forward({"im_A":im0, "im_B":im1})
|
238 |
+
#return 1,1
|
239 |
+
flow = F.interpolate(
|
240 |
+
corresps[4]["flow"],
|
241 |
+
size = (H0, W0),
|
242 |
+
mode = "bilinear", align_corners = False).permute(0,2,3,1).reshape(B,H0,W0,2)
|
243 |
+
grid = torch.stack(
|
244 |
+
torch.meshgrid(
|
245 |
+
torch.linspace(-1+1/W0,1-1/W0, W0),
|
246 |
+
torch.linspace(-1+1/H0,1-1/H0, H0),
|
247 |
+
indexing = "xy"),
|
248 |
+
dim = -1).float().to(flow.device).expand(B, H0, W0, 2)
|
249 |
+
|
250 |
+
certainty = F.interpolate(corresps[4]["certainty"], size = (H0,W0), mode = "bilinear", align_corners = False)
|
251 |
+
warp, cert = torch.cat((grid, flow), dim = -1), certainty[:,0].sigmoid()
|
252 |
+
if batched:
|
253 |
+
return warp, cert
|
254 |
+
else:
|
255 |
+
return warp[0], cert[0]
|
256 |
+
|
257 |
+
def sample(
|
258 |
+
self,
|
259 |
+
matches,
|
260 |
+
certainty,
|
261 |
+
num=10000,
|
262 |
+
):
|
263 |
+
if "threshold" in self.sample_mode:
|
264 |
+
upper_thresh = self.sample_thresh
|
265 |
+
certainty = certainty.clone()
|
266 |
+
certainty[certainty > upper_thresh] = 1
|
267 |
+
matches, certainty = (
|
268 |
+
matches.reshape(-1, 4),
|
269 |
+
certainty.reshape(-1),
|
270 |
+
)
|
271 |
+
expansion_factor = 4 if "balanced" in self.sample_mode else 1
|
272 |
+
good_samples = torch.multinomial(certainty,
|
273 |
+
num_samples = min(expansion_factor*num, len(certainty)),
|
274 |
+
replacement=False)
|
275 |
+
good_matches, good_certainty = matches[good_samples], certainty[good_samples]
|
276 |
+
if "balanced" not in self.sample_mode:
|
277 |
+
return good_matches, good_certainty
|
278 |
+
density = kde(good_matches, std=0.1)
|
279 |
+
p = 1 / (density+1)
|
280 |
+
p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones
|
281 |
+
balanced_samples = torch.multinomial(p,
|
282 |
+
num_samples = min(num,len(good_certainty)),
|
283 |
+
replacement=False)
|
284 |
+
return good_matches[balanced_samples], good_certainty[balanced_samples]
|
285 |
+
|
286 |
+
def forward(self, batch):
|
287 |
+
"""
|
288 |
+
input:
|
289 |
+
x -> torch.Tensor(B, C, H, W) grayscale or rgb images
|
290 |
+
return:
|
291 |
+
|
292 |
+
"""
|
293 |
+
im0 = batch["im_A"]
|
294 |
+
im1 = batch["im_B"]
|
295 |
+
corresps = {}
|
296 |
+
im0, rh0, rw0 = self.preprocess_tensor(im0)
|
297 |
+
im1, rh1, rw1 = self.preprocess_tensor(im1)
|
298 |
+
B, C, H0, W0 = im0.shape
|
299 |
+
B, C, H1, W1 = im1.shape
|
300 |
+
to_normalized = torch.tensor((2/W1, 2/H1, 1)).to(im0.device)[None,:,None,None]
|
301 |
+
|
302 |
+
if im0.shape[-2:] == im1.shape[-2:]:
|
303 |
+
x = torch.cat([im0, im1], dim=0)
|
304 |
+
x = self.forward_single(x)
|
305 |
+
feats_x0_c, feats_x1_c = x[1].chunk(2)
|
306 |
+
feats_x0_f, feats_x1_f = x[0].chunk(2)
|
307 |
+
else:
|
308 |
+
feats_x0_f, feats_x0_c = self.forward_single(im0)
|
309 |
+
feats_x1_f, feats_x1_c = self.forward_single(im1)
|
310 |
+
corr_volume = self.corr_volume(feats_x0_c, feats_x1_c)
|
311 |
+
coarse_warp = self.pos_embed(corr_volume)
|
312 |
+
coarse_matches = torch.cat((coarse_warp, torch.zeros_like(coarse_warp[:,-1:])), dim=1)
|
313 |
+
feats_x1_c_warped = F.grid_sample(feats_x1_c, coarse_matches.permute(0, 2, 3, 1)[...,:2], mode = 'bilinear', align_corners = False)
|
314 |
+
coarse_matches_delta = self.coarse_matcher(torch.cat((feats_x0_c, feats_x1_c_warped, coarse_warp), dim=1))
|
315 |
+
coarse_matches = coarse_matches + coarse_matches_delta * to_normalized
|
316 |
+
corresps[8] = {"flow": coarse_matches[:,:2], "certainty": coarse_matches[:,2:]}
|
317 |
+
coarse_matches_up = F.interpolate(coarse_matches, size = feats_x0_f.shape[-2:], mode = "bilinear", align_corners = False)
|
318 |
+
coarse_matches_up_detach = coarse_matches_up.detach()#note the detach
|
319 |
+
feats_x1_f_warped = F.grid_sample(feats_x1_f, coarse_matches_up_detach.permute(0, 2, 3, 1)[...,:2], mode = 'bilinear', align_corners = False)
|
320 |
+
fine_matches_delta = self.fine_matcher(torch.cat((feats_x0_f, feats_x1_f_warped, coarse_matches_up_detach[:,:2]), dim=1))
|
321 |
+
fine_matches = coarse_matches_up_detach+fine_matches_delta * to_normalized
|
322 |
+
corresps[4] = {"flow": fine_matches[:,:2], "certainty": fine_matches[:,2:]}
|
323 |
+
return corresps
|
324 |
+
|
325 |
+
|
326 |
+
|
327 |
+
|
328 |
+
|
329 |
+
def train(args):
|
330 |
+
rank = 0
|
331 |
+
gpus = 1
|
332 |
+
device_id = rank % torch.cuda.device_count()
|
333 |
+
romatch.LOCAL_RANK = 0
|
334 |
+
torch.cuda.set_device(device_id)
|
335 |
+
|
336 |
+
resolution = "big"
|
337 |
+
wandb_log = not args.dont_log_wandb
|
338 |
+
experiment_name = Path(__file__).stem
|
339 |
+
wandb_mode = "online" if wandb_log and rank == 0 else "disabled"
|
340 |
+
wandb.init(project="romatch", entity=args.wandb_entity, name=experiment_name, reinit=False, mode = wandb_mode)
|
341 |
+
checkpoint_dir = "workspace/checkpoints/"
|
342 |
+
h,w = resolutions[resolution]
|
343 |
+
model = XFeatModel(freeze_xfeat = False).to(device_id)
|
344 |
+
# Num steps
|
345 |
+
global_step = 0
|
346 |
+
batch_size = args.gpu_batch_size
|
347 |
+
step_size = gpus*batch_size
|
348 |
+
romatch.STEP_SIZE = step_size
|
349 |
+
|
350 |
+
N = 2_000_000 # 2M pairs
|
351 |
+
# checkpoint every
|
352 |
+
k = 25000 // romatch.STEP_SIZE
|
353 |
+
|
354 |
+
# Data
|
355 |
+
mega = MegadepthBuilder(data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True)
|
356 |
+
use_horizontal_flip_aug = True
|
357 |
+
normalize = False # don't imgnet normalize
|
358 |
+
rot_prob = 0
|
359 |
+
depth_interpolation_mode = "bilinear"
|
360 |
+
megadepth_train1 = mega.build_scenes(
|
361 |
+
split="train_loftr", min_overlap=0.01, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
|
362 |
+
ht=h,wt=w, normalize = normalize
|
363 |
+
)
|
364 |
+
megadepth_train2 = mega.build_scenes(
|
365 |
+
split="train_loftr", min_overlap=0.35, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
|
366 |
+
ht=h,wt=w, normalize = normalize
|
367 |
+
)
|
368 |
+
megadepth_train = ConcatDataset(megadepth_train1 + megadepth_train2)
|
369 |
+
mega_ws = mega.weight_scenes(megadepth_train, alpha=0.75)
|
370 |
+
# Loss and optimizer
|
371 |
+
depth_loss = RobustLosses(
|
372 |
+
ce_weight=0.01,
|
373 |
+
local_dist={4:4},
|
374 |
+
depth_interpolation_mode=depth_interpolation_mode,
|
375 |
+
alpha = {4:0.15, 8:0.15},
|
376 |
+
c = 1e-4,
|
377 |
+
epe_mask_prob_th = 0.001,
|
378 |
+
)
|
379 |
+
parameters = [
|
380 |
+
{"params": model.parameters(), "lr": romatch.STEP_SIZE * 1e-4 / 8},
|
381 |
+
]
|
382 |
+
optimizer = torch.optim.AdamW(parameters, weight_decay=0.01)
|
383 |
+
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
384 |
+
optimizer, milestones=[(9*N/romatch.STEP_SIZE)//10])
|
385 |
+
#megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000, h=h,w=w)
|
386 |
+
mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 30)
|
387 |
+
|
388 |
+
checkpointer = CheckPoint(checkpoint_dir, experiment_name)
|
389 |
+
model, optimizer, lr_scheduler, global_step = checkpointer.load(model, optimizer, lr_scheduler, global_step)
|
390 |
+
romatch.GLOBAL_STEP = global_step
|
391 |
+
grad_scaler = torch.cuda.amp.GradScaler(growth_interval=1_000_000)
|
392 |
+
grad_clip_norm = 0.01
|
393 |
+
#megadense_benchmark.benchmark(model)
|
394 |
+
for n in range(romatch.GLOBAL_STEP, N, k * romatch.STEP_SIZE):
|
395 |
+
mega_sampler = torch.utils.data.WeightedRandomSampler(
|
396 |
+
mega_ws, num_samples = batch_size * k, replacement=False
|
397 |
+
)
|
398 |
+
mega_dataloader = iter(
|
399 |
+
torch.utils.data.DataLoader(
|
400 |
+
megadepth_train,
|
401 |
+
batch_size = batch_size,
|
402 |
+
sampler = mega_sampler,
|
403 |
+
num_workers = 8,
|
404 |
+
)
|
405 |
+
)
|
406 |
+
train_k_steps(
|
407 |
+
n, k, mega_dataloader, model, depth_loss, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm,
|
408 |
+
)
|
409 |
+
checkpointer.save(model, optimizer, lr_scheduler, romatch.GLOBAL_STEP)
|
410 |
+
wandb.log(mega1500_benchmark.benchmark(model, model_name=experiment_name), step = romatch.GLOBAL_STEP)
|
411 |
+
|
412 |
+
def test_mega_8_scenes(model, name):
|
413 |
+
mega_8_scenes_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth",
|
414 |
+
scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
|
415 |
+
'mega_8_scenes_0025_0.1_0.3.npz',
|
416 |
+
'mega_8_scenes_0021_0.1_0.3.npz',
|
417 |
+
'mega_8_scenes_0008_0.1_0.3.npz',
|
418 |
+
'mega_8_scenes_0032_0.1_0.3.npz',
|
419 |
+
'mega_8_scenes_1589_0.1_0.3.npz',
|
420 |
+
'mega_8_scenes_0063_0.1_0.3.npz',
|
421 |
+
'mega_8_scenes_0024_0.1_0.3.npz',
|
422 |
+
'mega_8_scenes_0019_0.3_0.5.npz',
|
423 |
+
'mega_8_scenes_0025_0.3_0.5.npz',
|
424 |
+
'mega_8_scenes_0021_0.3_0.5.npz',
|
425 |
+
'mega_8_scenes_0008_0.3_0.5.npz',
|
426 |
+
'mega_8_scenes_0032_0.3_0.5.npz',
|
427 |
+
'mega_8_scenes_1589_0.3_0.5.npz',
|
428 |
+
'mega_8_scenes_0063_0.3_0.5.npz',
|
429 |
+
'mega_8_scenes_0024_0.3_0.5.npz'])
|
430 |
+
mega_8_scenes_results = mega_8_scenes_benchmark.benchmark(model, model_name=name)
|
431 |
+
print(mega_8_scenes_results)
|
432 |
+
json.dump(mega_8_scenes_results, open(f"results/mega_8_scenes_{name}.json", "w"))
|
433 |
+
|
434 |
+
def test_mega1500(model, name):
|
435 |
+
mega1500_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth")
|
436 |
+
mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
|
437 |
+
json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w"))
|
438 |
+
|
439 |
+
def test_mega1500_poselib(model, name):
|
440 |
+
mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 1)
|
441 |
+
mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
|
442 |
+
json.dump(mega1500_results, open(f"results/mega1500_poselib_{name}.json", "w"))
|
443 |
+
|
444 |
+
def test_mega_8_scenes_poselib(model, name):
|
445 |
+
mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 1,
|
446 |
+
scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
|
447 |
+
'mega_8_scenes_0025_0.1_0.3.npz',
|
448 |
+
'mega_8_scenes_0021_0.1_0.3.npz',
|
449 |
+
'mega_8_scenes_0008_0.1_0.3.npz',
|
450 |
+
'mega_8_scenes_0032_0.1_0.3.npz',
|
451 |
+
'mega_8_scenes_1589_0.1_0.3.npz',
|
452 |
+
'mega_8_scenes_0063_0.1_0.3.npz',
|
453 |
+
'mega_8_scenes_0024_0.1_0.3.npz',
|
454 |
+
'mega_8_scenes_0019_0.3_0.5.npz',
|
455 |
+
'mega_8_scenes_0025_0.3_0.5.npz',
|
456 |
+
'mega_8_scenes_0021_0.3_0.5.npz',
|
457 |
+
'mega_8_scenes_0008_0.3_0.5.npz',
|
458 |
+
'mega_8_scenes_0032_0.3_0.5.npz',
|
459 |
+
'mega_8_scenes_1589_0.3_0.5.npz',
|
460 |
+
'mega_8_scenes_0063_0.3_0.5.npz',
|
461 |
+
'mega_8_scenes_0024_0.3_0.5.npz'])
|
462 |
+
mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
|
463 |
+
json.dump(mega1500_results, open(f"results/mega_8_scenes_poselib_{name}.json", "w"))
|
464 |
+
|
465 |
+
def test_scannet_poselib(model, name):
|
466 |
+
scannet_benchmark = ScanNetPoselibBenchmark("data/scannet")
|
467 |
+
scannet_results = scannet_benchmark.benchmark(model)
|
468 |
+
json.dump(scannet_results, open(f"results/scannet_{name}.json", "w"))
|
469 |
+
|
470 |
+
def test_scannet(model, name):
|
471 |
+
scannet_benchmark = ScanNetBenchmark("data/scannet")
|
472 |
+
scannet_results = scannet_benchmark.benchmark(model)
|
473 |
+
json.dump(scannet_results, open(f"results/scannet_{name}.json", "w"))
|
474 |
+
|
475 |
+
if __name__ == "__main__":
|
476 |
+
os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations
|
477 |
+
os.environ["OMP_NUM_THREADS"] = "16"
|
478 |
+
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
|
479 |
+
import romatch
|
480 |
+
parser = ArgumentParser()
|
481 |
+
parser.add_argument("--only_test", action='store_true')
|
482 |
+
parser.add_argument("--debug_mode", action='store_true')
|
483 |
+
parser.add_argument("--dont_log_wandb", action='store_true')
|
484 |
+
parser.add_argument("--train_resolution", default='medium')
|
485 |
+
parser.add_argument("--gpu_batch_size", default=8, type=int)
|
486 |
+
parser.add_argument("--wandb_entity", required = False)
|
487 |
+
|
488 |
+
args, _ = parser.parse_known_args()
|
489 |
+
romatch.DEBUG_MODE = args.debug_mode
|
490 |
+
if not args.only_test:
|
491 |
+
train(args)
|
492 |
+
|
493 |
+
experiment_name = "tiny_roma_v1_outdoor"#Path(__file__).stem
|
494 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
495 |
+
model = XFeatModel(freeze_xfeat=False, exact_softmax=False).to(device)
|
496 |
+
model.load_state_dict(torch.load(f"{experiment_name}.pth"))
|
497 |
+
test_mega1500_poselib(model, experiment_name)
|
498 |
+
|
submodules/RoMa/requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
einops
|
3 |
+
torchvision
|
4 |
+
opencv-python
|
5 |
+
kornia
|
6 |
+
albumentations
|
7 |
+
loguru
|
8 |
+
tqdm
|
9 |
+
matplotlib
|
10 |
+
h5py
|
11 |
+
wandb
|
12 |
+
timm
|
13 |
+
poselib
|
14 |
+
#xformers # Optional, used for memefficient attention
|
submodules/RoMa/romatch/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from .models import roma_outdoor, tiny_roma_v1_outdoor, roma_indoor
|
3 |
+
|
4 |
+
DEBUG_MODE = False
|
5 |
+
RANK = int(os.environ.get('RANK', default = 0))
|
6 |
+
GLOBAL_STEP = 0
|
7 |
+
STEP_SIZE = 1
|
8 |
+
LOCAL_RANK = -1
|
submodules/RoMa/romatch/benchmarks/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .hpatches_sequences_homog_benchmark import HpatchesHomogBenchmark
|
2 |
+
from .scannet_benchmark import ScanNetBenchmark
|
3 |
+
from .megadepth_pose_estimation_benchmark import MegaDepthPoseEstimationBenchmark
|
4 |
+
from .megadepth_dense_benchmark import MegadepthDenseBenchmark
|
5 |
+
from .megadepth_pose_estimation_benchmark_poselib import Mega1500PoseLibBenchmark
|
6 |
+
#from .scannet_benchmark_poselib import ScanNetPoselibBenchmark
|
submodules/RoMa/romatch/benchmarks/hpatches_sequences_homog_benchmark.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import os
|
5 |
+
|
6 |
+
from tqdm import tqdm
|
7 |
+
from romatch.utils import pose_auc
|
8 |
+
import cv2
|
9 |
+
|
10 |
+
|
11 |
+
class HpatchesHomogBenchmark:
|
12 |
+
"""Hpatches grid goes from [0,n-1] instead of [0.5,n-0.5]"""
|
13 |
+
|
14 |
+
def __init__(self, dataset_path) -> None:
|
15 |
+
seqs_dir = "hpatches-sequences-release"
|
16 |
+
self.seqs_path = os.path.join(dataset_path, seqs_dir)
|
17 |
+
self.seq_names = sorted(os.listdir(self.seqs_path))
|
18 |
+
# Ignore seqs is same as LoFTR.
|
19 |
+
self.ignore_seqs = set(
|
20 |
+
[
|
21 |
+
"i_contruction",
|
22 |
+
"i_crownnight",
|
23 |
+
"i_dc",
|
24 |
+
"i_pencils",
|
25 |
+
"i_whitebuilding",
|
26 |
+
"v_artisans",
|
27 |
+
"v_astronautis",
|
28 |
+
"v_talent",
|
29 |
+
]
|
30 |
+
)
|
31 |
+
|
32 |
+
def convert_coordinates(self, im_A_coords, im_A_to_im_B, wq, hq, wsup, hsup):
|
33 |
+
offset = 0.5 # Hpatches assumes that the center of the top-left pixel is at [0,0] (I think)
|
34 |
+
im_A_coords = (
|
35 |
+
np.stack(
|
36 |
+
(
|
37 |
+
wq * (im_A_coords[..., 0] + 1) / 2,
|
38 |
+
hq * (im_A_coords[..., 1] + 1) / 2,
|
39 |
+
),
|
40 |
+
axis=-1,
|
41 |
+
)
|
42 |
+
- offset
|
43 |
+
)
|
44 |
+
im_A_to_im_B = (
|
45 |
+
np.stack(
|
46 |
+
(
|
47 |
+
wsup * (im_A_to_im_B[..., 0] + 1) / 2,
|
48 |
+
hsup * (im_A_to_im_B[..., 1] + 1) / 2,
|
49 |
+
),
|
50 |
+
axis=-1,
|
51 |
+
)
|
52 |
+
- offset
|
53 |
+
)
|
54 |
+
return im_A_coords, im_A_to_im_B
|
55 |
+
|
56 |
+
def benchmark(self, model, model_name = None):
|
57 |
+
n_matches = []
|
58 |
+
homog_dists = []
|
59 |
+
for seq_idx, seq_name in tqdm(
|
60 |
+
enumerate(self.seq_names), total=len(self.seq_names)
|
61 |
+
):
|
62 |
+
im_A_path = os.path.join(self.seqs_path, seq_name, "1.ppm")
|
63 |
+
im_A = Image.open(im_A_path)
|
64 |
+
w1, h1 = im_A.size
|
65 |
+
for im_idx in range(2, 7):
|
66 |
+
im_B_path = os.path.join(self.seqs_path, seq_name, f"{im_idx}.ppm")
|
67 |
+
im_B = Image.open(im_B_path)
|
68 |
+
w2, h2 = im_B.size
|
69 |
+
H = np.loadtxt(
|
70 |
+
os.path.join(self.seqs_path, seq_name, "H_1_" + str(im_idx))
|
71 |
+
)
|
72 |
+
dense_matches, dense_certainty = model.match(
|
73 |
+
im_A_path, im_B_path
|
74 |
+
)
|
75 |
+
good_matches, _ = model.sample(dense_matches, dense_certainty, 5000)
|
76 |
+
pos_a, pos_b = self.convert_coordinates(
|
77 |
+
good_matches[:, :2], good_matches[:, 2:], w1, h1, w2, h2
|
78 |
+
)
|
79 |
+
try:
|
80 |
+
H_pred, inliers = cv2.findHomography(
|
81 |
+
pos_a,
|
82 |
+
pos_b,
|
83 |
+
method = cv2.RANSAC,
|
84 |
+
confidence = 0.99999,
|
85 |
+
ransacReprojThreshold = 3 * min(w2, h2) / 480,
|
86 |
+
)
|
87 |
+
except:
|
88 |
+
H_pred = None
|
89 |
+
if H_pred is None:
|
90 |
+
H_pred = np.zeros((3, 3))
|
91 |
+
H_pred[2, 2] = 1.0
|
92 |
+
corners = np.array(
|
93 |
+
[[0, 0, 1], [0, h1 - 1, 1], [w1 - 1, 0, 1], [w1 - 1, h1 - 1, 1]]
|
94 |
+
)
|
95 |
+
real_warped_corners = np.dot(corners, np.transpose(H))
|
96 |
+
real_warped_corners = (
|
97 |
+
real_warped_corners[:, :2] / real_warped_corners[:, 2:]
|
98 |
+
)
|
99 |
+
warped_corners = np.dot(corners, np.transpose(H_pred))
|
100 |
+
warped_corners = warped_corners[:, :2] / warped_corners[:, 2:]
|
101 |
+
mean_dist = np.mean(
|
102 |
+
np.linalg.norm(real_warped_corners - warped_corners, axis=1)
|
103 |
+
) / (min(w2, h2) / 480.0)
|
104 |
+
homog_dists.append(mean_dist)
|
105 |
+
|
106 |
+
n_matches = np.array(n_matches)
|
107 |
+
thresholds = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
108 |
+
auc = pose_auc(np.array(homog_dists), thresholds)
|
109 |
+
return {
|
110 |
+
"hpatches_homog_auc_3": auc[2],
|
111 |
+
"hpatches_homog_auc_5": auc[4],
|
112 |
+
"hpatches_homog_auc_10": auc[9],
|
113 |
+
}
|
submodules/RoMa/romatch/benchmarks/megadepth_dense_benchmark.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import tqdm
|
4 |
+
from romatch.datasets import MegadepthBuilder
|
5 |
+
from romatch.utils import warp_kpts
|
6 |
+
from torch.utils.data import ConcatDataset
|
7 |
+
import romatch
|
8 |
+
|
9 |
+
class MegadepthDenseBenchmark:
|
10 |
+
def __init__(self, data_root="data/megadepth", h = 384, w = 512, num_samples = 2000) -> None:
|
11 |
+
mega = MegadepthBuilder(data_root=data_root)
|
12 |
+
self.dataset = ConcatDataset(
|
13 |
+
mega.build_scenes(split="test_loftr", ht=h, wt=w)
|
14 |
+
) # fixed resolution of 384,512
|
15 |
+
self.num_samples = num_samples
|
16 |
+
|
17 |
+
def geometric_dist(self, depth1, depth2, T_1to2, K1, K2, dense_matches):
|
18 |
+
b, h1, w1, d = dense_matches.shape
|
19 |
+
with torch.no_grad():
|
20 |
+
x1 = dense_matches[..., :2].reshape(b, h1 * w1, 2)
|
21 |
+
mask, x2 = warp_kpts(
|
22 |
+
x1.double(),
|
23 |
+
depth1.double(),
|
24 |
+
depth2.double(),
|
25 |
+
T_1to2.double(),
|
26 |
+
K1.double(),
|
27 |
+
K2.double(),
|
28 |
+
)
|
29 |
+
x2 = torch.stack(
|
30 |
+
(w1 * (x2[..., 0] + 1) / 2, h1 * (x2[..., 1] + 1) / 2), dim=-1
|
31 |
+
)
|
32 |
+
prob = mask.float().reshape(b, h1, w1)
|
33 |
+
x2_hat = dense_matches[..., 2:]
|
34 |
+
x2_hat = torch.stack(
|
35 |
+
(w1 * (x2_hat[..., 0] + 1) / 2, h1 * (x2_hat[..., 1] + 1) / 2), dim=-1
|
36 |
+
)
|
37 |
+
gd = (x2_hat - x2.reshape(b, h1, w1, 2)).norm(dim=-1)
|
38 |
+
gd = gd[prob == 1]
|
39 |
+
pck_1 = (gd < 1.0).float().mean()
|
40 |
+
pck_3 = (gd < 3.0).float().mean()
|
41 |
+
pck_5 = (gd < 5.0).float().mean()
|
42 |
+
return gd, pck_1, pck_3, pck_5, prob
|
43 |
+
|
44 |
+
def benchmark(self, model, batch_size=8):
|
45 |
+
model.train(False)
|
46 |
+
with torch.no_grad():
|
47 |
+
gd_tot = 0.0
|
48 |
+
pck_1_tot = 0.0
|
49 |
+
pck_3_tot = 0.0
|
50 |
+
pck_5_tot = 0.0
|
51 |
+
sampler = torch.utils.data.WeightedRandomSampler(
|
52 |
+
torch.ones(len(self.dataset)), replacement=False, num_samples=self.num_samples
|
53 |
+
)
|
54 |
+
B = batch_size
|
55 |
+
dataloader = torch.utils.data.DataLoader(
|
56 |
+
self.dataset, batch_size=B, num_workers=batch_size, sampler=sampler
|
57 |
+
)
|
58 |
+
for idx, data in tqdm.tqdm(enumerate(dataloader), disable = romatch.RANK > 0):
|
59 |
+
im_A, im_B, depth1, depth2, T_1to2, K1, K2 = (
|
60 |
+
data["im_A"].cuda(),
|
61 |
+
data["im_B"].cuda(),
|
62 |
+
data["im_A_depth"].cuda(),
|
63 |
+
data["im_B_depth"].cuda(),
|
64 |
+
data["T_1to2"].cuda(),
|
65 |
+
data["K1"].cuda(),
|
66 |
+
data["K2"].cuda(),
|
67 |
+
)
|
68 |
+
matches, certainty = model.match(im_A, im_B, batched=True)
|
69 |
+
gd, pck_1, pck_3, pck_5, prob = self.geometric_dist(
|
70 |
+
depth1, depth2, T_1to2, K1, K2, matches
|
71 |
+
)
|
72 |
+
if romatch.DEBUG_MODE:
|
73 |
+
from romatch.utils.utils import tensor_to_pil
|
74 |
+
import torch.nn.functional as F
|
75 |
+
path = "vis"
|
76 |
+
H, W = model.get_output_resolution()
|
77 |
+
white_im = torch.ones((B,1,H,W),device="cuda")
|
78 |
+
im_B_transfer_rgb = F.grid_sample(
|
79 |
+
im_B.cuda(), matches[:,:,:W, 2:], mode="bilinear", align_corners=False
|
80 |
+
)
|
81 |
+
warp_im = im_B_transfer_rgb
|
82 |
+
c_b = certainty[:,None]#(certainty*0.9 + 0.1*torch.ones_like(certainty))[:,None]
|
83 |
+
vis_im = c_b * warp_im + (1 - c_b) * white_im
|
84 |
+
for b in range(B):
|
85 |
+
import os
|
86 |
+
os.makedirs(f"{path}/{model.name}/{idx}_{b}_{H}_{W}",exist_ok=True)
|
87 |
+
tensor_to_pil(vis_im[b], unnormalize=True).save(
|
88 |
+
f"{path}/{model.name}/{idx}_{b}_{H}_{W}/warp.jpg")
|
89 |
+
tensor_to_pil(im_A[b].cuda(), unnormalize=True).save(
|
90 |
+
f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_A.jpg")
|
91 |
+
tensor_to_pil(im_B[b].cuda(), unnormalize=True).save(
|
92 |
+
f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_B.jpg")
|
93 |
+
|
94 |
+
|
95 |
+
gd_tot, pck_1_tot, pck_3_tot, pck_5_tot = (
|
96 |
+
gd_tot + gd.mean(),
|
97 |
+
pck_1_tot + pck_1,
|
98 |
+
pck_3_tot + pck_3,
|
99 |
+
pck_5_tot + pck_5,
|
100 |
+
)
|
101 |
+
return {
|
102 |
+
"epe": gd_tot.item() / len(dataloader),
|
103 |
+
"mega_pck_1": pck_1_tot.item() / len(dataloader),
|
104 |
+
"mega_pck_3": pck_3_tot.item() / len(dataloader),
|
105 |
+
"mega_pck_5": pck_5_tot.item() / len(dataloader),
|
106 |
+
}
|
submodules/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from romatch.utils import *
|
4 |
+
from PIL import Image
|
5 |
+
from tqdm import tqdm
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import romatch
|
8 |
+
import kornia.geometry.epipolar as kepi
|
9 |
+
|
10 |
+
class MegaDepthPoseEstimationBenchmark:
|
11 |
+
def __init__(self, data_root="data/megadepth", scene_names = None) -> None:
|
12 |
+
if scene_names is None:
|
13 |
+
self.scene_names = [
|
14 |
+
"0015_0.1_0.3.npz",
|
15 |
+
"0015_0.3_0.5.npz",
|
16 |
+
"0022_0.1_0.3.npz",
|
17 |
+
"0022_0.3_0.5.npz",
|
18 |
+
"0022_0.5_0.7.npz",
|
19 |
+
]
|
20 |
+
else:
|
21 |
+
self.scene_names = scene_names
|
22 |
+
self.scenes = [
|
23 |
+
np.load(f"{data_root}/{scene}", allow_pickle=True)
|
24 |
+
for scene in self.scene_names
|
25 |
+
]
|
26 |
+
self.data_root = data_root
|
27 |
+
|
28 |
+
def benchmark(self, model, model_name = None):
|
29 |
+
with torch.no_grad():
|
30 |
+
data_root = self.data_root
|
31 |
+
tot_e_t, tot_e_R, tot_e_pose = [], [], []
|
32 |
+
thresholds = [5, 10, 20]
|
33 |
+
for scene_ind in range(len(self.scenes)):
|
34 |
+
import os
|
35 |
+
scene_name = os.path.splitext(self.scene_names[scene_ind])[0]
|
36 |
+
scene = self.scenes[scene_ind]
|
37 |
+
pairs = scene["pair_infos"]
|
38 |
+
intrinsics = scene["intrinsics"]
|
39 |
+
poses = scene["poses"]
|
40 |
+
im_paths = scene["image_paths"]
|
41 |
+
pair_inds = range(len(pairs))
|
42 |
+
for pairind in tqdm(pair_inds):
|
43 |
+
idx1, idx2 = pairs[pairind][0]
|
44 |
+
K1 = intrinsics[idx1].copy()
|
45 |
+
T1 = poses[idx1].copy()
|
46 |
+
R1, t1 = T1[:3, :3], T1[:3, 3]
|
47 |
+
K2 = intrinsics[idx2].copy()
|
48 |
+
T2 = poses[idx2].copy()
|
49 |
+
R2, t2 = T2[:3, :3], T2[:3, 3]
|
50 |
+
R, t = compute_relative_pose(R1, t1, R2, t2)
|
51 |
+
T1_to_2 = np.concatenate((R,t[:,None]), axis=-1)
|
52 |
+
im_A_path = f"{data_root}/{im_paths[idx1]}"
|
53 |
+
im_B_path = f"{data_root}/{im_paths[idx2]}"
|
54 |
+
dense_matches, dense_certainty = model.match(
|
55 |
+
im_A_path, im_B_path, K1.copy(), K2.copy(), T1_to_2.copy()
|
56 |
+
)
|
57 |
+
sparse_matches,_ = model.sample(
|
58 |
+
dense_matches, dense_certainty, 5_000
|
59 |
+
)
|
60 |
+
|
61 |
+
im_A = Image.open(im_A_path)
|
62 |
+
w1, h1 = im_A.size
|
63 |
+
im_B = Image.open(im_B_path)
|
64 |
+
w2, h2 = im_B.size
|
65 |
+
if True: # Note: we keep this true as it was used in DKM/RoMa papers. There is very little difference compared to setting to False.
|
66 |
+
scale1 = 1200 / max(w1, h1)
|
67 |
+
scale2 = 1200 / max(w2, h2)
|
68 |
+
w1, h1 = scale1 * w1, scale1 * h1
|
69 |
+
w2, h2 = scale2 * w2, scale2 * h2
|
70 |
+
K1, K2 = K1.copy(), K2.copy()
|
71 |
+
K1[:2] = K1[:2] * scale1
|
72 |
+
K2[:2] = K2[:2] * scale2
|
73 |
+
|
74 |
+
kpts1, kpts2 = model.to_pixel_coordinates(sparse_matches, h1, w1, h2, w2)
|
75 |
+
kpts1, kpts2 = kpts1.cpu().numpy(), kpts2.cpu().numpy()
|
76 |
+
for _ in range(5):
|
77 |
+
shuffling = np.random.permutation(np.arange(len(kpts1)))
|
78 |
+
kpts1 = kpts1[shuffling]
|
79 |
+
kpts2 = kpts2[shuffling]
|
80 |
+
try:
|
81 |
+
threshold = 0.5
|
82 |
+
norm_threshold = threshold / (np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
|
83 |
+
R_est, t_est, mask = estimate_pose(
|
84 |
+
kpts1,
|
85 |
+
kpts2,
|
86 |
+
K1,
|
87 |
+
K2,
|
88 |
+
norm_threshold,
|
89 |
+
conf=0.99999,
|
90 |
+
)
|
91 |
+
T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) #
|
92 |
+
e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
|
93 |
+
e_pose = max(e_t, e_R)
|
94 |
+
except Exception as e:
|
95 |
+
print(repr(e))
|
96 |
+
e_t, e_R = 90, 90
|
97 |
+
e_pose = max(e_t, e_R)
|
98 |
+
tot_e_t.append(e_t)
|
99 |
+
tot_e_R.append(e_R)
|
100 |
+
tot_e_pose.append(e_pose)
|
101 |
+
tot_e_pose = np.array(tot_e_pose)
|
102 |
+
auc = pose_auc(tot_e_pose, thresholds)
|
103 |
+
acc_5 = (tot_e_pose < 5).mean()
|
104 |
+
acc_10 = (tot_e_pose < 10).mean()
|
105 |
+
acc_15 = (tot_e_pose < 15).mean()
|
106 |
+
acc_20 = (tot_e_pose < 20).mean()
|
107 |
+
map_5 = acc_5
|
108 |
+
map_10 = np.mean([acc_5, acc_10])
|
109 |
+
map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
|
110 |
+
print(f"{model_name} auc: {auc}")
|
111 |
+
return {
|
112 |
+
"auc_5": auc[0],
|
113 |
+
"auc_10": auc[1],
|
114 |
+
"auc_20": auc[2],
|
115 |
+
"map_5": map_5,
|
116 |
+
"map_10": map_10,
|
117 |
+
"map_20": map_20,
|
118 |
+
}
|
submodules/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark_poselib.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from romatch.utils import *
|
4 |
+
from PIL import Image
|
5 |
+
from tqdm import tqdm
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import romatch
|
8 |
+
import kornia.geometry.epipolar as kepi
|
9 |
+
|
10 |
+
# wrap cause pyposelib is still in dev
|
11 |
+
# will add in deps later
|
12 |
+
import poselib
|
13 |
+
|
14 |
+
class Mega1500PoseLibBenchmark:
|
15 |
+
def __init__(self, data_root="data/megadepth", scene_names = None, num_ransac_iter = 5, test_every = 1) -> None:
|
16 |
+
if scene_names is None:
|
17 |
+
self.scene_names = [
|
18 |
+
"0015_0.1_0.3.npz",
|
19 |
+
"0015_0.3_0.5.npz",
|
20 |
+
"0022_0.1_0.3.npz",
|
21 |
+
"0022_0.3_0.5.npz",
|
22 |
+
"0022_0.5_0.7.npz",
|
23 |
+
]
|
24 |
+
else:
|
25 |
+
self.scene_names = scene_names
|
26 |
+
self.scenes = [
|
27 |
+
np.load(f"{data_root}/{scene}", allow_pickle=True)
|
28 |
+
for scene in self.scene_names
|
29 |
+
]
|
30 |
+
self.data_root = data_root
|
31 |
+
self.num_ransac_iter = num_ransac_iter
|
32 |
+
self.test_every = test_every
|
33 |
+
|
34 |
+
def benchmark(self, model, model_name = None):
|
35 |
+
with torch.no_grad():
|
36 |
+
data_root = self.data_root
|
37 |
+
tot_e_t, tot_e_R, tot_e_pose = [], [], []
|
38 |
+
thresholds = [5, 10, 20]
|
39 |
+
for scene_ind in range(len(self.scenes)):
|
40 |
+
import os
|
41 |
+
scene_name = os.path.splitext(self.scene_names[scene_ind])[0]
|
42 |
+
scene = self.scenes[scene_ind]
|
43 |
+
pairs = scene["pair_infos"]
|
44 |
+
intrinsics = scene["intrinsics"]
|
45 |
+
poses = scene["poses"]
|
46 |
+
im_paths = scene["image_paths"]
|
47 |
+
pair_inds = range(len(pairs))[::self.test_every]
|
48 |
+
for pairind in (pbar := tqdm(pair_inds, desc = "Current AUC: ?")):
|
49 |
+
idx1, idx2 = pairs[pairind][0]
|
50 |
+
K1 = intrinsics[idx1].copy()
|
51 |
+
T1 = poses[idx1].copy()
|
52 |
+
R1, t1 = T1[:3, :3], T1[:3, 3]
|
53 |
+
K2 = intrinsics[idx2].copy()
|
54 |
+
T2 = poses[idx2].copy()
|
55 |
+
R2, t2 = T2[:3, :3], T2[:3, 3]
|
56 |
+
R, t = compute_relative_pose(R1, t1, R2, t2)
|
57 |
+
T1_to_2 = np.concatenate((R,t[:,None]), axis=-1)
|
58 |
+
im_A_path = f"{data_root}/{im_paths[idx1]}"
|
59 |
+
im_B_path = f"{data_root}/{im_paths[idx2]}"
|
60 |
+
dense_matches, dense_certainty = model.match(
|
61 |
+
im_A_path, im_B_path, K1.copy(), K2.copy(), T1_to_2.copy()
|
62 |
+
)
|
63 |
+
sparse_matches,_ = model.sample(
|
64 |
+
dense_matches, dense_certainty, 5_000
|
65 |
+
)
|
66 |
+
|
67 |
+
im_A = Image.open(im_A_path)
|
68 |
+
w1, h1 = im_A.size
|
69 |
+
im_B = Image.open(im_B_path)
|
70 |
+
w2, h2 = im_B.size
|
71 |
+
kpts1, kpts2 = model.to_pixel_coordinates(sparse_matches, h1, w1, h2, w2)
|
72 |
+
kpts1, kpts2 = kpts1.cpu().numpy(), kpts2.cpu().numpy()
|
73 |
+
for _ in range(self.num_ransac_iter):
|
74 |
+
shuffling = np.random.permutation(np.arange(len(kpts1)))
|
75 |
+
kpts1 = kpts1[shuffling]
|
76 |
+
kpts2 = kpts2[shuffling]
|
77 |
+
try:
|
78 |
+
threshold = 1
|
79 |
+
camera1 = {'model': 'PINHOLE', 'width': w1, 'height': h1, 'params': K1[[0,1,0,1], [0,1,2,2]]}
|
80 |
+
camera2 = {'model': 'PINHOLE', 'width': w2, 'height': h2, 'params': K2[[0,1,0,1], [0,1,2,2]]}
|
81 |
+
relpose, res = poselib.estimate_relative_pose(
|
82 |
+
kpts1,
|
83 |
+
kpts2,
|
84 |
+
camera1,
|
85 |
+
camera2,
|
86 |
+
ransac_opt = {"max_reproj_error": 2*threshold, "max_epipolar_error": threshold, "min_inliers": 8, "max_iterations": 10_000},
|
87 |
+
)
|
88 |
+
Rt_est = relpose.Rt
|
89 |
+
R_est, t_est = Rt_est[:3,:3], Rt_est[:3,3:]
|
90 |
+
mask = np.array(res['inliers']).astype(np.float32)
|
91 |
+
T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) #
|
92 |
+
e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
|
93 |
+
e_pose = max(e_t, e_R)
|
94 |
+
except Exception as e:
|
95 |
+
print(repr(e))
|
96 |
+
e_t, e_R = 90, 90
|
97 |
+
e_pose = max(e_t, e_R)
|
98 |
+
tot_e_t.append(e_t)
|
99 |
+
tot_e_R.append(e_R)
|
100 |
+
tot_e_pose.append(e_pose)
|
101 |
+
pbar.set_description(f"Current AUC: {pose_auc(tot_e_pose, thresholds)}")
|
102 |
+
tot_e_pose = np.array(tot_e_pose)
|
103 |
+
auc = pose_auc(tot_e_pose, thresholds)
|
104 |
+
acc_5 = (tot_e_pose < 5).mean()
|
105 |
+
acc_10 = (tot_e_pose < 10).mean()
|
106 |
+
acc_15 = (tot_e_pose < 15).mean()
|
107 |
+
acc_20 = (tot_e_pose < 20).mean()
|
108 |
+
map_5 = acc_5
|
109 |
+
map_10 = np.mean([acc_5, acc_10])
|
110 |
+
map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
|
111 |
+
print(f"{model_name} auc: {auc}")
|
112 |
+
return {
|
113 |
+
"auc_5": auc[0],
|
114 |
+
"auc_10": auc[1],
|
115 |
+
"auc_20": auc[2],
|
116 |
+
"map_5": map_5,
|
117 |
+
"map_10": map_10,
|
118 |
+
"map_20": map_20,
|
119 |
+
}
|
submodules/RoMa/romatch/benchmarks/scannet_benchmark.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path as osp
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from romatch.utils import *
|
5 |
+
from PIL import Image
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
|
9 |
+
class ScanNetBenchmark:
|
10 |
+
def __init__(self, data_root="data/scannet") -> None:
|
11 |
+
self.data_root = data_root
|
12 |
+
|
13 |
+
def benchmark(self, model, model_name = None):
|
14 |
+
model.train(False)
|
15 |
+
with torch.no_grad():
|
16 |
+
data_root = self.data_root
|
17 |
+
tmp = np.load(osp.join(data_root, "test.npz"))
|
18 |
+
pairs, rel_pose = tmp["name"], tmp["rel_pose"]
|
19 |
+
tot_e_t, tot_e_R, tot_e_pose = [], [], []
|
20 |
+
pair_inds = np.random.choice(
|
21 |
+
range(len(pairs)), size=len(pairs), replace=False
|
22 |
+
)
|
23 |
+
for pairind in tqdm(pair_inds, smoothing=0.9):
|
24 |
+
scene = pairs[pairind]
|
25 |
+
scene_name = f"scene0{scene[0]}_00"
|
26 |
+
im_A_path = osp.join(
|
27 |
+
self.data_root,
|
28 |
+
"scans_test",
|
29 |
+
scene_name,
|
30 |
+
"color",
|
31 |
+
f"{scene[2]}.jpg",
|
32 |
+
)
|
33 |
+
im_A = Image.open(im_A_path)
|
34 |
+
im_B_path = osp.join(
|
35 |
+
self.data_root,
|
36 |
+
"scans_test",
|
37 |
+
scene_name,
|
38 |
+
"color",
|
39 |
+
f"{scene[3]}.jpg",
|
40 |
+
)
|
41 |
+
im_B = Image.open(im_B_path)
|
42 |
+
T_gt = rel_pose[pairind].reshape(3, 4)
|
43 |
+
R, t = T_gt[:3, :3], T_gt[:3, 3]
|
44 |
+
K = np.stack(
|
45 |
+
[
|
46 |
+
np.array([float(i) for i in r.split()])
|
47 |
+
for r in open(
|
48 |
+
osp.join(
|
49 |
+
self.data_root,
|
50 |
+
"scans_test",
|
51 |
+
scene_name,
|
52 |
+
"intrinsic",
|
53 |
+
"intrinsic_color.txt",
|
54 |
+
),
|
55 |
+
"r",
|
56 |
+
)
|
57 |
+
.read()
|
58 |
+
.split("\n")
|
59 |
+
if r
|
60 |
+
]
|
61 |
+
)
|
62 |
+
w1, h1 = im_A.size
|
63 |
+
w2, h2 = im_B.size
|
64 |
+
K1 = K.copy()
|
65 |
+
K2 = K.copy()
|
66 |
+
dense_matches, dense_certainty = model.match(im_A_path, im_B_path)
|
67 |
+
sparse_matches, sparse_certainty = model.sample(
|
68 |
+
dense_matches, dense_certainty, 5000
|
69 |
+
)
|
70 |
+
scale1 = 480 / min(w1, h1)
|
71 |
+
scale2 = 480 / min(w2, h2)
|
72 |
+
w1, h1 = scale1 * w1, scale1 * h1
|
73 |
+
w2, h2 = scale2 * w2, scale2 * h2
|
74 |
+
K1 = K1 * scale1
|
75 |
+
K2 = K2 * scale2
|
76 |
+
|
77 |
+
offset = 0.5
|
78 |
+
kpts1 = sparse_matches[:, :2]
|
79 |
+
kpts1 = (
|
80 |
+
np.stack(
|
81 |
+
(
|
82 |
+
w1 * (kpts1[:, 0] + 1) / 2 - offset,
|
83 |
+
h1 * (kpts1[:, 1] + 1) / 2 - offset,
|
84 |
+
),
|
85 |
+
axis=-1,
|
86 |
+
)
|
87 |
+
)
|
88 |
+
kpts2 = sparse_matches[:, 2:]
|
89 |
+
kpts2 = (
|
90 |
+
np.stack(
|
91 |
+
(
|
92 |
+
w2 * (kpts2[:, 0] + 1) / 2 - offset,
|
93 |
+
h2 * (kpts2[:, 1] + 1) / 2 - offset,
|
94 |
+
),
|
95 |
+
axis=-1,
|
96 |
+
)
|
97 |
+
)
|
98 |
+
for _ in range(5):
|
99 |
+
shuffling = np.random.permutation(np.arange(len(kpts1)))
|
100 |
+
kpts1 = kpts1[shuffling]
|
101 |
+
kpts2 = kpts2[shuffling]
|
102 |
+
try:
|
103 |
+
norm_threshold = 0.5 / (
|
104 |
+
np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
|
105 |
+
R_est, t_est, mask = estimate_pose(
|
106 |
+
kpts1,
|
107 |
+
kpts2,
|
108 |
+
K1,
|
109 |
+
K2,
|
110 |
+
norm_threshold,
|
111 |
+
conf=0.99999,
|
112 |
+
)
|
113 |
+
T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) #
|
114 |
+
e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
|
115 |
+
e_pose = max(e_t, e_R)
|
116 |
+
except Exception as e:
|
117 |
+
print(repr(e))
|
118 |
+
e_t, e_R = 90, 90
|
119 |
+
e_pose = max(e_t, e_R)
|
120 |
+
tot_e_t.append(e_t)
|
121 |
+
tot_e_R.append(e_R)
|
122 |
+
tot_e_pose.append(e_pose)
|
123 |
+
tot_e_t.append(e_t)
|
124 |
+
tot_e_R.append(e_R)
|
125 |
+
tot_e_pose.append(e_pose)
|
126 |
+
tot_e_pose = np.array(tot_e_pose)
|
127 |
+
thresholds = [5, 10, 20]
|
128 |
+
auc = pose_auc(tot_e_pose, thresholds)
|
129 |
+
acc_5 = (tot_e_pose < 5).mean()
|
130 |
+
acc_10 = (tot_e_pose < 10).mean()
|
131 |
+
acc_15 = (tot_e_pose < 15).mean()
|
132 |
+
acc_20 = (tot_e_pose < 20).mean()
|
133 |
+
map_5 = acc_5
|
134 |
+
map_10 = np.mean([acc_5, acc_10])
|
135 |
+
map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
|
136 |
+
return {
|
137 |
+
"auc_5": auc[0],
|
138 |
+
"auc_10": auc[1],
|
139 |
+
"auc_20": auc[2],
|
140 |
+
"map_5": map_5,
|
141 |
+
"map_10": map_10,
|
142 |
+
"map_20": map_20,
|
143 |
+
}
|
submodules/RoMa/romatch/checkpointing/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .checkpoint import CheckPoint
|
submodules/RoMa/romatch/checkpointing/checkpoint.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from torch.nn.parallel.data_parallel import DataParallel
|
4 |
+
from torch.nn.parallel.distributed import DistributedDataParallel
|
5 |
+
from loguru import logger
|
6 |
+
import gc
|
7 |
+
|
8 |
+
import romatch
|
9 |
+
|
10 |
+
class CheckPoint:
|
11 |
+
def __init__(self, dir=None, name="tmp"):
|
12 |
+
self.name = name
|
13 |
+
self.dir = dir
|
14 |
+
os.makedirs(self.dir, exist_ok=True)
|
15 |
+
|
16 |
+
def save(
|
17 |
+
self,
|
18 |
+
model,
|
19 |
+
optimizer,
|
20 |
+
lr_scheduler,
|
21 |
+
n,
|
22 |
+
):
|
23 |
+
if romatch.RANK == 0:
|
24 |
+
assert model is not None
|
25 |
+
if isinstance(model, (DataParallel, DistributedDataParallel)):
|
26 |
+
model = model.module
|
27 |
+
states = {
|
28 |
+
"model": model.state_dict(),
|
29 |
+
"n": n,
|
30 |
+
"optimizer": optimizer.state_dict(),
|
31 |
+
"lr_scheduler": lr_scheduler.state_dict(),
|
32 |
+
}
|
33 |
+
torch.save(states, self.dir + self.name + f"_latest.pth")
|
34 |
+
logger.info(f"Saved states {list(states.keys())}, at step {n}")
|
35 |
+
|
36 |
+
def load(
|
37 |
+
self,
|
38 |
+
model,
|
39 |
+
optimizer,
|
40 |
+
lr_scheduler,
|
41 |
+
n,
|
42 |
+
):
|
43 |
+
if os.path.exists(self.dir + self.name + f"_latest.pth") and romatch.RANK == 0:
|
44 |
+
states = torch.load(self.dir + self.name + f"_latest.pth")
|
45 |
+
if "model" in states:
|
46 |
+
model.load_state_dict(states["model"])
|
47 |
+
if "n" in states:
|
48 |
+
n = states["n"] if states["n"] else n
|
49 |
+
if "optimizer" in states:
|
50 |
+
try:
|
51 |
+
optimizer.load_state_dict(states["optimizer"])
|
52 |
+
except Exception as e:
|
53 |
+
print(f"Failed to load states for optimizer, with error {e}")
|
54 |
+
if "lr_scheduler" in states:
|
55 |
+
lr_scheduler.load_state_dict(states["lr_scheduler"])
|
56 |
+
print(f"Loaded states {list(states.keys())}, at step {n}")
|
57 |
+
del states
|
58 |
+
gc.collect()
|
59 |
+
torch.cuda.empty_cache()
|
60 |
+
return model, optimizer, lr_scheduler, n
|