| import torch | |
| import torch.nn as nn | |
| class MeanVFE(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| def forward(self, batch_dict, **kwargs): | |
| """ | |
| Args: | |
| batch_dict: | |
| voxels: (num_voxels, max_points_per_voxel, C) | |
| voxel_num_points: optional (num_voxels) | |
| **kwargs: | |
| Returns: | |
| vfe_features: (num_voxels, C) | |
| """ | |
| voxel_features, voxel_num_points = batch_dict['voxels'], batch_dict['voxel_num_points'] | |
| points_mean = voxel_features[:, :, :].sum(dim=1, keepdim=False) | |
| normalizer = torch.clamp_min(voxel_num_points.view(-1, 1), min=1.0).type_as(voxel_features) | |
| points_mean = points_mean / normalizer | |
| batch_dict['voxel_features'] = points_mean.contiguous() | |
| return batch_dict | |