Jaocs commited on
Commit
c096a7a
·
1 Parent(s): 0f0079f

application3

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Dockerfile +72 -0
  2. LICENSE.txt +12 -0
  3. README.md +232 -8
  4. configs/gs/base.yaml +51 -0
  5. configs/train.yaml +38 -0
  6. extra/archive/rasterizer_impl.h +75 -0
  7. extra/archive/simple-knn.patch.txt +13 -0
  8. full_eval.py +102 -0
  9. gradio_demo.py +424 -0
  10. install.sh +35 -0
  11. main.py +268 -0
  12. metrics.py +115 -0
  13. requirements.txt +1 -0
  14. script.bash +0 -0
  15. source/EDGS.code-workspace +11 -0
  16. source/__init__.py +0 -0
  17. source/corr_init.py +907 -0
  18. source/data_utils.py +28 -0
  19. source/losses.py +100 -0
  20. source/networks.py +48 -0
  21. source/timer.py +24 -0
  22. source/trainer.py +265 -0
  23. source/utils_aux.py +92 -0
  24. source/utils_preprocess.py +334 -0
  25. source/visualization.py +1072 -0
  26. submodules/RoMa/.gitignore +11 -0
  27. submodules/RoMa/LICENSE +21 -0
  28. submodules/RoMa/README.md +123 -0
  29. submodules/RoMa/data/.gitignore +2 -0
  30. submodules/RoMa/demo/demo_3D_effect.py +47 -0
  31. submodules/RoMa/demo/demo_fundamental.py +34 -0
  32. submodules/RoMa/demo/demo_match.py +50 -0
  33. submodules/RoMa/demo/demo_match_opencv_sift.py +43 -0
  34. submodules/RoMa/demo/demo_match_tiny.py +77 -0
  35. submodules/RoMa/demo/gif/.gitignore +2 -0
  36. submodules/RoMa/experiments/eval_roma_outdoor.py +57 -0
  37. submodules/RoMa/experiments/eval_tiny_roma_v1_outdoor.py +84 -0
  38. submodules/RoMa/experiments/roma_indoor.py +320 -0
  39. submodules/RoMa/experiments/train_roma_outdoor.py +307 -0
  40. submodules/RoMa/experiments/train_tiny_roma_v1_outdoor.py +498 -0
  41. submodules/RoMa/requirements.txt +14 -0
  42. submodules/RoMa/romatch/__init__.py +8 -0
  43. submodules/RoMa/romatch/benchmarks/__init__.py +6 -0
  44. submodules/RoMa/romatch/benchmarks/hpatches_sequences_homog_benchmark.py +113 -0
  45. submodules/RoMa/romatch/benchmarks/megadepth_dense_benchmark.py +106 -0
  46. submodules/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark.py +118 -0
  47. submodules/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark_poselib.py +119 -0
  48. submodules/RoMa/romatch/benchmarks/scannet_benchmark.py +143 -0
  49. submodules/RoMa/romatch/checkpointing/__init__.py +1 -0
  50. submodules/RoMa/romatch/checkpointing/checkpoint.py +60 -0
