Spaces:
Running
on
Zero
Running
on
Zero
from functools import partial | |
import torch | |
import torch.nn as nn | |
from einops import rearrange | |
import spconv.pytorch as spconv | |
from timm.models.layers import trunc_normal_ | |
from ..builder import MODELS | |
from ..utils import offset2batch | |
from torch_geometric.nn.pool import voxel_grid | |
from torch_geometric.utils import scatter | |
class BasicBlock(nn.Module): | |
def __init__( | |
self, | |
in_channels, | |
embed_channels, | |
norm_fn=None, | |
indice_key=None, | |
depth=4, | |
groups=None, | |
grid_size=None, | |
bias=False, | |
): | |
super().__init__() | |
assert embed_channels % groups == 0 | |
self.groups = groups | |
self.embed_channels = embed_channels | |
self.proj = nn.ModuleList() | |
self.grid_size = grid_size | |
self.weight = nn.ModuleList() | |
self.l_w = nn.ModuleList() | |
self.proj.append( | |
nn.Sequential( | |
nn.Linear(embed_channels, embed_channels, bias=False), | |
norm_fn(embed_channels), | |
nn.ReLU(), | |
) | |
) | |
for _ in range(depth - 1): | |
self.proj.append( | |
nn.Sequential( | |
nn.Linear(embed_channels, embed_channels, bias=False), | |
norm_fn(embed_channels), | |
nn.ReLU(), | |
) | |
) | |
self.l_w.append( | |
nn.Sequential( | |
nn.Linear(embed_channels, embed_channels, bias=False), | |
norm_fn(embed_channels), | |
nn.ReLU(), | |
) | |
) | |
self.weight.append(nn.Linear(embed_channels, embed_channels, bias=False)) | |
self.adaptive = nn.Linear(embed_channels, depth - 1, bias=False) | |
self.fuse = nn.Sequential( | |
nn.Linear(embed_channels * 2, embed_channels, bias=False), | |
norm_fn(embed_channels), | |
nn.ReLU(), | |
) | |
self.voxel_block = spconv.SparseSequential( | |
spconv.SubMConv3d( | |
embed_channels, | |
embed_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
indice_key=indice_key, | |
bias=bias, | |
), | |
norm_fn(embed_channels), | |
nn.ReLU(), | |
spconv.SubMConv3d( | |
embed_channels, | |
embed_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
indice_key=indice_key, | |
bias=bias, | |
), | |
norm_fn(embed_channels), | |
) | |
self.act = nn.ReLU() | |
def forward(self, x, clusters): | |
feat = x.features | |
feats = [] | |
for i, cluster in enumerate(clusters): | |
pw = self.l_w[i](feat) | |
pw = pw - scatter(pw, cluster, reduce="mean")[cluster] | |
pw = self.weight[i](pw) | |
pw = torch.exp(pw - pw.max()) | |
pw = pw / (scatter(pw, cluster, reduce="sum", dim=0)[cluster] + 1e-6) | |
pfeat = self.proj[i](feat) * pw | |
pfeat = scatter(pfeat, cluster, reduce="sum")[cluster] | |
feats.append(pfeat) | |
adp = self.adaptive(feat) | |
adp = torch.softmax(adp, dim=1) | |
feats = torch.stack(feats, dim=1) | |
feats = torch.einsum("l n, l n c -> l c", adp, feats) | |
feat = self.proj[-1](feat) | |
feat = torch.cat([feat, feats], dim=1) | |
feat = self.fuse(feat) + x.features | |
res = feat | |
x = x.replace_feature(feat) | |
x = self.voxel_block(x) | |
x = x.replace_feature(self.act(x.features + res)) | |
return x | |
class DonwBlock(nn.Module): | |
def __init__( | |
self, | |
in_channels, | |
embed_channels, | |
depth, | |
sp_indice_key, | |
point_grid_size, | |
num_ref=16, | |
groups=None, | |
norm_fn=None, | |
sub_indice_key=None, | |
): | |
super().__init__() | |
self.num_ref = num_ref | |
self.depth = depth | |
self.point_grid_size = point_grid_size | |
self.down = spconv.SparseSequential( | |
spconv.SparseConv3d( | |
in_channels, | |
embed_channels, | |
kernel_size=2, | |
stride=2, | |
indice_key=sp_indice_key, | |
bias=False, | |
), | |
norm_fn(embed_channels), | |
nn.ReLU(), | |
) | |
self.blocks = nn.ModuleList() | |
for _ in range(depth): | |
self.blocks.append( | |
BasicBlock( | |
in_channels=embed_channels, | |
embed_channels=embed_channels, | |
depth=len(point_grid_size) + 1, | |
groups=groups, | |
grid_size=point_grid_size, | |
norm_fn=norm_fn, | |
indice_key=sub_indice_key, | |
) | |
) | |
def forward(self, x): | |
x = self.down(x) | |
coord = x.indices[:, 1:].float() | |
batch = x.indices[:, 0] | |
clusters = [] | |
for grid_size in self.point_grid_size: | |
cluster = voxel_grid(pos=coord, size=grid_size, batch=batch) | |
_, cluster = torch.unique(cluster, return_inverse=True) | |
clusters.append(cluster) | |
for block in self.blocks: | |
x = block(x, clusters) | |
return x | |
class UpBlock(nn.Module): | |
def __init__( | |
self, | |
in_channels, | |
skip_channels, | |
embed_channels, | |
depth, | |
sp_indice_key, | |
norm_fn=None, | |
down_ratio=2, | |
sub_indice_key=None, | |
): | |
super().__init__() | |
assert depth > 0 | |
self.up = spconv.SparseSequential( | |
spconv.SparseInverseConv3d( | |
in_channels, | |
embed_channels, | |
kernel_size=down_ratio, | |
indice_key=sp_indice_key, | |
bias=False, | |
), | |
norm_fn(embed_channels), | |
nn.ReLU(), | |
) | |
self.blocks = nn.ModuleList() | |
self.fuse = nn.Sequential( | |
nn.Linear(skip_channels + embed_channels, embed_channels), | |
norm_fn(embed_channels), | |
nn.ReLU(), | |
nn.Linear(embed_channels, embed_channels), | |
norm_fn(embed_channels), | |
nn.ReLU(), | |
) | |
def forward(self, x, skip_x): | |
x = self.up(x) | |
x = x.replace_feature( | |
self.fuse(torch.cat([x.features, skip_x.features], dim=1)) + x.features | |
) | |
return x | |
class OACNNs(nn.Module): | |
def __init__( | |
self, | |
in_channels, | |
num_classes, | |
embed_channels=64, | |
enc_num_ref=[16, 16, 16, 16], | |
enc_channels=[64, 64, 128, 256], | |
groups=[2, 4, 8, 16], | |
enc_depth=[2, 3, 6, 4], | |
down_ratio=[2, 2, 2, 2], | |
dec_channels=[96, 96, 128, 256], | |
point_grid_size=[[16, 32, 64], [8, 16, 24], [4, 8, 12], [2, 4, 6]], | |
dec_depth=[2, 2, 2, 2], | |
): | |
super().__init__() | |
self.in_channels = in_channels | |
self.num_classes = num_classes | |
self.num_stages = len(enc_channels) | |
self.embed_channels = embed_channels | |
norm_fn = partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01) | |
self.stem = spconv.SparseSequential( | |
spconv.SubMConv3d( | |
in_channels, | |
embed_channels, | |
kernel_size=3, | |
padding=1, | |
indice_key="stem", | |
bias=False, | |
), | |
norm_fn(embed_channels), | |
nn.ReLU(), | |
spconv.SubMConv3d( | |
embed_channels, | |
embed_channels, | |
kernel_size=3, | |
padding=1, | |
indice_key="stem", | |
bias=False, | |
), | |
norm_fn(embed_channels), | |
nn.ReLU(), | |
spconv.SubMConv3d( | |
embed_channels, | |
embed_channels, | |
kernel_size=3, | |
padding=1, | |
indice_key="stem", | |
bias=False, | |
), | |
norm_fn(embed_channels), | |
nn.ReLU(), | |
) | |
self.enc = nn.ModuleList() | |
self.dec = nn.ModuleList() | |
for i in range(self.num_stages): | |
self.enc.append( | |
DonwBlock( | |
in_channels=embed_channels if i == 0 else enc_channels[i - 1], | |
embed_channels=enc_channels[i], | |
depth=enc_depth[i], | |
norm_fn=norm_fn, | |
groups=groups[i], | |
point_grid_size=point_grid_size[i], | |
num_ref=enc_num_ref[i], | |
sp_indice_key=f"spconv{i}", | |
sub_indice_key=f"subm{i + 1}", | |
) | |
) | |
self.dec.append( | |
UpBlock( | |
in_channels=( | |
enc_channels[-1] | |
if i == self.num_stages - 1 | |
else dec_channels[i + 1] | |
), | |
skip_channels=embed_channels if i == 0 else enc_channels[i - 1], | |
embed_channels=dec_channels[i], | |
depth=dec_depth[i], | |
norm_fn=norm_fn, | |
sp_indice_key=f"spconv{i}", | |
sub_indice_key=f"subm{i}", | |
) | |
) | |
self.final = spconv.SubMConv3d(dec_channels[0], num_classes, kernel_size=1) | |
self.apply(self._init_weights) | |
def forward(self, input_dict): | |
discrete_coord = input_dict["grid_coord"] | |
feat = input_dict["feat"] | |
offset = input_dict["offset"] | |
batch = offset2batch(offset) | |
x = spconv.SparseConvTensor( | |
features=feat, | |
indices=torch.cat([batch.unsqueeze(-1), discrete_coord], dim=1) | |
.int() | |
.contiguous(), | |
spatial_shape=torch.add( | |
torch.max(discrete_coord, dim=0).values, 1 | |
).tolist(), | |
batch_size=batch[-1].tolist() + 1, | |
) | |
x = self.stem(x) | |
skips = [x] | |
for i in range(self.num_stages): | |
x = self.enc[i](x) | |
skips.append(x) | |
x = skips.pop(-1) | |
for i in reversed(range(self.num_stages)): | |
skip = skips.pop(-1) | |
x = self.dec[i](x, skip) | |
x = self.final(x) | |
return x.features | |
def _init_weights(m): | |
if isinstance(m, nn.Linear): | |
trunc_normal_(m.weight, std=0.02) | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, spconv.SubMConv3d): | |
trunc_normal_(m.weight, std=0.02) | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.BatchNorm1d): | |
nn.init.constant_(m.bias, 0) | |
nn.init.constant_(m.weight, 1.0) | |