File size: 5,095 Bytes
97aa5af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
import numpy as np
import torch

class ClassificationData:
	def __init__(self, data_dict):
		self.data_dict = data_dict
		self.pcs = self.find_attribute('pcs')
		self.labels = self.find_attribute('labels')
		self.check_data()

	def find_attribute(self, attribute):
		try:
			attribute_data = self.data_dict[attribute]
		except:
			print("Given data directory has no key attribute \"{}\"".format(attribute))
		return attribute_data

	def check_data(self):
		assert 1 < len(self.pcs.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.pcs.shape)
		assert 0 < len(self.labels.shape) < 3, "Error in dimension of labels! Given data dimension: {}".format(self.labels.shape)
		
		if len(self.pcs.shape)==2: self.pcs = self.pcs.reshape(1, -1, 3)
		if len(self.labels.shape) == 1: self.labels = self.labels.reshape(1, -1)

		assert self.pcs.shape[0] == self.labels.shape[0], "Inconsistency in the number of point clouds and number of ground truth labels!"


	def __len__(self):
		return self.pcs.shape[0]

	def __getitem__(self, index):
		return torch.tensor(self.pcs[index]).float(), torch.from_numpy(self.labels[idx]).type(torch.LongTensor)


class RegistrationData:
	def __init__(self, data_dict):
		self.data_dict = data_dict
		self.template = self.find_attribute('template')
		self.source = self.find_attribute('source')
		self.transformation = self.find_attribute('transformation')
		self.check_data()

	# def find_attribute(self, attribute):
	# 	try:
	# 		attribute_data = self.data[attribute]
	# 	except:
	# 		print("Given data directory has no key attribute \"{}\"".format(attribute))
	# 	return attribute_data
	
	def find_attribute(self, attribute):
		attribute_data = None
		if attribute in self.data_dict:
			attribute_data = self.data_dict[attribute]
		else:
			print("Given data directory has no key attribute \"{}\"".format(attribute))
		return attribute_data

	def check_data(self):
		assert 1 < len(self.template.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.template.shape)
		assert 1 < len(self.source.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.source.shape)
		assert 1 < len(self.transformation.shape) < 4, "Error in dimension of transformations! Given data dimension: {}".format(self.transformation.shape)

		if len(self.template.shape)==2: self.template = self.template.reshape(1, -1, 3)
		if len(self.source.shape)==2: self.source = self.source.reshape(1, -1, 3)
		if len(self.transformation.shape) == 2: self.transformation = self.transformation.reshape(1, 4, 4)

		assert self.template.shape[0] == self.source.shape[0], "Inconsistency in the number of template and source point clouds!"
		assert self.source.shape[0] == self.transformation.shape[0], "Inconsistency in the number of transformation and source point clouds!"

	def __len__(self):
		return self.template.shape[0]

	def __getitem__(self, index):
		return torch.tensor(self.template[index]).float(), torch.tensor(self.source[index]).float(), torch.tensor(self.transformation[index]).float()


class FlowData:
	def __init__(self, data_dict):
		self.data_dict = data_dict
		self.frame1 = self.find_attribute('frame1')
		self.frame2 = self.find_attribute('frame2')
		self.flow = self.find_attribute('flow')
		self.check_data()

	def find_attribute(self, attribute):
		try:
			attribute_data = self.data[attribute]
		except:
			print("Given data directory has no key attribute \"{}\"".format(attribute))
		return attribute_data

	def check_data(self):
		assert 1 < len(self.frame1.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.frame1.shape)
		assert 1 < len(self.frame2.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.frame2.shape)
		assert 1 < len(self.flow.shape) < 4, "Error in dimension of flow! Given data dimension: {}".format(self.flow.shape)

		if len(self.frame1.shape)==2: self.frame1 = self.frame1.reshape(1, -1, 3)
		if len(self.frame2.shape)==2: self.frame2 = self.frame2.reshape(1, -1, 3)
		if len(self.flow.shape) == 2: self.flow = self.flow.reshape(1, -1, 3)

		assert self.frame1.shape[0] == self.frame2.shape[0], "Inconsistency in the number of frame1 and frame2 point clouds!"
		assert self.frame2.shape[0] == self.flow.shape[0], "Inconsistency in the number of flow and frame2 point clouds!"

	def __len__(self):
		return self.frame1.shape[0]

	def __getitem__(self, index):
		return torch.tensor(self.frame1[index]).float(), torch.tensor(self.frame2[index]).float(), torch.tensor(self.flow[index]).float()


class UserData:
	def __init__(self, application, data_dict):
		self.application = application

		if self.application == 'classification':
			self.data_class = ClassificationData(data_dict)
		elif self.application == 'registration':
			self.data_class = RegistrationData(data_dict)
		elif self.application == 'flow_estimation':
			self.data_class = FlowData(data_dict)

	def __len__(self):
		return len(self.data_class)

	def __getitem__(self, index):
		return self.data_class[index]