Dockerfile ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.1.1-devel-ubuntu22.04 AS builder
2
+
3
+ WORKDIR /app
4
+
5
+ COPY . /app/
6
+
7
+ ENV CUDA_HOME=/usr/local/cuda-12.1
8
+ ENV LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH
9
+ ENV PATH=$CUDA_HOME/bin:$PATH
10
+
11
+ ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX"
12
+
13
+ RUN apt-get update && \
14
+ DEBIAN_FRONTEND=noninteractive apt-get install -y \
15
+ build-essential wget curl nano ninja-build unzip libgl-dev ffmpeg && \
16
+ apt-get clean && \
17
+ rm -rf /var/lib/apt/lists/*
18
+
19
+ ENV CONDA_DIR=/opt/conda
20
+
21
+ RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh && \
22
+ /bin/bash ~/miniconda.sh -b -p /opt/conda && \
23
+ rm ~/miniconda.sh
24
+
25
+ ENV PATH=$CONDA_DIR/bin:$PATH
26
+
27
+ RUN conda update -n base conda -y && \
28
+ conda install -n base conda-libmamba-solver -y && \
29
+ conda config --set solver libmamba
30
+
31
+ RUN conda create -y -n edgs python=3.10 pip
32
+
33
+ SHELL ["conda", "run", "-n", "edgs", "/bin/bash", "-c"]
34
+
35
+
36
+ RUN conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia -y
37
+
38
+ RUN pip install -e submodules/gaussian-splatting/submodules/diff-gaussian-rasterization --no-build-isolation && \
39
+ pip install -e submodules/gaussian-splatting/submodules/simple-knn --no-build-isolation
40
+
41
+ RUN pip install pycolmap
42
+
43
+ RUN pip install wandb hydra-core tqdm torchmetrics lpips matplotlib rich plyfile imageio imageio-ffmpeg && \
44
+ conda install numpy=1.26.4 -y -c conda-forge --override-channels
45
+
46
+ RUN pip install -e submodules/RoMa
47
+
48
+ RUN pip install plotly scikit-learn moviepy==2.1.1 ffmpeg fastapi[standard]
49
+
50
+ # Imagen final
51
+
52
+ FROM nvidia/cuda:12.1.1-runtime-ubuntu22.04 AS final
53
+
54
+ WORKDIR /app
55
+
56
+ RUN apt-get update && \
57
+ DEBIAN_FRONTEND=noninteractive apt-get install -y \
58
+ libgl1-mesa-glx libsm6 libxext6 ffmpeg && \
59
+ apt-get clean && \
60
+ rm -rf /var/lib/apt/lists/*
61
+
62
+ COPY --from=builder /opt/conda /opt/conda
63
+
64
+ COPY --from=builder /app /app
65
+
66
+ ENV PATH="/opt/conda/bin:/opt/conda/envs/edgs/bin:$PATH"
67
+
68
+ ENV CUDA_HOME=/usr/local/cuda-12.1
69
+ ENV LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH
70
+
71
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
72
+
LICENSE.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright 2025, Dmytro Kotovenko, Olga Grebenkova, Björn Ommer
2
+ Redistribution and use in source and binary forms, with or without modification, are permitted for non-commercial academic research and/or non-commercial personal use only provided that the following conditions are met:
3
+
4
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
5
+
6
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
7
+
8
+ 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
9
+
10
+ Any use of this software beyond the above specified conditions requires a separate license. Please contact the copyright holders to discuss license terms.
11
+
12
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
README.md CHANGED
@@ -1,10 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
- title: Gs Final
3
- emoji: 🐢
4
- colorFrom: gray
5
- colorTo: green
6
- sdk: docker
7
- pinned: false
8
- ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
1
+ <h1 align="center">EDGS: Eliminating Densification for Efficient Convergence of 3DGS</h2>
2
+
3
+ <p align="center">
4
+ <a href="https://www.linkedin.com/in/dmitry-kotovenko-dl/">Dmytro Kotovenko</a><sup>*</sup> ·
5
+ <a href="https://www.linkedin.com/in/grebenkovao/">Olga Grebenkova</a><sup>*</sup> ·
6
+ <a href="https://ommer-lab.com/people/ommer/">Björn Ommer</a>
7
+ </p>
8
+
9
+ <p align="center">CompVis @ LMU Munich · Munich Center for Machine Learning (MCML) </p>
10
+ <p align="center">* equal contribution </p>
11
+
12
+ <p align="center">
13
+ <a href="https://compvis.github.io/EDGS/"><img src="https://img.shields.io/badge/Project-Page-blue" alt="Project Page"></a>
14
+ <a href="https://arxiv.org/pdf/2504.13204"><img src="https://img.shields.io/badge/arXiv-PDF-b31b1b" alt="Paper"></a>
15
+ <a href="https://colab.research.google.com/github/CompVis/EDGS/blob/main/notebooks/fit_model_to_scene_full.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
16
+ <a href="https://huggingface.co/spaces/CompVis/EDGS"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue" alt="Hugging Face"></a>
17
+
18
+ </p>
19
+
20
+ <p align="center">
21
+ <img src="./assets/Teaser2.png" width="99%">
22
+ </p>
23
+
24
+ <p>
25
+ <strong>3DGS</strong> initializes with a sparse set of Gaussians and progressively adds more in under-reconstructed regions. In contrast, <strong>EDGS</strong> starts with
26
+ a dense initialization from triangulated 2D correspondences across training image pairs,
27
+ requiring only minimal refinement. This leads to <strong>faster convergence</strong> and <strong>higher rendering quality</strong>. Our method reaches the original 3DGS <strong>LPIPS score in just 25% of the training time</strong> and uses only <strong>60% of the splats</strong>.
28
+ Renderings become <strong>nearly indistinguishable from ground truth after only 3,000 steps — without any densification</strong>.
29
+ </p>
30
+
31
+ <h3 align="center">3D scene reconstruction using our method in 11 seconds.</h3>
32
+ <p align="center">
33
+ <img src="assets/video_fruits_our_optimization.gif" width="480" alt="3D Reconstruction Demo">
34
+ </p>
35
+
36
+
37
+
38
+ ## 📚 Table of Contents
39
+ - [🚀 Quickstart](#sec-quickstart)
40
+ - [🛠️ Installation](#sec-install)
41
+ - [📦 Data](#sec-data)
42
+
43
+ - [🏋️ Training](#sec-training)
44
+ - [🏗️ Reusing Our Model](#sec-reuse)
45
+ - [📄 Citation](#sec-citation)
46
+
47
+ <a id="sec-quickstart"></a>
48
+ ## 🚀 Quickstart
49
+ The fastest way to try our model is through the [Hugging Face demo](https://huggingface.co/spaces/magistrkoljan/EDGS), which lets you upload images or a video and interactively rotate the resulting 3D scene. For broad accessibility, we currently support only **forward-facing scenes**.
50
+ #### Steps:
51
+ 1. Upload a list of photos or a single video.
52
+ 2. Click **📸 Preprocess Input** to estimate 3D positions using COLMAP.
53
+ 3. Click **🚀 Start Reconstruction** to run the model.
54
+
55
+ You can also **explore the reconstructed scene in 3D** directly in the browser.
56
+
57
+ > ⚡ Runtime: EDGS typically takes just **10–20 seconds**, plus **5–10 seconds** for COLMAP processing. Additional time may be needed to save outputs (model, video, 3D preview).
58
+
59
+ You can also run the same app locally on your machine with command:
60
+ ```CUDA_VISIBLE_DEVICES=0 python gradio_demo.py --port 7862 --no_share```
61
+ Without `--no_share` flag you will get the adress for gradio app that you can share with the others allowing others to process their data on your server.
62
+
63
+ Alternatively, check our [Colab notebook](https://colab.research.google.com/github/CompVis/EDGS/blob/main/notebooks/fit_model_to_scene_full.ipynb).
64
+
65
+ ###
66
+
67
+
68
+
69
+ <a id="sec-install"></a>
70
+ ## 🛠️ Installation
71
+
72
+ You can either run `install.sh` or manually install using the following:
73
+
74
+ ```bash
75
+ git clone [email protected]:CompVis/EDGS.git --recursive
76
+ cd EDGS
77
+ git submodule update --init --recursive
78
+
79
+ conda create -y -n edgs python=3.10 pip
80
+ conda activate edgs
81
+
82
+ # Set up path to your CUDA. In our experience similar versions like 12.2 also work well
83
+ export CUDA_HOME=/usr/local/cuda-12.1
84
+ export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH
85
+ export PATH=$CUDA_HOME/bin:$PATH
86
+
87
+ conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia -y
88
+ conda install nvidia/label/cuda-12.1.0::cuda-toolkit -y
89
+
90
+ pip install -e submodules/gaussian-splatting/submodules/diff-gaussian-rasterization
91
+ pip install -e submodules/gaussian-splatting/submodules/simple-knn
92
+
93
+ # For COLMAP and pycolmap
94
+ # Optionally install original colmap but probably pycolmap suffices
95
+ # conda install conda-forge/label/colmap_dev::colmap
96
+ pip install pycolmap
97
+
98
+
99
+ pip install wandb hydra-core tqdm torchmetrics lpips matplotlib rich plyfile imageio imageio-ffmpeg
100
+ conda install numpy=1.26.4 -y -c conda-forge --override-channels
101
+
102
+ pip install -e submodules/RoMa
103
+ conda install anaconda::jupyter --yes
104
+
105
+ # Stuff necessary for gradio and visualizations
106
+ pip install gradio
107
+ pip install plotly scikit-learn moviepy==2.1.1 ffmpeg
108
+ pip install open3d
109
+ ```
110
+
111
+ <a id="sec-data"></a>
112
+ ## 📦 Data
113
+
114
+ We evaluated on the following datasets:
115
+
116
+ - **MipNeRF360** — download [here](https://jonbarron.info/mipnerf360/). Unzip "Dataset Pt. 1" and "Dataset Pt. 2", then merge scenes.
117
+ - **Tanks & Temples + Deep Blending** — from the [original 3DGS repo](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/datasets/input/tandt_db.zip).
118
+
119
+ ### Using Your Own Dataset
120
+
121
+ You can use the same data format as the [3DGS project](https://github.com/graphdeco-inria/gaussian-splatting?tab=readme-ov-file#processing-your-own-scenes). Please follow their guide to prepare your scene.
122
+
123
+ Expected folder structure:
124
+ ```
125
+ scene_folder
126
+ |---images
127
+ | |---<image 0>
128
+ | |---<image 1>
129
+ | |---...
130
+ |---sparse
131
+ |---0
132
+ |---cameras.bin
133
+ |---images.bin
134
+ |---points3D.bin
135
+ ```
136
+
137
+ Nerf synthetic format is also acceptable.
138
+
139
+ You can also use functions provided in our code to convert a collection of images or a sinlge video into a desired format. However, this may requre tweaking and processing time can be large for large collection of images with little overlap.
140
+
141
+ <a id="sec-training"></a>
142
+ ## 🏋️ Training
143
+
144
+
145
+ To optimize on a single scene in COLMAP format use this code.
146
+ ```bash
147
+ python train.py \
148
+ train.gs_epochs=30000 \
149
+ train.no_densify=True \
150
+ gs.dataset.source_path=<scene folder> \
151
+ gs.dataset.model_path=<output folder> \
152
+ init_wC.matches_per_ref=20000 \
153
+ init_wC.nns_per_ref=3 \
154
+ init_wC.num_refs=180
155
+ ```
156
+ <details>
157
+ <summary><span style="font-weight: bold;">Command Line Arguments for train.py</span></summary>
158
+
159
+ * `train.gs_epochs`
160
+ Number of training iterations (steps) for Gaussian Splatting.
161
+ * `train.no_densify`
162
+ Disables densification. True by default.
163
+ * `gs.dataset.source_path`
164
+ Path to your input dataset directory. This should follow the same format as the original 3DGS dataset structure.
165
+ * `gs.dataset.model_path`
166
+ Output directory where the trained model, logs, and renderings will be saved.
167
+ * `init_wC.matches_per_ref`
168
+ Number of 2D feature correspondences to extract per reference view for initialization. More matches leads to more gaussians.
169
+ * `init_wC.nns_per_ref`
170
+ Number of nearest neighbor images used per reference during matching.
171
+ * `init_wC.num_refs`
172
+ Total number of reference views sampled.
173
+ * `wandb.mode`
174
+ Specifies how Weights & Biases (W&B) logging is handled.
175
+
176
+ - Default: `"disabled"`
177
+ - Options:
178
+ - `"online"` — log to the W&B server in real-time
179
+ - `"offline"` — save logs locally to sync later
180
+ - `"disabled"` — turn off W&B logging entirely
181
+
182
+ If you want to enable W&B logging, make sure to also configure:
183
+
184
+ - `wandb.project` — the name of your W&B project
185
+ - `wandb.entity` — your W&B username or team name
186
+
187
+ Example override:
188
+ ```bash
189
+ wandb.mode=online wandb.project=EDGS wandb.entity=your_username train.gs_epochs=15_000 init_wC.matches_per_ref=15_000
190
+ ```
191
+ </details>
192
+ <br>
193
+
194
+ To run full evaluation on all datasets:
195
+
196
+ ```bash
197
+ python full_eval.py -m360 <mipnerf360 folder> -tat <tanks and temples folder> -db <deep blending folder>
198
+ ```
199
+ <a id="sec-reuse"></a>
200
+ ## 🏗️ Reusing Our Model
201
+
202
+ Our model is essentially a better **initialization module** for Gaussian Splatting. You can integrate it into your pipeline by calling:
203
+
204
+ ```python
205
+ source.corr_init.init_gaussians_with_corr(...)
206
+ ```
207
+ ### Input arguments:
208
+ - A GaussianModel and Scene instance
209
+ - A configuration namespace `cfg.init_wC` to specify parameters like the number of matches, neighbors, and reference views
210
+ - A RoMA model (automatically instantiated if not provided)
211
+
212
+
213
+
214
+ <a id="sec-citation"></a>
215
+ ## 📄 Citation
216
+ ```bibtex
217
+ @misc{kotovenko2025edgseliminatingdensificationefficient,
218
+ title={EDGS: Eliminating Densification for Efficient Convergence of 3DGS},
219
+ author={Dmytro Kotovenko and Olga Grebenkova and Björn Ommer},
220
+ year={2025},
221
+ eprint={2504.13204},
222
+ archivePrefix={arXiv},
223
+ primaryClass={cs.GR},
224
+ url={https://arxiv.org/abs/2504.13204},
225
+ }
226
+ ```
227
  ---
 
 
 
 
 
 
 
228
 
229
+ # TODO:
230
+ - [ ] Code for training and processing forward-facing scenes.
231
+ - [ ] More data examples
232
+
233
+
234
+
configs/gs/base.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: source.networks.Warper3DGS
2
+
3
+ verbose: True
4
+ viewpoint_stack: !!null
5
+ sh_degree: 3
6
+
7
+ opt:
8
+ iterations: 30000
9
+ position_lr_init: 0.00016
10
+ position_lr_final: 1.6e-06
11
+ position_lr_delay_mult: 0.01
12
+ position_lr_max_steps: 30000
13
+ feature_lr: 0.0025
14
+ opacity_lr: 0.025
15
+ scaling_lr: 0.005
16
+ rotation_lr: 0.001
17
+ percent_dense: 0.01
18
+ lambda_dssim: 0.2
19
+ densification_interval: 100
20
+ opacity_reset_interval: 30000
21
+ densify_from_iter: 500
22
+ densify_until_iter: 15000
23
+ densify_grad_threshold: 0.0002
24
+ random_background: false
25
+ save_iterations: [3000, 7000, 15000, 30000]
26
+ batch_size: 64
27
+ exposure_lr_init: 0.01
28
+ exposure_lr_final: 0.0001
29
+ exposure_lr_delay_steps: 0
30
+ exposure_lr_delay_mult: 0.0
31
+
32
+ TRAIN_CAM_IDX_TO_LOG: 50
33
+ TEST_CAM_IDX_TO_LOG: 10
34
+
35
+ pipe:
36
+ convert_SHs_python: False
37
+ compute_cov3D_python: False
38
+ debug: False
39
+ antialiasing: False
40
+
41
+ dataset:
42
+ densify_until_iter: 15000
43
+ source_path: '' #path to dataset
44
+ model_path: '' #path to logs
45
+ images: images
46
+ resolution: -1
47
+ white_background: false
48
+ data_device: cuda
49
+ eval: false
50
+ depths: ""
51
+ train_test_exp: False
configs/train.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - gs: base
3
+ - _self_
4
+
5
+ seed: 228
6
+
7
+ wandb:
8
+ mode: "online" # "disabled" for no logging
9
+ entity: "3dcorrespondence"
10
+ project: "Adv3DGS"
11
+ group: null
12
+ name: null
13
+ tag: "debug"
14
+
15
+ train:
16
+ gs_epochs: 0 # number of 3dgs iterations
17
+ reduce_opacity: True
18
+ no_densify: False # if True, the model will not be densified
19
+ max_lr: True
20
+
21
+ load:
22
+ gs: null #path to 3dgs checkpoint
23
+ gs_step: null #number of iterations, e.g. 7000
24
+
25
+ device: "cuda:0"
26
+ verbose: true
27
+
28
+ init_wC:
29
+ use: True # use EDGS
30
+ matches_per_ref: 15_000 # number of matches per reference
31
+ num_refs: 180 # number of reference images
32
+ nns_per_ref: 3 # number of nearest neighbors per reference
33
+ scaling_factor: 0.001
34
+ proj_err_tolerance: 0.01
35
+ roma_model: "outdoors" # you can change this to "indoors" or "outdoors"
36
+ add_SfM_init : False
37
+
38
+
extra/archive/rasterizer_impl.h ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2023, Inria
3
+ * GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ * All rights reserved.
5
+ *
6
+ * This software is free for non-commercial, research and evaluation use
7
+ * under the terms of the LICENSE.md file.
8
+ *
9
+ * For inquiries contact [email protected]
10
+ */
11
+
12
+ #pragma once
13
+
14
+ #include <cstdint>
15
+ #include <iostream>
16
+ #include <vector>
17
+ #include "rasterizer.h"
18
+ #include <cuda_runtime_api.h>
19
+
20
+ namespace CudaRasterizer
21
+ {
22
+ template <typename T>
23
+ static void obtain(char*& chunk, T*& ptr, std::size_t count, std::size_t alignment)
24
+ {
25
+ std::size_t offset = (reinterpret_cast<std::uintptr_t>(chunk) + alignment - 1) & ~(alignment - 1);
26
+ ptr = reinterpret_cast<T*>(offset);
27
+ chunk = reinterpret_cast<char*>(ptr + count);
28
+ }
29
+
30
+ struct GeometryState
31
+ {
32
+ size_t scan_size;
33
+ float* depths;
34
+ char* scanning_space;
35
+ bool* clamped;
36
+ int* internal_radii;
37
+ float2* means2D;
38
+ float* cov3D;
39
+ float4* conic_opacity;
40
+ float* rgb;
41
+ uint32_t* point_offsets;
42
+ uint32_t* tiles_touched;
43
+
44
+ static GeometryState fromChunk(char*& chunk, size_t P);
45
+ };
46
+
47
+ struct ImageState
48
+ {
49
+ uint2* ranges;
50
+ uint32_t* n_contrib;
51
+ float* accum_alpha;
52
+
53
+ static ImageState fromChunk(char*& chunk, size_t N);
54
+ };
55
+
56
+ struct BinningState
57
+ {
58
+ size_t sorting_size;
59
+ uint64_t* point_list_keys_unsorted;
60
+ uint64_t* point_list_keys;
61
+ uint32_t* point_list_unsorted;
62
+ uint32_t* point_list;
63
+ char* list_sorting_space;
64
+
65
+ static BinningState fromChunk(char*& chunk, size_t P);
66
+ };
67
+
68
+ template<typename T>
69
+ size_t required(size_t P)
70
+ {
71
+ char* size = nullptr;
72
+ T::fromChunk(size, P);
73
+ return ((size_t)size) + 128;
74
+ }
75
+ };
extra/archive/simple-knn.patch.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/simple_knn.cu b/simple_knn.cu
2
+ index e72e4c9..b2deb1b 100644
3
+ --- a/simple_knn.cu
4
+ +++ b/simple_knn.cu
5
+ @@ -11,6 +11,8 @@
6
+
7
+ #define BOX_SIZE 1024
8
+
9
+ +#include <float.h>
10
+ +
11
+ #include "cuda_runtime.h"
12
+ #include "device_launch_parameters.h"
13
+ #include "simple_knn.h"
full_eval.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ import os
13
+ from argparse import ArgumentParser
14
+
15
+ mipnerf360_outdoor_scenes = ["bicycle", "flowers", "garden", "stump", "treehill"]
16
+ mipnerf360_indoor_scenes = ["room", "counter", "kitchen", "bonsai"]
17
+ tanks_and_temples_scenes = ["truck", "train"]
18
+ deep_blending_scenes = ["drjohnson", "playroom"]
19
+
20
+ parser = ArgumentParser(description="Full evaluation script parameters")
21
+ parser.add_argument("--skip_training", action="store_true")
22
+ parser.add_argument("--skip_rendering", action="store_true")
23
+ parser.add_argument("--skip_metrics", action="store_true")
24
+ parser.add_argument("--output_path", default="./eval")
25
+ args, _ = parser.parse_known_args()
26
+
27
+ all_scenes = []
28
+ all_scenes.extend(mipnerf360_outdoor_scenes)
29
+ all_scenes.extend(mipnerf360_indoor_scenes)
30
+ all_scenes.extend(tanks_and_temples_scenes)
31
+ all_scenes.extend(deep_blending_scenes)
32
+
33
+ if not args.skip_training or not args.skip_rendering:
34
+ parser.add_argument('--mipnerf360', "-m360", required=True, type=str)
35
+ parser.add_argument("--tanksandtemples", "-tat", required=True, type=str)
36
+ parser.add_argument("--deepblending", "-db", required=True, type=str)
37
+ args = parser.parse_args()
38
+
39
+ if not args.skip_training:
40
+ name = "EDGS_"
41
+ common_args = " --quiet --eval --test_iterations -1 "
42
+ for scene in mipnerf360_outdoor_scenes:
43
+ source = args.mipnerf360 + "/" + scene
44
+ experiment = name + scene
45
+ os.system(f"python train.py verbose=True gs.dataset.source_path={source} gs.dataset.model_path={args.output_path}/mipnerf/{scene} wandb.name={experiment} init_wC.use=True train.gs_epochs=30000 init_wC.matches_per_ref=25_000 init_wC.nns_per_ref=3 gs.dataset.images=images_4 init_wC.num_refs=180 train.no_densify=True")
46
+ for scene in mipnerf360_indoor_scenes:
47
+ source = args.mipnerf360 + "/" + scene
48
+ experiment = name + scene
49
+ os.system(f"python train.py verbose=True gs.dataset.source_path={source} gs.dataset.model_path={args.output_path}/mipnerf/{scene} wandb.name={experiment} init_wC.use=True train.gs_epochs=30000 init_wC.matches_per_ref=25_000 init_wC.nns_per_ref=3 gs.dataset.images=images_2 init_wC.num_refs=180 train.no_densify=True")
50
+ for scene in tanks_and_temples_scenes:
51
+ source = args.tanksandtemples + "/" + scene
52
+ experiment = name + scene +"_tandt"
53
+ os.system(f"python train.py verbose=True gs.dataset.source_path={source} gs.dataset.model_path={args.output_path}/mipnerf/{scene} wandb.name={experiment} init_wC.use=True train.gs_epochs=30000 init_wC.matches_per_ref=15_000 init_wC.nns_per_ref=3 init_wC.num_refs=180 train.no_densify=True")
54
+ for scene in deep_blending_scenes:
55
+ source = args.deepblending + "/" + scene
56
+ experiment = name + scene + "_db"
57
+ os.system(f"python train.py verbose=True gs.dataset.source_path={source} gs.dataset.model_path={args.output_path}/mipnerf/{scene} wandb.name={experiment} init_wC.use=True train.gs_epochs=30000 init_wC.matches_per_ref=15_000 init_wC.nns_per_ref=3 init_wC.num_refs=180 train.no_densify=True")
58
+
59
+
60
+ if not args.skip_rendering:
61
+ all_sources = []
62
+ for scene in mipnerf360_outdoor_scenes:
63
+ all_sources.append(args.mipnerf360 + "/" + scene)
64
+ for scene in mipnerf360_indoor_scenes:
65
+ all_sources.append(args.mipnerf360 + "/" + scene)
66
+ for scene in tanks_and_temples_scenes:
67
+ all_sources.append(args.tanksandtemples + "/" + scene )
68
+ for scene in deep_blending_scenes:
69
+ all_sources.append(args.deepblending + "/" + scene)
70
+
71
+ all_outputs = []
72
+ for scene in mipnerf360_outdoor_scenes:
73
+ all_outputs.append(args.output_path + "/mipnerf/" + scene)
74
+ for scene in mipnerf360_indoor_scenes:
75
+ all_outputs.append(args.output_path + "/mipnerf/" + scene)
76
+ for scene in tanks_and_temples_scenes:
77
+ all_outputs.append(args.output_path + "/tandt/" + scene)
78
+ for scene in deep_blending_scenes:
79
+ all_outputs.append(args.output_path + "/db/" + scene)
80
+
81
+
82
+ common_args = " --quiet --eval --skip_train"
83
+ for scene, source, output in zip(all_scenes, all_sources, all_outputs):
84
+ os.system("python ./submodules/gaussian-splatting/render.py --iteration 7000 -s " + source + " -m " + output + common_args)
85
+ os.system("python ./submodules/gaussian-splatting/render.py --iteration 30000 -s " + source + " -m " + output + common_args)
86
+
87
+ if not args.skip_metrics:
88
+ all_outputs = []
89
+ for scene in mipnerf360_outdoor_scenes:
90
+ all_outputs.append(args.output_path + "/mipnerf/" + scene)
91
+ for scene in mipnerf360_indoor_scenes:
92
+ all_outputs.append(args.output_path + "/mipnerf/" + scene)
93
+ for scene in tanks_and_temples_scenes:
94
+ all_outputs.append(args.output_path + "/tandt/" + scene)
95
+ for scene in deep_blending_scenes:
96
+ all_outputs.append(args.output_path + "/db/" + scene)
97
+
98
+ scenes_string = ""
99
+ for scene, output in zip(all_scenes, all_outputs):
100
+ scenes_string += "\"" + output + "\" "
101
+
102
+ os.system("python metrics.py -m " + scenes_string)
gradio_demo.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import shutil
4
+ import tempfile
5
+ import argparse
6
+ import gradio as gr
7
+ import sys
8
+ import io
9
+ from PIL import Image
10
+ import numpy as np
11
+ from source.utils_aux import set_seed
12
+ from source.utils_preprocess import read_video_frames, preprocess_frames, select_optimal_frames, save_frames_to_scene_dir, run_colmap_on_scene
13
+ from source.trainer import EDGSTrainer
14
+ from hydra import initialize, compose
15
+ import hydra
16
+ import time
17
+ from source.visualization import generate_circular_camera_path, save_numpy_frames_as_mp4, generate_fully_smooth_cameras_with_tsp, put_text_on_image
18
+ import contextlib
19
+ import base64
20
+
21
+
22
+ # Init RoMA model:
23
+ sys.path.append('../submodules/RoMa')
24
+ from romatch import roma_outdoor, roma_indoor
25
+
26
+ roma_model = roma_indoor(device="cuda:0")
27
+ roma_model.upsample_preds = False
28
+ roma_model.symmetric = False
29
+
30
+ STATIC_FILE_SERVING_FOLDER = "./served_files"
31
+ MODEL_PATH = None
32
+ os.makedirs(STATIC_FILE_SERVING_FOLDER, exist_ok=True)
33
+
34
+ trainer = None
35
+
36
+ class Tee(io.TextIOBase):
37
+ def __init__(self, *streams):
38
+ self.streams = streams
39
+
40
+ def write(self, data):
41
+ for stream in self.streams:
42
+ stream.write(data)
43
+ return len(data)
44
+
45
+ def flush(self):
46
+ for stream in self.streams:
47
+ stream.flush()
48
+
49
+ def capture_logs(func, *args, **kwargs):
50
+ log_capture_string = io.StringIO()
51
+ tee = Tee(sys.__stdout__, log_capture_string)
52
+ with contextlib.redirect_stdout(tee):
53
+ result = func(*args, **kwargs)
54
+ return result, log_capture_string.getvalue()
55
+
56
+ # Training Pipeline
57
+ def run_training_pipeline(scene_dir,
58
+ num_ref_views=16,
59
+ num_corrs_per_view=20000,
60
+ num_steps=1_000,
61
+ mode_toggle="Ours (EDGS)"):
62
+ with initialize(config_path="./configs", version_base="1.1"):
63
+ cfg = compose(config_name="train")
64
+
65
+ scene_name = os.path.basename(scene_dir)
66
+ model_output_dir = f"./outputs/{scene_name}_trained"
67
+
68
+ cfg.wandb.mode = "disabled"
69
+ cfg.gs.dataset.model_path = model_output_dir
70
+ cfg.gs.dataset.source_path = scene_dir
71
+ cfg.gs.dataset.images = "images"
72
+
73
+ cfg.gs.opt.TEST_CAM_IDX_TO_LOG = 12
74
+ cfg.train.gs_epochs = 30000
75
+
76
+ if mode_toggle=="Ours (EDGS)":
77
+ cfg.gs.opt.opacity_reset_interval = 1_000_000
78
+ cfg.train.reduce_opacity = True
79
+ cfg.train.no_densify = True
80
+ cfg.train.max_lr = True
81
+
82
+ cfg.init_wC.use = True
83
+ cfg.init_wC.matches_per_ref = num_corrs_per_view
84
+ cfg.init_wC.nns_per_ref = 1
85
+ cfg.init_wC.num_refs = num_ref_views
86
+ cfg.init_wC.add_SfM_init = False
87
+ cfg.init_wC.scaling_factor = 0.00077 * 2.
88
+
89
+ set_seed(cfg.seed)
90
+ os.makedirs(cfg.gs.dataset.model_path, exist_ok=True)
91
+
92
+ global trainer
93
+ global MODEL_PATH
94
+ generator3dgs = hydra.utils.instantiate(cfg.gs, do_train_test_split=False)
95
+ trainer = EDGSTrainer(GS=generator3dgs, training_config=cfg.gs.opt, device=cfg.device, log_wandb=cfg.wandb.mode != 'disabled')
96
+
97
+ # Disable evaluation and saving
98
+ trainer.saving_iterations = []
99
+ trainer.evaluate_iterations = []
100
+
101
+ # Initialize
102
+ trainer.timer.start()
103
+ start_time = time.time()
104
+ trainer.init_with_corr(cfg.init_wC, roma_model=roma_model)
105
+ time_for_init = time.time()-start_time
106
+
107
+ viewpoint_cams = trainer.GS.scene.getTrainCameras()
108
+ path_cameras = generate_fully_smooth_cameras_with_tsp(existing_cameras=viewpoint_cams,
109
+ n_selected=6, # 8
110
+ n_points_per_segment=30, # 30
111
+ closed=False)
112
+ path_cameras = path_cameras + path_cameras[::-1]
113
+
114
+ path_renderings = []
115
+ idx = 0
116
+ # Visualize after init
117
+ for _ in range(120):
118
+ with torch.no_grad():
119
+ viewpoint_cam = path_cameras[idx]
120
+ idx = (idx + 1) % len(path_cameras)
121
+ render_pkg = trainer.GS(viewpoint_cam)
122
+ image = render_pkg["render"]
123
+ image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1)
124
+ image_np = (image_np * 255).astype(np.uint8)
125
+ path_renderings.append(put_text_on_image(img=image_np,
126
+ text=f"Init stage.\nTime:{time_for_init:.3f}s. "))
127
+ path_renderings = path_renderings + [put_text_on_image(img=image_np, text=f"Start fitting.\nTime:{time_for_init:.3f}s. ")]*30
128
+
129
+ # Train and save visualizations during training.
130
+ start_time = time.time()
131
+ for _ in range(int(num_steps//10)):
132
+ with torch.no_grad():
133
+ viewpoint_cam = path_cameras[idx]
134
+ idx = (idx + 1) % len(path_cameras)
135
+ render_pkg = trainer.GS(viewpoint_cam)
136
+ image = render_pkg["render"]
137
+ image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1)
138
+ image_np = (image_np * 255).astype(np.uint8)
139
+ path_renderings.append(put_text_on_image(
140
+ img=image_np,
141
+ text=f"Fitting stage.\nTime:{time_for_init + time.time()-start_time:.3f}s. "))
142
+
143
+ cfg.train.gs_epochs = 10
144
+ trainer.train(cfg.train)
145
+ print(f"Time elapsed: {(time_for_init + time.time()-start_time):.2f}s.")
146
+ # if (cfg.init_wC.use == False) and (time_for_init + time.time()-start_time) > 60:
147
+ # break
148
+ final_time = time.time()
149
+
150
+ # Add static frame. To highlight we're done
151
+ path_renderings += [put_text_on_image(
152
+ img=image_np, text=f"Done.\nTime:{time_for_init + final_time -start_time:.3f}s. ")]*30
153
+ # Final rendering at the end.
154
+ for _ in range(len(path_cameras)):
155
+ with torch.no_grad():
156
+ viewpoint_cam = path_cameras[idx]
157
+ idx = (idx + 1) % len(path_cameras)
158
+ render_pkg = trainer.GS(viewpoint_cam)
159
+ image = render_pkg["render"]
160
+ image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1)
161
+ image_np = (image_np * 255).astype(np.uint8)
162
+ path_renderings.append(put_text_on_image(img=image_np,
163
+ text=f"Final result.\nTime:{time_for_init + final_time -start_time:.3f}s. "))
164
+
165
+ trainer.save_model()
166
+ final_video_path = os.path.join(STATIC_FILE_SERVING_FOLDER, f"{scene_name}_final.mp4")
167
+ save_numpy_frames_as_mp4(frames=path_renderings, output_path=final_video_path, fps=30, center_crop=0.85)
168
+ MODEL_PATH = cfg.gs.dataset.model_path
169
+ ply_path = os.path.join(cfg.gs.dataset.model_path, f"point_cloud/iteration_{trainer.gs_step}/point_cloud.ply")
170
+ shutil.copy(ply_path, os.path.join(STATIC_FILE_SERVING_FOLDER, "point_cloud_final.ply"))
171
+
172
+ return final_video_path, ply_path
173
+
174
+ # Gradio Interface
175
+ def gradio_interface(input_path, num_ref_views, num_corrs, num_steps):
176
+ images, scene_dir = run_full_pipeline(input_path, num_ref_views, num_corrs, max_size=1024)
177
+ shutil.copytree(scene_dir, STATIC_FILE_SERVING_FOLDER+'/scene_colmaped', dirs_exist_ok=True)
178
+ (final_video_path, ply_path), log_output = capture_logs(run_training_pipeline,
179
+ scene_dir,
180
+ num_ref_views,
181
+ num_corrs,
182
+ num_steps)
183
+ images_rgb = [img[:, :, ::-1] for img in images]
184
+ return images_rgb, final_video_path, scene_dir, ply_path, log_output
185
+
186
+ # Dummy Render Functions
187
+ def render_all_views(scene_dir):
188
+ viewpoint_cams = trainer.GS.scene.getTrainCameras()
189
+ path_cameras = generate_fully_smooth_cameras_with_tsp(existing_cameras=viewpoint_cams,
190
+ n_selected=8,
191
+ n_points_per_segment=60,
192
+ closed=False)
193
+ path_cameras = path_cameras + path_cameras[::-1]
194
+
195
+ path_renderings = []
196
+ with torch.no_grad():
197
+ for viewpoint_cam in path_cameras:
198
+ render_pkg = trainer.GS(viewpoint_cam)
199
+ image = render_pkg["render"]
200
+ image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1)
201
+ image_np = (image_np * 255).astype(np.uint8)
202
+ path_renderings.append(image_np)
203
+ save_numpy_frames_as_mp4(frames=path_renderings,
204
+ output_path=os.path.join(STATIC_FILE_SERVING_FOLDER, "render_all_views.mp4"),
205
+ fps=30,
206
+ center_crop=0.85)
207
+
208
+ return os.path.join(STATIC_FILE_SERVING_FOLDER, "render_all_views.mp4")
209
+
210
+ def render_circular_path(scene_dir):
211
+ viewpoint_cams = trainer.GS.scene.getTrainCameras()
212
+ path_cameras = generate_circular_camera_path(existing_cameras=viewpoint_cams,
213
+ N=240,
214
+ radius_scale=0.65,
215
+ d=0)
216
+
217
+ path_renderings = []
218
+ with torch.no_grad():
219
+ for viewpoint_cam in path_cameras:
220
+ render_pkg = trainer.GS(viewpoint_cam)
221
+ image = render_pkg["render"]
222
+ image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1)
223
+ image_np = (image_np * 255).astype(np.uint8)
224
+ path_renderings.append(image_np)
225
+ save_numpy_frames_as_mp4(frames=path_renderings,
226
+ output_path=os.path.join(STATIC_FILE_SERVING_FOLDER, "render_circular_path.mp4"),
227
+ fps=30,
228
+ center_crop=0.85)
229
+
230
+ return os.path.join(STATIC_FILE_SERVING_FOLDER, "render_circular_path.mp4")
231
+
232
+ # Download Functions
233
+ def download_cameras():
234
+ path = os.path.join(MODEL_PATH, "cameras.json")
235
+ return f"[📥 Download Cameras.json](file={path})"
236
+
237
+ def download_model():
238
+ path = os.path.join(STATIC_FILE_SERVING_FOLDER, "point_cloud_final.ply")
239
+ return f"[📥 Download Pretrained Model (.ply)](file={path})"
240
+
241
+ # Full pipeline helpers
242
+ def run_full_pipeline(input_path, num_ref_views, num_corrs, max_size=1024):
243
+ tmpdirname = tempfile.mkdtemp()
244
+ scene_dir = os.path.join(tmpdirname, "scene")
245
+ os.makedirs(scene_dir, exist_ok=True)
246
+
247
+ selected_frames = process_input(input_path, num_ref_views, scene_dir, max_size)
248
+ run_colmap_on_scene(scene_dir)
249
+
250
+ return selected_frames, scene_dir
251
+
252
+ # Preprocess Input
253
+ def process_input(input_path, num_ref_views, output_dir, max_size=1024):
254
+ if isinstance(input_path, (str, os.PathLike)):
255
+ if os.path.isdir(input_path):
256
+ frames = []
257
+ for img_file in sorted(os.listdir(input_path)):
258
+ if img_file.lower().endswith(('jpg', 'jpeg', 'png')):
259
+ img = Image.open(os.path.join(output_dir, img_file)).convert('RGB')
260
+ img.thumbnail((1024, 1024))
261
+ frames.append(np.array(img))
262
+ else:
263
+ frames = read_video_frames(video_input=input_path, max_size=max_size)
264
+ else:
265
+ frames = read_video_frames(video_input=input_path, max_size=max_size)
266
+
267
+ frames_scores = preprocess_frames(frames)
268
+ selected_frames_indices = select_optimal_frames(scores=frames_scores,
269
+ k=min(num_ref_views, len(frames)))
270
+ selected_frames = [frames[frame_idx] for frame_idx in selected_frames_indices]
271
+
272
+ save_frames_to_scene_dir(frames=selected_frames, scene_dir=output_dir)
273
+ return selected_frames
274
+
275
+ def preprocess_input(input_path, num_ref_views, max_size=1024):
276
+ tmpdirname = tempfile.mkdtemp()
277
+ scene_dir = os.path.join(tmpdirname, "scene")
278
+ os.makedirs(scene_dir, exist_ok=True)
279
+ selected_frames = process_input(input_path, num_ref_views, scene_dir, max_size)
280
+ run_colmap_on_scene(scene_dir)
281
+ return selected_frames, scene_dir
282
+
283
+ def start_training(scene_dir, num_ref_views, num_corrs, num_steps):
284
+ return capture_logs(run_training_pipeline, scene_dir, num_ref_views, num_corrs, num_steps)
285
+
286
+
287
+ # Gradio App
288
+ with gr.Blocks() as demo:
289
+ with gr.Row():
290
+ with gr.Column(scale=6):
291
+ gr.Markdown("""
292
+ ## <span style='font-size: 20px;'>📄 EDGS: Eliminating Densification for Efficient Convergence of 3DGS</span>
293
+ 🔗 <a href='https://compvis.github.io/EDGS' target='_blank'>Project Page</a>
294
+ """, elem_id="header")
295
+
296
+ gr.Markdown("""
297
+ ### <span style='font-size: 22px;'>🛠️ How to Use This Demo</span>
298
+
299
+ 1. Upload a **front-facing video** or **a folder of images** of a **static** scene.
300
+ 2. Use the sliders to configure the number of reference views, correspondences, and optimization steps.
301
+ 3. First press on preprocess Input to extract frames from video(for videos) and COLMAP frames.
302
+ 4. Then click **🚀 Start Reconstruction** to actually launch the reconstruction pipeline.
303
+ 5. Watch the training visualization and explore the 3D model.
304
+ ‼️ **If you see nothing in the 3D model viewer**, try rotating or zooming — sometimes the initial camera orientation is off.
305
+
306
+
307
+ ✅ Best for scenes with small camera motion.
308
+ ❗ For full 360° or large-scale scenes, we recommend the Colab version (see project page).
309
+ """, elem_id="quickstart")
310
+
311
+
312
+ scene_dir_state = gr.State()
313
+ ply_model_state = gr.State()
314
+
315
+ with gr.Row():
316
+ with gr.Column(scale=2):
317
+ input_file = gr.File(label="Upload Video or Images",
318
+ file_types=[".mp4", ".avi", ".mov", ".png", ".jpg", ".jpeg"],
319
+ file_count="multiple")
320
+ gr.Examples(
321
+ examples = [
322
+ [["assets/examples/video_bakery.mp4"]],
323
+ [["assets/examples/video_flowers.mp4"]],
324
+ [["assets/examples/video_fruits.mp4"]],
325
+ [["assets/examples/video_plant.mp4"]],
326
+ [["assets/examples/video_salad.mp4"]],
327
+ [["assets/examples/video_tram.mp4"]],
328
+ [["assets/examples/video_tulips.mp4"]]
329
+ ],
330
+ inputs=[input_file],
331
+ label="🎞️ ALternatively, try an Example Video",
332
+ examples_per_page=4
333
+ )
334
+ ref_slider = gr.Slider(4, 32, value=16, step=1, label="Number of Reference Views")
335
+ corr_slider = gr.Slider(5000, 30000, value=20000, step=1000, label="Correspondences per Reference View")
336
+ fit_steps_slider = gr.Slider(100, 5000, value=400, step=100, label="Number of optimization steps")
337
+ preprocess_button = gr.Button("📸 Preprocess Input")
338
+ start_button = gr.Button("🚀 Start Reconstruction", interactive=False)
339
+ gallery = gr.Gallery(label="Selected Reference Views", columns=4, height=300)
340
+
341
+ with gr.Column(scale=3):
342
+ gr.Markdown("### 🏋️ Training Visualization")
343
+ video_output = gr.Video(label="Training Video", autoplay=True)
344
+ render_all_views_button = gr.Button("🎥 Render All-Views Path")
345
+ render_circular_path_button = gr.Button("🎥 Render Circular Path")
346
+ rendered_video_output = gr.Video(label="Rendered Video", autoplay=True)
347
+ with gr.Column(scale=5):
348
+ gr.Markdown("### 🌐 Final 3D Model")
349
+ model3d_viewer = gr.Model3D(label="3D Model Viewer")
350
+
351
+ gr.Markdown("### 📦 Output Files")
352
+ with gr.Row(height=50):
353
+ with gr.Column():
354
+ #gr.Markdown(value=f"[📥 Download .ply](file/point_cloud_final.ply)")
355
+ download_cameras_button = gr.Button("📥 Download Cameras.json")
356
+ download_cameras_file = gr.File(label="📄 Cameras.json")
357
+ with gr.Column():
358
+ download_model_button = gr.Button("📥 Download Pretrained Model (.ply)")
359
+ download_model_file = gr.File(label="📄 Pretrained Model (.ply)")
360
+
361
+ log_output_box = gr.Textbox(label="🖥️ Log", lines=10, interactive=False)
362
+
363
+ def on_preprocess_click(input_file, num_ref_views):
364
+ images, scene_dir = preprocess_input(input_file, num_ref_views)
365
+ return gr.update(value=[x[...,::-1] for x in images]), scene_dir, gr.update(interactive=True)
366
+
367
+ def on_start_click(scene_dir, num_ref_views, num_corrs, num_steps):
368
+ (video_path, ply_path), logs = start_training(scene_dir, num_ref_views, num_corrs, num_steps)
369
+ return video_path, ply_path, logs
370
+
371
+ preprocess_button.click(
372
+ fn=on_preprocess_click,
373
+ inputs=[input_file, ref_slider],
374
+ outputs=[gallery, scene_dir_state, start_button]
375
+ )
376
+
377
+ start_button.click(
378
+ fn=on_start_click,
379
+ inputs=[scene_dir_state, ref_slider, corr_slider, fit_steps_slider],
380
+ outputs=[video_output, model3d_viewer, log_output_box]
381
+ )
382
+
383
+ render_all_views_button.click(fn=render_all_views, inputs=[scene_dir_state], outputs=[rendered_video_output])
384
+ render_circular_path_button.click(fn=render_circular_path, inputs=[scene_dir_state], outputs=[rendered_video_output])
385
+
386
+ download_cameras_button.click(fn=lambda: os.path.join(MODEL_PATH, "cameras.json"), inputs=[], outputs=[download_cameras_file])
387
+ download_model_button.click(fn=lambda: os.path.join(STATIC_FILE_SERVING_FOLDER, "point_cloud_final.ply"), inputs=[], outputs=[download_model_file])
388
+
389
+
390
+ gr.Markdown("""
391
+ ---
392
+ ### <span style='font-size: 20px;'>📖 Detailed Overview</span>
393
+
394
+ If you uploaded a video, it will be automatically cut into a smaller number of frames (default: 16).
395
+
396
+ The model pipeline:
397
+ 1. 🧠 Runs PyCOLMAP to estimate camera intrinsics & poses (~3–7 seconds for <16 images).
398
+ 2. 🔁 Computes 2D-2D correspondences between views. More correspondences generally improve quality.
399
+ 3. 🔧 Optimizes a 3D Gaussian Splatting model for several steps.
400
+
401
+ ### 🎥 Training Visualization
402
+ You will see a visualization of the entire training process in the "Training Video" pane.
403
+
404
+ ### 🌀 Rendering & 3D Model
405
+ - Render the scene from a circular path of novel views.
406
+ - Or from camera views close to the original input.
407
+
408
+ The 3D model is shown in the right viewer. You can explore it interactively:
409
+ - On PC: WASD keys, arrow keys, and mouse clicks
410
+ - On mobile: pan and pinch to zoom
411
+
412
+ 🕒 Note: the 3D viewer takes a few extra seconds (~5s) to display after training ends.
413
+
414
+ ---
415
+ Preloaded models coming soon. (TODO)
416
+ """, elem_id="details")
417
+
418
+ if __name__ == "__main__":
419
+ parser = argparse.ArgumentParser(description="Launch Gradio demo for EDGS preprocessing and 3D viewing.")
420
+ parser.add_argument("--port", type=int, default=7860, help="Port to launch the Gradio app on.")
421
+ parser.add_argument("--no_share", action='store_true', help="Disable Gradio sharing and assume local access (default: share=True)")
422
+ args = parser.parse_args()
423
+
424
+ demo.launch(server_name="0.0.0.0", server_port=args.port, share=not args.no_share)
install.sh ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ git clone [email protected]:CompVis/EDGS.git --recursive
3
+ cd EDGS
4
+ git submodule update --init --recursive
5
+
6
+ conda create -y -n edgs python=3.10 pip
7
+ conda activate edgs
8
+
9
+ # Optionally set path to CUDA
10
+ export CUDA_HOME=/usr/local/cuda-12.1
11
+ export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH
12
+ export PATH=$CUDA_HOME/bin:$PATH
13
+ conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia -y
14
+ conda install nvidia/label/cuda-12.1.0::cuda-toolkit -y
15
+
16
+ pip install -e submodules/gaussian-splatting/submodules/diff-gaussian-rasterization
17
+ pip install -e submodules/gaussian-splatting/submodules/simple-knn
18
+
19
+ # For COLMAP and pycolmap
20
+ # Optionally install original colmap but probably pycolmap suffices
21
+ # conda install conda-forge/label/colmap_dev::colmap
22
+ pip install pycolmap
23
+
24
+
25
+ pip install wandb hydra-core tqdm torchmetrics lpips matplotlib rich plyfile imageio imageio-ffmpeg
26
+ conda install numpy=1.26.4 -y -c conda-forge --override-channels
27
+
28
+ pip install -e submodules/RoMa
29
+ conda install anaconda::jupyter --yes
30
+
31
+ # Stuff necessary for gradio and visualizations
32
+ pip install gradio
33
+ pip install plotly scikit-learn moviepy==2.1.1 ffmpeg
34
+ pip install open3d
35
+
main.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import shutil
4
+ import tempfile
5
+ import uuid
6
+ import asyncio
7
+ import io
8
+ import time
9
+ import contextlib
10
+ import base64
11
+ from PIL import Image
12
+ import numpy as np
13
+ from fastapi import FastAPI, UploadFile, File, HTTPException, Body
14
+ from fastapi.responses import JSONResponse, StreamingResponse, FileResponse
15
+ from pydantic import BaseModel, Field
16
+
17
+ try:
18
+ from source.utils_aux import set_seed
19
+ from source.utils_preprocess import read_video_frames, preprocess_frames, select_optimal_frames, save_frames_to_scene_dir, run_colmap_on_scene
20
+ from source.trainer import EDGSTrainer
21
+ from hydra import initialize, compose
22
+ import hydra
23
+ from source.visualization import generate_fully_smooth_cameras_with_tsp, put_text_on_image
24
+ import sys
25
+ sys.path.append('../submodules/RoMa') # Ajusta esta ruta si es necesario
26
+ from romatch import roma_indoor
27
+ except ImportError as e:
28
+ print(f"Error: No se pudieron importar los módulos del proyecto EDGS. Asegúrate de que las rutas y la instalación son correctas. {e}")
29
+ sys.exit(1)
30
+
31
+ # --- Configuración Inicial ---
32
+ # 1. Inicialización de la App FastAPI
33
+ app = FastAPI(
34
+ title="EDGS Training API",
35
+ description="Una API para preprocesar videos y entrenar modelos 3DGS con EDGS.",
36
+ version="1.0.0"
37
+ )
38
+
39
+ # 2. Variables Globales y Almacenamiento de Estado
40
+ # El modelo se cargará en el evento 'startup'
41
+ roma_model = None
42
+
43
+ # Base de datos en memoria para gestionar el estado de las tareas entre endpoints
44
+ tasks_db = {}
45
+
46
+ # 3. Modelos Pydantic para la validación de datos
47
+ class TrainParams(BaseModel):
48
+ num_corrs_per_view: int = Field(20000, gt=0, description="Correspondencias por vista de referencia.")
49
+ num_steps: int = Field(1000, gt=0, description="Número de pasos de optimización.")
50
+
51
+ class PreprocessResponse(BaseModel):
52
+ task_id: str
53
+ message: str
54
+ selected_frames_count: int
55
+ # Opcional: podrías devolver las imágenes en base64 si el cliente las necesita visualizar
56
+ # frames: list[str]
57
+
58
+ # --- Lógica de Negocio (Adaptada del script de Gradio) ---
59
+
60
+ # Esta función se ejecutará en un hilo separado para no bloquear el servidor
61
+ def run_preprocessing_sync(input_path: str, num_ref_views: int):
62
+ """
63
+ Ejecuta el preprocesamiento: selección de frames y ejecución de COLMAP.
64
+ """
65
+ tmpdirname = tempfile.mkdtemp()
66
+ scene_dir = os.path.join(tmpdirname, "scene")
67
+ os.makedirs(scene_dir, exist_ok=True)
68
+
69
+ # 1. Lee y selecciona los mejores frames
70
+ frames = read_video_frames(video_input=input_path, max_size=1024)
71
+ frames_scores = preprocess_frames(frames)
72
+ selected_frames_indices = select_optimal_frames(scores=frames_scores, k=min(num_ref_views, len(frames)))
73
+ selected_frames = [frames[frame_idx] for frame_idx in selected_frames_indices]
74
+
75
+ # 2. Guarda los frames y ejecuta COLMAP
76
+ save_frames_to_scene_dir(frames=selected_frames, scene_dir=scene_dir)
77
+ run_colmap_on_scene(scene_dir)
78
+
79
+ return scene_dir, selected_frames
80
+
81
+ async def training_log_generator(scene_dir: str, num_ref_views: int, params: TrainParams, task_id: str):
82
+ """
83
+ Un generador asíncrono que ejecuta el entrenamiento. Los logs detallados se muestran
84
+ en la terminal del servidor, mientras que el cliente recibe un stream de progreso simple.
85
+ """
86
+ def training_pipeline():
87
+ try:
88
+ # La inicialización y configuración de Hydra se mantienen igual
89
+ with initialize(config_path="./configs", version_base="1.1"):
90
+ cfg = compose(config_name="train")
91
+
92
+ # --- CONFIGURACIÓN COMPLETA ---
93
+ scene_name = os.path.basename(scene_dir)
94
+ model_output_dir = f"./outputs/{scene_name}_trained"
95
+ cfg.wandb.mode = "disabled"
96
+ cfg.gs.dataset.model_path = model_output_dir
97
+ cfg.gs.dataset.source_path = scene_dir
98
+ cfg.gs.dataset.images = "images"
99
+ cfg.train.gs_epochs = 30000
100
+ cfg.gs.opt.opacity_reset_interval = 1_000_000
101
+ cfg.train.reduce_opacity = True
102
+ cfg.train.no_densify = True
103
+ cfg.train.max_lr = True
104
+ cfg.init_wC.use = True
105
+ cfg.init_wC.matches_per_ref = params.num_corrs_per_view
106
+ cfg.init_wC.nns_per_ref = 1
107
+ cfg.init_wC.num_refs = num_ref_views
108
+ cfg.init_wC.add_SfM_init = False
109
+ cfg.init_wC.scaling_factor = 0.00077 * 2.
110
+
111
+ set_seed(cfg.seed)
112
+ os.makedirs(cfg.gs.dataset.model_path, exist_ok=True)
113
+
114
+ device = cfg.device
115
+ generator3dgs = hydra.utils.instantiate(cfg.gs, do_train_test_split=False)
116
+ trainer = EDGSTrainer(GS=generator3dgs, training_config=cfg.gs.opt, device=device, log_wandb=False)
117
+ trainer.saving_iterations = []
118
+ trainer.evaluate_iterations = []
119
+ trainer.timer.start()
120
+
121
+ # Mensaje de progreso para el cliente antes de la inicialización
122
+ yield "data: Inicializando modelo...\n\n"
123
+ trainer.init_with_corr(cfg.init_wC, roma_model=roma_model)
124
+
125
+ # El bucle de entrenamiento principal
126
+ for step in range(int(params.num_steps // 10)):
127
+ cfg.train.gs_epochs = 10
128
+ # trainer.train() ahora imprimirá sus logs detallados directamente en la terminal
129
+ trainer.train(cfg.train)
130
+
131
+ # --- CAMBIO CLAVE ---
132
+ # Envía un mensaje de progreso simple al cliente en lugar de los logs capturados.
133
+ yield f"data: Progreso: {step*10+10}/{params.num_steps} pasos completados.\n\n"
134
+
135
+ trainer.save_model()
136
+ ply_path = os.path.join(cfg.gs.dataset.model_path, f"point_cloud/iteration_{trainer.gs_step}/point_cloud.ply")
137
+
138
+ tasks_db[task_id]['result_ply_path'] = ply_path
139
+
140
+ final_message = "Entrenamiento completado. El modelo está listo para descargar."
141
+ yield f"data: {final_message}\n\n"
142
+
143
+ except Exception as e:
144
+ yield f"data: ERROR: {repr(e)}\n\n"
145
+
146
+ # El bucle que llama a la pipeline se mantiene igual
147
+ training_gen = training_pipeline()
148
+ for log_message in training_gen:
149
+ yield log_message
150
+ await asyncio.sleep(0.1)
151
+
152
+ # --- Eventos de Ciclo de Vida de la App ---
153
+
154
+ @app.on_event("startup")
155
+ async def startup_event():
156
+ """
157
+ Carga el modelo RoMa cuando el servidor se inicia.
158
+ """
159
+ global roma_model
160
+ print("🚀 Iniciando servidor FastAPI...")
161
+ if torch.cuda.is_available():
162
+ device = "cuda:0"
163
+ print("✅ GPU detectada. Usando CUDA.")
164
+ else:
165
+ device = "cpu"
166
+ print("⚠️ No se detectó GPU. Usando CPU (puede ser muy lento).")
167
+
168
+ roma_model = roma_indoor(device=device)
169
+ roma_model.upsample_preds = False
170
+ roma_model.symmetric = False
171
+ print("🤖 Modelo RoMa cargado y listo.")
172
+
173
+ # --- Endpoints de la API ---
174
+
175
+ @app.post("/preprocess", response_model=PreprocessResponse)
176
+ async def preprocess_video(
177
+ num_ref_views: int = Body(16, embed=True, description="Número de vistas de referencia a extraer del video."),
178
+ video: UploadFile = File(..., description="Archivo de video a procesar (.mp4, .mov).")
179
+ ):
180
+ """
181
+ Recibe un video, lo preprocesa (extrae frames + COLMAP) y prepara para el entrenamiento.
182
+ """
183
+ if not video.filename.lower().endswith(('.mp4', '.avi', '.mov')):
184
+ raise HTTPException(status_code=400, detail="Formato de archivo no soportado. Usa .mp4, .avi, o .mov.")
185
+
186
+ # Guarda el video temporalmente para que la librería pueda procesarlo
187
+ with tempfile.NamedTemporaryFile(delete=False, suffix=video.filename) as tmp_video:
188
+ shutil.copyfileobj(video.file, tmp_video)
189
+ tmp_video_path = tmp_video.name
190
+
191
+ try:
192
+ loop = asyncio.get_running_loop()
193
+ # Ejecuta la función síncrona y bloqueante en un executor para no bloquear el servidor
194
+ scene_dir, selected_frames = await loop.run_in_executor(
195
+ None, run_preprocessing_sync, tmp_video_path, num_ref_views
196
+ )
197
+
198
+ # Genera un ID único para esta tarea y guarda la ruta
199
+ task_id = str(uuid.uuid4())
200
+ tasks_db[task_id] = {
201
+ "scene_dir": scene_dir,
202
+ "num_ref_views": len(selected_frames),
203
+ "result_ply_path": None
204
+ }
205
+
206
+ return JSONResponse(
207
+ status_code=200,
208
+ content={
209
+ "task_id": task_id,
210
+ "message": f"Preprocesamiento completado. Se generó el directorio de la escena. Listo para entrenar.",
211
+ "selected_frames_count": len(selected_frames)
212
+ }
213
+ )
214
+ except Exception as e:
215
+ raise HTTPException(status_code=500, detail=f"Error durante el preprocesamiento: {e}")
216
+ finally:
217
+ os.unlink(tmp_video_path) # Limpia el archivo de video temporal
218
+
219
+
220
+ @app.post("/train/{task_id}")
221
+ async def train_model(task_id: str, params: TrainParams):
222
+ """
223
+ Inicia el entrenamiento para una tarea preprocesada.
224
+ Devuelve un stream de logs en tiempo real.
225
+ """
226
+ if task_id not in tasks_db:
227
+ raise HTTPException(status_code=404, detail="Task ID no encontrado. Por favor, ejecuta el preprocesamiento primero.")
228
+
229
+ task_info = tasks_db[task_id]
230
+ scene_dir = task_info["scene_dir"]
231
+ num_ref_views = task_info["num_ref_views"]
232
+
233
+ return StreamingResponse(
234
+ training_log_generator(scene_dir, num_ref_views, params, task_id),
235
+ media_type="text/event-stream"
236
+ )
237
+
238
+ @app.get("/download/{task_id}")
239
+ async def download_ply_file(task_id: str):
240
+ """
241
+ Permite descargar el archivo .ply resultante de un entrenamiento completado.
242
+ """
243
+ if task_id not in tasks_db:
244
+ raise HTTPException(status_code=404, detail="Task ID no encontrado.")
245
+
246
+ task_info = tasks_db[task_id]
247
+ ply_path = task_info.get("result_ply_path")
248
+
249
+ if not ply_path:
250
+ raise HTTPException(status_code=404, detail="El entrenamiento no ha finalizado o el archivo aún no está disponible.")
251
+
252
+ if not os.path.exists(ply_path):
253
+ raise HTTPException(status_code=500, detail="Error: El archivo del modelo no se encuentra en el servidor.")
254
+
255
+ # Generamos un nombre de archivo amigable para el usuario
256
+ file_name = f"model_{task_id[:8]}.ply"
257
+
258
+ return FileResponse(
259
+ path=ply_path,
260
+ media_type='application/octet-stream',
261
+ filename=file_name
262
+ )
263
+
264
+ if __name__ == "__main__":
265
+ import uvicorn
266
+ # Para ejecutar: uvicorn main:app --reload
267
+ # El flag --reload es para desarrollo. Quítalo en producción.
268
+ uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=False)
metrics.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ from pathlib import Path
13
+ import os
14
+ import sys
15
+ from PIL import Image
16
+ import torch
17
+ import torchvision.transforms.functional as tf
18
+ sys.path.append('./submodules/gaussian-splatting/')
19
+ from utils.loss_utils import ssim
20
+ from lpipsPyTorch import lpips as lpips_3dgs
21
+ import json
22
+ from tqdm import tqdm
23
+ from utils.image_utils import psnr
24
+ from argparse import ArgumentParser
25
+
26
+ import lpips
27
+
28
+ def readImages(renders_dir, gt_dir):
29
+ renders = []
30
+ gts = []
31
+ image_names = []
32
+ for fname in os.listdir(renders_dir):
33
+ render = Image.open(renders_dir / fname)
34
+ gt = Image.open(gt_dir / fname)
35
+ renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda())
36
+ gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda())
37
+ image_names.append(fname)
38
+ return renders, gts, image_names
39
+
40
+ def evaluate(model_paths):
41
+
42
+ full_dict = {}
43
+ per_view_dict = {}
44
+ full_dict_polytopeonly = {}
45
+ per_view_dict_polytopeonly = {}
46
+ print("")
47
+
48
+ for scene_dir in model_paths:
49
+ #try:
50
+ print("Scene:", scene_dir)
51
+ full_dict[scene_dir] = {}
52
+ per_view_dict[scene_dir] = {}
53
+ full_dict_polytopeonly[scene_dir] = {}
54
+ per_view_dict_polytopeonly[scene_dir] = {}
55
+
56
+ test_dir = Path(scene_dir) / "test"
57
+
58
+ for method in os.listdir(test_dir):
59
+ print("Method:", method)
60
+
61
+ full_dict[scene_dir][method] = {}
62
+ per_view_dict[scene_dir][method] = {}
63
+ full_dict_polytopeonly[scene_dir][method] = {}
64
+ per_view_dict_polytopeonly[scene_dir][method] = {}
65
+
66
+ method_dir = test_dir / method
67
+ gt_dir = method_dir/ "gt"
68
+ renders_dir = method_dir / "renders"
69
+ renders, gts, image_names = readImages(renders_dir, gt_dir)
70
+
71
+ ssims = []
72
+ psnrs = []
73
+ lpipss = []
74
+ lpipss_3dgs = []
75
+ with torch.no_grad():
76
+ for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"):
77
+ ssims.append(ssim(renders[idx], gts[idx]))
78
+ psnrs.append(psnr(renders[idx], gts[idx]))
79
+ lpipss.append(lpips_fn(renders[idx], gts[idx]))
80
+ lpipss_3dgs.append(lpips_3dgs(renders[idx], gts[idx], net_type='vgg'))
81
+ torch.cuda.empty_cache()
82
+
83
+ print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5"))
84
+ print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5"))
85
+ print(" LPIPS: {:>12.7f}".format(torch.tensor(lpipss).mean(), ".5"))
86
+ print(" LPIPS_3dgs: {:>12.7f}".format(torch.tensor(lpipss_3dgs).mean(), ".5"))
87
+ print("")
88
+
89
+ full_dict[scene_dir][method].update({"SSIM": torch.tensor(ssims).mean().item(),
90
+ "PSNR": torch.tensor(psnrs).mean().item(),
91
+ "LPIPS": torch.tensor(lpipss).mean().item(),
92
+ "LPIPS_3dgs": torch.tensor(lpipss_3dgs).mean().item(),
93
+ })
94
+ per_view_dict[scene_dir][method].update({"SSIM": {name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)},
95
+ "PSNR": {name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)},
96
+ "LPIPS": {name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)},
97
+ "LPIPS_3dgs": {name: lp for lp, name in zip(torch.tensor(lpipss_3dgs).tolist(), image_names)},
98
+ })
99
+
100
+ with open(scene_dir + "/results.json", 'w') as fp:
101
+ json.dump(full_dict[scene_dir], fp, indent=True)
102
+ with open(scene_dir + "/per_view.json", 'w') as fp:
103
+ json.dump(per_view_dict[scene_dir], fp, indent=True)
104
+ #except:
105
+ # print("Unable to compute metrics for model", scene_dir)
106
+
107
+ if __name__ == "__main__":
108
+ device = torch.device("cuda:0")
109
+ torch.cuda.set_device(device)
110
+ lpips_fn = lpips.LPIPS(net='vgg').to(device)
111
+ # Set up command line argument parser
112
+ parser = ArgumentParser(description="Training script parameters")
113
+ parser.add_argument('--model_paths', '-m', required=True, nargs="+", type=str, default=[])
114
+ args = parser.parse_args()
115
+ evaluate(args.model_paths)
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ fastapi
script.bash ADDED
File without changes
source/EDGS.code-workspace ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "folders": [
3
+ {
4
+ "path": ".."
5
+ },
6
+ {
7
+ "path": "../../../../.."
8
+ }
9
+ ],
10
+ "settings": {}
11
+ }
source/__init__.py ADDED
File without changes
source/corr_init.py ADDED
@@ -0,0 +1,907 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('../')
3
+ sys.path.append("../submodules")
4
+ sys.path.append('../submodules/RoMa')
5
+
6
+ from matplotlib import pyplot as plt
7
+ from PIL import Image
8
+ import torch
9
+ import numpy as np
10
+
11
+ #from tqdm import tqdm_notebook as tqdm
12
+ from tqdm import tqdm
13
+ from scipy.cluster.vq import kmeans, vq
14
+ from scipy.spatial.distance import cdist
15
+
16
+ import torch.nn.functional as F
17
+ from romatch import roma_outdoor, roma_indoor
18
+ from utils.sh_utils import RGB2SH
19
+ from romatch.utils import get_tuple_transform_ops
20
+
21
+ import time
22
+ from collections import defaultdict
23
+ from tqdm import tqdm
24
+
25
+
26
+ def pairwise_distances(matrix):
27
+ """
28
+ Computes the pairwise Euclidean distances between all vectors in the input matrix.
29
+
30
+ Args:
31
+ matrix (torch.Tensor): Input matrix of shape [N, D], where N is the number of vectors and D is the dimensionality.
32
+
33
+ Returns:
34
+ torch.Tensor: Pairwise distance matrix of shape [N, N].
35
+ """
36
+ # Compute squared pairwise distances
37
+ squared_diff = torch.cdist(matrix, matrix, p=2)
38
+ return squared_diff
39
+
40
+
41
+ def k_closest_vectors(matrix, k):
42
+ """
43
+ Finds the k-closest vectors for each vector in the input matrix based on Euclidean distance.
44
+
45
+ Args:
46
+ matrix (torch.Tensor): Input matrix of shape [N, D], where N is the number of vectors and D is the dimensionality.
47
+ k (int): Number of closest vectors to return for each vector.
48
+
49
+ Returns:
50
+ torch.Tensor: Indices of the k-closest vectors for each vector, excluding the vector itself.
51
+ """
52
+ # Compute pairwise distances
53
+ distances = pairwise_distances(matrix)
54
+
55
+ # For each vector, sort distances and get the indices of the k-closest vectors (excluding itself)
56
+ # Set diagonal distances to infinity to exclude the vector itself from the nearest neighbors
57
+ distances.fill_diagonal_(float('inf'))
58
+
59
+ # Get the indices of the k smallest distances (k-closest vectors)
60
+ _, indices = torch.topk(distances, k, largest=False, dim=1)
61
+
62
+ return indices
63
+
64
+
65
+ def select_cameras_kmeans(cameras, K):
66
+ """
67
+ Selects K cameras from a set using K-means clustering.
68
+
69
+ Args:
70
+ cameras: NumPy array of shape (N, 16), representing N cameras with their 4x4 homogeneous matrices flattened.
71
+ K: Number of clusters (cameras to select).
72
+
73
+ Returns:
74
+ selected_indices: List of indices of the cameras closest to the cluster centers.
75
+ """
76
+ # Ensure input is a NumPy array
77
+ if not isinstance(cameras, np.ndarray):
78
+ cameras = np.asarray(cameras)
79
+
80
+ if cameras.shape[1] != 16:
81
+ raise ValueError("Each camera must have 16 values corresponding to a flattened 4x4 matrix.")
82
+
83
+ # Perform K-means clustering
84
+ cluster_centers, _ = kmeans(cameras, K)
85
+
86
+ # Assign each camera to a cluster and find distances to cluster centers
87
+ cluster_assignments, _ = vq(cameras, cluster_centers)
88
+
89
+ # Find the camera nearest to each cluster center
90
+ selected_indices = []
91
+ for k in range(K):
92
+ cluster_members = cameras[cluster_assignments == k]
93
+ distances = cdist([cluster_centers[k]], cluster_members)[0]
94
+ nearest_camera_idx = np.where(cluster_assignments == k)[0][np.argmin(distances)]
95
+ selected_indices.append(nearest_camera_idx)
96
+
97
+ return selected_indices
98
+
99
+
100
+ def compute_warp_and_confidence(viewpoint_cam1, viewpoint_cam2, roma_model, device="cuda", verbose=False, output_dict={}):
101
+ """
102
+ Computes the warp and confidence between two viewpoint cameras using the roma_model.
103
+
104
+ Args:
105
+ viewpoint_cam1: Source viewpoint camera.
106
+ viewpoint_cam2: Target viewpoint camera.
107
+ roma_model: Pre-trained Roma model for correspondence matching.
108
+ device: Device to run the computation on.
109
+ verbose: If True, displays the images.
110
+
111
+ Returns:
112
+ certainty: Confidence tensor.
113
+ warp: Warp tensor.
114
+ imB: Processed image B as numpy array.
115
+ """
116
+ # Prepare images
117
+ imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0)
118
+ imB = viewpoint_cam2.original_image.detach().cpu().numpy().transpose(1, 2, 0)
119
+ imA = Image.fromarray(np.clip(imA * 255, 0, 255).astype(np.uint8))
120
+ imB = Image.fromarray(np.clip(imB * 255, 0, 255).astype(np.uint8))
121
+
122
+ if verbose:
123
+ fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(16, 8))
124
+ cax1 = ax[0].imshow(imA)
125
+ ax[0].set_title("Image 1")
126
+ cax2 = ax[1].imshow(imB)
127
+ ax[1].set_title("Image 2")
128
+ fig.colorbar(cax1, ax=ax[0])
129
+ fig.colorbar(cax2, ax=ax[1])
130
+
131
+ for axis in ax:
132
+ axis.axis('off')
133
+ # Save the figure into the dictionary
134
+ output_dict[f'image_pair'] = fig
135
+
136
+ # Transform images
137
+ ws, hs = roma_model.w_resized, roma_model.h_resized
138
+ test_transform = get_tuple_transform_ops(resize=(hs, ws), normalize=True)
139
+ im_A, im_B = test_transform((imA, imB))
140
+ batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)}
141
+
142
+ # Forward pass through Roma model
143
+ corresps = roma_model.forward(batch) if not roma_model.symmetric else roma_model.forward_symmetric(batch)
144
+ finest_scale = 1
145
+ hs, ws = roma_model.upsample_res if roma_model.upsample_preds else (hs, ws)
146
+
147
+ # Process certainty and warp
148
+ certainty = corresps[finest_scale]["certainty"]
149
+ im_A_to_im_B = corresps[finest_scale]["flow"]
150
+ if roma_model.attenuate_cert:
151
+ low_res_certainty = F.interpolate(
152
+ corresps[16]["certainty"], size=(hs, ws), align_corners=False, mode="bilinear"
153
+ )
154
+ certainty -= 0.5 * low_res_certainty * (low_res_certainty < 0)
155
+
156
+ # Upsample predictions if needed
157
+ if roma_model.upsample_preds:
158
+ im_A_to_im_B = F.interpolate(
159
+ im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear"
160
+ )
161
+ certainty = F.interpolate(
162
+ certainty, size=(hs, ws), align_corners=False, mode="bilinear"
163
+ )
164
+
165
+ # Convert predictions to final format
166
+ im_A_to_im_B = im_A_to_im_B.permute(0, 2, 3, 1)
167
+ im_A_coords = torch.stack(torch.meshgrid(
168
+ torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
169
+ torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
170
+ indexing='ij'
171
+ ), dim=0).permute(1, 2, 0).unsqueeze(0).expand(im_A_to_im_B.size(0), -1, -1, -1)
172
+
173
+ warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1)
174
+ certainty = certainty.sigmoid()
175
+
176
+ return certainty[0, 0], warp[0], np.array(imB)
177
+
178
+
179
+ def resize_batch(tensors_3d, tensors_4d, target_shape):
180
+ """
181
+ Resizes a batch of tensors with shapes [B, H, W] and [B, H, W, 4] to the target spatial dimensions.
182
+
183
+ Args:
184
+ tensors_3d: Tensor of shape [B, H, W].
185
+ tensors_4d: Tensor of shape [B, H, W, 4].
186
+ target_shape: Tuple (target_H, target_W) specifying the target spatial dimensions.
187
+
188
+ Returns:
189
+ resized_tensors_3d: Tensor of shape [B, target_H, target_W].
190
+ resized_tensors_4d: Tensor of shape [B, target_H, target_W, 4].
191
+ """
192
+ target_H, target_W = target_shape
193
+
194
+ # Resize [B, H, W] tensor
195
+ resized_tensors_3d = F.interpolate(
196
+ tensors_3d.unsqueeze(1), size=(target_H, target_W), mode="bilinear", align_corners=False
197
+ ).squeeze(1)
198
+
199
+ # Resize [B, H, W, 4] tensor
200
+ B, _, _, C = tensors_4d.shape
201
+ resized_tensors_4d = F.interpolate(
202
+ tensors_4d.permute(0, 3, 1, 2), size=(target_H, target_W), mode="bilinear", align_corners=False
203
+ ).permute(0, 2, 3, 1)
204
+
205
+ return resized_tensors_3d, resized_tensors_4d
206
+
207
+
208
+ def aggregate_confidences_and_warps(viewpoint_stack, closest_indices, roma_model, source_idx, verbose=False, output_dict={}):
209
+ """
210
+ Aggregates confidences and warps by iterating over the nearest neighbors of the source viewpoint.
211
+
212
+ Args:
213
+ viewpoint_stack: Stack of viewpoint cameras.
214
+ closest_indices: Indices of the nearest neighbors for each viewpoint.
215
+ roma_model: Pre-trained Roma model.
216
+ source_idx: Index of the source viewpoint.
217
+ verbose: If True, displays intermediate results.
218
+
219
+ Returns:
220
+ certainties_max: Aggregated maximum confidences.
221
+ warps_max: Aggregated warps corresponding to maximum confidences.
222
+ certainties_max_idcs: Pixel-wise index of the image from which we taken the best matching.
223
+ imB_compound: List of the neighboring images.
224
+ """
225
+ certainties_all, warps_all, imB_compound = [], [], []
226
+
227
+ for nn in tqdm(closest_indices[source_idx]):
228
+
229
+ viewpoint_cam1 = viewpoint_stack[source_idx]
230
+ viewpoint_cam2 = viewpoint_stack[nn]
231
+
232
+ certainty, warp, imB = compute_warp_and_confidence(viewpoint_cam1, viewpoint_cam2, roma_model, verbose=verbose, output_dict=output_dict)
233
+ certainties_all.append(certainty)
234
+ warps_all.append(warp)
235
+ imB_compound.append(imB)
236
+
237
+ certainties_all = torch.stack(certainties_all, dim=0)
238
+ target_shape = imB_compound[0].shape[:2]
239
+ if verbose:
240
+ print("certainties_all.shape:", certainties_all.shape)
241
+ print("torch.stack(warps_all, dim=0).shape:", torch.stack(warps_all, dim=0).shape)
242
+ print("target_shape:", target_shape)
243
+
244
+ certainties_all_resized, warps_all_resized = resize_batch(certainties_all,
245
+ torch.stack(warps_all, dim=0),
246
+ target_shape
247
+ )
248
+
249
+ if verbose:
250
+ print("warps_all_resized.shape:", warps_all_resized.shape)
251
+ for n, cert in enumerate(certainties_all):
252
+ fig, ax = plt.subplots()
253
+ cax = ax.imshow(cert.cpu().numpy(), cmap='viridis')
254
+ fig.colorbar(cax, ax=ax)
255
+ ax.set_title("Pixel-wise Confidence")
256
+ output_dict[f'certainty_{n}'] = fig
257
+
258
+ for n, warp in enumerate(warps_all):
259
+ fig, ax = plt.subplots()
260
+ cax = ax.imshow(warp.cpu().numpy()[:, :, :3], cmap='viridis')
261
+ fig.colorbar(cax, ax=ax)
262
+ ax.set_title("Pixel-wise warp")
263
+ output_dict[f'warp_resized_{n}'] = fig
264
+
265
+ for n, cert in enumerate(certainties_all_resized):
266
+ fig, ax = plt.subplots()
267
+ cax = ax.imshow(cert.cpu().numpy(), cmap='viridis')
268
+ fig.colorbar(cax, ax=ax)
269
+ ax.set_title("Pixel-wise Confidence resized")
270
+ output_dict[f'certainty_resized_{n}'] = fig
271
+
272
+ for n, warp in enumerate(warps_all_resized):
273
+ fig, ax = plt.subplots()
274
+ cax = ax.imshow(warp.cpu().numpy()[:, :, :3], cmap='viridis')
275
+ fig.colorbar(cax, ax=ax)
276
+ ax.set_title("Pixel-wise warp resized")
277
+ output_dict[f'warp_resized_{n}'] = fig
278
+
279
+ certainties_max, certainties_max_idcs = torch.max(certainties_all_resized, dim=0)
280
+ H, W = certainties_max.shape
281
+
282
+ warps_max = warps_all_resized[certainties_max_idcs, torch.arange(H).unsqueeze(1), torch.arange(W)]
283
+
284
+ imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0)
285
+ imA = np.clip(imA * 255, 0, 255).astype(np.uint8)
286
+
287
+ return certainties_max, warps_max, certainties_max_idcs, imA, imB_compound, certainties_all_resized, warps_all_resized
288
+
289
+
290
+
291
+ def extract_keypoints_and_colors(imA, imB_compound, certainties_max, certainties_max_idcs, matches, roma_model,
292
+ verbose=False, output_dict={}):
293
+ """
294
+ Extracts keypoints and corresponding colors from the source image (imA) and multiple target images (imB_compound).
295
+
296
+ Args:
297
+ imA: Source image as a NumPy array (H_A, W_A, C).
298
+ imB_compound: List of target images as NumPy arrays [(H_B, W_B, C), ...].
299
+ certainties_max: Tensor of pixel-wise maximum confidences.
300
+ certainties_max_idcs: Tensor of pixel-wise indices for the best matches.
301
+ matches: Matches in normalized coordinates.
302
+ roma_model: Roma model instance for keypoint operations.
303
+ verbose: if to show intermediate outputs and visualize results
304
+
305
+ Returns:
306
+ kptsA_np: Keypoints in imA in normalized coordinates.
307
+ kptsB_np: Keypoints in imB in normalized coordinates.
308
+ kptsA_color: Colors of keypoints in imA.
309
+ kptsB_color: Colors of keypoints in imB based on certainties_max_idcs.
310
+ """
311
+ H_A, W_A, _ = imA.shape
312
+ H, W = certainties_max.shape
313
+
314
+ # Convert matches to pixel coordinates
315
+ kptsA, kptsB = roma_model.to_pixel_coordinates(
316
+ matches, W_A, H_A, H, W # W, H
317
+ )
318
+
319
+ kptsA_np = kptsA.detach().cpu().numpy()
320
+ kptsB_np = kptsB.detach().cpu().numpy()
321
+ kptsA_np = kptsA_np[:, [1, 0]]
322
+
323
+ if verbose:
324
+ fig, ax = plt.subplots(figsize=(12, 6))
325
+ cax = ax.imshow(imA)
326
+ ax.set_title("Reference image, imA")
327
+ output_dict[f'reference_image'] = fig
328
+
329
+ fig, ax = plt.subplots(figsize=(12, 6))
330
+ cax = ax.imshow(imB_compound[0])
331
+ ax.set_title("Image to compare to image, imB_compound")
332
+ output_dict[f'imB_compound'] = fig
333
+
334
+ fig, ax = plt.subplots(figsize=(12, 6))
335
+ cax = ax.imshow(np.flipud(imA))
336
+ cax = ax.scatter(kptsA_np[:, 0], H_A - kptsA_np[:, 1], s=.03)
337
+ ax.set_title("Keypoints in imA")
338
+ ax.set_xlim(0, W_A)
339
+ ax.set_ylim(0, H_A)
340
+ output_dict[f'kptsA'] = fig
341
+
342
+ fig, ax = plt.subplots(figsize=(12, 6))
343
+ cax = ax.imshow(np.flipud(imB_compound[0]))
344
+ cax = ax.scatter(kptsB_np[:, 0], H_A - kptsB_np[:, 1], s=.03)
345
+ ax.set_title("Keypoints in imB")
346
+ ax.set_xlim(0, W_A)
347
+ ax.set_ylim(0, H_A)
348
+ output_dict[f'kptsB'] = fig
349
+
350
+ # Keypoints are in format (row, column) so the first value is alwain in range [0;height] and second is in range[0;width]
351
+
352
+ kptsA_np = kptsA.detach().cpu().numpy()
353
+ kptsB_np = kptsB.detach().cpu().numpy()
354
+
355
+ # Extract colors for keypoints in imA (vectorized)
356
+ # New experimental version
357
+ kptsA_x = np.round(kptsA_np[:, 0] / 1.).astype(int)
358
+ kptsA_y = np.round(kptsA_np[:, 1] / 1.).astype(int)
359
+ kptsA_color = imA[np.clip(kptsA_x, 0, H - 1), np.clip(kptsA_y, 0, W - 1)]
360
+
361
+ # Create a composite image from imB_compound
362
+ imB_compound_np = np.stack(imB_compound, axis=0)
363
+ H_B, W_B, _ = imB_compound[0].shape
364
+
365
+ # Extract colors for keypoints in imB using certainties_max_idcs
366
+ imB_np = imB_compound_np[
367
+ certainties_max_idcs.detach().cpu().numpy(),
368
+ np.arange(H).reshape(-1, 1),
369
+ np.arange(W)
370
+ ]
371
+
372
+ if verbose:
373
+ print("imB_np.shape:", imB_np.shape)
374
+ print("imB_np:", imB_np)
375
+ fig, ax = plt.subplots(figsize=(12, 6))
376
+ cax = ax.imshow(np.flipud(imB_np))
377
+ cax = ax.scatter(kptsB_np[:, 0], H_A - kptsB_np[:, 1], s=.03)
378
+ ax.set_title("np.flipud(imB_np[0]")
379
+ ax.set_xlim(0, W_A)
380
+ ax.set_ylim(0, H_A)
381
+ output_dict[f'np.flipud(imB_np[0]'] = fig
382
+
383
+
384
+ kptsB_x = np.round(kptsB_np[:, 0]).astype(int)
385
+ kptsB_y = np.round(kptsB_np[:, 1]).astype(int)
386
+
387
+ certainties_max_idcs_np = certainties_max_idcs.detach().cpu().numpy()
388
+ kptsB_proj_matrices_idx = certainties_max_idcs_np[np.clip(kptsA_x, 0, H - 1), np.clip(kptsA_y, 0, W - 1)]
389
+ kptsB_color = imB_compound_np[kptsB_proj_matrices_idx, np.clip(kptsB_y, 0, H - 1), np.clip(kptsB_x, 0, W - 1)]
390
+
391
+ # Normalize keypoints in both images
392
+ kptsA_np[:, 0] = kptsA_np[:, 0] / H * 2.0 - 1.0
393
+ kptsA_np[:, 1] = kptsA_np[:, 1] / W * 2.0 - 1.0
394
+ kptsB_np[:, 0] = kptsB_np[:, 0] / W_B * 2.0 - 1.0
395
+ kptsB_np[:, 1] = kptsB_np[:, 1] / H_B * 2.0 - 1.0
396
+
397
+ return kptsA_np[:, [1, 0]], kptsB_np, kptsB_proj_matrices_idx, kptsA_color, kptsB_color
398
+
399
+ def prepare_tensor(input_array, device):
400
+ """
401
+ Converts an input array to a torch tensor, clones it, and detaches it for safe computation.
402
+ Args:
403
+ input_array (array-like): The input array to convert.
404
+ device (str or torch.device): The device to move the tensor to.
405
+ Returns:
406
+ torch.Tensor: A detached tensor clone of the input array on the specified device.
407
+ """
408
+ if not isinstance(input_array, torch.Tensor):
409
+ return torch.tensor(input_array, dtype=torch.float32).to(device).clone().detach()
410
+ return input_array.clone().detach().to(device).to(torch.float32)
411
+
412
+ def triangulate_points(P1, P2, k1_x, k1_y, k2_x, k2_y, device="cuda"):
413
+ """
414
+ Solves for a batch of 3D points given batches of projection matrices and corresponding image points.
415
+
416
+ Parameters:
417
+ - P1, P2: Tensors of projection matrices of size (batch_size, 4, 4) or (4, 4)
418
+ - k1_x, k1_y: Tensors of shape (batch_size,)
419
+ - k2_x, k2_y: Tensors of shape (batch_size,)
420
+
421
+ Returns:
422
+ - X: A tensor containing the 3D homogeneous coordinates, shape (batch_size, 4)
423
+ """
424
+ EPS = 1e-4
425
+ # Ensure inputs are tensors
426
+
427
+ P1 = prepare_tensor(P1, device)
428
+ P2 = prepare_tensor(P2, device)
429
+ k1_x = prepare_tensor(k1_x, device)
430
+ k1_y = prepare_tensor(k1_y, device)
431
+ k2_x = prepare_tensor(k2_x, device)
432
+ k2_y = prepare_tensor(k2_y, device)
433
+ batch_size = k1_x.shape[0]
434
+
435
+ # Expand P1 and P2 if they are not batched
436
+ if P1.ndim == 2:
437
+ P1 = P1.unsqueeze(0).expand(batch_size, -1, -1)
438
+ if P2.ndim == 2:
439
+ P2 = P2.unsqueeze(0).expand(batch_size, -1, -1)
440
+
441
+ # Extract columns from P1 and P2
442
+ P1_0 = P1[:, :, 0] # Shape: (batch_size, 4)
443
+ P1_1 = P1[:, :, 1]
444
+ P1_2 = P1[:, :, 2]
445
+
446
+ P2_0 = P2[:, :, 0]
447
+ P2_1 = P2[:, :, 1]
448
+ P2_2 = P2[:, :, 2]
449
+
450
+ # Reshape kx and ky to (batch_size, 1)
451
+ k1_x = k1_x.view(-1, 1)
452
+ k1_y = k1_y.view(-1, 1)
453
+ k2_x = k2_x.view(-1, 1)
454
+ k2_y = k2_y.view(-1, 1)
455
+
456
+ # Construct the equations for each batch
457
+ # For camera 1
458
+ A1 = P1_0 - k1_x * P1_2 # Shape: (batch_size, 4)
459
+ A2 = P1_1 - k1_y * P1_2
460
+ # For camera 2
461
+ A3 = P2_0 - k2_x * P2_2
462
+ A4 = P2_1 - k2_y * P2_2
463
+
464
+ # Stack the equations
465
+ A = torch.stack([A1, A2, A3, A4], dim=1) # Shape: (batch_size, 4, 4)
466
+
467
+ # Right-hand side (constants)
468
+ b = -A[:, :, 3] # Shape: (batch_size, 4)
469
+ A_reduced = A[:, :, :3] # Coefficients of x, y, z
470
+
471
+ # Solve using torch.linalg.lstsq (supports batching)
472
+ X_xyz = torch.linalg.lstsq(A_reduced, b.unsqueeze(2)).solution.squeeze(2) # Shape: (batch_size, 3)
473
+
474
+ # Append 1 to get homogeneous coordinates
475
+ ones = torch.ones((batch_size, 1), dtype=torch.float32, device=X_xyz.device)
476
+ X = torch.cat([X_xyz, ones], dim=1) # Shape: (batch_size, 4)
477
+
478
+ # Now compute the errors of projections.
479
+ seeked_splats_proj1 = (X.unsqueeze(1) @ P1).squeeze(1)
480
+ seeked_splats_proj1 = seeked_splats_proj1 / (EPS + seeked_splats_proj1[:, [3]])
481
+ seeked_splats_proj2 = (X.unsqueeze(1) @ P2).squeeze(1)
482
+ seeked_splats_proj2 = seeked_splats_proj2 / (EPS + seeked_splats_proj2[:, [3]])
483
+ proj1_target = torch.concat([k1_x, k1_y], dim=1)
484
+ proj2_target = torch.concat([k2_x, k2_y], dim=1)
485
+ errors_proj1 = torch.abs(seeked_splats_proj1[:, :2] - proj1_target).sum(1).detach().cpu().numpy()
486
+ errors_proj2 = torch.abs(seeked_splats_proj2[:, :2] - proj2_target).sum(1).detach().cpu().numpy()
487
+
488
+ return X, errors_proj1, errors_proj2
489
+
490
+
491
+
492
+ def select_best_keypoints(
493
+ NNs_triangulated_points, NNs_errors_proj1, NNs_errors_proj2, device="cuda"):
494
+ """
495
+ From all the points fitted to keypoints and corresponding colors from the source image (imA) and multiple target images (imB_compound).
496
+
497
+ Args:
498
+ NNs_triangulated_points: torch tensor with keypoints coordinates (num_nns, num_points, dim). dim can be arbitrary,
499
+ usually 3 or 4(for homogeneous representation).
500
+ NNs_errors_proj1: numpy array with projection error of the estimated keypoint on the reference frame (num_nns, num_points).
501
+ NNs_errors_proj2: numpy array with projection error of the estimated keypoint on the neighbor frame (num_nns, num_points).
502
+ Returns:
503
+ selected_keypoints: keypoints with the best score.
504
+ """
505
+
506
+ NNs_errors_proj = np.maximum(NNs_errors_proj1, NNs_errors_proj2)
507
+
508
+ # Convert indices to PyTorch tensor
509
+ indices = torch.from_numpy(np.argmin(NNs_errors_proj, axis=0)).long().to(device)
510
+
511
+ # Create index tensor for the second dimension
512
+ n_indices = torch.arange(NNs_triangulated_points.shape[1]).long().to(device)
513
+
514
+ # Use advanced indexing to select elements
515
+ NNs_triangulated_points_selected = NNs_triangulated_points[indices, n_indices, :] # Shape: [N, k]
516
+
517
+ return NNs_triangulated_points_selected, np.min(NNs_errors_proj, axis=0)
518
+
519
+
520
+
521
+ def init_gaussians_with_corr(gaussians, scene, cfg, device, verbose = False, roma_model=None):
522
+ """
523
+ For a given input gaussians and a scene we instantiate a RoMa model(change to indoors if necessary) and process scene
524
+ training frames to extract correspondences. Those are used to initialize gaussians
525
+ Args:
526
+ gaussians: object gaussians of the class GaussianModel that we need to enrich with gaussians.
527
+ scene: object of the Scene class.
528
+ cfg: configuration. Use init_wC
529
+ Returns:
530
+ gaussians: inplace transforms object gaussians of the class GaussianModel.
531
+
532
+ """
533
+ if roma_model is None:
534
+ if cfg.roma_model == "indoors":
535
+ roma_model = roma_indoor(device=device)
536
+ else:
537
+ roma_model = roma_outdoor(device=device)
538
+ roma_model.upsample_preds = False
539
+ roma_model.symmetric = False
540
+ M = cfg.matches_per_ref
541
+ upper_thresh = roma_model.sample_thresh
542
+ scaling_factor = cfg.scaling_factor
543
+ expansion_factor = 1
544
+ keypoint_fit_error_tolerance = cfg.proj_err_tolerance
545
+ visualizations = {}
546
+ viewpoint_stack = scene.getTrainCameras().copy()
547
+ NUM_REFERENCE_FRAMES = min(cfg.num_refs, len(viewpoint_stack))
548
+ NUM_NNS_PER_REFERENCE = min(cfg.nns_per_ref , len(viewpoint_stack))
549
+ # Select cameras using K-means
550
+ viewpoint_cam_all = torch.stack([x.world_view_transform.flatten() for x in viewpoint_stack], axis=0)
551
+
552
+ selected_indices = select_cameras_kmeans(cameras=viewpoint_cam_all.detach().cpu().numpy(), K=NUM_REFERENCE_FRAMES)
553
+ selected_indices = sorted(selected_indices)
554
+
555
+
556
+ # Find the k-closest vectors for each vector
557
+ viewpoint_cam_all = torch.stack([x.world_view_transform.flatten() for x in viewpoint_stack], axis=0)
558
+ closest_indices = k_closest_vectors(viewpoint_cam_all, NUM_NNS_PER_REFERENCE)
559
+ if verbose: print("Indices of k-closest vectors for each vector:\n", closest_indices)
560
+
561
+ closest_indices_selected = closest_indices[:, :].detach().cpu().numpy()
562
+
563
+ all_new_xyz = []
564
+ all_new_features_dc = []
565
+ all_new_features_rest = []
566
+ all_new_opacities = []
567
+ all_new_scaling = []
568
+ all_new_rotation = []
569
+
570
+ # Run roma_model.match once to kinda initialize the model
571
+ with torch.no_grad():
572
+ viewpoint_cam1 = viewpoint_stack[0]
573
+ viewpoint_cam2 = viewpoint_stack[1]
574
+ imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0)
575
+ imB = viewpoint_cam2.original_image.detach().cpu().numpy().transpose(1, 2, 0)
576
+ imA = Image.fromarray(np.clip(imA * 255, 0, 255).astype(np.uint8))
577
+ imB = Image.fromarray(np.clip(imB * 255, 0, 255).astype(np.uint8))
578
+ warp, certainty_warp = roma_model.match(imA, imB, device=device)
579
+ print("Once run full roma_model.match warp.shape:", warp.shape)
580
+ print("Once run full roma_model.match certainty_warp.shape:", certainty_warp.shape)
581
+ del warp, certainty_warp
582
+ torch.cuda.empty_cache()
583
+
584
+ for source_idx in tqdm(sorted(selected_indices)):
585
+ # 1. Compute keypoints and warping for all the neigboring views
586
+ with torch.no_grad():
587
+ # Call the aggregation function to get imA and imB_compound
588
+ certainties_max, warps_max, certainties_max_idcs, imA, imB_compound, certainties_all, warps_all = aggregate_confidences_and_warps(
589
+ viewpoint_stack=viewpoint_stack,
590
+ closest_indices=closest_indices_selected,
591
+ roma_model=roma_model,
592
+ source_idx=source_idx,
593
+ verbose=verbose, output_dict=visualizations
594
+ )
595
+
596
+
597
+ # Triangulate keypoints
598
+ with torch.no_grad():
599
+ matches = warps_max
600
+ certainty = certainties_max
601
+ certainty = certainty.clone()
602
+ certainty[certainty > upper_thresh] = 1
603
+ matches, certainty = (
604
+ matches.reshape(-1, 4),
605
+ certainty.reshape(-1),
606
+ )
607
+
608
+ # Select based on certainty elements with high confidence. These are basically all of
609
+ # kptsA_np.
610
+ good_samples = torch.multinomial(certainty,
611
+ num_samples=min(expansion_factor * M, len(certainty)),
612
+ replacement=False)
613
+
614
+ certainties_max, warps_max, certainties_max_idcs, imA, imB_compound, certainties_all, warps_all
615
+ reference_image_dict = {
616
+ "ref_image": imA,
617
+ "NNs_images": imB_compound,
618
+ "certainties_all": certainties_all,
619
+ "warps_all": warps_all,
620
+ "triangulated_points": [],
621
+ "triangulated_points_errors_proj1": [],
622
+ "triangulated_points_errors_proj2": []
623
+
624
+ }
625
+ with torch.no_grad():
626
+ for NN_idx in tqdm(range(len(warps_all))):
627
+ matches_NN = warps_all[NN_idx].reshape(-1, 4)[good_samples]
628
+
629
+ # Extract keypoints and colors
630
+ kptsA_np, kptsB_np, kptsB_proj_matrices_idcs, kptsA_color, kptsB_color = extract_keypoints_and_colors(
631
+ imA, imB_compound, certainties_max, certainties_max_idcs, matches_NN, roma_model
632
+ )
633
+
634
+ proj_matrices_A = viewpoint_stack[source_idx].full_proj_transform
635
+ proj_matrices_B = viewpoint_stack[closest_indices_selected[source_idx, NN_idx]].full_proj_transform
636
+ triangulated_points, triangulated_points_errors_proj1, triangulated_points_errors_proj2 = triangulate_points(
637
+ P1=torch.stack([proj_matrices_A] * M, axis=0),
638
+ P2=torch.stack([proj_matrices_B] * M, axis=0),
639
+ k1_x=kptsA_np[:M, 0], k1_y=kptsA_np[:M, 1],
640
+ k2_x=kptsB_np[:M, 0], k2_y=kptsB_np[:M, 1])
641
+
642
+ reference_image_dict["triangulated_points"].append(triangulated_points)
643
+ reference_image_dict["triangulated_points_errors_proj1"].append(triangulated_points_errors_proj1)
644
+ reference_image_dict["triangulated_points_errors_proj2"].append(triangulated_points_errors_proj2)
645
+
646
+ with torch.no_grad():
647
+ NNs_triangulated_points_selected, NNs_triangulated_points_selected_proj_errors = select_best_keypoints(
648
+ NNs_triangulated_points=torch.stack(reference_image_dict["triangulated_points"], dim=0),
649
+ NNs_errors_proj1=np.stack(reference_image_dict["triangulated_points_errors_proj1"], axis=0),
650
+ NNs_errors_proj2=np.stack(reference_image_dict["triangulated_points_errors_proj2"], axis=0))
651
+
652
+ # 4. Save as gaussians
653
+ viewpoint_cam1 = viewpoint_stack[source_idx]
654
+ N = len(NNs_triangulated_points_selected)
655
+ with torch.no_grad():
656
+ new_xyz = NNs_triangulated_points_selected[:, :-1]
657
+ all_new_xyz.append(new_xyz) # seeked_splats
658
+ all_new_features_dc.append(RGB2SH(torch.tensor(kptsA_color.astype(np.float32) / 255.)).unsqueeze(1))
659
+ all_new_features_rest.append(torch.stack([gaussians._features_rest[-1].clone().detach() * 0.] * N, dim=0))
660
+ # new version that sets points with large error invisible
661
+ # TODO: remove those points instead. However it doesn't affect the performance.
662
+ mask_bad_points = torch.tensor(
663
+ NNs_triangulated_points_selected_proj_errors > keypoint_fit_error_tolerance,
664
+ dtype=torch.float32).unsqueeze(1).to(device)
665
+ all_new_opacities.append(torch.stack([gaussians._opacity[-1].clone().detach()] * N, dim=0) * 0. - mask_bad_points * (1e1))
666
+
667
+ dist_points_to_cam1 = torch.linalg.norm(viewpoint_cam1.camera_center.clone().detach() - new_xyz,
668
+ dim=1, ord=2)
669
+ #all_new_scaling.append(torch.log(((dist_points_to_cam1) / 1. * scaling_factor).unsqueeze(1).repeat(1, 3)))
670
+ all_new_scaling.append(gaussians.scaling_inverse_activation((dist_points_to_cam1 * scaling_factor).unsqueeze(1).repeat(1, 3)))
671
+ all_new_rotation.append(torch.stack([gaussians._rotation[-1].clone().detach()] * N, dim=0))
672
+
673
+ all_new_xyz = torch.cat(all_new_xyz, dim=0)
674
+ all_new_features_dc = torch.cat(all_new_features_dc, dim=0)
675
+ new_tmp_radii = torch.zeros(all_new_xyz.shape[0])
676
+ prune_mask = torch.ones(all_new_xyz.shape[0], dtype=torch.bool)
677
+
678
+ gaussians.densification_postfix(all_new_xyz[prune_mask].to(device),
679
+ all_new_features_dc[prune_mask].to(device),
680
+ torch.cat(all_new_features_rest, dim=0)[prune_mask].to(device),
681
+ torch.cat(all_new_opacities, dim=0)[prune_mask].to(device),
682
+ torch.cat(all_new_scaling, dim=0)[prune_mask].to(device),
683
+ torch.cat(all_new_rotation, dim=0)[prune_mask].to(device),
684
+ new_tmp_radii[prune_mask].to(device))
685
+
686
+ return viewpoint_stack, closest_indices_selected, visualizations
687
+
688
+
689
+
690
+ def extract_keypoints_and_colors_single(imA, imB, matches, roma_model, verbose=False, output_dict={}):
691
+ """
692
+ Extracts keypoints and corresponding colors from a source image (imA) and a single target image (imB).
693
+
694
+ Args:
695
+ imA: Source image as a NumPy array (H_A, W_A, C).
696
+ imB: Target image as a NumPy array (H_B, W_B, C).
697
+ matches: Matches in normalized coordinates (torch.Tensor).
698
+ roma_model: Roma model instance for keypoint operations.
699
+ verbose: If True, outputs intermediate visualizations.
700
+ Returns:
701
+ kptsA_np: Keypoints in imA (normalized).
702
+ kptsB_np: Keypoints in imB (normalized).
703
+ kptsA_color: Colors of keypoints in imA.
704
+ kptsB_color: Colors of keypoints in imB.
705
+ """
706
+ H_A, W_A, _ = imA.shape
707
+ H_B, W_B, _ = imB.shape
708
+
709
+ # Convert matches to pixel coordinates
710
+ # Matches format: (B, 4) = (x1_norm, y1_norm, x2_norm, y2_norm)
711
+ kptsA = matches[:, :2] # [N, 2]
712
+ kptsB = matches[:, 2:] # [N, 2]
713
+
714
+ # Scale normalized coordinates [-1,1] to pixel coordinates
715
+ kptsA_pix = torch.zeros_like(kptsA)
716
+ kptsB_pix = torch.zeros_like(kptsB)
717
+
718
+ # Important! [Normalized to pixel space]
719
+ kptsA_pix[:, 0] = (kptsA[:, 0] + 1) * (W_A - 1) / 2
720
+ kptsA_pix[:, 1] = (kptsA[:, 1] + 1) * (H_A - 1) / 2
721
+
722
+ kptsB_pix[:, 0] = (kptsB[:, 0] + 1) * (W_B - 1) / 2
723
+ kptsB_pix[:, 1] = (kptsB[:, 1] + 1) * (H_B - 1) / 2
724
+
725
+ kptsA_np = kptsA_pix.detach().cpu().numpy()
726
+ kptsB_np = kptsB_pix.detach().cpu().numpy()
727
+
728
+ # Extract colors
729
+ kptsA_x = np.round(kptsA_np[:, 0]).astype(int)
730
+ kptsA_y = np.round(kptsA_np[:, 1]).astype(int)
731
+ kptsB_x = np.round(kptsB_np[:, 0]).astype(int)
732
+ kptsB_y = np.round(kptsB_np[:, 1]).astype(int)
733
+
734
+ kptsA_color = imA[np.clip(kptsA_y, 0, H_A-1), np.clip(kptsA_x, 0, W_A-1)]
735
+ kptsB_color = imB[np.clip(kptsB_y, 0, H_B-1), np.clip(kptsB_x, 0, W_B-1)]
736
+
737
+ # Normalize keypoints into [-1, 1] for downstream triangulation
738
+ kptsA_np_norm = np.zeros_like(kptsA_np)
739
+ kptsB_np_norm = np.zeros_like(kptsB_np)
740
+
741
+ kptsA_np_norm[:, 0] = kptsA_np[:, 0] / (W_A - 1) * 2.0 - 1.0
742
+ kptsA_np_norm[:, 1] = kptsA_np[:, 1] / (H_A - 1) * 2.0 - 1.0
743
+
744
+ kptsB_np_norm[:, 0] = kptsB_np[:, 0] / (W_B - 1) * 2.0 - 1.0
745
+ kptsB_np_norm[:, 1] = kptsB_np[:, 1] / (H_B - 1) * 2.0 - 1.0
746
+
747
+ return kptsA_np_norm, kptsB_np_norm, kptsA_color, kptsB_color
748
+
749
+
750
+
751
+ def init_gaussians_with_corr_fast(gaussians, scene, cfg, device, verbose=False, roma_model=None):
752
+ timings = defaultdict(list)
753
+
754
+ if roma_model is None:
755
+ if cfg.roma_model == "indoors":
756
+ roma_model = roma_indoor(device=device)
757
+ else:
758
+ roma_model = roma_outdoor(device=device)
759
+ roma_model.upsample_preds = False
760
+ roma_model.symmetric = False
761
+
762
+ M = cfg.matches_per_ref
763
+ upper_thresh = roma_model.sample_thresh
764
+ scaling_factor = cfg.scaling_factor
765
+ expansion_factor = 1
766
+ keypoint_fit_error_tolerance = cfg.proj_err_tolerance
767
+ visualizations = {}
768
+ viewpoint_stack = scene.getTrainCameras().copy()
769
+ NUM_REFERENCE_FRAMES = min(cfg.num_refs, len(viewpoint_stack))
770
+ NUM_NNS_PER_REFERENCE = 1 # Only ONE neighbor now!
771
+
772
+ viewpoint_cam_all = torch.stack([x.world_view_transform.flatten() for x in viewpoint_stack], axis=0)
773
+
774
+ selected_indices = select_cameras_kmeans(cameras=viewpoint_cam_all.detach().cpu().numpy(), K=NUM_REFERENCE_FRAMES)
775
+ selected_indices = sorted(selected_indices)
776
+
777
+ viewpoint_cam_all = torch.stack([x.world_view_transform.flatten() for x in viewpoint_stack], axis=0)
778
+ closest_indices = k_closest_vectors(viewpoint_cam_all, NUM_NNS_PER_REFERENCE)
779
+ closest_indices_selected = closest_indices[:, :].detach().cpu().numpy()
780
+
781
+ all_new_xyz = []
782
+ all_new_features_dc = []
783
+ all_new_features_rest = []
784
+ all_new_opacities = []
785
+ all_new_scaling = []
786
+ all_new_rotation = []
787
+
788
+ # Dummy first pass to initialize model
789
+ with torch.no_grad():
790
+ viewpoint_cam1 = viewpoint_stack[0]
791
+ viewpoint_cam2 = viewpoint_stack[1]
792
+ imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0)
793
+ imB = viewpoint_cam2.original_image.detach().cpu().numpy().transpose(1, 2, 0)
794
+ imA = Image.fromarray(np.clip(imA * 255, 0, 255).astype(np.uint8))
795
+ imB = Image.fromarray(np.clip(imB * 255, 0, 255).astype(np.uint8))
796
+ warp, certainty_warp = roma_model.match(imA, imB, device=device)
797
+ del warp, certainty_warp
798
+ torch.cuda.empty_cache()
799
+
800
+ # Main Loop over source_idx
801
+ for source_idx in tqdm(sorted(selected_indices), desc="Profiling source frames"):
802
+
803
+ # =================== Step 1: Compute Warp and Certainty ===================
804
+ start = time.time()
805
+ viewpoint_cam1 = viewpoint_stack[source_idx]
806
+ NNs=closest_indices_selected.shape[1]
807
+ viewpoint_cam2 = viewpoint_stack[closest_indices_selected[source_idx, np.random.randint(NNs)]]
808
+ imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0)
809
+ imB = viewpoint_cam2.original_image.detach().cpu().numpy().transpose(1, 2, 0)
810
+ imA = Image.fromarray(np.clip(imA * 255, 0, 255).astype(np.uint8))
811
+ imB = Image.fromarray(np.clip(imB * 255, 0, 255).astype(np.uint8))
812
+ warp, certainty_warp = roma_model.match(imA, imB, device=device)
813
+
814
+ certainties_max = certainty_warp # New manual sampling
815
+ timings['aggregation_warp_certainty'].append(time.time() - start)
816
+
817
+ # =================== Step 2: Good Samples Selection ===================
818
+ start = time.time()
819
+ certainty = certainties_max.reshape(-1).clone()
820
+ certainty[certainty > upper_thresh] = 1
821
+ good_samples = torch.multinomial(certainty, num_samples=min(expansion_factor * M, len(certainty)), replacement=False)
822
+ timings['good_samples_selection'].append(time.time() - start)
823
+
824
+ # =================== Step 3: Triangulate Keypoints ===================
825
+ reference_image_dict = {
826
+ "triangulated_points": [],
827
+ "triangulated_points_errors_proj1": [],
828
+ "triangulated_points_errors_proj2": []
829
+ }
830
+
831
+ start = time.time()
832
+ matches_NN = warp.reshape(-1, 4)[good_samples]
833
+
834
+ # Convert matches to pixel coordinates
835
+ kptsA_np, kptsB_np, kptsA_color, kptsB_color = extract_keypoints_and_colors_single(
836
+ np.array(imA).astype(np.uint8),
837
+ np.array(imB).astype(np.uint8),
838
+ matches_NN,
839
+ roma_model
840
+ )
841
+
842
+ proj_matrices_A = viewpoint_stack[source_idx].full_proj_transform
843
+ proj_matrices_B = viewpoint_stack[closest_indices_selected[source_idx, 0]].full_proj_transform
844
+
845
+ triangulated_points, triangulated_points_errors_proj1, triangulated_points_errors_proj2 = triangulate_points(
846
+ P1=torch.stack([proj_matrices_A] * M, axis=0),
847
+ P2=torch.stack([proj_matrices_B] * M, axis=0),
848
+ k1_x=kptsA_np[:M, 0], k1_y=kptsA_np[:M, 1],
849
+ k2_x=kptsB_np[:M, 0], k2_y=kptsB_np[:M, 1])
850
+
851
+ reference_image_dict["triangulated_points"].append(triangulated_points)
852
+ reference_image_dict["triangulated_points_errors_proj1"].append(triangulated_points_errors_proj1)
853
+ reference_image_dict["triangulated_points_errors_proj2"].append(triangulated_points_errors_proj2)
854
+ timings['triangulation_per_NN'].append(time.time() - start)
855
+
856
+ # =================== Step 4: Select Best Triangulated Points ===================
857
+ start = time.time()
858
+ NNs_triangulated_points_selected, NNs_triangulated_points_selected_proj_errors = select_best_keypoints(
859
+ NNs_triangulated_points=torch.stack(reference_image_dict["triangulated_points"], dim=0),
860
+ NNs_errors_proj1=np.stack(reference_image_dict["triangulated_points_errors_proj1"], axis=0),
861
+ NNs_errors_proj2=np.stack(reference_image_dict["triangulated_points_errors_proj2"], axis=0))
862
+ timings['select_best_keypoints'].append(time.time() - start)
863
+
864
+ # =================== Step 5: Create New Gaussians ===================
865
+ start = time.time()
866
+ viewpoint_cam1 = viewpoint_stack[source_idx]
867
+ N = len(NNs_triangulated_points_selected)
868
+ new_xyz = NNs_triangulated_points_selected[:, :-1]
869
+ all_new_xyz.append(new_xyz)
870
+ all_new_features_dc.append(RGB2SH(torch.tensor(kptsA_color.astype(np.float32) / 255.)).unsqueeze(1))
871
+ all_new_features_rest.append(torch.stack([gaussians._features_rest[-1].clone().detach() * 0.] * N, dim=0))
872
+
873
+ mask_bad_points = torch.tensor(
874
+ NNs_triangulated_points_selected_proj_errors > keypoint_fit_error_tolerance,
875
+ dtype=torch.float32).unsqueeze(1).to(device)
876
+
877
+ all_new_opacities.append(torch.stack([gaussians._opacity[-1].clone().detach()] * N, dim=0) * 0. - mask_bad_points * (1e1))
878
+
879
+ dist_points_to_cam1 = torch.linalg.norm(viewpoint_cam1.camera_center.clone().detach() - new_xyz, dim=1, ord=2)
880
+ all_new_scaling.append(gaussians.scaling_inverse_activation((dist_points_to_cam1 * scaling_factor).unsqueeze(1).repeat(1, 3)))
881
+ all_new_rotation.append(torch.stack([gaussians._rotation[-1].clone().detach()] * N, dim=0))
882
+ timings['save_gaussians'].append(time.time() - start)
883
+
884
+ # =================== Final Densification Postfix ===================
885
+ start = time.time()
886
+ all_new_xyz = torch.cat(all_new_xyz, dim=0)
887
+ all_new_features_dc = torch.cat(all_new_features_dc, dim=0)
888
+ new_tmp_radii = torch.zeros(all_new_xyz.shape[0])
889
+ prune_mask = torch.ones(all_new_xyz.shape[0], dtype=torch.bool)
890
+
891
+ gaussians.densification_postfix(
892
+ all_new_xyz[prune_mask].to(device),
893
+ all_new_features_dc[prune_mask].to(device),
894
+ torch.cat(all_new_features_rest, dim=0)[prune_mask].to(device),
895
+ torch.cat(all_new_opacities, dim=0)[prune_mask].to(device),
896
+ torch.cat(all_new_scaling, dim=0)[prune_mask].to(device),
897
+ torch.cat(all_new_rotation, dim=0)[prune_mask].to(device),
898
+ new_tmp_radii[prune_mask].to(device)
899
+ )
900
+ timings['final_densification_postfix'].append(time.time() - start)
901
+
902
+ # =================== Print Profiling Results ===================
903
+ print("\n=== Profiling Summary (average per frame) ===")
904
+ for key, times in timings.items():
905
+ print(f"{key:35s}: {sum(times) / len(times):.4f} sec (total {sum(times):.2f} sec)")
906
+
907
+ return viewpoint_stack, closest_indices_selected, visualizations
source/data_utils.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def scene_cameras_train_test_split(scene, verbose=False):
2
+ """
3
+ Iterate over resolutions in the scene. For each resolution check if this resolution has test_cameras
4
+ if it doesn't then extract every 8th camera from the train and put it to the test set. This follows the
5
+ evaluation protocol suggested by Kerbl et al. in the seminal work on 3DGS. All changes to the input
6
+ object scene are inplace changes.
7
+ :param scene: Scene Class object from the gaussian-splatting.scene module
8
+ :param verbose: Print initial and final stage of the function
9
+ :return: None
10
+
11
+ """
12
+ if verbose: print("Preparing train and test sets split...")
13
+ for resolution in scene.train_cameras.keys():
14
+ if len(scene.test_cameras[resolution]) == 0:
15
+ if verbose:
16
+ print(f"Found no test_cameras for resolution {resolution}. Move every 8th camera out ouf total "+\
17
+ f"{len(scene.train_cameras[resolution])} train cameras to the test set now")
18
+ N = len(scene.train_cameras[resolution])
19
+ scene.test_cameras[resolution] = [scene.train_cameras[resolution][idx] for idx in range(0, N)
20
+ if idx % 8 == 0]
21
+ scene.train_cameras[resolution] = [scene.train_cameras[resolution][idx] for idx in range(0, N)
22
+ if idx % 8 != 0]
23
+ if verbose:
24
+ print(f"Done. Now train and test sets contain each {len(scene.train_cameras[resolution])} and " + \
25
+ f"{len(scene.test_cameras[resolution])} cameras respectively.")
26
+
27
+
28
+ return
source/losses.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code is copied from the gaussian-splatting/utils/loss_utils.py
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch.autograd import Variable
6
+ from math import exp
7
+
8
+ def l1_loss(network_output, gt, mean=True):
9
+ return torch.abs((network_output - gt)).mean() if mean else torch.abs((network_output - gt))
10
+
11
+ def l2_loss(network_output, gt):
12
+ return ((network_output - gt) ** 2).mean()
13
+
14
+ def gaussian(window_size, sigma):
15
+ gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
16
+ return gauss / gauss.sum()
17
+
18
+ def create_window(window_size, channel):
19
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
20
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
21
+ window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
22
+ return window
23
+
24
+ def ssim(img1, img2, window_size=11, size_average=True, mask = None):
25
+ channel = img1.size(-3)
26
+ window = create_window(window_size, channel)
27
+
28
+ if img1.is_cuda:
29
+ window = window.cuda(img1.get_device())
30
+ window = window.type_as(img1)
31
+
32
+ return _ssim(img1, img2, window, window_size, channel, size_average, mask)
33
+
34
+ def _ssim(img1, img2, window, window_size, channel, size_average=True, mask = None):
35
+ mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
36
+ mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
37
+
38
+ mu1_sq = mu1.pow(2)
39
+ mu2_sq = mu2.pow(2)
40
+ mu1_mu2 = mu1 * mu2
41
+
42
+ sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
43
+ sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
44
+ sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
45
+
46
+ C1 = 0.01 ** 2
47
+ C2 = 0.03 ** 2
48
+
49
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
50
+
51
+ if mask is not None:
52
+ ssim_map = ssim_map * mask
53
+
54
+ if size_average:
55
+ return ssim_map.mean()
56
+ else:
57
+ return ssim_map.mean(1).mean(1).mean(1)
58
+
59
+
60
+ def mse(img1, img2):
61
+ return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
62
+
63
+ def psnr(img1, img2):
64
+ """
65
+ Computes the Peak Signal-to-Noise Ratio (PSNR) between two single images. NOT BATCHED!
66
+ Args:
67
+ img1 (torch.Tensor): The first image tensor, with pixel values scaled between 0 and 1.
68
+ Shape should be (channels, height, width).
69
+ img2 (torch.Tensor): The second image tensor with the same shape as img1, used for comparison.
70
+
71
+ Returns:
72
+ torch.Tensor: A scalar tensor containing the PSNR value in decibels (dB).
73
+ """
74
+ mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
75
+ return 20 * torch.log10(1.0 / torch.sqrt(mse))
76
+
77
+
78
+ def tv_loss(image):
79
+ """
80
+ Computes the total variation (TV) loss for an image of shape [3, H, W].
81
+
82
+ Args:
83
+ image (torch.Tensor): Input image of shape [3, H, W]
84
+
85
+ Returns:
86
+ torch.Tensor: Scalar value representing the total variation loss.
87
+ """
88
+ # Ensure the image has the correct dimensions
89
+ assert image.ndim == 3 and image.shape[0] == 3, "Input must be of shape [3, H, W]"
90
+
91
+ # Compute the difference between adjacent pixels in the x-direction (width)
92
+ diff_x = torch.abs(image[:, :, 1:] - image[:, :, :-1])
93
+
94
+ # Compute the difference between adjacent pixels in the y-direction (height)
95
+ diff_y = torch.abs(image[:, 1:, :] - image[:, :-1, :])
96
+
97
+ # Sum the total variation in both directions
98
+ tv_loss_value = torch.mean(diff_x) + torch.mean(diff_y)
99
+
100
+ return tv_loss_value
source/networks.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import sys
4
+ sys.path.append('./submodules/gaussian-splatting/')
5
+
6
+ from random import randint
7
+ from scene import Scene, GaussianModel
8
+ from gaussian_renderer import render
9
+ from source.data_utils import scene_cameras_train_test_split
10
+
11
+ class Warper3DGS(torch.nn.Module):
12
+ def __init__(self, sh_degree, opt, pipe, dataset, viewpoint_stack, verbose,
13
+ do_train_test_split=True):
14
+ super(Warper3DGS, self).__init__()
15
+ """
16
+ Init Warper using all the objects necessary for rendering gaussian splats.
17
+ Here we merely link class objects to the objects instantiated outsided the class.
18
+ """
19
+ self.gaussians = GaussianModel(sh_degree)
20
+ self.gaussians.tmp_radii = torch.zeros((self.gaussians.get_xyz.shape[0]), device="cuda")
21
+ self.render = render
22
+ self.gs_config_opt = opt
23
+ bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
24
+ self.bg = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
25
+ self.pipe = pipe
26
+ self.scene = Scene(dataset, self.gaussians, shuffle=False)
27
+ if do_train_test_split:
28
+ scene_cameras_train_test_split(self.scene, verbose=verbose)
29
+
30
+ self.gaussians.training_setup(opt)
31
+ self.viewpoint_stack = viewpoint_stack
32
+ if not self.viewpoint_stack:
33
+ self.viewpoint_stack = self.scene.getTrainCameras().copy()
34
+
35
+ def forward(self, viewpoint_cam=None):
36
+ """
37
+ For a provided camera viewpoint_cam we render gaussians from this viewpoint.
38
+ If no camera provided then we use the self.viewpoint_stack (list of cameras).
39
+ If the latter is empty we reinitialize it using the self.scene object.
40
+ """
41
+ if not viewpoint_cam:
42
+ if not self.viewpoint_stack:
43
+ self.viewpoint_stack = self.scene.getTrainCameras().copy()
44
+ viewpoint_cam = self.viewpoint_stack[randint(0, len(self.viewpoint_stack) - 1)]
45
+
46
+ render_pkg = self.render(viewpoint_cam, self.gaussians, self.pipe, self.bg)
47
+ return render_pkg
48
+
source/timer.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ class Timer:
3
+ def __init__(self):
4
+ self.start_time = None
5
+ self.elapsed = 0
6
+ self.paused = False
7
+
8
+ def start(self):
9
+ if self.start_time is None:
10
+ self.start_time = time.time()
11
+ elif self.paused:
12
+ self.start_time = time.time() - self.elapsed
13
+ self.paused = False
14
+
15
+ def pause(self):
16
+ if not self.paused:
17
+ self.elapsed = time.time() - self.start_time
18
+ self.paused = True
19
+
20
+ def get_elapsed_time(self):
21
+ if self.paused:
22
+ return self.elapsed
23
+ else:
24
+ return time.time() - self.start_time
source/trainer.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from random import randint
3
+ from tqdm.rich import trange
4
+ from tqdm import tqdm as tqdm
5
+ from source.networks import Warper3DGS
6
+ import wandb
7
+ import sys
8
+
9
+ sys.path.append('./submodules/gaussian-splatting/')
10
+ import lpips
11
+ from source.losses import ssim, l1_loss, psnr
12
+ from rich.console import Console
13
+ from rich.theme import Theme
14
+
15
+ custom_theme = Theme({
16
+ "info": "dim cyan",
17
+ "warning": "magenta",
18
+ "danger": "bold red"
19
+ })
20
+
21
+ from source.corr_init import init_gaussians_with_corr, init_gaussians_with_corr_fast
22
+ from source.utils_aux import log_samples
23
+
24
+ from source.timer import Timer
25
+
26
+ class EDGSTrainer:
27
+ def __init__(self,
28
+ GS: Warper3DGS,
29
+ training_config,
30
+ dataset_white_background=False,
31
+ device=torch.device('cuda'),
32
+ log_wandb=True,
33
+ ):
34
+ self.GS = GS
35
+ self.scene = GS.scene
36
+ self.viewpoint_stack = GS.viewpoint_stack
37
+ self.gaussians = GS.gaussians
38
+
39
+ self.training_config = training_config
40
+ self.GS_optimizer = GS.gaussians.optimizer
41
+ self.dataset_white_background = dataset_white_background
42
+
43
+ self.training_step = 1
44
+ self.gs_step = 0
45
+ self.CONSOLE = Console(width=120, theme=custom_theme)
46
+ self.saving_iterations = training_config.save_iterations
47
+ self.evaluate_iterations = None
48
+ self.batch_size = training_config.batch_size
49
+ self.ema_loss_for_log = 0.0
50
+
51
+ # Logs in the format {step:{"loss1":loss1_value, "loss2":loss2_value}}
52
+ self.logs_losses = {}
53
+ self.lpips = lpips.LPIPS(net='vgg').to(device)
54
+ self.device = device
55
+ self.timer = Timer()
56
+ self.log_wandb = log_wandb
57
+
58
+ def load_checkpoints(self, load_cfg):
59
+ # Load 3DGS checkpoint
60
+ if load_cfg.gs:
61
+ self.gs.gaussians.restore(
62
+ torch.load(f"{load_cfg.gs}/chkpnt{load_cfg.gs_step}.pth")[0],
63
+ self.training_config)
64
+ self.GS_optimizer = self.GS.gaussians.optimizer
65
+ self.CONSOLE.print(f"3DGS loaded from checkpoint for iteration {load_cfg.gs_step}",
66
+ style="info")
67
+ self.training_step += load_cfg.gs_step
68
+ self.gs_step += load_cfg.gs_step
69
+
70
+ def train(self, train_cfg):
71
+ # 3DGS training
72
+ self.CONSOLE.print("Train 3DGS for {} iterations".format(train_cfg.gs_epochs), style="info")
73
+ with trange(self.training_step, self.training_step + train_cfg.gs_epochs, desc="[green]Train gaussians") as progress_bar:
74
+ for self.training_step in progress_bar:
75
+ radii = self.train_step_gs(max_lr=train_cfg.max_lr, no_densify=train_cfg.no_densify)
76
+ with torch.no_grad():
77
+ if train_cfg.no_densify:
78
+ self.prune(radii)
79
+ else:
80
+ self.densify_and_prune(radii)
81
+ if train_cfg.reduce_opacity:
82
+ # Slightly reduce opacity every few steps:
83
+ if self.gs_step < self.training_config.densify_until_iter and self.gs_step % 10 == 0:
84
+ opacities_new = torch.log(torch.exp(self.GS.gaussians._opacity.data) * 0.99)
85
+ self.GS.gaussians._opacity.data = opacities_new
86
+ self.timer.pause()
87
+ # Progress bar
88
+ if self.training_step % 10 == 0:
89
+ progress_bar.set_postfix({"[red]Loss": f"{self.ema_loss_for_log:.{7}f}"}, refresh=True)
90
+ # Log and save
91
+ if self.training_step in self.saving_iterations:
92
+ self.save_model()
93
+ if self.evaluate_iterations is not None:
94
+ if self.training_step in self.evaluate_iterations:
95
+ self.evaluate()
96
+ else:
97
+ if (self.training_step <= 3000 and self.training_step % 500 == 0) or \
98
+ (self.training_step > 3000 and self.training_step % 1000 == 228) :
99
+ self.evaluate()
100
+
101
+ self.timer.start()
102
+
103
+
104
+ def evaluate(self):
105
+ torch.cuda.empty_cache()
106
+ log_gen_images, log_real_images = [], []
107
+ validation_configs = ({'name': 'test', 'cameras': self.scene.getTestCameras(), 'cam_idx': self.training_config.TEST_CAM_IDX_TO_LOG},
108
+ {'name': 'train',
109
+ 'cameras': [self.scene.getTrainCameras()[idx % len(self.scene.getTrainCameras())] for idx in
110
+ range(0, 150, 5)], 'cam_idx': 10})
111
+ if self.log_wandb:
112
+ wandb.log({f"Number of Gaussians": len(self.GS.gaussians._xyz)}, step=self.training_step)
113
+ for config in validation_configs:
114
+ if config['cameras'] and len(config['cameras']) > 0:
115
+ l1_test = 0.0
116
+ psnr_test = 0.0
117
+ ssim_test = 0.0
118
+ lpips_splat_test = 0.0
119
+ for idx, viewpoint in enumerate(config['cameras']):
120
+ image = torch.clamp(self.GS(viewpoint)["render"], 0.0, 1.0)
121
+ gt_image = torch.clamp(viewpoint.original_image.to(self.device), 0.0, 1.0)
122
+ l1_test += l1_loss(image, gt_image).double()
123
+ psnr_test += psnr(image.unsqueeze(0), gt_image.unsqueeze(0)).double()
124
+ ssim_test += ssim(image, gt_image).double()
125
+ lpips_splat_test += self.lpips(image, gt_image).detach().double()
126
+ if idx in [config['cam_idx']]:
127
+ log_gen_images.append(image)
128
+ log_real_images.append(gt_image)
129
+ psnr_test /= len(config['cameras'])
130
+ l1_test /= len(config['cameras'])
131
+ ssim_test /= len(config['cameras'])
132
+ lpips_splat_test /= len(config['cameras'])
133
+ if self.log_wandb:
134
+ wandb.log({f"{config['name']}/L1": l1_test.item(), f"{config['name']}/PSNR": psnr_test.item(), \
135
+ f"{config['name']}/SSIM": ssim_test.item(), f"{config['name']}/LPIPS_splat": lpips_splat_test.item()}, step = self.training_step)
136
+ self.CONSOLE.print("\n[ITER {}], #{} gaussians, Evaluating {}: L1={:.6f}, PSNR={:.6f}, SSIM={:.6f}, LPIPS_splat={:.6f} ".format(
137
+ self.training_step, len(self.GS.gaussians._xyz), config['name'], l1_test.item(), psnr_test.item(), ssim_test.item(), lpips_splat_test.item()), style="info")
138
+ if self.log_wandb:
139
+ with torch.no_grad():
140
+ log_samples(torch.stack((log_real_images[0],log_gen_images[0])) , [], self.training_step, caption="Real and Generated Samples")
141
+ wandb.log({"time": self.timer.get_elapsed_time()}, step=self.training_step)
142
+ torch.cuda.empty_cache()
143
+
144
+ def train_step_gs(self, max_lr = False, no_densify = False):
145
+ self.gs_step += 1
146
+ if max_lr:
147
+ self.GS.gaussians.update_learning_rate(max(self.gs_step, 8_000))
148
+ else:
149
+ self.GS.gaussians.update_learning_rate(self.gs_step)
150
+ # Every 1000 its we increase the levels of SH up to a maximum degree
151
+ if self.gs_step % 1000 == 0:
152
+ self.GS.gaussians.oneupSHdegree()
153
+
154
+ # Pick a random Camera
155
+ if not self.viewpoint_stack:
156
+ self.viewpoint_stack = self.scene.getTrainCameras().copy()
157
+ viewpoint_cam = self.viewpoint_stack.pop(randint(0, len(self.viewpoint_stack) - 1))
158
+
159
+ render_pkg = self.GS(viewpoint_cam=viewpoint_cam)
160
+ image = render_pkg["render"]
161
+ # Loss
162
+ gt_image = viewpoint_cam.original_image.to(self.device)
163
+ L1_loss = l1_loss(image, gt_image)
164
+
165
+ ssim_loss = (1.0 - ssim(image, gt_image))
166
+ loss = (1.0 - self.training_config.lambda_dssim) * L1_loss + \
167
+ self.training_config.lambda_dssim * ssim_loss
168
+ self.timer.pause()
169
+ self.logs_losses[self.training_step] = {"loss": loss.item(),
170
+ "L1_loss": L1_loss.item(),
171
+ "ssim_loss": ssim_loss.item()}
172
+
173
+ if self.log_wandb:
174
+ for k, v in self.logs_losses[self.training_step].items():
175
+ wandb.log({f"train/{k}": v}, step=self.training_step)
176
+ self.ema_loss_for_log = 0.4 * self.logs_losses[self.training_step]["loss"] + 0.6 * self.ema_loss_for_log
177
+ self.timer.start()
178
+ self.GS_optimizer.zero_grad(set_to_none=True)
179
+ loss.backward()
180
+ with torch.no_grad():
181
+ if self.gs_step < self.training_config.densify_until_iter and not no_densify:
182
+ self.GS.gaussians.max_radii2D[render_pkg["visibility_filter"]] = torch.max(
183
+ self.GS.gaussians.max_radii2D[render_pkg["visibility_filter"]],
184
+ render_pkg["radii"][render_pkg["visibility_filter"]])
185
+ self.GS.gaussians.add_densification_stats(render_pkg["viewspace_points"],
186
+ render_pkg["visibility_filter"])
187
+
188
+ # Optimizer step
189
+ self.GS_optimizer.step()
190
+ self.GS_optimizer.zero_grad(set_to_none=True)
191
+ return render_pkg["radii"]
192
+
193
+ def densify_and_prune(self, radii = None):
194
+ # Densification or pruning
195
+ if self.gs_step < self.training_config.densify_until_iter:
196
+ if (self.gs_step > self.training_config.densify_from_iter) and \
197
+ (self.gs_step % self.training_config.densification_interval == 0):
198
+ size_threshold = 20 if self.gs_step > self.training_config.opacity_reset_interval else None
199
+ self.GS.gaussians.densify_and_prune(self.training_config.densify_grad_threshold,
200
+ 0.005,
201
+ self.GS.scene.cameras_extent,
202
+ size_threshold, radii)
203
+ if self.gs_step % self.training_config.opacity_reset_interval == 0 or (
204
+ self.dataset_white_background and self.gs_step == self.training_config.densify_from_iter):
205
+ self.GS.gaussians.reset_opacity()
206
+
207
+
208
+
209
+ def save_model(self):
210
+ print("\n[ITER {}] Saving Gaussians".format(self.gs_step))
211
+ self.scene.save(self.gs_step)
212
+ print("\n[ITER {}] Saving Checkpoint".format(self.gs_step))
213
+ torch.save((self.GS.gaussians.capture(), self.gs_step),
214
+ self.scene.model_path + "/chkpnt" + str(self.gs_step) + ".pth")
215
+
216
+
217
+ def init_with_corr(self, cfg, verbose=False, roma_model=None):
218
+ """
219
+ Initializes image with matchings. Also removes SfM init points.
220
+ Args:
221
+ cfg: configuration part named init_wC. Check train.yaml
222
+ verbose: whether you want to print intermediate results. Useful for debug.
223
+ roma_model: optionally you can pass here preinit RoMA model to avoid reinit
224
+ it every time.
225
+ """
226
+ if not cfg.use:
227
+ return None
228
+ N_splats_at_init = len(self.GS.gaussians._xyz)
229
+ print("N_splats_at_init:", N_splats_at_init)
230
+ if cfg.nns_per_ref == 1:
231
+ init_fn = init_gaussians_with_corr_fast
232
+ else:
233
+ init_fn = init_gaussians_with_corr
234
+ camera_set, selected_indices, visualization_dict = init_fn(
235
+ self.GS.gaussians,
236
+ self.scene,
237
+ cfg,
238
+ self.device,
239
+ verbose=verbose,
240
+ roma_model=roma_model)
241
+
242
+ # Remove SfM points and leave only matchings inits
243
+ if not cfg.add_SfM_init:
244
+ with torch.no_grad():
245
+ N_splats_after_init = len(self.GS.gaussians._xyz)
246
+ print("N_splats_after_init:", N_splats_after_init)
247
+ self.gaussians.tmp_radii = torch.zeros(self.gaussians._xyz.shape[0]).to(self.device)
248
+ mask = torch.concat([torch.ones(N_splats_at_init, dtype=torch.bool),
249
+ torch.zeros(N_splats_after_init-N_splats_at_init, dtype=torch.bool)],
250
+ axis=0)
251
+ self.GS.gaussians.prune_points(mask)
252
+ with torch.no_grad():
253
+ gaussians = self.gaussians
254
+ gaussians._scaling = gaussians.scaling_inverse_activation(gaussians.scaling_activation(gaussians._scaling)*0.5)
255
+ return visualization_dict
256
+
257
+
258
+ def prune(self, radii, min_opacity=0.005):
259
+ self.GS.gaussians.tmp_radii = radii
260
+ if self.gs_step < self.training_config.densify_until_iter:
261
+ prune_mask = (self.GS.gaussians.get_opacity < min_opacity).squeeze()
262
+ self.GS.gaussians.prune_points(prune_mask)
263
+ torch.cuda.empty_cache()
264
+ self.GS.gaussians.tmp_radii = None
265
+
source/utils_aux.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Perlin noise code taken from https://gist.github.com/adefossez/0646dbe9ed4005480a2407c62aac8869
2
+ from types import SimpleNamespace
3
+ import random
4
+ import numpy as np
5
+ import torch
6
+ import torchvision
7
+ import wandb
8
+ import random
9
+ import torchvision.transforms as T
10
+ import torchvision.transforms.functional as F
11
+ import torch
12
+ from PIL import Image
13
+
14
+ def parse_dict_to_namespace(dict_nested):
15
+ """Turns nested dictionary into nested namespaces"""
16
+ if type(dict_nested) != dict and type(dict_nested) != list: return dict_nested
17
+ x = SimpleNamespace()
18
+ for key, val in dict_nested.items():
19
+ if type(val) == dict:
20
+ setattr(x, key, parse_dict_to_namespace(val))
21
+ elif type(val) == list:
22
+ setattr(x, key, [parse_dict_to_namespace(v) for v in val])
23
+ else:
24
+ setattr(x, key, val)
25
+ return x
26
+
27
+ def set_seed(seed=42, cuda=True):
28
+ random.seed(seed)
29
+ np.random.seed(seed)
30
+ torch.manual_seed(seed)
31
+ if cuda:
32
+ torch.cuda.manual_seed_all(seed)
33
+
34
+
35
+
36
+ def log_samples(samples, scores, iteration, caption="Real Samples"):
37
+ # Create a grid of images
38
+ grid = torchvision.utils.make_grid(samples)
39
+
40
+ # Log the images and scores to wandb
41
+ wandb.log({
42
+ f"{caption}_images": [wandb.Image(grid, caption=f"{caption}: {scores}")],
43
+ }, step = iteration)
44
+
45
+
46
+
47
+ def pairwise_distances(matrix):
48
+ """
49
+ Computes the pairwise Euclidean distances between all vectors in the input matrix.
50
+
51
+ Args:
52
+ matrix (torch.Tensor): Input matrix of shape [N, D], where N is the number of vectors and D is the dimensionality.
53
+
54
+ Returns:
55
+ torch.Tensor: Pairwise distance matrix of shape [N, N].
56
+ """
57
+ # Compute squared pairwise distances
58
+ squared_diff = torch.cdist(matrix, matrix, p=2)
59
+ return squared_diff
60
+
61
+ def k_closest_vectors(matrix, k):
62
+ """
63
+ Finds the k-closest vectors for each vector in the input matrix based on Euclidean distance.
64
+
65
+ Args:
66
+ matrix (torch.Tensor): Input matrix of shape [N, D], where N is the number of vectors and D is the dimensionality.
67
+ k (int): Number of closest vectors to return for each vector.
68
+
69
+ Returns:
70
+ torch.Tensor: Indices of the k-closest vectors for each vector, excluding the vector itself.
71
+ """
72
+ # Compute pairwise distances
73
+ distances = pairwise_distances(matrix)
74
+
75
+ # For each vector, sort distances and get the indices of the k-closest vectors (excluding itself)
76
+ # Set diagonal distances to infinity to exclude the vector itself from the nearest neighbors
77
+ distances.fill_diagonal_(float('inf'))
78
+
79
+ # Get the indices of the k smallest distances (k-closest vectors)
80
+ _, indices = torch.topk(distances, k, largest=False, dim=1)
81
+
82
+ return indices
83
+
84
+ def process_image(image_tensor):
85
+ image_np = image_tensor.detach().cpu().numpy().transpose(1, 2, 0)
86
+ return Image.fromarray(np.clip(image_np * 255, 0, 255).astype(np.uint8))
87
+
88
+
89
+ def normalize_keypoints(kpts_np, width, height):
90
+ kpts_np[:, 0] = kpts_np[:, 0] / width * 2. - 1.
91
+ kpts_np[:, 1] = kpts_np[:, 1] / height * 2. - 1.
92
+ return kpts_np
source/utils_preprocess.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file contains function for video or image collection preprocessing.
2
+ # For video we do the preprocessing and select k sharpest frames.
3
+ # Afterwards scene is constructed
4
+ import cv2
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+ import pycolmap
8
+ import os
9
+ import time
10
+ import tempfile
11
+ from moviepy import VideoFileClip
12
+ from matplotlib import pyplot as plt
13
+ from PIL import Image
14
+ import cv2
15
+ from tqdm import tqdm
16
+
17
+ WORKDIR = "../outputs/"
18
+
19
+
20
+ def get_rotation_moviepy(video_path):
21
+ clip = VideoFileClip(video_path)
22
+ rotation = 0
23
+
24
+ try:
25
+ displaymatrix = clip.reader.infos['inputs'][0]['streams'][2]['metadata'].get('displaymatrix', '')
26
+ if 'rotation of' in displaymatrix:
27
+ angle = float(displaymatrix.strip().split('rotation of')[-1].split('degrees')[0])
28
+ rotation = int(angle) % 360
29
+
30
+ except Exception as e:
31
+ print(f"No displaymatrix rotation found: {e}")
32
+
33
+ clip.reader.close()
34
+ #if clip.audio:
35
+ # clip.audio.reader.close_proc()
36
+
37
+ return rotation
38
+
39
+ def resize_max_side(frame, max_size):
40
+ h, w = frame.shape[:2]
41
+ scale = max_size / max(h, w)
42
+ if scale < 1:
43
+ frame = cv2.resize(frame, (int(w * scale), int(h * scale)))
44
+ return frame
45
+
46
+ def read_video_frames(video_input, k=1, max_size=1024):
47
+ """
48
+ Extracts every k-th frame from a video or list of images, resizes to max size, and returns frames as list.
49
+
50
+ Parameters:
51
+ video_input (str, file-like, or list): Path to video file, file-like object, or list of image files.
52
+ k (int): Interval for frame extraction (every k-th frame).
53
+ max_size (int): Maximum size for width or height after resizing.
54
+
55
+ Returns:
56
+ frames (list): List of resized frames (numpy arrays).
57
+ """
58
+ # Handle list of image files (not single video in a list)
59
+ if isinstance(video_input, list):
60
+ # If it's a single video in a list, treat it as video
61
+ if len(video_input) == 1 and video_input[0].name.endswith(('.mp4', '.avi', '.mov')):
62
+ video_input = video_input[0] # unwrap single video file
63
+ else:
64
+ # Treat as list of images
65
+ frames = []
66
+ for img_file in video_input:
67
+ img = Image.open(img_file.name).convert("RGB")
68
+ img.thumbnail((max_size, max_size))
69
+ frames.append(np.array(img)[...,::-1])
70
+ return frames
71
+
72
+ # Handle file-like or path
73
+ if hasattr(video_input, 'name'):
74
+ video_path = video_input.name
75
+ elif isinstance(video_input, (str, os.PathLike)):
76
+ video_path = str(video_input)
77
+ else:
78
+ raise ValueError("Unsupported video input type. Must be a filepath, file-like object, or list of images.")
79
+
80
+
81
+ cap = cv2.VideoCapture(video_path)
82
+ if not cap.isOpened():
83
+ raise ValueError(f"Error: Could not open video {video_path}.")
84
+
85
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
86
+ frame_count = 0
87
+ frames = []
88
+
89
+ with tqdm(total=total_frames // k, desc="Processing Video", unit="frame") as pbar:
90
+ while True:
91
+ ret, frame = cap.read()
92
+ if not ret:
93
+ break
94
+ if frame_count % k == 0:
95
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
96
+ h, w = frame.shape[:2]
97
+ scale = max(h, w) / max_size
98
+ if scale > 1:
99
+ frame = cv2.resize(frame, (int(w / scale), int(h / scale)))
100
+ frames.append(frame[...,[2,1,0]])
101
+ pbar.update(1)
102
+ frame_count += 1
103
+
104
+ cap.release()
105
+ return frames
106
+
107
+ def resize_max_side(frame, max_size):
108
+ """
109
+ Resizes the frame so that its largest side equals max_size, maintaining aspect ratio.
110
+ """
111
+ height, width = frame.shape[:2]
112
+ max_dim = max(height, width)
113
+
114
+ if max_dim <= max_size:
115
+ return frame # No need to resize
116
+
117
+ scale = max_size / max_dim
118
+ new_width = int(width * scale)
119
+ new_height = int(height * scale)
120
+
121
+ resized_frame = cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_AREA)
122
+ return resized_frame
123
+
124
+
125
+
126
+ def variance_of_laplacian(image):
127
+ # compute the Laplacian of the image and then return the focus
128
+ # measure, which is simply the variance of the Laplacian
129
+ return cv2.Laplacian(image, cv2.CV_64F).var()
130
+
131
+ def process_all_frames(IMG_FOLDER = '/scratch/datasets/hq_data/night2_all_frames',
132
+ to_visualize=False,
133
+ save_images=True):
134
+ dict_scores = {}
135
+ for idx, img_name in tqdm(enumerate(sorted([x for x in os.listdir(IMG_FOLDER) if '.png' in x]))):
136
+
137
+ img = cv2.imread(os.path.join(IMG_FOLDER, img_name))#[250:, 100:]
138
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
139
+ fm = variance_of_laplacian(gray) + \
140
+ variance_of_laplacian(cv2.resize(gray, (0,0), fx=0.75, fy=0.75)) + \
141
+ variance_of_laplacian(cv2.resize(gray, (0,0), fx=0.5, fy=0.5)) + \
142
+ variance_of_laplacian(cv2.resize(gray, (0,0), fx=0.25, fy=0.25))
143
+ if to_visualize:
144
+ plt.figure()
145
+ plt.title(f"Laplacian score: {fm:.2f}")
146
+ plt.imshow(img[..., [2,1,0]])
147
+ plt.show()
148
+ dict_scores[idx] = {"idx" : idx,
149
+ "img_name" : img_name,
150
+ "score" : fm}
151
+ if save_images:
152
+ dict_scores[idx]["img"] = img
153
+
154
+ return dict_scores
155
+
156
+ def select_optimal_frames(scores, k):
157
+ """
158
+ Selects a minimal subset of frames while ensuring no gaps exceed k.
159
+
160
+ Args:
161
+ scores (list of float): List of scores where index represents frame number.
162
+ k (int): Maximum allowed gap between selected frames.
163
+
164
+ Returns:
165
+ list of int: Indices of selected frames.
166
+ """
167
+ n = len(scores)
168
+ selected = [0, n-1]
169
+ i = 0 # Start at the first frame
170
+
171
+ while i < n:
172
+ # Find the best frame to select within the next k frames
173
+ best_idx = max(range(i, min(i + k + 1, n)), key=lambda x: scores[x], default=None)
174
+
175
+ if best_idx is None:
176
+ break # No more frames left
177
+
178
+ selected.append(best_idx)
179
+ i = best_idx + k + 1 # Move forward, ensuring gaps stay within k
180
+
181
+ return sorted(selected)
182
+
183
+
184
+ def variance_of_laplacian(image):
185
+ """
186
+ Compute the variance of Laplacian as a focus measure.
187
+ """
188
+ return cv2.Laplacian(image, cv2.CV_64F).var()
189
+
190
+ def preprocess_frames(frames, verbose=False):
191
+ """
192
+ Compute sharpness scores for a list of frames using multi-scale Laplacian variance.
193
+
194
+ Args:
195
+ frames (list of np.ndarray): List of frames (BGR images).
196
+ verbose (bool): If True, print scores.
197
+
198
+ Returns:
199
+ list of float: Sharpness scores for each frame.
200
+ """
201
+ scores = []
202
+
203
+ for idx, frame in enumerate(tqdm(frames, desc="Scoring frames")):
204
+ gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
205
+
206
+ fm = (
207
+ variance_of_laplacian(gray) +
208
+ variance_of_laplacian(cv2.resize(gray, (0, 0), fx=0.75, fy=0.75)) +
209
+ variance_of_laplacian(cv2.resize(gray, (0, 0), fx=0.5, fy=0.5)) +
210
+ variance_of_laplacian(cv2.resize(gray, (0, 0), fx=0.25, fy=0.25))
211
+ )
212
+
213
+ if verbose:
214
+ print(f"Frame {idx}: Sharpness Score = {fm:.2f}")
215
+
216
+ scores.append(fm)
217
+
218
+ return scores
219
+
220
+ def select_optimal_frames(scores, k):
221
+ """
222
+ Selects k frames by splitting into k segments and picking the sharpest frame from each.
223
+
224
+ Args:
225
+ scores (list of float): List of sharpness scores.
226
+ k (int): Number of frames to select.
227
+
228
+ Returns:
229
+ list of int: Indices of selected frames.
230
+ """
231
+ n = len(scores)
232
+ selected_indices = []
233
+ segment_size = n // k
234
+
235
+ for i in range(k):
236
+ start = i * segment_size
237
+ end = (i + 1) * segment_size if i < k - 1 else n # Last chunk may be larger
238
+ segment_scores = scores[start:end]
239
+
240
+ if len(segment_scores) == 0:
241
+ continue # Safety check if some segment is empty
242
+
243
+ best_in_segment = start + np.argmax(segment_scores)
244
+ selected_indices.append(best_in_segment)
245
+
246
+ return sorted(selected_indices)
247
+
248
+ def save_frames_to_scene_dir(frames, scene_dir):
249
+ """
250
+ Saves a list of frames into the target scene directory under 'images/' subfolder.
251
+
252
+ Args:
253
+ frames (list of np.ndarray): List of frames (BGR images) to save.
254
+ scene_dir (str): Target path where 'images/' subfolder will be created.
255
+ """
256
+ images_dir = os.path.join(scene_dir, "images")
257
+ os.makedirs(images_dir, exist_ok=True)
258
+
259
+ for idx, frame in enumerate(frames):
260
+ filename = os.path.join(images_dir, f"{idx:08d}.png") # 00000000.png, 00000001.png, etc.
261
+ cv2.imwrite(filename, frame)
262
+
263
+ print(f"Saved {len(frames)} frames to {images_dir}")
264
+
265
+
266
+ def run_colmap_on_scene(scene_dir):
267
+ """
268
+ Runs feature extraction, matching, and mapping on all images inside scene_dir/images using pycolmap.
269
+
270
+ Args:
271
+ scene_dir (str): Path to scene directory containing 'images' folder.
272
+
273
+ TODO: if the function hasn't managed to match all the frames either increase image size,
274
+ increase number of features or just remove those frames from the folder scene_dir/images
275
+ """
276
+ start_time = time.time()
277
+ print(f"Running COLMAP pipeline on all images inside {scene_dir}")
278
+
279
+ # Setup paths
280
+ database_path = os.path.join(scene_dir, "database.db")
281
+ sparse_path = os.path.join(scene_dir, "sparse")
282
+ image_dir = os.path.join(scene_dir, "images")
283
+
284
+ # Make sure output directories exist
285
+ os.makedirs(sparse_path, exist_ok=True)
286
+
287
+ # Step 1: Feature Extraction
288
+ pycolmap.extract_features(
289
+ database_path,
290
+ image_dir,
291
+ sift_options={
292
+ "max_num_features": 512 * 2,
293
+ "max_image_size": 512 * 1,
294
+ }
295
+ )
296
+ print(f"Finished feature extraction in {(time.time() - start_time):.2f}s.")
297
+
298
+ # Step 2: Feature Matching
299
+ pycolmap.match_exhaustive(database_path)
300
+ print(f"Finished feature matching in {(time.time() - start_time):.2f}s.")
301
+
302
+ # Step 3: Mapping
303
+ pipeline_options = pycolmap.IncrementalPipelineOptions()
304
+ pipeline_options.min_num_matches = 15
305
+ pipeline_options.multiple_models = True
306
+ pipeline_options.max_num_models = 50
307
+ pipeline_options.max_model_overlap = 20
308
+ pipeline_options.min_model_size = 10
309
+ pipeline_options.extract_colors = True
310
+ pipeline_options.num_threads = 8
311
+ pipeline_options.mapper.init_min_num_inliers = 30
312
+ pipeline_options.mapper.init_max_error = 8.0
313
+ pipeline_options.mapper.init_min_tri_angle = 5.0
314
+
315
+ reconstruction = pycolmap.incremental_mapping(
316
+ database_path=database_path,
317
+ image_path=image_dir,
318
+ output_path=sparse_path,
319
+ options=pipeline_options,
320
+ )
321
+ print(f"Finished incremental mapping in {(time.time() - start_time):.2f}s.")
322
+
323
+ # Step 4: Post-process Cameras to SIMPLE_PINHOLE
324
+ recon_path = os.path.join(sparse_path, "0")
325
+ reconstruction = pycolmap.Reconstruction(recon_path)
326
+
327
+ for cam in reconstruction.cameras.values():
328
+ cam.model = 'SIMPLE_PINHOLE'
329
+ cam.params = cam.params[:3] # Keep only [f, cx, cy]
330
+
331
+ reconstruction.write(recon_path)
332
+
333
+ print(f"Total pipeline time: {(time.time() - start_time):.2f}s.")
334
+
source/visualization.py ADDED
@@ -0,0 +1,1072 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from matplotlib import pyplot as plt
2
+ import numpy as np
3
+ import torch
4
+
5
+ import numpy as np
6
+ from typing import List
7
+ import sys
8
+ sys.path.append('./submodules/gaussian-splatting/')
9
+ from scene.cameras import Camera
10
+ from PIL import Image
11
+ import imageio
12
+ from scipy.interpolate import splprep, splev
13
+
14
+ import cv2
15
+ import numpy as np
16
+ import plotly.graph_objects as go
17
+ import numpy as np
18
+ from scipy.spatial.transform import Rotation as R, Slerp
19
+ from scipy.spatial import distance_matrix
20
+ from sklearn.decomposition import PCA
21
+ from scipy.interpolate import splprep, splev
22
+ from typing import List
23
+ from sklearn.mixture import GaussianMixture
24
+
25
+ def render_gaussians_rgb(generator3DGS, viewpoint_cam, visualize=False):
26
+ """
27
+ Simply render gaussians from the generator3DGS from the viewpoint_cam.
28
+ Args:
29
+ generator3DGS : instance of the Generator3DGS class from the networks.py file
30
+ viewpoint_cam : camera instance
31
+ visualize : boolean flag. If True, will call pyplot function and render image inplace
32
+ Returns:
33
+ uint8 numpy array with shape (H, W, 3) representing the image
34
+ """
35
+ with torch.no_grad():
36
+ render_pkg = generator3DGS(viewpoint_cam)
37
+ image = render_pkg["render"]
38
+ image_np = image.clone().detach().cpu().numpy().transpose(1, 2, 0)
39
+
40
+ # Clip values to be in the range [0, 1]
41
+ image_np = np.clip(image_np * 255, 0, 255).astype(np.uint8)
42
+ if visualize:
43
+ plt.figure(figsize=(12, 8))
44
+ plt.imshow(image_np)
45
+ plt.show()
46
+
47
+ return image_np
48
+
49
+ def render_gaussians_D_scores(generator3DGS, viewpoint_cam, mask=None, mask_channel=0, visualize=False):
50
+ """
51
+ Simply render D_scores of gaussians from the generator3DGS from the viewpoint_cam.
52
+ Args:
53
+ generator3DGS : instance of the Generator3DGS class from the networks.py file
54
+ viewpoint_cam : camera instance
55
+ visualize : boolean flag. If True, will call pyplot function and render image inplace
56
+ mask : optional mask to highlight specific gaussians. Must be of shape (N) where N is the numnber
57
+ of gaussians in generator3DGS.gaussians. Must be a torch tensor of floats, please scale according
58
+ to how much color you want to have. Recommended mask value is 10.
59
+ mask_channel: to which color channel should we add mask
60
+ Returns:
61
+ uint8 numpy array with shape (H, W, 3) representing the generator3DGS.gaussians.D_scores rendered as colors
62
+ """
63
+ with torch.no_grad():
64
+ # Visualize D_scores
65
+ generator3DGS.gaussians._features_dc = generator3DGS.gaussians._features_dc * 1e-4 + \
66
+ torch.stack([generator3DGS.gaussians.D_scores] * 3, axis=-1)
67
+ generator3DGS.gaussians._features_rest = generator3DGS.gaussians._features_rest * 1e-4
68
+ if mask is not None:
69
+ generator3DGS.gaussians._features_dc[..., mask_channel] += mask.unsqueeze(-1)
70
+ render_pkg = generator3DGS(viewpoint_cam)
71
+ image = render_pkg["render"]
72
+ image_np = image.clone().detach().cpu().numpy().transpose(1, 2, 0)
73
+
74
+ # Clip values to be in the range [0, 1]
75
+ image_np = np.clip(image_np * 255, 0, 255).astype(np.uint8)
76
+ if visualize:
77
+ plt.figure(figsize=(12, 8))
78
+ plt.imshow(image_np)
79
+ plt.show()
80
+
81
+ if mask is not None:
82
+ generator3DGS.gaussians._features_dc[..., mask_channel] -= mask.unsqueeze(-1)
83
+
84
+ generator3DGS.gaussians._features_dc = (generator3DGS.gaussians._features_dc - \
85
+ torch.stack([generator3DGS.gaussians.D_scores] * 3, axis=-1)) * 1e4
86
+ generator3DGS.gaussians._features_rest = generator3DGS.gaussians._features_rest * 1e4
87
+
88
+ return image_np
89
+
90
+
91
+
92
+ def normalize(v):
93
+ """
94
+ Normalize a vector to unit length.
95
+
96
+ Parameters:
97
+ v (np.ndarray): Input vector.
98
+
99
+ Returns:
100
+ np.ndarray: Unit vector in the same direction as `v`.
101
+ """
102
+ return v / np.linalg.norm(v)
103
+
104
+ def look_at_rotation(camera_position: np.ndarray, target: np.ndarray, world_up=np.array([0, 1, 0])):
105
+ """
106
+ Compute a rotation matrix for a camera looking at a target point.
107
+
108
+ Parameters:
109
+ camera_position (np.ndarray): The 3D position of the camera.
110
+ target (np.ndarray): The point the camera should look at.
111
+ world_up (np.ndarray): A vector that defines the global 'up' direction.
112
+
113
+ Returns:
114
+ np.ndarray: A 3x3 rotation matrix (camera-to-world) with columns [right, up, forward].
115
+ """
116
+ z_axis = normalize(target - camera_position) # Forward direction
117
+ x_axis = normalize(np.cross(world_up, z_axis)) # Right direction
118
+ y_axis = np.cross(z_axis, x_axis) # Recomputed up
119
+ return np.stack([x_axis, y_axis, z_axis], axis=1)
120
+
121
+
122
+ def generate_circular_camera_path(existing_cameras: List[Camera], N: int = 12, radius_scale: float = 1.0, d: float = 2.0) -> List[Camera]:
123
+ """
124
+ Generate a circular path of cameras around an existing camera group,
125
+ with each new camera oriented to look at the average viewing direction.
126
+
127
+ Parameters:
128
+ existing_cameras (List[Camera]): List of existing camera objects to estimate average orientation and layout.
129
+ N (int): Number of new cameras to generate along the circular path.
130
+ radius_scale (float): Scale factor to adjust the radius of the circle.
131
+ d (float): Distance ahead of each camera used to estimate its look-at point.
132
+
133
+ Returns:
134
+ List[Camera]: A list of newly generated Camera objects forming a circular path and oriented toward a shared view center.
135
+ """
136
+ # Step 1: Compute average camera position
137
+ center = np.mean([cam.T for cam in existing_cameras], axis=0)
138
+
139
+ # Estimate where each camera is looking
140
+ # d denotes how far ahead each camera sees — you can scale this
141
+ look_targets = [cam.T + cam.R[:, 2] * d for cam in existing_cameras]
142
+ center_of_view = np.mean(look_targets, axis=0)
143
+
144
+ # Step 2: Define circular plane basis using fixed up vector
145
+ avg_forward = normalize(np.mean([cam.R[:, 2] for cam in existing_cameras], axis=0))
146
+ up_guess = np.array([0, 1, 0])
147
+ right = normalize(np.cross(avg_forward, up_guess))
148
+ up = normalize(np.cross(right, avg_forward))
149
+
150
+ # Step 3: Estimate radius
151
+ avg_radius = np.mean([np.linalg.norm(cam.T - center) for cam in existing_cameras]) * radius_scale
152
+
153
+ # Step 4: Create cameras on a circular path
154
+ angles = np.linspace(0, 2 * np.pi, N, endpoint=False)
155
+ reference_cam = existing_cameras[0]
156
+ new_cameras = []
157
+
158
+
159
+ for i, a in enumerate(angles):
160
+ position = center + avg_radius * (np.cos(a) * right + np.sin(a) * up)
161
+
162
+ if d < 1e-5 or radius_scale < 1e-5:
163
+ # Use same orientation as the first camera
164
+ R = reference_cam.R.copy()
165
+ else:
166
+ # Change orientation
167
+ R = look_at_rotation(position, center_of_view)
168
+ new_cameras.append(Camera(
169
+ R=R,
170
+ T=position, # New position
171
+ FoVx=reference_cam.FoVx,
172
+ FoVy=reference_cam.FoVy,
173
+ resolution=(reference_cam.image_width, reference_cam.image_height),
174
+ colmap_id=-1,
175
+ depth_params=None,
176
+ image=Image.fromarray(np.zeros((reference_cam.image_height, reference_cam.image_width, 3), dtype=np.uint8)),
177
+ invdepthmap=None,
178
+ image_name=f"circular_a={a:.3f}",
179
+ uid=i
180
+ ))
181
+
182
+ return new_cameras
183
+
184
+
185
+ def save_numpy_frames_as_gif(frames, output_path="animation.gif", duration=100):
186
+ """
187
+ Save a list of RGB NumPy frames as a looping GIF animation.
188
+
189
+ Parameters:
190
+ frames (List[np.ndarray]): List of RGB images as uint8 NumPy arrays (shape HxWx3).
191
+ output_path (str): Path to save the output GIF.
192
+ duration (int): Duration per frame in milliseconds.
193
+
194
+ Returns:
195
+ None
196
+ """
197
+ pil_frames = [Image.fromarray(f) for f in frames]
198
+ pil_frames[0].save(
199
+ output_path,
200
+ save_all=True,
201
+ append_images=pil_frames[1:],
202
+ duration=duration, # duration per frame in ms
203
+ loop=0
204
+ )
205
+ print(f"GIF saved to: {output_path}")
206
+
207
+ def center_crop_frame(frame: np.ndarray, crop_fraction: float) -> np.ndarray:
208
+ """
209
+ Crop the central region of the frame by the given fraction.
210
+
211
+ Parameters:
212
+ frame (np.ndarray): Input RGB image (H, W, 3).
213
+ crop_fraction (float): Fraction of the original size to retain (e.g., 0.8 keeps 80%).
214
+
215
+ Returns:
216
+ np.ndarray: Cropped RGB image.
217
+ """
218
+ if crop_fraction >= 1.0:
219
+ return frame
220
+
221
+ h, w, _ = frame.shape
222
+ new_h, new_w = int(h * crop_fraction), int(w * crop_fraction)
223
+ start_y = (h - new_h) // 2
224
+ start_x = (w - new_w) // 2
225
+ return frame[start_y:start_y + new_h, start_x:start_x + new_w, :]
226
+
227
+
228
+
229
+ def generate_smooth_closed_camera_path(existing_cameras: List[Camera], N: int = 120, d: float = 2.0, s=.25) -> List[Camera]:
230
+ """
231
+ Generate a smooth, closed path interpolating the positions of existing cameras.
232
+
233
+ Parameters:
234
+ existing_cameras (List[Camera]): List of existing cameras.
235
+ N (int): Number of points (cameras) to sample along the smooth path.
236
+ d (float): Distance ahead for estimating the center of view.
237
+
238
+ Returns:
239
+ List[Camera]: A list of smoothly moving Camera objects along a closed loop.
240
+ """
241
+ # Step 1: Extract camera positions
242
+ positions = np.array([cam.T for cam in existing_cameras])
243
+
244
+ # Step 2: Estimate center of view
245
+ look_targets = [cam.T + cam.R[:, 2] * d for cam in existing_cameras]
246
+ center_of_view = np.mean(look_targets, axis=0)
247
+
248
+ # Step 3: Fit a smooth closed spline through the positions
249
+ positions = np.vstack([positions, positions[0]]) # close the loop
250
+ tck, u = splprep(positions.T, s=s, per=True) # periodic=True for closed loop
251
+
252
+ # Step 4: Sample points along the spline
253
+ u_fine = np.linspace(0, 1, N)
254
+ smooth_path = np.stack(splev(u_fine, tck), axis=-1)
255
+
256
+ # Step 5: Generate cameras along the smooth path
257
+ reference_cam = existing_cameras[0]
258
+ new_cameras = []
259
+
260
+ for i, pos in enumerate(smooth_path):
261
+ R = look_at_rotation(pos, center_of_view)
262
+ new_cameras.append(Camera(
263
+ R=R,
264
+ T=pos,
265
+ FoVx=reference_cam.FoVx,
266
+ FoVy=reference_cam.FoVy,
267
+ resolution=(reference_cam.image_width, reference_cam.image_height),
268
+ colmap_id=-1,
269
+ depth_params=None,
270
+ image=Image.fromarray(np.zeros((reference_cam.image_height, reference_cam.image_width, 3), dtype=np.uint8)),
271
+ invdepthmap=None,
272
+ image_name=f"smooth_path_i={i}",
273
+ uid=i
274
+ ))
275
+
276
+ return new_cameras
277
+
278
+
279
+ def save_numpy_frames_as_mp4(frames, output_path="animation.mp4", fps=10, center_crop: float = 1.0):
280
+ """
281
+ Save a list of RGB NumPy frames as an MP4 video with optional center cropping.
282
+
283
+ Parameters:
284
+ frames (List[np.ndarray]): List of RGB images as uint8 NumPy arrays (shape HxWx3).
285
+ output_path (str): Path to save the output MP4.
286
+ fps (int): Frames per second for playback speed.
287
+ center_crop (float): Fraction (0 < center_crop <= 1.0) of central region to retain.
288
+ Use 1.0 for no cropping; 0.8 to crop to 80% center region.
289
+
290
+ Returns:
291
+ None
292
+ """
293
+ with imageio.get_writer(output_path, fps=fps, codec='libx264', quality=8) as writer:
294
+ for frame in frames:
295
+ cropped = center_crop_frame(frame, center_crop)
296
+ writer.append_data(cropped)
297
+ print(f"MP4 saved to: {output_path}")
298
+
299
+
300
+
301
+ def put_text_on_image(img: np.ndarray, text: str) -> np.ndarray:
302
+ """
303
+ Draws multiline white text on a copy of the input image, positioned near the bottom
304
+ and around 80% of the image width. Handles '\n' characters to split text into multiple lines.
305
+
306
+ Args:
307
+ img (np.ndarray): Input image as a (H, W, 3) uint8 numpy array.
308
+ text (str): Text string to draw on the image. Newlines '\n' are treated as line breaks.
309
+
310
+ Returns:
311
+ np.ndarray: The output image with the text drawn on it.
312
+
313
+ Notes:
314
+ - The function automatically adjusts line spacing and prevents text from going outside the image.
315
+ - Text is drawn in white with small font size (0.5) for minimal visual impact.
316
+ """
317
+ img = img.copy()
318
+ height, width, _ = img.shape
319
+
320
+ font = cv2.FONT_HERSHEY_SIMPLEX
321
+ font_scale = 1.
322
+ color = (255, 255, 255)
323
+ thickness = 2
324
+ line_spacing = 5 # extra pixels between lines
325
+
326
+ lines = text.split('\n')
327
+
328
+ # Precompute the maximum text width to adjust starting x
329
+ max_text_width = max(cv2.getTextSize(line, font, font_scale, thickness)[0][0] for line in lines)
330
+
331
+ x = int(0.8 * width)
332
+ x = min(x, width - max_text_width - 30) # margin on right
333
+ #x = int(0.03 * width)
334
+
335
+ # Start near the bottom, but move up depending on number of lines
336
+ total_text_height = len(lines) * (cv2.getTextSize('A', font, font_scale, thickness)[0][1] + line_spacing)
337
+ y_start = int(height*0.9) - total_text_height # 30 pixels from bottom
338
+
339
+ for i, line in enumerate(lines):
340
+ y = y_start + i * (cv2.getTextSize(line, font, font_scale, thickness)[0][1] + line_spacing)
341
+ cv2.putText(img, line, (x, y), font, font_scale, color, thickness, cv2.LINE_AA)
342
+
343
+ return img
344
+
345
+
346
+
347
+
348
+ def catmull_rom_spline(P0, P1, P2, P3, n_points=20):
349
+ """
350
+ Compute Catmull-Rom spline segment between P1 and P2.
351
+ """
352
+ t = np.linspace(0, 1, n_points)[:, None]
353
+
354
+ M = 0.5 * np.array([
355
+ [-1, 3, -3, 1],
356
+ [ 2, -5, 4, -1],
357
+ [-1, 0, 1, 0],
358
+ [ 0, 2, 0, 0]
359
+ ])
360
+
361
+ G = np.stack([P0, P1, P2, P3], axis=0)
362
+ T = np.concatenate([t**3, t**2, t, np.ones_like(t)], axis=1)
363
+
364
+ return T @ M @ G
365
+
366
+ def sort_cameras_pca(existing_cameras: List[Camera]):
367
+ """
368
+ Sort cameras along the main PCA axis.
369
+ """
370
+ positions = np.array([cam.T for cam in existing_cameras])
371
+ pca = PCA(n_components=1)
372
+ scores = pca.fit_transform(positions)
373
+ sorted_indices = np.argsort(scores[:, 0])
374
+ return sorted_indices
375
+
376
+ def generate_fully_smooth_cameras(existing_cameras: List[Camera],
377
+ n_selected: int = 30,
378
+ n_points_per_segment: int = 20,
379
+ d: float = 2.0,
380
+ closed: bool = False) -> List[Camera]:
381
+ """
382
+ Generate a fully smooth camera path using PCA ordering, global Catmull-Rom spline for positions, and global SLERP for orientations.
383
+
384
+ Args:
385
+ existing_cameras (List[Camera]): List of input cameras.
386
+ n_selected (int): Number of cameras to select after sorting.
387
+ n_points_per_segment (int): Number of interpolated points per spline segment.
388
+ d (float): Distance ahead for estimating center of view.
389
+ closed (bool): Whether to close the path.
390
+
391
+ Returns:
392
+ List[Camera]: List of smoothly moving Camera objects.
393
+ """
394
+ # 1. Sort cameras along PCA axis
395
+ sorted_indices = sort_cameras_pca(existing_cameras)
396
+ sorted_cameras = [existing_cameras[i] for i in sorted_indices]
397
+ positions = np.array([cam.T for cam in sorted_cameras])
398
+
399
+ # 2. Subsample uniformly
400
+ idx = np.linspace(0, len(positions) - 1, n_selected).astype(int)
401
+ sampled_positions = positions[idx]
402
+ sampled_cameras = [sorted_cameras[i] for i in idx]
403
+
404
+ # 3. Prepare for Catmull-Rom
405
+ if closed:
406
+ sampled_positions = np.vstack([sampled_positions[-1], sampled_positions, sampled_positions[0], sampled_positions[1]])
407
+ else:
408
+ sampled_positions = np.vstack([sampled_positions[0], sampled_positions, sampled_positions[-1], sampled_positions[-1]])
409
+
410
+ # 4. Generate smooth path positions
411
+ path_positions = []
412
+ for i in range(1, len(sampled_positions) - 2):
413
+ segment = catmull_rom_spline(sampled_positions[i-1], sampled_positions[i], sampled_positions[i+1], sampled_positions[i+2], n_points_per_segment)
414
+ path_positions.append(segment)
415
+ path_positions = np.concatenate(path_positions, axis=0)
416
+
417
+ # 5. Global SLERP for rotations
418
+ rotations = R.from_matrix([cam.R for cam in sampled_cameras])
419
+ key_times = np.linspace(0, 1, len(rotations))
420
+ slerp = Slerp(key_times, rotations)
421
+
422
+ query_times = np.linspace(0, 1, len(path_positions))
423
+ interpolated_rotations = slerp(query_times)
424
+
425
+ # 6. Generate Camera objects
426
+ reference_cam = existing_cameras[0]
427
+ smooth_cameras = []
428
+
429
+ for i, pos in enumerate(path_positions):
430
+ R_interp = interpolated_rotations[i].as_matrix()
431
+
432
+ smooth_cameras.append(Camera(
433
+ R=R_interp,
434
+ T=pos,
435
+ FoVx=reference_cam.FoVx,
436
+ FoVy=reference_cam.FoVy,
437
+ resolution=(reference_cam.image_width, reference_cam.image_height),
438
+ colmap_id=-1,
439
+ depth_params=None,
440
+ image=Image.fromarray(np.zeros((reference_cam.image_height, reference_cam.image_width, 3), dtype=np.uint8)),
441
+ invdepthmap=None,
442
+ image_name=f"fully_smooth_path_i={i}",
443
+ uid=i
444
+ ))
445
+
446
+ return smooth_cameras
447
+
448
+
449
+ def plot_cameras_and_smooth_path_with_orientation(existing_cameras: List[Camera], smooth_cameras: List[Camera], scale: float = 0.1):
450
+ """
451
+ Plot input cameras and smooth path cameras with their orientations in 3D.
452
+
453
+ Args:
454
+ existing_cameras (List[Camera]): List of original input cameras.
455
+ smooth_cameras (List[Camera]): List of smooth path cameras.
456
+ scale (float): Length of orientation arrows.
457
+
458
+ Returns:
459
+ None
460
+ """
461
+ # Input cameras
462
+ input_positions = np.array([cam.T for cam in existing_cameras])
463
+
464
+ # Smooth cameras
465
+ smooth_positions = np.array([cam.T for cam in smooth_cameras])
466
+
467
+ fig = go.Figure()
468
+
469
+ # Plot input camera positions
470
+ fig.add_trace(go.Scatter3d(
471
+ x=input_positions[:, 0], y=input_positions[:, 1], z=input_positions[:, 2],
472
+ mode='markers',
473
+ marker=dict(size=4, color='blue'),
474
+ name='Input Cameras'
475
+ ))
476
+
477
+ # Plot smooth path positions
478
+ fig.add_trace(go.Scatter3d(
479
+ x=smooth_positions[:, 0], y=smooth_positions[:, 1], z=smooth_positions[:, 2],
480
+ mode='lines+markers',
481
+ line=dict(color='red', width=3),
482
+ marker=dict(size=2, color='red'),
483
+ name='Smooth Path Cameras'
484
+ ))
485
+
486
+ # Plot input camera orientations
487
+ for cam in existing_cameras:
488
+ origin = cam.T
489
+ forward = cam.R[:, 2] # Forward direction
490
+
491
+ fig.add_trace(go.Cone(
492
+ x=[origin[0]], y=[origin[1]], z=[origin[2]],
493
+ u=[forward[0]], v=[forward[1]], w=[forward[2]],
494
+ colorscale=[[0, 'blue'], [1, 'blue']],
495
+ sizemode="absolute",
496
+ sizeref=scale,
497
+ anchor="tail",
498
+ showscale=False,
499
+ name='Input Camera Direction'
500
+ ))
501
+
502
+ # Plot smooth camera orientations
503
+ for cam in smooth_cameras:
504
+ origin = cam.T
505
+ forward = cam.R[:, 2] # Forward direction
506
+
507
+ fig.add_trace(go.Cone(
508
+ x=[origin[0]], y=[origin[1]], z=[origin[2]],
509
+ u=[forward[0]], v=[forward[1]], w=[forward[2]],
510
+ colorscale=[[0, 'red'], [1, 'red']],
511
+ sizemode="absolute",
512
+ sizeref=scale,
513
+ anchor="tail",
514
+ showscale=False,
515
+ name='Smooth Camera Direction'
516
+ ))
517
+
518
+ fig.update_layout(
519
+ scene=dict(
520
+ xaxis_title='X',
521
+ yaxis_title='Y',
522
+ zaxis_title='Z',
523
+ aspectmode='data'
524
+ ),
525
+ title="Input Cameras and Smooth Path with Orientations",
526
+ margin=dict(l=0, r=0, b=0, t=30)
527
+ )
528
+
529
+ fig.show()
530
+
531
+
532
+ def solve_tsp_nearest_neighbor(points: np.ndarray):
533
+ """
534
+ Solve TSP approximately using nearest neighbor heuristic.
535
+
536
+ Args:
537
+ points (np.ndarray): (N, 3) array of points.
538
+
539
+ Returns:
540
+ List[int]: Optimal visiting order of points.
541
+ """
542
+ N = points.shape[0]
543
+ dist = distance_matrix(points, points)
544
+ visited = [0]
545
+ unvisited = set(range(1, N))
546
+
547
+ while unvisited:
548
+ last = visited[-1]
549
+ next_city = min(unvisited, key=lambda city: dist[last, city])
550
+ visited.append(next_city)
551
+ unvisited.remove(next_city)
552
+
553
+ return visited
554
+
555
+ def solve_tsp_2opt(points: np.ndarray, n_iter: int = 1000) -> np.ndarray:
556
+ """
557
+ Solve TSP approximately using Nearest Neighbor + 2-Opt.
558
+
559
+ Args:
560
+ points (np.ndarray): Array of shape (N, D) with points.
561
+ n_iter (int): Number of 2-opt iterations.
562
+
563
+ Returns:
564
+ np.ndarray: Ordered list of indices.
565
+ """
566
+ n_points = points.shape[0]
567
+
568
+ # === 1. Start with Nearest Neighbor
569
+ unvisited = list(range(n_points))
570
+ current = unvisited.pop(0)
571
+ path = [current]
572
+
573
+ while unvisited:
574
+ dists = np.linalg.norm(points[unvisited] - points[current], axis=1)
575
+ next_idx = unvisited[np.argmin(dists)]
576
+ unvisited.remove(next_idx)
577
+ path.append(next_idx)
578
+ current = next_idx
579
+
580
+ # === 2. Apply 2-Opt improvements
581
+ def path_length(path):
582
+ return np.sum(np.linalg.norm(points[path[i]] - points[path[i+1]], axis=0) for i in range(len(path)-1))
583
+
584
+ best_length = path_length(path)
585
+ improved = True
586
+
587
+ for _ in range(n_iter):
588
+ if not improved:
589
+ break
590
+ improved = False
591
+ for i in range(1, n_points - 2):
592
+ for j in range(i + 1, n_points):
593
+ if j - i == 1: continue
594
+ new_path = path[:i] + path[i:j][::-1] + path[j:]
595
+ new_length = path_length(new_path)
596
+ if new_length < best_length:
597
+ path = new_path
598
+ best_length = new_length
599
+ improved = True
600
+ break
601
+ if improved:
602
+ break
603
+
604
+ return np.array(path)
605
+
606
+ def generate_fully_smooth_cameras_with_tsp(existing_cameras: List[Camera],
607
+ n_selected: int = 30,
608
+ n_points_per_segment: int = 20,
609
+ d: float = 2.0,
610
+ closed: bool = False) -> List[Camera]:
611
+ """
612
+ Generate a fully smooth camera path using TSP ordering, global Catmull-Rom spline for positions, and global SLERP for orientations.
613
+
614
+ Args:
615
+ existing_cameras (List[Camera]): List of input cameras.
616
+ n_selected (int): Number of cameras to select after ordering.
617
+ n_points_per_segment (int): Number of interpolated points per spline segment.
618
+ d (float): Distance ahead for estimating center of view.
619
+ closed (bool): Whether to close the path.
620
+
621
+ Returns:
622
+ List[Camera]: List of smoothly moving Camera objects.
623
+ """
624
+ positions = np.array([cam.T for cam in existing_cameras])
625
+
626
+ # 1. Solve approximate TSP
627
+ order = solve_tsp_nearest_neighbor(positions)
628
+ ordered_cameras = [existing_cameras[i] for i in order]
629
+ ordered_positions = positions[order]
630
+
631
+ # 2. Subsample uniformly
632
+ idx = np.linspace(0, len(ordered_positions) - 1, n_selected).astype(int)
633
+ sampled_positions = ordered_positions[idx]
634
+ sampled_cameras = [ordered_cameras[i] for i in idx]
635
+
636
+ # 3. Prepare for Catmull-Rom
637
+ if closed:
638
+ sampled_positions = np.vstack([sampled_positions[-1], sampled_positions, sampled_positions[0], sampled_positions[1]])
639
+ else:
640
+ sampled_positions = np.vstack([sampled_positions[0], sampled_positions, sampled_positions[-1], sampled_positions[-1]])
641
+
642
+ # 4. Generate smooth path positions
643
+ path_positions = []
644
+ for i in range(1, len(sampled_positions) - 2):
645
+ segment = catmull_rom_spline(sampled_positions[i-1], sampled_positions[i], sampled_positions[i+1], sampled_positions[i+2], n_points_per_segment)
646
+ path_positions.append(segment)
647
+ path_positions = np.concatenate(path_positions, axis=0)
648
+
649
+ # 5. Global SLERP for rotations
650
+ rotations = R.from_matrix([cam.R for cam in sampled_cameras])
651
+ key_times = np.linspace(0, 1, len(rotations))
652
+ slerp = Slerp(key_times, rotations)
653
+
654
+ query_times = np.linspace(0, 1, len(path_positions))
655
+ interpolated_rotations = slerp(query_times)
656
+
657
+ # 6. Generate Camera objects
658
+ reference_cam = existing_cameras[0]
659
+ smooth_cameras = []
660
+
661
+ for i, pos in enumerate(path_positions):
662
+ R_interp = interpolated_rotations[i].as_matrix()
663
+
664
+ smooth_cameras.append(Camera(
665
+ R=R_interp,
666
+ T=pos,
667
+ FoVx=reference_cam.FoVx,
668
+ FoVy=reference_cam.FoVy,
669
+ resolution=(reference_cam.image_width, reference_cam.image_height),
670
+ colmap_id=-1,
671
+ depth_params=None,
672
+ image=Image.fromarray(np.zeros((reference_cam.image_height, reference_cam.image_width, 3), dtype=np.uint8)),
673
+ invdepthmap=None,
674
+ image_name=f"fully_smooth_path_i={i}",
675
+ uid=i
676
+ ))
677
+
678
+ return smooth_cameras
679
+
680
+ from typing import List
681
+ import numpy as np
682
+ from sklearn.mixture import GaussianMixture
683
+ from scipy.spatial.transform import Rotation as R, Slerp
684
+ from PIL import Image
685
+
686
+ def generate_clustered_smooth_cameras_with_tsp(existing_cameras: List[Camera],
687
+ n_selected: int = 30,
688
+ n_points_per_segment: int = 20,
689
+ d: float = 2.0,
690
+ n_clusters: int = 5,
691
+ closed: bool = False) -> List[Camera]:
692
+ """
693
+ Generate a fully smooth camera path using clustering + TSP between nearest cluster centers + TSP inside clusters.
694
+ Positions are normalized before clustering and denormalized before generating final cameras.
695
+
696
+ Args:
697
+ existing_cameras (List[Camera]): List of input cameras.
698
+ n_selected (int): Number of cameras to select after ordering.
699
+ n_points_per_segment (int): Number of interpolated points per spline segment.
700
+ d (float): Distance ahead for estimating center of view.
701
+ n_clusters (int): Number of GMM clusters.
702
+ closed (bool): Whether to close the path.
703
+
704
+ Returns:
705
+ List[Camera]: Smooth path of Camera objects.
706
+ """
707
+ # Extract positions and rotations
708
+ positions = np.array([cam.T for cam in existing_cameras])
709
+ rotations = np.array([R.from_matrix(cam.R).as_quat() for cam in existing_cameras])
710
+
711
+ # === Normalize positions
712
+ mean_pos = np.mean(positions, axis=0)
713
+ scale_pos = np.std(positions, axis=0)
714
+ scale_pos[scale_pos == 0] = 1.0 # avoid division by zero
715
+
716
+ positions_normalized = (positions - mean_pos) / scale_pos
717
+
718
+ # === Features for clustering (only positions, not rotations)
719
+ features = positions_normalized
720
+
721
+ # === 1. GMM clustering
722
+ gmm = GaussianMixture(n_components=n_clusters, covariance_type='full', random_state=42)
723
+ cluster_labels = gmm.fit_predict(features)
724
+
725
+ clusters = {}
726
+ cluster_centers = []
727
+
728
+ for cluster_id in range(n_clusters):
729
+ cluster_indices = np.where(cluster_labels == cluster_id)[0]
730
+ if len(cluster_indices) == 0:
731
+ continue
732
+ clusters[cluster_id] = cluster_indices
733
+ cluster_center = np.mean(features[cluster_indices], axis=0)
734
+ cluster_centers.append(cluster_center)
735
+
736
+ cluster_centers = np.stack(cluster_centers)
737
+
738
+ # === 2. Remap cluster centers to nearest existing cameras
739
+ if False:
740
+ mapped_centers = []
741
+ for center in cluster_centers:
742
+ dists = np.linalg.norm(features - center, axis=1)
743
+ nearest_idx = np.argmin(dists)
744
+ mapped_centers.append(features[nearest_idx])
745
+ mapped_centers = np.stack(mapped_centers)
746
+ cluster_centers = mapped_centers
747
+ # === 3. Solve TSP between mapped cluster centers
748
+ cluster_order = solve_tsp_2opt(cluster_centers)
749
+
750
+ # === 4. For each cluster, solve TSP inside cluster
751
+ final_indices = []
752
+ for cluster_id in cluster_order:
753
+ cluster_indices = clusters[cluster_id]
754
+ cluster_positions = features[cluster_indices]
755
+
756
+ if len(cluster_positions) == 1:
757
+ final_indices.append(cluster_indices[0])
758
+ continue
759
+
760
+ local_order = solve_tsp_nearest_neighbor(cluster_positions)
761
+ ordered_cluster_indices = cluster_indices[local_order]
762
+ final_indices.extend(ordered_cluster_indices)
763
+
764
+ ordered_cameras = [existing_cameras[i] for i in final_indices]
765
+ ordered_positions = positions_normalized[final_indices]
766
+
767
+ # === 5. Subsample uniformly
768
+ idx = np.linspace(0, len(ordered_positions) - 1, n_selected).astype(int)
769
+ sampled_positions = ordered_positions[idx]
770
+ sampled_cameras = [ordered_cameras[i] for i in idx]
771
+
772
+ # === 6. Prepare for Catmull-Rom spline
773
+ if closed:
774
+ sampled_positions = np.vstack([sampled_positions[-1], sampled_positions, sampled_positions[0], sampled_positions[1]])
775
+ else:
776
+ sampled_positions = np.vstack([sampled_positions[0], sampled_positions, sampled_positions[-1], sampled_positions[-1]])
777
+
778
+ # === 7. Smooth path positions
779
+ path_positions = []
780
+ for i in range(1, len(sampled_positions) - 2):
781
+ segment = catmull_rom_spline(sampled_positions[i-1], sampled_positions[i], sampled_positions[i+1], sampled_positions[i+2], n_points_per_segment)
782
+ path_positions.append(segment)
783
+ path_positions = np.concatenate(path_positions, axis=0)
784
+
785
+ # === 8. Denormalize
786
+ path_positions = path_positions * scale_pos + mean_pos
787
+
788
+ # === 9. SLERP for rotations
789
+ rotations = R.from_matrix([cam.R for cam in sampled_cameras])
790
+ key_times = np.linspace(0, 1, len(rotations))
791
+ slerp = Slerp(key_times, rotations)
792
+
793
+ query_times = np.linspace(0, 1, len(path_positions))
794
+ interpolated_rotations = slerp(query_times)
795
+
796
+ # === 10. Generate Camera objects
797
+ reference_cam = existing_cameras[0]
798
+ smooth_cameras = []
799
+
800
+ for i, pos in enumerate(path_positions):
801
+ R_interp = interpolated_rotations[i].as_matrix()
802
+
803
+ smooth_cameras.append(Camera(
804
+ R=R_interp,
805
+ T=pos,
806
+ FoVx=reference_cam.FoVx,
807
+ FoVy=reference_cam.FoVy,
808
+ resolution=(reference_cam.image_width, reference_cam.image_height),
809
+ colmap_id=-1,
810
+ depth_params=None,
811
+ image=Image.fromarray(np.zeros((reference_cam.image_height, reference_cam.image_width, 3), dtype=np.uint8)),
812
+ invdepthmap=None,
813
+ image_name=f"clustered_smooth_path_i={i}",
814
+ uid=i
815
+ ))
816
+
817
+ return smooth_cameras
818
+
819
+
820
+ # def generate_clustered_path(existing_cameras: List[Camera],
821
+ # n_points_per_segment: int = 20,
822
+ # d: float = 2.0,
823
+ # n_clusters: int = 5,
824
+ # closed: bool = False) -> List[Camera]:
825
+ # """
826
+ # Generate a smooth camera path using GMM clustering and TSP on cluster centers.
827
+
828
+ # Args:
829
+ # existing_cameras (List[Camera]): List of input cameras.
830
+ # n_points_per_segment (int): Number of interpolated points per spline segment.
831
+ # d (float): Distance ahead for estimating center of view.
832
+ # n_clusters (int): Number of GMM clusters (zones).
833
+ # closed (bool): Whether to close the path.
834
+
835
+ # Returns:
836
+ # List[Camera]: Smooth path of Camera objects.
837
+ # """
838
+ # # Extract positions and rotations
839
+ # positions = np.array([cam.T for cam in existing_cameras])
840
+
841
+ # # === Normalize positions
842
+ # mean_pos = np.mean(positions, axis=0)
843
+ # scale_pos = np.std(positions, axis=0)
844
+ # scale_pos[scale_pos == 0] = 1.0
845
+
846
+ # positions_normalized = (positions - mean_pos) / scale_pos
847
+
848
+ # # === 1. GMM clustering (only positions)
849
+ # gmm = GaussianMixture(n_components=n_clusters, covariance_type='full', random_state=42)
850
+ # cluster_labels = gmm.fit_predict(positions_normalized)
851
+
852
+ # cluster_centers = []
853
+ # for cluster_id in range(n_clusters):
854
+ # cluster_indices = np.where(cluster_labels == cluster_id)[0]
855
+ # if len(cluster_indices) == 0:
856
+ # continue
857
+ # cluster_center = np.mean(positions_normalized[cluster_indices], axis=0)
858
+ # cluster_centers.append(cluster_center)
859
+
860
+ # cluster_centers = np.stack(cluster_centers)
861
+
862
+ # # === 2. Solve TSP between cluster centers
863
+ # cluster_order = solve_tsp_2opt(cluster_centers)
864
+
865
+ # # === 3. Reorder cluster centers
866
+ # ordered_centers = cluster_centers[cluster_order]
867
+
868
+ # # === 4. Prepare Catmull-Rom spline
869
+ # if closed:
870
+ # ordered_centers = np.vstack([ordered_centers[-1], ordered_centers, ordered_centers[0], ordered_centers[1]])
871
+ # else:
872
+ # ordered_centers = np.vstack([ordered_centers[0], ordered_centers, ordered_centers[-1], ordered_centers[-1]])
873
+
874
+ # # === 5. Generate smooth path positions
875
+ # path_positions = []
876
+ # for i in range(1, len(ordered_centers) - 2):
877
+ # segment = catmull_rom_spline(ordered_centers[i-1], ordered_centers[i], ordered_centers[i+1], ordered_centers[i+2], n_points_per_segment)
878
+ # path_positions.append(segment)
879
+ # path_positions = np.concatenate(path_positions, axis=0)
880
+
881
+ # # === 6. Denormalize back
882
+ # path_positions = path_positions * scale_pos + mean_pos
883
+
884
+ # # === 7. Generate dummy rotations (constant forward facing)
885
+ # reference_cam = existing_cameras[0]
886
+ # default_rotation = R.from_matrix(reference_cam.R)
887
+
888
+ # # For simplicity, fixed rotation for all
889
+ # smooth_cameras = []
890
+
891
+ # for i, pos in enumerate(path_positions):
892
+ # R_interp = default_rotation.as_matrix()
893
+
894
+ # smooth_cameras.append(Camera(
895
+ # R=R_interp,
896
+ # T=pos,
897
+ # FoVx=reference_cam.FoVx,
898
+ # FoVy=reference_cam.FoVy,
899
+ # resolution=(reference_cam.image_width, reference_cam.image_height),
900
+ # colmap_id=-1,
901
+ # depth_params=None,
902
+ # image=Image.fromarray(np.zeros((reference_cam.image_height, reference_cam.image_width, 3), dtype=np.uint8)),
903
+ # invdepthmap=None,
904
+ # image_name=f"cluster_path_i={i}",
905
+ # uid=i
906
+ # ))
907
+
908
+ # return smooth_cameras
909
+
910
+ from typing import List
911
+ import numpy as np
912
+ from sklearn.cluster import KMeans
913
+ from scipy.spatial.transform import Rotation as R, Slerp
914
+ from PIL import Image
915
+
916
+ def generate_clustered_path(existing_cameras: List[Camera],
917
+ n_points_per_segment: int = 20,
918
+ d: float = 2.0,
919
+ n_clusters: int = 5,
920
+ closed: bool = False) -> List[Camera]:
921
+ """
922
+ Generate a smooth camera path using K-Means clustering and TSP on cluster centers.
923
+
924
+ Args:
925
+ existing_cameras (List[Camera]): List of input cameras.
926
+ n_points_per_segment (int): Number of interpolated points per spline segment.
927
+ d (float): Distance ahead for estimating center of view.
928
+ n_clusters (int): Number of KMeans clusters (zones).
929
+ closed (bool): Whether to close the path.
930
+
931
+ Returns:
932
+ List[Camera]: Smooth path of Camera objects.
933
+ """
934
+ # Extract positions
935
+ positions = np.array([cam.T for cam in existing_cameras])
936
+
937
+ # === Normalize positions
938
+ mean_pos = np.mean(positions, axis=0)
939
+ scale_pos = np.std(positions, axis=0)
940
+ scale_pos[scale_pos == 0] = 1.0
941
+
942
+ positions_normalized = (positions - mean_pos) / scale_pos
943
+
944
+ # === 1. K-Means clustering (only positions)
945
+ kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init='auto')
946
+ cluster_labels = kmeans.fit_predict(positions_normalized)
947
+
948
+ cluster_centers = []
949
+ for cluster_id in range(n_clusters):
950
+ cluster_indices = np.where(cluster_labels == cluster_id)[0]
951
+ if len(cluster_indices) == 0:
952
+ continue
953
+ cluster_center = np.mean(positions_normalized[cluster_indices], axis=0)
954
+ cluster_centers.append(cluster_center)
955
+
956
+ cluster_centers = np.stack(cluster_centers)
957
+
958
+ # === 2. Solve TSP between cluster centers
959
+ cluster_order = solve_tsp_2opt(cluster_centers)
960
+
961
+ # === 3. Reorder cluster centers
962
+ ordered_centers = cluster_centers[cluster_order]
963
+
964
+ # === 4. Prepare Catmull-Rom spline
965
+ if closed:
966
+ ordered_centers = np.vstack([ordered_centers[-1], ordered_centers, ordered_centers[0], ordered_centers[1]])
967
+ else:
968
+ ordered_centers = np.vstack([ordered_centers[0], ordered_centers, ordered_centers[-1], ordered_centers[-1]])
969
+
970
+ # === 5. Generate smooth path positions
971
+ path_positions = []
972
+ for i in range(1, len(ordered_centers) - 2):
973
+ segment = catmull_rom_spline(ordered_centers[i-1], ordered_centers[i], ordered_centers[i+1], ordered_centers[i+2], n_points_per_segment)
974
+ path_positions.append(segment)
975
+ path_positions = np.concatenate(path_positions, axis=0)
976
+
977
+ # === 6. Denormalize back
978
+ path_positions = path_positions * scale_pos + mean_pos
979
+
980
+ # === 7. Generate dummy rotations (constant forward facing)
981
+ reference_cam = existing_cameras[0]
982
+ default_rotation = R.from_matrix(reference_cam.R)
983
+
984
+ # For simplicity, fixed rotation for all
985
+ smooth_cameras = []
986
+
987
+ for i, pos in enumerate(path_positions):
988
+ R_interp = default_rotation.as_matrix()
989
+
990
+ smooth_cameras.append(Camera(
991
+ R=R_interp,
992
+ T=pos,
993
+ FoVx=reference_cam.FoVx,
994
+ FoVy=reference_cam.FoVy,
995
+ resolution=(reference_cam.image_width, reference_cam.image_height),
996
+ colmap_id=-1,
997
+ depth_params=None,
998
+ image=Image.fromarray(np.zeros((reference_cam.image_height, reference_cam.image_width, 3), dtype=np.uint8)),
999
+ invdepthmap=None,
1000
+ image_name=f"cluster_path_i={i}",
1001
+ uid=i
1002
+ ))
1003
+
1004
+ return smooth_cameras
1005
+
1006
+
1007
+
1008
+
1009
+ def visualize_image_with_points(image, points):
1010
+ """
1011
+ Visualize an image with points overlaid on top. This is useful for correspondences visualizations
1012
+
1013
+ Parameters:
1014
+ - image: PIL Image object
1015
+ - points: Numpy array of shape [N, 2] containing (x, y) coordinates of points
1016
+
1017
+ Returns:
1018
+ - None (displays the visualization)
1019
+ """
1020
+
1021
+ # Convert PIL image to numpy array
1022
+ img_array = np.array(image)
1023
+
1024
+ # Create a figure and axis
1025
+ fig, ax = plt.subplots(figsize=(7,7))
1026
+
1027
+ # Display the image
1028
+ ax.imshow(img_array)
1029
+
1030
+ # Scatter plot the points on top of the image
1031
+ ax.scatter(points[:, 0], points[:, 1], color='red', marker='o', s=1)
1032
+
1033
+ # Show the plot
1034
+ plt.show()
1035
+
1036
+
1037
+ def visualize_correspondences(image1, points1, image2, points2):
1038
+ """
1039
+ Visualize two images concatenated horizontally with key points and correspondences.
1040
+
1041
+ Parameters:
1042
+ - image1: PIL Image object (left image)
1043
+ - points1: Numpy array of shape [N, 2] containing (x, y) coordinates of key points for image1
1044
+ - image2: PIL Image object (right image)
1045
+ - points2: Numpy array of shape [N, 2] containing (x, y) coordinates of key points for image2
1046
+
1047
+ Returns:
1048
+ - None (displays the visualization)
1049
+ """
1050
+
1051
+ # Concatenate images horizontally
1052
+ concatenated_image = np.concatenate((np.array(image1), np.array(image2)), axis=1)
1053
+
1054
+ # Create a figure and axis
1055
+ fig, ax = plt.subplots(figsize=(10,10))
1056
+
1057
+ # Display the concatenated image
1058
+ ax.imshow(concatenated_image)
1059
+
1060
+ # Plot key points on the left image
1061
+ ax.scatter(points1[:, 0], points1[:, 1], color='red', marker='o', s=10)
1062
+
1063
+ # Plot key points on the right image
1064
+ ax.scatter(points2[:, 0] + image1.width, points2[:, 1], color='blue', marker='o', s=10)
1065
+
1066
+ # Draw lines connecting corresponding key points
1067
+ for i in range(len(points1)):
1068
+ ax.plot([points1[i, 0], points2[i, 0] + image1.width], [points1[i, 1], points2[i, 1]])#, color='green')
1069
+
1070
+ # Show the plot
1071
+ plt.show()
1072
+
submodules/RoMa/.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.egg-info*
2
+ *.vscode*
3
+ *__pycache__*
4
+ vis*
5
+ workspace*
6
+ .venv
7
+ .DS_Store
8
+ jobs/*
9
+ *ignore_me*
10
+ *.pth
11
+ wandb*
submodules/RoMa/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Johan Edstedt
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
submodules/RoMa/README.md ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ <p align="center">
3
+ <h1 align="center"> <ins>RoMa</ins> 🏛️:<br> Robust Dense Feature Matching <br> ⭐CVPR 2024⭐</h1>
4
+ <p align="center">
5
+ <a href="https://scholar.google.com/citations?user=Ul-vMR0AAAAJ">Johan Edstedt</a>
6
+ ·
7
+ <a href="https://scholar.google.com/citations?user=HS2WuHkAAAAJ">Qiyu Sun</a>
8
+ ·
9
+ <a href="https://scholar.google.com/citations?user=FUE3Wd0AAAAJ">Georg Bökman</a>
10
+ ·
11
+ <a href="https://scholar.google.com/citations?user=6WRQpCQAAAAJ">Mårten Wadenbäck</a>
12
+ ·
13
+ <a href="https://scholar.google.com/citations?user=lkWfR08AAAAJ">Michael Felsberg</a>
14
+ </p>
15
+ <h2 align="center"><p>
16
+ <a href="https://arxiv.org/abs/2305.15404" align="center">Paper</a> |
17
+ <a href="https://parskatt.github.io/RoMa" align="center">Project Page</a>
18
+ </p></h2>
19
+ <div align="center"></div>
20
+ </p>
21
+ <br/>
22
+ <p align="center">
23
+ <img src="https://github.com/Parskatt/RoMa/assets/22053118/15d8fea7-aa6d-479f-8a93-350d950d006b" alt="example" width=80%>
24
+ <br>
25
+ <em>RoMa is the robust dense feature matcher capable of estimating pixel-dense warps and reliable certainties for almost any image pair.</em>
26
+ </p>
27
+
28
+ ## Setup/Install
29
+ In your python environment (tested on Linux python 3.10), run:
30
+ ```bash
31
+ pip install -e .
32
+ ```
33
+ ## Demo / How to Use
34
+ We provide two demos in the [demos folder](demo).
35
+ Here's the gist of it:
36
+ ```python
37
+ from romatch import roma_outdoor
38
+ roma_model = roma_outdoor(device=device)
39
+ # Match
40
+ warp, certainty = roma_model.match(imA_path, imB_path, device=device)
41
+ # Sample matches for estimation
42
+ matches, certainty = roma_model.sample(warp, certainty)
43
+ # Convert to pixel coordinates (RoMa produces matches in [-1,1]x[-1,1])
44
+ kptsA, kptsB = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
45
+ # Find a fundamental matrix (or anything else of interest)
46
+ F, mask = cv2.findFundamentalMat(
47
+ kptsA.cpu().numpy(), kptsB.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000
48
+ )
49
+ ```
50
+
51
+ **New**: You can also match arbitrary keypoints with RoMa. See [match_keypoints](romatch/models/matcher.py) in RegressionMatcher.
52
+
53
+ ## Settings
54
+
55
+ ### Resolution
56
+ By default RoMa uses an initial resolution of (560,560) which is then upsampled to (864,864).
57
+ You can change this at construction (see roma_outdoor kwargs).
58
+ You can also change this later, by changing the roma_model.w_resized, roma_model.h_resized, and roma_model.upsample_res.
59
+
60
+ ### Sampling
61
+ roma_model.sample_thresh controls the thresholding used when sampling matches for estimation. In certain cases a lower or higher threshold may improve results.
62
+
63
+
64
+ ## Reproducing Results
65
+ The experiments in the paper are provided in the [experiments folder](experiments).
66
+
67
+ ### Training
68
+ 1. First follow the instructions provided here: https://github.com/Parskatt/DKM for downloading and preprocessing datasets.
69
+ 2. Run the relevant experiment, e.g.,
70
+ ```bash
71
+ torchrun --nproc_per_node=4 --nnodes=1 --rdzv_backend=c10d experiments/roma_outdoor.py
72
+ ```
73
+ ### Testing
74
+ ```bash
75
+ python experiments/roma_outdoor.py --only_test --benchmark mega-1500
76
+ ```
77
+ ## License
78
+ All our code except DINOv2 is MIT license.
79
+ DINOv2 has an Apache 2 license [DINOv2](https://github.com/facebookresearch/dinov2/blob/main/LICENSE).
80
+
81
+ ## Acknowledgement
82
+ Our codebase builds on the code in [DKM](https://github.com/Parskatt/DKM).
83
+
84
+ ## Tiny RoMa
85
+ If you find that RoMa is too heavy, you might want to try Tiny RoMa which is built on top of XFeat.
86
+ ```python
87
+ from romatch import tiny_roma_v1_outdoor
88
+ tiny_roma_model = tiny_roma_v1_outdoor(device=device)
89
+ ```
90
+ Mega1500:
91
+ | | AUC@5 | AUC@10 | AUC@20 |
92
+ |----------|----------|----------|----------|
93
+ | XFeat | 46.4 | 58.9 | 69.2 |
94
+ | XFeat* | 51.9 | 67.2 | 78.9 |
95
+ | Tiny RoMa v1 | 56.4 | 69.5 | 79.5 |
96
+ | RoMa | - | - | - |
97
+
98
+ Mega-8-Scenes (See DKM):
99
+ | | AUC@5 | AUC@10 | AUC@20 |
100
+ |----------|----------|----------|----------|
101
+ | XFeat | - | - | - |
102
+ | XFeat* | 50.1 | 64.4 | 75.2 |
103
+ | Tiny RoMa v1 | 57.7 | 70.5 | 79.6 |
104
+ | RoMa | - | - | - |
105
+
106
+ IMC22 :'):
107
+ | | mAA@10 |
108
+ |----------|----------|
109
+ | XFeat | 42.1 |
110
+ | XFeat* | - |
111
+ | Tiny RoMa v1 | 42.2 |
112
+ | RoMa | - |
113
+
114
+ ## BibTeX
115
+ If you find our models useful, please consider citing our paper!
116
+ ```
117
+ @article{edstedt2024roma,
118
+ title={{RoMa: Robust Dense Feature Matching}},
119
+ author={Edstedt, Johan and Sun, Qiyu and Bökman, Georg and Wadenbäck, Mårten and Felsberg, Michael},
120
+ journal={IEEE Conference on Computer Vision and Pattern Recognition},
121
+ year={2024}
122
+ }
123
+ ```
submodules/RoMa/data/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *
2
+ !.gitignore
submodules/RoMa/demo/demo_3D_effect.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from romatch.utils.utils import tensor_to_pil
6
+
7
+ from romatch import roma_outdoor
8
+
9
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
+ if torch.backends.mps.is_available():
11
+ device = torch.device('mps')
12
+
13
+ if __name__ == "__main__":
14
+ from argparse import ArgumentParser
15
+ parser = ArgumentParser()
16
+ parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
17
+ parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
18
+ parser.add_argument("--save_path", default="demo/gif/roma_warp_toronto", type=str)
19
+
20
+ args, _ = parser.parse_known_args()
21
+ im1_path = args.im_A_path
22
+ im2_path = args.im_B_path
23
+ save_path = args.save_path
24
+
25
+ # Create model
26
+ roma_model = roma_outdoor(device=device, coarse_res=560, upsample_res=(864, 1152))
27
+ roma_model.symmetric = False
28
+
29
+ H, W = roma_model.get_output_resolution()
30
+
31
+ im1 = Image.open(im1_path).resize((W, H))
32
+ im2 = Image.open(im2_path).resize((W, H))
33
+
34
+ # Match
35
+ warp, certainty = roma_model.match(im1_path, im2_path, device=device)
36
+ # Sampling not needed, but can be done with model.sample(warp, certainty)
37
+ x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
38
+ x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
39
+
40
+ coords_A, coords_B = warp[...,:2], warp[...,2:]
41
+ for i, x in enumerate(np.linspace(0,2*np.pi,200)):
42
+ t = (1 + np.cos(x))/2
43
+ interp_warp = (1-t)*coords_A + t*coords_B
44
+ im2_transfer_rgb = F.grid_sample(
45
+ x2[None], interp_warp[None], mode="bilinear", align_corners=False
46
+ )[0]
47
+ tensor_to_pil(im2_transfer_rgb, unnormalize=False).save(f"{save_path}_{i:03d}.jpg")
submodules/RoMa/demo/demo_fundamental.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ import cv2
4
+ from romatch import roma_outdoor
5
+
6
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
7
+ if torch.backends.mps.is_available():
8
+ device = torch.device('mps')
9
+
10
+ if __name__ == "__main__":
11
+ from argparse import ArgumentParser
12
+ parser = ArgumentParser()
13
+ parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
14
+ parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
15
+
16
+ args, _ = parser.parse_known_args()
17
+ im1_path = args.im_A_path
18
+ im2_path = args.im_B_path
19
+
20
+ # Create model
21
+ roma_model = roma_outdoor(device=device)
22
+
23
+
24
+ W_A, H_A = Image.open(im1_path).size
25
+ W_B, H_B = Image.open(im2_path).size
26
+
27
+ # Match
28
+ warp, certainty = roma_model.match(im1_path, im2_path, device=device)
29
+ # Sample matches for estimation
30
+ matches, certainty = roma_model.sample(warp, certainty)
31
+ kpts1, kpts2 = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
32
+ F, mask = cv2.findFundamentalMat(
33
+ kpts1.cpu().numpy(), kpts2.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000
34
+ )
submodules/RoMa/demo/demo_match.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
3
+ import torch
4
+ from PIL import Image
5
+ import torch.nn.functional as F
6
+ import numpy as np
7
+ from romatch.utils.utils import tensor_to_pil
8
+
9
+ from romatch import roma_outdoor
10
+
11
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+ if torch.backends.mps.is_available():
13
+ device = torch.device('mps')
14
+
15
+ if __name__ == "__main__":
16
+ from argparse import ArgumentParser
17
+ parser = ArgumentParser()
18
+ parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
19
+ parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
20
+ parser.add_argument("--save_path", default="demo/roma_warp_toronto.jpg", type=str)
21
+
22
+ args, _ = parser.parse_known_args()
23
+ im1_path = args.im_A_path
24
+ im2_path = args.im_B_path
25
+ save_path = args.save_path
26
+
27
+ # Create model
28
+ roma_model = roma_outdoor(device=device, coarse_res=560, upsample_res=(864, 1152))
29
+
30
+ H, W = roma_model.get_output_resolution()
31
+
32
+ im1 = Image.open(im1_path).resize((W, H))
33
+ im2 = Image.open(im2_path).resize((W, H))
34
+
35
+ # Match
36
+ warp, certainty = roma_model.match(im1_path, im2_path, device=device)
37
+ # Sampling not needed, but can be done with model.sample(warp, certainty)
38
+ x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
39
+ x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
40
+
41
+ im2_transfer_rgb = F.grid_sample(
42
+ x2[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
43
+ )[0]
44
+ im1_transfer_rgb = F.grid_sample(
45
+ x1[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
46
+ )[0]
47
+ warp_im = torch.cat((im2_transfer_rgb,im1_transfer_rgb),dim=2)
48
+ white_im = torch.ones((H,2*W),device=device)
49
+ vis_im = certainty * warp_im + (1 - certainty) * white_im
50
+ tensor_to_pil(vis_im, unnormalize=False).save(save_path)
submodules/RoMa/demo/demo_match_opencv_sift.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+
4
+ import numpy as np
5
+ import cv2 as cv
6
+ import matplotlib.pyplot as plt
7
+
8
+
9
+
10
+ if __name__ == "__main__":
11
+ from argparse import ArgumentParser
12
+ parser = ArgumentParser()
13
+ parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
14
+ parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
15
+ parser.add_argument("--save_path", default="demo/roma_warp_toronto.jpg", type=str)
16
+
17
+ args, _ = parser.parse_known_args()
18
+ im1_path = args.im_A_path
19
+ im2_path = args.im_B_path
20
+ save_path = args.save_path
21
+
22
+ img1 = cv.imread(im1_path,cv.IMREAD_GRAYSCALE) # queryImage
23
+ img2 = cv.imread(im2_path,cv.IMREAD_GRAYSCALE) # trainImage
24
+ # Initiate SIFT detector
25
+ sift = cv.SIFT_create()
26
+ # find the keypoints and descriptors with SIFT
27
+ kp1, des1 = sift.detectAndCompute(img1,None)
28
+ kp2, des2 = sift.detectAndCompute(img2,None)
29
+ # BFMatcher with default params
30
+ bf = cv.BFMatcher()
31
+ matches = bf.knnMatch(des1,des2,k=2)
32
+ # Apply ratio test
33
+ good = []
34
+ for m,n in matches:
35
+ if m.distance < 0.75*n.distance:
36
+ good.append([m])
37
+ # cv.drawMatchesKnn expects list of lists as matches.
38
+ draw_params = dict(matchColor = (255,0,0), # draw matches in red color
39
+ singlePointColor = None,
40
+ flags = 2)
41
+
42
+ img3 = cv.drawMatchesKnn(img1,kp1,img2,kp2,good,None,**draw_params)
43
+ Image.fromarray(img3).save("demo/sift_matches.png")
submodules/RoMa/demo/demo_match_tiny.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
3
+ import torch
4
+ from PIL import Image
5
+ import torch.nn.functional as F
6
+ import numpy as np
7
+ from romatch.utils.utils import tensor_to_pil
8
+
9
+ from romatch import tiny_roma_v1_outdoor
10
+
11
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+ if torch.backends.mps.is_available():
13
+ device = torch.device('mps')
14
+
15
+ if __name__ == "__main__":
16
+ from argparse import ArgumentParser
17
+ parser = ArgumentParser()
18
+ parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
19
+ parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
20
+ parser.add_argument("--save_A_path", default="demo/tiny_roma_warp_A.jpg", type=str)
21
+ parser.add_argument("--save_B_path", default="demo/tiny_roma_warp_B.jpg", type=str)
22
+
23
+ args, _ = parser.parse_known_args()
24
+ im1_path = args.im_A_path
25
+ im2_path = args.im_B_path
26
+
27
+ # Create model
28
+ roma_model = tiny_roma_v1_outdoor(device=device)
29
+
30
+ # Match
31
+ warp, certainty1 = roma_model.match(im1_path, im2_path)
32
+
33
+ h1, w1 = warp.shape[:2]
34
+
35
+ # maybe im1.size != im2.size
36
+ im1 = Image.open(im1_path).resize((w1, h1))
37
+ im2 = Image.open(im2_path)
38
+ x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
39
+ x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
40
+
41
+ h2, w2 = x2.shape[1:]
42
+ g1_p2x = w2 / 2 * (warp[..., 2] + 1)
43
+ g1_p2y = h2 / 2 * (warp[..., 3] + 1)
44
+ g2_p1x = torch.zeros((h2, w2), dtype=torch.float32).to(device) - 2
45
+ g2_p1y = torch.zeros((h2, w2), dtype=torch.float32).to(device) - 2
46
+
47
+ x, y = torch.meshgrid(
48
+ torch.arange(w1, device=device),
49
+ torch.arange(h1, device=device),
50
+ indexing="xy",
51
+ )
52
+ g2x = torch.round(g1_p2x[y, x]).long()
53
+ g2y = torch.round(g1_p2y[y, x]).long()
54
+ idx_x = torch.bitwise_and(0 <= g2x, g2x < w2)
55
+ idx_y = torch.bitwise_and(0 <= g2y, g2y < h2)
56
+ idx = torch.bitwise_and(idx_x, idx_y)
57
+ g2_p1x[g2y[idx], g2x[idx]] = x[idx].float() * 2 / w1 - 1
58
+ g2_p1y[g2y[idx], g2x[idx]] = y[idx].float() * 2 / h1 - 1
59
+
60
+ certainty2 = F.grid_sample(
61
+ certainty1[None][None],
62
+ torch.stack([g2_p1x, g2_p1y], dim=2)[None],
63
+ mode="bilinear",
64
+ align_corners=False,
65
+ )[0]
66
+
67
+ white_im1 = torch.ones((h1, w1), device = device)
68
+ white_im2 = torch.ones((h2, w2), device = device)
69
+
70
+ certainty1 = F.avg_pool2d(certainty1[None], kernel_size=5, stride=1, padding=2)[0]
71
+ certainty2 = F.avg_pool2d(certainty2[None], kernel_size=5, stride=1, padding=2)[0]
72
+
73
+ vis_im1 = certainty1 * x1 + (1 - certainty1) * white_im1
74
+ vis_im2 = certainty2 * x2 + (1 - certainty2) * white_im2
75
+
76
+ tensor_to_pil(vis_im1, unnormalize=False).save(args.save_A_path)
77
+ tensor_to_pil(vis_im2, unnormalize=False).save(args.save_B_path)
submodules/RoMa/demo/gif/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *
2
+ !.gitignore
submodules/RoMa/experiments/eval_roma_outdoor.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ from romatch.benchmarks import MegadepthDenseBenchmark
4
+ from romatch.benchmarks import MegaDepthPoseEstimationBenchmark, HpatchesHomogBenchmark
5
+ from romatch.benchmarks import Mega1500PoseLibBenchmark
6
+
7
+ def test_mega_8_scenes(model, name):
8
+ mega_8_scenes_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth",
9
+ scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
10
+ 'mega_8_scenes_0025_0.1_0.3.npz',
11
+ 'mega_8_scenes_0021_0.1_0.3.npz',
12
+ 'mega_8_scenes_0008_0.1_0.3.npz',
13
+ 'mega_8_scenes_0032_0.1_0.3.npz',
14
+ 'mega_8_scenes_1589_0.1_0.3.npz',
15
+ 'mega_8_scenes_0063_0.1_0.3.npz',
16
+ 'mega_8_scenes_0024_0.1_0.3.npz',
17
+ 'mega_8_scenes_0019_0.3_0.5.npz',
18
+ 'mega_8_scenes_0025_0.3_0.5.npz',
19
+ 'mega_8_scenes_0021_0.3_0.5.npz',
20
+ 'mega_8_scenes_0008_0.3_0.5.npz',
21
+ 'mega_8_scenes_0032_0.3_0.5.npz',
22
+ 'mega_8_scenes_1589_0.3_0.5.npz',
23
+ 'mega_8_scenes_0063_0.3_0.5.npz',
24
+ 'mega_8_scenes_0024_0.3_0.5.npz'])
25
+ mega_8_scenes_results = mega_8_scenes_benchmark.benchmark(model, model_name=name)
26
+ print(mega_8_scenes_results)
27
+ json.dump(mega_8_scenes_results, open(f"results/mega_8_scenes_{name}.json", "w"))
28
+
29
+ def test_mega1500(model, name):
30
+ mega1500_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth")
31
+ mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
32
+ json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w"))
33
+
34
+ def test_mega1500_poselib(model, name):
35
+ mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth")
36
+ mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
37
+ json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w"))
38
+
39
+ def test_mega_dense(model, name):
40
+ megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000)
41
+ megadense_results = megadense_benchmark.benchmark(model)
42
+ json.dump(megadense_results, open(f"results/mega_dense_{name}.json", "w"))
43
+
44
+ def test_hpatches(model, name):
45
+ hpatches_benchmark = HpatchesHomogBenchmark("data/hpatches")
46
+ hpatches_results = hpatches_benchmark.benchmark(model)
47
+ json.dump(hpatches_results, open(f"results/hpatches_{name}.json", "w"))
48
+
49
+
50
+ if __name__ == "__main__":
51
+ from romatch import roma_outdoor
52
+ device = "cuda"
53
+ model = roma_outdoor(device = device, coarse_res = 672, upsample_res = 1344)
54
+ experiment_name = "roma_latest"
55
+ test_mega1500(model, experiment_name)
56
+ #test_mega1500_poselib(model, experiment_name)
57
+
submodules/RoMa/experiments/eval_tiny_roma_v1_outdoor.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from pathlib import Path
4
+ import json
5
+ from romatch.benchmarks import ScanNetBenchmark
6
+ from romatch.benchmarks import Mega1500PoseLibBenchmark, ScanNetPoselibBenchmark
7
+ from romatch.benchmarks import MegaDepthPoseEstimationBenchmark
8
+
9
+ def test_mega_8_scenes(model, name):
10
+ mega_8_scenes_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth",
11
+ scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
12
+ 'mega_8_scenes_0025_0.1_0.3.npz',
13
+ 'mega_8_scenes_0021_0.1_0.3.npz',
14
+ 'mega_8_scenes_0008_0.1_0.3.npz',
15
+ 'mega_8_scenes_0032_0.1_0.3.npz',
16
+ 'mega_8_scenes_1589_0.1_0.3.npz',
17
+ 'mega_8_scenes_0063_0.1_0.3.npz',
18
+ 'mega_8_scenes_0024_0.1_0.3.npz',
19
+ 'mega_8_scenes_0019_0.3_0.5.npz',
20
+ 'mega_8_scenes_0025_0.3_0.5.npz',
21
+ 'mega_8_scenes_0021_0.3_0.5.npz',
22
+ 'mega_8_scenes_0008_0.3_0.5.npz',
23
+ 'mega_8_scenes_0032_0.3_0.5.npz',
24
+ 'mega_8_scenes_1589_0.3_0.5.npz',
25
+ 'mega_8_scenes_0063_0.3_0.5.npz',
26
+ 'mega_8_scenes_0024_0.3_0.5.npz'])
27
+ mega_8_scenes_results = mega_8_scenes_benchmark.benchmark(model, model_name=name)
28
+ print(mega_8_scenes_results)
29
+ json.dump(mega_8_scenes_results, open(f"results/mega_8_scenes_{name}.json", "w"))
30
+
31
+ def test_mega1500(model, name):
32
+ mega1500_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth")
33
+ mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
34
+ json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w"))
35
+
36
+ def test_mega1500_poselib(model, name):
37
+ #model.exact_softmax = True
38
+ mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 1)
39
+ mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
40
+ json.dump(mega1500_results, open(f"results/mega1500_poselib_{name}.json", "w"))
41
+
42
+ def test_mega_8_scenes_poselib(model, name):
43
+ mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 1,
44
+ scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
45
+ 'mega_8_scenes_0025_0.1_0.3.npz',
46
+ 'mega_8_scenes_0021_0.1_0.3.npz',
47
+ 'mega_8_scenes_0008_0.1_0.3.npz',
48
+ 'mega_8_scenes_0032_0.1_0.3.npz',
49
+ 'mega_8_scenes_1589_0.1_0.3.npz',
50
+ 'mega_8_scenes_0063_0.1_0.3.npz',
51
+ 'mega_8_scenes_0024_0.1_0.3.npz',
52
+ 'mega_8_scenes_0019_0.3_0.5.npz',
53
+ 'mega_8_scenes_0025_0.3_0.5.npz',
54
+ 'mega_8_scenes_0021_0.3_0.5.npz',
55
+ 'mega_8_scenes_0008_0.3_0.5.npz',
56
+ 'mega_8_scenes_0032_0.3_0.5.npz',
57
+ 'mega_8_scenes_1589_0.3_0.5.npz',
58
+ 'mega_8_scenes_0063_0.3_0.5.npz',
59
+ 'mega_8_scenes_0024_0.3_0.5.npz'])
60
+ mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
61
+ json.dump(mega1500_results, open(f"results/mega_8_scenes_poselib_{name}.json", "w"))
62
+
63
+ def test_scannet_poselib(model, name):
64
+ scannet_benchmark = ScanNetPoselibBenchmark("data/scannet")
65
+ scannet_results = scannet_benchmark.benchmark(model)
66
+ json.dump(scannet_results, open(f"results/scannet_{name}.json", "w"))
67
+
68
+ def test_scannet(model, name):
69
+ scannet_benchmark = ScanNetBenchmark("data/scannet")
70
+ scannet_results = scannet_benchmark.benchmark(model)
71
+ json.dump(scannet_results, open(f"results/scannet_{name}.json", "w"))
72
+
73
+ if __name__ == "__main__":
74
+ os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations
75
+ os.environ["OMP_NUM_THREADS"] = "16"
76
+ torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
77
+ from romatch import tiny_roma_v1_outdoor
78
+
79
+ experiment_name = Path(__file__).stem
80
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
81
+ model = tiny_roma_v1_outdoor(device)
82
+ #test_mega1500_poselib(model, experiment_name)
83
+ test_mega_8_scenes_poselib(model, experiment_name)
84
+
submodules/RoMa/experiments/roma_indoor.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from argparse import ArgumentParser
4
+
5
+ from torch import nn
6
+ from torch.utils.data import ConcatDataset
7
+ import torch.distributed as dist
8
+ from torch.nn.parallel import DistributedDataParallel as DDP
9
+
10
+ import json
11
+ import wandb
12
+ from tqdm import tqdm
13
+
14
+ from romatch.benchmarks import MegadepthDenseBenchmark
15
+ from romatch.datasets.megadepth import MegadepthBuilder
16
+ from romatch.datasets.scannet import ScanNetBuilder
17
+ from romatch.losses.robust_loss import RobustLosses
18
+ from romatch.benchmarks import MegadepthDenseBenchmark, ScanNetBenchmark
19
+ from romatch.train.train import train_k_steps
20
+ from romatch.models.matcher import *
21
+ from romatch.models.transformer import Block, TransformerDecoder, MemEffAttention
22
+ from romatch.models.encoders import *
23
+ from romatch.checkpointing import CheckPoint
24
+
25
+ resolutions = {"low":(448, 448), "medium":(14*8*5, 14*8*5), "high":(14*8*6, 14*8*6)}
26
+
27
+ def get_model(pretrained_backbone=True, resolution = "medium", **kwargs):
28
+ gp_dim = 512
29
+ feat_dim = 512
30
+ decoder_dim = gp_dim + feat_dim
31
+ cls_to_coord_res = 64
32
+ coordinate_decoder = TransformerDecoder(
33
+ nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]),
34
+ decoder_dim,
35
+ cls_to_coord_res**2 + 1,
36
+ is_classifier=True,
37
+ amp = True,
38
+ pos_enc = False,)
39
+ dw = True
40
+ hidden_blocks = 8
41
+ kernel_size = 5
42
+ displacement_emb = "linear"
43
+ disable_local_corr_grad = True
44
+
45
+ conv_refiner = nn.ModuleDict(
46
+ {
47
+ "16": ConvRefiner(
48
+ 2 * 512+128+(2*7+1)**2,
49
+ 2 * 512+128+(2*7+1)**2,
50
+ 2 + 1,
51
+ kernel_size=kernel_size,
52
+ dw=dw,
53
+ hidden_blocks=hidden_blocks,
54
+ displacement_emb=displacement_emb,
55
+ displacement_emb_dim=128,
56
+ local_corr_radius = 7,
57
+ corr_in_other = True,
58
+ amp = True,
59
+ disable_local_corr_grad = disable_local_corr_grad,
60
+ bn_momentum = 0.01,
61
+ ),
62
+ "8": ConvRefiner(
63
+ 2 * 512+64+(2*3+1)**2,
64
+ 2 * 512+64+(2*3+1)**2,
65
+ 2 + 1,
66
+ kernel_size=kernel_size,
67
+ dw=dw,
68
+ hidden_blocks=hidden_blocks,
69
+ displacement_emb=displacement_emb,
70
+ displacement_emb_dim=64,
71
+ local_corr_radius = 3,
72
+ corr_in_other = True,
73
+ amp = True,
74
+ disable_local_corr_grad = disable_local_corr_grad,
75
+ bn_momentum = 0.01,
76
+ ),
77
+ "4": ConvRefiner(
78
+ 2 * 256+32+(2*2+1)**2,
79
+ 2 * 256+32+(2*2+1)**2,
80
+ 2 + 1,
81
+ kernel_size=kernel_size,
82
+ dw=dw,
83
+ hidden_blocks=hidden_blocks,
84
+ displacement_emb=displacement_emb,
85
+ displacement_emb_dim=32,
86
+ local_corr_radius = 2,
87
+ corr_in_other = True,
88
+ amp = True,
89
+ disable_local_corr_grad = disable_local_corr_grad,
90
+ bn_momentum = 0.01,
91
+ ),
92
+ "2": ConvRefiner(
93
+ 2 * 64+16,
94
+ 128+16,
95
+ 2 + 1,
96
+ kernel_size=kernel_size,
97
+ dw=dw,
98
+ hidden_blocks=hidden_blocks,
99
+ displacement_emb=displacement_emb,
100
+ displacement_emb_dim=16,
101
+ amp = True,
102
+ disable_local_corr_grad = disable_local_corr_grad,
103
+ bn_momentum = 0.01,
104
+ ),
105
+ "1": ConvRefiner(
106
+ 2 * 9 + 6,
107
+ 24,
108
+ 2 + 1,
109
+ kernel_size=kernel_size,
110
+ dw=dw,
111
+ hidden_blocks = hidden_blocks,
112
+ displacement_emb = displacement_emb,
113
+ displacement_emb_dim = 6,
114
+ amp = True,
115
+ disable_local_corr_grad = disable_local_corr_grad,
116
+ bn_momentum = 0.01,
117
+ ),
118
+ }
119
+ )
120
+ kernel_temperature = 0.2
121
+ learn_temperature = False
122
+ no_cov = True
123
+ kernel = CosKernel
124
+ only_attention = False
125
+ basis = "fourier"
126
+ gp16 = GP(
127
+ kernel,
128
+ T=kernel_temperature,
129
+ learn_temperature=learn_temperature,
130
+ only_attention=only_attention,
131
+ gp_dim=gp_dim,
132
+ basis=basis,
133
+ no_cov=no_cov,
134
+ )
135
+ gps = nn.ModuleDict({"16": gp16})
136
+ proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512))
137
+ proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512))
138
+ proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256))
139
+ proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64))
140
+ proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9))
141
+ proj = nn.ModuleDict({
142
+ "16": proj16,
143
+ "8": proj8,
144
+ "4": proj4,
145
+ "2": proj2,
146
+ "1": proj1,
147
+ })
148
+ displacement_dropout_p = 0.0
149
+ gm_warp_dropout_p = 0.0
150
+ decoder = Decoder(coordinate_decoder,
151
+ gps,
152
+ proj,
153
+ conv_refiner,
154
+ detach=True,
155
+ scales=["16", "8", "4", "2", "1"],
156
+ displacement_dropout_p = displacement_dropout_p,
157
+ gm_warp_dropout_p = gm_warp_dropout_p)
158
+ h,w = resolutions[resolution]
159
+ encoder = CNNandDinov2(
160
+ cnn_kwargs = dict(
161
+ pretrained=pretrained_backbone,
162
+ amp = True),
163
+ amp = True,
164
+ use_vgg = True,
165
+ )
166
+ matcher = RegressionMatcher(encoder, decoder, h=h, w=w, alpha=1, beta=0,**kwargs)
167
+ return matcher
168
+
169
+ def train(args):
170
+ dist.init_process_group('nccl')
171
+ #torch._dynamo.config.verbose=True
172
+ gpus = int(os.environ['WORLD_SIZE'])
173
+ # create model and move it to GPU with id rank
174
+ rank = dist.get_rank()
175
+ print(f"Start running DDP on rank {rank}")
176
+ device_id = rank % torch.cuda.device_count()
177
+ romatch.LOCAL_RANK = device_id
178
+ torch.cuda.set_device(device_id)
179
+
180
+ resolution = args.train_resolution
181
+ wandb_log = not args.dont_log_wandb
182
+ experiment_name = os.path.splitext(os.path.basename(__file__))[0]
183
+ wandb_mode = "online" if wandb_log and rank == 0 and False else "disabled"
184
+ wandb.init(project="romatch", entity=args.wandb_entity, name=experiment_name, reinit=False, mode = wandb_mode)
185
+ checkpoint_dir = "workspace/checkpoints/"
186
+ h,w = resolutions[resolution]
187
+ model = get_model(pretrained_backbone=True, resolution=resolution, attenuate_cert = False).to(device_id)
188
+ # Num steps
189
+ global_step = 0
190
+ batch_size = args.gpu_batch_size
191
+ step_size = gpus*batch_size
192
+ romatch.STEP_SIZE = step_size
193
+
194
+ N = (32 * 250000) # 250k steps of batch size 32
195
+ # checkpoint every
196
+ k = 25000 // romatch.STEP_SIZE
197
+
198
+ # Data
199
+ mega = MegadepthBuilder(data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True)
200
+ use_horizontal_flip_aug = True
201
+ rot_prob = 0
202
+ depth_interpolation_mode = "bilinear"
203
+ megadepth_train1 = mega.build_scenes(
204
+ split="train_loftr", min_overlap=0.01, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
205
+ ht=h,wt=w,
206
+ )
207
+ megadepth_train2 = mega.build_scenes(
208
+ split="train_loftr", min_overlap=0.35, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
209
+ ht=h,wt=w,
210
+ )
211
+ megadepth_train = ConcatDataset(megadepth_train1 + megadepth_train2)
212
+ mega_ws = mega.weight_scenes(megadepth_train, alpha=0.75)
213
+
214
+ scannet = ScanNetBuilder(data_root="data/scannet")
215
+ scannet_train = scannet.build_scenes(split="train", ht=h, wt=w, use_horizontal_flip_aug = use_horizontal_flip_aug)
216
+ scannet_train = ConcatDataset(scannet_train)
217
+ scannet_ws = scannet.weight_scenes(scannet_train, alpha=0.75)
218
+
219
+ # Loss and optimizer
220
+ depth_loss_scannet = RobustLosses(
221
+ ce_weight=0.0,
222
+ local_dist={1:4, 2:4, 4:8, 8:8},
223
+ local_largest_scale=8,
224
+ depth_interpolation_mode=depth_interpolation_mode,
225
+ alpha = 0.5,
226
+ c = 1e-4,)
227
+ # Loss and optimizer
228
+ depth_loss_mega = RobustLosses(
229
+ ce_weight=0.01,
230
+ local_dist={1:4, 2:4, 4:8, 8:8},
231
+ local_largest_scale=8,
232
+ depth_interpolation_mode=depth_interpolation_mode,
233
+ alpha = 0.5,
234
+ c = 1e-4,)
235
+ parameters = [
236
+ {"params": model.encoder.parameters(), "lr": romatch.STEP_SIZE * 5e-6 / 8},
237
+ {"params": model.decoder.parameters(), "lr": romatch.STEP_SIZE * 1e-4 / 8},
238
+ ]
239
+ optimizer = torch.optim.AdamW(parameters, weight_decay=0.01)
240
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
241
+ optimizer, milestones=[(9*N/romatch.STEP_SIZE)//10])
242
+ megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000, h=h,w=w)
243
+ checkpointer = CheckPoint(checkpoint_dir, experiment_name)
244
+ model, optimizer, lr_scheduler, global_step = checkpointer.load(model, optimizer, lr_scheduler, global_step)
245
+ romatch.GLOBAL_STEP = global_step
246
+ ddp_model = DDP(model, device_ids=[device_id], find_unused_parameters = False, gradient_as_bucket_view=True)
247
+ grad_scaler = torch.cuda.amp.GradScaler(growth_interval=1_000_000)
248
+ grad_clip_norm = 0.01
249
+ for n in range(romatch.GLOBAL_STEP, N, k * romatch.STEP_SIZE):
250
+ mega_sampler = torch.utils.data.WeightedRandomSampler(
251
+ mega_ws, num_samples = batch_size * k, replacement=False
252
+ )
253
+ mega_dataloader = iter(
254
+ torch.utils.data.DataLoader(
255
+ megadepth_train,
256
+ batch_size = batch_size,
257
+ sampler = mega_sampler,
258
+ num_workers = 8,
259
+ )
260
+ )
261
+ scannet_ws_sampler = torch.utils.data.WeightedRandomSampler(
262
+ scannet_ws, num_samples=batch_size * k, replacement=False
263
+ )
264
+ scannet_dataloader = iter(
265
+ torch.utils.data.DataLoader(
266
+ scannet_train,
267
+ batch_size=batch_size,
268
+ sampler=scannet_ws_sampler,
269
+ num_workers=gpus * 8,
270
+ )
271
+ )
272
+ for n_k in tqdm(range(n, n + 2 * k, 2),disable = romatch.RANK > 0):
273
+ train_k_steps(
274
+ n_k, 1, mega_dataloader, ddp_model, depth_loss_mega, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm, progress_bar=False
275
+ )
276
+ train_k_steps(
277
+ n_k + 1, 1, scannet_dataloader, ddp_model, depth_loss_scannet, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm, progress_bar=False
278
+ )
279
+ checkpointer.save(model, optimizer, lr_scheduler, romatch.GLOBAL_STEP)
280
+ wandb.log(megadense_benchmark.benchmark(model), step = romatch.GLOBAL_STEP)
281
+
282
+ def test_scannet(model, name, resolution, sample_mode):
283
+ scannet_benchmark = ScanNetBenchmark("data/scannet")
284
+ scannet_results = scannet_benchmark.benchmark(model)
285
+ json.dump(scannet_results, open(f"results/scannet_{name}.json", "w"))
286
+
287
+ if __name__ == "__main__":
288
+ import warnings
289
+ warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
290
+ warnings.filterwarnings('ignore')#, category=UserWarning)#, message='WARNING batched routines are designed for small sizes.')
291
+ os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations
292
+ os.environ["OMP_NUM_THREADS"] = "16"
293
+
294
+ import romatch
295
+ parser = ArgumentParser()
296
+ parser.add_argument("--test", action='store_true')
297
+ parser.add_argument("--debug_mode", action='store_true')
298
+ parser.add_argument("--dont_log_wandb", action='store_true')
299
+ parser.add_argument("--train_resolution", default='medium')
300
+ parser.add_argument("--gpu_batch_size", default=4, type=int)
301
+ parser.add_argument("--wandb_entity", required = False)
302
+
303
+ args, _ = parser.parse_known_args()
304
+ romatch.DEBUG_MODE = args.debug_mode
305
+ if not args.test:
306
+ train(args)
307
+ experiment_name = os.path.splitext(os.path.basename(__file__))[0]
308
+ checkpoint_dir = "workspace/"
309
+ checkpoint_name = checkpoint_dir + experiment_name + ".pth"
310
+ test_resolution = "medium"
311
+ sample_mode = "threshold_balanced"
312
+ symmetric = True
313
+ upsample_preds = False
314
+ attenuate_cert = True
315
+
316
+ model = get_model(pretrained_backbone=False, resolution = test_resolution, sample_mode = sample_mode, upsample_preds = upsample_preds, symmetric=symmetric, name=experiment_name, attenuate_cert = attenuate_cert)
317
+ model = model.cuda()
318
+ states = torch.load(checkpoint_name)
319
+ model.load_state_dict(states["model"])
320
+ test_scannet(model, experiment_name, resolution = test_resolution, sample_mode = sample_mode)
submodules/RoMa/experiments/train_roma_outdoor.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from argparse import ArgumentParser
4
+
5
+ from torch import nn
6
+ from torch.utils.data import ConcatDataset
7
+ import torch.distributed as dist
8
+ from torch.nn.parallel import DistributedDataParallel as DDP
9
+ import json
10
+ import wandb
11
+
12
+ from romatch.benchmarks import MegadepthDenseBenchmark
13
+ from romatch.datasets.megadepth import MegadepthBuilder
14
+ from romatch.losses.robust_loss import RobustLosses
15
+ from romatch.benchmarks import MegaDepthPoseEstimationBenchmark, MegadepthDenseBenchmark, HpatchesHomogBenchmark
16
+
17
+ from romatch.train.train import train_k_steps
18
+ from romatch.models.matcher import *
19
+ from romatch.models.transformer import Block, TransformerDecoder, MemEffAttention
20
+ from romatch.models.encoders import *
21
+ from romatch.checkpointing import CheckPoint
22
+
23
+ resolutions = {"low":(448, 448), "medium":(14*8*5, 14*8*5), "high":(14*8*6, 14*8*6)}
24
+
25
+ def get_model(pretrained_backbone=True, resolution = "medium", **kwargs):
26
+ import warnings
27
+ warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
28
+ gp_dim = 512
29
+ feat_dim = 512
30
+ decoder_dim = gp_dim + feat_dim
31
+ cls_to_coord_res = 64
32
+ coordinate_decoder = TransformerDecoder(
33
+ nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]),
34
+ decoder_dim,
35
+ cls_to_coord_res**2 + 1,
36
+ is_classifier=True,
37
+ amp = True,
38
+ pos_enc = False,)
39
+ dw = True
40
+ hidden_blocks = 8
41
+ kernel_size = 5
42
+ displacement_emb = "linear"
43
+ disable_local_corr_grad = True
44
+
45
+ conv_refiner = nn.ModuleDict(
46
+ {
47
+ "16": ConvRefiner(
48
+ 2 * 512+128+(2*7+1)**2,
49
+ 2 * 512+128+(2*7+1)**2,
50
+ 2 + 1,
51
+ kernel_size=kernel_size,
52
+ dw=dw,
53
+ hidden_blocks=hidden_blocks,
54
+ displacement_emb=displacement_emb,
55
+ displacement_emb_dim=128,
56
+ local_corr_radius = 7,
57
+ corr_in_other = True,
58
+ amp = True,
59
+ disable_local_corr_grad = disable_local_corr_grad,
60
+ bn_momentum = 0.01,
61
+ ),
62
+ "8": ConvRefiner(
63
+ 2 * 512+64+(2*3+1)**2,
64
+ 2 * 512+64+(2*3+1)**2,
65
+ 2 + 1,
66
+ kernel_size=kernel_size,
67
+ dw=dw,
68
+ hidden_blocks=hidden_blocks,
69
+ displacement_emb=displacement_emb,
70
+ displacement_emb_dim=64,
71
+ local_corr_radius = 3,
72
+ corr_in_other = True,
73
+ amp = True,
74
+ disable_local_corr_grad = disable_local_corr_grad,
75
+ bn_momentum = 0.01,
76
+ ),
77
+ "4": ConvRefiner(
78
+ 2 * 256+32+(2*2+1)**2,
79
+ 2 * 256+32+(2*2+1)**2,
80
+ 2 + 1,
81
+ kernel_size=kernel_size,
82
+ dw=dw,
83
+ hidden_blocks=hidden_blocks,
84
+ displacement_emb=displacement_emb,
85
+ displacement_emb_dim=32,
86
+ local_corr_radius = 2,
87
+ corr_in_other = True,
88
+ amp = True,
89
+ disable_local_corr_grad = disable_local_corr_grad,
90
+ bn_momentum = 0.01,
91
+ ),
92
+ "2": ConvRefiner(
93
+ 2 * 64+16,
94
+ 128+16,
95
+ 2 + 1,
96
+ kernel_size=kernel_size,
97
+ dw=dw,
98
+ hidden_blocks=hidden_blocks,
99
+ displacement_emb=displacement_emb,
100
+ displacement_emb_dim=16,
101
+ amp = True,
102
+ disable_local_corr_grad = disable_local_corr_grad,
103
+ bn_momentum = 0.01,
104
+ ),
105
+ "1": ConvRefiner(
106
+ 2 * 9 + 6,
107
+ 24,
108
+ 2 + 1,
109
+ kernel_size=kernel_size,
110
+ dw=dw,
111
+ hidden_blocks = hidden_blocks,
112
+ displacement_emb = displacement_emb,
113
+ displacement_emb_dim = 6,
114
+ amp = True,
115
+ disable_local_corr_grad = disable_local_corr_grad,
116
+ bn_momentum = 0.01,
117
+ ),
118
+ }
119
+ )
120
+ kernel_temperature = 0.2
121
+ learn_temperature = False
122
+ no_cov = True
123
+ kernel = CosKernel
124
+ only_attention = False
125
+ basis = "fourier"
126
+ gp16 = GP(
127
+ kernel,
128
+ T=kernel_temperature,
129
+ learn_temperature=learn_temperature,
130
+ only_attention=only_attention,
131
+ gp_dim=gp_dim,
132
+ basis=basis,
133
+ no_cov=no_cov,
134
+ )
135
+ gps = nn.ModuleDict({"16": gp16})
136
+ proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512))
137
+ proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512))
138
+ proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256))
139
+ proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64))
140
+ proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9))
141
+ proj = nn.ModuleDict({
142
+ "16": proj16,
143
+ "8": proj8,
144
+ "4": proj4,
145
+ "2": proj2,
146
+ "1": proj1,
147
+ })
148
+ displacement_dropout_p = 0.0
149
+ gm_warp_dropout_p = 0.0
150
+ decoder = Decoder(coordinate_decoder,
151
+ gps,
152
+ proj,
153
+ conv_refiner,
154
+ detach=True,
155
+ scales=["16", "8", "4", "2", "1"],
156
+ displacement_dropout_p = displacement_dropout_p,
157
+ gm_warp_dropout_p = gm_warp_dropout_p)
158
+ h,w = resolutions[resolution]
159
+ encoder = CNNandDinov2(
160
+ cnn_kwargs = dict(
161
+ pretrained=pretrained_backbone,
162
+ amp = True),
163
+ amp = True,
164
+ use_vgg = True,
165
+ )
166
+ matcher = RegressionMatcher(encoder, decoder, h=h, w=w,**kwargs)
167
+ return matcher
168
+
169
+ def train(args):
170
+ dist.init_process_group('nccl')
171
+ #torch._dynamo.config.verbose=True
172
+ gpus = int(os.environ['WORLD_SIZE'])
173
+ # create model and move it to GPU with id rank
174
+ rank = dist.get_rank()
175
+ print(f"Start running DDP on rank {rank}")
176
+ device_id = rank % torch.cuda.device_count()
177
+ romatch.LOCAL_RANK = device_id
178
+ torch.cuda.set_device(device_id)
179
+
180
+ resolution = args.train_resolution
181
+ wandb_log = not args.dont_log_wandb
182
+ experiment_name = os.path.splitext(os.path.basename(__file__))[0]
183
+ wandb_mode = "online" if wandb_log and rank == 0 else "disabled"
184
+ wandb.init(project="romatch", entity=args.wandb_entity, name=experiment_name, reinit=False, mode = wandb_mode)
185
+ checkpoint_dir = "workspace/checkpoints/"
186
+ h,w = resolutions[resolution]
187
+ model = get_model(pretrained_backbone=True, resolution=resolution, attenuate_cert = False).to(device_id)
188
+ # Num steps
189
+ global_step = 0
190
+ batch_size = args.gpu_batch_size
191
+ step_size = gpus*batch_size
192
+ romatch.STEP_SIZE = step_size
193
+
194
+ N = (32 * 250000) # 250k steps of batch size 32
195
+ # checkpoint every
196
+ k = 25000 // romatch.STEP_SIZE
197
+
198
+ # Data
199
+ mega = MegadepthBuilder(data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True)
200
+ use_horizontal_flip_aug = True
201
+ rot_prob = 0
202
+ depth_interpolation_mode = "bilinear"
203
+ megadepth_train1 = mega.build_scenes(
204
+ split="train_loftr", min_overlap=0.01, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
205
+ ht=h,wt=w,
206
+ )
207
+ megadepth_train2 = mega.build_scenes(
208
+ split="train_loftr", min_overlap=0.35, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
209
+ ht=h,wt=w,
210
+ )
211
+ megadepth_train = ConcatDataset(megadepth_train1 + megadepth_train2)
212
+ mega_ws = mega.weight_scenes(megadepth_train, alpha=0.75)
213
+ # Loss and optimizer
214
+ depth_loss = RobustLosses(
215
+ ce_weight=0.01,
216
+ local_dist={1:4, 2:4, 4:8, 8:8},
217
+ local_largest_scale=8,
218
+ depth_interpolation_mode=depth_interpolation_mode,
219
+ alpha = 0.5,
220
+ c = 1e-4,)
221
+ parameters = [
222
+ {"params": model.encoder.parameters(), "lr": romatch.STEP_SIZE * 5e-6 / 8},
223
+ {"params": model.decoder.parameters(), "lr": romatch.STEP_SIZE * 1e-4 / 8},
224
+ ]
225
+ optimizer = torch.optim.AdamW(parameters, weight_decay=0.01)
226
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
227
+ optimizer, milestones=[(9*N/romatch.STEP_SIZE)//10])
228
+ megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000, h=h,w=w)
229
+ checkpointer = CheckPoint(checkpoint_dir, experiment_name)
230
+ model, optimizer, lr_scheduler, global_step = checkpointer.load(model, optimizer, lr_scheduler, global_step)
231
+ romatch.GLOBAL_STEP = global_step
232
+ ddp_model = DDP(model, device_ids=[device_id], find_unused_parameters = False, gradient_as_bucket_view=True)
233
+ grad_scaler = torch.cuda.amp.GradScaler(growth_interval=1_000_000)
234
+ grad_clip_norm = 0.01
235
+ for n in range(romatch.GLOBAL_STEP, N, k * romatch.STEP_SIZE):
236
+ mega_sampler = torch.utils.data.WeightedRandomSampler(
237
+ mega_ws, num_samples = batch_size * k, replacement=False
238
+ )
239
+ mega_dataloader = iter(
240
+ torch.utils.data.DataLoader(
241
+ megadepth_train,
242
+ batch_size = batch_size,
243
+ sampler = mega_sampler,
244
+ num_workers = 8,
245
+ )
246
+ )
247
+ train_k_steps(
248
+ n, k, mega_dataloader, ddp_model, depth_loss, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm,
249
+ )
250
+ checkpointer.save(model, optimizer, lr_scheduler, romatch.GLOBAL_STEP)
251
+ wandb.log(megadense_benchmark.benchmark(model), step = romatch.GLOBAL_STEP)
252
+
253
+ def test_mega_8_scenes(model, name):
254
+ mega_8_scenes_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth",
255
+ scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
256
+ 'mega_8_scenes_0025_0.1_0.3.npz',
257
+ 'mega_8_scenes_0021_0.1_0.3.npz',
258
+ 'mega_8_scenes_0008_0.1_0.3.npz',
259
+ 'mega_8_scenes_0032_0.1_0.3.npz',
260
+ 'mega_8_scenes_1589_0.1_0.3.npz',
261
+ 'mega_8_scenes_0063_0.1_0.3.npz',
262
+ 'mega_8_scenes_0024_0.1_0.3.npz',
263
+ 'mega_8_scenes_0019_0.3_0.5.npz',
264
+ 'mega_8_scenes_0025_0.3_0.5.npz',
265
+ 'mega_8_scenes_0021_0.3_0.5.npz',
266
+ 'mega_8_scenes_0008_0.3_0.5.npz',
267
+ 'mega_8_scenes_0032_0.3_0.5.npz',
268
+ 'mega_8_scenes_1589_0.3_0.5.npz',
269
+ 'mega_8_scenes_0063_0.3_0.5.npz',
270
+ 'mega_8_scenes_0024_0.3_0.5.npz'])
271
+ mega_8_scenes_results = mega_8_scenes_benchmark.benchmark(model, model_name=name)
272
+ print(mega_8_scenes_results)
273
+ json.dump(mega_8_scenes_results, open(f"results/mega_8_scenes_{name}.json", "w"))
274
+
275
+ def test_mega1500(model, name):
276
+ mega1500_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth")
277
+ mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
278
+ json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w"))
279
+
280
+ def test_mega_dense(model, name):
281
+ megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000)
282
+ megadense_results = megadense_benchmark.benchmark(model)
283
+ json.dump(megadense_results, open(f"results/mega_dense_{name}.json", "w"))
284
+
285
+ def test_hpatches(model, name):
286
+ hpatches_benchmark = HpatchesHomogBenchmark("data/hpatches")
287
+ hpatches_results = hpatches_benchmark.benchmark(model)
288
+ json.dump(hpatches_results, open(f"results/hpatches_{name}.json", "w"))
289
+
290
+
291
+ if __name__ == "__main__":
292
+ os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations
293
+ os.environ["OMP_NUM_THREADS"] = "16"
294
+ torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
295
+ import romatch
296
+ parser = ArgumentParser()
297
+ parser.add_argument("--only_test", action='store_true')
298
+ parser.add_argument("--debug_mode", action='store_true')
299
+ parser.add_argument("--dont_log_wandb", action='store_true')
300
+ parser.add_argument("--train_resolution", default='medium')
301
+ parser.add_argument("--gpu_batch_size", default=8, type=int)
302
+ parser.add_argument("--wandb_entity", required = False)
303
+
304
+ args, _ = parser.parse_known_args()
305
+ romatch.DEBUG_MODE = args.debug_mode
306
+ if not args.only_test:
307
+ train(args)
submodules/RoMa/experiments/train_tiny_roma_v1_outdoor.py ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import os
5
+ import torch
6
+ from argparse import ArgumentParser
7
+ from pathlib import Path
8
+ import math
9
+ import numpy as np
10
+
11
+ from torch import nn
12
+ from torch.utils.data import ConcatDataset
13
+ import torch.distributed as dist
14
+ from torch.nn.parallel import DistributedDataParallel as DDP
15
+ import json
16
+ import wandb
17
+ from PIL import Image
18
+ from torchvision.transforms import ToTensor
19
+
20
+ from romatch.benchmarks import MegadepthDenseBenchmark, ScanNetBenchmark
21
+ from romatch.benchmarks import Mega1500PoseLibBenchmark, ScanNetPoselibBenchmark
22
+ from romatch.datasets.megadepth import MegadepthBuilder
23
+ from romatch.losses.robust_loss_tiny_roma import RobustLosses
24
+ from romatch.benchmarks import MegaDepthPoseEstimationBenchmark, MegadepthDenseBenchmark, HpatchesHomogBenchmark
25
+ from romatch.train.train import train_k_steps
26
+ from romatch.checkpointing import CheckPoint
27
+
28
+ resolutions = {"low":(448, 448), "medium":(14*8*5, 14*8*5), "high":(14*8*6, 14*8*6), "xfeat": (600,800), "big": (768, 1024)}
29
+
30
+ def kde(x, std = 0.1):
31
+ # use a gaussian kernel to estimate density
32
+ x = x.half() # Do it in half precision TODO: remove hardcoding
33
+ scores = (-torch.cdist(x,x)**2/(2*std**2)).exp()
34
+ density = scores.sum(dim=-1)
35
+ return density
36
+
37
+ class BasicLayer(nn.Module):
38
+ """
39
+ Basic Convolutional Layer: Conv2d -> BatchNorm -> ReLU
40
+ """
41
+ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, bias=False, relu = True):
42
+ super().__init__()
43
+ self.layer = nn.Sequential(
44
+ nn.Conv2d( in_channels, out_channels, kernel_size, padding = padding, stride=stride, dilation=dilation, bias = bias),
45
+ nn.BatchNorm2d(out_channels, affine=False),
46
+ nn.ReLU(inplace = True) if relu else nn.Identity()
47
+ )
48
+
49
+ def forward(self, x):
50
+ return self.layer(x)
51
+
52
+ class XFeatModel(nn.Module):
53
+ """
54
+ Implementation of architecture described in
55
+ "XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024."
56
+ """
57
+
58
+ def __init__(self, xfeat = None,
59
+ freeze_xfeat = True,
60
+ sample_mode = "threshold_balanced",
61
+ symmetric = False,
62
+ exact_softmax = False):
63
+ super().__init__()
64
+ if xfeat is None:
65
+ xfeat = torch.hub.load('verlab/accelerated_features', 'XFeat', pretrained = True, top_k = 4096).net
66
+ del xfeat.heatmap_head, xfeat.keypoint_head, xfeat.fine_matcher
67
+ if freeze_xfeat:
68
+ xfeat.train(False)
69
+ self.xfeat = [xfeat]# hide params from ddp
70
+ else:
71
+ self.xfeat = nn.ModuleList([xfeat])
72
+ self.freeze_xfeat = freeze_xfeat
73
+ match_dim = 256
74
+ self.coarse_matcher = nn.Sequential(
75
+ BasicLayer(64+64+2, match_dim,),
76
+ BasicLayer(match_dim, match_dim,),
77
+ BasicLayer(match_dim, match_dim,),
78
+ BasicLayer(match_dim, match_dim,),
79
+ nn.Conv2d(match_dim, 3, kernel_size=1, bias=True, padding=0))
80
+ fine_match_dim = 64
81
+ self.fine_matcher = nn.Sequential(
82
+ BasicLayer(24+24+2, fine_match_dim,),
83
+ BasicLayer(fine_match_dim, fine_match_dim,),
84
+ BasicLayer(fine_match_dim, fine_match_dim,),
85
+ BasicLayer(fine_match_dim, fine_match_dim,),
86
+ nn.Conv2d(fine_match_dim, 3, kernel_size=1, bias=True, padding=0),)
87
+ self.sample_mode = sample_mode
88
+ self.sample_thresh = 0.2
89
+ self.symmetric = symmetric
90
+ self.exact_softmax = exact_softmax
91
+
92
+ @property
93
+ def device(self):
94
+ return self.fine_matcher[-1].weight.device
95
+
96
+ def preprocess_tensor(self, x):
97
+ """ Guarantee that image is divisible by 32 to avoid aliasing artifacts. """
98
+ H, W = x.shape[-2:]
99
+ _H, _W = (H//32) * 32, (W//32) * 32
100
+ rh, rw = H/_H, W/_W
101
+
102
+ x = F.interpolate(x, (_H, _W), mode='bilinear', align_corners=False)
103
+ return x, rh, rw
104
+
105
+ def forward_single(self, x):
106
+ with torch.inference_mode(self.freeze_xfeat or not self.training):
107
+ xfeat = self.xfeat[0]
108
+ with torch.no_grad():
109
+ x = x.mean(dim=1, keepdim = True)
110
+ x = xfeat.norm(x)
111
+
112
+ #main backbone
113
+ x1 = xfeat.block1(x)
114
+ x2 = xfeat.block2(x1 + xfeat.skip1(x))
115
+ x3 = xfeat.block3(x2)
116
+ x4 = xfeat.block4(x3)
117
+ x5 = xfeat.block5(x4)
118
+ x4 = F.interpolate(x4, (x3.shape[-2], x3.shape[-1]), mode='bilinear')
119
+ x5 = F.interpolate(x5, (x3.shape[-2], x3.shape[-1]), mode='bilinear')
120
+ feats = xfeat.block_fusion( x3 + x4 + x5 )
121
+ if self.freeze_xfeat:
122
+ return x2.clone(), feats.clone()
123
+ return x2, feats
124
+
125
+ def to_pixel_coordinates(self, coords, H_A, W_A, H_B = None, W_B = None):
126
+ if coords.shape[-1] == 2:
127
+ return self._to_pixel_coordinates(coords, H_A, W_A)
128
+
129
+ if isinstance(coords, (list, tuple)):
130
+ kpts_A, kpts_B = coords[0], coords[1]
131
+ else:
132
+ kpts_A, kpts_B = coords[...,:2], coords[...,2:]
133
+ return self._to_pixel_coordinates(kpts_A, H_A, W_A), self._to_pixel_coordinates(kpts_B, H_B, W_B)
134
+
135
+ def _to_pixel_coordinates(self, coords, H, W):
136
+ kpts = torch.stack((W/2 * (coords[...,0]+1), H/2 * (coords[...,1]+1)),axis=-1)
137
+ return kpts
138
+
139
+ def pos_embed(self, corr_volume: torch.Tensor):
140
+ B, H1, W1, H0, W0 = corr_volume.shape
141
+ grid = torch.stack(
142
+ torch.meshgrid(
143
+ torch.linspace(-1+1/W1,1-1/W1, W1),
144
+ torch.linspace(-1+1/H1,1-1/H1, H1),
145
+ indexing = "xy"),
146
+ dim = -1).float().to(corr_volume).reshape(H1*W1, 2)
147
+ down = 4
148
+ if not self.training and not self.exact_softmax:
149
+ grid_lr = torch.stack(
150
+ torch.meshgrid(
151
+ torch.linspace(-1+down/W1,1-down/W1, W1//down),
152
+ torch.linspace(-1+down/H1,1-down/H1, H1//down),
153
+ indexing = "xy"),
154
+ dim = -1).float().to(corr_volume).reshape(H1*W1 //down**2, 2)
155
+ cv = corr_volume
156
+ best_match = cv.reshape(B,H1*W1,H0,W0).amax(dim=1) # B, HW, H, W
157
+ P_lowres = torch.cat((cv[:,::down,::down].reshape(B,H1*W1 // down**2,H0,W0), best_match[:,None]),dim=1).softmax(dim=1)
158
+ pos_embeddings = torch.einsum('bchw,cd->bdhw', P_lowres[:,:-1], grid_lr)
159
+ pos_embeddings += P_lowres[:,-1] * grid[best_match].permute(0,3,1,2)
160
+ else:
161
+ P = corr_volume.reshape(B,H1*W1,H0,W0).softmax(dim=1) # B, HW, H, W
162
+ pos_embeddings = torch.einsum('bchw,cd->bdhw', P, grid)
163
+ return pos_embeddings
164
+
165
+ def visualize_warp(self, warp, certainty, im_A = None, im_B = None,
166
+ im_A_path = None, im_B_path = None, symmetric = True, save_path = None, unnormalize = False):
167
+ device = warp.device
168
+ H,W2,_ = warp.shape
169
+ W = W2//2 if symmetric else W2
170
+ if im_A is None:
171
+ from PIL import Image
172
+ im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")
173
+ if not isinstance(im_A, torch.Tensor):
174
+ im_A = im_A.resize((W,H))
175
+ im_B = im_B.resize((W,H))
176
+ x_B = (torch.tensor(np.array(im_B)) / 255).to(device).permute(2, 0, 1)
177
+ if symmetric:
178
+ x_A = (torch.tensor(np.array(im_A)) / 255).to(device).permute(2, 0, 1)
179
+ else:
180
+ if symmetric:
181
+ x_A = im_A
182
+ x_B = im_B
183
+ im_A_transfer_rgb = F.grid_sample(
184
+ x_B[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
185
+ )[0]
186
+ if symmetric:
187
+ im_B_transfer_rgb = F.grid_sample(
188
+ x_A[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
189
+ )[0]
190
+ warp_im = torch.cat((im_A_transfer_rgb,im_B_transfer_rgb),dim=2)
191
+ white_im = torch.ones((H,2*W),device=device)
192
+ else:
193
+ warp_im = im_A_transfer_rgb
194
+ white_im = torch.ones((H, W), device = device)
195
+ vis_im = certainty * warp_im + (1 - certainty) * white_im
196
+ if save_path is not None:
197
+ from romatch.utils import tensor_to_pil
198
+ tensor_to_pil(vis_im, unnormalize=unnormalize).save(save_path)
199
+ return vis_im
200
+
201
+ def corr_volume(self, feat0, feat1):
202
+ """
203
+ input:
204
+ feat0 -> torch.Tensor(B, C, H, W)
205
+ feat1 -> torch.Tensor(B, C, H, W)
206
+ return:
207
+ corr_volume -> torch.Tensor(B, H, W, H, W)
208
+ """
209
+ B, C, H0, W0 = feat0.shape
210
+ B, C, H1, W1 = feat1.shape
211
+ feat0 = feat0.view(B, C, H0*W0)
212
+ feat1 = feat1.view(B, C, H1*W1)
213
+ corr_volume = torch.einsum('bci,bcj->bji', feat0, feat1).reshape(B, H1, W1, H0 , W0)/math.sqrt(C) #16*16*16
214
+ return corr_volume
215
+
216
+ @torch.inference_mode()
217
+ def match_from_path(self, im0_path, im1_path):
218
+ device = self.device
219
+ im0 = ToTensor()(Image.open(im0_path))[None].to(device)
220
+ im1 = ToTensor()(Image.open(im1_path))[None].to(device)
221
+ return self.match(im0, im1, batched = False)
222
+
223
+ @torch.inference_mode()
224
+ def match(self, im0, im1, *args, batched = True):
225
+ # stupid
226
+ if isinstance(im0, (str, Path)):
227
+ return self.match_from_path(im0, im1)
228
+ elif isinstance(im0, Image.Image):
229
+ batched = False
230
+ device = self.device
231
+ im0 = ToTensor()(im0)[None].to(device)
232
+ im1 = ToTensor()(im1)[None].to(device)
233
+
234
+ B,C,H0,W0 = im0.shape
235
+ B,C,H1,W1 = im1.shape
236
+ self.train(False)
237
+ corresps = self.forward({"im_A":im0, "im_B":im1})
238
+ #return 1,1
239
+ flow = F.interpolate(
240
+ corresps[4]["flow"],
241
+ size = (H0, W0),
242
+ mode = "bilinear", align_corners = False).permute(0,2,3,1).reshape(B,H0,W0,2)
243
+ grid = torch.stack(
244
+ torch.meshgrid(
245
+ torch.linspace(-1+1/W0,1-1/W0, W0),
246
+ torch.linspace(-1+1/H0,1-1/H0, H0),
247
+ indexing = "xy"),
248
+ dim = -1).float().to(flow.device).expand(B, H0, W0, 2)
249
+
250
+ certainty = F.interpolate(corresps[4]["certainty"], size = (H0,W0), mode = "bilinear", align_corners = False)
251
+ warp, cert = torch.cat((grid, flow), dim = -1), certainty[:,0].sigmoid()
252
+ if batched:
253
+ return warp, cert
254
+ else:
255
+ return warp[0], cert[0]
256
+
257
+ def sample(
258
+ self,
259
+ matches,
260
+ certainty,
261
+ num=10000,
262
+ ):
263
+ if "threshold" in self.sample_mode:
264
+ upper_thresh = self.sample_thresh
265
+ certainty = certainty.clone()
266
+ certainty[certainty > upper_thresh] = 1
267
+ matches, certainty = (
268
+ matches.reshape(-1, 4),
269
+ certainty.reshape(-1),
270
+ )
271
+ expansion_factor = 4 if "balanced" in self.sample_mode else 1
272
+ good_samples = torch.multinomial(certainty,
273
+ num_samples = min(expansion_factor*num, len(certainty)),
274
+ replacement=False)
275
+ good_matches, good_certainty = matches[good_samples], certainty[good_samples]
276
+ if "balanced" not in self.sample_mode:
277
+ return good_matches, good_certainty
278
+ density = kde(good_matches, std=0.1)
279
+ p = 1 / (density+1)
280
+ p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones
281
+ balanced_samples = torch.multinomial(p,
282
+ num_samples = min(num,len(good_certainty)),
283
+ replacement=False)
284
+ return good_matches[balanced_samples], good_certainty[balanced_samples]
285
+
286
+ def forward(self, batch):
287
+ """
288
+ input:
289
+ x -> torch.Tensor(B, C, H, W) grayscale or rgb images
290
+ return:
291
+
292
+ """
293
+ im0 = batch["im_A"]
294
+ im1 = batch["im_B"]
295
+ corresps = {}
296
+ im0, rh0, rw0 = self.preprocess_tensor(im0)
297
+ im1, rh1, rw1 = self.preprocess_tensor(im1)
298
+ B, C, H0, W0 = im0.shape
299
+ B, C, H1, W1 = im1.shape
300
+ to_normalized = torch.tensor((2/W1, 2/H1, 1)).to(im0.device)[None,:,None,None]
301
+
302
+ if im0.shape[-2:] == im1.shape[-2:]:
303
+ x = torch.cat([im0, im1], dim=0)
304
+ x = self.forward_single(x)
305
+ feats_x0_c, feats_x1_c = x[1].chunk(2)
306
+ feats_x0_f, feats_x1_f = x[0].chunk(2)
307
+ else:
308
+ feats_x0_f, feats_x0_c = self.forward_single(im0)
309
+ feats_x1_f, feats_x1_c = self.forward_single(im1)
310
+ corr_volume = self.corr_volume(feats_x0_c, feats_x1_c)
311
+ coarse_warp = self.pos_embed(corr_volume)
312
+ coarse_matches = torch.cat((coarse_warp, torch.zeros_like(coarse_warp[:,-1:])), dim=1)
313
+ feats_x1_c_warped = F.grid_sample(feats_x1_c, coarse_matches.permute(0, 2, 3, 1)[...,:2], mode = 'bilinear', align_corners = False)
314
+ coarse_matches_delta = self.coarse_matcher(torch.cat((feats_x0_c, feats_x1_c_warped, coarse_warp), dim=1))
315
+ coarse_matches = coarse_matches + coarse_matches_delta * to_normalized
316
+ corresps[8] = {"flow": coarse_matches[:,:2], "certainty": coarse_matches[:,2:]}
317
+ coarse_matches_up = F.interpolate(coarse_matches, size = feats_x0_f.shape[-2:], mode = "bilinear", align_corners = False)
318
+ coarse_matches_up_detach = coarse_matches_up.detach()#note the detach
319
+ feats_x1_f_warped = F.grid_sample(feats_x1_f, coarse_matches_up_detach.permute(0, 2, 3, 1)[...,:2], mode = 'bilinear', align_corners = False)
320
+ fine_matches_delta = self.fine_matcher(torch.cat((feats_x0_f, feats_x1_f_warped, coarse_matches_up_detach[:,:2]), dim=1))
321
+ fine_matches = coarse_matches_up_detach+fine_matches_delta * to_normalized
322
+ corresps[4] = {"flow": fine_matches[:,:2], "certainty": fine_matches[:,2:]}
323
+ return corresps
324
+
325
+
326
+
327
+
328
+
329
+ def train(args):
330
+ rank = 0
331
+ gpus = 1
332
+ device_id = rank % torch.cuda.device_count()
333
+ romatch.LOCAL_RANK = 0
334
+ torch.cuda.set_device(device_id)
335
+
336
+ resolution = "big"
337
+ wandb_log = not args.dont_log_wandb
338
+ experiment_name = Path(__file__).stem
339
+ wandb_mode = "online" if wandb_log and rank == 0 else "disabled"
340
+ wandb.init(project="romatch", entity=args.wandb_entity, name=experiment_name, reinit=False, mode = wandb_mode)
341
+ checkpoint_dir = "workspace/checkpoints/"
342
+ h,w = resolutions[resolution]
343
+ model = XFeatModel(freeze_xfeat = False).to(device_id)
344
+ # Num steps
345
+ global_step = 0
346
+ batch_size = args.gpu_batch_size
347
+ step_size = gpus*batch_size
348
+ romatch.STEP_SIZE = step_size
349
+
350
+ N = 2_000_000 # 2M pairs
351
+ # checkpoint every
352
+ k = 25000 // romatch.STEP_SIZE
353
+
354
+ # Data
355
+ mega = MegadepthBuilder(data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True)
356
+ use_horizontal_flip_aug = True
357
+ normalize = False # don't imgnet normalize
358
+ rot_prob = 0
359
+ depth_interpolation_mode = "bilinear"
360
+ megadepth_train1 = mega.build_scenes(
361
+ split="train_loftr", min_overlap=0.01, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
362
+ ht=h,wt=w, normalize = normalize
363
+ )
364
+ megadepth_train2 = mega.build_scenes(
365
+ split="train_loftr", min_overlap=0.35, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
366
+ ht=h,wt=w, normalize = normalize
367
+ )
368
+ megadepth_train = ConcatDataset(megadepth_train1 + megadepth_train2)
369
+ mega_ws = mega.weight_scenes(megadepth_train, alpha=0.75)
370
+ # Loss and optimizer
371
+ depth_loss = RobustLosses(
372
+ ce_weight=0.01,
373
+ local_dist={4:4},
374
+ depth_interpolation_mode=depth_interpolation_mode,
375
+ alpha = {4:0.15, 8:0.15},
376
+ c = 1e-4,
377
+ epe_mask_prob_th = 0.001,
378
+ )
379
+ parameters = [
380
+ {"params": model.parameters(), "lr": romatch.STEP_SIZE * 1e-4 / 8},
381
+ ]
382
+ optimizer = torch.optim.AdamW(parameters, weight_decay=0.01)
383
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
384
+ optimizer, milestones=[(9*N/romatch.STEP_SIZE)//10])
385
+ #megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000, h=h,w=w)
386
+ mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 30)
387
+
388
+ checkpointer = CheckPoint(checkpoint_dir, experiment_name)
389
+ model, optimizer, lr_scheduler, global_step = checkpointer.load(model, optimizer, lr_scheduler, global_step)
390
+ romatch.GLOBAL_STEP = global_step
391
+ grad_scaler = torch.cuda.amp.GradScaler(growth_interval=1_000_000)
392
+ grad_clip_norm = 0.01
393
+ #megadense_benchmark.benchmark(model)
394
+ for n in range(romatch.GLOBAL_STEP, N, k * romatch.STEP_SIZE):
395
+ mega_sampler = torch.utils.data.WeightedRandomSampler(
396
+ mega_ws, num_samples = batch_size * k, replacement=False
397
+ )
398
+ mega_dataloader = iter(
399
+ torch.utils.data.DataLoader(
400
+ megadepth_train,
401
+ batch_size = batch_size,
402
+ sampler = mega_sampler,
403
+ num_workers = 8,
404
+ )
405
+ )
406
+ train_k_steps(
407
+ n, k, mega_dataloader, model, depth_loss, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm,
408
+ )
409
+ checkpointer.save(model, optimizer, lr_scheduler, romatch.GLOBAL_STEP)
410
+ wandb.log(mega1500_benchmark.benchmark(model, model_name=experiment_name), step = romatch.GLOBAL_STEP)
411
+
412
+ def test_mega_8_scenes(model, name):
413
+ mega_8_scenes_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth",
414
+ scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
415
+ 'mega_8_scenes_0025_0.1_0.3.npz',
416
+ 'mega_8_scenes_0021_0.1_0.3.npz',
417
+ 'mega_8_scenes_0008_0.1_0.3.npz',
418
+ 'mega_8_scenes_0032_0.1_0.3.npz',
419
+ 'mega_8_scenes_1589_0.1_0.3.npz',
420
+ 'mega_8_scenes_0063_0.1_0.3.npz',
421
+ 'mega_8_scenes_0024_0.1_0.3.npz',
422
+ 'mega_8_scenes_0019_0.3_0.5.npz',
423
+ 'mega_8_scenes_0025_0.3_0.5.npz',
424
+ 'mega_8_scenes_0021_0.3_0.5.npz',
425
+ 'mega_8_scenes_0008_0.3_0.5.npz',
426
+ 'mega_8_scenes_0032_0.3_0.5.npz',
427
+ 'mega_8_scenes_1589_0.3_0.5.npz',
428
+ 'mega_8_scenes_0063_0.3_0.5.npz',
429
+ 'mega_8_scenes_0024_0.3_0.5.npz'])
430
+ mega_8_scenes_results = mega_8_scenes_benchmark.benchmark(model, model_name=name)
431
+ print(mega_8_scenes_results)
432
+ json.dump(mega_8_scenes_results, open(f"results/mega_8_scenes_{name}.json", "w"))
433
+
434
+ def test_mega1500(model, name):
435
+ mega1500_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth")
436
+ mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
437
+ json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w"))
438
+
439
+ def test_mega1500_poselib(model, name):
440
+ mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 1)
441
+ mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
442
+ json.dump(mega1500_results, open(f"results/mega1500_poselib_{name}.json", "w"))
443
+
444
+ def test_mega_8_scenes_poselib(model, name):
445
+ mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 1,
446
+ scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
447
+ 'mega_8_scenes_0025_0.1_0.3.npz',
448
+ 'mega_8_scenes_0021_0.1_0.3.npz',
449
+ 'mega_8_scenes_0008_0.1_0.3.npz',
450
+ 'mega_8_scenes_0032_0.1_0.3.npz',
451
+ 'mega_8_scenes_1589_0.1_0.3.npz',
452
+ 'mega_8_scenes_0063_0.1_0.3.npz',
453
+ 'mega_8_scenes_0024_0.1_0.3.npz',
454
+ 'mega_8_scenes_0019_0.3_0.5.npz',
455
+ 'mega_8_scenes_0025_0.3_0.5.npz',
456
+ 'mega_8_scenes_0021_0.3_0.5.npz',
457
+ 'mega_8_scenes_0008_0.3_0.5.npz',
458
+ 'mega_8_scenes_0032_0.3_0.5.npz',
459
+ 'mega_8_scenes_1589_0.3_0.5.npz',
460
+ 'mega_8_scenes_0063_0.3_0.5.npz',
461
+ 'mega_8_scenes_0024_0.3_0.5.npz'])
462
+ mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
463
+ json.dump(mega1500_results, open(f"results/mega_8_scenes_poselib_{name}.json", "w"))
464
+
465
+ def test_scannet_poselib(model, name):
466
+ scannet_benchmark = ScanNetPoselibBenchmark("data/scannet")
467
+ scannet_results = scannet_benchmark.benchmark(model)
468
+ json.dump(scannet_results, open(f"results/scannet_{name}.json", "w"))
469
+
470
+ def test_scannet(model, name):
471
+ scannet_benchmark = ScanNetBenchmark("data/scannet")
472
+ scannet_results = scannet_benchmark.benchmark(model)
473
+ json.dump(scannet_results, open(f"results/scannet_{name}.json", "w"))
474
+
475
+ if __name__ == "__main__":
476
+ os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations
477
+ os.environ["OMP_NUM_THREADS"] = "16"
478
+ torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
479
+ import romatch
480
+ parser = ArgumentParser()
481
+ parser.add_argument("--only_test", action='store_true')
482
+ parser.add_argument("--debug_mode", action='store_true')
483
+ parser.add_argument("--dont_log_wandb", action='store_true')
484
+ parser.add_argument("--train_resolution", default='medium')
485
+ parser.add_argument("--gpu_batch_size", default=8, type=int)
486
+ parser.add_argument("--wandb_entity", required = False)
487
+
488
+ args, _ = parser.parse_known_args()
489
+ romatch.DEBUG_MODE = args.debug_mode
490
+ if not args.only_test:
491
+ train(args)
492
+
493
+ experiment_name = "tiny_roma_v1_outdoor"#Path(__file__).stem
494
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
495
+ model = XFeatModel(freeze_xfeat=False, exact_softmax=False).to(device)
496
+ model.load_state_dict(torch.load(f"{experiment_name}.pth"))
497
+ test_mega1500_poselib(model, experiment_name)
498
+
submodules/RoMa/requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ einops
3
+ torchvision
4
+ opencv-python
5
+ kornia
6
+ albumentations
7
+ loguru
8
+ tqdm
9
+ matplotlib
10
+ h5py
11
+ wandb
12
+ timm
13
+ poselib
14
+ #xformers # Optional, used for memefficient attention
submodules/RoMa/romatch/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from .models import roma_outdoor, tiny_roma_v1_outdoor, roma_indoor
3
+
4
+ DEBUG_MODE = False
5
+ RANK = int(os.environ.get('RANK', default = 0))
6
+ GLOBAL_STEP = 0
7
+ STEP_SIZE = 1
8
+ LOCAL_RANK = -1
submodules/RoMa/romatch/benchmarks/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .hpatches_sequences_homog_benchmark import HpatchesHomogBenchmark
2
+ from .scannet_benchmark import ScanNetBenchmark
3
+ from .megadepth_pose_estimation_benchmark import MegaDepthPoseEstimationBenchmark
4
+ from .megadepth_dense_benchmark import MegadepthDenseBenchmark
5
+ from .megadepth_pose_estimation_benchmark_poselib import Mega1500PoseLibBenchmark
6
+ #from .scannet_benchmark_poselib import ScanNetPoselibBenchmark
submodules/RoMa/romatch/benchmarks/hpatches_sequences_homog_benchmark.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+
4
+ import os
5
+
6
+ from tqdm import tqdm
7
+ from romatch.utils import pose_auc
8
+ import cv2
9
+
10
+
11
+ class HpatchesHomogBenchmark:
12
+ """Hpatches grid goes from [0,n-1] instead of [0.5,n-0.5]"""
13
+
14
+ def __init__(self, dataset_path) -> None:
15
+ seqs_dir = "hpatches-sequences-release"
16
+ self.seqs_path = os.path.join(dataset_path, seqs_dir)
17
+ self.seq_names = sorted(os.listdir(self.seqs_path))
18
+ # Ignore seqs is same as LoFTR.
19
+ self.ignore_seqs = set(
20
+ [
21
+ "i_contruction",
22
+ "i_crownnight",
23
+ "i_dc",
24
+ "i_pencils",
25
+ "i_whitebuilding",
26
+ "v_artisans",
27
+ "v_astronautis",
28
+ "v_talent",
29
+ ]
30
+ )
31
+
32
+ def convert_coordinates(self, im_A_coords, im_A_to_im_B, wq, hq, wsup, hsup):
33
+ offset = 0.5 # Hpatches assumes that the center of the top-left pixel is at [0,0] (I think)
34
+ im_A_coords = (
35
+ np.stack(
36
+ (
37
+ wq * (im_A_coords[..., 0] + 1) / 2,
38
+ hq * (im_A_coords[..., 1] + 1) / 2,
39
+ ),
40
+ axis=-1,
41
+ )
42
+ - offset
43
+ )
44
+ im_A_to_im_B = (
45
+ np.stack(
46
+ (
47
+ wsup * (im_A_to_im_B[..., 0] + 1) / 2,
48
+ hsup * (im_A_to_im_B[..., 1] + 1) / 2,
49
+ ),
50
+ axis=-1,
51
+ )
52
+ - offset
53
+ )
54
+ return im_A_coords, im_A_to_im_B
55
+
56
+ def benchmark(self, model, model_name = None):
57
+ n_matches = []
58
+ homog_dists = []
59
+ for seq_idx, seq_name in tqdm(
60
+ enumerate(self.seq_names), total=len(self.seq_names)
61
+ ):
62
+ im_A_path = os.path.join(self.seqs_path, seq_name, "1.ppm")
63
+ im_A = Image.open(im_A_path)
64
+ w1, h1 = im_A.size
65
+ for im_idx in range(2, 7):
66
+ im_B_path = os.path.join(self.seqs_path, seq_name, f"{im_idx}.ppm")
67
+ im_B = Image.open(im_B_path)
68
+ w2, h2 = im_B.size
69
+ H = np.loadtxt(
70
+ os.path.join(self.seqs_path, seq_name, "H_1_" + str(im_idx))
71
+ )
72
+ dense_matches, dense_certainty = model.match(
73
+ im_A_path, im_B_path
74
+ )
75
+ good_matches, _ = model.sample(dense_matches, dense_certainty, 5000)
76
+ pos_a, pos_b = self.convert_coordinates(
77
+ good_matches[:, :2], good_matches[:, 2:], w1, h1, w2, h2
78
+ )
79
+ try:
80
+ H_pred, inliers = cv2.findHomography(
81
+ pos_a,
82
+ pos_b,
83
+ method = cv2.RANSAC,
84
+ confidence = 0.99999,
85
+ ransacReprojThreshold = 3 * min(w2, h2) / 480,
86
+ )
87
+ except:
88
+ H_pred = None
89
+ if H_pred is None:
90
+ H_pred = np.zeros((3, 3))
91
+ H_pred[2, 2] = 1.0
92
+ corners = np.array(
93
+ [[0, 0, 1], [0, h1 - 1, 1], [w1 - 1, 0, 1], [w1 - 1, h1 - 1, 1]]
94
+ )
95
+ real_warped_corners = np.dot(corners, np.transpose(H))
96
+ real_warped_corners = (
97
+ real_warped_corners[:, :2] / real_warped_corners[:, 2:]
98
+ )
99
+ warped_corners = np.dot(corners, np.transpose(H_pred))
100
+ warped_corners = warped_corners[:, :2] / warped_corners[:, 2:]
101
+ mean_dist = np.mean(
102
+ np.linalg.norm(real_warped_corners - warped_corners, axis=1)
103
+ ) / (min(w2, h2) / 480.0)
104
+ homog_dists.append(mean_dist)
105
+
106
+ n_matches = np.array(n_matches)
107
+ thresholds = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
108
+ auc = pose_auc(np.array(homog_dists), thresholds)
109
+ return {
110
+ "hpatches_homog_auc_3": auc[2],
111
+ "hpatches_homog_auc_5": auc[4],
112
+ "hpatches_homog_auc_10": auc[9],
113
+ }
submodules/RoMa/romatch/benchmarks/megadepth_dense_benchmark.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import tqdm
4
+ from romatch.datasets import MegadepthBuilder
5
+ from romatch.utils import warp_kpts
6
+ from torch.utils.data import ConcatDataset
7
+ import romatch
8
+
9
+ class MegadepthDenseBenchmark:
10
+ def __init__(self, data_root="data/megadepth", h = 384, w = 512, num_samples = 2000) -> None:
11
+ mega = MegadepthBuilder(data_root=data_root)
12
+ self.dataset = ConcatDataset(
13
+ mega.build_scenes(split="test_loftr", ht=h, wt=w)
14
+ ) # fixed resolution of 384,512
15
+ self.num_samples = num_samples
16
+
17
+ def geometric_dist(self, depth1, depth2, T_1to2, K1, K2, dense_matches):
18
+ b, h1, w1, d = dense_matches.shape
19
+ with torch.no_grad():
20
+ x1 = dense_matches[..., :2].reshape(b, h1 * w1, 2)
21
+ mask, x2 = warp_kpts(
22
+ x1.double(),
23
+ depth1.double(),
24
+ depth2.double(),
25
+ T_1to2.double(),
26
+ K1.double(),
27
+ K2.double(),
28
+ )
29
+ x2 = torch.stack(
30
+ (w1 * (x2[..., 0] + 1) / 2, h1 * (x2[..., 1] + 1) / 2), dim=-1
31
+ )
32
+ prob = mask.float().reshape(b, h1, w1)
33
+ x2_hat = dense_matches[..., 2:]
34
+ x2_hat = torch.stack(
35
+ (w1 * (x2_hat[..., 0] + 1) / 2, h1 * (x2_hat[..., 1] + 1) / 2), dim=-1
36
+ )
37
+ gd = (x2_hat - x2.reshape(b, h1, w1, 2)).norm(dim=-1)
38
+ gd = gd[prob == 1]
39
+ pck_1 = (gd < 1.0).float().mean()
40
+ pck_3 = (gd < 3.0).float().mean()
41
+ pck_5 = (gd < 5.0).float().mean()
42
+ return gd, pck_1, pck_3, pck_5, prob
43
+
44
+ def benchmark(self, model, batch_size=8):
45
+ model.train(False)
46
+ with torch.no_grad():
47
+ gd_tot = 0.0
48
+ pck_1_tot = 0.0
49
+ pck_3_tot = 0.0
50
+ pck_5_tot = 0.0
51
+ sampler = torch.utils.data.WeightedRandomSampler(
52
+ torch.ones(len(self.dataset)), replacement=False, num_samples=self.num_samples
53
+ )
54
+ B = batch_size
55
+ dataloader = torch.utils.data.DataLoader(
56
+ self.dataset, batch_size=B, num_workers=batch_size, sampler=sampler
57
+ )
58
+ for idx, data in tqdm.tqdm(enumerate(dataloader), disable = romatch.RANK > 0):
59
+ im_A, im_B, depth1, depth2, T_1to2, K1, K2 = (
60
+ data["im_A"].cuda(),
61
+ data["im_B"].cuda(),
62
+ data["im_A_depth"].cuda(),
63
+ data["im_B_depth"].cuda(),
64
+ data["T_1to2"].cuda(),
65
+ data["K1"].cuda(),
66
+ data["K2"].cuda(),
67
+ )
68
+ matches, certainty = model.match(im_A, im_B, batched=True)
69
+ gd, pck_1, pck_3, pck_5, prob = self.geometric_dist(
70
+ depth1, depth2, T_1to2, K1, K2, matches
71
+ )
72
+ if romatch.DEBUG_MODE:
73
+ from romatch.utils.utils import tensor_to_pil
74
+ import torch.nn.functional as F
75
+ path = "vis"
76
+ H, W = model.get_output_resolution()
77
+ white_im = torch.ones((B,1,H,W),device="cuda")
78
+ im_B_transfer_rgb = F.grid_sample(
79
+ im_B.cuda(), matches[:,:,:W, 2:], mode="bilinear", align_corners=False
80
+ )
81
+ warp_im = im_B_transfer_rgb
82
+ c_b = certainty[:,None]#(certainty*0.9 + 0.1*torch.ones_like(certainty))[:,None]
83
+ vis_im = c_b * warp_im + (1 - c_b) * white_im
84
+ for b in range(B):
85
+ import os
86
+ os.makedirs(f"{path}/{model.name}/{idx}_{b}_{H}_{W}",exist_ok=True)
87
+ tensor_to_pil(vis_im[b], unnormalize=True).save(
88
+ f"{path}/{model.name}/{idx}_{b}_{H}_{W}/warp.jpg")
89
+ tensor_to_pil(im_A[b].cuda(), unnormalize=True).save(
90
+ f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_A.jpg")
91
+ tensor_to_pil(im_B[b].cuda(), unnormalize=True).save(
92
+ f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_B.jpg")
93
+
94
+
95
+ gd_tot, pck_1_tot, pck_3_tot, pck_5_tot = (
96
+ gd_tot + gd.mean(),
97
+ pck_1_tot + pck_1,
98
+ pck_3_tot + pck_3,
99
+ pck_5_tot + pck_5,
100
+ )
101
+ return {
102
+ "epe": gd_tot.item() / len(dataloader),
103
+ "mega_pck_1": pck_1_tot.item() / len(dataloader),
104
+ "mega_pck_3": pck_3_tot.item() / len(dataloader),
105
+ "mega_pck_5": pck_5_tot.item() / len(dataloader),
106
+ }
submodules/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from romatch.utils import *
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+ import torch.nn.functional as F
7
+ import romatch
8
+ import kornia.geometry.epipolar as kepi
9
+
10
+ class MegaDepthPoseEstimationBenchmark:
11
+ def __init__(self, data_root="data/megadepth", scene_names = None) -> None:
12
+ if scene_names is None:
13
+ self.scene_names = [
14
+ "0015_0.1_0.3.npz",
15
+ "0015_0.3_0.5.npz",
16
+ "0022_0.1_0.3.npz",
17
+ "0022_0.3_0.5.npz",
18
+ "0022_0.5_0.7.npz",
19
+ ]
20
+ else:
21
+ self.scene_names = scene_names
22
+ self.scenes = [
23
+ np.load(f"{data_root}/{scene}", allow_pickle=True)
24
+ for scene in self.scene_names
25
+ ]
26
+ self.data_root = data_root
27
+
28
+ def benchmark(self, model, model_name = None):
29
+ with torch.no_grad():
30
+ data_root = self.data_root
31
+ tot_e_t, tot_e_R, tot_e_pose = [], [], []
32
+ thresholds = [5, 10, 20]
33
+ for scene_ind in range(len(self.scenes)):
34
+ import os
35
+ scene_name = os.path.splitext(self.scene_names[scene_ind])[0]
36
+ scene = self.scenes[scene_ind]
37
+ pairs = scene["pair_infos"]
38
+ intrinsics = scene["intrinsics"]
39
+ poses = scene["poses"]
40
+ im_paths = scene["image_paths"]
41
+ pair_inds = range(len(pairs))
42
+ for pairind in tqdm(pair_inds):
43
+ idx1, idx2 = pairs[pairind][0]
44
+ K1 = intrinsics[idx1].copy()
45
+ T1 = poses[idx1].copy()
46
+ R1, t1 = T1[:3, :3], T1[:3, 3]
47
+ K2 = intrinsics[idx2].copy()
48
+ T2 = poses[idx2].copy()
49
+ R2, t2 = T2[:3, :3], T2[:3, 3]
50
+ R, t = compute_relative_pose(R1, t1, R2, t2)
51
+ T1_to_2 = np.concatenate((R,t[:,None]), axis=-1)
52
+ im_A_path = f"{data_root}/{im_paths[idx1]}"
53
+ im_B_path = f"{data_root}/{im_paths[idx2]}"
54
+ dense_matches, dense_certainty = model.match(
55
+ im_A_path, im_B_path, K1.copy(), K2.copy(), T1_to_2.copy()
56
+ )
57
+ sparse_matches,_ = model.sample(
58
+ dense_matches, dense_certainty, 5_000
59
+ )
60
+
61
+ im_A = Image.open(im_A_path)
62
+ w1, h1 = im_A.size
63
+ im_B = Image.open(im_B_path)
64
+ w2, h2 = im_B.size
65
+ if True: # Note: we keep this true as it was used in DKM/RoMa papers. There is very little difference compared to setting to False.
66
+ scale1 = 1200 / max(w1, h1)
67
+ scale2 = 1200 / max(w2, h2)
68
+ w1, h1 = scale1 * w1, scale1 * h1
69
+ w2, h2 = scale2 * w2, scale2 * h2
70
+ K1, K2 = K1.copy(), K2.copy()
71
+ K1[:2] = K1[:2] * scale1
72
+ K2[:2] = K2[:2] * scale2
73
+
74
+ kpts1, kpts2 = model.to_pixel_coordinates(sparse_matches, h1, w1, h2, w2)
75
+ kpts1, kpts2 = kpts1.cpu().numpy(), kpts2.cpu().numpy()
76
+ for _ in range(5):
77
+ shuffling = np.random.permutation(np.arange(len(kpts1)))
78
+ kpts1 = kpts1[shuffling]
79
+ kpts2 = kpts2[shuffling]
80
+ try:
81
+ threshold = 0.5
82
+ norm_threshold = threshold / (np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
83
+ R_est, t_est, mask = estimate_pose(
84
+ kpts1,
85
+ kpts2,
86
+ K1,
87
+ K2,
88
+ norm_threshold,
89
+ conf=0.99999,
90
+ )
91
+ T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) #
92
+ e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
93
+ e_pose = max(e_t, e_R)
94
+ except Exception as e:
95
+ print(repr(e))
96
+ e_t, e_R = 90, 90
97
+ e_pose = max(e_t, e_R)
98
+ tot_e_t.append(e_t)
99
+ tot_e_R.append(e_R)
100
+ tot_e_pose.append(e_pose)
101
+ tot_e_pose = np.array(tot_e_pose)
102
+ auc = pose_auc(tot_e_pose, thresholds)
103
+ acc_5 = (tot_e_pose < 5).mean()
104
+ acc_10 = (tot_e_pose < 10).mean()
105
+ acc_15 = (tot_e_pose < 15).mean()
106
+ acc_20 = (tot_e_pose < 20).mean()
107
+ map_5 = acc_5
108
+ map_10 = np.mean([acc_5, acc_10])
109
+ map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
110
+ print(f"{model_name} auc: {auc}")
111
+ return {
112
+ "auc_5": auc[0],
113
+ "auc_10": auc[1],
114
+ "auc_20": auc[2],
115
+ "map_5": map_5,
116
+ "map_10": map_10,
117
+ "map_20": map_20,
118
+ }
submodules/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark_poselib.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from romatch.utils import *
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+ import torch.nn.functional as F
7
+ import romatch
8
+ import kornia.geometry.epipolar as kepi
9
+
10
+ # wrap cause pyposelib is still in dev
11
+ # will add in deps later
12
+ import poselib
13
+
14
+ class Mega1500PoseLibBenchmark:
15
+ def __init__(self, data_root="data/megadepth", scene_names = None, num_ransac_iter = 5, test_every = 1) -> None:
16
+ if scene_names is None:
17
+ self.scene_names = [
18
+ "0015_0.1_0.3.npz",
19
+ "0015_0.3_0.5.npz",
20
+ "0022_0.1_0.3.npz",
21
+ "0022_0.3_0.5.npz",
22
+ "0022_0.5_0.7.npz",
23
+ ]
24
+ else:
25
+ self.scene_names = scene_names
26
+ self.scenes = [
27
+ np.load(f"{data_root}/{scene}", allow_pickle=True)
28
+ for scene in self.scene_names
29
+ ]
30
+ self.data_root = data_root
31
+ self.num_ransac_iter = num_ransac_iter
32
+ self.test_every = test_every
33
+
34
+ def benchmark(self, model, model_name = None):
35
+ with torch.no_grad():
36
+ data_root = self.data_root
37
+ tot_e_t, tot_e_R, tot_e_pose = [], [], []
38
+ thresholds = [5, 10, 20]
39
+ for scene_ind in range(len(self.scenes)):
40
+ import os
41
+ scene_name = os.path.splitext(self.scene_names[scene_ind])[0]
42
+ scene = self.scenes[scene_ind]
43
+ pairs = scene["pair_infos"]
44
+ intrinsics = scene["intrinsics"]
45
+ poses = scene["poses"]
46
+ im_paths = scene["image_paths"]
47
+ pair_inds = range(len(pairs))[::self.test_every]
48
+ for pairind in (pbar := tqdm(pair_inds, desc = "Current AUC: ?")):
49
+ idx1, idx2 = pairs[pairind][0]
50
+ K1 = intrinsics[idx1].copy()
51
+ T1 = poses[idx1].copy()
52
+ R1, t1 = T1[:3, :3], T1[:3, 3]
53
+ K2 = intrinsics[idx2].copy()
54
+ T2 = poses[idx2].copy()
55
+ R2, t2 = T2[:3, :3], T2[:3, 3]
56
+ R, t = compute_relative_pose(R1, t1, R2, t2)
57
+ T1_to_2 = np.concatenate((R,t[:,None]), axis=-1)
58
+ im_A_path = f"{data_root}/{im_paths[idx1]}"
59
+ im_B_path = f"{data_root}/{im_paths[idx2]}"
60
+ dense_matches, dense_certainty = model.match(
61
+ im_A_path, im_B_path, K1.copy(), K2.copy(), T1_to_2.copy()
62
+ )
63
+ sparse_matches,_ = model.sample(
64
+ dense_matches, dense_certainty, 5_000
65
+ )
66
+
67
+ im_A = Image.open(im_A_path)
68
+ w1, h1 = im_A.size
69
+ im_B = Image.open(im_B_path)
70
+ w2, h2 = im_B.size
71
+ kpts1, kpts2 = model.to_pixel_coordinates(sparse_matches, h1, w1, h2, w2)
72
+ kpts1, kpts2 = kpts1.cpu().numpy(), kpts2.cpu().numpy()
73
+ for _ in range(self.num_ransac_iter):
74
+ shuffling = np.random.permutation(np.arange(len(kpts1)))
75
+ kpts1 = kpts1[shuffling]
76
+ kpts2 = kpts2[shuffling]
77
+ try:
78
+ threshold = 1
79
+ camera1 = {'model': 'PINHOLE', 'width': w1, 'height': h1, 'params': K1[[0,1,0,1], [0,1,2,2]]}
80
+ camera2 = {'model': 'PINHOLE', 'width': w2, 'height': h2, 'params': K2[[0,1,0,1], [0,1,2,2]]}
81
+ relpose, res = poselib.estimate_relative_pose(
82
+ kpts1,
83
+ kpts2,
84
+ camera1,
85
+ camera2,
86
+ ransac_opt = {"max_reproj_error": 2*threshold, "max_epipolar_error": threshold, "min_inliers": 8, "max_iterations": 10_000},
87
+ )
88
+ Rt_est = relpose.Rt
89
+ R_est, t_est = Rt_est[:3,:3], Rt_est[:3,3:]
90
+ mask = np.array(res['inliers']).astype(np.float32)
91
+ T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) #
92
+ e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
93
+ e_pose = max(e_t, e_R)
94
+ except Exception as e:
95
+ print(repr(e))
96
+ e_t, e_R = 90, 90
97
+ e_pose = max(e_t, e_R)
98
+ tot_e_t.append(e_t)
99
+ tot_e_R.append(e_R)
100
+ tot_e_pose.append(e_pose)
101
+ pbar.set_description(f"Current AUC: {pose_auc(tot_e_pose, thresholds)}")
102
+ tot_e_pose = np.array(tot_e_pose)
103
+ auc = pose_auc(tot_e_pose, thresholds)
104
+ acc_5 = (tot_e_pose < 5).mean()
105
+ acc_10 = (tot_e_pose < 10).mean()
106
+ acc_15 = (tot_e_pose < 15).mean()
107
+ acc_20 = (tot_e_pose < 20).mean()
108
+ map_5 = acc_5
109
+ map_10 = np.mean([acc_5, acc_10])
110
+ map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
111
+ print(f"{model_name} auc: {auc}")
112
+ return {
113
+ "auc_5": auc[0],
114
+ "auc_10": auc[1],
115
+ "auc_20": auc[2],
116
+ "map_5": map_5,
117
+ "map_10": map_10,
118
+ "map_20": map_20,
119
+ }
submodules/RoMa/romatch/benchmarks/scannet_benchmark.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import numpy as np
3
+ import torch
4
+ from romatch.utils import *
5
+ from PIL import Image
6
+ from tqdm import tqdm
7
+
8
+
9
+ class ScanNetBenchmark:
10
+ def __init__(self, data_root="data/scannet") -> None:
11
+ self.data_root = data_root
12
+
13
+ def benchmark(self, model, model_name = None):
14
+ model.train(False)
15
+ with torch.no_grad():
16
+ data_root = self.data_root
17
+ tmp = np.load(osp.join(data_root, "test.npz"))
18
+ pairs, rel_pose = tmp["name"], tmp["rel_pose"]
19
+ tot_e_t, tot_e_R, tot_e_pose = [], [], []
20
+ pair_inds = np.random.choice(
21
+ range(len(pairs)), size=len(pairs), replace=False
22
+ )
23
+ for pairind in tqdm(pair_inds, smoothing=0.9):
24
+ scene = pairs[pairind]
25
+ scene_name = f"scene0{scene[0]}_00"
26
+ im_A_path = osp.join(
27
+ self.data_root,
28
+ "scans_test",
29
+ scene_name,
30
+ "color",
31
+ f"{scene[2]}.jpg",
32
+ )
33
+ im_A = Image.open(im_A_path)
34
+ im_B_path = osp.join(
35
+ self.data_root,
36
+ "scans_test",
37
+ scene_name,
38
+ "color",
39
+ f"{scene[3]}.jpg",
40
+ )
41
+ im_B = Image.open(im_B_path)
42
+ T_gt = rel_pose[pairind].reshape(3, 4)
43
+ R, t = T_gt[:3, :3], T_gt[:3, 3]
44
+ K = np.stack(
45
+ [
46
+ np.array([float(i) for i in r.split()])
47
+ for r in open(
48
+ osp.join(
49
+ self.data_root,
50
+ "scans_test",
51
+ scene_name,
52
+ "intrinsic",
53
+ "intrinsic_color.txt",
54
+ ),
55
+ "r",
56
+ )
57
+ .read()
58
+ .split("\n")
59
+ if r
60
+ ]
61
+ )
62
+ w1, h1 = im_A.size
63
+ w2, h2 = im_B.size
64
+ K1 = K.copy()
65
+ K2 = K.copy()
66
+ dense_matches, dense_certainty = model.match(im_A_path, im_B_path)
67
+ sparse_matches, sparse_certainty = model.sample(
68
+ dense_matches, dense_certainty, 5000
69
+ )
70
+ scale1 = 480 / min(w1, h1)
71
+ scale2 = 480 / min(w2, h2)
72
+ w1, h1 = scale1 * w1, scale1 * h1
73
+ w2, h2 = scale2 * w2, scale2 * h2
74
+ K1 = K1 * scale1
75
+ K2 = K2 * scale2
76
+
77
+ offset = 0.5
78
+ kpts1 = sparse_matches[:, :2]
79
+ kpts1 = (
80
+ np.stack(
81
+ (
82
+ w1 * (kpts1[:, 0] + 1) / 2 - offset,
83
+ h1 * (kpts1[:, 1] + 1) / 2 - offset,
84
+ ),
85
+ axis=-1,
86
+ )
87
+ )
88
+ kpts2 = sparse_matches[:, 2:]
89
+ kpts2 = (
90
+ np.stack(
91
+ (
92
+ w2 * (kpts2[:, 0] + 1) / 2 - offset,
93
+ h2 * (kpts2[:, 1] + 1) / 2 - offset,
94
+ ),
95
+ axis=-1,
96
+ )
97
+ )
98
+ for _ in range(5):
99
+ shuffling = np.random.permutation(np.arange(len(kpts1)))
100
+ kpts1 = kpts1[shuffling]
101
+ kpts2 = kpts2[shuffling]
102
+ try:
103
+ norm_threshold = 0.5 / (
104
+ np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
105
+ R_est, t_est, mask = estimate_pose(
106
+ kpts1,
107
+ kpts2,
108
+ K1,
109
+ K2,
110
+ norm_threshold,
111
+ conf=0.99999,
112
+ )
113
+ T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) #
114
+ e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
115
+ e_pose = max(e_t, e_R)
116
+ except Exception as e:
117
+ print(repr(e))
118
+ e_t, e_R = 90, 90
119
+ e_pose = max(e_t, e_R)
120
+ tot_e_t.append(e_t)
121
+ tot_e_R.append(e_R)
122
+ tot_e_pose.append(e_pose)
123
+ tot_e_t.append(e_t)
124
+ tot_e_R.append(e_R)
125
+ tot_e_pose.append(e_pose)
126
+ tot_e_pose = np.array(tot_e_pose)
127
+ thresholds = [5, 10, 20]
128
+ auc = pose_auc(tot_e_pose, thresholds)
129
+ acc_5 = (tot_e_pose < 5).mean()
130
+ acc_10 = (tot_e_pose < 10).mean()
131
+ acc_15 = (tot_e_pose < 15).mean()
132
+ acc_20 = (tot_e_pose < 20).mean()
133
+ map_5 = acc_5
134
+ map_10 = np.mean([acc_5, acc_10])
135
+ map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
136
+ return {
137
+ "auc_5": auc[0],
138
+ "auc_10": auc[1],
139
+ "auc_20": auc[2],
140
+ "map_5": map_5,
141
+ "map_10": map_10,
142
+ "map_20": map_20,
143
+ }
submodules/RoMa/romatch/checkpointing/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .checkpoint import CheckPoint
submodules/RoMa/romatch/checkpointing/checkpoint.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch.nn.parallel.data_parallel import DataParallel
4
+ from torch.nn.parallel.distributed import DistributedDataParallel
5
+ from loguru import logger
6
+ import gc
7
+
8
+ import romatch
9
+
10
+ class CheckPoint:
11
+ def __init__(self, dir=None, name="tmp"):
12
+ self.name = name
13
+ self.dir = dir
14
+ os.makedirs(self.dir, exist_ok=True)
15
+
16
+ def save(
17
+ self,
18
+ model,
19
+ optimizer,
20
+ lr_scheduler,
21
+ n,
22
+ ):
23
+ if romatch.RANK == 0:
24
+ assert model is not None
25
+ if isinstance(model, (DataParallel, DistributedDataParallel)):
26
+ model = model.module
27
+ states = {
28
+ "model": model.state_dict(),
29
+ "n": n,
30
+ "optimizer": optimizer.state_dict(),
31
+ "lr_scheduler": lr_scheduler.state_dict(),
32
+ }
33
+ torch.save(states, self.dir + self.name + f"_latest.pth")
34
+ logger.info(f"Saved states {list(states.keys())}, at step {n}")
35
+
36
+ def load(
37
+ self,
38
+ model,
39
+ optimizer,
40
+ lr_scheduler,
41
+ n,
42
+ ):
43
+ if os.path.exists(self.dir + self.name + f"_latest.pth") and romatch.RANK == 0:
44
+ states = torch.load(self.dir + self.name + f"_latest.pth")
45
+ if "model" in states:
46
+ model.load_state_dict(states["model"])
47
+ if "n" in states:
48
+ n = states["n"] if states["n"] else n
49
+ if "optimizer" in states:
50
+ try:
51
+ optimizer.load_state_dict(states["optimizer"])
52
+ except Exception as e:
53
+ print(f"Failed to load states for optimizer, with error {e}")
54
+ if "lr_scheduler" in states:
55
+ lr_scheduler.load_state_dict(states["lr_scheduler"])
56
+ print(f"Loaded states {list(states.keys())}, at step {n}")
57
+ del states
58
+ gc.collect()
59
+ torch.cuda.empty_cache()
60
+ return model, optimizer, lr_scheduler, n