YasiiKB's picture
initial commit
97aa5af verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from .pooling import Pooling
# Mish Activation Function
class Mish(nn.Module):
def __init__(self):
super(Mish, self).__init__()
def forward(self, x):
return x * torch.tanh(F.softplus(x))
# Basic Convolution Block
class BasicConv1D(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, active = True):
super(BasicConv1D, self).__init__()
self.active = active
self.bn = nn.BatchNorm1d( out_channels)
if self.active == True:
self.activation = Mish()
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, bias=False)
#self.dropout = nn.Dropout(0.5)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
if self.active == True:
x = self.activation(x)
return x
class Self_Attn(nn.Module):
""" Self attention Layer"""
def __init__(self, in_dim, out_dim):
super(Self_Attn,self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
# Query Convolution
self.query_conv =BasicConv1D(in_dim, out_dim)
self.beta = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1) #
def forward(self,x):
"""
inputs :
x : input feature maps( B X C X N) 32, 1024, 64
returns :
out : self attention value + input feature
attention: B X N X N (N is Width*Height)
"""
proj_query = self.query_conv(x).permute(0,2,1) # B, in_dim, N ---> B, in_dim // 8, N ----> B, N, in_dim // 8
proj_key = proj_query.permute(0,2,1) #B, in_dim, N ---> B, in_dim // 8, N
energy = torch.bmm(proj_query,proj_key) # transpose check B, N, N
attention = self.softmax(energy) # B , N, N
out_x = torch.bmm(proj_key, attention.permute(0,2,1) ) #B, out_dim, N
out = self.beta * out_x + proj_key
return out
class PointNet(torch.nn.Module):
def __init__(self, emb_dims=224, input_shape="bnc", use_bn=False, global_feat=True):
# emb_dims: Embedding Dimensions for PointNet.
# input_shape: Shape of Input Point Cloud (b: batch, n: no of points, c: channels)
super(PointNet, self).__init__()
if input_shape not in ["bcn", "bnc"]:
raise ValueError("Allowed shapes are 'bcn' (batch * channels * num_in_points), 'bnc' ")
self.input_shape = input_shape
self.emb_dims = emb_dims
self.use_bn = use_bn
self.global_feat = global_feat
if not self.global_feat: self.pooling = Pooling('max')
self.conv1 = Self_Attn(3, 32)
self.conv2 = Self_Attn(32, 64)
self.conv3 = Self_Attn(64, 64)
self.conv4 = Self_Attn(64, 128)
self.conv5 = Self_Attn(128, self.emb_dims)
def forward(self, input_data):
# input_data: Point Cloud having shape input_shape.
# output: PointNet features (Batch x emb_dims)
if self.input_shape == "bnc":
num_points = input_data.shape[1]
input_data = input_data.permute(0, 2, 1)
else:
num_points = input_data.shape[2]
if input_data.shape[1] != 3:
raise RuntimeError("shape of x must be of [Batch x 3 x NumInPoints]")
output = input_data
x1 = self.conv1(output) #32
x2 = self.conv2(x1) #64
x3 = self.conv3(x2) #64
x4 = self.conv4(x3+x2) #128
x5 = self.conv5(x4)
output = torch.cat([ x1, x2, x3, x4, x5], dim=1) #256, x4 x0,
point_feature = output
if self.global_feat:
return output
else:
output = self.pooling(output)
output = output.view(-1, self.emb_dims, 1).repeat(1, 1, num_points)
return torch.cat([output, point_feature], 1)
# self attention mechanism
class self_attention_fc(nn.Module):
""" Self attention Layer"""
def __init__(self,in_dim, out_dim): #1024
super(self_attention_fc,self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.query_conv = BasicConv1D(in_dim, out_dim)
self.beta = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1) #
def forward(self,x, y): #B, 1024 , 1
"""
inputs :
x : input feature maps( B X C,1 )
returns :
out : self attention value + input feature
attention: B X N X N (N is Width*Height)
"""
proj_query_x = self.query_conv(x) #[B, in_dim, 1]----->[B, out_dim1, 1]
proj_key_y = self.query_conv(y).permute(0,2,1) #[B, 1, out_dim1]
energy_xy = torch.bmm(proj_query_x, proj_key_y) # xi Attention scores for all points in y [B, 64, 64]
attention_xy = self.softmax(energy_xy)
attention_yx = self.softmax(energy_xy.permute(0,2,1))
proj_value_x = proj_query_x # self.value_conv_x(x) # [B, out_dim, 64]
proj_value_y = proj_key_y.permute(0,2,1) # self.value_conv_x(y) # [B, out_dim, 64]
out_x = torch.bmm(attention_xy, proj_value_x) # [B, out_dim]
out_x = self.beta* out_x + proj_value_x # self.kama*
out_y = torch.bmm(attention_yx, proj_value_y ) # [B, out_dim]
out_y = self.beta*out_y + proj_value_y # self.kama *
return out_x, out_y
class PointNetMask(nn.Module):
def __init__(self, template_feature_size=1024, source_feature_size=1024, feature_model=PointNet()):
super().__init__()
self.feature_model = feature_model
self.pooling_max = Pooling(pool_type='max')
self.pooling_avg = Pooling(pool_type='avg')
input_size = template_feature_size + source_feature_size
self.global_feat_1 = self_attention_fc(1024, 512)
self.global_feat_2 = self_attention_fc(512, 256)
self.global_feat_3 = self_attention_fc(256, 512)
self.h3 = nn.Sequential(BasicConv1D(1024, 512),
BasicConv1D(512, 256),
BasicConv1D(256, 128),
nn.Conv1d(128, 1, 1), nn.Sigmoid())
def find_mask(self, source_features, template_features):
global_source_features_max = self.pooling_max(source_features)
global_template_features_max = self.pooling_max(template_features)
global_source_features_avg = self.pooling_avg(source_features)
global_template_features_avg = self.pooling_avg(template_features)
global_source_features = torch.cat([global_source_features_max, global_source_features_avg], dim=1)
global_template_features = torch.cat([global_template_features_max, global_template_features_avg], dim=1)
shared_feat_1,shared_feat_2 = self.global_feat_1(global_source_features.unsqueeze(2), global_template_features.unsqueeze(2))
shared_feat_1,shared_feat_2 = self.global_feat_2(shared_feat_1, shared_feat_2)
shared_feat_1,shared_feat_2 = self.global_feat_3(shared_feat_1, shared_feat_2)
batch_size, _ , num_points = source_features.size()
global_source_features = shared_feat_1
global_source_features = global_source_features.repeat(1,1,num_points)
x = torch.cat([template_features, global_source_features], dim=1)
x = self.h3(x)
batch_size, _ , num_points = template_features.size()
global_template_features = shared_feat_2
global_template_features = global_template_features.repeat(1,1,num_points)
y = torch.cat([source_features, global_template_features], dim=1)
y = self.h3(y)
return x.view(batch_size, -1), y.view(batch_size, -1)
def forward(self, template, source):
source_features = self.feature_model(source) # [B x C x N]
template_features = self.feature_model(template) # [B x C x N]
template_mask, source_mask = self.find_mask(source_features, template_features)
return template_mask, source_mask
class MaskNet2(nn.Module):
def __init__(self, feature_model=PointNet(use_bn=True), is_training=True):
super().__init__()
self.maskNet = PointNetMask(feature_model=feature_model)
self.is_training = is_training
@staticmethod
def index_points(points, idx):
"""
Input:
points: input points data, [B, N, C]
idx: sample index data, [B, S]
Return:
new_points:, indexed points data, [B, S, C]
"""
device = points.device
B = points.shape[0]
view_shape = list(idx.shape)
view_shape[1:] = [1] * (len(view_shape) - 1)
repeat_shape = list(idx.shape)
repeat_shape[0] = 1
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
new_points = points[batch_indices, idx, :]
return new_points
def forward(self, template, source, point_selection='threshold', mask_threshold = 0.5):
template_mask, source_mask = self.maskNet(template, source) #B, N
if not torch.cuda.is_available():
device = 'cpu'
device = torch.device(device)
source_binary_mask = torch.where(source_mask > mask_threshold, torch.ones(source_mask.size()).to(device), torch.zeros(source_mask.size()).to(device))
template_binary_mask = torch.where(template_mask > mask_threshold, torch.ones(template_mask.size()).to(device), torch.zeros(template_mask.size()).to(device))
masked_template = template[:, torch.tensor(template_binary_mask, dtype = torch.bool).squeeze(0), 0:3]
masked_source = source[:, torch.tensor(source_binary_mask, dtype = torch.bool).squeeze(0), 0:3]
return masked_template, masked_source, template_mask, source_mask
if __name__ == '__main__':
template, source = torch.rand(10,1024,3), torch.rand(10,1024,3)
net = MaskNet2()
result = net(template, source)
import ipdb; ipdb.set_trace()