|
import os |
|
import torch |
|
from argparse import ArgumentParser |
|
|
|
from torch import nn |
|
from torch.utils.data import ConcatDataset |
|
import torch.distributed as dist |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
|
import json |
|
import wandb |
|
from tqdm import tqdm |
|
|
|
from romatch.benchmarks import MegadepthDenseBenchmark |
|
from romatch.datasets.megadepth import MegadepthBuilder |
|
from romatch.datasets.scannet import ScanNetBuilder |
|
from romatch.losses.robust_loss import RobustLosses |
|
from romatch.benchmarks import MegadepthDenseBenchmark, ScanNetBenchmark |
|
from romatch.train.train import train_k_steps |
|
from romatch.models.matcher import * |
|
from romatch.models.transformer import Block, TransformerDecoder, MemEffAttention |
|
from romatch.models.encoders import * |
|
from romatch.checkpointing import CheckPoint |
|
|
|
resolutions = {"low":(448, 448), "medium":(14*8*5, 14*8*5), "high":(14*8*6, 14*8*6)} |
|
|
|
def get_model(pretrained_backbone=True, resolution = "medium", **kwargs): |
|
gp_dim = 512 |
|
feat_dim = 512 |
|
decoder_dim = gp_dim + feat_dim |
|
cls_to_coord_res = 64 |
|
coordinate_decoder = TransformerDecoder( |
|
nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]), |
|
decoder_dim, |
|
cls_to_coord_res**2 + 1, |
|
is_classifier=True, |
|
amp = True, |
|
pos_enc = False,) |
|
dw = True |
|
hidden_blocks = 8 |
|
kernel_size = 5 |
|
displacement_emb = "linear" |
|
disable_local_corr_grad = True |
|
|
|
conv_refiner = nn.ModuleDict( |
|
{ |
|
"16": ConvRefiner( |
|
2 * 512+128+(2*7+1)**2, |
|
2 * 512+128+(2*7+1)**2, |
|
2 + 1, |
|
kernel_size=kernel_size, |
|
dw=dw, |
|
hidden_blocks=hidden_blocks, |
|
displacement_emb=displacement_emb, |
|
displacement_emb_dim=128, |
|
local_corr_radius = 7, |
|
corr_in_other = True, |
|
amp = True, |
|
disable_local_corr_grad = disable_local_corr_grad, |
|
bn_momentum = 0.01, |
|
), |
|
"8": ConvRefiner( |
|
2 * 512+64+(2*3+1)**2, |
|
2 * 512+64+(2*3+1)**2, |
|
2 + 1, |
|
kernel_size=kernel_size, |
|
dw=dw, |
|
hidden_blocks=hidden_blocks, |
|
displacement_emb=displacement_emb, |
|
displacement_emb_dim=64, |
|
local_corr_radius = 3, |
|
corr_in_other = True, |
|
amp = True, |
|
disable_local_corr_grad = disable_local_corr_grad, |
|
bn_momentum = 0.01, |
|
), |
|
"4": ConvRefiner( |
|
2 * 256+32+(2*2+1)**2, |
|
2 * 256+32+(2*2+1)**2, |
|
2 + 1, |
|
kernel_size=kernel_size, |
|
dw=dw, |
|
hidden_blocks=hidden_blocks, |
|
displacement_emb=displacement_emb, |
|
displacement_emb_dim=32, |
|
local_corr_radius = 2, |
|
corr_in_other = True, |
|
amp = True, |
|
disable_local_corr_grad = disable_local_corr_grad, |
|
bn_momentum = 0.01, |
|
), |
|
"2": ConvRefiner( |
|
2 * 64+16, |
|
128+16, |
|
2 + 1, |
|
kernel_size=kernel_size, |
|
dw=dw, |
|
hidden_blocks=hidden_blocks, |
|
displacement_emb=displacement_emb, |
|
displacement_emb_dim=16, |
|
amp = True, |
|
disable_local_corr_grad = disable_local_corr_grad, |
|
bn_momentum = 0.01, |
|
), |
|
"1": ConvRefiner( |
|
2 * 9 + 6, |
|
24, |
|
2 + 1, |
|
kernel_size=kernel_size, |
|
dw=dw, |
|
hidden_blocks = hidden_blocks, |
|
displacement_emb = displacement_emb, |
|
displacement_emb_dim = 6, |
|
amp = True, |
|
disable_local_corr_grad = disable_local_corr_grad, |
|
bn_momentum = 0.01, |
|
), |
|
} |
|
) |
|
kernel_temperature = 0.2 |
|
learn_temperature = False |
|
no_cov = True |
|
kernel = CosKernel |
|
only_attention = False |
|
basis = "fourier" |
|
gp16 = GP( |
|
kernel, |
|
T=kernel_temperature, |
|
learn_temperature=learn_temperature, |
|
only_attention=only_attention, |
|
gp_dim=gp_dim, |
|
basis=basis, |
|
no_cov=no_cov, |
|
) |
|
gps = nn.ModuleDict({"16": gp16}) |
|
proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512)) |
|
proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512)) |
|
proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256)) |
|
proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64)) |
|
proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9)) |
|
proj = nn.ModuleDict({ |
|
"16": proj16, |
|
"8": proj8, |
|
"4": proj4, |
|
"2": proj2, |
|
"1": proj1, |
|
}) |
|
displacement_dropout_p = 0.0 |
|
gm_warp_dropout_p = 0.0 |
|
decoder = Decoder(coordinate_decoder, |
|
gps, |
|
proj, |
|
conv_refiner, |
|
detach=True, |
|
scales=["16", "8", "4", "2", "1"], |
|
displacement_dropout_p = displacement_dropout_p, |
|
gm_warp_dropout_p = gm_warp_dropout_p) |
|
h,w = resolutions[resolution] |
|
encoder = CNNandDinov2( |
|
cnn_kwargs = dict( |
|
pretrained=pretrained_backbone, |
|
amp = True), |
|
amp = True, |
|
use_vgg = True, |
|
) |
|
matcher = RegressionMatcher(encoder, decoder, h=h, w=w, alpha=1, beta=0,**kwargs) |
|
return matcher |
|
|
|
def train(args): |
|
dist.init_process_group('nccl') |
|
|
|
gpus = int(os.environ['WORLD_SIZE']) |
|
|
|
rank = dist.get_rank() |
|
print(f"Start running DDP on rank {rank}") |
|
device_id = rank % torch.cuda.device_count() |
|
romatch.LOCAL_RANK = device_id |
|
torch.cuda.set_device(device_id) |
|
|
|
resolution = args.train_resolution |
|
wandb_log = not args.dont_log_wandb |
|
experiment_name = os.path.splitext(os.path.basename(__file__))[0] |
|
wandb_mode = "online" if wandb_log and rank == 0 and False else "disabled" |
|
wandb.init(project="romatch", entity=args.wandb_entity, name=experiment_name, reinit=False, mode = wandb_mode) |
|
checkpoint_dir = "workspace/checkpoints/" |
|
h,w = resolutions[resolution] |
|
model = get_model(pretrained_backbone=True, resolution=resolution, attenuate_cert = False).to(device_id) |
|
|
|
global_step = 0 |
|
batch_size = args.gpu_batch_size |
|
step_size = gpus*batch_size |
|
romatch.STEP_SIZE = step_size |
|
|
|
N = (32 * 250000) |
|
|
|
k = 25000 // romatch.STEP_SIZE |
|
|
|
|
|
mega = MegadepthBuilder(data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True) |
|
use_horizontal_flip_aug = True |
|
rot_prob = 0 |
|
depth_interpolation_mode = "bilinear" |
|
megadepth_train1 = mega.build_scenes( |
|
split="train_loftr", min_overlap=0.01, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob, |
|
ht=h,wt=w, |
|
) |
|
megadepth_train2 = mega.build_scenes( |
|
split="train_loftr", min_overlap=0.35, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob, |
|
ht=h,wt=w, |
|
) |
|
megadepth_train = ConcatDataset(megadepth_train1 + megadepth_train2) |
|
mega_ws = mega.weight_scenes(megadepth_train, alpha=0.75) |
|
|
|
scannet = ScanNetBuilder(data_root="data/scannet") |
|
scannet_train = scannet.build_scenes(split="train", ht=h, wt=w, use_horizontal_flip_aug = use_horizontal_flip_aug) |
|
scannet_train = ConcatDataset(scannet_train) |
|
scannet_ws = scannet.weight_scenes(scannet_train, alpha=0.75) |
|
|
|
|
|
depth_loss_scannet = RobustLosses( |
|
ce_weight=0.0, |
|
local_dist={1:4, 2:4, 4:8, 8:8}, |
|
local_largest_scale=8, |
|
depth_interpolation_mode=depth_interpolation_mode, |
|
alpha = 0.5, |
|
c = 1e-4,) |
|
|
|
depth_loss_mega = RobustLosses( |
|
ce_weight=0.01, |
|
local_dist={1:4, 2:4, 4:8, 8:8}, |
|
local_largest_scale=8, |
|
depth_interpolation_mode=depth_interpolation_mode, |
|
alpha = 0.5, |
|
c = 1e-4,) |
|
parameters = [ |
|
{"params": model.encoder.parameters(), "lr": romatch.STEP_SIZE * 5e-6 / 8}, |
|
{"params": model.decoder.parameters(), "lr": romatch.STEP_SIZE * 1e-4 / 8}, |
|
] |
|
optimizer = torch.optim.AdamW(parameters, weight_decay=0.01) |
|
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( |
|
optimizer, milestones=[(9*N/romatch.STEP_SIZE)//10]) |
|
megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000, h=h,w=w) |
|
checkpointer = CheckPoint(checkpoint_dir, experiment_name) |
|
model, optimizer, lr_scheduler, global_step = checkpointer.load(model, optimizer, lr_scheduler, global_step) |
|
romatch.GLOBAL_STEP = global_step |
|
ddp_model = DDP(model, device_ids=[device_id], find_unused_parameters = False, gradient_as_bucket_view=True) |
|
grad_scaler = torch.cuda.amp.GradScaler(growth_interval=1_000_000) |
|
grad_clip_norm = 0.01 |
|
for n in range(romatch.GLOBAL_STEP, N, k * romatch.STEP_SIZE): |
|
mega_sampler = torch.utils.data.WeightedRandomSampler( |
|
mega_ws, num_samples = batch_size * k, replacement=False |
|
) |
|
mega_dataloader = iter( |
|
torch.utils.data.DataLoader( |
|
megadepth_train, |
|
batch_size = batch_size, |
|
sampler = mega_sampler, |
|
num_workers = 8, |
|
) |
|
) |
|
scannet_ws_sampler = torch.utils.data.WeightedRandomSampler( |
|
scannet_ws, num_samples=batch_size * k, replacement=False |
|
) |
|
scannet_dataloader = iter( |
|
torch.utils.data.DataLoader( |
|
scannet_train, |
|
batch_size=batch_size, |
|
sampler=scannet_ws_sampler, |
|
num_workers=gpus * 8, |
|
) |
|
) |
|
for n_k in tqdm(range(n, n + 2 * k, 2),disable = romatch.RANK > 0): |
|
train_k_steps( |
|
n_k, 1, mega_dataloader, ddp_model, depth_loss_mega, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm, progress_bar=False |
|
) |
|
train_k_steps( |
|
n_k + 1, 1, scannet_dataloader, ddp_model, depth_loss_scannet, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm, progress_bar=False |
|
) |
|
checkpointer.save(model, optimizer, lr_scheduler, romatch.GLOBAL_STEP) |
|
wandb.log(megadense_benchmark.benchmark(model), step = romatch.GLOBAL_STEP) |
|
|
|
def test_scannet(model, name, resolution, sample_mode): |
|
scannet_benchmark = ScanNetBenchmark("data/scannet") |
|
scannet_results = scannet_benchmark.benchmark(model) |
|
json.dump(scannet_results, open(f"results/scannet_{name}.json", "w")) |
|
|
|
if __name__ == "__main__": |
|
import warnings |
|
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') |
|
warnings.filterwarnings('ignore') |
|
os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" |
|
os.environ["OMP_NUM_THREADS"] = "16" |
|
|
|
import romatch |
|
parser = ArgumentParser() |
|
parser.add_argument("--test", action='store_true') |
|
parser.add_argument("--debug_mode", action='store_true') |
|
parser.add_argument("--dont_log_wandb", action='store_true') |
|
parser.add_argument("--train_resolution", default='medium') |
|
parser.add_argument("--gpu_batch_size", default=4, type=int) |
|
parser.add_argument("--wandb_entity", required = False) |
|
|
|
args, _ = parser.parse_known_args() |
|
romatch.DEBUG_MODE = args.debug_mode |
|
if not args.test: |
|
train(args) |
|
experiment_name = os.path.splitext(os.path.basename(__file__))[0] |
|
checkpoint_dir = "workspace/" |
|
checkpoint_name = checkpoint_dir + experiment_name + ".pth" |
|
test_resolution = "medium" |
|
sample_mode = "threshold_balanced" |
|
symmetric = True |
|
upsample_preds = False |
|
attenuate_cert = True |
|
|
|
model = get_model(pretrained_backbone=False, resolution = test_resolution, sample_mode = sample_mode, upsample_preds = upsample_preds, symmetric=symmetric, name=experiment_name, attenuate_cert = attenuate_cert) |
|
model = model.cuda() |
|
states = torch.load(checkpoint_name) |
|
model.load_state_dict(states["model"]) |
|
test_scannet(model, experiment_name, resolution = test_resolution, sample_mode = sample_mode) |
|
|