| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from .pooling import Pooling |
|
|
|
|
| |
| class Mish(nn.Module): |
| def __init__(self): |
| super(Mish, self).__init__() |
|
|
| def forward(self, x): |
| return x * torch.tanh(F.softplus(x)) |
|
|
|
|
| |
| class BasicConv1D(nn.Module): |
| def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, active = True): |
| super(BasicConv1D, self).__init__() |
| self.active = active |
| self.bn = nn.BatchNorm1d( out_channels) |
| if self.active == True: |
| self.activation = Mish() |
| self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, bias=False) |
| |
|
|
| def forward(self, x): |
| x = self.conv(x) |
| x = self.bn(x) |
| if self.active == True: |
| x = self.activation(x) |
| return x |
|
|
|
|
| class Self_Attn(nn.Module): |
| """ Self attention Layer""" |
| def __init__(self, in_dim, out_dim): |
| super(Self_Attn,self).__init__() |
|
|
| self.in_dim = in_dim |
| self.out_dim = out_dim |
| |
| |
| self.query_conv =BasicConv1D(in_dim, out_dim) |
| |
| self.beta = nn.Parameter(torch.zeros(1)) |
|
|
| self.softmax = nn.Softmax(dim=-1) |
|
|
| def forward(self,x): |
| """ |
| inputs : |
| x : input feature maps( B X C X N) 32, 1024, 64 |
| returns : |
| out : self attention value + input feature |
| attention: B X N X N (N is Width*Height) |
| """ |
|
|
| proj_query = self.query_conv(x).permute(0,2,1) |
| proj_key = proj_query.permute(0,2,1) |
| |
| energy = torch.bmm(proj_query,proj_key) |
|
|
| attention = self.softmax(energy) |
|
|
| out_x = torch.bmm(proj_key, attention.permute(0,2,1) ) |
| |
| out = self.beta * out_x + proj_key |
| |
| return out |
|
|
| class PointNet(torch.nn.Module): |
| def __init__(self, emb_dims=224, input_shape="bnc", use_bn=False, global_feat=True): |
| |
| |
| super(PointNet, self).__init__() |
| if input_shape not in ["bcn", "bnc"]: |
| raise ValueError("Allowed shapes are 'bcn' (batch * channels * num_in_points), 'bnc' ") |
| self.input_shape = input_shape |
| self.emb_dims = emb_dims |
| self.use_bn = use_bn |
| self.global_feat = global_feat |
| if not self.global_feat: self.pooling = Pooling('max') |
|
|
| self.conv1 = Self_Attn(3, 32) |
| self.conv2 = Self_Attn(32, 64) |
| self.conv3 = Self_Attn(64, 64) |
| self.conv4 = Self_Attn(64, 128) |
| self.conv5 = Self_Attn(128, self.emb_dims) |
|
|
|
|
| def forward(self, input_data): |
| |
| |
|
|
| if self.input_shape == "bnc": |
| num_points = input_data.shape[1] |
| input_data = input_data.permute(0, 2, 1) |
| else: |
| num_points = input_data.shape[2] |
| if input_data.shape[1] != 3: |
| raise RuntimeError("shape of x must be of [Batch x 3 x NumInPoints]") |
|
|
| output = input_data |
| |
| x1 = self.conv1(output) |
| x2 = self.conv2(x1) |
| x3 = self.conv3(x2) |
| x4 = self.conv4(x3+x2) |
| x5 = self.conv5(x4) |
|
|
| output = torch.cat([ x1, x2, x3, x4, x5], dim=1) |
| point_feature = output |
|
|
| if self.global_feat: |
| return output |
| else: |
| output = self.pooling(output) |
| output = output.view(-1, self.emb_dims, 1).repeat(1, 1, num_points) |
| return torch.cat([output, point_feature], 1) |
|
|
|
|
| |
| class self_attention_fc(nn.Module): |
| """ Self attention Layer""" |
| def __init__(self,in_dim, out_dim): |
| super(self_attention_fc,self).__init__() |
| |
| self.in_dim = in_dim |
| self.out_dim = out_dim |
|
|
| self.query_conv = BasicConv1D(in_dim, out_dim) |
|
|
| self.beta = nn.Parameter(torch.zeros(1)) |
| self.softmax = nn.Softmax(dim=-1) |
|
|
| def forward(self,x, y): |
| """ |
| inputs : |
| x : input feature maps( B X C,1 ) |
| returns : |
| out : self attention value + input feature |
| attention: B X N X N (N is Width*Height) |
| """ |
| proj_query_x = self.query_conv(x) |
|
|
| proj_key_y = self.query_conv(y).permute(0,2,1) |
| |
| energy_xy = torch.bmm(proj_query_x, proj_key_y) |
|
|
| attention_xy = self.softmax(energy_xy) |
| attention_yx = self.softmax(energy_xy.permute(0,2,1)) |
|
|
| proj_value_x = proj_query_x |
| proj_value_y = proj_key_y.permute(0,2,1) |
|
|
| out_x = torch.bmm(attention_xy, proj_value_x) |
| out_x = self.beta* out_x + proj_value_x |
|
|
| out_y = torch.bmm(attention_yx, proj_value_y ) |
| out_y = self.beta*out_y + proj_value_y |
|
|
| return out_x, out_y |
|
|
|
|
|
|
| class PointNetMask(nn.Module): |
| def __init__(self, template_feature_size=1024, source_feature_size=1024, feature_model=PointNet()): |
| super().__init__() |
| self.feature_model = feature_model |
| self.pooling_max = Pooling(pool_type='max') |
| self.pooling_avg = Pooling(pool_type='avg') |
|
|
| input_size = template_feature_size + source_feature_size |
|
|
| self.global_feat_1 = self_attention_fc(1024, 512) |
| self.global_feat_2 = self_attention_fc(512, 256) |
| self.global_feat_3 = self_attention_fc(256, 512) |
|
|
| self.h3 = nn.Sequential(BasicConv1D(1024, 512), |
| BasicConv1D(512, 256), |
| BasicConv1D(256, 128), |
| nn.Conv1d(128, 1, 1), nn.Sigmoid()) |
|
|
|
|
| def find_mask(self, source_features, template_features): |
| global_source_features_max = self.pooling_max(source_features) |
| global_template_features_max = self.pooling_max(template_features) |
| global_source_features_avg = self.pooling_avg(source_features) |
| global_template_features_avg = self.pooling_avg(template_features) |
| global_source_features = torch.cat([global_source_features_max, global_source_features_avg], dim=1) |
| global_template_features = torch.cat([global_template_features_max, global_template_features_avg], dim=1) |
|
|
| shared_feat_1,shared_feat_2 = self.global_feat_1(global_source_features.unsqueeze(2), global_template_features.unsqueeze(2)) |
| shared_feat_1,shared_feat_2 = self.global_feat_2(shared_feat_1, shared_feat_2) |
| shared_feat_1,shared_feat_2 = self.global_feat_3(shared_feat_1, shared_feat_2) |
|
|
| batch_size, _ , num_points = source_features.size() |
| global_source_features = shared_feat_1 |
| global_source_features = global_source_features.repeat(1,1,num_points) |
| x = torch.cat([template_features, global_source_features], dim=1) |
| x = self.h3(x) |
|
|
| batch_size, _ , num_points = template_features.size() |
| global_template_features = shared_feat_2 |
| global_template_features = global_template_features.repeat(1,1,num_points) |
| y = torch.cat([source_features, global_template_features], dim=1) |
| y = self.h3(y) |
|
|
| return x.view(batch_size, -1), y.view(batch_size, -1) |
|
|
| def forward(self, template, source): |
| source_features = self.feature_model(source) |
| template_features = self.feature_model(template) |
|
|
| template_mask, source_mask = self.find_mask(source_features, template_features) |
| return template_mask, source_mask |
|
|
| class MaskNet2(nn.Module): |
| def __init__(self, feature_model=PointNet(use_bn=True), is_training=True): |
| super().__init__() |
| self.maskNet = PointNetMask(feature_model=feature_model) |
| self.is_training = is_training |
|
|
| @staticmethod |
| def index_points(points, idx): |
| """ |
| Input: |
| points: input points data, [B, N, C] |
| idx: sample index data, [B, S] |
| Return: |
| new_points:, indexed points data, [B, S, C] |
| """ |
| device = points.device |
| B = points.shape[0] |
| view_shape = list(idx.shape) |
| view_shape[1:] = [1] * (len(view_shape) - 1) |
| repeat_shape = list(idx.shape) |
| repeat_shape[0] = 1 |
| batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) |
| new_points = points[batch_indices, idx, :] |
|
|
| return new_points |
|
|
| def forward(self, template, source, point_selection='threshold', mask_threshold = 0.5): |
| template_mask, source_mask = self.maskNet(template, source) |
| if not torch.cuda.is_available(): |
| device = 'cpu' |
| device = torch.device(device) |
| |
| source_binary_mask = torch.where(source_mask > mask_threshold, torch.ones(source_mask.size()).to(device), torch.zeros(source_mask.size()).to(device)) |
| template_binary_mask = torch.where(template_mask > mask_threshold, torch.ones(template_mask.size()).to(device), torch.zeros(template_mask.size()).to(device)) |
|
|
| masked_template = template[:, torch.tensor(template_binary_mask, dtype = torch.bool).squeeze(0), 0:3] |
| masked_source = source[:, torch.tensor(source_binary_mask, dtype = torch.bool).squeeze(0), 0:3] |
|
|
| return masked_template, masked_source, template_mask, source_mask |
|
|
|
|
| if __name__ == '__main__': |
| template, source = torch.rand(10,1024,3), torch.rand(10,1024,3) |
| net = MaskNet2() |
| result = net(template, source) |
| import ipdb; ipdb.set_trace() |