| |
| |
|
|
|
|
| import os |
| import sys |
| import glob |
| import h5py |
| import copy |
| import math |
| import json |
| import numpy as np |
| from tqdm import tqdm |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from .. ops import transform_functions as transform |
| from .. utils import Transformer, Identity, knn, get_graph_feature |
|
|
| from sklearn.metrics import r2_score |
|
|
| device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') |
|
|
|
|
| def pairwise_distance(src, tgt): |
| inner = -2 * torch.matmul(src.transpose(2, 1).contiguous(), tgt) |
| xx = torch.sum(src**2, dim=1, keepdim=True) |
| yy = torch.sum(tgt**2, dim=1, keepdim=True) |
| distances = xx.transpose(2, 1).contiguous() + inner + yy |
| return torch.sqrt(distances) |
|
|
| def cycle_consistency(rotation_ab, translation_ab, rotation_ba, translation_ba): |
| batch_size = rotation_ab.size(0) |
| identity = torch.eye(3, device=rotation_ab.device).unsqueeze(0).repeat(batch_size, 1, 1) |
| return F.mse_loss(torch.matmul(rotation_ab, rotation_ba), identity) + F.mse_loss(translation_ab, -translation_ba) |
|
|
|
|
| class PointNet(nn.Module): |
| def __init__(self, emb_dims=512): |
| super(PointNet, self).__init__() |
| self.conv1 = nn.Conv1d(3, 64, kernel_size=1, bias=False) |
| self.conv2 = nn.Conv1d(64, 64, kernel_size=1, bias=False) |
| self.conv3 = nn.Conv1d(64, 64, kernel_size=1, bias=False) |
| self.conv4 = nn.Conv1d(64, 128, kernel_size=1, bias=False) |
| self.conv5 = nn.Conv1d(128, emb_dims, kernel_size=1, bias=False) |
| self.bn1 = nn.BatchNorm1d(64) |
| self.bn2 = nn.BatchNorm1d(64) |
| self.bn3 = nn.BatchNorm1d(64) |
| self.bn4 = nn.BatchNorm1d(128) |
| self.bn5 = nn.BatchNorm1d(emb_dims) |
|
|
| def forward(self, x): |
| x = F.relu(self.bn1(self.conv1(x))) |
| x = F.relu(self.bn2(self.conv2(x))) |
| x = F.relu(self.bn3(self.conv3(x))) |
| x = F.relu(self.bn4(self.conv4(x))) |
| x = F.relu(self.bn5(self.conv5(x))) |
| return x |
|
|
|
|
| class DGCNN(nn.Module): |
| def __init__(self, emb_dims=512): |
| super(DGCNN, self).__init__() |
| self.conv1 = nn.Conv2d(6, 64, kernel_size=1, bias=False) |
| self.conv2 = nn.Conv2d(64*2, 64, kernel_size=1, bias=False) |
| self.conv3 = nn.Conv2d(64*2, 128, kernel_size=1, bias=False) |
| self.conv4 = nn.Conv2d(128*2, 256, kernel_size=1, bias=False) |
| self.conv5 = nn.Conv2d(512, emb_dims, kernel_size=1, bias=False) |
| self.bn1 = nn.BatchNorm2d(64) |
| self.bn2 = nn.BatchNorm2d(64) |
| self.bn3 = nn.BatchNorm2d(128) |
| self.bn4 = nn.BatchNorm2d(256) |
| self.bn5 = nn.BatchNorm2d(emb_dims) |
|
|
| def forward(self, x): |
| batch_size, num_dims, num_points = x.size() |
| x = get_graph_feature(x, device=device) |
| x = F.leaky_relu(self.bn1(self.conv1(x)), negative_slope=0.2) |
| x1 = x.max(dim=-1, keepdim=True)[0] |
| |
| x = get_graph_feature(x1, device=device) |
| x = F.leaky_relu(self.bn2(self.conv2(x)), negative_slope=0.2) |
| x2 = x.max(dim=-1, keepdim=True)[0] |
|
|
| x = get_graph_feature(x2, device=device) |
| x = F.leaky_relu(self.bn3(self.conv3(x)), negative_slope=0.2) |
| x3 = x.max(dim=-1, keepdim=True)[0] |
|
|
| x = get_graph_feature(x3, device=device) |
| x = F.leaky_relu(self.bn4(self.conv4(x)), negative_slope=0.2) |
| x4 = x.max(dim=-1, keepdim=True)[0] |
|
|
| x = torch.cat((x1, x2, x3, x4), dim=1) |
|
|
| x = F.leaky_relu(self.bn5(self.conv5(x)), negative_slope=0.2).view(batch_size, -1, num_points) |
| return x |
|
|
|
|
| class MLPHead(nn.Module): |
| def __init__(self, emb_dims): |
| super(MLPHead, self).__init__() |
| n_emb_dims = emb_dims |
| self.n_emb_dims = n_emb_dims |
| self.nn = nn.Sequential(nn.Linear(n_emb_dims*2, n_emb_dims//2), |
| nn.BatchNorm1d(n_emb_dims//2), |
| nn.ReLU(), |
| nn.Linear(n_emb_dims//2, n_emb_dims//4), |
| nn.BatchNorm1d(n_emb_dims//4), |
| nn.ReLU(), |
| nn.Linear(n_emb_dims//4, n_emb_dims//8), |
| nn.BatchNorm1d(n_emb_dims//8), |
| nn.ReLU()) |
| self.proj_rot = nn.Linear(n_emb_dims//8, 4) |
| self.proj_trans = nn.Linear(n_emb_dims//8, 3) |
|
|
| def forward(self, *input): |
| src_embedding = input[0] |
| tgt_embedding = input[1] |
| embedding = torch.cat((src_embedding, tgt_embedding), dim=1) |
| embedding = self.nn(embedding.max(dim=-1)[0]) |
| rotation = self.proj_rot(embedding) |
| rotation = rotation / torch.norm(rotation, p=2, dim=1, keepdim=True) |
| translation = self.proj_trans(embedding) |
| return quat2mat(rotation), translation |
|
|
|
|
| class TemperatureNet(nn.Module): |
| def __init__(self, emb_dims, temp_factor): |
| super(TemperatureNet, self).__init__() |
| self.n_emb_dims = emb_dims |
| self.temp_factor = temp_factor |
| self.nn = nn.Sequential(nn.Linear(self.n_emb_dims, 128), |
| nn.BatchNorm1d(128), |
| nn.ReLU(), |
| nn.Linear(128, 128), |
| nn.BatchNorm1d(128), |
| nn.ReLU(), |
| nn.Linear(128, 128), |
| nn.BatchNorm1d(128), |
| nn.ReLU(), |
| nn.Linear(128, 1), |
| nn.ReLU()) |
| self.feature_disparity = None |
|
|
| def forward(self, *input): |
| src_embedding = input[0] |
| tgt_embedding = input[1] |
| src_embedding = src_embedding.mean(dim=2) |
| tgt_embedding = tgt_embedding.mean(dim=2) |
| residual = torch.abs(src_embedding-tgt_embedding) |
|
|
| self.feature_disparity = residual |
|
|
| return torch.clamp(self.nn(residual), 1.0/self.temp_factor, 1.0*self.temp_factor), residual |
|
|
|
|
| class SVDHead(nn.Module): |
| def __init__(self, emb_dims, cat_sampler): |
| super(SVDHead, self).__init__() |
| self.n_emb_dims = emb_dims |
| self.cat_sampler = cat_sampler |
| self.reflect = nn.Parameter(torch.eye(3), requires_grad=False) |
| self.reflect[2, 2] = -1 |
| self.temperature = nn.Parameter(torch.ones(1)*0.5, requires_grad=True) |
| self.my_iter = torch.ones(1) |
|
|
| def forward(self, *input): |
| src_embedding = input[0] |
| tgt_embedding = input[1] |
| src = input[2] |
| tgt = input[3] |
| batch_size, num_dims, num_points = src.size() |
| temperature = input[4].view(batch_size, 1, 1) |
|
|
| if self.cat_sampler == 'softmax': |
| d_k = src_embedding.size(1) |
| scores = torch.matmul(src_embedding.transpose(2, 1).contiguous(), tgt_embedding) / math.sqrt(d_k) |
| scores = torch.softmax(temperature*scores, dim=2) |
| elif self.cat_sampler == 'gumbel_softmax': |
| d_k = src_embedding.size(1) |
| scores = torch.matmul(src_embedding.transpose(2, 1).contiguous(), tgt_embedding) / math.sqrt(d_k) |
| scores = scores.view(batch_size*num_points, num_points) |
| temperature = temperature.repeat(1, num_points, 1).view(-1, 1) |
| scores = F.gumbel_softmax(scores, tau=temperature, hard=True) |
| scores = scores.view(batch_size, num_points, num_points) |
| else: |
| raise Exception('not implemented') |
|
|
| src_corr = torch.matmul(tgt, scores.transpose(2, 1).contiguous()) |
|
|
| src_centered = src - src.mean(dim=2, keepdim=True) |
|
|
| src_corr_centered = src_corr - src_corr.mean(dim=2, keepdim=True) |
|
|
| H = torch.matmul(src_centered, src_corr_centered.transpose(2, 1).contiguous()).cpu() |
|
|
| R = [] |
|
|
| for i in range(src.size(0)): |
| u, s, v = torch.svd(H[i]) |
| r = torch.matmul(v, u.transpose(1, 0)).contiguous() |
| r_det = torch.det(r).item() |
| diag = torch.from_numpy(np.array([[1.0, 0, 0], |
| [0, 1.0, 0], |
| [0, 0, r_det]]).astype('float32')).to(v.device) |
| r = torch.matmul(torch.matmul(v, diag), u.transpose(1, 0)).contiguous() |
| R.append(r) |
|
|
| R = torch.stack(R, dim=0).to(device) |
|
|
| t = torch.matmul(-R, src.mean(dim=2, keepdim=True)) + src_corr.mean(dim=2, keepdim=True) |
| if self.training: |
| self.my_iter += 1 |
| return R, t.view(batch_size, 3) |
|
|
|
|
| class KeyPointNet(nn.Module): |
| def __init__(self, num_keypoints): |
| super(KeyPointNet, self).__init__() |
| self.num_keypoints = num_keypoints |
|
|
| def forward(self, *input): |
| src = input[0] |
| tgt = input[1] |
| src_embedding = input[2] |
| tgt_embedding = input[3] |
| batch_size, num_dims, num_points = src_embedding.size() |
| src_norm = torch.norm(src_embedding, dim=1, keepdim=True) |
| tgt_norm = torch.norm(tgt_embedding, dim=1, keepdim=True) |
| src_topk_idx = torch.topk(src_norm, k=self.num_keypoints, dim=2, sorted=False)[1] |
| tgt_topk_idx = torch.topk(tgt_norm, k=self.num_keypoints, dim=2, sorted=False)[1] |
| src_keypoints_idx = src_topk_idx.repeat(1, 3, 1) |
| tgt_keypoints_idx = tgt_topk_idx.repeat(1, 3, 1) |
| src_embedding_idx = src_topk_idx.repeat(1, num_dims, 1) |
| tgt_embedding_idx = tgt_topk_idx.repeat(1, num_dims, 1) |
|
|
| src_keypoints = torch.gather(src, dim=2, index=src_keypoints_idx) |
| tgt_keypoints = torch.gather(tgt, dim=2, index=tgt_keypoints_idx) |
| |
| src_embedding = torch.gather(src_embedding, dim=2, index=src_embedding_idx) |
| tgt_embedding = torch.gather(tgt_embedding, dim=2, index=tgt_embedding_idx) |
| return src_keypoints, tgt_keypoints, src_embedding, tgt_embedding |
|
|
|
|
| class PRNet(nn.Module): |
| def __init__(self, emb_nn='dgcnn', attention='transformer', head='svd', emb_dims=512, num_keypoints=512, num_subsampled_points=768, num_iters=3, cycle_consistency_loss=0.1, feature_alignment_loss=0.1, discount_factor = 0.9, input_shape='bnc'): |
| super(PRNet, self).__init__() |
| self.emb_dims = emb_dims |
| self.num_keypoints = num_keypoints |
| self.num_subsampled_points = num_subsampled_points |
| self.num_iters = num_iters |
| self.discount_factor = discount_factor |
| self.feature_alignment_loss = feature_alignment_loss |
| self.cycle_consistency_loss = cycle_consistency_loss |
| self.input_shape = input_shape |
| |
| if emb_nn == 'pointnet': |
| self.emb_nn = PointNet(emb_dims=self.emb_dims) |
| elif emb_nn == 'dgcnn': |
| self.emb_nn = DGCNN(emb_dims=self.emb_dims) |
| else: |
| raise Exception('Not implemented') |
|
|
| if attention == 'identity': |
| self.attention = Identity() |
| elif attention == 'transformer': |
| self.attention = Transformer(emb_dims=self.emb_dims, n_blocks=1, dropout=0.0, ff_dims=1024, n_heads=4) |
| else: |
| raise Exception("Not implemented") |
|
|
| self.temp_net = TemperatureNet(emb_dims=self.emb_dims, temp_factor=100) |
|
|
| if head == 'mlp': |
| self.head = MLPHead(emb_dims=self.emb_dims) |
| elif head == 'svd': |
| self.head = SVDHead(emb_dims=self.emb_dims, cat_sampler='softmax') |
| else: |
| raise Exception('Not implemented') |
|
|
| if self.num_keypoints != self.num_subsampled_points: |
| self.keypointnet = KeyPointNet(num_keypoints=self.num_keypoints) |
| else: |
| self.keypointnet = Identity() |
|
|
| def predict_embedding(self, *input): |
| src = input[0] |
| tgt = input[1] |
| src_embedding = self.emb_nn(src) |
| tgt_embedding = self.emb_nn(tgt) |
|
|
| src_embedding_p, tgt_embedding_p = self.attention(src_embedding, tgt_embedding) |
|
|
| src_embedding = src_embedding + src_embedding_p |
| tgt_embedding = tgt_embedding + tgt_embedding_p |
|
|
| src, tgt, src_embedding, tgt_embedding = self.keypointnet(src, tgt, src_embedding, tgt_embedding) |
|
|
| temperature, feature_disparity = self.temp_net(src_embedding, tgt_embedding) |
|
|
| return src, tgt, src_embedding, tgt_embedding, temperature, feature_disparity |
| |
| |
| def spam(self, *input): |
| src, tgt, src_embedding, tgt_embedding, temperature, feature_disparity = self.predict_embedding(*input) |
| rotation_ab, translation_ab = self.head(src_embedding, tgt_embedding, src, tgt, temperature) |
| rotation_ba, translation_ba = self.head(tgt_embedding, src_embedding, tgt, src, temperature) |
| return rotation_ab, translation_ab, rotation_ba, translation_ba, feature_disparity |
|
|
| def predict_keypoint_correspondence(self, *input): |
| src, tgt, src_embedding, tgt_embedding, temperature, _ = self.predict_embedding(*input) |
| batch_size, num_dims, num_points = src.size() |
| d_k = src_embedding.size(1) |
| scores = torch.matmul(src_embedding.transpose(2, 1).contiguous(), tgt_embedding) / math.sqrt(d_k) |
| scores = scores.view(batch_size*num_points, num_points) |
| temperature = temperature.repeat(1, num_points, 1).view(-1, 1) |
| scores = F.gumbel_softmax(scores, tau=temperature, hard=True) |
| scores = scores.view(batch_size, num_points, num_points) |
| return src, tgt, scores |
|
|
| def forward(self, *input): |
| calculate_loss = False |
| if len(input) == 2: |
| src, tgt = input[0], input[1] |
| elif len(input) == 3: |
| src, tgt, rotation_ab, translation_ab = input[0], input[1], input[2][:, :3, :3], input[2][:, :3, 3].view(-1, 3) |
| calculate_loss = True |
| elif len(input) == 4: |
| src, tgt, rotation_ab, translation_ab = input[0], input[1], input[2], input[3] |
| calculate_loss = True |
|
|
| if self.input_shape == 'bnc': |
| src, tgt = src.permute(0, 2, 1), tgt.permute(0, 2, 1) |
|
|
| batch_size = src.size(0) |
| identity = torch.eye(3, device=src.device).unsqueeze(0).repeat(batch_size, 1, 1) |
|
|
| rotation_ab_pred = torch.eye(3, device=src.device, dtype=torch.float32).view(1, 3, 3).repeat(batch_size, 1, 1) |
| translation_ab_pred = torch.zeros(3, device=src.device, dtype=torch.float32).view(1, 3).repeat(batch_size, 1) |
|
|
| rotation_ba_pred = torch.eye(3, device=src.device, dtype=torch.float32).view(1, 3, 3).repeat(batch_size, 1, 1) |
| translation_ba_pred = torch.zeros(3, device=src.device, dtype=torch.float32).view(1, 3).repeat(batch_size, 1) |
|
|
| total_loss = 0 |
| total_feature_alignment_loss = 0 |
| total_cycle_consistency_loss = 0 |
| total_scale_consensus_loss = 0 |
|
|
| for i in range(self.num_iters): |
| rotation_ab_pred_i, translation_ab_pred_i, rotation_ba_pred_i, translation_ba_pred_i, feature_disparity = self.spam(src, tgt) |
|
|
| rotation_ab_pred = torch.matmul(rotation_ab_pred_i, rotation_ab_pred) |
| translation_ab_pred = torch.matmul(rotation_ab_pred_i, translation_ab_pred.unsqueeze(2)).squeeze(2) + translation_ab_pred_i |
|
|
| rotation_ba_pred = torch.matmul(rotation_ba_pred_i, rotation_ba_pred) |
| translation_ba_pred = torch.matmul(rotation_ba_pred_i, translation_ba_pred.unsqueeze(2)).squeeze(2) + translation_ba_pred_i |
|
|
| if calculate_loss: |
| loss = (F.mse_loss(torch.matmul(rotation_ab_pred.transpose(2, 1), rotation_ab), identity) \ |
| + F.mse_loss(translation_ab_pred, translation_ab)) * self.discount_factor**i |
| |
| feature_alignment_loss = feature_disparity.mean() * self.feature_alignment_loss * self.discount_factor**i |
| cycle_consistency_loss = cycle_consistency(rotation_ab_pred_i, translation_ab_pred_i, |
| rotation_ba_pred_i, translation_ba_pred_i) \ |
| * self.cycle_consistency_loss * self.discount_factor**i |
|
|
| scale_consensus_loss = 0 |
| total_feature_alignment_loss += feature_alignment_loss |
| total_cycle_consistency_loss += cycle_consistency_loss |
| total_loss = total_loss + loss + feature_alignment_loss + cycle_consistency_loss + scale_consensus_loss |
| |
| if self.input_shape == 'bnc': |
| src = transform.transform_point_cloud(src.permute(0, 2, 1), rotation_ab_pred_i, translation_ab_pred_i).permute(0, 2, 1) |
| else: |
| src = transform.transform_point_cloud(src, rotation_ab_pred_i, translation_ab_pred_i) |
|
|
| if self.input_shape == 'bnc': |
| src, tgt = src.permute(0, 2, 1), tgt.permute(0, 2, 1) |
| |
| result = {'est_R': rotation_ab_pred, |
| 'est_t': translation_ab_pred, |
| 'est_T': transform.convert2transformation(rotation_ab_pred, translation_ab_pred), |
| 'transformed_source': src} |
|
|
| if calculate_loss: |
| result['loss'] = total_loss |
| return result |
|
|
|
|
| if __name__ == '__main__': |
| model = PRNet() |
| src = torch.tensor(10, 1024, 3) |
| tgt = torch.tensor(10, 768, 3) |
| rotation_ab, translation_ab = torch.tensor(10, 3, 3), torch.tensor(10, 3) |
| src, tgt = src.to(device), tgt.to(device) |
| rotation_ab, translation_ab = rotation_ab.to(device), translation_ab.to(device) |
| rotation_ab_pred, translation_ab_pred, loss = model(src, tgt, rotation_ab, translation_ab) |