YasiiKB's picture
initial commit
97aa5af verified
'''
We thank the author of DeepGMR paper to open-source their code.
Modified by Vinit Sarode.
'''
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .. ops import transform_functions as transform
def gmm_params(gamma, pts):
'''
Inputs:
gamma: B x N x J
pts: B x N x 3
'''
# pi: B x J
pi = gamma.mean(dim=1)
Npi = pi * gamma.shape[1]
# mu: B x J x 3
mu = gamma.transpose(1, 2) @ pts / Npi.unsqueeze(2)
# diff: B x N x J x 3
diff = pts.unsqueeze(2) - mu.unsqueeze(1)
# sigma: B x J x 3 x 3
eye = torch.eye(3).unsqueeze(0).unsqueeze(1).to(gamma.device)
sigma = (
((diff.unsqueeze(3) @ diff.unsqueeze(4)).squeeze() * gamma).sum(dim=1) / Npi
).unsqueeze(2).unsqueeze(3) * eye
return pi, mu, sigma
def gmm_register(pi_s, mu_s, mu_t, sigma_t):
'''
Inputs:
pi: B x J
mu: B x J x 3
sigma: B x J x 3 x 3
'''
c_s = pi_s.unsqueeze(1) @ mu_s
c_t = pi_s.unsqueeze(1) @ mu_t
Ms = torch.sum((pi_s.unsqueeze(2) * (mu_s - c_s)).unsqueeze(3) @
(mu_t - c_t).unsqueeze(2) @ sigma_t.inverse(), dim=1)
U, _, V = torch.svd(Ms.cpu())
U = U.cuda() if torch.cuda.is_available() else U
V = V.cuda() if torch.cuda.is_available() else V
S = torch.eye(3).unsqueeze(0).repeat(U.shape[0], 1, 1).to(U.device)
S[:, 2, 2] = torch.det(V @ U.transpose(1, 2))
R = V @ S @ U.transpose(1, 2)
t = c_t.transpose(1, 2) - R @ c_s.transpose(1, 2)
bot_row = torch.Tensor([[[0, 0, 0, 1]]]).repeat(R.shape[0], 1, 1).to(R.device)
T = torch.cat([torch.cat([R, t], dim=2), bot_row], dim=1)
return T
class Conv1dBNReLU(nn.Sequential):
def __init__(self, in_planes, out_planes):
super(Conv1dBNReLU, self).__init__(
nn.Conv1d(in_planes, out_planes, kernel_size=1, bias=False),
nn.BatchNorm1d(out_planes),
nn.ReLU(inplace=True))
class FCBNReLU(nn.Sequential):
def __init__(self, in_planes, out_planes):
super(FCBNReLU, self).__init__(
nn.Linear(in_planes, out_planes, bias=False),
nn.BatchNorm1d(out_planes),
nn.ReLU(inplace=True))
class TNet(nn.Module):
def __init__(self):
super(TNet, self).__init__()
self.encoder = nn.Sequential(
Conv1dBNReLU(3, 64),
Conv1dBNReLU(64, 128),
Conv1dBNReLU(128, 256))
self.decoder = nn.Sequential(
FCBNReLU(256, 128),
FCBNReLU(128, 64),
nn.Linear(64, 6))
@staticmethod
def f2R(f):
r1 = F.normalize(f[:, :3])
proj = (r1.unsqueeze(1) @ f[:, 3:].unsqueeze(2)).squeeze(2)
r2 = F.normalize(f[:, 3:] - proj * r1)
r3 = r1.cross(r2)
return torch.stack([r1, r2, r3], dim=2)
def forward(self, pts):
f = self.encoder(pts)
f, _ = f.max(dim=2)
f = self.decoder(f)
R = self.f2R(f)
return R @ pts
class PointNet(nn.Module):
def __init__(self, use_rri, use_tnet=False, nearest_neighbors=20):
super(PointNet, self).__init__()
self.use_tnet = use_tnet
self.tnet = TNet() if self.use_tnet else None
d_input = nearest_neighbors * 4 if use_rri else 3
self.encoder = nn.Sequential(
Conv1dBNReLU(d_input, 64),
Conv1dBNReLU(64, 128),
Conv1dBNReLU(128, 256),
Conv1dBNReLU(256, args.d_model))
self.decoder = nn.Sequential(
Conv1dBNReLU(args.d_model * 2, 512),
Conv1dBNReLU(512, 256),
Conv1dBNReLU(256, 128),
nn.Conv1d(128, args.n_clusters, kernel_size=1))
def forward(self, pts):
pts = self.tnet(pts) if self.use_tnet else pts
f_loc = self.encoder(pts)
f_glob, _ = f_loc.max(dim=2)
f_glob = f_glob.unsqueeze(2).expand_as(f_loc)
y = self.decoder(torch.cat([f_loc, f_glob], dim=1))
return y.transpose(1, 2)
class DeepGMR(nn.Module):
def __init__(self, use_rri=True, feature_model=None, nearest_neighbors=20):
super(DeepGMR, self).__init__()
self.backbone = feature_model if not None else PointNet(use_rri=use_rri, nearest_neighbors=nearest_neighbors)
self.use_rri = use_rri
def forward(self, template, source):
if self.use_rri:
self.template = template[..., :3]
self.source = source[..., :3]
template_features = template[..., 3:].transpose(1, 2)
source_features = source[..., 3:].transpose(1, 2)
else:
self.template = template
self.source = source
template_features = (template - template.mean(dim=2, keepdim=True)).transpose(1, 2)
source_features = (source - source.mean(dim=2, keepdim=True)).transpose(1, 2)
self.template_gamma = F.softmax(self.backbone(template_features), dim=2)
self.template_pi, self.template_mu, self.template_sigma = gmm_params(self.template_gamma, self.template)
self.source_gamma = F.softmax(self.backbone(source_features), dim=2)
self.source_pi, self.source_mu, self.source_sigma = gmm_params(self.source_gamma, self.source)
self.est_T_inverse = gmm_register(self.template_pi, self.template_mu, self.source_mu, self.source_sigma)
self.est_T = gmm_register(self.source_pi, self.source_mu, self.template_mu, self.template_sigma) # [template = source * est_T]
self.igt = igt # [source = template * igt]
transformed_source = transform.transform_point_cloud(source, est_T[:, :3, :3], est_T[:, :3, 3])
result = {'est_R': est_T[:, :3, :3],
'est_t': est_T[:, :3, 3],
'est_R_inverse': est_T_inverse[:, :3, :3],
'est_t_inverese': est_T_inverse[:, :3, 3],
'est_T': est_T,
'est_T_inverse': est_T_inverse,
'r': template_features - source_features,
'transformed_source': transformed_source}
return result